Compare commits

...

6 Commits

Author SHA1 Message Date
joachim-danswer
f2de6e4c5a k 2025-12-15 09:54:24 -08:00
joachim-danswer
728f727976 updates 2025-12-12 07:21:29 -10:00
joachim-danswer
32220d08d9 temp 2025-12-11 18:11:35 -10:00
joachim-danswer
29522d81f7 db changes 2025-12-11 16:33:56 -10:00
joachim-danswer
97de25ba35 first DC 2025-11-17 17:34:40 -08:00
joachim-danswer
45f6c28605 start 2025-11-14 11:20:13 -08:00
70 changed files with 6611 additions and 3 deletions

View File

@@ -0,0 +1,53 @@
"""add cheat_sheet_context to user
Revision ID: 12h73u00mcwb
Revises: a4f23d6b71c8
Create Date: 2025-11-13 00:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "12h73u00mcwb"
down_revision = "a4f23d6b71c8"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"user",
sa.Column(
"cheat_sheet_context",
postgresql.JSONB(),
nullable=True,
),
)
op.create_table(
"temporary_user_cheat_sheet_context",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("context", postgresql.JSONB(), nullable=False),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.PrimaryKeyConstraint("id"),
)
def downgrade() -> None:
op.drop_table("temporary_user_cheat_sheet_context")
op.drop_column("user", "cheat_sheet_context")

View File

@@ -0,0 +1,97 @@
"""add subscription tables
Revision ID: 13a84b9c2d5f
Revises: 12h73u00mcwb
Create Date: 2025-12-12 00:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "13a84b9c2d5f"
down_revision = "12h73u00mcwb"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Create subscription_registrations table
op.create_table(
"subscription_registrations",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column(
"user_id",
postgresql.UUID(as_uuid=True),
nullable=False,
),
sa.Column(
"doc_extraction_contexts",
postgresql.JSONB(),
nullable=False,
),
sa.Column(
"search_questions",
postgresql.ARRAY(sa.String()),
nullable=False,
),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.PrimaryKeyConstraint("id"),
sa.ForeignKeyConstraint(
["user_id"],
["user.id"],
),
)
# Create subscription_results table
op.create_table(
"subscription_results",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column(
"user_id",
postgresql.UUID(as_uuid=True),
nullable=False,
),
sa.Column("type", sa.String(), nullable=False),
sa.Column(
"notifications",
postgresql.JSONB(),
nullable=False,
),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.PrimaryKeyConstraint("id"),
sa.ForeignKeyConstraint(
["user_id"],
["user.id"],
),
)
def downgrade() -> None:
op.drop_table("subscription_results")
op.drop_table("subscription_registrations")

View File

@@ -0,0 +1,64 @@
from collections.abc import Hashable
from langgraph.graph import END
from langgraph.types import Send
from onyx.agents.agent_search.exploration.enums import DRPath
from onyx.agents.agent_search.exploration.states import MainState
def decision_router(state: MainState) -> list[Send | Hashable] | DRPath | str:
if not state.tools_used:
raise IndexError("state.tools_used cannot be empty")
# next_tool is either a generic tool name or a DRPath string
next_tool_name = state.tools_used[-1]
available_tools = state.available_tools
if next_tool_name == DRPath.THINKING.value:
return DRPath.ORCHESTRATOR # thinking alteady done
elif next_tool_name == DRPath.END.value:
return END
elif not available_tools:
raise ValueError("No tool is available. This should not happen.")
if next_tool_name in available_tools:
next_tool_path = available_tools[next_tool_name].path
elif next_tool_name == DRPath.LOGGER.value:
return DRPath.LOGGER
elif next_tool_name == DRPath.CLOSER.value:
return DRPath.CLOSER
else:
return DRPath.ORCHESTRATOR
# handle invalid paths
if next_tool_path == DRPath.CLARIFIER:
raise ValueError("CLARIFIER is not a valid path during iteration")
# handle tool calls without a query
if (
next_tool_path
in (
DRPath.INTERNAL_SEARCH,
DRPath.WEB_SEARCH,
DRPath.KNOWLEDGE_GRAPH,
DRPath.IMAGE_GENERATION,
)
and len(state.query_list) == 0
):
return DRPath.CLOSER
return next_tool_path
def completeness_router(state: MainState) -> DRPath | str:
if not state.tools_used:
raise IndexError("tools_used cannot be empty")
# go to closer if path is CLOSER or no queries
next_path = state.tools_used[-1]
if next_path == DRPath.ORCHESTRATOR.value:
return DRPath.ORCHESTRATOR
return DRPath.LOGGER

View File

@@ -0,0 +1,31 @@
from onyx.agents.agent_search.exploration.enums import DRPath
from onyx.agents.agent_search.exploration.enums import ResearchType
MAX_CHAT_HISTORY_MESSAGES = (
3 # note: actual count is x2 to account for user and assistant messages
)
MAX_DR_PARALLEL_SEARCH = 4
# TODO: test more, generally not needed/adds unnecessary iterations
MAX_NUM_CLOSER_SUGGESTIONS = (
0 # how many times the closer can send back to the orchestrator
)
CLARIFICATION_REQUEST_PREFIX = "PLEASE CLARIFY:"
HIGH_LEVEL_PLAN_PREFIX = "The Plan:"
AVERAGE_TOOL_COSTS: dict[DRPath, float] = {
DRPath.INTERNAL_SEARCH: 1.0,
DRPath.KNOWLEDGE_GRAPH: 2.0,
DRPath.WEB_SEARCH: 1.5,
DRPath.IMAGE_GENERATION: 3.0,
DRPath.GENERIC_TOOL: 1.5, # TODO: see todo in OrchestratorTool
DRPath.CLOSER: 0.0,
}
DR_TIME_BUDGET_BY_TYPE = {
ResearchType.THOUGHTFUL: 3.0,
ResearchType.DEEP: 12.0,
ResearchType.FAST: 0.5,
}

View File

@@ -0,0 +1,141 @@
BASE_SYSTEM_MESSAGE_TEMPLATE = """
You are a helpful assistant that can answer questions and help with tasks.
You should answer the user's question based on their needs. Their needs and directions \
are specified as follows:
###
---user_prompt---
###
The current date is ---current_date---.
The answer process may be complex and may involve multiple tool calls. Here is \
a description of the tools that MAY be available to you throughout the conversations \
(note though that not all tools may be available at all times, depending on the context):
###
---available_tool_descriptions_str---
###
---cheat_sheet_string---
Here are some principle reminders about how to answer the user's question:
- you will derive the answer through conversational steps with the user.
---plan_instruction_insertion---
- the answer should only be generated using the tools responses and if applicable, \
retrieved documents and information.
"""
PLAN_PROMPT_TEMPLATE = """
Now you should create a plan how to address the user's question.
###
---user_plan_instructions_prompt---
###
"""
ORCHESTRATOR_PROMPT_TEMPLATE = """
Your need to consider the conversation thus far and see what you want to do next in order \
to answer the user's question/task.
Particularly, you should consider the following:
- the ORIGINAL QUESTION
- the plan you had created early on
- the additional context provided in the system prompt, if applicable
- the questions generated so far and the corresponding answers you have received
- previous documented thinking processes, if any. In particular, if \
your previous step in the conversation was a thinking step, pay doubly attention to that \
one as it should provide you with clear guaidance for what to do next.
- the tools you have available, and the instructions you have been given for each tool, \
including how many quesries can be generated in each iteration.
Note:
- make sure that you don't repeat yourself. If you have already asked a question of the same tool, \
do not ask it again! New questons to the same tool must be substantially different from the previous ones.
- a previous question can however be asked of a DIFFERENT tool, if there is reason to believe that the \
new tool is suitable for the question.
- you must make sure that the tool has ALL RELEVANT context to answer the question/adrres the \
request you are posing to it.
- NEVER answer the question directly! If you do have all of the information available AND you \
do not want to request additional information or make checks, you need to call the CLOSER tool.
Your task is to select the next tool to call and the questions/requests to ask of that tool.
"""
EXTRACTION_SYSTEM_PROMPT_TEMPLATE = """
You are an expert in identifying relevant information and strategies from extracted facts, thoughts, and \
answers provided to a user based on their question, that may be useful in INFORMING FUTURE ANSWER STRATEGIES.
As such, your extractions MUST be correct and broadly informative. You will receive the information fro the user
in the next message.
Conceptual examples are:
- facts about the user, their team, or their companies that are helpful to provide context for future questions
- search strategies
- reasonaning strategies
Please format the extractions as a json dictionary in this format:
{{
"user_information": {{<type of user information>: <the extracted information about the user>, ...}},
"company_information": {{<type of company information>: <the extracted information about the company>, ...}},
"search_strategies": {{<type of search strategy>: <the extracted information about the search strategy>, ...}},
"reasoning_strategies": {{<type of reasoning strategy>: <the extracted information about the reasoning strategy>, ...}},
}}
"""
EXTRACTION_SYSTEM_PROMPT = """
You are an expert in identifying relevant information and strategies from extracted facts, thoughts, and \
answers provided to a user based on their question, that may be useful in INFORMING FUTURE ANSWER STRATEGIES, and \
updating/extending the previous knowledge accordingly. As such, your updates MUST be correct and broadly informative.
You will receive the original knowledge and the new information from the user in the next message.
Conceptual examples are:
- facts about the user, their team, or their companies that are helpful to provide context for future questions
- search strategies
- reasonaning strategies
Please format the extractions as a json dictionary in this format:
{{
"user": [<list of new information about the user, each formatted as a dictionary in this format:
{{'type': <essentially, a keyword for the type of infotmation, like 'location', 'interest',... >',
'change_type': <'update', 'delete', or 'add'. 'update' should be selected if this type of information \
exists in the original knowledge but it should be extended/updated, 'delete' if the information in the \
original knowledge is called into question and no final determination can be made, 'add' for a new information type.>,
'information': <the actual information to be added, updated, or deleted. Do not do a rewrite, just the new \
information.>}}>,.. ],
"company": [<list of new information about the user, each formatted as a dictionary in the exact same \
format as above, except for the 'type' key, which should be company-specific instead of user-specific'.>],
"search_strategy": [<list of new information about the search strategies, each formatted as a dictionary \
in the exact same format as above, except for the 'type' key, which should be search-strategy-specific.>],
"reasoning_strategy": [<list of new information about the reasoning strategies, each formatted as a \
dictionary in the exact same format as above, except for the 'type' key, which should be reasoning-strategy-specific.>],
}}
Note:
- make absolutely sure new information about the user is actually ABOUT THE USER WHO IS ASKING THE QUESTION!
- similar, make sure that new information about the company is actually ABOUT THE COMPANY OF THE USER WHO ASKS THE QUESTION!
- only suggest updates if there is substantially new information that should extend the same type of information \
in the original knowledge.
- keep the information concise, to the point, and in a way that would be useful to provide context to future questions.
"""
CONTEXT_UPDATE_SYSTEM_PROMPT = """
You are an expert in updating/modifying previous knowledge as new information becomes available.
Your task is to generate the updated information that consists of both, the old and the new information.
You will receive the original context and the new information from the user in the next message.
Please responsd with the consolidated information. Keep the information concise, to the point, and in a way \
that would be useful to provide context to future questions.
"""

View File

@@ -0,0 +1,112 @@
from datetime import datetime
from onyx.agents.agent_search.exploration.enums import DRPath
from onyx.agents.agent_search.exploration.enums import ResearchType
from onyx.agents.agent_search.exploration.models import DRPromptPurpose
from onyx.agents.agent_search.exploration.models import OrchestratorTool
from onyx.prompts.dr_prompts import GET_CLARIFICATION_PROMPT
from onyx.prompts.dr_prompts import KG_TYPES_DESCRIPTIONS
from onyx.prompts.dr_prompts import ORCHESTRATOR_DEEP_INITIAL_PLAN_PROMPT
from onyx.prompts.dr_prompts import ORCHESTRATOR_DEEP_ITERATIVE_DECISION_PROMPT
from onyx.prompts.dr_prompts import ORCHESTRATOR_FAST_ITERATIVE_DECISION_PROMPT
from onyx.prompts.dr_prompts import ORCHESTRATOR_FAST_ITERATIVE_REASONING_PROMPT
from onyx.prompts.dr_prompts import ORCHESTRATOR_NEXT_STEP_PURPOSE_PROMPT
from onyx.prompts.dr_prompts import TOOL_DIFFERENTIATION_HINTS
from onyx.prompts.dr_prompts import TOOL_QUESTION_HINTS
from onyx.prompts.prompt_template import PromptTemplate
def get_dr_prompt_orchestration_templates(
purpose: DRPromptPurpose,
research_type: ResearchType,
available_tools: dict[str, OrchestratorTool],
entity_types_string: str | None = None,
relationship_types_string: str | None = None,
reasoning_result: str | None = None,
tool_calls_string: str | None = None,
) -> PromptTemplate:
available_tools = available_tools or {}
tool_names = list(available_tools.keys())
tool_description_str = "\n\n".join(
f"- {tool_name}: {tool.description}"
for tool_name, tool in available_tools.items()
)
tool_cost_str = "\n".join(
f"{tool_name}: {tool.cost}" for tool_name, tool in available_tools.items()
)
tool_differentiations: list[str] = [
TOOL_DIFFERENTIATION_HINTS[(tool_1, tool_2)]
for tool_1 in available_tools
for tool_2 in available_tools
if (tool_1, tool_2) in TOOL_DIFFERENTIATION_HINTS
]
tool_differentiation_hint_string = (
"\n".join(tool_differentiations) or "(No differentiating hints available)"
)
# TODO: add tool deliniation pairs for custom tools as well
tool_question_hint_string = (
"\n".join(
"- " + TOOL_QUESTION_HINTS[tool]
for tool in available_tools
if tool in TOOL_QUESTION_HINTS
)
or "(No examples available)"
)
if DRPath.KNOWLEDGE_GRAPH.value in available_tools and (
entity_types_string or relationship_types_string
):
kg_types_descriptions = KG_TYPES_DESCRIPTIONS.build(
possible_entities=entity_types_string or "",
possible_relationships=relationship_types_string or "",
)
else:
kg_types_descriptions = "(The Knowledge Graph is not used.)"
if purpose == DRPromptPurpose.PLAN:
if research_type == ResearchType.THOUGHTFUL:
raise ValueError("plan generation is not supported for FAST time budget")
base_template = ORCHESTRATOR_DEEP_INITIAL_PLAN_PROMPT
elif purpose == DRPromptPurpose.NEXT_STEP_REASONING:
if research_type == ResearchType.THOUGHTFUL:
base_template = ORCHESTRATOR_FAST_ITERATIVE_REASONING_PROMPT
else:
raise ValueError(
"reasoning is not separately required for DEEP time budget"
)
elif purpose == DRPromptPurpose.NEXT_STEP_PURPOSE:
base_template = ORCHESTRATOR_NEXT_STEP_PURPOSE_PROMPT
elif purpose == DRPromptPurpose.NEXT_STEP:
if research_type == ResearchType.THOUGHTFUL:
base_template = ORCHESTRATOR_FAST_ITERATIVE_DECISION_PROMPT
else:
base_template = ORCHESTRATOR_DEEP_ITERATIVE_DECISION_PROMPT
elif purpose == DRPromptPurpose.CLARIFICATION:
if research_type == ResearchType.THOUGHTFUL:
raise ValueError("clarification is not supported for FAST time budget")
base_template = GET_CLARIFICATION_PROMPT
else:
# for mypy, clearly a mypy bug
raise ValueError(f"Invalid purpose: {purpose}")
return base_template.partial_build(
num_available_tools=str(len(tool_names)),
available_tools=", ".join(tool_names),
tool_choice_options=" or ".join(tool_names),
current_time=datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
kg_types_descriptions=kg_types_descriptions,
tool_descriptions=tool_description_str,
tool_differentiation_hints=tool_differentiation_hint_string,
tool_question_hints=tool_question_hint_string,
average_tool_costs=tool_cost_str,
reasoning_result=reasoning_result or "(No reasoning result provided.)",
tool_calls_string=tool_calls_string or "(No tool calls provided.)",
)

View File

@@ -0,0 +1,33 @@
from enum import Enum
class ResearchType(str, Enum):
"""Research type options for agent search operations"""
# BASIC = "BASIC"
LEGACY_AGENTIC = "LEGACY_AGENTIC" # only used for legacy agentic search migrations
THOUGHTFUL = "THOUGHTFUL"
DEEP = "DEEP"
FAST = "FAST"
class ResearchAnswerPurpose(str, Enum):
"""Research answer purpose options for agent search operations"""
ANSWER = "ANSWER"
CLARIFICATION_REQUEST = "CLARIFICATION_REQUEST"
class DRPath(str, Enum):
CLARIFIER = "Clarifier"
ORCHESTRATOR = "Orchestrator"
INTERNAL_SEARCH = "Internal Search"
GENERIC_TOOL = "Generic Tool"
KNOWLEDGE_GRAPH = "Knowledge Graph Search"
WEB_SEARCH = "Web Search"
IMAGE_GENERATION = "Image Generation"
GENERIC_INTERNAL_TOOL = "Generic Internal Tool"
CLOSER = "Closer"
THINKING = "Thinking"
LOGGER = "Logger"
END = "End"

View File

@@ -0,0 +1,88 @@
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
from onyx.agents.agent_search.exploration.conditional_edges import completeness_router
from onyx.agents.agent_search.exploration.conditional_edges import decision_router
from onyx.agents.agent_search.exploration.enums import DRPath
from onyx.agents.agent_search.exploration.nodes.dr_a0_clarification import clarifier
from onyx.agents.agent_search.exploration.nodes.dr_a1_orchestrator import orchestrator
from onyx.agents.agent_search.exploration.nodes.dr_a2_closer import closer
from onyx.agents.agent_search.exploration.nodes.dr_a3_logger import logging
from onyx.agents.agent_search.exploration.states import MainInput
from onyx.agents.agent_search.exploration.states import MainState
from onyx.agents.agent_search.exploration.sub_agents.basic_search.dr_basic_search_graph_builder import (
dr_basic_search_graph_builder,
)
from onyx.agents.agent_search.exploration.sub_agents.custom_tool.dr_custom_tool_graph_builder import (
dr_custom_tool_graph_builder,
)
from onyx.agents.agent_search.exploration.sub_agents.generic_internal_tool.dr_generic_internal_tool_graph_builder import (
dr_generic_internal_tool_graph_builder,
)
from onyx.agents.agent_search.exploration.sub_agents.image_generation.dr_image_generation_graph_builder import (
dr_image_generation_graph_builder,
)
from onyx.agents.agent_search.exploration.sub_agents.kg_search.dr_kg_search_graph_builder import (
dr_kg_search_graph_builder,
)
from onyx.agents.agent_search.exploration.sub_agents.web_search.dr_ws_graph_builder import (
dr_ws_graph_builder,
)
# from onyx.agents.agent_search.exploration.sub_agents.basic_search.dr_basic_search_2_act import search
def exploration_graph_builder() -> StateGraph:
"""
LangGraph graph builder for the deep research agent.
"""
graph = StateGraph(state_schema=MainState, input=MainInput)
### Add nodes ###
graph.add_node(DRPath.CLARIFIER, clarifier)
graph.add_node(DRPath.ORCHESTRATOR, orchestrator)
basic_search_graph = dr_basic_search_graph_builder().compile()
graph.add_node(DRPath.INTERNAL_SEARCH, basic_search_graph)
kg_search_graph = dr_kg_search_graph_builder().compile()
graph.add_node(DRPath.KNOWLEDGE_GRAPH, kg_search_graph)
internet_search_graph = dr_ws_graph_builder().compile()
graph.add_node(DRPath.WEB_SEARCH, internet_search_graph)
image_generation_graph = dr_image_generation_graph_builder().compile()
graph.add_node(DRPath.IMAGE_GENERATION, image_generation_graph)
custom_tool_graph = dr_custom_tool_graph_builder().compile()
graph.add_node(DRPath.GENERIC_TOOL, custom_tool_graph)
generic_internal_tool_graph = dr_generic_internal_tool_graph_builder().compile()
graph.add_node(DRPath.GENERIC_INTERNAL_TOOL, generic_internal_tool_graph)
graph.add_node(DRPath.CLOSER, closer)
graph.add_node(DRPath.LOGGER, logging)
### Add edges ###
graph.add_edge(start_key=START, end_key=DRPath.CLARIFIER)
graph.add_conditional_edges(DRPath.CLARIFIER, decision_router)
graph.add_conditional_edges(DRPath.ORCHESTRATOR, decision_router)
graph.add_edge(start_key=DRPath.INTERNAL_SEARCH, end_key=DRPath.ORCHESTRATOR)
graph.add_edge(start_key=DRPath.KNOWLEDGE_GRAPH, end_key=DRPath.ORCHESTRATOR)
graph.add_edge(start_key=DRPath.WEB_SEARCH, end_key=DRPath.ORCHESTRATOR)
graph.add_edge(start_key=DRPath.IMAGE_GENERATION, end_key=DRPath.ORCHESTRATOR)
graph.add_edge(start_key=DRPath.GENERIC_TOOL, end_key=DRPath.ORCHESTRATOR)
graph.add_edge(start_key=DRPath.GENERIC_INTERNAL_TOOL, end_key=DRPath.ORCHESTRATOR)
graph.add_conditional_edges(DRPath.CLOSER, completeness_router)
graph.add_edge(start_key=DRPath.LOGGER, end_key=END)
return graph

View File

@@ -0,0 +1,180 @@
import json
from collections import defaultdict
from datetime import datetime
from datetime import timedelta
from sqlalchemy.orm import Session
from onyx.agents.agent_search.shared_graph_utils.llm import invoke_llm_raw
from onyx.context.search.models import IndexFilters
from onyx.context.search.preprocessing.access_filters import (
build_access_filters_for_user,
)
from onyx.db.hackathon_subscriptions import get_document_ids_by_cc_pair_name
from onyx.db.hackathon_subscriptions import get_subscription_registration
from onyx.db.hackathon_subscriptions import get_subscription_result
from onyx.db.hackathon_subscriptions import save_subscription_result
from onyx.db.models import SubscriptionResult
from onyx.db.models import User
from onyx.db.search_settings import get_current_search_settings
from onyx.document_index.factory import get_default_document_index
from onyx.document_index.interfaces import VespaChunkRequest
from onyx.llm.interfaces import LLM
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
def process_notifications(
db_session: Session,
llm: LLM,
user: User | None = None,
) -> None:
if not user:
return
subscription_registration = get_subscription_registration(db_session, str(user.id))
if not subscription_registration:
return
doc_extraction_contexts = subscription_registration.doc_extraction_contexts
subscription_registration.search_questions
# Get the document index for retrieval
search_settings = get_current_search_settings(db_session)
document_index = get_default_document_index(search_settings, None)
# Build access control filters for the user
user_acl_filters = build_access_filters_for_user(user, db_session)
# Get current tenant ID
tenant_id = get_current_tenant_id()
analysis_results = defaultdict(lambda: defaultdict(list))
invalid_doc_extraction_context_keys = []
for (
doc_extraction_context_key,
doc_extraction_context_value,
) in doc_extraction_contexts.items():
# Get all document IDs and links for this connector credential pair
# Filter for documents updated after September 18, 2025
date_threshold = datetime(2025, 9, 19) - timedelta(days=1)
documents = get_document_ids_by_cc_pair_name(
db_session, doc_extraction_context_key, date_threshold
)
for document_id, document_link in documents:
# Retrieve all chunks for this specific document with proper access control
filters = IndexFilters(
tenant_id=tenant_id,
access_control_list=user_acl_filters,
)
document_chunks = document_index.id_based_retrieval(
chunk_requests=[VespaChunkRequest(document_id=document_id)],
filters=filters,
batch_retrieval=False,
)
# Sort chunks by chunk_id and concatenate content
content_chunks = [
{chunk.chunk_id: chunk.content} for chunk in document_chunks
]
sorted_content_chunks = sorted(
content_chunks, key=lambda x: list(x.keys())[0]
)
sorted_content_chunks_string = "\n".join(
[
f"{chunk_id}: {content}"
for chunk_dict in sorted_content_chunks
for chunk_id, content in chunk_dict.items()
]
)
# Replace placeholder in extraction context with document content
analysis_prompt = doc_extraction_context_value.replace(
"---doc_content---", sorted_content_chunks_string
)
# Invoke LLM with the analysis prompt
analysis_response = invoke_llm_raw(
llm=llm,
prompt=analysis_prompt,
)
# Parse the response content from string to dictionary
response_content = str(analysis_response.content)
# Try to extract JSON from the response
try:
# First, try to parse the entire response as JSON
analysis_dict = json.loads(response_content)
except json.JSONDecodeError:
# If that fails, try to find JSON within markdown code blocks
import re
json_match = re.search(
r"```(?:json)?\s*(\{.*?\})\s*```", response_content, re.DOTALL
)
if json_match:
analysis_dict = json.loads(json_match.group(1))
else:
# Try to find JSON between curly braces
json_match = re.search(r"\{.*\}", response_content, re.DOTALL)
if json_match:
analysis_dict = json.loads(json_match.group(0))
else:
logger.error(
f"Failed to parse LLM response as JSON: {response_content}"
)
analysis_dict = {}
for analysis_type, analysis_value in analysis_dict.items():
if analysis_value:
analysis_results[doc_extraction_context_key][analysis_type].append(
f"[{document_id}]({document_link}): {analysis_value}"
)
if analysis_type == "use" and analysis_value.lower() != "yes":
invalid_doc_extraction_context_keys.append(document_id)
# Build the analysis string from all results
analysis_string_components = []
for doc_extraction_context_key, analysis_type_dict in analysis_results.items():
if doc_extraction_context_key in invalid_doc_extraction_context_keys:
continue
for analysis_type, analysis_values in analysis_type_dict.items():
if analysis_type != "use":
analysis_string_components.append(f"## Calls - {analysis_type}")
for analysis_value in analysis_values:
# Check if the analysis_value is valid and doesn't contain any invalid document IDs
if (
analysis_value
and not any(
invalid_id in analysis_value
for invalid_id in invalid_doc_extraction_context_keys
)
and analysis_type != "use"
):
analysis_string_components.append(analysis_value)
analysis_string = "\n \n \n ".join(analysis_string_components)
# Save the results to the subscription_results table
subscription_result = SubscriptionResult(
user_id=user.id,
type="document_analysis",
notifications={"analysis": analysis_string},
)
save_subscription_result(db_session, subscription_result)
def get_notifications(
db_session: Session,
user: User,
) -> None:
subscription_result = get_subscription_result(db_session, str(user.id))
if not subscription_result:
return
return subscription_result.notifications["analysis"]

View File

@@ -0,0 +1,147 @@
from enum import Enum
from typing import Dict
from pydantic import BaseModel
from onyx.agents.agent_search.exploration.enums import DRPath
from onyx.agents.agent_search.exploration.sub_agents.image_generation.models import (
GeneratedImage,
)
from onyx.context.search.models import InferenceSection
from onyx.tools.tool import Tool
class OrchestratorStep(BaseModel):
tool: str
questions: list[str]
class OrchestratorDecisonsNoPlan(BaseModel):
reasoning: str
next_step: OrchestratorStep
class OrchestrationPlan(BaseModel):
reasoning: str
plan: str
class ClarificationGenerationResponse(BaseModel):
clarification_needed: bool
clarification_question: str
class DecisionResponse(BaseModel):
reasoning: str
decision: str
class QueryEvaluationResponse(BaseModel):
reasoning: str
query_permitted: bool
class OrchestrationClarificationInfo(BaseModel):
clarification_question: str
clarification_response: str | None = None
class WebSearchAnswer(BaseModel):
urls_to_open_indices: list[int]
class SearchAnswer(BaseModel):
reasoning: str
answer: str
claims: list[str] | None = None
class TestInfoCompleteResponse(BaseModel):
reasoning: str
complete: bool
gaps: list[str]
# TODO: revisit with custom tools implementation in v2
# each tool should be a class with the attributes below, plus the actual tool implementation
# this will also allow custom tools to have their own cost
class OrchestratorTool(BaseModel):
tool_id: int
name: str
llm_path: str # the path for the LLM to refer by
path: DRPath # the actual path in the graph
description: str
metadata: dict[str, str]
cost: float
tool_object: Tool | None = None # None for CLOSER
class Config:
arbitrary_types_allowed = True
class IterationInstructions(BaseModel):
iteration_nr: int
plan: str | None
reasoning: str
purpose: str
class IterationAnswer(BaseModel):
tool: str
tool_id: int
iteration_nr: int
parallelization_nr: int
question: str
reasoning: str | None
answer: str
cited_documents: dict[int, InferenceSection]
background_info: str | None = None
claims: list[str] | None = None
additional_data: dict[str, str] | None = None
response_type: str | None = None
data: dict | list | str | int | float | bool | None = None
file_ids: list[str] | None = None
# TODO: This is not ideal, but we'll can rework the schema
# for deep research later
is_web_fetch: bool = False
# for image generation step-types
generated_images: list[GeneratedImage] | None = None
# for multi-query search tools (v2 web search and internal search)
# TODO: Clean this up to be more flexible to tools
queries: list[str] | None = None
class AggregatedDRContext(BaseModel):
context: str
cited_documents: list[InferenceSection]
is_internet_marker_dict: dict[str, bool]
global_iteration_responses: list[IterationAnswer]
class DRPromptPurpose(str, Enum):
PLAN = "PLAN"
NEXT_STEP = "NEXT_STEP"
NEXT_STEP_REASONING = "NEXT_STEP_REASONING"
NEXT_STEP_PURPOSE = "NEXT_STEP_PURPOSE"
CLARIFICATION = "CLARIFICATION"
class BaseSearchProcessingResponse(BaseModel):
specified_source_types: list[str]
rewritten_query: str
time_filter: str
# EXPLORATION TESTING
class CheatSheetContext(BaseModel):
history: Dict[str, str]
user_context: Dict[str, str]
class ExtractionResponse(BaseModel):
user: list[Dict[str, str]]
company: list[Dict[str, str]]
search_strategy: list[Dict[str, str]]
reasoning_strategy: list[Dict[str, str]]

View File

@@ -0,0 +1,716 @@
import re
from datetime import datetime
from typing import Any
from typing import cast
from langchain_core.messages import AIMessage
from langchain_core.messages import HumanMessage
from langchain_core.messages import SystemMessage
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from sqlalchemy.orm import Session
from onyx.agents.agent_search.exploration.constants import AVERAGE_TOOL_COSTS
from onyx.agents.agent_search.exploration.constants import MAX_CHAT_HISTORY_MESSAGES
from onyx.agents.agent_search.exploration.dr_experimentation_prompts import (
BASE_SYSTEM_MESSAGE_TEMPLATE,
)
from onyx.agents.agent_search.exploration.dr_experimentation_prompts import (
PLAN_PROMPT_TEMPLATE,
)
from onyx.agents.agent_search.exploration.enums import DRPath
from onyx.agents.agent_search.exploration.enums import ResearchAnswerPurpose
from onyx.agents.agent_search.exploration.hackathon_functions import get_notifications
from onyx.agents.agent_search.exploration.hackathon_functions import (
process_notifications,
)
from onyx.agents.agent_search.exploration.models import OrchestrationClarificationInfo
from onyx.agents.agent_search.exploration.models import OrchestrationPlan
from onyx.agents.agent_search.exploration.models import OrchestratorTool
from onyx.agents.agent_search.exploration.states import MainState
from onyx.agents.agent_search.exploration.states import OrchestrationSetup
from onyx.agents.agent_search.exploration.utils import get_chat_history_string
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.llm import invoke_llm_json
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.chat_utils import build_citation_map_from_numbers
from onyx.chat.chat_utils import saved_search_docs_from_llm_docs
from onyx.chat.memories import get_memories
from onyx.configs.agent_configs import TF_DR_TIMEOUT_SHORT
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import DocumentSourceDescription
from onyx.configs.constants import TMP_DRALPHA_PERSONA_NAME
from onyx.configs.exploration_research_configs import (
EXPLORATION_TEST_USE_CALRIFIER_DEFAULT,
)
from onyx.configs.exploration_research_configs import (
EXPLORATION_TEST_USE_CORPUS_HISTORY_DEFAULT,
)
from onyx.configs.exploration_research_configs import EXPLORATION_TEST_USE_PLAN_DEFAULT
from onyx.configs.exploration_research_configs import (
EXPLORATION_TEST_USE_PLAN_UPDATES_DEFAULT,
)
from onyx.configs.exploration_research_configs import (
EXPLORATION_TEST_USE_THINKING_DEFAULT,
)
from onyx.db.chat import create_search_doc_from_saved_search_doc
from onyx.db.connector import fetch_unique_document_sources
from onyx.db.models import SearchDoc
from onyx.db.models import Tool
from onyx.db.tools import get_tools
from onyx.db.users import get_user_cheat_sheet_context
from onyx.file_store.models import ChatFileType
from onyx.file_store.models import InMemoryChatFile
from onyx.llm.utils import check_number_of_tokens
from onyx.llm.utils import get_max_input_tokens
from onyx.natural_language_processing.utils import get_tokenizer
from onyx.prompts.chat_prompts import PROJECT_INSTRUCTIONS_SEPARATOR
from onyx.prompts.dr_prompts import DEFAULT_DR_SYSTEM_PROMPT
from onyx.prompts.dr_prompts import TOOL_DESCRIPTION
from onyx.prompts.prompt_template import PromptTemplate
from onyx.prompts.prompt_utils import handle_company_awareness
from onyx.prompts.prompt_utils import handle_memories
from onyx.server.query_and_chat.streaming_models import MessageDelta
from onyx.server.query_and_chat.streaming_models import MessageStart
from onyx.server.query_and_chat.streaming_models import SectionEnd
from onyx.tools.tool_implementations.images.image_generation_tool import (
ImageGenerationTool,
)
from onyx.tools.tool_implementations.knowledge_graph.knowledge_graph_tool import (
KnowledgeGraphTool,
)
from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.tools.tool_implementations.web_search.web_search_tool import (
WebSearchTool,
)
from onyx.utils.b64 import get_image_type
from onyx.utils.b64 import get_image_type_from_bytes
from onyx.utils.logger import setup_logger
logger = setup_logger()
_PLAN_INSTRUCTION_INSERTION = """ - Early on, the user MAY ask you to create a plan for the answer process. \
Think about the tools you have available and how you can use them, and then \
create a HIGH-LEVEL PLAN of how you want to approach the answer process."""
def _get_available_tools(
db_session: Session,
graph_config: GraphConfig,
kg_enabled: bool,
active_source_types: list[DocumentSource],
use_clarifier: bool = False,
use_thinking: bool = False,
) -> dict[str, OrchestratorTool]:
available_tools: dict[str, OrchestratorTool] = {}
kg_enabled = graph_config.behavior.kg_config_settings.KG_ENABLED
persona = graph_config.inputs.persona
if persona:
include_kg = persona.name == TMP_DRALPHA_PERSONA_NAME and kg_enabled
else:
include_kg = False
tool_dict: dict[int, Tool] = {
tool.id: tool for tool in get_tools(db_session, only_enabled=True)
}
for tool in graph_config.tooling.tools:
if not tool.is_available(db_session):
logger.info(f"Tool {tool.name} is not available, skipping")
continue
tool_db_info = tool_dict.get(tool.id)
if tool_db_info:
incode_tool_id = tool_db_info.in_code_tool_id
else:
raise ValueError(f"Tool {tool.name} is not found in the database")
if isinstance(tool, WebSearchTool):
llm_path = DRPath.WEB_SEARCH.value
path = DRPath.WEB_SEARCH
elif isinstance(tool, SearchTool):
llm_path = DRPath.INTERNAL_SEARCH.value
path = DRPath.INTERNAL_SEARCH
elif isinstance(tool, KnowledgeGraphTool) and include_kg:
# TODO (chris): move this into the `is_available` check
if len(active_source_types) == 0:
logger.error(
"No active source types found, skipping Knowledge Graph tool"
)
continue
llm_path = DRPath.KNOWLEDGE_GRAPH.value
path = DRPath.KNOWLEDGE_GRAPH
elif isinstance(tool, ImageGenerationTool):
llm_path = DRPath.IMAGE_GENERATION.value
path = DRPath.IMAGE_GENERATION
elif incode_tool_id:
# if incode tool id is found, it is a generic internal tool
llm_path = DRPath.GENERIC_INTERNAL_TOOL.value
path = DRPath.GENERIC_INTERNAL_TOOL
else:
# otherwise it is a custom tool
llm_path = DRPath.GENERIC_TOOL.value
path = DRPath.GENERIC_TOOL
if path not in {DRPath.GENERIC_INTERNAL_TOOL, DRPath.GENERIC_TOOL}:
description = TOOL_DESCRIPTION.get(path, tool.description)
cost = AVERAGE_TOOL_COSTS[path]
else:
description = tool.description
cost = 1.0
tool_info = OrchestratorTool(
tool_id=tool.id,
name=tool.llm_name,
llm_path=llm_path,
path=path,
description=description,
metadata={},
cost=cost,
tool_object=tool,
)
# TODO: handle custom tools with same name as other tools (e.g., CLOSER)
available_tools[tool.llm_name] = tool_info
available_tool_paths = [tool.path for tool in available_tools.values()]
# make sure KG isn't enabled without internal search
if (
DRPath.KNOWLEDGE_GRAPH in available_tool_paths
and DRPath.INTERNAL_SEARCH not in available_tool_paths
):
raise ValueError(
"The Knowledge Graph is not supported without internal search tool"
)
# add CLOSER tool, which is always available
available_tools[DRPath.CLOSER.value] = OrchestratorTool(
tool_id=-1,
name=DRPath.CLOSER.value,
llm_path=DRPath.CLOSER.value,
path=DRPath.CLOSER,
description=TOOL_DESCRIPTION[DRPath.CLOSER],
metadata={},
cost=0.0,
tool_object=None,
)
if use_thinking:
available_tools[DRPath.THINKING.value] = OrchestratorTool(
tool_id=102,
name=DRPath.THINKING.value,
llm_path=DRPath.THINKING.value,
path=DRPath.THINKING,
description="""This tool should be used if the next step is not particularly clear, \
or if you think you need to think through the original question and the questions and answers \
you have received so far in order to make a decision about what to do next AMONGST THE TOOLS AVAILABLE TO YOU \
IN THIS REQUEST! (Note: some tools described earlier may be excluded!).
If in doubt, use this tool. No action will be taken, just some reasoning will be done.""",
metadata={},
cost=0.0,
tool_object=None,
)
if use_clarifier:
available_tools[DRPath.CLARIFIER.value] = OrchestratorTool(
tool_id=103,
name=DRPath.CLARIFIER.value,
llm_path=DRPath.CLARIFIER.value,
path=DRPath.CLARIFIER,
description="""This tool should be used ONLY if you need to have clarification on something IMPORTANT FROM \
the user. This can pertain to the original question or something you found out during the process so far.""",
metadata={},
cost=0.0,
tool_object=None,
)
return available_tools
def _construct_uploaded_text_context(files: list[InMemoryChatFile]) -> str:
"""Construct the uploaded context from the files."""
file_contents = []
for file in files:
if file.file_type in (
ChatFileType.DOC,
ChatFileType.PLAIN_TEXT,
ChatFileType.CSV,
):
file_contents.append(file.content.decode("utf-8"))
if len(file_contents) > 0:
return "Uploaded context:\n\n\n" + "\n\n".join(file_contents)
return ""
def _construct_uploaded_image_context(
files: list[InMemoryChatFile] | None = None,
img_urls: list[str] | None = None,
b64_imgs: list[str] | None = None,
) -> list[dict[str, Any]] | None:
"""Construct the uploaded image context from the files."""
# Only include image files for user messages
if files is None:
return None
img_files = [file for file in files if file.file_type == ChatFileType.IMAGE]
img_urls = img_urls or []
b64_imgs = b64_imgs or []
if not (img_files or img_urls or b64_imgs):
return None
return cast(
list[dict[str, Any]],
[
{
"type": "image_url",
"image_url": {
"url": (
f"data:{get_image_type_from_bytes(file.content)};"
f"base64,{file.to_base64()}"
),
},
}
for file in img_files
]
+ [
{
"type": "image_url",
"image_url": {
"url": f"data:{get_image_type(b64_img)};base64,{b64_img}",
},
}
for b64_img in b64_imgs
]
+ [
{
"type": "image_url",
"image_url": {
"url": url,
},
}
for url in img_urls
],
)
def _get_existing_clarification_request(
graph_config: GraphConfig,
) -> tuple[OrchestrationClarificationInfo, str, str] | None:
"""
Returns the clarification info, original question, and updated chat history if
a clarification request and response exists, otherwise returns None.
"""
# check for clarification request and response in message history
previous_raw_messages = graph_config.inputs.prompt_builder.raw_message_history
if len(previous_raw_messages) == 0 or (
previous_raw_messages[-1].research_answer_purpose
!= ResearchAnswerPurpose.CLARIFICATION_REQUEST
):
return None
# get the clarification request and response
previous_messages = graph_config.inputs.prompt_builder.message_history
last_message = previous_raw_messages[-1].message
clarification = OrchestrationClarificationInfo(
clarification_question=last_message.strip(),
clarification_response=graph_config.inputs.prompt_builder.raw_user_query,
)
original_question = graph_config.inputs.prompt_builder.raw_user_query
chat_history_string = "(No chat history yet available)"
# get the original user query and chat history string before the original query
# e.g., if history = [user query, assistant clarification request, user clarification response],
# previous_messages = [user query, assistant clarification request], we want the user query
for i, message in enumerate(reversed(previous_messages), 1):
if (
isinstance(message, HumanMessage)
and message.content
and isinstance(message.content, str)
):
original_question = message.content
chat_history_string = (
get_chat_history_string(
graph_config.inputs.prompt_builder.message_history[:-i],
MAX_CHAT_HISTORY_MESSAGES,
)
or "(No chat history yet available)"
)
break
return clarification, original_question, chat_history_string
def _persist_final_docs_and_citations(
db_session: Session,
context_llm_docs: list[Any] | None,
full_answer: str | None,
) -> tuple[list[SearchDoc], dict[int, int] | None]:
"""Persist final documents from in-context docs and derive citation mapping.
Returns the list of persisted `SearchDoc` records and an optional
citation map translating inline [[n]] references to DB doc indices.
"""
final_documents_db: list[SearchDoc] = []
citations_map: dict[int, int] | None = None
if not context_llm_docs:
return final_documents_db, citations_map
saved_search_docs = saved_search_docs_from_llm_docs(context_llm_docs)
for saved_doc in saved_search_docs:
db_doc = create_search_doc_from_saved_search_doc(saved_doc)
db_session.add(db_doc)
final_documents_db.append(db_doc)
db_session.flush()
cited_numbers: set[int] = set()
try:
# Match [[1]] or [[1, 2]] optionally followed by a link like ([[1]](http...))
matches = re.findall(
r"\[\[(\d+(?:,\s*\d+)*)\]\](?:\([^)]*\))?", full_answer or ""
)
for match in matches:
for num_str in match.split(","):
num = int(num_str.strip())
cited_numbers.add(num)
except Exception:
cited_numbers = set()
if cited_numbers and final_documents_db:
translations = build_citation_map_from_numbers(
cited_numbers=cited_numbers,
db_docs=final_documents_db,
)
citations_map = translations or None
return final_documents_db, citations_map
_ARTIFICIAL_ALL_ENCOMPASSING_TOOL = {
"type": "function",
"function": {
"name": "run_any_knowledge_retrieval_and_any_action_tool",
"description": "Use this tool to get ANY external information \
that is relevant to the question, or for any action to be taken, including image generation. In fact, \
ANY tool mentioned can be accessed through this generic tool. If in doubt, use this tool.",
"parameters": {
"type": "object",
"properties": {
"request": {
"type": "string",
"description": "The request to be made to the tool",
},
},
"required": ["request"],
},
},
}
def clarifier(
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> OrchestrationSetup:
"""
Perform a quick search on the question as is and see whether a set of clarification
questions is needed. For now this is based on the models
"""
_EXPLORATION_TEST_USE_CALRIFIER = EXPLORATION_TEST_USE_CALRIFIER_DEFAULT
_EXPLORATION_TEST_USE_PLAN = EXPLORATION_TEST_USE_PLAN_DEFAULT
_EXPLORATION_TEST_USE_PLAN_UPDATES = EXPLORATION_TEST_USE_PLAN_UPDATES_DEFAULT
_EXPLORATION_TEST_USE_CORPUS_HISTORY = EXPLORATION_TEST_USE_CORPUS_HISTORY_DEFAULT
_EXPLORATION_TEST_USE_THINKING = EXPLORATION_TEST_USE_THINKING_DEFAULT
_EXPLORATION_TEST_USE_PLAN = False
node_start_time = datetime.now()
current_step_nr = 0
graph_config = cast(GraphConfig, config["metadata"]["config"])
llm_provider = graph_config.tooling.primary_llm.config.model_provider
llm_model_name = graph_config.tooling.primary_llm.config.model_name
llm_tokenizer = get_tokenizer(
model_name=llm_model_name,
provider_type=llm_provider,
)
max_input_tokens = get_max_input_tokens(
model_name=llm_model_name,
model_provider=llm_provider,
)
db_session = graph_config.persistence.db_session
original_question = graph_config.inputs.prompt_builder.raw_user_query
# Perform a commit to ensure the message_id is set and saved
db_session.commit()
# get the connected tools and format for the Deep Research flow
kg_enabled = graph_config.behavior.kg_config_settings.KG_ENABLED
active_source_types = fetch_unique_document_sources(db_session)
available_tools = _get_available_tools(
db_session,
graph_config,
kg_enabled,
active_source_types,
use_clarifier=_EXPLORATION_TEST_USE_CALRIFIER,
use_thinking=_EXPLORATION_TEST_USE_THINKING,
)
available_tool_descriptions_str = "\n -" + "\n -".join(
[
tool.name + ": " + tool.description
for tool in available_tools.values()
if tool.path != DRPath.CLOSER
]
)
active_source_types_descriptions = [
DocumentSourceDescription[source_type] for source_type in active_source_types
]
if len(active_source_types_descriptions) > 0:
active_source_type_descriptions_str = "\n -" + "\n -".join(
active_source_types_descriptions
)
else:
active_source_type_descriptions_str = ""
if graph_config.inputs.persona:
assistant_system_prompt = PromptTemplate(
graph_config.inputs.persona.system_prompt or DEFAULT_DR_SYSTEM_PROMPT
).build()
if graph_config.inputs.persona.task_prompt:
assistant_task_prompt = (
"\n\nHere are more specifications from the user:\n\n"
+ PromptTemplate(graph_config.inputs.persona.task_prompt).build()
)
else:
assistant_task_prompt = ""
else:
assistant_system_prompt = PromptTemplate(DEFAULT_DR_SYSTEM_PROMPT).build()
assistant_task_prompt = ""
if graph_config.inputs.project_instructions:
assistant_system_prompt = (
assistant_system_prompt
+ PROJECT_INSTRUCTIONS_SEPARATOR
+ graph_config.inputs.project_instructions
)
user = (
graph_config.tooling.search_tool.user
if graph_config.tooling.search_tool
else None
)
continue_to_answer = True
if original_question == "process_notifications" and user:
process_notifications(db_session, llm=graph_config.tooling.fast_llm, user=user)
continue_to_answer = False
# Stream the notifications message
write_custom_event(
current_step_nr,
MessageStart(content="", final_documents=None),
writer,
)
write_custom_event(
current_step_nr,
MessageDelta(content="Done!"),
writer,
)
write_custom_event(current_step_nr, SectionEnd(), writer)
elif original_question == "get_notifications" and user:
notifications = get_notifications(db_session, user)
if notifications:
# Stream the notifications message
write_custom_event(
current_step_nr,
MessageStart(content="", final_documents=None),
writer,
)
write_custom_event(
current_step_nr,
MessageDelta(content=notifications),
writer,
)
write_custom_event(current_step_nr, SectionEnd(), writer)
continue_to_answer = False
if not continue_to_answer:
return OrchestrationSetup(
original_question=original_question,
chat_history_string="",
tools_used=[DRPath.END.value],
query_list=[],
iteration_nr=0,
current_step_nr=current_step_nr,
)
memories = get_memories(user, db_session)
assistant_system_prompt = handle_company_awareness(assistant_system_prompt)
assistant_system_prompt = handle_memories(assistant_system_prompt, memories)
chat_history_string = (
get_chat_history_string(
graph_config.inputs.prompt_builder.message_history,
MAX_CHAT_HISTORY_MESSAGES,
)
or "(No chat history yet available)"
)
uploaded_text_context = (
_construct_uploaded_text_context(graph_config.inputs.files)
if graph_config.inputs.files
else ""
)
uploaded_context_tokens = check_number_of_tokens(
uploaded_text_context, llm_tokenizer.encode
)
if uploaded_context_tokens > 0.5 * max_input_tokens:
raise ValueError(
f"Uploaded context is too long. {uploaded_context_tokens} tokens, "
f"but for this model we only allow {0.5 * max_input_tokens} tokens for uploaded context"
)
uploaded_image_context = _construct_uploaded_image_context(
graph_config.inputs.files
)
current_step_nr += 1
clarification = None
message_history_for_continuation: list[SystemMessage | HumanMessage | AIMessage] = (
[]
)
if user is not None:
original_cheat_sheet_context = get_user_cheat_sheet_context(
user=user, db_session=db_session
)
if original_cheat_sheet_context:
cheat_sheet_string = f"""\n\nHere is additional context learned that may inform the \
process (plan generation if applicable, reasoning, tool calls, etc.):\n{str(original_cheat_sheet_context)}\n###\n\n"""
else:
cheat_sheet_string = ""
if _EXPLORATION_TEST_USE_PLAN:
plan_instruction_insertion = _PLAN_INSTRUCTION_INSERTION
else:
plan_instruction_insertion = ""
system_message = (
BASE_SYSTEM_MESSAGE_TEMPLATE.replace(
"---user_prompt---", assistant_system_prompt
)
.replace("---current_date---", datetime.now().strftime("%Y-%m-%d"))
.replace(
"---available_tool_descriptions_str---", available_tool_descriptions_str
)
.replace(
"---active_source_types_descriptions_str---",
active_source_type_descriptions_str,
)
.replace(
"---cheat_sheet_string---",
cheat_sheet_string,
)
.replace("---plan_instruction_insertion---", plan_instruction_insertion)
)
message_history_for_continuation.append(SystemMessage(content=system_message))
message_history_for_continuation.append(
HumanMessage(
content=f"""Here is the questions to answer:\n{original_question}"""
)
)
message_history_for_continuation.append(
AIMessage(content="""How should I proceed to answer the question?""")
)
if _EXPLORATION_TEST_USE_PLAN:
user_plan_instructions_prompt = """Think carefully how you want to address the question. You may use multiple iterations \
of tool calls, reasoning, etc.
Note:
- the plan should be HIGH-LEVEL! Do not specify any specific tools, but think about what you want to learn in each iteration.
- if the question is simple, one iteration may be enough.
- DO NOT close with 'summarize...' etc as the last steps. Just focus on the information gathering steps.
"""
plan_prompt = PLAN_PROMPT_TEMPLATE.replace(
"---user_plan_instructions_prompt---", user_plan_instructions_prompt
)
message_history_for_continuation.append(HumanMessage(content=plan_prompt))
plan_of_record = invoke_llm_json(
llm=graph_config.tooling.primary_llm,
prompt=message_history_for_continuation,
schema=OrchestrationPlan,
timeout_override=TF_DR_TIMEOUT_SHORT,
# max_tokens=3000,
)
plan_string = f"""Here is how the answer process should be broken down: {plan_of_record.plan}"""
message_history_for_continuation.append(AIMessage(content=plan_string))
next_tool = DRPath.ORCHESTRATOR.value
return OrchestrationSetup(
original_question=original_question,
chat_history_string=chat_history_string,
tools_used=[next_tool],
query_list=[],
iteration_nr=0,
current_step_nr=current_step_nr,
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="clarifier",
node_start_time=node_start_time,
)
],
clarification=clarification,
available_tools=available_tools,
active_source_types=active_source_types,
active_source_types_descriptions="\n".join(active_source_types_descriptions),
assistant_system_prompt=assistant_system_prompt,
assistant_task_prompt=assistant_task_prompt,
uploaded_test_context=uploaded_text_context,
uploaded_image_context=uploaded_image_context,
message_history_for_continuation=message_history_for_continuation,
cheat_sheet_context=original_cheat_sheet_context,
use_clarifier=_EXPLORATION_TEST_USE_CALRIFIER,
use_thinking=_EXPLORATION_TEST_USE_THINKING,
use_plan=_EXPLORATION_TEST_USE_PLAN,
use_plan_updates=_EXPLORATION_TEST_USE_PLAN_UPDATES,
use_corpus_history=_EXPLORATION_TEST_USE_CORPUS_HISTORY,
)

View File

@@ -0,0 +1,325 @@
from datetime import datetime
from typing import cast
from langchain_core.messages import AIMessage
from langchain_core.messages import HumanMessage
from langchain_core.messages import SystemMessage
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.exploration.dr_experimentation_prompts import (
ORCHESTRATOR_PROMPT_TEMPLATE,
)
from onyx.agents.agent_search.exploration.enums import DRPath
from onyx.agents.agent_search.exploration.enums import ResearchType
from onyx.agents.agent_search.exploration.models import IterationAnswer
from onyx.agents.agent_search.exploration.states import IterationInstructions
from onyx.agents.agent_search.exploration.states import MainState
from onyx.agents.agent_search.exploration.states import OrchestrationUpdate
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.llm import invoke_llm_raw
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.configs.agent_configs import TF_DR_TIMEOUT_LONG
from onyx.utils.logger import setup_logger
logger = setup_logger()
_SEARCH_TOOL = {
"type": "function",
"function": {
"name": "search_tool",
"description": "This tool is the search tool whose functionality and details are \
described in he system prompt. Use it if you think you have one or more questions that you believe are \
suitable for the search tool.",
"parameters": {
"type": "object",
"properties": {
"request": {
"type": "array",
"description": "The list of questions to be asked of the search tool",
"items": {
"type": "string",
"description": "The question to be asked of the search tool",
},
},
},
"required": ["request"],
},
},
}
_THINKING_TOOL = {
"type": "function",
"function": {
"name": "thinking_tool",
"description": "This tool is used if yoi think you need to think through the original question and the \
questions and answers you have received so far in order to male a decision about what to do next. If in doubt, use this tool.",
"parameters": {
"type": "object",
"properties": {
"request": {
"type": "string",
"description": "Please generate the thinking here you want to do that leads you to the next decision. This \
should end with a recommendation of which tool to invoke next.",
},
},
},
},
}
_CLOSER_TOOL = {
"type": "function",
"function": {
"name": "closer_tool",
"description": "This tool is used to close the conversation. Use it if you think you have \
all of the information you need to answer the question, and you also do not want to request additional \
information or make checks.",
"parameters": {
"type": "object",
"properties": {
"request": {
"type": "string",
"description": "The request to be made to the thinking tool",
},
},
},
},
}
_CLARIFIER_TOOL = {
"type": "function",
"function": {
"name": "clarifier_tool",
"description": "This tool is used if you need to have clarification on something IMPORTANT from \
the user. This can pertain to the original question or something you found out during the process so far.",
},
"parameters": {
"type": "object",
"properties": {
"request": {
"type": "string",
"description": "The question you would like to ask the user to get clarification.",
},
},
},
}
_DECISION_SYSTEM_PROMPT_PREFIX = "Here are general instructions by the user, which \
may or may not influence the decision what to do next:\n\n"
def _get_implied_next_tool_based_on_tool_call_history(
tools_used: list[str],
) -> str | None:
"""
Identify the next tool based on the tool call history. Initially, we only support
special handling of the image generation tool.
"""
if tools_used[-1] == DRPath.IMAGE_GENERATION.value:
return DRPath.LOGGER.value
else:
return None
def orchestrator(
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> OrchestrationUpdate:
"""
LangGraph node to decide the next step in the DR process.
"""
node_start_time = datetime.now()
_EXPLORATION_TEST_USE_CALRIFIER = state.use_clarifier
state.use_plan
state.use_plan_updates
state.use_corpus_history
_EXPLORATION_TEST_USE_THINKING = state.use_thinking
previous_tool_call_name = state.tools_used[-1] if state.tools_used else ""
graph_config = cast(GraphConfig, config["metadata"]["config"])
question = state.original_question
if not question:
raise ValueError("Question is required for orchestrator")
iteration_nr = state.iteration_nr
plan_of_record = state.plan_of_record
message_history_for_continuation = state.message_history_for_continuation
new_messages_for_continuation: list[SystemMessage | HumanMessage | AIMessage] = []
if iteration_nr > 0:
last_iteration_responses = [
x
for x in state.iteration_responses
if x.iteration_nr == iteration_nr
and x.tool != DRPath.CLARIFIER.value
and x.tool != DRPath.THINKING.value
]
if last_iteration_responses:
response_wrapper = f"For the previous iteration {iteration_nr}, here are the tool calls I decided to execute, \
the questions and tasks posed, and responses:\n\n"
for last_iteration_response in last_iteration_responses:
response_wrapper += f"{last_iteration_response.tool}: {last_iteration_response.question}\n"
response_wrapper += f"Response: {last_iteration_response.answer}\n\n"
message_history_for_continuation.append(AIMessage(content=response_wrapper))
new_messages_for_continuation.append(AIMessage(content=response_wrapper))
iteration_nr += 1
current_step_nr = state.current_step_nr
ResearchType.DEEP
remaining_time_budget = state.remaining_time_budget
state.chat_history_string or "(No chat history yet available)"
next_tool_name = None
# Identify early exit condition based on tool call history
next_tool_based_on_tool_call_history = (
_get_implied_next_tool_based_on_tool_call_history(state.tools_used)
)
if next_tool_based_on_tool_call_history == DRPath.LOGGER.value:
return OrchestrationUpdate(
tools_used=[DRPath.LOGGER.value],
query_list=[],
iteration_nr=iteration_nr,
current_step_nr=current_step_nr,
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="orchestrator",
node_start_time=node_start_time,
)
],
plan_of_record=plan_of_record,
remaining_time_budget=remaining_time_budget,
iteration_instructions=[
IterationInstructions(
iteration_nr=iteration_nr,
plan=plan_of_record.plan if plan_of_record else None,
reasoning="",
purpose="",
)
],
)
# no early exit forced. Continue.
state.available_tools or {}
state.uploaded_test_context or ""
state.uploaded_image_context or []
# default to closer
query_list = ["Answer the question with the information you have."]
reasoning_result = "(No reasoning result provided yet.)"
ORCHESTRATOR_PROMPT = ORCHESTRATOR_PROMPT_TEMPLATE
message_history_for_continuation.append(HumanMessage(content=ORCHESTRATOR_PROMPT))
new_messages_for_continuation.append(HumanMessage(content=ORCHESTRATOR_PROMPT))
tools = [_SEARCH_TOOL, _CLOSER_TOOL]
if (
_EXPLORATION_TEST_USE_THINKING
and previous_tool_call_name != DRPath.THINKING.value
):
tools.append(_THINKING_TOOL)
if (
_EXPLORATION_TEST_USE_CALRIFIER
and previous_tool_call_name != DRPath.CLARIFIER.value
):
tools.append(_CLARIFIER_TOOL)
in_orchestration_iteration_answers: list[IterationAnswer] = []
if remaining_time_budget > 0:
orchestrator_action: AIMessage = invoke_llm_raw(
llm=graph_config.tooling.primary_llm,
prompt=message_history_for_continuation,
tools=tools,
timeout_override=TF_DR_TIMEOUT_LONG,
# max_tokens=1500,
)
tool_calls = orchestrator_action.tool_calls
if tool_calls:
for tool_call in tool_calls:
if tool_call["name"] == "search_tool":
query_list = tool_call["args"]["request"]
next_tool_name = DRPath.INTERNAL_SEARCH.value
elif tool_call["name"] == "thinking_tool":
reasoning_result = tool_call["args"]["request"]
next_tool_name = (
DRPath.THINKING.value
) # note: thinking already done. Will return to Orchestrator.
message_history_for_continuation.append(
AIMessage(content=reasoning_result)
)
new_messages_for_continuation.append(
AIMessage(content=reasoning_result)
)
in_orchestration_iteration_answers.append(
IterationAnswer(
tool=DRPath.THINKING.value,
tool_id=102,
iteration_nr=iteration_nr,
parallelization_nr=0,
question="",
cited_documents={},
answer=reasoning_result,
reasoning=reasoning_result,
)
)
elif tool_call["name"] == "closer_tool":
reasoning_result = "Time to wrap up."
next_tool_name = DRPath.CLOSER.value
elif tool_call["name"] == "clarifier_tool":
reasoning_result = tool_call["args"]["request"]
next_tool_name = DRPath.CLARIFIER.value
else:
raise ValueError(f"Unknown tool: {tool_call['name']}")
else:
reasoning_result = "Time to wrap up. All information is available"
new_messages_for_continuation.append(AIMessage(content=reasoning_result))
next_tool_name = DRPath.CLOSER.value
return OrchestrationUpdate(
tools_used=[next_tool_name or ""],
query_list=query_list or [],
iteration_nr=iteration_nr,
current_step_nr=current_step_nr,
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="orchestrator",
node_start_time=node_start_time,
)
],
plan_of_record=plan_of_record,
remaining_time_budget=remaining_time_budget - 1.0,
iteration_instructions=[
IterationInstructions(
iteration_nr=iteration_nr,
plan=plan_of_record.plan if plan_of_record else None,
reasoning=reasoning_result,
purpose="",
)
],
message_history_for_continuation=new_messages_for_continuation,
iteration_responses=in_orchestration_iteration_answers,
)

View File

@@ -0,0 +1,431 @@
import re
from datetime import datetime
from typing import cast
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from sqlalchemy.orm import Session
from onyx.agents.agent_search.exploration.constants import MAX_CHAT_HISTORY_MESSAGES
from onyx.agents.agent_search.exploration.constants import MAX_NUM_CLOSER_SUGGESTIONS
from onyx.agents.agent_search.exploration.enums import DRPath
from onyx.agents.agent_search.exploration.enums import ResearchAnswerPurpose
from onyx.agents.agent_search.exploration.enums import ResearchType
from onyx.agents.agent_search.exploration.models import AggregatedDRContext
from onyx.agents.agent_search.exploration.models import TestInfoCompleteResponse
from onyx.agents.agent_search.exploration.states import FinalUpdate
from onyx.agents.agent_search.exploration.states import MainState
from onyx.agents.agent_search.exploration.states import OrchestrationUpdate
from onyx.agents.agent_search.exploration.sub_agents.image_generation.models import (
GeneratedImageFullResult,
)
from onyx.agents.agent_search.exploration.utils import aggregate_context
from onyx.agents.agent_search.exploration.utils import (
convert_inference_sections_to_search_docs,
)
from onyx.agents.agent_search.exploration.utils import get_chat_history_string
from onyx.agents.agent_search.exploration.utils import get_prompt_question
from onyx.agents.agent_search.exploration.utils import parse_plan_to_dict
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.llm import invoke_llm_json
from onyx.agents.agent_search.shared_graph_utils.llm import stream_llm_answer
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.agents.agent_search.utils import create_question_prompt
from onyx.chat.chat_utils import llm_doc_from_inference_section
from onyx.configs.agent_configs import TF_DR_TIMEOUT_LONG
from onyx.context.search.models import InferenceSection
from onyx.db.chat import create_search_doc_from_inference_section
from onyx.db.chat import update_db_session_with_messages
from onyx.db.models import ChatMessage__SearchDoc
from onyx.db.models import ResearchAgentIteration
from onyx.db.models import ResearchAgentIterationSubStep
from onyx.db.models import SearchDoc as DbSearchDoc
from onyx.llm.utils import check_number_of_tokens
from onyx.prompts.chat_prompts import PROJECT_INSTRUCTIONS_SEPARATOR
from onyx.prompts.dr_prompts import FINAL_ANSWER_PROMPT_W_SUB_ANSWERS
from onyx.prompts.dr_prompts import FINAL_ANSWER_PROMPT_WITHOUT_SUB_ANSWERS
from onyx.prompts.dr_prompts import TEST_INFO_COMPLETE_PROMPT
from onyx.server.query_and_chat.streaming_models import CitationDelta
from onyx.server.query_and_chat.streaming_models import CitationStart
from onyx.server.query_and_chat.streaming_models import MessageStart
from onyx.server.query_and_chat.streaming_models import SectionEnd
from onyx.server.query_and_chat.streaming_models import StreamingType
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
logger = setup_logger()
def extract_citation_numbers(text: str) -> list[int]:
"""
Extract all citation numbers from text in the format [[<number>]] or [[<number_1>, <number_2>, ...]].
Returns a list of all unique citation numbers found.
"""
# Pattern to match [[number]] or [[number1, number2, ...]]
pattern = r"\[\[(\d+(?:,\s*\d+)*)\]\]"
matches = re.findall(pattern, text)
cited_numbers = []
for match in matches:
# Split by comma and extract all numbers
numbers = [int(num.strip()) for num in match.split(",")]
cited_numbers.extend(numbers)
return list(set(cited_numbers)) # Return unique numbers
def replace_citation_with_link(match: re.Match[str], docs: list[DbSearchDoc]) -> str:
citation_content = match.group(1) # e.g., "3" or "3, 5, 7"
numbers = [int(num.strip()) for num in citation_content.split(",")]
# For multiple citations like [[3, 5, 7]], create separate linked citations
linked_citations = []
for num in numbers:
if num - 1 < len(docs): # Check bounds
link = docs[num - 1].link or ""
linked_citations.append(f"[[{num}]]({link})")
else:
linked_citations.append(f"[[{num}]]") # No link if out of bounds
return "".join(linked_citations)
def insert_chat_message_search_doc_pair(
message_id: int, search_doc_ids: list[int], db_session: Session
) -> None:
"""
Insert a pair of message_id and search_doc_id into the chat_message__search_doc table.
Args:
message_id: The ID of the chat message
search_doc_id: The ID of the search document
db_session: The database session
"""
for search_doc_id in search_doc_ids:
chat_message_search_doc = ChatMessage__SearchDoc(
chat_message_id=message_id, search_doc_id=search_doc_id
)
db_session.add(chat_message_search_doc)
def save_iteration(
state: MainState,
graph_config: GraphConfig,
aggregated_context: AggregatedDRContext,
final_answer: str,
all_cited_documents: list[InferenceSection],
is_internet_marker_dict: dict[str, bool],
) -> None:
db_session = graph_config.persistence.db_session
message_id = graph_config.persistence.message_id
research_type = graph_config.behavior.research_type
db_session = graph_config.persistence.db_session
# first, insert the search_docs
search_docs = [
create_search_doc_from_inference_section(
inference_section=inference_section,
is_internet=is_internet_marker_dict.get(
inference_section.center_chunk.document_id, False
), # TODO: revisit
db_session=db_session,
commit=False,
)
for inference_section in all_cited_documents
]
# then, map_search_docs to message
insert_chat_message_search_doc_pair(
message_id, [search_doc.id for search_doc in search_docs], db_session
)
# lastly, insert the citations
citation_dict: dict[int, int] = {}
cited_doc_nrs = extract_citation_numbers(final_answer)
for cited_doc_nr in cited_doc_nrs:
citation_dict[cited_doc_nr] = search_docs[cited_doc_nr - 1].id
# TODO: generate plan as dict in the first place
plan_of_record = state.plan_of_record.plan if state.plan_of_record else ""
plan_of_record_dict = parse_plan_to_dict(plan_of_record)
# Update the chat message and its parent message in database
update_db_session_with_messages(
db_session=db_session,
chat_message_id=message_id,
chat_session_id=graph_config.persistence.chat_session_id,
is_agentic=graph_config.behavior.use_agentic_search,
message=final_answer,
citations=citation_dict,
research_type=research_type,
research_plan=plan_of_record_dict,
final_documents=search_docs,
update_parent_message=True,
research_answer_purpose=ResearchAnswerPurpose.ANSWER,
)
for iteration_preparation in state.iteration_instructions:
research_agent_iteration_step = ResearchAgentIteration(
primary_question_id=message_id,
reasoning=iteration_preparation.reasoning,
purpose=iteration_preparation.purpose,
iteration_nr=iteration_preparation.iteration_nr,
)
db_session.add(research_agent_iteration_step)
for iteration_answer in aggregated_context.global_iteration_responses:
retrieved_search_docs = convert_inference_sections_to_search_docs(
list(iteration_answer.cited_documents.values())
)
# Convert SavedSearchDoc objects to JSON-serializable format
serialized_search_docs = [doc.model_dump() for doc in retrieved_search_docs]
research_agent_iteration_sub_step = ResearchAgentIterationSubStep(
primary_question_id=message_id,
iteration_nr=iteration_answer.iteration_nr,
iteration_sub_step_nr=iteration_answer.parallelization_nr,
sub_step_instructions=iteration_answer.question,
sub_step_tool_id=iteration_answer.tool_id,
sub_answer=iteration_answer.answer,
reasoning=iteration_answer.reasoning,
claims=iteration_answer.claims,
cited_doc_results=serialized_search_docs,
generated_images=(
GeneratedImageFullResult(images=iteration_answer.generated_images)
if iteration_answer.generated_images
else None
),
additional_data=iteration_answer.additional_data,
queries=iteration_answer.queries,
)
db_session.add(research_agent_iteration_sub_step)
db_session.commit()
def closer(
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> FinalUpdate | OrchestrationUpdate:
"""
LangGraph node to close the DR process and finalize the answer.
"""
node_start_time = datetime.now()
# TODO: generate final answer using all the previous steps
# (right now, answers from each step are concatenated onto each other)
# Also, add missing fields once usage in UI is clear.
state.use_clarifier
state.use_plan
state.use_plan_updates
state.use_corpus_history
state.use_thinking
current_step_nr = state.current_step_nr
graph_config = cast(GraphConfig, config["metadata"]["config"])
base_question = state.original_question
if not base_question:
raise ValueError("Question is required for closer")
research_type = ResearchType.DEEP
assistant_system_prompt: str = state.assistant_system_prompt or ""
assistant_task_prompt = state.assistant_task_prompt
uploaded_context = state.uploaded_test_context or ""
clarification = state.clarification
prompt_question = get_prompt_question(base_question, clarification)
chat_history_string = (
get_chat_history_string(
graph_config.inputs.prompt_builder.message_history,
MAX_CHAT_HISTORY_MESSAGES,
)
or "(No chat history yet available)"
)
aggregated_context_w_docs = aggregate_context(
state.iteration_responses, include_documents=True
)
aggregated_context_wo_docs = aggregate_context(
state.iteration_responses, include_documents=False
)
iteration_responses_w_docs_string = aggregated_context_w_docs.context
iteration_responses_wo_docs_string = aggregated_context_wo_docs.context
all_cited_documents = aggregated_context_w_docs.cited_documents
num_closer_suggestions = state.num_closer_suggestions
if (
num_closer_suggestions < MAX_NUM_CLOSER_SUGGESTIONS
and research_type == ResearchType.DEEP
):
test_info_complete_prompt = TEST_INFO_COMPLETE_PROMPT.build(
base_question=prompt_question,
questions_answers_claims=iteration_responses_wo_docs_string,
chat_history_string=chat_history_string,
high_level_plan=(
state.plan_of_record.plan
if state.plan_of_record
else "No plan available"
),
)
test_info_complete_json = invoke_llm_json(
llm=graph_config.tooling.primary_llm,
prompt=create_question_prompt(
assistant_system_prompt,
test_info_complete_prompt + (assistant_task_prompt or ""),
),
schema=TestInfoCompleteResponse,
timeout_override=TF_DR_TIMEOUT_LONG,
# max_tokens=1000,
)
if test_info_complete_json.complete:
pass
else:
return OrchestrationUpdate(
tools_used=[DRPath.ORCHESTRATOR.value],
query_list=[],
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="closer",
node_start_time=node_start_time,
)
],
gaps=test_info_complete_json.gaps,
num_closer_suggestions=num_closer_suggestions + 1,
)
retrieved_search_docs = convert_inference_sections_to_search_docs(
all_cited_documents
)
write_custom_event(
current_step_nr,
MessageStart(
content="",
final_documents=retrieved_search_docs,
),
writer,
)
if research_type in [ResearchType.THOUGHTFUL, ResearchType.FAST]:
final_answer_base_prompt = FINAL_ANSWER_PROMPT_WITHOUT_SUB_ANSWERS
elif research_type == ResearchType.DEEP:
final_answer_base_prompt = FINAL_ANSWER_PROMPT_W_SUB_ANSWERS
else:
raise ValueError(f"Invalid research type: {research_type}")
estimated_final_answer_prompt_tokens = check_number_of_tokens(
final_answer_base_prompt.build(
base_question=prompt_question,
iteration_responses_string=iteration_responses_w_docs_string,
chat_history_string=chat_history_string,
uploaded_context=uploaded_context,
)
)
# for DR, rely only on sub-answers and claims to save tokens if context is too long
# TODO: consider compression step for Thoughtful mode if context is too long.
# Should generally not be the case though.
max_allowed_input_tokens = graph_config.tooling.primary_llm.config.max_input_tokens
if (
estimated_final_answer_prompt_tokens > 0.8 * max_allowed_input_tokens
and research_type == ResearchType.DEEP
):
iteration_responses_string = iteration_responses_wo_docs_string
else:
iteration_responses_string = iteration_responses_w_docs_string
final_answer_prompt = final_answer_base_prompt.build(
base_question=prompt_question,
iteration_responses_string=iteration_responses_string,
chat_history_string=chat_history_string,
uploaded_context=uploaded_context,
)
if graph_config.inputs.project_instructions:
assistant_system_prompt = (
assistant_system_prompt
+ PROJECT_INSTRUCTIONS_SEPARATOR
+ (graph_config.inputs.project_instructions or "")
)
all_context_llmdocs = [
llm_doc_from_inference_section(inference_section)
for inference_section in all_cited_documents
]
try:
streamed_output, _, citation_infos = run_with_timeout(
int(3 * TF_DR_TIMEOUT_LONG),
lambda: stream_llm_answer(
llm=graph_config.tooling.primary_llm,
prompt=create_question_prompt(
assistant_system_prompt,
final_answer_prompt + (assistant_task_prompt or ""),
),
event_name="basic_response",
writer=writer,
agent_answer_level=0,
agent_answer_question_num=0,
agent_answer_type="agent_level_answer",
timeout_override=int(2 * TF_DR_TIMEOUT_LONG),
answer_piece=StreamingType.MESSAGE_DELTA.value,
ind=current_step_nr,
context_docs=all_context_llmdocs,
replace_citations=True,
# max_tokens=None,
),
)
final_answer = "".join(streamed_output)
except Exception as e:
raise ValueError(f"Error in consolidate_research: {e}")
write_custom_event(current_step_nr, SectionEnd(), writer)
current_step_nr += 1
write_custom_event(current_step_nr, CitationStart(), writer)
write_custom_event(current_step_nr, CitationDelta(citations=citation_infos), writer)
write_custom_event(current_step_nr, SectionEnd(), writer)
current_step_nr += 1
# Log the research agent steps
# save_iteration(
# state,
# graph_config,
# aggregated_context,
# final_answer,
# all_cited_documents,
# is_internet_marker_dict,
# )
return FinalUpdate(
final_answer=final_answer,
all_cited_documents=all_cited_documents,
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="closer",
node_start_time=node_start_time,
)
],
)

View File

@@ -0,0 +1,326 @@
import json
import re
from datetime import datetime
from typing import Any
from typing import cast
from langchain_core.messages import HumanMessage
from langchain_core.messages import SystemMessage
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from sqlalchemy.orm import Session
from onyx.agents.agent_search.exploration.dr_experimentation_prompts import (
EXTRACTION_SYSTEM_PROMPT,
)
from onyx.agents.agent_search.exploration.enums import DRPath
from onyx.agents.agent_search.exploration.enums import ResearchAnswerPurpose
from onyx.agents.agent_search.exploration.models import AggregatedDRContext
from onyx.agents.agent_search.exploration.states import LoggerUpdate
from onyx.agents.agent_search.exploration.states import MainState
from onyx.agents.agent_search.exploration.sub_agents.image_generation.models import (
GeneratedImageFullResult,
)
from onyx.agents.agent_search.exploration.utils import aggregate_context
from onyx.agents.agent_search.exploration.utils import (
convert_inference_sections_to_search_docs,
)
from onyx.agents.agent_search.exploration.utils import parse_plan_to_dict
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.llm import invoke_llm_raw
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.context.search.models import InferenceSection
from onyx.db.chat import create_search_doc_from_inference_section
from onyx.db.chat import update_db_session_with_messages
from onyx.db.models import ChatMessage__SearchDoc
from onyx.db.models import ResearchAgentIteration
from onyx.db.models import ResearchAgentIterationSubStep
from onyx.db.models import SearchDoc as DbSearchDoc
from onyx.db.users import fetch_user_by_id
from onyx.db.users import update_user_cheat_sheet_context
from onyx.natural_language_processing.utils import get_tokenizer
from onyx.server.query_and_chat.streaming_models import OverallStop
from onyx.utils.logger import setup_logger
logger = setup_logger()
def _extract_citation_numbers(text: str) -> list[int]:
"""
Extract all citation numbers from text in the format [[<number>]] or [[<number_1>, <number_2>, ...]].
Returns a list of all unique citation numbers found.
"""
# Pattern to match [[number]] or [[number1, number2, ...]]
pattern = r"\[\[(\d+(?:,\s*\d+)*)\]\]"
matches = re.findall(pattern, text)
cited_numbers = []
for match in matches:
# Split by comma and extract all numbers
numbers = [int(num.strip()) for num in match.split(",")]
cited_numbers.extend(numbers)
return list(set(cited_numbers)) # Return unique numbers
def replace_citation_with_link(match: re.Match[str], docs: list[DbSearchDoc]) -> str:
citation_content = match.group(1) # e.g., "3" or "3, 5, 7"
numbers = [int(num.strip()) for num in citation_content.split(",")]
# For multiple citations like [[3, 5, 7]], create separate linked citations
linked_citations = []
for num in numbers:
if num - 1 < len(docs): # Check bounds
link = docs[num - 1].link or ""
linked_citations.append(f"[[{num}]]({link})")
else:
linked_citations.append(f"[[{num}]]") # No link if out of bounds
return "".join(linked_citations)
def _insert_chat_message_search_doc_pair(
message_id: int, search_doc_ids: list[int], db_session: Session
) -> None:
"""
Insert a pair of message_id and search_doc_id into the chat_message__search_doc table.
Args:
message_id: The ID of the chat message
search_doc_id: The ID of the search document
db_session: The database session
"""
for search_doc_id in search_doc_ids:
chat_message_search_doc = ChatMessage__SearchDoc(
chat_message_id=message_id, search_doc_id=search_doc_id
)
db_session.add(chat_message_search_doc)
def save_iteration(
state: MainState,
graph_config: GraphConfig,
aggregated_context: AggregatedDRContext,
final_answer: str,
all_cited_documents: list[InferenceSection],
is_internet_marker_dict: dict[str, bool],
num_tokens: int,
new_cheat_sheet_context: dict[str, Any] | None = None,
) -> None:
db_session = graph_config.persistence.db_session
message_id = graph_config.persistence.message_id
research_type = graph_config.behavior.research_type
db_session = graph_config.persistence.db_session
# first, insert the search_docs
search_docs = [
create_search_doc_from_inference_section(
inference_section=inference_section,
is_internet=is_internet_marker_dict.get(
inference_section.center_chunk.document_id, False
), # TODO: revisit
db_session=db_session,
commit=False,
)
for inference_section in all_cited_documents
]
# then, map_search_docs to message
_insert_chat_message_search_doc_pair(
message_id, [search_doc.id for search_doc in search_docs], db_session
)
# lastly, insert the citations
citation_dict: dict[int, int] = {}
cited_doc_nrs = _extract_citation_numbers(final_answer)
if search_docs:
for cited_doc_nr in cited_doc_nrs:
citation_dict[cited_doc_nr] = search_docs[cited_doc_nr - 1].id
# TODO: generate plan as dict in the first place
plan_of_record = state.plan_of_record.plan if state.plan_of_record else ""
plan_of_record_dict = parse_plan_to_dict(plan_of_record)
# Update the chat message and its parent message in database
update_db_session_with_messages(
db_session=db_session,
chat_message_id=message_id,
chat_session_id=graph_config.persistence.chat_session_id,
is_agentic=graph_config.behavior.use_agentic_search,
message=final_answer,
citations=citation_dict,
research_type=research_type,
research_plan=plan_of_record_dict,
final_documents=search_docs,
update_parent_message=True,
research_answer_purpose=ResearchAnswerPurpose.ANSWER,
token_count=num_tokens,
)
for iteration_preparation in state.iteration_instructions:
research_agent_iteration_step = ResearchAgentIteration(
primary_question_id=message_id,
reasoning=iteration_preparation.reasoning,
purpose=iteration_preparation.purpose,
iteration_nr=iteration_preparation.iteration_nr,
)
db_session.add(research_agent_iteration_step)
for iteration_answer in aggregated_context.global_iteration_responses:
retrieved_search_docs = convert_inference_sections_to_search_docs(
list(iteration_answer.cited_documents.values())
)
# Convert SavedSearchDoc objects to JSON-serializable format
serialized_search_docs = [doc.model_dump() for doc in retrieved_search_docs]
research_agent_iteration_sub_step = ResearchAgentIterationSubStep(
primary_question_id=message_id,
iteration_nr=iteration_answer.iteration_nr,
iteration_sub_step_nr=iteration_answer.parallelization_nr,
sub_step_instructions=iteration_answer.question,
sub_step_tool_id=iteration_answer.tool_id,
sub_answer=iteration_answer.answer,
reasoning=iteration_answer.reasoning,
claims=iteration_answer.claims,
cited_doc_results=serialized_search_docs,
generated_images=(
GeneratedImageFullResult(images=iteration_answer.generated_images)
if iteration_answer.generated_images
else None
),
additional_data=iteration_answer.additional_data,
queries=iteration_answer.queries,
)
db_session.add(research_agent_iteration_sub_step)
if graph_config.tooling.search_tool and graph_config.tooling.search_tool.user:
user = fetch_user_by_id(db_session, graph_config.tooling.search_tool.user.id)
if new_cheat_sheet_context and user:
update_user_cheat_sheet_context(
user=user,
new_cheat_sheet_context=new_cheat_sheet_context,
db_session=db_session,
)
db_session.commit()
def logging(
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> LoggerUpdate:
"""
LangGraph node to close the DR process and finalize the answer.
"""
node_start_time = datetime.now()
# TODO: generate final answer using all the previous steps
# (right now, answers from each step are concatenated onto each other)
# Also, add missing fields once usage in UI is clear.
current_step_nr = state.current_step_nr
base_cheat_sheet_context = state.cheat_sheet_context or {}
graph_config = cast(GraphConfig, config["metadata"]["config"])
base_question = state.original_question
if not base_question:
raise ValueError("Question is required for closer")
aggregated_context = aggregate_context(
state.iteration_responses, include_documents=True
)
all_cited_documents = aggregated_context.cited_documents
is_internet_marker_dict = aggregated_context.is_internet_marker_dict
final_answer = state.final_answer or ""
llm_provider = graph_config.tooling.primary_llm.config.model_provider
llm_model_name = graph_config.tooling.primary_llm.config.model_name
llm_tokenizer = get_tokenizer(
model_name=llm_model_name,
provider_type=llm_provider,
)
num_tokens = len(llm_tokenizer.encode(final_answer or ""))
write_custom_event(current_step_nr, OverallStop(), writer)
# build extraction context
extracted_facts: list[str] = []
for iteration_response in state.iteration_responses:
if iteration_response.tool == DRPath.INTERNAL_SEARCH.value:
extracted_facts.extend(iteration_response.claims)
extraction_context = (
"Extracted facts: \n - "
+ "\n - ".join(extracted_facts)
+ f"\n\nProvidede Answer: \n\n{final_answer}"
)
extraction_system_prompt = SystemMessage(content=EXTRACTION_SYSTEM_PROMPT)
extraction_human_prompt = HumanMessage(content=extraction_context)
extraction_prompt = [extraction_system_prompt, extraction_human_prompt]
extraction_response = invoke_llm_raw(
llm=graph_config.tooling.primary_llm,
prompt=extraction_prompt,
# schema=ExtractionResponse,
)
extraction_information = json.loads(extraction_response.content)
consolidated_updates: dict[str, list[tuple[str, str]]] = {}
for area in ["user", "company", "search_strategy", "reasoning_strategy"]:
update_knowledge = extraction_information.get(area, {})
base_area_knowledge = base_cheat_sheet_context.get(area, {})
if area not in base_cheat_sheet_context:
base_cheat_sheet_context[area] = {}
for update_info in update_knowledge:
update_type = update_info.get("type", "n/a")
change_type = update_info.get("change_type", "n/a")
information = update_info.get("information", "n/a")
if update_type in base_area_knowledge:
if change_type == "update":
consolidated_updates[area].append(
(base_area_knowledge.get(update_type, ""), information)
)
elif change_type == "delete":
del base_cheat_sheet_context.get(area, {})[update_type]
elif change_type == "add":
base_cheat_sheet_context.get(area, {})[update_type] = information
else:
base_cheat_sheet_context[area][update_type] = information
# Log the research agent steps
save_iteration(
state,
graph_config,
aggregated_context,
final_answer,
all_cited_documents,
is_internet_marker_dict,
num_tokens,
new_cheat_sheet_context=base_cheat_sheet_context,
)
return LoggerUpdate(
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="logger",
node_start_time=node_start_time,
)
],
)

View File

@@ -0,0 +1,132 @@
from collections.abc import Iterator
from typing import cast
from langchain_core.messages import AIMessageChunk
from langchain_core.messages import BaseMessage
from langgraph.types import StreamWriter
from pydantic import BaseModel
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.chat_utils import saved_search_docs_from_llm_docs
from onyx.chat.models import AgentAnswerPiece
from onyx.chat.models import CitationInfo
from onyx.chat.models import LlmDoc
from onyx.chat.models import OnyxAnswerPiece
from onyx.chat.stream_processing.answer_response_handler import AnswerResponseHandler
from onyx.chat.stream_processing.answer_response_handler import CitationResponseHandler
from onyx.chat.stream_processing.answer_response_handler import (
PassThroughAnswerResponseHandler,
)
from onyx.chat.stream_processing.utils import map_document_id_order
from onyx.context.search.models import InferenceSection
from onyx.server.query_and_chat.streaming_models import CitationDelta
from onyx.server.query_and_chat.streaming_models import CitationStart
from onyx.server.query_and_chat.streaming_models import MessageDelta
from onyx.server.query_and_chat.streaming_models import MessageStart
from onyx.server.query_and_chat.streaming_models import SectionEnd
from onyx.utils.logger import setup_logger
logger = setup_logger()
class BasicSearchProcessedStreamResults(BaseModel):
ai_message_chunk: AIMessageChunk = AIMessageChunk(content="")
full_answer: str | None = None
cited_references: list[InferenceSection] = []
retrieved_documents: list[LlmDoc] = []
def process_llm_stream(
messages: Iterator[BaseMessage],
should_stream_answer: bool,
writer: StreamWriter,
ind: int,
search_results: list[LlmDoc] | None = None,
generate_final_answer: bool = False,
chat_message_id: str | None = None,
) -> BasicSearchProcessedStreamResults:
tool_call_chunk = AIMessageChunk(content="")
if search_results:
answer_handler: AnswerResponseHandler = CitationResponseHandler(
context_docs=search_results,
doc_id_to_rank_map=map_document_id_order(search_results),
)
else:
answer_handler = PassThroughAnswerResponseHandler()
full_answer = ""
start_final_answer_streaming_set = False
# Accumulate citation infos if handler emits them
collected_citation_infos: list[CitationInfo] = []
# This stream will be the llm answer if no tool is chosen. When a tool is chosen,
# the stream will contain AIMessageChunks with tool call information.
for message in messages:
answer_piece = message.content
if not isinstance(answer_piece, str):
# this is only used for logging, so fine to
# just add the string representation
answer_piece = str(answer_piece)
full_answer += answer_piece
if isinstance(message, AIMessageChunk) and (
message.tool_call_chunks or message.tool_calls
):
tool_call_chunk += message # type: ignore
elif should_stream_answer:
for response_part in answer_handler.handle_response_part(message):
# only stream out answer parts
if (
isinstance(response_part, (OnyxAnswerPiece, AgentAnswerPiece))
and generate_final_answer
and response_part.answer_piece
):
if chat_message_id is None:
raise ValueError(
"chat_message_id is required when generating final answer"
)
if not start_final_answer_streaming_set:
# Convert LlmDocs to SavedSearchDocs
saved_search_docs = saved_search_docs_from_llm_docs(
search_results
)
write_custom_event(
ind,
MessageStart(content="", final_documents=saved_search_docs),
writer,
)
start_final_answer_streaming_set = True
write_custom_event(
ind,
MessageDelta(content=response_part.answer_piece),
writer,
)
# collect citation info objects
elif isinstance(response_part, CitationInfo):
collected_citation_infos.append(response_part)
if generate_final_answer and start_final_answer_streaming_set:
# start_final_answer_streaming_set is only set if the answer is verbal and not a tool call
write_custom_event(
ind,
SectionEnd(),
writer,
)
# Emit citations section if any were collected
if collected_citation_infos:
write_custom_event(ind, CitationStart(), writer)
write_custom_event(
ind, CitationDelta(citations=collected_citation_infos), writer
)
write_custom_event(ind, SectionEnd(), writer)
logger.debug(f"Full answer: {full_answer}")
return BasicSearchProcessedStreamResults(
ai_message_chunk=cast(AIMessageChunk, tool_call_chunk), full_answer=full_answer
)

View File

@@ -0,0 +1,99 @@
from operator import add
from typing import Annotated
from typing import Any
from typing import Dict
from typing import TypedDict
from langchain_core.messages import AIMessage
from langchain_core.messages import HumanMessage
from langchain_core.messages import SystemMessage
from pydantic import BaseModel
from onyx.agents.agent_search.core_state import CoreState
from onyx.agents.agent_search.exploration.models import IterationAnswer
from onyx.agents.agent_search.exploration.models import IterationInstructions
from onyx.agents.agent_search.exploration.models import OrchestrationClarificationInfo
from onyx.agents.agent_search.exploration.models import OrchestrationPlan
from onyx.agents.agent_search.exploration.models import OrchestratorTool
from onyx.context.search.models import InferenceSection
from onyx.db.connector import DocumentSource
### States ###
class LoggerUpdate(BaseModel):
log_messages: Annotated[list[str], add] = []
class OrchestrationUpdate(LoggerUpdate):
tools_used: Annotated[list[str], add] = []
query_list: list[str] = []
iteration_nr: int = 0
current_step_nr: int = 1
plan_of_record: OrchestrationPlan | None = None # None for Thoughtful
remaining_time_budget: float = 2.0 # set by default to about 2 searches
num_closer_suggestions: int = 0 # how many times the closer was suggested
gaps: list[str] = (
[]
) # gaps that may be identified by the closer before being able to answer the question.
iteration_instructions: Annotated[list[IterationInstructions], add] = []
message_history_for_continuation: Annotated[
list[SystemMessage | HumanMessage | AIMessage], add
] = []
iteration_responses: Annotated[list[IterationAnswer], add] = []
class OrchestrationSetup(OrchestrationUpdate):
original_question: str | None = None
chat_history_string: str | None = None
clarification: OrchestrationClarificationInfo | None = None
available_tools: dict[str, OrchestratorTool] | None = None
num_closer_suggestions: int = 0 # how many times the closer was suggested
active_source_types: list[DocumentSource] | None = None
active_source_types_descriptions: str | None = None
assistant_system_prompt: str | None = None
assistant_task_prompt: str | None = None
uploaded_test_context: str | None = None
uploaded_image_context: list[dict[str, Any]] | None = None
message_history_for_continuation: Annotated[
list[SystemMessage | HumanMessage | AIMessage], add
] = []
cheat_sheet_context: Dict[str, Any] | None = None
use_clarifier: bool = False
use_thinking: bool = False
use_plan: bool = False
use_plan_updates: bool = False
use_corpus_history: bool = False
class AnswerUpdate(LoggerUpdate):
iteration_responses: Annotated[list[IterationAnswer], add] = []
class FinalUpdate(LoggerUpdate):
final_answer: str | None = None
all_cited_documents: list[InferenceSection] = []
## Graph Input State
class MainInput(CoreState):
pass
## Graph State
class MainState(
# This includes the core state
MainInput,
OrchestrationSetup,
AnswerUpdate,
FinalUpdate,
):
pass
## Graph Output State
class MainOutput(TypedDict):
log_messages: list[str]
final_answer: str | None
all_cited_documents: list[InferenceSection]

View File

@@ -0,0 +1,47 @@
from datetime import datetime
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.exploration.states import LoggerUpdate
from onyx.agents.agent_search.exploration.sub_agents.states import SubAgentInput
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.server.query_and_chat.streaming_models import SearchToolStart
from onyx.utils.logger import setup_logger
logger = setup_logger()
def basic_search_branch(
state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> LoggerUpdate:
"""
LangGraph node to perform a standard search as part of the DR process.
"""
node_start_time = datetime.now()
iteration_nr = state.iteration_nr
current_step_nr = state.current_step_nr
logger.debug(f"Search start for Basic Search {iteration_nr} at {datetime.now()}")
write_custom_event(
current_step_nr,
SearchToolStart(
is_internet_search=False,
),
writer,
)
return LoggerUpdate(
log_messages=[
get_langgraph_node_log_string(
graph_component="basic_search",
node_name="branching",
node_start_time=node_start_time,
)
],
)

View File

@@ -0,0 +1,288 @@
import re
from datetime import datetime
from typing import cast
from uuid import UUID
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.exploration.enums import ResearchType
from onyx.agents.agent_search.exploration.models import BaseSearchProcessingResponse
from onyx.agents.agent_search.exploration.models import IterationAnswer
from onyx.agents.agent_search.exploration.models import SearchAnswer
from onyx.agents.agent_search.exploration.sub_agents.states import BranchInput
from onyx.agents.agent_search.exploration.sub_agents.states import BranchUpdate
from onyx.agents.agent_search.exploration.utils import (
convert_inference_sections_to_search_docs,
)
from onyx.agents.agent_search.exploration.utils import extract_document_citations
from onyx.agents.agent_search.kb_search.graph_utils import build_document_context
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.llm import invoke_llm_json
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.agents.agent_search.utils import create_question_prompt
from onyx.chat.models import LlmDoc
from onyx.configs.agent_configs import TF_DR_TIMEOUT_LONG
from onyx.configs.agent_configs import TF_DR_TIMEOUT_SHORT
from onyx.context.search.models import InferenceSection
from onyx.db.connector import DocumentSource
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.prompts.dr_prompts import BASE_SEARCH_PROCESSING_PROMPT
from onyx.prompts.dr_prompts import INTERNAL_SEARCH_PROMPTS
from onyx.secondary_llm_flows.source_filter import strings_to_document_sources
from onyx.server.query_and_chat.streaming_models import SearchToolDelta
from onyx.tools.models import SearchToolOverrideKwargs
from onyx.tools.tool_implementations.search.search_tool import (
SEARCH_RESPONSE_SUMMARY_ID,
)
from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary
from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.utils.logger import setup_logger
logger = setup_logger()
def basic_search(
state: BranchInput,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> BranchUpdate:
"""
LangGraph node to perform a standard search as part of the DR process.
"""
node_start_time = datetime.now()
iteration_nr = state.iteration_nr
parallelization_nr = state.parallelization_nr
current_step_nr = state.current_step_nr
assistant_system_prompt = state.assistant_system_prompt
assistant_task_prompt = state.assistant_task_prompt
branch_query = state.branch_question
if not branch_query:
raise ValueError("branch_query is not set")
graph_config = cast(GraphConfig, config["metadata"]["config"])
base_question = graph_config.inputs.prompt_builder.raw_user_query
research_type = ResearchType.DEEP
if not state.available_tools:
raise ValueError("available_tools is not set")
elif len(state.tools_used) == 0:
raise ValueError("tools_used is empty")
search_tool_info = state.available_tools[state.tools_used[-1]]
search_tool = cast(SearchTool, search_tool_info.tool_object)
force_use_tool = graph_config.tooling.force_use_tool
# sanity check
if search_tool != graph_config.tooling.search_tool:
raise ValueError("search_tool does not match the configured search tool")
# rewrite query and identify source types
active_source_types_str = ", ".join(
[source.value for source in state.active_source_types or []]
)
base_search_processing_prompt = BASE_SEARCH_PROCESSING_PROMPT.build(
active_source_types_str=active_source_types_str,
branch_query=branch_query,
current_time=datetime.now().strftime("%Y-%m-%d %H:%M"),
)
try:
search_processing = invoke_llm_json(
llm=graph_config.tooling.primary_llm,
prompt=create_question_prompt(
assistant_system_prompt, base_search_processing_prompt
),
schema=BaseSearchProcessingResponse,
timeout_override=TF_DR_TIMEOUT_SHORT,
# max_tokens=100,
)
except Exception as e:
logger.error(f"Could not process query: {e}")
raise e
rewritten_query = search_processing.rewritten_query
# give back the query so we can render it in the UI
write_custom_event(
current_step_nr,
SearchToolDelta(
queries=[rewritten_query],
documents=[],
),
writer,
)
implied_start_date = search_processing.time_filter
# Validate time_filter format if it exists
implied_time_filter = None
if implied_start_date:
# Check if time_filter is in YYYY-MM-DD format
date_pattern = r"^\d{4}-\d{2}-\d{2}$"
if re.match(date_pattern, implied_start_date):
implied_time_filter = datetime.strptime(implied_start_date, "%Y-%m-%d")
specified_source_types: list[DocumentSource] | None = (
strings_to_document_sources(search_processing.specified_source_types)
if search_processing.specified_source_types
else None
)
if specified_source_types is not None and len(specified_source_types) == 0:
specified_source_types = None
logger.debug(
f"Search start for Standard Search {iteration_nr}.{parallelization_nr} at {datetime.now()}"
)
retrieved_docs: list[InferenceSection] = []
callback_container: list[list[InferenceSection]] = []
user_file_ids: list[UUID] | None = None
project_id: int | None = None
if force_use_tool.override_kwargs and isinstance(
force_use_tool.override_kwargs, SearchToolOverrideKwargs
):
override_kwargs = force_use_tool.override_kwargs
user_file_ids = override_kwargs.user_file_ids
project_id = override_kwargs.project_id
# new db session to avoid concurrency issues
with get_session_with_current_tenant() as search_db_session:
for tool_response in search_tool.run(
query=rewritten_query,
document_sources=specified_source_types,
time_filter=implied_time_filter,
override_kwargs=SearchToolOverrideKwargs(
force_no_rerank=True,
alternate_db_session=search_db_session,
retrieved_sections_callback=callback_container.append,
skip_query_analysis=True,
original_query=rewritten_query,
user_file_ids=user_file_ids,
project_id=project_id,
),
):
# get retrieved docs to send to the rest of the graph
if tool_response.id == SEARCH_RESPONSE_SUMMARY_ID:
response = cast(SearchResponseSummary, tool_response.response)
retrieved_docs = response.top_sections
break
# render the retrieved docs in the UI
write_custom_event(
current_step_nr,
SearchToolDelta(
queries=[],
documents=convert_inference_sections_to_search_docs(
retrieved_docs, is_internet=False
),
),
writer,
)
document_texts_list = []
for doc_num, retrieved_doc in enumerate(retrieved_docs[:15]):
if not isinstance(retrieved_doc, (InferenceSection, LlmDoc)):
raise ValueError(f"Unexpected document type: {type(retrieved_doc)}")
chunk_text = build_document_context(retrieved_doc, doc_num + 1)
document_texts_list.append(chunk_text)
document_texts = "\n\n".join(document_texts_list)
logger.debug(
f"Search end/LLM start for Standard Search {iteration_nr}.{parallelization_nr} at {datetime.now()}"
)
# Built prompt
if research_type == ResearchType.DEEP:
search_prompt = INTERNAL_SEARCH_PROMPTS[research_type].build(
search_query=branch_query,
base_question=base_question,
document_text=document_texts,
)
# Run LLM
# search_answer_json = None
search_answer_json = invoke_llm_json(
llm=graph_config.tooling.primary_llm,
prompt=create_question_prompt(
assistant_system_prompt, search_prompt + (assistant_task_prompt or "")
),
schema=SearchAnswer,
timeout_override=TF_DR_TIMEOUT_LONG,
# max_tokens=1500,
)
logger.debug(
f"LLM/all done for Standard Search {iteration_nr}.{parallelization_nr} at {datetime.now()}"
)
# get cited documents
answer_string = search_answer_json.answer
claims = search_answer_json.claims or []
reasoning = search_answer_json.reasoning
# answer_string = ""
# claims = []
(
citation_numbers,
answer_string,
claims,
) = extract_document_citations(answer_string, claims)
if citation_numbers and (
(max(citation_numbers) > len(retrieved_docs)) or min(citation_numbers) < 1
):
raise ValueError("Citation numbers are out of range for retrieved docs.")
cited_documents = {
citation_number: retrieved_docs[citation_number - 1]
for citation_number in citation_numbers
}
else:
answer_string = ""
claims = []
cited_documents = {
doc_num + 1: retrieved_doc
for doc_num, retrieved_doc in enumerate(retrieved_docs[:15])
}
reasoning = ""
return BranchUpdate(
branch_iteration_responses=[
IterationAnswer(
tool=search_tool_info.llm_path,
tool_id=search_tool_info.tool_id,
iteration_nr=iteration_nr,
parallelization_nr=parallelization_nr,
question=branch_query,
answer=answer_string,
claims=claims,
cited_documents=cited_documents,
reasoning=reasoning,
additional_data=None,
)
],
log_messages=[
get_langgraph_node_log_string(
graph_component="basic_search",
node_name="searching",
node_start_time=node_start_time,
)
],
)

View File

@@ -0,0 +1,77 @@
from datetime import datetime
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.exploration.sub_agents.states import SubAgentMainState
from onyx.agents.agent_search.exploration.sub_agents.states import SubAgentUpdate
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.context.search.models import SavedSearchDoc
from onyx.context.search.models import SearchDoc
from onyx.server.query_and_chat.streaming_models import SectionEnd
from onyx.utils.logger import setup_logger
logger = setup_logger()
def is_reducer(
state: SubAgentMainState,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> SubAgentUpdate:
"""
LangGraph node to perform a standard search as part of the DR process.
"""
node_start_time = datetime.now()
branch_updates = state.branch_iteration_responses
current_iteration = state.iteration_nr
current_step_nr = state.current_step_nr
new_updates = [
update for update in branch_updates if update.iteration_nr == current_iteration
]
[update.question for update in new_updates]
doc_lists = [list(update.cited_documents.values()) for update in new_updates]
doc_list = []
for xs in doc_lists:
for x in xs:
doc_list.append(x)
# Convert InferenceSections to SavedSearchDocs
search_docs = SearchDoc.from_chunks_or_sections(doc_list)
retrieved_saved_search_docs = [
SavedSearchDoc.from_search_doc(search_doc, db_doc_id=0)
for search_doc in search_docs
]
for retrieved_saved_search_doc in retrieved_saved_search_docs:
retrieved_saved_search_doc.is_internet = False
write_custom_event(
current_step_nr,
SectionEnd(),
writer,
)
current_step_nr += 1
return SubAgentUpdate(
iteration_responses=new_updates,
current_step_nr=current_step_nr,
log_messages=[
get_langgraph_node_log_string(
graph_component="basic_search",
node_name="consolidation",
node_start_time=node_start_time,
)
],
)

View File

@@ -0,0 +1,50 @@
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
from onyx.agents.agent_search.exploration.sub_agents.basic_search.dr_basic_search_1_branch import (
basic_search_branch,
)
from onyx.agents.agent_search.exploration.sub_agents.basic_search.dr_basic_search_2_act import (
basic_search,
)
from onyx.agents.agent_search.exploration.sub_agents.basic_search.dr_basic_search_3_reduce import (
is_reducer,
)
from onyx.agents.agent_search.exploration.sub_agents.basic_search.dr_image_generation_conditional_edges import (
branching_router,
)
from onyx.agents.agent_search.exploration.sub_agents.states import SubAgentInput
from onyx.agents.agent_search.exploration.sub_agents.states import SubAgentMainState
from onyx.utils.logger import setup_logger
logger = setup_logger()
def dr_basic_search_graph_builder() -> StateGraph:
"""
LangGraph graph builder for Web Search Sub-Agent
"""
graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
### Add nodes ###
graph.add_node("branch", basic_search_branch)
graph.add_node("act", basic_search)
graph.add_node("reducer", is_reducer)
### Add edges ###
graph.add_edge(start_key=START, end_key="branch")
graph.add_conditional_edges("branch", branching_router)
graph.add_edge(start_key="act", end_key="reducer")
graph.add_edge(start_key="reducer", end_key=END)
return graph

View File

@@ -0,0 +1,30 @@
from collections.abc import Hashable
from langgraph.types import Send
from onyx.agents.agent_search.exploration.constants import MAX_DR_PARALLEL_SEARCH
from onyx.agents.agent_search.exploration.sub_agents.states import BranchInput
from onyx.agents.agent_search.exploration.sub_agents.states import SubAgentInput
def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
return [
Send(
"act",
BranchInput(
iteration_nr=state.iteration_nr,
parallelization_nr=parallelization_nr,
branch_question=query,
current_step_nr=state.current_step_nr,
context="",
active_source_types=state.active_source_types,
tools_used=state.tools_used,
available_tools=state.available_tools,
assistant_system_prompt=state.assistant_system_prompt,
assistant_task_prompt=state.assistant_task_prompt,
),
)
for parallelization_nr, query in enumerate(
state.query_list[:MAX_DR_PARALLEL_SEARCH]
)
]

View File

@@ -0,0 +1,36 @@
from datetime import datetime
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.exploration.states import LoggerUpdate
from onyx.agents.agent_search.exploration.sub_agents.states import SubAgentInput
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.utils.logger import setup_logger
logger = setup_logger()
def custom_tool_branch(
state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> LoggerUpdate:
"""
LangGraph node to perform a generic tool call as part of the DR process.
"""
node_start_time = datetime.now()
iteration_nr = state.iteration_nr
logger.debug(f"Search start for Generic Tool {iteration_nr} at {datetime.now()}")
return LoggerUpdate(
log_messages=[
get_langgraph_node_log_string(
graph_component="custom_tool",
node_name="branching",
node_start_time=node_start_time,
)
],
)

View File

@@ -0,0 +1,169 @@
import json
from datetime import datetime
from typing import cast
from langchain_core.messages import AIMessage
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.exploration.sub_agents.states import BranchInput
from onyx.agents.agent_search.exploration.sub_agents.states import BranchUpdate
from onyx.agents.agent_search.exploration.sub_agents.states import IterationAnswer
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.configs.agent_configs import TF_DR_TIMEOUT_LONG
from onyx.configs.agent_configs import TF_DR_TIMEOUT_SHORT
from onyx.prompts.dr_prompts import CUSTOM_TOOL_PREP_PROMPT
from onyx.prompts.dr_prompts import CUSTOM_TOOL_USE_PROMPT
from onyx.tools.tool_implementations.custom.custom_tool import CUSTOM_TOOL_RESPONSE_ID
from onyx.tools.tool_implementations.custom.custom_tool import CustomTool
from onyx.tools.tool_implementations.custom.custom_tool import CustomToolCallSummary
from onyx.tools.tool_implementations.mcp.mcp_tool import MCP_TOOL_RESPONSE_ID
from onyx.utils.logger import setup_logger
logger = setup_logger()
def custom_tool_act(
state: BranchInput,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> BranchUpdate:
"""
LangGraph node to perform a generic tool call as part of the DR process.
"""
node_start_time = datetime.now()
iteration_nr = state.iteration_nr
parallelization_nr = state.parallelization_nr
if not state.available_tools:
raise ValueError("available_tools is not set")
custom_tool_info = state.available_tools[state.tools_used[-1]]
custom_tool_name = custom_tool_info.name
custom_tool = cast(CustomTool, custom_tool_info.tool_object)
branch_query = state.branch_question
if not branch_query:
raise ValueError("branch_query is not set")
graph_config = cast(GraphConfig, config["metadata"]["config"])
base_question = graph_config.inputs.prompt_builder.raw_user_query
logger.debug(
f"Tool call start for {custom_tool_name} {iteration_nr}.{parallelization_nr} at {datetime.now()}"
)
# get tool call args
tool_args: dict | None = None
if graph_config.tooling.using_tool_calling_llm:
# get tool call args from tool-calling LLM
tool_use_prompt = CUSTOM_TOOL_PREP_PROMPT.build(
query=branch_query,
base_question=base_question,
tool_description=custom_tool_info.description,
)
tool_calling_msg = graph_config.tooling.primary_llm.invoke_langchain(
tool_use_prompt,
tools=[custom_tool.tool_definition()],
tool_choice="required",
timeout_override=TF_DR_TIMEOUT_LONG,
)
# make sure we got a tool call
if (
isinstance(tool_calling_msg, AIMessage)
and len(tool_calling_msg.tool_calls) == 1
):
tool_args = tool_calling_msg.tool_calls[0]["args"]
else:
logger.warning("Tool-calling LLM did not emit a tool call")
if tool_args is None:
# get tool call args from non-tool-calling LLM or for failed tool-calling LLM
tool_args = custom_tool.get_args_for_non_tool_calling_llm(
query=branch_query,
history=[],
llm=graph_config.tooling.primary_llm,
force_run=True,
)
if tool_args is None:
raise ValueError("Failed to obtain tool arguments from LLM")
# run the tool
response_summary: CustomToolCallSummary | None = None
for tool_response in custom_tool.run(**tool_args):
if tool_response.id in {CUSTOM_TOOL_RESPONSE_ID, MCP_TOOL_RESPONSE_ID}:
response_summary = cast(CustomToolCallSummary, tool_response.response)
break
if not response_summary:
raise ValueError("Custom tool did not return a valid response summary")
# summarise tool result
if not response_summary.response_type:
raise ValueError("Response type is not returned.")
if response_summary.response_type == "json":
tool_result_str = json.dumps(response_summary.tool_result, ensure_ascii=False)
elif response_summary.response_type in {"image", "csv"}:
tool_result_str = f"{response_summary.response_type} files: {response_summary.tool_result.file_ids}"
else:
tool_result_str = str(response_summary.tool_result)
tool_str = (
f"Tool used: {custom_tool_name}\n"
f"Description: {custom_tool_info.description}\n"
f"Result: {tool_result_str}"
)
tool_summary_prompt = CUSTOM_TOOL_USE_PROMPT.build(
query=branch_query, base_question=base_question, tool_response=tool_str
)
answer_string = str(
graph_config.tooling.primary_llm.invoke_langchain(
tool_summary_prompt, timeout_override=TF_DR_TIMEOUT_SHORT
).content
).strip()
# get file_ids:
file_ids = None
if response_summary.response_type in {"image", "csv"} and hasattr(
response_summary.tool_result, "file_ids"
):
file_ids = response_summary.tool_result.file_ids
logger.debug(
f"Tool call end for {custom_tool_name} {iteration_nr}.{parallelization_nr} at {datetime.now()}"
)
return BranchUpdate(
branch_iteration_responses=[
IterationAnswer(
tool=custom_tool_name,
tool_id=custom_tool_info.tool_id,
iteration_nr=iteration_nr,
parallelization_nr=parallelization_nr,
question=branch_query,
answer=answer_string,
claims=[],
cited_documents={},
reasoning="",
additional_data=None,
response_type=response_summary.response_type,
data=response_summary.tool_result,
file_ids=file_ids,
)
],
log_messages=[
get_langgraph_node_log_string(
graph_component="custom_tool",
node_name="tool_calling",
node_start_time=node_start_time,
)
],
)

View File

@@ -0,0 +1,82 @@
from datetime import datetime
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.exploration.sub_agents.states import SubAgentMainState
from onyx.agents.agent_search.exploration.sub_agents.states import SubAgentUpdate
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.server.query_and_chat.streaming_models import CustomToolDelta
from onyx.server.query_and_chat.streaming_models import CustomToolStart
from onyx.server.query_and_chat.streaming_models import SectionEnd
from onyx.utils.logger import setup_logger
logger = setup_logger()
def custom_tool_reducer(
state: SubAgentMainState,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> SubAgentUpdate:
"""
LangGraph node to perform a generic tool call as part of the DR process.
"""
node_start_time = datetime.now()
current_step_nr = state.current_step_nr
branch_updates = state.branch_iteration_responses
current_iteration = state.iteration_nr
new_updates = [
update for update in branch_updates if update.iteration_nr == current_iteration
]
for new_update in new_updates:
if not new_update.response_type:
raise ValueError("Response type is not returned.")
write_custom_event(
current_step_nr,
CustomToolStart(
tool_name=new_update.tool,
),
writer,
)
write_custom_event(
current_step_nr,
CustomToolDelta(
tool_name=new_update.tool,
response_type=new_update.response_type,
data=new_update.data,
file_ids=new_update.file_ids,
),
writer,
)
write_custom_event(
current_step_nr,
SectionEnd(),
writer,
)
current_step_nr += 1
return SubAgentUpdate(
iteration_responses=new_updates,
log_messages=[
get_langgraph_node_log_string(
graph_component="custom_tool",
node_name="consolidation",
node_start_time=node_start_time,
)
],
)

View File

@@ -0,0 +1,28 @@
from collections.abc import Hashable
from langgraph.types import Send
from onyx.agents.agent_search.exploration.sub_agents.states import BranchInput
from onyx.agents.agent_search.exploration.sub_agents.states import (
SubAgentInput,
)
def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
return [
Send(
"act",
BranchInput(
iteration_nr=state.iteration_nr,
parallelization_nr=parallelization_nr,
branch_question=query,
context="",
active_source_types=state.active_source_types,
tools_used=state.tools_used,
available_tools=state.available_tools,
),
)
for parallelization_nr, query in enumerate(
state.query_list[:1] # no parallel call for now
)
]

View File

@@ -0,0 +1,50 @@
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
from onyx.agents.agent_search.exploration.sub_agents.custom_tool.dr_custom_tool_1_branch import (
custom_tool_branch,
)
from onyx.agents.agent_search.exploration.sub_agents.custom_tool.dr_custom_tool_2_act import (
custom_tool_act,
)
from onyx.agents.agent_search.exploration.sub_agents.custom_tool.dr_custom_tool_3_reduce import (
custom_tool_reducer,
)
from onyx.agents.agent_search.exploration.sub_agents.custom_tool.dr_custom_tool_conditional_edges import (
branching_router,
)
from onyx.agents.agent_search.exploration.sub_agents.states import SubAgentInput
from onyx.agents.agent_search.exploration.sub_agents.states import SubAgentMainState
from onyx.utils.logger import setup_logger
logger = setup_logger()
def dr_custom_tool_graph_builder() -> StateGraph:
"""
LangGraph graph builder for Generic Tool Sub-Agent
"""
graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
### Add nodes ###
graph.add_node("branch", custom_tool_branch)
graph.add_node("act", custom_tool_act)
graph.add_node("reducer", custom_tool_reducer)
### Add edges ###
graph.add_edge(start_key=START, end_key="branch")
graph.add_conditional_edges("branch", branching_router)
graph.add_edge(start_key="act", end_key="reducer")
graph.add_edge(start_key="reducer", end_key=END)
return graph

View File

@@ -0,0 +1,36 @@
from datetime import datetime
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.exploration.states import LoggerUpdate
from onyx.agents.agent_search.exploration.sub_agents.states import SubAgentInput
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.utils.logger import setup_logger
logger = setup_logger()
def generic_internal_tool_branch(
state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> LoggerUpdate:
"""
LangGraph node to perform a generic tool call as part of the DR process.
"""
node_start_time = datetime.now()
iteration_nr = state.iteration_nr
logger.debug(f"Search start for Generic Tool {iteration_nr} at {datetime.now()}")
return LoggerUpdate(
log_messages=[
get_langgraph_node_log_string(
graph_component="generic_internal_tool",
node_name="branching",
node_start_time=node_start_time,
)
],
)

View File

@@ -0,0 +1,149 @@
import json
from datetime import datetime
from typing import cast
from langchain_core.messages import AIMessage
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.exploration.sub_agents.states import BranchInput
from onyx.agents.agent_search.exploration.sub_agents.states import BranchUpdate
from onyx.agents.agent_search.exploration.sub_agents.states import IterationAnswer
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.configs.agent_configs import TF_DR_TIMEOUT_SHORT
from onyx.prompts.dr_prompts import CUSTOM_TOOL_PREP_PROMPT
from onyx.prompts.dr_prompts import CUSTOM_TOOL_USE_PROMPT
from onyx.prompts.dr_prompts import OKTA_TOOL_USE_SPECIAL_PROMPT
from onyx.utils.logger import setup_logger
logger = setup_logger()
def generic_internal_tool_act(
state: BranchInput,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> BranchUpdate:
"""
LangGraph node to perform a generic tool call as part of the DR process.
"""
node_start_time = datetime.now()
iteration_nr = state.iteration_nr
parallelization_nr = state.parallelization_nr
if not state.available_tools:
raise ValueError("available_tools is not set")
generic_internal_tool_info = state.available_tools[state.tools_used[-1]]
generic_internal_tool_name = generic_internal_tool_info.llm_path
generic_internal_tool = generic_internal_tool_info.tool_object
if generic_internal_tool is None:
raise ValueError("generic_internal_tool is not set")
branch_query = state.branch_question
if not branch_query:
raise ValueError("branch_query is not set")
graph_config = cast(GraphConfig, config["metadata"]["config"])
base_question = graph_config.inputs.prompt_builder.raw_user_query
logger.debug(
f"Tool call start for {generic_internal_tool_name} {iteration_nr}.{parallelization_nr} at {datetime.now()}"
)
# get tool call args
tool_args: dict | None = None
if graph_config.tooling.using_tool_calling_llm:
# get tool call args from tool-calling LLM
tool_use_prompt = CUSTOM_TOOL_PREP_PROMPT.build(
query=branch_query,
base_question=base_question,
tool_description=generic_internal_tool_info.description,
)
tool_calling_msg = graph_config.tooling.primary_llm.invoke_langchain(
tool_use_prompt,
tools=[generic_internal_tool.tool_definition()],
tool_choice="required",
timeout_override=TF_DR_TIMEOUT_SHORT,
)
# make sure we got a tool call
if (
isinstance(tool_calling_msg, AIMessage)
and len(tool_calling_msg.tool_calls) == 1
):
tool_args = tool_calling_msg.tool_calls[0]["args"]
else:
logger.warning("Tool-calling LLM did not emit a tool call")
if tool_args is None:
# get tool call args from non-tool-calling LLM or for failed tool-calling LLM
tool_args = generic_internal_tool.get_args_for_non_tool_calling_llm(
query=branch_query,
history=[],
llm=graph_config.tooling.primary_llm,
force_run=True,
)
if tool_args is None:
raise ValueError("Failed to obtain tool arguments from LLM")
# run the tool
tool_responses = list(generic_internal_tool.run(**tool_args))
final_data = generic_internal_tool.final_result(*tool_responses)
tool_result_str = json.dumps(final_data, ensure_ascii=False)
tool_str = (
f"Tool used: {generic_internal_tool.display_name}\n"
f"Description: {generic_internal_tool_info.description}\n"
f"Result: {tool_result_str}"
)
if generic_internal_tool.display_name == "Okta Profile":
tool_prompt = OKTA_TOOL_USE_SPECIAL_PROMPT
else:
tool_prompt = CUSTOM_TOOL_USE_PROMPT
tool_summary_prompt = tool_prompt.build(
query=branch_query, base_question=base_question, tool_response=tool_str
)
answer_string = str(
graph_config.tooling.primary_llm.invoke_langchain(
tool_summary_prompt, timeout_override=TF_DR_TIMEOUT_SHORT
).content
).strip()
logger.debug(
f"Tool call end for {generic_internal_tool_name} {iteration_nr}.{parallelization_nr} at {datetime.now()}"
)
return BranchUpdate(
branch_iteration_responses=[
IterationAnswer(
tool=generic_internal_tool.llm_name,
tool_id=generic_internal_tool_info.tool_id,
iteration_nr=iteration_nr,
parallelization_nr=parallelization_nr,
question=branch_query,
answer=answer_string,
claims=[],
cited_documents={},
reasoning="",
additional_data=None,
response_type="text", # TODO: convert all response types to enums
data=answer_string,
)
],
log_messages=[
get_langgraph_node_log_string(
graph_component="custom_tool",
node_name="tool_calling",
node_start_time=node_start_time,
)
],
)

View File

@@ -0,0 +1,82 @@
from datetime import datetime
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.exploration.sub_agents.states import SubAgentMainState
from onyx.agents.agent_search.exploration.sub_agents.states import SubAgentUpdate
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.server.query_and_chat.streaming_models import CustomToolDelta
from onyx.server.query_and_chat.streaming_models import CustomToolStart
from onyx.server.query_and_chat.streaming_models import SectionEnd
from onyx.utils.logger import setup_logger
logger = setup_logger()
def generic_internal_tool_reducer(
state: SubAgentMainState,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> SubAgentUpdate:
"""
LangGraph node to perform a generic tool call as part of the DR process.
"""
node_start_time = datetime.now()
current_step_nr = state.current_step_nr
branch_updates = state.branch_iteration_responses
current_iteration = state.iteration_nr
new_updates = [
update for update in branch_updates if update.iteration_nr == current_iteration
]
for new_update in new_updates:
if not new_update.response_type:
raise ValueError("Response type is not returned.")
write_custom_event(
current_step_nr,
CustomToolStart(
tool_name=new_update.tool,
),
writer,
)
write_custom_event(
current_step_nr,
CustomToolDelta(
tool_name=new_update.tool,
response_type=new_update.response_type,
data=new_update.data,
file_ids=[],
),
writer,
)
write_custom_event(
current_step_nr,
SectionEnd(),
writer,
)
current_step_nr += 1
return SubAgentUpdate(
iteration_responses=new_updates,
log_messages=[
get_langgraph_node_log_string(
graph_component="custom_tool",
node_name="consolidation",
node_start_time=node_start_time,
)
],
)

View File

@@ -0,0 +1,28 @@
from collections.abc import Hashable
from langgraph.types import Send
from onyx.agents.agent_search.exploration.sub_agents.states import BranchInput
from onyx.agents.agent_search.exploration.sub_agents.states import (
SubAgentInput,
)
def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
return [
Send(
"act",
BranchInput(
iteration_nr=state.iteration_nr,
parallelization_nr=parallelization_nr,
branch_question=query,
context="",
active_source_types=state.active_source_types,
tools_used=state.tools_used,
available_tools=state.available_tools,
),
)
for parallelization_nr, query in enumerate(
state.query_list[:1] # no parallel call for now
)
]

View File

@@ -0,0 +1,50 @@
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
from onyx.agents.agent_search.exploration.sub_agents.generic_internal_tool.dr_generic_internal_tool_1_branch import (
generic_internal_tool_branch,
)
from onyx.agents.agent_search.exploration.sub_agents.generic_internal_tool.dr_generic_internal_tool_2_act import (
generic_internal_tool_act,
)
from onyx.agents.agent_search.exploration.sub_agents.generic_internal_tool.dr_generic_internal_tool_3_reduce import (
generic_internal_tool_reducer,
)
from onyx.agents.agent_search.exploration.sub_agents.generic_internal_tool.dr_generic_internal_tool_conditional_edges import (
branching_router,
)
from onyx.agents.agent_search.exploration.sub_agents.states import SubAgentInput
from onyx.agents.agent_search.exploration.sub_agents.states import SubAgentMainState
from onyx.utils.logger import setup_logger
logger = setup_logger()
def dr_generic_internal_tool_graph_builder() -> StateGraph:
"""
LangGraph graph builder for Generic Tool Sub-Agent
"""
graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
### Add nodes ###
graph.add_node("branch", generic_internal_tool_branch)
graph.add_node("act", generic_internal_tool_act)
graph.add_node("reducer", generic_internal_tool_reducer)
### Add edges ###
graph.add_edge(start_key=START, end_key="branch")
graph.add_conditional_edges("branch", branching_router)
graph.add_edge(start_key="act", end_key="reducer")
graph.add_edge(start_key="reducer", end_key=END)
return graph

View File

@@ -0,0 +1,45 @@
from datetime import datetime
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.exploration.states import LoggerUpdate
from onyx.agents.agent_search.exploration.sub_agents.states import SubAgentInput
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.server.query_and_chat.streaming_models import ImageGenerationToolStart
from onyx.utils.logger import setup_logger
logger = setup_logger()
def image_generation_branch(
state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> LoggerUpdate:
"""
LangGraph node to perform a image generation as part of the DR process.
"""
node_start_time = datetime.now()
iteration_nr = state.iteration_nr
logger.debug(f"Image generation start {iteration_nr} at {datetime.now()}")
# tell frontend that we are starting the image generation tool
write_custom_event(
state.current_step_nr,
ImageGenerationToolStart(),
writer,
)
return LoggerUpdate(
log_messages=[
get_langgraph_node_log_string(
graph_component="image_generation",
node_name="branching",
node_start_time=node_start_time,
)
],
)

View File

@@ -0,0 +1,189 @@
import json
from datetime import datetime
from typing import cast
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.exploration.models import GeneratedImage
from onyx.agents.agent_search.exploration.models import IterationAnswer
from onyx.agents.agent_search.exploration.sub_agents.states import BranchInput
from onyx.agents.agent_search.exploration.sub_agents.states import BranchUpdate
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.file_store.utils import build_frontend_file_url
from onyx.file_store.utils import save_files
from onyx.server.query_and_chat.streaming_models import ImageGenerationToolHeartbeat
from onyx.tools.tool_implementations.images.image_generation_tool import (
IMAGE_GENERATION_HEARTBEAT_ID,
)
from onyx.tools.tool_implementations.images.image_generation_tool import (
IMAGE_GENERATION_RESPONSE_ID,
)
from onyx.tools.tool_implementations.images.image_generation_tool import (
ImageGenerationResponse,
)
from onyx.tools.tool_implementations.images.image_generation_tool import (
ImageGenerationTool,
)
from onyx.tools.tool_implementations.images.image_generation_tool import ImageShape
from onyx.utils.logger import setup_logger
logger = setup_logger()
def image_generation(
state: BranchInput,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> BranchUpdate:
"""
LangGraph node to perform a standard search as part of the DR process.
"""
node_start_time = datetime.now()
iteration_nr = state.iteration_nr
parallelization_nr = state.parallelization_nr
state.assistant_system_prompt
state.assistant_task_prompt
branch_query = state.branch_question
if not branch_query:
raise ValueError("branch_query is not set")
graph_config = cast(GraphConfig, config["metadata"]["config"])
graph_config.inputs.prompt_builder.raw_user_query
graph_config.behavior.research_type
if not state.available_tools:
raise ValueError("available_tools is not set")
image_tool_info = state.available_tools[state.tools_used[-1]]
image_tool = cast(ImageGenerationTool, image_tool_info.tool_object)
image_prompt = branch_query
requested_shape: ImageShape | None = None
try:
parsed_query = json.loads(branch_query)
except json.JSONDecodeError:
parsed_query = None
if isinstance(parsed_query, dict):
prompt_from_llm = parsed_query.get("prompt")
if isinstance(prompt_from_llm, str) and prompt_from_llm.strip():
image_prompt = prompt_from_llm.strip()
raw_shape = parsed_query.get("shape")
if isinstance(raw_shape, str):
try:
requested_shape = ImageShape(raw_shape)
except ValueError:
logger.warning(
"Received unsupported image shape '%s' from LLM. Falling back to square.",
raw_shape,
)
logger.debug(
f"Image generation start for {iteration_nr}.{parallelization_nr} at {datetime.now()}"
)
# Generate images using the image generation tool
image_generation_responses: list[ImageGenerationResponse] = []
if requested_shape is not None:
tool_iterator = image_tool.run(
prompt=image_prompt,
shape=requested_shape.value,
)
else:
tool_iterator = image_tool.run(prompt=image_prompt)
for tool_response in tool_iterator:
if tool_response.id == IMAGE_GENERATION_HEARTBEAT_ID:
# Stream heartbeat to frontend
write_custom_event(
state.current_step_nr,
ImageGenerationToolHeartbeat(),
writer,
)
elif tool_response.id == IMAGE_GENERATION_RESPONSE_ID:
response = cast(list[ImageGenerationResponse], tool_response.response)
image_generation_responses = response
break
# save images to file store
file_ids = save_files(
urls=[],
base64_files=[img.image_data for img in image_generation_responses],
)
final_generated_images = [
GeneratedImage(
file_id=file_id,
url=build_frontend_file_url(file_id),
revised_prompt=img.revised_prompt,
shape=(requested_shape or ImageShape.SQUARE).value,
)
for file_id, img in zip(file_ids, image_generation_responses)
]
logger.debug(
f"Image generation complete for {iteration_nr}.{parallelization_nr} at {datetime.now()}"
)
# Create answer string describing the generated images
if final_generated_images:
image_descriptions = []
for i, img in enumerate(final_generated_images, 1):
if img.shape and img.shape != ImageShape.SQUARE.value:
image_descriptions.append(
f"Image {i}: {img.revised_prompt} (shape: {img.shape})"
)
else:
image_descriptions.append(f"Image {i}: {img.revised_prompt}")
answer_string = (
f"Generated {len(final_generated_images)} image(s) based on the request: {image_prompt}\n\n"
+ "\n".join(image_descriptions)
)
if requested_shape:
reasoning = (
"Used image generation tool to create "
f"{len(final_generated_images)} image(s) in {requested_shape.value} orientation."
)
else:
reasoning = (
"Used image generation tool to create "
f"{len(final_generated_images)} image(s) based on the user's request."
)
else:
answer_string = f"Failed to generate images for request: {image_prompt}"
reasoning = "Image generation tool did not return any results."
return BranchUpdate(
branch_iteration_responses=[
IterationAnswer(
tool=image_tool_info.llm_path,
tool_id=image_tool_info.tool_id,
iteration_nr=iteration_nr,
parallelization_nr=parallelization_nr,
question=branch_query,
answer=answer_string,
claims=[],
cited_documents={},
reasoning=reasoning,
generated_images=final_generated_images,
)
],
log_messages=[
get_langgraph_node_log_string(
graph_component="image_generation",
node_name="generating",
node_start_time=node_start_time,
)
],
)

View File

@@ -0,0 +1,71 @@
from datetime import datetime
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.exploration.models import GeneratedImage
from onyx.agents.agent_search.exploration.sub_agents.states import SubAgentMainState
from onyx.agents.agent_search.exploration.sub_agents.states import SubAgentUpdate
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.server.query_and_chat.streaming_models import ImageGenerationToolDelta
from onyx.server.query_and_chat.streaming_models import SectionEnd
from onyx.utils.logger import setup_logger
logger = setup_logger()
def is_reducer(
state: SubAgentMainState,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> SubAgentUpdate:
"""
LangGraph node to perform a standard search as part of the DR process.
"""
node_start_time = datetime.now()
branch_updates = state.branch_iteration_responses
current_iteration = state.iteration_nr
current_step_nr = state.current_step_nr
new_updates = [
update for update in branch_updates if update.iteration_nr == current_iteration
]
generated_images: list[GeneratedImage] = []
for update in new_updates:
if update.generated_images:
generated_images.extend(update.generated_images)
# Write the results to the stream
write_custom_event(
current_step_nr,
ImageGenerationToolDelta(
images=generated_images,
),
writer,
)
write_custom_event(
current_step_nr,
SectionEnd(),
writer,
)
current_step_nr += 1
return SubAgentUpdate(
iteration_responses=new_updates,
current_step_nr=current_step_nr,
log_messages=[
get_langgraph_node_log_string(
graph_component="image_generation",
node_name="consolidation",
node_start_time=node_start_time,
)
],
)

View File

@@ -0,0 +1,29 @@
from collections.abc import Hashable
from langgraph.types import Send
from onyx.agents.agent_search.exploration.constants import MAX_DR_PARALLEL_SEARCH
from onyx.agents.agent_search.exploration.sub_agents.states import BranchInput
from onyx.agents.agent_search.exploration.sub_agents.states import SubAgentInput
def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
return [
Send(
"act",
BranchInput(
iteration_nr=state.iteration_nr,
parallelization_nr=parallelization_nr,
branch_question=query,
context="",
active_source_types=state.active_source_types,
tools_used=state.tools_used,
available_tools=state.available_tools,
assistant_system_prompt=state.assistant_system_prompt,
assistant_task_prompt=state.assistant_task_prompt,
),
)
for parallelization_nr, query in enumerate(
state.query_list[:MAX_DR_PARALLEL_SEARCH]
)
]

View File

@@ -0,0 +1,50 @@
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
from onyx.agents.agent_search.exploration.sub_agents.image_generation.dr_image_generation_1_branch import (
image_generation_branch,
)
from onyx.agents.agent_search.exploration.sub_agents.image_generation.dr_image_generation_2_act import (
image_generation,
)
from onyx.agents.agent_search.exploration.sub_agents.image_generation.dr_image_generation_3_reduce import (
is_reducer,
)
from onyx.agents.agent_search.exploration.sub_agents.image_generation.dr_image_generation_conditional_edges import (
branching_router,
)
from onyx.agents.agent_search.exploration.sub_agents.states import SubAgentInput
from onyx.agents.agent_search.exploration.sub_agents.states import SubAgentMainState
from onyx.utils.logger import setup_logger
logger = setup_logger()
def dr_image_generation_graph_builder() -> StateGraph:
"""
LangGraph graph builder for Image Generation Sub-Agent
"""
graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
### Add nodes ###
graph.add_node("branch", image_generation_branch)
graph.add_node("act", image_generation)
graph.add_node("reducer", is_reducer)
### Add edges ###
graph.add_edge(start_key=START, end_key="branch")
graph.add_conditional_edges("branch", branching_router)
graph.add_edge(start_key="act", end_key="reducer")
graph.add_edge(start_key="reducer", end_key=END)
return graph

View File

@@ -0,0 +1,13 @@
from pydantic import BaseModel
class GeneratedImage(BaseModel):
file_id: str
url: str
revised_prompt: str
shape: str | None = None
# Needed for PydanticType
class GeneratedImageFullResult(BaseModel):
images: list[GeneratedImage]

View File

@@ -0,0 +1,36 @@
from datetime import datetime
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.exploration.states import LoggerUpdate
from onyx.agents.agent_search.exploration.sub_agents.states import SubAgentInput
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.utils.logger import setup_logger
logger = setup_logger()
def kg_search_branch(
state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> LoggerUpdate:
"""
LangGraph node to perform a KG search as part of the DR process.
"""
node_start_time = datetime.now()
iteration_nr = state.iteration_nr
logger.debug(f"Search start for KG Search {iteration_nr} at {datetime.now()}")
return LoggerUpdate(
log_messages=[
get_langgraph_node_log_string(
graph_component="kg_search",
node_name="branching",
node_start_time=node_start_time,
)
],
)

View File

@@ -0,0 +1,97 @@
from datetime import datetime
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.exploration.models import IterationAnswer
from onyx.agents.agent_search.exploration.sub_agents.states import BranchInput
from onyx.agents.agent_search.exploration.sub_agents.states import BranchUpdate
from onyx.agents.agent_search.exploration.utils import extract_document_citations
from onyx.agents.agent_search.kb_search.graph_builder import kb_graph_builder
from onyx.agents.agent_search.kb_search.states import MainInput as KbMainInput
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.context.search.models import InferenceSection
from onyx.utils.logger import setup_logger
logger = setup_logger()
def kg_search(
state: BranchInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> BranchUpdate:
"""
LangGraph node to perform a KG search as part of the DR process.
"""
node_start_time = datetime.now()
iteration_nr = state.iteration_nr
state.current_step_nr
parallelization_nr = state.parallelization_nr
search_query = state.branch_question
if not search_query:
raise ValueError("search_query is not set")
logger.debug(
f"Search start for KG Search {iteration_nr}.{parallelization_nr} at {datetime.now()}"
)
if not state.available_tools:
raise ValueError("available_tools is not set")
kg_tool_info = state.available_tools[state.tools_used[-1]]
kb_graph = kb_graph_builder().compile()
kb_results = kb_graph.invoke(
input=KbMainInput(question=search_query, individual_flow=False),
config=config,
)
# get cited documents
answer_string = kb_results.get("final_answer") or "No answer provided"
claims: list[str] = []
retrieved_docs: list[InferenceSection] = kb_results.get("retrieved_documents", [])
(
citation_numbers,
answer_string,
claims,
) = extract_document_citations(answer_string, claims)
# if citation is empty, the answer must have come from the KG rather than a doc
# in that case, simply cite the docs returned by the KG
if not citation_numbers:
citation_numbers = [i + 1 for i in range(len(retrieved_docs))]
cited_documents = {
citation_number: retrieved_docs[citation_number - 1]
for citation_number in citation_numbers
if citation_number <= len(retrieved_docs)
}
return BranchUpdate(
branch_iteration_responses=[
IterationAnswer(
tool=kg_tool_info.llm_path,
tool_id=kg_tool_info.tool_id,
iteration_nr=iteration_nr,
parallelization_nr=parallelization_nr,
question=search_query,
answer=answer_string,
claims=claims,
cited_documents=cited_documents,
reasoning=None,
additional_data=None,
)
],
log_messages=[
get_langgraph_node_log_string(
graph_component="kg_search",
node_name="searching",
node_start_time=node_start_time,
)
],
)

View File

@@ -0,0 +1,123 @@
from datetime import datetime
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.exploration.sub_agents.states import SubAgentMainState
from onyx.agents.agent_search.exploration.sub_agents.states import SubAgentUpdate
from onyx.agents.agent_search.exploration.utils import (
convert_inference_sections_to_search_docs,
)
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.server.query_and_chat.streaming_models import ReasoningDelta
from onyx.server.query_and_chat.streaming_models import ReasoningStart
from onyx.server.query_and_chat.streaming_models import SearchToolDelta
from onyx.server.query_and_chat.streaming_models import SearchToolStart
from onyx.server.query_and_chat.streaming_models import SectionEnd
from onyx.utils.logger import setup_logger
logger = setup_logger()
_MAX_KG_STEAMED_ANSWER_LENGTH = 1000 # num characters
def kg_search_reducer(
state: SubAgentMainState,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> SubAgentUpdate:
"""
LangGraph node to perform a KG search as part of the DR process.
"""
node_start_time = datetime.now()
branch_updates = state.branch_iteration_responses
current_iteration = state.iteration_nr
current_step_nr = state.current_step_nr
new_updates = [
update for update in branch_updates if update.iteration_nr == current_iteration
]
queries = [update.question for update in new_updates]
doc_lists = [list(update.cited_documents.values()) for update in new_updates]
doc_list = []
for xs in doc_lists:
for x in xs:
doc_list.append(x)
retrieved_search_docs = convert_inference_sections_to_search_docs(doc_list)
kg_answer = (
"The Knowledge Graph Answer:\n\n" + new_updates[0].answer
if len(queries) == 1
else None
)
if len(retrieved_search_docs) > 0:
write_custom_event(
current_step_nr,
SearchToolStart(
is_internet_search=False,
),
writer,
)
write_custom_event(
current_step_nr,
SearchToolDelta(
queries=queries,
documents=retrieved_search_docs,
),
writer,
)
write_custom_event(
current_step_nr,
SectionEnd(),
writer,
)
current_step_nr += 1
if kg_answer is not None:
kg_display_answer = (
f"{kg_answer[:_MAX_KG_STEAMED_ANSWER_LENGTH]}..."
if len(kg_answer) > _MAX_KG_STEAMED_ANSWER_LENGTH
else kg_answer
)
write_custom_event(
current_step_nr,
ReasoningStart(),
writer,
)
write_custom_event(
current_step_nr,
ReasoningDelta(reasoning=kg_display_answer),
writer,
)
write_custom_event(
current_step_nr,
SectionEnd(),
writer,
)
current_step_nr += 1
return SubAgentUpdate(
iteration_responses=new_updates,
current_step_nr=current_step_nr,
log_messages=[
get_langgraph_node_log_string(
graph_component="kg_search",
node_name="consolidation",
node_start_time=node_start_time,
)
],
)

View File

@@ -0,0 +1,27 @@
from collections.abc import Hashable
from langgraph.types import Send
from onyx.agents.agent_search.exploration.sub_agents.states import BranchInput
from onyx.agents.agent_search.exploration.sub_agents.states import SubAgentInput
def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
return [
Send(
"act",
BranchInput(
iteration_nr=state.iteration_nr,
parallelization_nr=parallelization_nr,
branch_question=query,
context="",
tools_used=state.tools_used,
available_tools=state.available_tools,
assistant_system_prompt=state.assistant_system_prompt,
assistant_task_prompt=state.assistant_task_prompt,
),
)
for parallelization_nr, query in enumerate(
state.query_list[:1] # no parallel search for now
)
]

View File

@@ -0,0 +1,50 @@
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
from onyx.agents.agent_search.exploration.sub_agents.kg_search.dr_kg_search_1_branch import (
kg_search_branch,
)
from onyx.agents.agent_search.exploration.sub_agents.kg_search.dr_kg_search_2_act import (
kg_search,
)
from onyx.agents.agent_search.exploration.sub_agents.kg_search.dr_kg_search_3_reduce import (
kg_search_reducer,
)
from onyx.agents.agent_search.exploration.sub_agents.kg_search.dr_kg_search_conditional_edges import (
branching_router,
)
from onyx.agents.agent_search.exploration.sub_agents.states import SubAgentInput
from onyx.agents.agent_search.exploration.sub_agents.states import SubAgentMainState
from onyx.utils.logger import setup_logger
logger = setup_logger()
def dr_kg_search_graph_builder() -> StateGraph:
"""
LangGraph graph builder for KG Search Sub-Agent
"""
graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
### Add nodes ###
graph.add_node("branch", kg_search_branch)
graph.add_node("act", kg_search)
graph.add_node("reducer", kg_search_reducer)
### Add edges ###
graph.add_edge(start_key=START, end_key="branch")
graph.add_conditional_edges("branch", branching_router)
graph.add_edge(start_key="act", end_key="reducer")
graph.add_edge(start_key="reducer", end_key=END)
return graph

View File

@@ -0,0 +1,46 @@
from operator import add
from typing import Annotated
from onyx.agents.agent_search.exploration.models import IterationAnswer
from onyx.agents.agent_search.exploration.models import OrchestratorTool
from onyx.agents.agent_search.exploration.states import LoggerUpdate
from onyx.db.connector import DocumentSource
class SubAgentUpdate(LoggerUpdate):
iteration_responses: Annotated[list[IterationAnswer], add] = []
current_step_nr: int = 1
class BranchUpdate(LoggerUpdate):
branch_iteration_responses: Annotated[list[IterationAnswer], add] = []
class SubAgentInput(LoggerUpdate):
iteration_nr: int = 0
current_step_nr: int = 1
query_list: list[str] = []
context: str | None = None
active_source_types: list[DocumentSource] | None = None
tools_used: Annotated[list[str], add] = []
available_tools: dict[str, OrchestratorTool] | None = None
assistant_system_prompt: str | None = None
assistant_task_prompt: str | None = None
class SubAgentMainState(
# This includes the core state
SubAgentInput,
SubAgentUpdate,
BranchUpdate,
):
pass
class BranchInput(SubAgentInput):
parallelization_nr: int = 0
branch_question: str
class CustomToolBranchInput(LoggerUpdate):
tool_info: OrchestratorTool

View File

@@ -0,0 +1,71 @@
from collections.abc import Sequence
from exa_py import Exa
from exa_py.api import HighlightsContentsOptions
from onyx.agents.agent_search.exploration.sub_agents.web_search.models import (
WebContent,
)
from onyx.agents.agent_search.exploration.sub_agents.web_search.models import (
WebSearchProvider,
)
from onyx.agents.agent_search.exploration.sub_agents.web_search.models import (
WebSearchResult,
)
from onyx.configs.chat_configs import EXA_API_KEY
from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
from onyx.utils.retry_wrapper import retry_builder
class ExaClient(WebSearchProvider):
def __init__(self, api_key: str | None = EXA_API_KEY) -> None:
self.exa = Exa(api_key=api_key)
@retry_builder(tries=3, delay=1, backoff=2)
def search(self, query: str) -> list[WebSearchResult]:
response = self.exa.search_and_contents(
query,
type="auto",
highlights=HighlightsContentsOptions(
num_sentences=2,
highlights_per_url=1,
),
num_results=10,
)
return [
WebSearchResult(
title=result.title or "",
link=result.url,
snippet=result.highlights[0] if result.highlights else "",
author=result.author,
published_date=(
time_str_to_utc(result.published_date)
if result.published_date
else None
),
)
for result in response.results
]
@retry_builder(tries=3, delay=1, backoff=2)
def contents(self, urls: Sequence[str]) -> list[WebContent]:
response = self.exa.get_contents(
urls=list(urls),
text=True,
livecrawl="preferred",
)
return [
WebContent(
title=result.title or "",
link=result.url,
full_content=result.text or "",
published_date=(
time_str_to_utc(result.published_date)
if result.published_date
else None
),
)
for result in response.results
]

View File

@@ -0,0 +1,148 @@
import json
from collections.abc import Sequence
from concurrent.futures import ThreadPoolExecutor
import requests
from onyx.agents.agent_search.exploration.sub_agents.web_search.models import (
WebContent,
)
from onyx.agents.agent_search.exploration.sub_agents.web_search.models import (
WebSearchProvider,
)
from onyx.agents.agent_search.exploration.sub_agents.web_search.models import (
WebSearchResult,
)
from onyx.configs.chat_configs import SERPER_API_KEY
from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
from onyx.utils.retry_wrapper import retry_builder
SERPER_SEARCH_URL = "https://google.serper.dev/search"
SERPER_CONTENTS_URL = "https://scrape.serper.dev"
class SerperClient(WebSearchProvider):
def __init__(self, api_key: str | None = SERPER_API_KEY) -> None:
self.headers = {
"X-API-KEY": api_key,
"Content-Type": "application/json",
}
@retry_builder(tries=3, delay=1, backoff=2)
def search(self, query: str) -> list[WebSearchResult]:
payload = {
"q": query,
}
response = requests.post(
SERPER_SEARCH_URL,
headers=self.headers,
data=json.dumps(payload),
)
response.raise_for_status()
results = response.json()
organic_results = results["organic"]
return [
WebSearchResult(
title=result["title"],
link=result["link"],
snippet=result["snippet"],
author=None,
published_date=None,
)
for result in organic_results
]
def contents(self, urls: Sequence[str]) -> list[WebContent]:
if not urls:
return []
# Serper can responds with 500s regularly. We want to retry,
# but in the event of failure, return an unsuccesful scrape.
def safe_get_webpage_content(url: str) -> WebContent:
try:
return self._get_webpage_content(url)
except Exception:
return WebContent(
title="",
link=url,
full_content="",
published_date=None,
scrape_successful=False,
)
with ThreadPoolExecutor(max_workers=min(8, len(urls))) as e:
return list(e.map(safe_get_webpage_content, urls))
@retry_builder(tries=3, delay=1, backoff=2)
def _get_webpage_content(self, url: str) -> WebContent:
payload = {
"url": url,
}
response = requests.post(
SERPER_CONTENTS_URL,
headers=self.headers,
data=json.dumps(payload),
)
# 400 returned when serper cannot scrape
if response.status_code == 400:
return WebContent(
title="",
link=url,
full_content="",
published_date=None,
scrape_successful=False,
)
response.raise_for_status()
response_json = response.json()
# Response only guarantees text
text = response_json["text"]
# metadata & jsonld is not guaranteed to be present
metadata = response_json.get("metadata", {})
jsonld = response_json.get("jsonld", {})
title = extract_title_from_metadata(metadata)
# Serper does not provide a reliable mechanism to extract the url
response_url = url
published_date_str = extract_published_date_from_jsonld(jsonld)
published_date = None
if published_date_str:
try:
published_date = time_str_to_utc(published_date_str)
except Exception:
published_date = None
return WebContent(
title=title or "",
link=response_url,
full_content=text or "",
published_date=published_date,
)
def extract_title_from_metadata(metadata: dict[str, str]) -> str | None:
keys = ["title", "og:title"]
return extract_value_from_dict(metadata, keys)
def extract_published_date_from_jsonld(jsonld: dict[str, str]) -> str | None:
keys = ["dateModified"]
return extract_value_from_dict(jsonld, keys)
def extract_value_from_dict(data: dict[str, str], keys: list[str]) -> str | None:
for key in keys:
if key in data:
return data[key]
return None

View File

@@ -0,0 +1,47 @@
from datetime import datetime
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.exploration.states import LoggerUpdate
from onyx.agents.agent_search.exploration.sub_agents.states import SubAgentInput
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.server.query_and_chat.streaming_models import SearchToolStart
from onyx.utils.logger import setup_logger
logger = setup_logger()
def is_branch(
state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> LoggerUpdate:
"""
LangGraph node to perform a web search as part of the DR process.
"""
node_start_time = datetime.now()
iteration_nr = state.iteration_nr
current_step_nr = state.current_step_nr
logger.debug(f"Search start for Web Search {iteration_nr} at {datetime.now()}")
write_custom_event(
current_step_nr,
SearchToolStart(
is_internet_search=True,
),
writer,
)
return LoggerUpdate(
log_messages=[
get_langgraph_node_log_string(
graph_component="internet_search",
node_name="branching",
node_start_time=node_start_time,
)
],
)

View File

@@ -0,0 +1,128 @@
from datetime import datetime
from typing import cast
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from langsmith import traceable
from onyx.agents.agent_search.exploration.models import WebSearchAnswer
from onyx.agents.agent_search.exploration.sub_agents.web_search.models import (
WebSearchResult,
)
from onyx.agents.agent_search.exploration.sub_agents.web_search.providers import (
get_default_provider,
)
from onyx.agents.agent_search.exploration.sub_agents.web_search.states import (
InternetSearchInput,
)
from onyx.agents.agent_search.exploration.sub_agents.web_search.states import (
InternetSearchUpdate,
)
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.llm import invoke_llm_json
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.agents.agent_search.utils import create_question_prompt
from onyx.configs.agent_configs import TF_DR_TIMEOUT_SHORT
from onyx.prompts.dr_prompts import WEB_SEARCH_URL_SELECTION_PROMPT
from onyx.server.query_and_chat.streaming_models import SearchToolDelta
from onyx.utils.logger import setup_logger
logger = setup_logger()
def web_search(
state: InternetSearchInput,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> InternetSearchUpdate:
"""
LangGraph node to perform internet search and decide which URLs to fetch.
"""
node_start_time = datetime.now()
current_step_nr = state.current_step_nr
if not current_step_nr:
raise ValueError("Current step number is not set. This should not happen.")
assistant_system_prompt = state.assistant_system_prompt
assistant_task_prompt = state.assistant_task_prompt
if not state.available_tools:
raise ValueError("available_tools is not set")
search_query = state.branch_question
write_custom_event(
current_step_nr,
SearchToolDelta(
queries=[search_query],
documents=[],
),
writer,
)
graph_config = cast(GraphConfig, config["metadata"]["config"])
base_question = graph_config.inputs.prompt_builder.raw_user_query
if graph_config.inputs.persona is None:
raise ValueError("persona is not set")
provider = get_default_provider()
if not provider:
raise ValueError("No internet search provider found")
@traceable(name="Search Provider API Call")
def _search(search_query: str) -> list[WebSearchResult]:
search_results: list[WebSearchResult] = []
try:
search_results = list(provider.search(search_query))
except Exception as e:
logger.error(f"Error performing search: {e}")
return search_results
search_results: list[WebSearchResult] = _search(search_query)
search_results_text = "\n\n".join(
[
f"{i}. {result.title}\n URL: {result.link}\n"
+ (f" Author: {result.author}\n" if result.author else "")
+ (
f" Date: {result.published_date.strftime('%Y-%m-%d')}\n"
if result.published_date
else ""
)
+ (f" Snippet: {result.snippet}\n" if result.snippet else "")
for i, result in enumerate(search_results)
]
)
agent_decision_prompt = WEB_SEARCH_URL_SELECTION_PROMPT.build(
search_query=search_query,
base_question=base_question,
search_results_text=search_results_text,
)
agent_decision = invoke_llm_json(
llm=graph_config.tooling.fast_llm,
prompt=create_question_prompt(
assistant_system_prompt,
agent_decision_prompt + (assistant_task_prompt or ""),
),
schema=WebSearchAnswer,
timeout_override=TF_DR_TIMEOUT_SHORT,
)
results_to_open = [
(search_query, search_results[i])
for i in agent_decision.urls_to_open_indices
if i < len(search_results) and i >= 0
]
return InternetSearchUpdate(
results_to_open=results_to_open,
log_messages=[
get_langgraph_node_log_string(
graph_component="internet_search",
node_name="searching",
node_start_time=node_start_time,
)
],
)

View File

@@ -0,0 +1,54 @@
from collections import defaultdict
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.exploration.sub_agents.web_search.models import (
WebSearchResult,
)
from onyx.agents.agent_search.exploration.sub_agents.web_search.states import (
InternetSearchInput,
)
from onyx.agents.agent_search.exploration.sub_agents.web_search.utils import (
dummy_inference_section_from_internet_search_result,
)
from onyx.agents.agent_search.exploration.utils import (
convert_inference_sections_to_search_docs,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.server.query_and_chat.streaming_models import SearchToolDelta
def dedup_urls(
state: InternetSearchInput,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> InternetSearchInput:
branch_questions_to_urls: dict[str, list[str]] = defaultdict(list)
unique_results_by_link: dict[str, WebSearchResult] = {}
for query, result in state.results_to_open:
branch_questions_to_urls[query].append(result.link)
if result.link not in unique_results_by_link:
unique_results_by_link[result.link] = result
unique_results = list(unique_results_by_link.values())
dummy_docs_inference_sections = [
dummy_inference_section_from_internet_search_result(doc)
for doc in unique_results
]
write_custom_event(
state.current_step_nr,
SearchToolDelta(
queries=[],
documents=convert_inference_sections_to_search_docs(
dummy_docs_inference_sections, is_internet=True
),
),
writer,
)
return InternetSearchInput(
results_to_open=[],
branch_questions_to_urls=branch_questions_to_urls,
)

View File

@@ -0,0 +1,71 @@
from datetime import datetime
from typing import cast
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.exploration.sub_agents.web_search.providers import (
get_default_provider,
)
from onyx.agents.agent_search.exploration.sub_agents.web_search.states import FetchInput
from onyx.agents.agent_search.exploration.sub_agents.web_search.states import (
FetchUpdate,
)
from onyx.agents.agent_search.exploration.sub_agents.web_search.utils import (
dummy_inference_section_from_internet_content,
)
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.context.search.models import InferenceSection
from onyx.utils.logger import setup_logger
logger = setup_logger()
def web_fetch(
state: FetchInput,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> FetchUpdate:
"""
LangGraph node to fetch content from URLs and process the results.
"""
node_start_time = datetime.now()
if not state.available_tools:
raise ValueError("available_tools is not set")
graph_config = cast(GraphConfig, config["metadata"]["config"])
if graph_config.inputs.persona is None:
raise ValueError("persona is not set")
provider = get_default_provider()
if provider is None:
raise ValueError("No web search provider found")
retrieved_docs: list[InferenceSection] = []
try:
retrieved_docs = [
dummy_inference_section_from_internet_content(result)
for result in provider.contents(state.urls_to_open)
]
except Exception as e:
logger.error(f"Error fetching URLs: {e}")
if not retrieved_docs:
logger.warning("No content retrieved from URLs")
return FetchUpdate(
raw_documents=retrieved_docs,
log_messages=[
get_langgraph_node_log_string(
graph_component="internet_search",
node_name="fetching",
node_start_time=node_start_time,
)
],
)

View File

@@ -0,0 +1,19 @@
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.exploration.sub_agents.web_search.states import FetchInput
from onyx.agents.agent_search.exploration.sub_agents.web_search.states import (
InternetSearchInput,
)
def collect_raw_docs(
state: FetchInput,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> InternetSearchInput:
raw_documents = state.raw_documents
return InternetSearchInput(
raw_documents=raw_documents,
)

View File

@@ -0,0 +1,135 @@
from datetime import datetime
from typing import cast
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.exploration.enums import ResearchType
from onyx.agents.agent_search.exploration.models import IterationAnswer
from onyx.agents.agent_search.exploration.models import SearchAnswer
from onyx.agents.agent_search.exploration.sub_agents.states import BranchUpdate
from onyx.agents.agent_search.exploration.sub_agents.web_search.states import (
SummarizeInput,
)
from onyx.agents.agent_search.exploration.utils import extract_document_citations
from onyx.agents.agent_search.kb_search.graph_utils import build_document_context
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.llm import invoke_llm_json
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.utils import create_question_prompt
from onyx.configs.agent_configs import TF_DR_TIMEOUT_SHORT
from onyx.context.search.models import InferenceSection
from onyx.prompts.dr_prompts import INTERNAL_SEARCH_PROMPTS
from onyx.utils.logger import setup_logger
from onyx.utils.url import normalize_url
logger = setup_logger()
def is_summarize(
state: SummarizeInput,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> BranchUpdate:
"""
LangGraph node to perform a internet search as part of the DR process.
"""
node_start_time = datetime.now()
# build branch iterations from fetch inputs
# Normalize URLs to handle mismatches from query parameters (e.g., ?activeTab=explore)
url_to_raw_document: dict[str, InferenceSection] = {}
for raw_document in state.raw_documents:
normalized_url = normalize_url(raw_document.center_chunk.semantic_identifier)
url_to_raw_document[normalized_url] = raw_document
# Normalize the URLs from branch_questions_to_urls as well
urls = [
normalize_url(url)
for url in state.branch_questions_to_urls[state.branch_question]
]
current_iteration = state.iteration_nr
graph_config = cast(GraphConfig, config["metadata"]["config"])
research_type = graph_config.behavior.research_type
if not state.available_tools:
raise ValueError("available_tools is not set")
is_tool_info = state.available_tools[state.tools_used[-1]]
if research_type == ResearchType.DEEP:
cited_raw_documents = [url_to_raw_document[url] for url in urls]
document_texts = _create_document_texts(cited_raw_documents)
search_prompt = INTERNAL_SEARCH_PROMPTS[research_type].build(
search_query=state.branch_question,
base_question=graph_config.inputs.prompt_builder.raw_user_query,
document_text=document_texts,
)
assistant_system_prompt = state.assistant_system_prompt
assistant_task_prompt = state.assistant_task_prompt
search_answer_json = invoke_llm_json(
llm=graph_config.tooling.primary_llm,
prompt=create_question_prompt(
assistant_system_prompt, search_prompt + (assistant_task_prompt or "")
),
schema=SearchAnswer,
timeout_override=TF_DR_TIMEOUT_SHORT,
)
answer_string = search_answer_json.answer
claims = search_answer_json.claims or []
reasoning = search_answer_json.reasoning or ""
(
citation_numbers,
answer_string,
claims,
) = extract_document_citations(answer_string, claims)
cited_documents = {
citation_number: cited_raw_documents[citation_number - 1]
for citation_number in citation_numbers
}
else:
answer_string = ""
reasoning = ""
claims = []
cited_raw_documents = [url_to_raw_document[url] for url in urls]
cited_documents = {
doc_num + 1: retrieved_doc
for doc_num, retrieved_doc in enumerate(cited_raw_documents)
}
return BranchUpdate(
branch_iteration_responses=[
IterationAnswer(
tool=is_tool_info.llm_path,
tool_id=is_tool_info.tool_id,
iteration_nr=current_iteration,
parallelization_nr=0,
question=state.branch_question,
answer=answer_string,
claims=claims,
cited_documents=cited_documents,
reasoning=reasoning,
additional_data=None,
)
],
log_messages=[
get_langgraph_node_log_string(
graph_component="internet_search",
node_name="summarizing",
node_start_time=node_start_time,
)
],
)
def _create_document_texts(raw_documents: list[InferenceSection]) -> str:
document_texts_list = []
for doc_num, retrieved_doc in enumerate(raw_documents):
if not isinstance(retrieved_doc, InferenceSection):
raise ValueError(f"Unexpected document type: {type(retrieved_doc)}")
chunk_text = build_document_context(retrieved_doc, doc_num + 1)
document_texts_list.append(chunk_text)
return "\n\n".join(document_texts_list)

View File

@@ -0,0 +1,56 @@
from datetime import datetime
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.exploration.sub_agents.states import SubAgentMainState
from onyx.agents.agent_search.exploration.sub_agents.states import SubAgentUpdate
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.server.query_and_chat.streaming_models import SectionEnd
from onyx.utils.logger import setup_logger
logger = setup_logger()
def is_reducer(
state: SubAgentMainState,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> SubAgentUpdate:
"""
LangGraph node to perform a internet search as part of the DR process.
"""
node_start_time = datetime.now()
branch_updates = state.branch_iteration_responses
current_iteration = state.iteration_nr
current_step_nr = state.current_step_nr
new_updates = [
update for update in branch_updates if update.iteration_nr == current_iteration
]
write_custom_event(
current_step_nr,
SectionEnd(),
writer,
)
current_step_nr += 1
return SubAgentUpdate(
iteration_responses=new_updates,
current_step_nr=current_step_nr,
log_messages=[
get_langgraph_node_log_string(
graph_component="internet_search",
node_name="consolidation",
node_start_time=node_start_time,
)
],
)

View File

@@ -0,0 +1,81 @@
from collections.abc import Hashable
from langgraph.types import Send
from onyx.agents.agent_search.exploration.constants import MAX_DR_PARALLEL_SEARCH
from onyx.agents.agent_search.exploration.sub_agents.states import SubAgentInput
from onyx.agents.agent_search.exploration.sub_agents.web_search.states import FetchInput
from onyx.agents.agent_search.exploration.sub_agents.web_search.states import (
InternetSearchInput,
)
from onyx.agents.agent_search.exploration.sub_agents.web_search.states import (
SummarizeInput,
)
def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
return [
Send(
"search",
InternetSearchInput(
iteration_nr=state.iteration_nr,
current_step_nr=state.current_step_nr,
parallelization_nr=parallelization_nr,
query_list=[query],
branch_question=query,
context="",
tools_used=state.tools_used,
available_tools=state.available_tools,
assistant_system_prompt=state.assistant_system_prompt,
assistant_task_prompt=state.assistant_task_prompt,
results_to_open=[],
),
)
for parallelization_nr, query in enumerate(
state.query_list[:MAX_DR_PARALLEL_SEARCH]
)
]
def fetch_router(state: InternetSearchInput) -> list[Send | Hashable]:
branch_questions_to_urls = state.branch_questions_to_urls
return [
Send(
"fetch",
FetchInput(
iteration_nr=state.iteration_nr,
urls_to_open=[url],
tools_used=state.tools_used,
available_tools=state.available_tools,
assistant_system_prompt=state.assistant_system_prompt,
assistant_task_prompt=state.assistant_task_prompt,
current_step_nr=state.current_step_nr,
branch_questions_to_urls=branch_questions_to_urls,
raw_documents=state.raw_documents,
),
)
for url in set(
url for urls in branch_questions_to_urls.values() for url in urls
)
]
def summarize_router(state: InternetSearchInput) -> list[Send | Hashable]:
branch_questions_to_urls = state.branch_questions_to_urls
return [
Send(
"summarize",
SummarizeInput(
iteration_nr=state.iteration_nr,
raw_documents=state.raw_documents,
branch_questions_to_urls=branch_questions_to_urls,
branch_question=branch_question,
tools_used=state.tools_used,
available_tools=state.available_tools,
assistant_system_prompt=state.assistant_system_prompt,
assistant_task_prompt=state.assistant_task_prompt,
current_step_nr=state.current_step_nr,
),
)
for branch_question in branch_questions_to_urls.keys()
]

View File

@@ -0,0 +1,84 @@
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
from onyx.agents.agent_search.exploration.sub_agents.states import SubAgentInput
from onyx.agents.agent_search.exploration.sub_agents.states import SubAgentMainState
from onyx.agents.agent_search.exploration.sub_agents.web_search.dr_ws_1_branch import (
is_branch,
)
from onyx.agents.agent_search.exploration.sub_agents.web_search.dr_ws_2_search import (
web_search,
)
from onyx.agents.agent_search.exploration.sub_agents.web_search.dr_ws_3_dedup_urls import (
dedup_urls,
)
from onyx.agents.agent_search.exploration.sub_agents.web_search.dr_ws_4_fetch import (
web_fetch,
)
from onyx.agents.agent_search.exploration.sub_agents.web_search.dr_ws_5_collect_raw_docs import (
collect_raw_docs,
)
from onyx.agents.agent_search.exploration.sub_agents.web_search.dr_ws_6_summarize import (
is_summarize,
)
from onyx.agents.agent_search.exploration.sub_agents.web_search.dr_ws_7_reduce import (
is_reducer,
)
from onyx.agents.agent_search.exploration.sub_agents.web_search.dr_ws_conditional_edges import (
branching_router,
)
from onyx.agents.agent_search.exploration.sub_agents.web_search.dr_ws_conditional_edges import (
fetch_router,
)
from onyx.agents.agent_search.exploration.sub_agents.web_search.dr_ws_conditional_edges import (
summarize_router,
)
from onyx.utils.logger import setup_logger
logger = setup_logger()
def dr_ws_graph_builder() -> StateGraph:
"""
LangGraph graph builder for Internet Search Sub-Agent
"""
graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
### Add nodes ###
graph.add_node("branch", is_branch)
graph.add_node("search", web_search)
graph.add_node("dedup_urls", dedup_urls)
graph.add_node("fetch", web_fetch)
graph.add_node("collect_raw_docs", collect_raw_docs)
graph.add_node("summarize", is_summarize)
graph.add_node("reducer", is_reducer)
### Add edges ###
graph.add_edge(start_key=START, end_key="branch")
graph.add_conditional_edges("branch", branching_router)
graph.add_edge(start_key="search", end_key="dedup_urls")
graph.add_conditional_edges("dedup_urls", fetch_router)
graph.add_edge(start_key="fetch", end_key="collect_raw_docs")
graph.add_conditional_edges("collect_raw_docs", summarize_router)
graph.add_edge(start_key="summarize", end_key="reducer")
graph.add_edge(start_key="reducer", end_key=END)
return graph

View File

@@ -0,0 +1,53 @@
from abc import ABC
from abc import abstractmethod
from collections.abc import Sequence
from datetime import datetime
from enum import Enum
from pydantic import BaseModel
from pydantic import field_validator
from onyx.utils.url import normalize_url
class ProviderType(Enum):
"""Enum for internet search provider types"""
GOOGLE = "google"
EXA = "exa"
class WebSearchResult(BaseModel):
title: str
link: str
snippet: str | None = None
author: str | None = None
published_date: datetime | None = None
@field_validator("link")
@classmethod
def normalize_link(cls, v: str) -> str:
return normalize_url(v)
class WebContent(BaseModel):
title: str
link: str
full_content: str
published_date: datetime | None = None
scrape_successful: bool = True
@field_validator("link")
@classmethod
def normalize_link(cls, v: str) -> str:
return normalize_url(v)
class WebSearchProvider(ABC):
@abstractmethod
def search(self, query: str) -> Sequence[WebSearchResult]:
pass
@abstractmethod
def contents(self, urls: Sequence[str]) -> list[WebContent]:
pass

View File

@@ -0,0 +1,19 @@
from onyx.agents.agent_search.exploration.sub_agents.web_search.clients.exa_client import (
ExaClient,
)
from onyx.agents.agent_search.exploration.sub_agents.web_search.clients.serper_client import (
SerperClient,
)
from onyx.agents.agent_search.exploration.sub_agents.web_search.models import (
WebSearchProvider,
)
from onyx.configs.chat_configs import EXA_API_KEY
from onyx.configs.chat_configs import SERPER_API_KEY
def get_default_provider() -> WebSearchProvider | None:
if EXA_API_KEY:
return ExaClient()
if SERPER_API_KEY:
return SerperClient()
return None

View File

@@ -0,0 +1,37 @@
from operator import add
from typing import Annotated
from onyx.agents.agent_search.exploration.states import LoggerUpdate
from onyx.agents.agent_search.exploration.sub_agents.states import SubAgentInput
from onyx.agents.agent_search.exploration.sub_agents.web_search.models import (
WebSearchResult,
)
from onyx.context.search.models import InferenceSection
class InternetSearchInput(SubAgentInput):
results_to_open: Annotated[list[tuple[str, WebSearchResult]], add] = []
parallelization_nr: int = 0
branch_question: Annotated[str, lambda x, y: y] = ""
branch_questions_to_urls: Annotated[dict[str, list[str]], lambda x, y: y] = {}
raw_documents: Annotated[list[InferenceSection], add] = []
class InternetSearchUpdate(LoggerUpdate):
results_to_open: Annotated[list[tuple[str, WebSearchResult]], add] = []
class FetchInput(SubAgentInput):
urls_to_open: Annotated[list[str], add] = []
branch_questions_to_urls: dict[str, list[str]]
raw_documents: Annotated[list[InferenceSection], add] = []
class FetchUpdate(LoggerUpdate):
raw_documents: Annotated[list[InferenceSection], add] = []
class SummarizeInput(SubAgentInput):
raw_documents: Annotated[list[InferenceSection], add] = []
branch_questions_to_urls: dict[str, list[str]]
branch_question: str

View File

@@ -0,0 +1,99 @@
from onyx.agents.agent_search.exploration.sub_agents.web_search.models import (
WebContent,
)
from onyx.agents.agent_search.exploration.sub_agents.web_search.models import (
WebSearchResult,
)
from onyx.chat.models import DOCUMENT_CITATION_NUMBER_EMPTY_VALUE
from onyx.chat.models import LlmDoc
from onyx.configs.constants import DocumentSource
from onyx.context.search.models import InferenceChunk
from onyx.context.search.models import InferenceSection
def truncate_search_result_content(content: str, max_chars: int = 10000) -> str:
"""Truncate search result content to a maximum number of characters"""
if len(content) <= max_chars:
return content
return content[:max_chars] + "..."
def dummy_inference_section_from_internet_content(
result: WebContent,
) -> InferenceSection:
truncated_content = truncate_search_result_content(result.full_content)
return InferenceSection(
center_chunk=InferenceChunk(
chunk_id=0,
blurb=result.title,
content=truncated_content,
source_links={0: result.link},
section_continuation=False,
document_id="INTERNET_SEARCH_DOC_" + result.link,
source_type=DocumentSource.WEB,
semantic_identifier=result.link,
title=result.title,
boost=1,
recency_bias=1.0,
score=1.0,
hidden=(not result.scrape_successful),
metadata={},
match_highlights=[],
doc_summary=truncated_content,
chunk_context=truncated_content,
updated_at=result.published_date,
image_file_id=None,
),
chunks=[],
combined_content=truncated_content,
)
def dummy_inference_section_from_internet_search_result(
result: WebSearchResult,
) -> InferenceSection:
return InferenceSection(
center_chunk=InferenceChunk(
chunk_id=0,
blurb=result.title,
content="",
source_links={0: result.link},
section_continuation=False,
document_id="INTERNET_SEARCH_DOC_" + result.link,
source_type=DocumentSource.WEB,
semantic_identifier=result.link,
title=result.title,
boost=1,
recency_bias=1.0,
score=1.0,
hidden=False,
metadata={},
match_highlights=[],
doc_summary="",
chunk_context="",
updated_at=result.published_date,
image_file_id=None,
),
chunks=[],
combined_content="",
)
def llm_doc_from_web_content(web_content: WebContent) -> LlmDoc:
"""Create an LlmDoc from WebContent with the INTERNET_SEARCH_DOC_ prefix"""
return LlmDoc(
# TODO: Is this what we want to do for document_id? We're kind of overloading it since it
# should ideally correspond to a document in the database. But I guess if you're calling this
# function you know it won't be in the database.
document_id="INTERNET_SEARCH_DOC_" + web_content.link,
content=truncate_search_result_content(web_content.full_content),
blurb=web_content.link,
semantic_identifier=web_content.link,
source_type=DocumentSource.WEB,
metadata={},
link=web_content.link,
document_citation_number=DOCUMENT_CITATION_NUMBER_EMPTY_VALUE,
updated_at=web_content.published_date,
source_links={},
match_highlights=[],
)

View File

@@ -0,0 +1,277 @@
import copy
import re
from langchain.schema.messages import BaseMessage
from langchain.schema.messages import HumanMessage
from onyx.agents.agent_search.exploration.models import AggregatedDRContext
from onyx.agents.agent_search.exploration.models import IterationAnswer
from onyx.agents.agent_search.exploration.models import OrchestrationClarificationInfo
from onyx.agents.agent_search.kb_search.graph_utils import build_document_context
from onyx.agents.agent_search.shared_graph_utils.operators import (
dedup_inference_section_list,
)
from onyx.context.search.models import InferenceSection
from onyx.context.search.models import SavedSearchDoc
from onyx.context.search.models import SearchDoc
from onyx.tools.tool_implementations.web_search.web_search_tool import (
WebSearchTool,
)
CITATION_PREFIX = "CITE:"
def extract_document_citations(
answer: str, claims: list[str]
) -> tuple[list[int], str, list[str]]:
"""
Finds all citations of the form [1], [2, 3], etc. and returns the list of cited indices,
as well as the answer and claims with the citations replaced with [<CITATION_PREFIX>1],
etc., to help with citation deduplication later on.
"""
citations: set[int] = set()
# Pattern to match both single citations [1] and multiple citations [1, 2, 3]
# This regex matches:
# - \[(\d+)\] for single citations like [1]
# - \[(\d+(?:,\s*\d+)*)\] for multiple citations like [1, 2, 3]
pattern = re.compile(r"\[(\d+(?:,\s*\d+)*)\]")
def _extract_and_replace(match: re.Match[str]) -> str:
numbers = [int(num) for num in match.group(1).split(",")]
citations.update(numbers)
return "".join(f"[{CITATION_PREFIX}{num}]" for num in numbers)
new_answer = pattern.sub(_extract_and_replace, answer)
new_claims = [pattern.sub(_extract_and_replace, claim) for claim in claims]
return list(citations), new_answer, new_claims
def aggregate_context(
iteration_responses: list[IterationAnswer], include_documents: bool = True
) -> AggregatedDRContext:
"""
Converts the iteration response into a single string with unified citations.
For example,
it 1: the answer is x [3][4]. {3: doc_abc, 4: doc_xyz}
it 2: blah blah [1, 3]. {1: doc_xyz, 3: doc_pqr}
Output:
it 1: the answer is x [1][2].
it 2: blah blah [2][3]
[1]: doc_xyz
[2]: doc_abc
[3]: doc_pqr
"""
# dedupe and merge inference section contents
unrolled_inference_sections: list[InferenceSection] = []
is_internet_marker_dict: dict[str, bool] = {}
for iteration_response in sorted(
iteration_responses,
key=lambda x: (x.iteration_nr, x.parallelization_nr),
):
iteration_tool = iteration_response.tool
is_internet = iteration_tool == WebSearchTool._NAME
for cited_doc in iteration_response.cited_documents.values():
unrolled_inference_sections.append(cited_doc)
if cited_doc.center_chunk.document_id not in is_internet_marker_dict:
is_internet_marker_dict[cited_doc.center_chunk.document_id] = (
is_internet
)
cited_doc.center_chunk.score = None # None means maintain order
global_documents = dedup_inference_section_list(unrolled_inference_sections)
global_citations = {
doc.center_chunk.document_id: i for i, doc in enumerate(global_documents, 1)
}
# build output string
output_strings: list[str] = []
global_iteration_responses: list[IterationAnswer] = []
for iteration_response in sorted(
iteration_responses,
key=lambda x: (x.iteration_nr, x.parallelization_nr),
):
# add basic iteration info
output_strings.append(
f"Iteration: {iteration_response.iteration_nr}, "
f"Question {iteration_response.parallelization_nr}"
)
output_strings.append(f"Tool: {iteration_response.tool}")
output_strings.append(f"Question: {iteration_response.question}")
# get answer and claims with global citations
answer_str = iteration_response.answer
claims = iteration_response.claims or []
iteration_citations: list[int] = []
for local_number, cited_doc in iteration_response.cited_documents.items():
global_number = global_citations[cited_doc.center_chunk.document_id]
# translate local citations to global citations
answer_str = answer_str.replace(
f"[{CITATION_PREFIX}{local_number}]", f"[{global_number}]"
)
claims = [
claim.replace(
f"[{CITATION_PREFIX}{local_number}]", f"[{global_number}]"
)
for claim in claims
]
iteration_citations.append(global_number)
# add answer, claims, and citation info
if answer_str:
output_strings.append(f"Answer: {answer_str}")
if claims:
output_strings.append(
"Claims: " + "".join(f"\n - {claim}" for claim in claims or [])
or "No claims provided"
)
if not answer_str and not claims:
output_strings.append(
"Retrieved documents: "
+ (
"".join(
f"[{global_number}]"
for global_number in sorted(iteration_citations)
)
or "No documents retrieved"
)
)
output_strings.append("\n---\n")
# save global iteration response
iteration_response_copy = iteration_response.model_copy()
iteration_response_copy.answer = answer_str
iteration_response_copy.claims = claims
iteration_response_copy.cited_documents = {
global_citations[doc.center_chunk.document_id]: doc
for doc in iteration_response.cited_documents.values()
}
global_iteration_responses.append(iteration_response_copy)
# add document contents if requested
if include_documents:
if global_documents:
output_strings.append("Cited document contents:")
for doc in global_documents:
output_strings.append(
build_document_context(
doc, global_citations[doc.center_chunk.document_id]
)
)
output_strings.append("\n---\n")
return AggregatedDRContext(
context="\n".join(output_strings),
cited_documents=global_documents,
is_internet_marker_dict=is_internet_marker_dict,
global_iteration_responses=global_iteration_responses,
)
def get_chat_history_string(chat_history: list[BaseMessage], max_messages: int) -> str:
"""
Get the chat history (up to max_messages) as a string.
"""
# get past max_messages USER, ASSISTANT message pairs
past_messages = chat_history[-max_messages * 2 :]
filtered_past_messages = copy.deepcopy(past_messages)
for past_message_number, past_message in enumerate(past_messages):
if isinstance(past_message.content, list):
removal_indices = []
for content_piece_number, content_piece in enumerate(past_message.content):
if (
isinstance(content_piece, dict)
and content_piece.get("type") != "text"
):
removal_indices.append(content_piece_number)
# Only rebuild the content list if there are items to remove
if removal_indices:
filtered_past_messages[past_message_number].content = [
content_piece
for content_piece_number, content_piece in enumerate(
past_message.content
)
if content_piece_number not in removal_indices
]
else:
continue
return (
"...\n" if len(chat_history) > len(filtered_past_messages) else ""
) + "\n".join(
("user" if isinstance(msg, HumanMessage) else "you")
+ f": {str(msg.content).strip()}"
for msg in filtered_past_messages
)
def get_prompt_question(
question: str, clarification: OrchestrationClarificationInfo | None
) -> str:
if clarification:
clarification_question = clarification.clarification_question
clarification_response = clarification.clarification_response
return (
f"Initial User Question: {question}\n"
f"(Clarification Question: {clarification_question}\n"
f"User Response: {clarification_response})"
)
return question
def create_tool_call_string(tool_name: str, query_list: list[str]) -> str:
"""
Create a string representation of the tool call.
"""
questions_str = "\n - ".join(query_list)
return f"Tool: {tool_name}\n\nQuestions:\n{questions_str}"
def parse_plan_to_dict(plan_text: str) -> dict[str, str]:
# Convert plan string to numbered dict format
if not plan_text:
return {}
# Split by numbered items (1., 2., 3., etc. or 1), 2), 3), etc.)
parts = re.split(r"(\d+[.)])", plan_text)
plan_dict = {}
for i in range(
1, len(parts), 2
): # Skip empty first part, then take number and text pairs
if i + 1 < len(parts):
number = parts[i].rstrip(".)") # Remove the dot or parenthesis
text = parts[i + 1].strip()
if text: # Only add if there's actual content
plan_dict[number] = text
return plan_dict
def convert_inference_sections_to_search_docs(
inference_sections: list[InferenceSection],
is_internet: bool = False,
) -> list[SavedSearchDoc]:
# Convert InferenceSections to SavedSearchDocs
search_docs = SearchDoc.from_chunks_or_sections(inference_sections)
for search_doc in search_docs:
search_doc.is_internet = is_internet
retrieved_saved_search_docs = [
SavedSearchDoc.from_search_doc(search_doc, db_doc_id=0)
for search_doc in search_docs
]
return retrieved_saved_search_docs

View File

@@ -12,6 +12,10 @@ from onyx.agents.agent_search.dc_search_analysis.graph_builder import (
from onyx.agents.agent_search.dc_search_analysis.states import MainInput as DCMainInput
from onyx.agents.agent_search.dr.graph_builder import dr_graph_builder
from onyx.agents.agent_search.dr.states import MainInput as DRMainInput
from onyx.agents.agent_search.exploration.graph_builder import exploration_graph_builder
from onyx.agents.agent_search.exploration.states import (
MainInput as ExplorationMainInput,
)
from onyx.agents.agent_search.kb_search.graph_builder import kb_graph_builder
from onyx.agents.agent_search.kb_search.states import MainInput as KBMainInput
from onyx.agents.agent_search.models import GraphConfig
@@ -81,6 +85,16 @@ def run_dr_graph(
yield from run_graph(compiled_graph, config, input)
def run_exploration_graph(
config: GraphConfig,
) -> AnswerStream:
graph = exploration_graph_builder()
compiled_graph = graph.compile()
input = ExplorationMainInput(log_messages=[])
yield from run_graph(compiled_graph, config, input)
def run_dc_graph(
config: GraphConfig,
) -> AnswerStream:

View File

@@ -7,6 +7,7 @@ from typing import TypeVar
from braintrust import traced
from langchain.schema.language_model import LanguageModelInput
from langchain_core.messages import AIMessage
from langchain_core.messages import HumanMessage
from langgraph.types import StreamWriter
from pydantic import BaseModel
@@ -183,6 +184,32 @@ def invoke_llm_json(
return schema.model_validate_json(response_content)
# FOR EXPERIMENTAL USE ONLY
def invoke_llm_raw(
llm: LLM,
prompt: LanguageModelInput,
tools: list[dict] | None = None,
tool_choice: ToolChoiceOptions | None = None,
timeout_override: int | None = None,
max_tokens: int | None = None,
) -> AIMessage:
"""
Invoke an LLM, forcing it to respond in a specified JSON format if possible,
and return an object of that schema.
"""
response_content = llm.invoke_langchain(
prompt,
tools=tools,
tool_choice=tool_choice,
timeout_override=timeout_override,
max_tokens=max_tokens,
**cast(dict, {}),
)
return cast(AIMessage, response_content)
def get_answer_from_llm(
llm: LLM,
prompt: str,

View File

@@ -5,12 +5,13 @@ from uuid import UUID
from sqlalchemy.orm import Session
from onyx.agents.agent_search.dr.enums import ResearchType
from onyx.agents.agent_search.exploration.enums import ResearchType as ExpResearchType
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.models import GraphInputs
from onyx.agents.agent_search.models import GraphPersistence
from onyx.agents.agent_search.models import GraphSearchConfig
from onyx.agents.agent_search.models import GraphTooling
from onyx.agents.agent_search.run_graph import run_dr_graph
from onyx.agents.agent_search.run_graph import run_exploration_graph
from onyx.chat.models import AnswerStream
from onyx.chat.models import AnswerStreamPart
from onyx.chat.models import AnswerStyleConfig
@@ -32,6 +33,7 @@ from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.tools.utils import explicit_tool_calling_supported
from onyx.utils.gpu_utils import fast_gpu_status_request
from onyx.utils.logger import setup_logger
from shared_configs.configs import EXPLORATION_TEST_TYPE
logger = setup_logger()
@@ -60,7 +62,7 @@ class Answer:
skip_gen_ai_answer_generation: bool = False,
is_connected: Callable[[], bool] | None = None,
use_agentic_search: bool = False,
research_type: ResearchType | None = None,
research_type: ResearchType | ExpResearchType | None = None,
research_plan: dict[str, Any] | None = None,
project_instructions: str | None = None,
) -> None:
@@ -121,6 +123,9 @@ class Answer:
else:
research_type = ResearchType.THOUGHTFUL
if EXPLORATION_TEST_TYPE != "not_explicitly_set":
research_type = ExpResearchType(research_type)
self.search_behavior_config = GraphSearchConfig(
use_agentic_search=use_agentic_search,
skip_gen_ai_answer_generation=skip_gen_ai_answer_generation,
@@ -144,7 +149,8 @@ class Answer:
return
# TODO: add toggle in UI with customizable TimeBudget
stream = run_dr_graph(self.graph_config)
# stream = run_dr_graph(self.graph_config)
stream = run_exploration_graph(self.graph_config)
processed_stream: list[AnswerStreamPart] = []
for packet in stream:

View File

@@ -0,0 +1,20 @@
import os
EXPLORATION_TEST_USE_DC_DEFAULT = (
os.environ.get("EXPLORATION_TEST_USE_DC_DEFAULT") or "false"
).lower() == "true"
EXPLORATION_TEST_USE_CALRIFIER_DEFAULT = (
os.environ.get("EXPLORATION_TEST_USE_CALRIFIER_DEFAULT") or "false"
).lower() == "true"
EXPLORATION_TEST_USE_PLAN_DEFAULT = (
os.environ.get("EXPLORATION_TEST_USE_PLAN_DEFAULT") or "false"
).lower() == "true"
EXPLORATION_TEST_USE_PLAN_UPDATES_DEFAULT = (
os.environ.get("EXPLORATION_TEST_USE_PLAN_UPDATES_DEFAULT") or "false"
).lower() == "true"
EXPLORATION_TEST_USE_CORPUS_HISTORY_DEFAULT = (
os.environ.get("EXPLORATION_TEST_USE_CORPUS_HISTORY_DEFAULT") or "false"
).lower() == "true"
EXPLORATION_TEST_USE_THINKING_DEFAULT = (
os.environ.get("EXPLORATION_TEST_USE_THINKING_DEFAULT") or "false"
).lower() == "true"

View File

@@ -0,0 +1,85 @@
from datetime import datetime
from uuid import UUID
from sqlalchemy import select
from sqlalchemy.orm import Session
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import Document
from onyx.db.models import DocumentByConnectorCredentialPair
from onyx.db.models import SubscriptionRegistration
from onyx.db.models import SubscriptionResult
def get_subscription_registration(
db_session: Session, user_id: str
) -> SubscriptionRegistration:
return (
db_session.query(SubscriptionRegistration)
.filter(SubscriptionRegistration.user_id == user_id)
.first()
)
def get_subscription_result(db_session: Session, user_id: UUID) -> SubscriptionResult:
return (
db_session.query(SubscriptionResult)
.filter(SubscriptionResult.user_id == user_id)
.order_by(SubscriptionResult.created_at.desc())
.first()
)
def save_subscription_result(
db_session: Session, subscription_result: SubscriptionResult
) -> None:
db_session.add(subscription_result)
db_session.commit()
def get_document_ids_by_cc_pair_name(
db_session: Session,
cc_pair_name: str,
date_threshold: datetime | None = None,
) -> list[tuple[str, str | None]]:
"""
Get all document IDs and links associated with a connector credential pair by its name,
optionally filtered by documents updated after a date threshold.
Args:
db_session: Database session
cc_pair_name: Name of the connector credential pair
date_threshold: Optional datetime to filter documents updated after this date
Returns:
List of tuples containing (document_id, document_link)
"""
# First, get the connector_id and credential_id from the ConnectorCredentialPair by name
cc_pair = (
db_session.query(ConnectorCredentialPair)
.filter(ConnectorCredentialPair.name == cc_pair_name)
.first()
)
if not cc_pair:
return []
# Build query to get document IDs and links associated with this connector/credential pair
stmt = (
select(DocumentByConnectorCredentialPair.id, Document.link)
.join(
Document,
DocumentByConnectorCredentialPair.id == Document.id,
)
.where(
DocumentByConnectorCredentialPair.connector_id == cc_pair.connector_id,
DocumentByConnectorCredentialPair.credential_id == cc_pair.credential_id,
)
)
# Add date threshold filter if provided
if date_threshold is not None:
stmt = stmt.where(Document.doc_updated_at >= date_threshold)
results = db_session.execute(stmt).all()
return [(doc_id, link) for doc_id, link in results]

View File

@@ -194,6 +194,9 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
personal_name: Mapped[str | None] = mapped_column(String, nullable=True)
personal_role: Mapped[str | None] = mapped_column(String, nullable=True)
use_memories: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
cheat_sheet_context: Mapped[dict[str, Any] | None] = mapped_column(
postgresql.JSONB(), nullable=True, default=None
)
chosen_assistants: Mapped[list[int] | None] = mapped_column(
postgresql.JSONB(), nullable=True, default=None
@@ -3952,3 +3955,72 @@ class ExternalGroupPermissionSyncAttempt(Base):
def is_finished(self) -> bool:
return self.status.is_terminal()
# EXPLORATION TESTING
class TemporaryUserCheatSheetContext(Base):
"""
Represents the context of a user's cheat sheet. Replace with column in user table once
login is working again.
"""
__tablename__ = "temporary_user_cheat_sheet_context"
id: Mapped[int] = mapped_column(primary_key=True)
# user_id: Mapped[UUID] = mapped_column(ForeignKey("user.id"))
context: Mapped[dict[str, Any]] = mapped_column(postgresql.JSONB())
created_at: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now()
)
updated_at: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
)
# HACKATHON CHANGES
class SubscriptionRegistration(Base):
"""
Represents a user's subscription registration with document extraction contexts
and search questions.
"""
__tablename__ = "subscription_registrations"
id: Mapped[int] = mapped_column(primary_key=True)
user_id: Mapped[UUID] = mapped_column(ForeignKey("user.id"), nullable=False)
doc_extraction_contexts: Mapped[dict[str, str]] = mapped_column(
postgresql.JSONB(), nullable=False
)
search_questions: Mapped[list[str]] = mapped_column(
postgresql.ARRAY(String), nullable=False
)
created_at: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now()
)
updated_at: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
)
class SubscriptionResult(Base):
"""
Represents the results of a subscription for a user, including notifications.
"""
__tablename__ = "subscription_results"
id: Mapped[int] = mapped_column(primary_key=True)
user_id: Mapped[UUID] = mapped_column(ForeignKey("user.id"), nullable=False)
type: Mapped[str] = mapped_column(String, nullable=False)
notifications: Mapped[dict[str, Any]] = mapped_column(
postgresql.JSONB(), nullable=False
)
created_at: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now()
)
updated_at: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
)

View File

@@ -0,0 +1,32 @@
from typing import Any
from sqlalchemy import select
from sqlalchemy.orm import Session
from onyx.db.models import TemporaryUserCheatSheetContext
# EXPLORATION TESTING
def get_user_cheat_sheet_context(db_session: Session) -> dict[str, Any] | None:
stmt = select(TemporaryUserCheatSheetContext).order_by(
TemporaryUserCheatSheetContext.created_at.desc()
)
result = db_session.execute(stmt).scalar_one_or_none()
return result.context if result else None
def update_user_cheat_sheet_context(
db_session: Session, new_cheat_sheet_context: dict[str, Any]
) -> None:
stmt = select(TemporaryUserCheatSheetContext).order_by(
TemporaryUserCheatSheetContext.created_at.desc()
)
result = db_session.execute(stmt).scalar_one_or_none()
if result:
result.context = new_cheat_sheet_context
db_session.commit()
else:
new_context = TemporaryUserCheatSheetContext(context=new_cheat_sheet_context)
db_session.add(new_context)
db_session.commit()

View File

@@ -342,3 +342,19 @@ def delete_user_from_db(
# NOTE: edge case may exist with race conditions
# with this `invited user` scheme generally.
remove_user_from_invited_users(user_to_delete.email)
# EXPLORATION TESTING
def get_user_cheat_sheet_context(
user: User, db_session: Session
) -> dict[str, Any] | None:
return user.cheat_sheet_context
def update_user_cheat_sheet_context(
user: User, new_cheat_sheet_context: dict[str, Any], db_session: Session
) -> None:
user.cheat_sheet_context = new_cheat_sheet_context
db_session.flush() # Make the change visible to subsequent queries in the same transaction

View File

@@ -207,6 +207,8 @@ def construct_tools(
continue
if db_tool_model.in_code_tool_id:
if db_tool_model.in_code_tool_id != "SearchTool":
continue
tool_cls = get_built_in_tool_by_id(db_tool_model.in_code_tool_id)
try:

View File

@@ -220,3 +220,5 @@ INDEXING_INFORMATION_CONTENT_CLASSIFICATION_CUTOFF_LENGTH = int(
)
ENVIRONMENT = os.environ.get("ENVIRONMENT") or "not_explicitly_set"
EXPLORATION_TEST_TYPE = os.environ.get("EXPLORATION_TEST_TYPE") or "not_explicitly_set"