Compare commits

..

3 Commits

Author SHA1 Message Date
SubashMohan
f71290fe05 move pageheader component to layouts 2025-12-01 19:00:26 +05:30
Chris Weaver
6b7c6c9a37 fix: icon coloring in Renderer (#6491) 2025-11-30 18:53:26 -08:00
SubashMohan
53ae1b598b fix(WebSearch): adjust Separator styling for improved layout consistency (#6487) 2025-11-30 11:45:37 +05:30
289 changed files with 26786 additions and 27597 deletions

2
.gitignore vendored
View File

@@ -49,7 +49,5 @@ CLAUDE.md
# Local .terraform.lock.hcl file
.terraform.lock.hcl
node_modules
# MCP configs
.playwright-mcp

View File

@@ -1,104 +0,0 @@
"""add_open_url_tool
Revision ID: 4f8a2b3c1d9e
Revises: a852cbe15577
Create Date: 2025-11-24 12:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "4f8a2b3c1d9e"
down_revision = "a852cbe15577"
branch_labels = None
depends_on = None
OPEN_URL_TOOL = {
"name": "OpenURLTool",
"display_name": "Open URL",
"description": (
"The Open URL Action allows the agent to fetch and read contents of web pages."
),
"in_code_tool_id": "OpenURLTool",
"enabled": True,
}
def upgrade() -> None:
conn = op.get_bind()
# Check if tool already exists
existing = conn.execute(
sa.text("SELECT id FROM tool WHERE in_code_tool_id = :in_code_tool_id"),
{"in_code_tool_id": OPEN_URL_TOOL["in_code_tool_id"]},
).fetchone()
if existing:
tool_id = existing[0]
# Update existing tool
conn.execute(
sa.text(
"""
UPDATE tool
SET name = :name,
display_name = :display_name,
description = :description
WHERE in_code_tool_id = :in_code_tool_id
"""
),
OPEN_URL_TOOL,
)
else:
# Insert new tool
conn.execute(
sa.text(
"""
INSERT INTO tool (name, display_name, description, in_code_tool_id, enabled)
VALUES (:name, :display_name, :description, :in_code_tool_id, :enabled)
"""
),
OPEN_URL_TOOL,
)
# Get the newly inserted tool's id
result = conn.execute(
sa.text("SELECT id FROM tool WHERE in_code_tool_id = :in_code_tool_id"),
{"in_code_tool_id": OPEN_URL_TOOL["in_code_tool_id"]},
).fetchone()
tool_id = result[0] # type: ignore
# Associate the tool with all existing personas
# Get all persona IDs
persona_ids = conn.execute(sa.text("SELECT id FROM persona")).fetchall()
for (persona_id,) in persona_ids:
# Check if association already exists
exists = conn.execute(
sa.text(
"""
SELECT 1 FROM persona__tool
WHERE persona_id = :persona_id AND tool_id = :tool_id
"""
),
{"persona_id": persona_id, "tool_id": tool_id},
).fetchone()
if not exists:
conn.execute(
sa.text(
"""
INSERT INTO persona__tool (persona_id, tool_id)
VALUES (:persona_id, :tool_id)
"""
),
{"persona_id": persona_id, "tool_id": tool_id},
)
def downgrade() -> None:
# We don't remove the tool on downgrade since it's fine to have it around.
# If we upgrade again, it will be a no-op.
pass

View File

@@ -1,572 +0,0 @@
"""New Chat History
Revision ID: a852cbe15577
Revises: 6436661d5b65
Create Date: 2025-11-08 15:16:37.781308
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "a852cbe15577"
down_revision = "6436661d5b65"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Drop research agent tables (if they exist)
op.execute("DROP TABLE IF EXISTS research_agent_iteration_sub_step CASCADE")
op.execute("DROP TABLE IF EXISTS research_agent_iteration CASCADE")
# Drop agent sub query and sub question tables (if they exist)
op.execute("DROP TABLE IF EXISTS agent__sub_query__search_doc CASCADE")
op.execute("DROP TABLE IF EXISTS agent__sub_query CASCADE")
op.execute("DROP TABLE IF EXISTS agent__sub_question CASCADE")
# Update ChatMessage table
# Rename parent_message to parent_message_id and make it a foreign key (if not already done)
conn = op.get_bind()
result = conn.execute(
sa.text(
"""
SELECT column_name FROM information_schema.columns
WHERE table_name = 'chat_message' AND column_name = 'parent_message'
"""
)
)
if result.fetchone():
op.alter_column(
"chat_message", "parent_message", new_column_name="parent_message_id"
)
op.create_foreign_key(
"fk_chat_message_parent_message_id",
"chat_message",
"chat_message",
["parent_message_id"],
["id"],
)
# Rename latest_child_message to latest_child_message_id and make it a foreign key (if not already done)
result = conn.execute(
sa.text(
"""
SELECT column_name FROM information_schema.columns
WHERE table_name = 'chat_message' AND column_name = 'latest_child_message'
"""
)
)
if result.fetchone():
op.alter_column(
"chat_message",
"latest_child_message",
new_column_name="latest_child_message_id",
)
op.create_foreign_key(
"fk_chat_message_latest_child_message_id",
"chat_message",
"chat_message",
["latest_child_message_id"],
["id"],
)
# Add reasoning_tokens column (if not exists)
result = conn.execute(
sa.text(
"""
SELECT column_name FROM information_schema.columns
WHERE table_name = 'chat_message' AND column_name = 'reasoning_tokens'
"""
)
)
if not result.fetchone():
op.add_column(
"chat_message", sa.Column("reasoning_tokens", sa.Text(), nullable=True)
)
# Drop columns no longer needed (if they exist)
for col in [
"rephrased_query",
"alternate_assistant_id",
"overridden_model",
"is_agentic",
"refined_answer_improvement",
"research_type",
"research_plan",
"research_answer_purpose",
]:
result = conn.execute(
sa.text(
f"""
SELECT column_name FROM information_schema.columns
WHERE table_name = 'chat_message' AND column_name = '{col}'
"""
)
)
if result.fetchone():
op.drop_column("chat_message", col)
# Update ToolCall table
# Add chat_session_id column (if not exists)
result = conn.execute(
sa.text(
"""
SELECT column_name FROM information_schema.columns
WHERE table_name = 'tool_call' AND column_name = 'chat_session_id'
"""
)
)
if not result.fetchone():
op.add_column(
"tool_call",
sa.Column("chat_session_id", postgresql.UUID(as_uuid=True), nullable=False),
)
op.create_foreign_key(
"fk_tool_call_chat_session_id",
"tool_call",
"chat_session",
["chat_session_id"],
["id"],
)
# Rename message_id to parent_chat_message_id and make nullable (if not already done)
result = conn.execute(
sa.text(
"""
SELECT column_name FROM information_schema.columns
WHERE table_name = 'tool_call' AND column_name = 'message_id'
"""
)
)
if result.fetchone():
op.alter_column(
"tool_call",
"message_id",
new_column_name="parent_chat_message_id",
nullable=True,
)
# Add parent_tool_call_id (if not exists)
result = conn.execute(
sa.text(
"""
SELECT column_name FROM information_schema.columns
WHERE table_name = 'tool_call' AND column_name = 'parent_tool_call_id'
"""
)
)
if not result.fetchone():
op.add_column(
"tool_call", sa.Column("parent_tool_call_id", sa.Integer(), nullable=True)
)
op.create_foreign_key(
"fk_tool_call_parent_tool_call_id",
"tool_call",
"tool_call",
["parent_tool_call_id"],
["id"],
)
op.drop_constraint("uq_tool_call_message_id", "tool_call", type_="unique")
# Add turn_number, tool_id (if not exists)
for col_name in ["turn_number", "tool_id"]:
result = conn.execute(
sa.text(
f"""
SELECT column_name FROM information_schema.columns
WHERE table_name = 'tool_call' AND column_name = '{col_name}'
"""
)
)
if not result.fetchone():
op.add_column(
"tool_call",
sa.Column(col_name, sa.Integer(), nullable=False, server_default="0"),
)
# Add tool_call_id as String (if not exists)
result = conn.execute(
sa.text(
"""
SELECT column_name FROM information_schema.columns
WHERE table_name = 'tool_call' AND column_name = 'tool_call_id'
"""
)
)
if not result.fetchone():
op.add_column(
"tool_call",
sa.Column("tool_call_id", sa.String(), nullable=False, server_default=""),
)
# Add reasoning_tokens (if not exists)
result = conn.execute(
sa.text(
"""
SELECT column_name FROM information_schema.columns
WHERE table_name = 'tool_call' AND column_name = 'reasoning_tokens'
"""
)
)
if not result.fetchone():
op.add_column(
"tool_call", sa.Column("reasoning_tokens", sa.Text(), nullable=True)
)
# Rename tool_arguments to tool_call_arguments (if not already done)
result = conn.execute(
sa.text(
"""
SELECT column_name FROM information_schema.columns
WHERE table_name = 'tool_call' AND column_name = 'tool_arguments'
"""
)
)
if result.fetchone():
op.alter_column(
"tool_call", "tool_arguments", new_column_name="tool_call_arguments"
)
# Rename tool_result to tool_call_response and change type from JSONB to Text (if not already done)
result = conn.execute(
sa.text(
"""
SELECT column_name, data_type FROM information_schema.columns
WHERE table_name = 'tool_call' AND column_name = 'tool_result'
"""
)
)
tool_result_row = result.fetchone()
if tool_result_row:
op.alter_column(
"tool_call", "tool_result", new_column_name="tool_call_response"
)
# Change type from JSONB to Text
op.execute(
sa.text(
"""
ALTER TABLE tool_call
ALTER COLUMN tool_call_response TYPE TEXT
USING tool_call_response::text
"""
)
)
else:
# Check if tool_call_response already exists and is JSONB, then convert to Text
result = conn.execute(
sa.text(
"""
SELECT data_type FROM information_schema.columns
WHERE table_name = 'tool_call' AND column_name = 'tool_call_response'
"""
)
)
tool_call_response_row = result.fetchone()
if tool_call_response_row and tool_call_response_row[0] == "jsonb":
op.execute(
sa.text(
"""
ALTER TABLE tool_call
ALTER COLUMN tool_call_response TYPE TEXT
USING tool_call_response::text
"""
)
)
# Add tool_call_tokens (if not exists)
result = conn.execute(
sa.text(
"""
SELECT column_name FROM information_schema.columns
WHERE table_name = 'tool_call' AND column_name = 'tool_call_tokens'
"""
)
)
if not result.fetchone():
op.add_column(
"tool_call",
sa.Column(
"tool_call_tokens", sa.Integer(), nullable=False, server_default="0"
),
)
# Add generated_images column for image generation tool replay (if not exists)
result = conn.execute(
sa.text(
"""
SELECT column_name FROM information_schema.columns
WHERE table_name = 'tool_call' AND column_name = 'generated_images'
"""
)
)
if not result.fetchone():
op.add_column(
"tool_call",
sa.Column("generated_images", postgresql.JSONB(), nullable=True),
)
# Drop tool_name column (if exists)
result = conn.execute(
sa.text(
"""
SELECT column_name FROM information_schema.columns
WHERE table_name = 'tool_call' AND column_name = 'tool_name'
"""
)
)
if result.fetchone():
op.drop_column("tool_call", "tool_name")
# Create tool_call__search_doc association table (if not exists)
result = conn.execute(
sa.text(
"""
SELECT table_name FROM information_schema.tables
WHERE table_name = 'tool_call__search_doc'
"""
)
)
if not result.fetchone():
op.create_table(
"tool_call__search_doc",
sa.Column("tool_call_id", sa.Integer(), nullable=False),
sa.Column("search_doc_id", sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(
["tool_call_id"], ["tool_call.id"], ondelete="CASCADE"
),
sa.ForeignKeyConstraint(
["search_doc_id"], ["search_doc.id"], ondelete="CASCADE"
),
sa.PrimaryKeyConstraint("tool_call_id", "search_doc_id"),
)
# Add replace_base_system_prompt to persona table (if not exists)
result = conn.execute(
sa.text(
"""
SELECT column_name FROM information_schema.columns
WHERE table_name = 'persona' AND column_name = 'replace_base_system_prompt'
"""
)
)
if not result.fetchone():
op.add_column(
"persona",
sa.Column(
"replace_base_system_prompt",
sa.Boolean(),
nullable=False,
server_default="false",
),
)
def downgrade() -> None:
# Reverse persona changes
op.drop_column("persona", "replace_base_system_prompt")
# Drop tool_call__search_doc association table
op.execute("DROP TABLE IF EXISTS tool_call__search_doc CASCADE")
# Reverse ToolCall changes
op.add_column("tool_call", sa.Column("tool_name", sa.String(), nullable=False))
op.drop_column("tool_call", "tool_id")
op.drop_column("tool_call", "tool_call_tokens")
op.drop_column("tool_call", "generated_images")
# Change tool_call_response back to JSONB before renaming
op.execute(
sa.text(
"""
ALTER TABLE tool_call
ALTER COLUMN tool_call_response TYPE JSONB
USING tool_call_response::jsonb
"""
)
)
op.alter_column("tool_call", "tool_call_response", new_column_name="tool_result")
op.alter_column(
"tool_call", "tool_call_arguments", new_column_name="tool_arguments"
)
op.drop_column("tool_call", "reasoning_tokens")
op.drop_column("tool_call", "tool_call_id")
op.drop_column("tool_call", "turn_number")
op.drop_constraint(
"fk_tool_call_parent_tool_call_id", "tool_call", type_="foreignkey"
)
op.drop_column("tool_call", "parent_tool_call_id")
op.alter_column(
"tool_call",
"parent_chat_message_id",
new_column_name="message_id",
nullable=False,
)
op.drop_constraint("fk_tool_call_chat_session_id", "tool_call", type_="foreignkey")
op.drop_column("tool_call", "chat_session_id")
op.add_column(
"chat_message",
sa.Column(
"research_answer_purpose",
sa.Enum("INTRO", "DEEP_DIVE", name="researchanswerpurpose"),
nullable=True,
),
)
op.add_column(
"chat_message", sa.Column("research_plan", postgresql.JSONB(), nullable=True)
)
op.add_column(
"chat_message",
sa.Column(
"research_type",
sa.Enum("SIMPLE", "DEEP", name="researchtype"),
nullable=True,
),
)
op.add_column(
"chat_message",
sa.Column("refined_answer_improvement", sa.Boolean(), nullable=True),
)
op.add_column(
"chat_message",
sa.Column("is_agentic", sa.Boolean(), nullable=False, server_default="false"),
)
op.add_column(
"chat_message", sa.Column("overridden_model", sa.String(), nullable=True)
)
op.add_column(
"chat_message", sa.Column("alternate_assistant_id", sa.Integer(), nullable=True)
)
op.add_column(
"chat_message", sa.Column("rephrased_query", sa.Text(), nullable=True)
)
op.drop_column("chat_message", "reasoning_tokens")
op.drop_constraint(
"fk_chat_message_latest_child_message_id", "chat_message", type_="foreignkey"
)
op.alter_column(
"chat_message",
"latest_child_message_id",
new_column_name="latest_child_message",
)
op.drop_constraint(
"fk_chat_message_parent_message_id", "chat_message", type_="foreignkey"
)
op.alter_column(
"chat_message", "parent_message_id", new_column_name="parent_message"
)
# Recreate agent sub question and sub query tables
op.create_table(
"agent__sub_question",
sa.Column("id", sa.Integer(), primary_key=True),
sa.Column("primary_question_id", sa.Integer(), nullable=False),
sa.Column("chat_session_id", postgresql.UUID(as_uuid=True), nullable=False),
sa.Column("sub_question", sa.Text(), nullable=False),
sa.Column("level", sa.Integer(), nullable=False),
sa.Column("level_question_num", sa.Integer(), nullable=False),
sa.Column(
"time_created",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column("sub_answer", sa.Text(), nullable=False),
sa.Column("sub_question_doc_results", postgresql.JSONB(), nullable=False),
sa.ForeignKeyConstraint(
["primary_question_id"], ["chat_message.id"], ondelete="CASCADE"
),
sa.ForeignKeyConstraint(["chat_session_id"], ["chat_session.id"]),
sa.PrimaryKeyConstraint("id"),
)
op.create_table(
"agent__sub_query",
sa.Column("id", sa.Integer(), primary_key=True),
sa.Column("parent_question_id", sa.Integer(), nullable=False),
sa.Column("chat_session_id", postgresql.UUID(as_uuid=True), nullable=False),
sa.Column("sub_query", sa.Text(), nullable=False),
sa.Column(
"time_created",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.ForeignKeyConstraint(
["parent_question_id"], ["agent__sub_question.id"], ondelete="CASCADE"
),
sa.ForeignKeyConstraint(["chat_session_id"], ["chat_session.id"]),
sa.PrimaryKeyConstraint("id"),
)
op.create_table(
"agent__sub_query__search_doc",
sa.Column("sub_query_id", sa.Integer(), nullable=False),
sa.Column("search_doc_id", sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(
["sub_query_id"], ["agent__sub_query.id"], ondelete="CASCADE"
),
sa.ForeignKeyConstraint(["search_doc_id"], ["search_doc.id"]),
sa.PrimaryKeyConstraint("sub_query_id", "search_doc_id"),
)
# Recreate research agent tables
op.create_table(
"research_agent_iteration",
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
sa.Column("primary_question_id", sa.Integer(), nullable=False),
sa.Column("iteration_nr", sa.Integer(), nullable=False),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column("purpose", sa.String(), nullable=True),
sa.Column("reasoning", sa.String(), nullable=True),
sa.ForeignKeyConstraint(
["primary_question_id"], ["chat_message.id"], ondelete="CASCADE"
),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint(
"primary_question_id",
"iteration_nr",
name="_research_agent_iteration_unique_constraint",
),
)
op.create_table(
"research_agent_iteration_sub_step",
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
sa.Column("primary_question_id", sa.Integer(), nullable=False),
sa.Column("iteration_nr", sa.Integer(), nullable=False),
sa.Column("iteration_sub_step_nr", sa.Integer(), nullable=False),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column("sub_step_instructions", sa.String(), nullable=True),
sa.Column("sub_step_tool_id", sa.Integer(), nullable=True),
sa.Column("reasoning", sa.String(), nullable=True),
sa.Column("sub_answer", sa.String(), nullable=True),
sa.Column("cited_doc_results", postgresql.JSONB(), nullable=False),
sa.Column("claims", postgresql.JSONB(), nullable=True),
sa.Column("is_web_fetch", sa.Boolean(), nullable=True),
sa.Column("queries", postgresql.JSONB(), nullable=True),
sa.Column("generated_images", postgresql.JSONB(), nullable=True),
sa.Column("additional_data", postgresql.JSONB(), nullable=True),
sa.ForeignKeyConstraint(
["primary_question_id", "iteration_nr"],
[
"research_agent_iteration.primary_question_id",
"research_agent_iteration.iteration_nr",
],
ondelete="CASCADE",
),
sa.ForeignKeyConstraint(["sub_step_tool_id"], ["tool.id"], ondelete="SET NULL"),
sa.PrimaryKeyConstraint("id"),
)

View File

@@ -199,7 +199,10 @@ def fetch_persona_message_analytics(
ChatMessage.chat_session_id == ChatSession.id,
)
.where(
ChatSession.persona_id == persona_id,
or_(
ChatMessage.alternate_assistant_id == persona_id,
ChatSession.persona_id == persona_id,
),
ChatMessage.time_sent >= start,
ChatMessage.time_sent <= end,
ChatMessage.message_type == MessageType.ASSISTANT,
@@ -228,7 +231,10 @@ def fetch_persona_unique_users(
ChatMessage.chat_session_id == ChatSession.id,
)
.where(
ChatSession.persona_id == persona_id,
or_(
ChatMessage.alternate_assistant_id == persona_id,
ChatSession.persona_id == persona_id,
),
ChatMessage.time_sent >= start,
ChatMessage.time_sent <= end,
ChatMessage.message_type == MessageType.ASSISTANT,
@@ -259,7 +265,10 @@ def fetch_assistant_message_analytics(
ChatMessage.chat_session_id == ChatSession.id,
)
.where(
ChatSession.persona_id == assistant_id,
or_(
ChatMessage.alternate_assistant_id == assistant_id,
ChatSession.persona_id == assistant_id,
),
ChatMessage.time_sent >= start,
ChatMessage.time_sent <= end,
ChatMessage.message_type == MessageType.ASSISTANT,
@@ -290,7 +299,10 @@ def fetch_assistant_unique_users(
ChatMessage.chat_session_id == ChatSession.id,
)
.where(
ChatSession.persona_id == assistant_id,
or_(
ChatMessage.alternate_assistant_id == assistant_id,
ChatSession.persona_id == assistant_id,
),
ChatMessage.time_sent >= start,
ChatMessage.time_sent <= end,
ChatMessage.message_type == MessageType.ASSISTANT,
@@ -320,7 +332,10 @@ def fetch_assistant_unique_users_total(
ChatMessage.chat_session_id == ChatSession.id,
)
.where(
ChatSession.persona_id == assistant_id,
or_(
ChatMessage.alternate_assistant_id == assistant_id,
ChatSession.persona_id == assistant_id,
),
ChatMessage.time_sent >= start,
ChatMessage.time_sent <= end,
ChatMessage.message_type == MessageType.ASSISTANT,

View File

@@ -55,7 +55,18 @@ def get_empty_chat_messages_entries__paginated(
# Get assistant name (from session persona, or alternate if specified)
assistant_name = None
if chat_session.persona:
if message.alternate_assistant_id:
# If there's an alternate assistant, we need to fetch it
from onyx.db.models import Persona
alternate_persona = (
db_session.query(Persona)
.filter(Persona.id == message.alternate_assistant_id)
.first()
)
if alternate_persona:
assistant_name = alternate_persona.name
elif chat_session.persona:
assistant_name = chat_session.persona.name
message_skeletons.append(

View File

@@ -9,7 +9,7 @@ from ee.onyx.server.query_and_chat.models import (
)
from onyx.auth.users import current_user
from onyx.chat.chat_utils import combine_message_thread
from onyx.chat.chat_utils import create_chat_history_chain
from onyx.chat.chat_utils import create_chat_chain
from onyx.chat.models import ChatBasicResponse
from onyx.chat.process_message import gather_stream
from onyx.chat.process_message import stream_chat_message_objects
@@ -69,7 +69,7 @@ def handle_simplified_chat_message(
chat_session_id = chat_message_req.chat_session_id
try:
parent_message, _ = create_chat_history_chain(
parent_message, _ = create_chat_chain(
chat_session_id=chat_session_id, db_session=db_session
)
except Exception:

View File

@@ -8,29 +8,10 @@ from pydantic import model_validator
from onyx.chat.models import ThreadMessage
from onyx.configs.constants import DocumentSource
from onyx.context.search.models import BaseFilters
from onyx.context.search.models import BasicChunkRequest
from onyx.context.search.models import ChunkContext
from onyx.context.search.models import InferenceChunk
from onyx.context.search.models import RetrievalDetails
from onyx.server.manage.models import StandardAnswer
class StandardAnswerRequest(BaseModel):
message: str
slack_bot_categories: list[str]
class StandardAnswerResponse(BaseModel):
standard_answers: list[StandardAnswer] = Field(default_factory=list)
class DocumentSearchRequest(BasicChunkRequest):
user_selected_filters: BaseFilters | None = None
class DocumentSearchResponse(BaseModel):
top_documents: list[InferenceChunk]
from onyx.server.query_and_chat.streaming_models import SubQuestionIdentifier
class BasicCreateChatMessageRequest(ChunkContext):
@@ -90,17 +71,17 @@ class SimpleDoc(BaseModel):
metadata: dict | None
class AgentSubQuestion(BaseModel):
class AgentSubQuestion(SubQuestionIdentifier):
sub_question: str
document_ids: list[str]
class AgentAnswer(BaseModel):
class AgentAnswer(SubQuestionIdentifier):
answer: str
answer_type: Literal["agent_sub_answer", "agent_level_answer"]
class AgentSubQuery(BaseModel):
class AgentSubQuery(SubQuestionIdentifier):
sub_query: str
query_id: int
@@ -146,3 +127,12 @@ class AgentSubQuery(BaseModel):
sorted(level_question_dict.items(), key=lambda x: (x is None, x))
)
return sorted_dict
class StandardAnswerRequest(BaseModel):
message: str
slack_bot_categories: list[str]
class StandardAnswerResponse(BaseModel):
standard_answers: list[StandardAnswer] = Field(default_factory=list)

View File

@@ -24,7 +24,7 @@ from onyx.auth.users import current_admin_user
from onyx.auth.users import get_display_email
from onyx.background.celery.versioned_apps.client import app as client_app
from onyx.background.task_utils import construct_query_history_report_name
from onyx.chat.chat_utils import create_chat_history_chain
from onyx.chat.chat_utils import create_chat_chain
from onyx.configs.app_configs import ONYX_QUERY_HISTORY_TYPE
from onyx.configs.constants import FileOrigin
from onyx.configs.constants import FileType
@@ -123,9 +123,10 @@ def snapshot_from_chat_session(
) -> ChatSessionSnapshot | None:
try:
# Older chats may not have the right structure
messages = create_chat_history_chain(
last_message, messages = create_chat_chain(
chat_session_id=chat_session.id, db_session=db_session
)
messages.append(last_message)
except RuntimeError:
return None

View File

@@ -1,309 +1,365 @@
# import json
# from collections.abc import Callable
# from collections.abc import Iterator
# from collections.abc import Sequence
# from dataclasses import dataclass
# from typing import Any
import json
from collections.abc import Callable
from collections.abc import Iterator
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Any
# from onyx.agents.agent_framework.models import RunItemStreamEvent
# from onyx.agents.agent_framework.models import StreamEvent
# from onyx.agents.agent_framework.models import ToolCallStreamItem
# from onyx.llm.interfaces import LanguageModelInput
# from onyx.llm.interfaces import LLM
# from onyx.llm.interfaces import ToolChoiceOptions
# from onyx.llm.message_types import ChatCompletionMessage
# from onyx.llm.message_types import ToolCall
# from onyx.llm.model_response import ModelResponseStream
# from onyx.tools.tool import Tool
# from onyx.tracing.framework.create import agent_span
# from onyx.tracing.framework.create import generation_span
import onyx.tracing.framework._error_tracing as _error_tracing
from onyx.agents.agent_framework.models import RunItemStreamEvent
from onyx.agents.agent_framework.models import StreamEvent
from onyx.agents.agent_framework.models import ToolCallOutputStreamItem
from onyx.agents.agent_framework.models import ToolCallStreamItem
from onyx.llm.interfaces import LanguageModelInput
from onyx.llm.interfaces import LLM
from onyx.llm.interfaces import ToolChoiceOptions
from onyx.llm.message_types import ChatCompletionMessage
from onyx.llm.message_types import ToolCall
from onyx.llm.model_response import ModelResponseStream
from onyx.tools.tool import RunContextWrapper
from onyx.tools.tool import Tool
from onyx.tracing.framework.create import agent_span
from onyx.tracing.framework.create import function_span
from onyx.tracing.framework.create import generation_span
from onyx.tracing.framework.spans import SpanError
# @dataclass
# class QueryResult:
# stream: Iterator[StreamEvent]
# new_messages_stateful: list[ChatCompletionMessage]
@dataclass
class QueryResult:
stream: Iterator[StreamEvent]
new_messages_stateful: list[ChatCompletionMessage]
# def _serialize_tool_output(output: Any) -> str:
# if isinstance(output, str):
# return output
# try:
# return json.dumps(output)
# except TypeError:
# return str(output)
def _serialize_tool_output(output: Any) -> str:
if isinstance(output, str):
return output
try:
return json.dumps(output)
except TypeError:
return str(output)
# def _parse_tool_calls_from_message_content(
# content: str,
# ) -> list[dict[str, Any]]:
# """Parse JSON content that represents tool call instructions."""
# try:
# parsed_content = json.loads(content)
# except json.JSONDecodeError:
# return []
def _parse_tool_calls_from_message_content(
content: str,
) -> list[dict[str, Any]]:
"""Parse JSON content that represents tool call instructions."""
try:
parsed_content = json.loads(content)
except json.JSONDecodeError:
return []
# if isinstance(parsed_content, dict):
# candidates = [parsed_content]
# elif isinstance(parsed_content, list):
# candidates = [item for item in parsed_content if isinstance(item, dict)]
# else:
# return []
if isinstance(parsed_content, dict):
candidates = [parsed_content]
elif isinstance(parsed_content, list):
candidates = [item for item in parsed_content if isinstance(item, dict)]
else:
return []
# tool_calls: list[dict[str, Any]] = []
tool_calls: list[dict[str, Any]] = []
# for candidate in candidates:
# name = candidate.get("name")
# arguments = candidate.get("arguments")
for candidate in candidates:
name = candidate.get("name")
arguments = candidate.get("arguments")
# if not isinstance(name, str) or arguments is None:
# continue
if not isinstance(name, str) or arguments is None:
continue
# if not isinstance(arguments, dict):
# continue
if not isinstance(arguments, dict):
continue
# call_id = candidate.get("id")
# arguments_str = json.dumps(arguments)
# tool_calls.append(
# {
# "id": call_id,
# "name": name,
# "arguments": arguments_str,
# }
# )
call_id = candidate.get("id")
arguments_str = json.dumps(arguments)
tool_calls.append(
{
"id": call_id,
"name": name,
"arguments": arguments_str,
}
)
# return tool_calls
return tool_calls
# def _try_convert_content_to_tool_calls_for_non_tool_calling_llms(
# tool_calls_in_progress: dict[int, dict[str, Any]],
# content_parts: list[str],
# structured_response_format: dict | None,
# next_synthetic_tool_call_id: Callable[[], str],
# ) -> None:
# """Populate tool_calls_in_progress when a non-tool-calling LLM returns JSON content describing tool calls."""
# if tool_calls_in_progress or not content_parts or structured_response_format:
# return
def _try_convert_content_to_tool_calls_for_non_tool_calling_llms(
tool_calls_in_progress: dict[int, dict[str, Any]],
content_parts: list[str],
structured_response_format: dict | None,
next_synthetic_tool_call_id: Callable[[], str],
) -> None:
"""Populate tool_calls_in_progress when a non-tool-calling LLM returns JSON content describing tool calls."""
if tool_calls_in_progress or not content_parts or structured_response_format:
return
# tool_calls_from_content = _parse_tool_calls_from_message_content(
# "".join(content_parts)
# )
tool_calls_from_content = _parse_tool_calls_from_message_content(
"".join(content_parts)
)
# if not tool_calls_from_content:
# return
if not tool_calls_from_content:
return
# content_parts.clear()
content_parts.clear()
# for index, tool_call_data in enumerate(tool_calls_from_content):
# call_id = tool_call_data["id"] or next_synthetic_tool_call_id()
# tool_calls_in_progress[index] = {
# "id": call_id,
# "name": tool_call_data["name"],
# "arguments": tool_call_data["arguments"],
# }
for index, tool_call_data in enumerate(tool_calls_from_content):
call_id = tool_call_data["id"] or next_synthetic_tool_call_id()
tool_calls_in_progress[index] = {
"id": call_id,
"name": tool_call_data["name"],
"arguments": tool_call_data["arguments"],
}
# def _update_tool_call_with_delta(
# tool_calls_in_progress: dict[int, dict[str, Any]],
# tool_call_delta: Any,
# ) -> None:
# index = tool_call_delta.index
def _update_tool_call_with_delta(
tool_calls_in_progress: dict[int, dict[str, Any]],
tool_call_delta: Any,
) -> None:
index = tool_call_delta.index
# if index not in tool_calls_in_progress:
# tool_calls_in_progress[index] = {
# "id": None,
# "name": None,
# "arguments": "",
# }
if index not in tool_calls_in_progress:
tool_calls_in_progress[index] = {
"id": None,
"name": None,
"arguments": "",
}
# if tool_call_delta.id:
# tool_calls_in_progress[index]["id"] = tool_call_delta.id
if tool_call_delta.id:
tool_calls_in_progress[index]["id"] = tool_call_delta.id
# if tool_call_delta.function:
# if tool_call_delta.function.name:
# tool_calls_in_progress[index]["name"] = tool_call_delta.function.name
if tool_call_delta.function:
if tool_call_delta.function.name:
tool_calls_in_progress[index]["name"] = tool_call_delta.function.name
# if tool_call_delta.function.arguments:
# tool_calls_in_progress[index][
# "arguments"
# ] += tool_call_delta.function.arguments
if tool_call_delta.function.arguments:
tool_calls_in_progress[index][
"arguments"
] += tool_call_delta.function.arguments
# def query(
# llm_with_default_settings: LLM,
# messages: LanguageModelInput,
# tools: Sequence[Tool],
# context: Any,
# tool_choice: ToolChoiceOptions | None = None,
# structured_response_format: dict | None = None,
# ) -> QueryResult:
# tool_definitions = [tool.tool_definition() for tool in tools]
# tools_by_name = {tool.name: tool for tool in tools}
def query(
llm_with_default_settings: LLM,
messages: LanguageModelInput,
tools: Sequence[Tool],
context: Any,
tool_choice: ToolChoiceOptions | None = None,
structured_response_format: dict | None = None,
) -> QueryResult:
tool_definitions = [tool.tool_definition() for tool in tools]
tools_by_name = {tool.name: tool for tool in tools}
# new_messages_stateful: list[ChatCompletionMessage] = []
new_messages_stateful: list[ChatCompletionMessage] = []
# current_span = agent_span(
# name="agent_framework_query",
# output_type="dict" if structured_response_format else "str",
# )
# current_span.start(mark_as_current=True)
# current_span.span_data.tools = [t.name for t in tools]
current_span = agent_span(
name="agent_framework_query",
output_type="dict" if structured_response_format else "str",
)
current_span.start(mark_as_current=True)
current_span.span_data.tools = [t.name for t in tools]
# def stream_generator() -> Iterator[StreamEvent]:
# message_started = False
# reasoning_started = False
def stream_generator() -> Iterator[StreamEvent]:
message_started = False
reasoning_started = False
# tool_calls_in_progress: dict[int, dict[str, Any]] = {}
tool_calls_in_progress: dict[int, dict[str, Any]] = {}
# content_parts: list[str] = []
content_parts: list[str] = []
# synthetic_tool_call_counter = 0
synthetic_tool_call_counter = 0
# def _next_synthetic_tool_call_id() -> str:
# nonlocal synthetic_tool_call_counter
# call_id = f"synthetic_tool_call_{synthetic_tool_call_counter}"
# synthetic_tool_call_counter += 1
# return call_id
def _next_synthetic_tool_call_id() -> str:
nonlocal synthetic_tool_call_counter
call_id = f"synthetic_tool_call_{synthetic_tool_call_counter}"
synthetic_tool_call_counter += 1
return call_id
# with generation_span( # type: ignore[misc]
# model=llm_with_default_settings.config.model_name,
# model_config={
# "base_url": str(llm_with_default_settings.config.api_base or ""),
# "model_impl": "litellm",
# },
# ) as span_generation:
# # Only set input if messages is a sequence (not a string)
# # ChatCompletionMessage TypedDicts are compatible with Mapping[str, Any] at runtime
# if isinstance(messages, Sequence) and not isinstance(messages, str):
# # Convert ChatCompletionMessage sequence to Sequence[Mapping[str, Any]]
# span_generation.span_data.input = [dict(msg) for msg in messages] # type: ignore[assignment]
# for chunk in llm_with_default_settings.stream(
# prompt=messages,
# tools=tool_definitions,
# tool_choice=tool_choice,
# structured_response_format=structured_response_format,
# ):
# assert isinstance(chunk, ModelResponseStream)
# usage = getattr(chunk, "usage", None)
# if usage:
# span_generation.span_data.usage = {
# "input_tokens": usage.prompt_tokens,
# "output_tokens": usage.completion_tokens,
# "cache_read_input_tokens": usage.cache_read_input_tokens,
# "cache_creation_input_tokens": usage.cache_creation_input_tokens,
# }
with generation_span( # type: ignore[misc]
model=llm_with_default_settings.config.model_name,
model_config={
"base_url": str(llm_with_default_settings.config.api_base or ""),
"model_impl": "litellm",
},
) as span_generation:
# Only set input if messages is a sequence (not a string)
# ChatCompletionMessage TypedDicts are compatible with Mapping[str, Any] at runtime
if isinstance(messages, Sequence) and not isinstance(messages, str):
# Convert ChatCompletionMessage sequence to Sequence[Mapping[str, Any]]
span_generation.span_data.input = [dict(msg) for msg in messages] # type: ignore[assignment]
for chunk in llm_with_default_settings.stream(
prompt=messages,
tools=tool_definitions,
tool_choice=tool_choice,
structured_response_format=structured_response_format,
):
assert isinstance(chunk, ModelResponseStream)
usage = getattr(chunk, "usage", None)
if usage:
span_generation.span_data.usage = {
"input_tokens": usage.prompt_tokens,
"output_tokens": usage.completion_tokens,
"cache_read_input_tokens": usage.cache_read_input_tokens,
"cache_creation_input_tokens": usage.cache_creation_input_tokens,
}
# delta = chunk.choice.delta
# finish_reason = chunk.choice.finish_reason
delta = chunk.choice.delta
finish_reason = chunk.choice.finish_reason
# if delta.reasoning_content:
# if not reasoning_started:
# yield RunItemStreamEvent(type="reasoning_start")
# reasoning_started = True
if delta.reasoning_content:
if not reasoning_started:
yield RunItemStreamEvent(type="reasoning_start")
reasoning_started = True
# if delta.content:
# if reasoning_started:
# yield RunItemStreamEvent(type="reasoning_done")
# reasoning_started = False
# content_parts.append(delta.content)
# if not message_started:
# yield RunItemStreamEvent(type="message_start")
# message_started = True
if delta.content:
if reasoning_started:
yield RunItemStreamEvent(type="reasoning_done")
reasoning_started = False
content_parts.append(delta.content)
if not message_started:
yield RunItemStreamEvent(type="message_start")
message_started = True
# if delta.tool_calls:
# if reasoning_started:
# yield RunItemStreamEvent(type="reasoning_done")
# reasoning_started = False
# if message_started:
# yield RunItemStreamEvent(type="message_done")
# message_started = False
if delta.tool_calls:
if reasoning_started:
yield RunItemStreamEvent(type="reasoning_done")
reasoning_started = False
if message_started:
yield RunItemStreamEvent(type="message_done")
message_started = False
# for tool_call_delta in delta.tool_calls:
# _update_tool_call_with_delta(
# tool_calls_in_progress, tool_call_delta
# )
for tool_call_delta in delta.tool_calls:
_update_tool_call_with_delta(
tool_calls_in_progress, tool_call_delta
)
# yield chunk
yield chunk
# if not finish_reason:
# continue
if not finish_reason:
continue
# if reasoning_started:
# yield RunItemStreamEvent(type="reasoning_done")
# reasoning_started = False
# if message_started:
# yield RunItemStreamEvent(type="message_done")
# message_started = False
if reasoning_started:
yield RunItemStreamEvent(type="reasoning_done")
reasoning_started = False
if message_started:
yield RunItemStreamEvent(type="message_done")
message_started = False
# if tool_choice != "none":
# _try_convert_content_to_tool_calls_for_non_tool_calling_llms(
# tool_calls_in_progress,
# content_parts,
# structured_response_format,
# _next_synthetic_tool_call_id,
# )
if tool_choice != "none":
_try_convert_content_to_tool_calls_for_non_tool_calling_llms(
tool_calls_in_progress,
content_parts,
structured_response_format,
_next_synthetic_tool_call_id,
)
# if content_parts:
# new_messages_stateful.append(
# {
# "role": "assistant",
# "content": "".join(content_parts),
# }
# )
# span_generation.span_data.output = new_messages_stateful
if content_parts:
new_messages_stateful.append(
{
"role": "assistant",
"content": "".join(content_parts),
}
)
span_generation.span_data.output = new_messages_stateful
# # Execute tool calls outside of the stream loop and generation_span
# if tool_calls_in_progress:
# sorted_tool_calls = sorted(tool_calls_in_progress.items())
# Execute tool calls outside of the stream loop and generation_span
if tool_calls_in_progress:
sorted_tool_calls = sorted(tool_calls_in_progress.items())
# # Build tool calls for the message and execute tools
# assistant_tool_calls: list[ToolCall] = []
# Build tool calls for the message and execute tools
assistant_tool_calls: list[ToolCall] = []
tool_outputs: dict[str, str] = {}
# for _, tool_call_data in sorted_tool_calls:
# call_id = tool_call_data["id"]
# name = tool_call_data["name"]
# arguments_str = tool_call_data["arguments"]
for _, tool_call_data in sorted_tool_calls:
call_id = tool_call_data["id"]
name = tool_call_data["name"]
arguments_str = tool_call_data["arguments"]
# if call_id is None or name is None:
# continue
if call_id is None or name is None:
continue
# assistant_tool_calls.append(
# {
# "id": call_id,
# "type": "function",
# "function": {
# "name": name,
# "arguments": arguments_str,
# },
# }
# )
assistant_tool_calls.append(
{
"id": call_id,
"type": "function",
"function": {
"name": name,
"arguments": arguments_str,
},
}
)
# yield RunItemStreamEvent(
# type="tool_call",
# details=ToolCallStreamItem(
# call_id=call_id,
# name=name,
# arguments=arguments_str,
# ),
# )
yield RunItemStreamEvent(
type="tool_call",
details=ToolCallStreamItem(
call_id=call_id,
name=name,
arguments=arguments_str,
),
)
# if name in tools_by_name:
# tools_by_name[name]
# json.loads(arguments_str)
if name in tools_by_name:
tool = tools_by_name[name]
arguments = json.loads(arguments_str)
# run_context = RunContextWrapper(context=context)
run_context = RunContextWrapper(context=context)
# TODO: Instead of executing sequentially, execute in parallel
# In practice, it's not a must right now since we don't use parallel
# tool calls, so kicking the can down the road for now.
# TODO: Instead of executing sequentially, execute in parallel
# In practice, it's not a must right now since we don't use parallel
# tool calls, so kicking the can down the road for now.
with function_span(tool.name) as span_fn:
span_fn.span_data.input = arguments
try:
output = tool.run_v2(run_context, **arguments)
tool_outputs[call_id] = _serialize_tool_output(output)
span_fn.span_data.output = output
except Exception as e:
_error_tracing.attach_error_to_current_span(
SpanError(
message="Error running tool",
data={"tool_name": tool.name, "error": str(e)},
)
)
# Treat the error as the tool output so the framework can continue
error_output = f"Error: {str(e)}"
tool_outputs[call_id] = error_output
output = error_output
# TODO broken for now, no need for a run_v2
# output = tool.run_v2(run_context, **arguments)
yield RunItemStreamEvent(
type="tool_call_output",
details=ToolCallOutputStreamItem(
call_id=call_id,
output=output,
),
)
else:
not_found_output = f"Tool {name} not found"
tool_outputs[call_id] = _serialize_tool_output(not_found_output)
yield RunItemStreamEvent(
type="tool_call_output",
details=ToolCallOutputStreamItem(
call_id=call_id,
output=not_found_output,
),
)
# yield RunItemStreamEvent(
# type="tool_call_output",
# details=ToolCallOutputStreamItem(
# call_id=call_id,
# output=output,
# ),
# )
new_messages_stateful.append(
{
"role": "assistant",
"content": None,
"tool_calls": assistant_tool_calls,
}
)
for _, tool_call_data in sorted_tool_calls:
call_id = tool_call_data["id"]
if call_id in tool_outputs:
new_messages_stateful.append(
{
"role": "tool",
"content": tool_outputs[call_id],
"tool_call_id": call_id,
}
)
current_span.finish(reset_current=True)
return QueryResult(
stream=stream_generator(),
new_messages_stateful=new_messages_stateful,
)

View File

@@ -1,21 +1,21 @@
# from operator import add
# from typing import Annotated
from operator import add
from typing import Annotated
# from pydantic import BaseModel
from pydantic import BaseModel
# class CoreState(BaseModel):
# """
# This is the core state that is shared across all subgraphs.
# """
class CoreState(BaseModel):
"""
This is the core state that is shared across all subgraphs.
"""
# log_messages: Annotated[list[str], add] = []
# current_step_nr: int = 1
log_messages: Annotated[list[str], add] = []
current_step_nr: int = 1
# class SubgraphCoreState(BaseModel):
# """
# This is the core state that is shared across all subgraphs.
# """
class SubgraphCoreState(BaseModel):
"""
This is the core state that is shared across all subgraphs.
"""
# log_messages: Annotated[list[str], add] = []
log_messages: Annotated[list[str], add] = []

View File

@@ -1,62 +1,62 @@
# from collections.abc import Hashable
# from typing import cast
from collections.abc import Hashable
from typing import cast
# from langchain_core.runnables.config import RunnableConfig
# from langgraph.types import Send
from langchain_core.runnables.config import RunnableConfig
from langgraph.types import Send
# from onyx.agents.agent_search.dc_search_analysis.states import ObjectInformationInput
# from onyx.agents.agent_search.dc_search_analysis.states import (
# ObjectResearchInformationUpdate,
# )
# from onyx.agents.agent_search.dc_search_analysis.states import ObjectSourceInput
# from onyx.agents.agent_search.dc_search_analysis.states import (
# SearchSourcesObjectsUpdate,
# )
# from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.dc_search_analysis.states import ObjectInformationInput
from onyx.agents.agent_search.dc_search_analysis.states import (
ObjectResearchInformationUpdate,
)
from onyx.agents.agent_search.dc_search_analysis.states import ObjectSourceInput
from onyx.agents.agent_search.dc_search_analysis.states import (
SearchSourcesObjectsUpdate,
)
from onyx.agents.agent_search.models import GraphConfig
# def parallel_object_source_research_edge(
# state: SearchSourcesObjectsUpdate, config: RunnableConfig
# ) -> list[Send | Hashable]:
# """
# LangGraph edge to parallelize the research for an individual object and source
# """
def parallel_object_source_research_edge(
state: SearchSourcesObjectsUpdate, config: RunnableConfig
) -> list[Send | Hashable]:
"""
LangGraph edge to parallelize the research for an individual object and source
"""
# search_objects = state.analysis_objects
# search_sources = state.analysis_sources
search_objects = state.analysis_objects
search_sources = state.analysis_sources
# object_source_combinations = [
# (object, source) for object in search_objects for source in search_sources
# ]
object_source_combinations = [
(object, source) for object in search_objects for source in search_sources
]
# return [
# Send(
# "research_object_source",
# ObjectSourceInput(
# object_source_combination=object_source_combination,
# log_messages=[],
# ),
# )
# for object_source_combination in object_source_combinations
# ]
return [
Send(
"research_object_source",
ObjectSourceInput(
object_source_combination=object_source_combination,
log_messages=[],
),
)
for object_source_combination in object_source_combinations
]
# def parallel_object_research_consolidation_edge(
# state: ObjectResearchInformationUpdate, config: RunnableConfig
# ) -> list[Send | Hashable]:
# """
# LangGraph edge to parallelize the research for an individual object and source
# """
# cast(GraphConfig, config["metadata"]["config"])
# object_research_information_results = state.object_research_information_results
def parallel_object_research_consolidation_edge(
state: ObjectResearchInformationUpdate, config: RunnableConfig
) -> list[Send | Hashable]:
"""
LangGraph edge to parallelize the research for an individual object and source
"""
cast(GraphConfig, config["metadata"]["config"])
object_research_information_results = state.object_research_information_results
# return [
# Send(
# "consolidate_object_research",
# ObjectInformationInput(
# object_information=object_information,
# log_messages=[],
# ),
# )
# for object_information in object_research_information_results
# ]
return [
Send(
"consolidate_object_research",
ObjectInformationInput(
object_information=object_information,
log_messages=[],
),
)
for object_information in object_research_information_results
]

View File

@@ -1,103 +1,103 @@
# from langgraph.graph import END
# from langgraph.graph import START
# from langgraph.graph import StateGraph
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
# from onyx.agents.agent_search.dc_search_analysis.edges import (
# parallel_object_research_consolidation_edge,
# )
# from onyx.agents.agent_search.dc_search_analysis.edges import (
# parallel_object_source_research_edge,
# )
# from onyx.agents.agent_search.dc_search_analysis.nodes.a1_search_objects import (
# search_objects,
# )
# from onyx.agents.agent_search.dc_search_analysis.nodes.a2_research_object_source import (
# research_object_source,
# )
# from onyx.agents.agent_search.dc_search_analysis.nodes.a3_structure_research_by_object import (
# structure_research_by_object,
# )
# from onyx.agents.agent_search.dc_search_analysis.nodes.a4_consolidate_object_research import (
# consolidate_object_research,
# )
# from onyx.agents.agent_search.dc_search_analysis.nodes.a5_consolidate_research import (
# consolidate_research,
# )
# from onyx.agents.agent_search.dc_search_analysis.states import MainInput
# from onyx.agents.agent_search.dc_search_analysis.states import MainState
# from onyx.utils.logger import setup_logger
from onyx.agents.agent_search.dc_search_analysis.edges import (
parallel_object_research_consolidation_edge,
)
from onyx.agents.agent_search.dc_search_analysis.edges import (
parallel_object_source_research_edge,
)
from onyx.agents.agent_search.dc_search_analysis.nodes.a1_search_objects import (
search_objects,
)
from onyx.agents.agent_search.dc_search_analysis.nodes.a2_research_object_source import (
research_object_source,
)
from onyx.agents.agent_search.dc_search_analysis.nodes.a3_structure_research_by_object import (
structure_research_by_object,
)
from onyx.agents.agent_search.dc_search_analysis.nodes.a4_consolidate_object_research import (
consolidate_object_research,
)
from onyx.agents.agent_search.dc_search_analysis.nodes.a5_consolidate_research import (
consolidate_research,
)
from onyx.agents.agent_search.dc_search_analysis.states import MainInput
from onyx.agents.agent_search.dc_search_analysis.states import MainState
from onyx.utils.logger import setup_logger
# logger = setup_logger()
logger = setup_logger()
# test_mode = False
test_mode = False
# def divide_and_conquer_graph_builder(test_mode: bool = False) -> StateGraph:
# """
# LangGraph graph builder for the knowledge graph search process.
# """
def divide_and_conquer_graph_builder(test_mode: bool = False) -> StateGraph:
"""
LangGraph graph builder for the knowledge graph search process.
"""
# graph = StateGraph(
# state_schema=MainState,
# input=MainInput,
# )
graph = StateGraph(
state_schema=MainState,
input=MainInput,
)
# ### Add nodes ###
### Add nodes ###
# graph.add_node(
# "search_objects",
# search_objects,
# )
graph.add_node(
"search_objects",
search_objects,
)
# graph.add_node(
# "structure_research_by_source",
# structure_research_by_object,
# )
graph.add_node(
"structure_research_by_source",
structure_research_by_object,
)
# graph.add_node(
# "research_object_source",
# research_object_source,
# )
graph.add_node(
"research_object_source",
research_object_source,
)
# graph.add_node(
# "consolidate_object_research",
# consolidate_object_research,
# )
graph.add_node(
"consolidate_object_research",
consolidate_object_research,
)
# graph.add_node(
# "consolidate_research",
# consolidate_research,
# )
graph.add_node(
"consolidate_research",
consolidate_research,
)
# ### Add edges ###
### Add edges ###
# graph.add_edge(start_key=START, end_key="search_objects")
graph.add_edge(start_key=START, end_key="search_objects")
# graph.add_conditional_edges(
# source="search_objects",
# path=parallel_object_source_research_edge,
# path_map=["research_object_source"],
# )
graph.add_conditional_edges(
source="search_objects",
path=parallel_object_source_research_edge,
path_map=["research_object_source"],
)
# graph.add_edge(
# start_key="research_object_source",
# end_key="structure_research_by_source",
# )
graph.add_edge(
start_key="research_object_source",
end_key="structure_research_by_source",
)
# graph.add_conditional_edges(
# source="structure_research_by_source",
# path=parallel_object_research_consolidation_edge,
# path_map=["consolidate_object_research"],
# )
graph.add_conditional_edges(
source="structure_research_by_source",
path=parallel_object_research_consolidation_edge,
path_map=["consolidate_object_research"],
)
# graph.add_edge(
# start_key="consolidate_object_research",
# end_key="consolidate_research",
# )
graph.add_edge(
start_key="consolidate_object_research",
end_key="consolidate_research",
)
# graph.add_edge(
# start_key="consolidate_research",
# end_key=END,
# )
graph.add_edge(
start_key="consolidate_research",
end_key=END,
)
# return graph
return graph

View File

@@ -1,146 +1,146 @@
# from typing import cast
from typing import cast
# from langchain_core.messages import HumanMessage
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from langchain_core.messages import HumanMessage
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from onyx.agents.agent_search.dc_search_analysis.ops import extract_section
# from onyx.agents.agent_search.dc_search_analysis.ops import research
# from onyx.agents.agent_search.dc_search_analysis.states import MainState
# from onyx.agents.agent_search.dc_search_analysis.states import (
# SearchSourcesObjectsUpdate,
# )
# from onyx.agents.agent_search.models import GraphConfig
# from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
# trim_prompt_piece,
# )
# from onyx.prompts.agents.dc_prompts import DC_OBJECT_NO_BASE_DATA_EXTRACTION_PROMPT
# from onyx.prompts.agents.dc_prompts import DC_OBJECT_SEPARATOR
# from onyx.prompts.agents.dc_prompts import DC_OBJECT_WITH_BASE_DATA_EXTRACTION_PROMPT
# from onyx.secondary_llm_flows.source_filter import strings_to_document_sources
# from onyx.utils.logger import setup_logger
# from onyx.utils.threadpool_concurrency import run_with_timeout
from onyx.agents.agent_search.dc_search_analysis.ops import extract_section
from onyx.agents.agent_search.dc_search_analysis.ops import research
from onyx.agents.agent_search.dc_search_analysis.states import MainState
from onyx.agents.agent_search.dc_search_analysis.states import (
SearchSourcesObjectsUpdate,
)
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
trim_prompt_piece,
)
from onyx.prompts.agents.dc_prompts import DC_OBJECT_NO_BASE_DATA_EXTRACTION_PROMPT
from onyx.prompts.agents.dc_prompts import DC_OBJECT_SEPARATOR
from onyx.prompts.agents.dc_prompts import DC_OBJECT_WITH_BASE_DATA_EXTRACTION_PROMPT
from onyx.secondary_llm_flows.source_filter import strings_to_document_sources
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
# logger = setup_logger()
logger = setup_logger()
# def search_objects(
# state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
# ) -> SearchSourcesObjectsUpdate:
# """
# LangGraph node to start the agentic search process.
# """
def search_objects(
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> SearchSourcesObjectsUpdate:
"""
LangGraph node to start the agentic search process.
"""
# graph_config = cast(GraphConfig, config["metadata"]["config"])
# question = graph_config.inputs.prompt_builder.raw_user_query
# search_tool = graph_config.tooling.search_tool
graph_config = cast(GraphConfig, config["metadata"]["config"])
question = graph_config.inputs.prompt_builder.raw_user_query
search_tool = graph_config.tooling.search_tool
# if search_tool is None or graph_config.inputs.persona is None:
# raise ValueError("Search tool and persona must be provided for DivCon search")
if search_tool is None or graph_config.inputs.persona is None:
raise ValueError("Search tool and persona must be provided for DivCon search")
# try:
# instructions = graph_config.inputs.persona.system_prompt or ""
try:
instructions = graph_config.inputs.persona.system_prompt or ""
# agent_1_instructions = extract_section(
# instructions, "Agent Step 1:", "Agent Step 2:"
# )
# if agent_1_instructions is None:
# raise ValueError("Agent 1 instructions not found")
agent_1_instructions = extract_section(
instructions, "Agent Step 1:", "Agent Step 2:"
)
if agent_1_instructions is None:
raise ValueError("Agent 1 instructions not found")
# agent_1_base_data = extract_section(instructions, "|Start Data|", "|End Data|")
agent_1_base_data = extract_section(instructions, "|Start Data|", "|End Data|")
# agent_1_task = extract_section(
# agent_1_instructions, "Task:", "Independent Research Sources:"
# )
# if agent_1_task is None:
# raise ValueError("Agent 1 task not found")
agent_1_task = extract_section(
agent_1_instructions, "Task:", "Independent Research Sources:"
)
if agent_1_task is None:
raise ValueError("Agent 1 task not found")
# agent_1_independent_sources_str = extract_section(
# agent_1_instructions, "Independent Research Sources:", "Output Objective:"
# )
# if agent_1_independent_sources_str is None:
# raise ValueError("Agent 1 Independent Research Sources not found")
agent_1_independent_sources_str = extract_section(
agent_1_instructions, "Independent Research Sources:", "Output Objective:"
)
if agent_1_independent_sources_str is None:
raise ValueError("Agent 1 Independent Research Sources not found")
# document_sources = strings_to_document_sources(
# [
# x.strip().lower()
# for x in agent_1_independent_sources_str.split(DC_OBJECT_SEPARATOR)
# ]
# )
document_sources = strings_to_document_sources(
[
x.strip().lower()
for x in agent_1_independent_sources_str.split(DC_OBJECT_SEPARATOR)
]
)
# agent_1_output_objective = extract_section(
# agent_1_instructions, "Output Objective:"
# )
# if agent_1_output_objective is None:
# raise ValueError("Agent 1 output objective not found")
agent_1_output_objective = extract_section(
agent_1_instructions, "Output Objective:"
)
if agent_1_output_objective is None:
raise ValueError("Agent 1 output objective not found")
# except Exception as e:
# raise ValueError(
# f"Agent 1 instructions not found or not formatted correctly: {e}"
# )
except Exception as e:
raise ValueError(
f"Agent 1 instructions not found or not formatted correctly: {e}"
)
# # Extract objects
# Extract objects
# if agent_1_base_data is None:
# # Retrieve chunks for objects
if agent_1_base_data is None:
# Retrieve chunks for objects
# retrieved_docs = research(question, search_tool)[:10]
retrieved_docs = research(question, search_tool)[:10]
# document_texts_list = []
# for doc_num, doc in enumerate(retrieved_docs):
# chunk_text = "Document " + str(doc_num) + ":\n" + doc.content
# document_texts_list.append(chunk_text)
document_texts_list = []
for doc_num, doc in enumerate(retrieved_docs):
chunk_text = "Document " + str(doc_num) + ":\n" + doc.content
document_texts_list.append(chunk_text)
# document_texts = "\n\n".join(document_texts_list)
document_texts = "\n\n".join(document_texts_list)
# dc_object_extraction_prompt = DC_OBJECT_NO_BASE_DATA_EXTRACTION_PROMPT.format(
# question=question,
# task=agent_1_task,
# document_text=document_texts,
# objects_of_interest=agent_1_output_objective,
# )
# else:
# dc_object_extraction_prompt = DC_OBJECT_WITH_BASE_DATA_EXTRACTION_PROMPT.format(
# question=question,
# task=agent_1_task,
# base_data=agent_1_base_data,
# objects_of_interest=agent_1_output_objective,
# )
dc_object_extraction_prompt = DC_OBJECT_NO_BASE_DATA_EXTRACTION_PROMPT.format(
question=question,
task=agent_1_task,
document_text=document_texts,
objects_of_interest=agent_1_output_objective,
)
else:
dc_object_extraction_prompt = DC_OBJECT_WITH_BASE_DATA_EXTRACTION_PROMPT.format(
question=question,
task=agent_1_task,
base_data=agent_1_base_data,
objects_of_interest=agent_1_output_objective,
)
# msg = [
# HumanMessage(
# content=trim_prompt_piece(
# config=graph_config.tooling.primary_llm.config,
# prompt_piece=dc_object_extraction_prompt,
# reserved_str="",
# ),
# )
# ]
# primary_llm = graph_config.tooling.primary_llm
# # Grader
# try:
# llm_response = run_with_timeout(
# 30,
# primary_llm.invoke_langchain,
# prompt=msg,
# timeout_override=30,
# max_tokens=300,
# )
msg = [
HumanMessage(
content=trim_prompt_piece(
config=graph_config.tooling.primary_llm.config,
prompt_piece=dc_object_extraction_prompt,
reserved_str="",
),
)
]
primary_llm = graph_config.tooling.primary_llm
# Grader
try:
llm_response = run_with_timeout(
30,
primary_llm.invoke_langchain,
prompt=msg,
timeout_override=30,
max_tokens=300,
)
# cleaned_response = (
# str(llm_response.content)
# .replace("```json\n", "")
# .replace("\n```", "")
# .replace("\n", "")
# )
# cleaned_response = cleaned_response.split("OBJECTS:")[1]
# object_list = [x.strip() for x in cleaned_response.split(";")]
cleaned_response = (
str(llm_response.content)
.replace("```json\n", "")
.replace("\n```", "")
.replace("\n", "")
)
cleaned_response = cleaned_response.split("OBJECTS:")[1]
object_list = [x.strip() for x in cleaned_response.split(";")]
# except Exception as e:
# raise ValueError(f"Error in search_objects: {e}")
except Exception as e:
raise ValueError(f"Error in search_objects: {e}")
# return SearchSourcesObjectsUpdate(
# analysis_objects=object_list,
# analysis_sources=document_sources,
# log_messages=["Agent 1 Task done"],
# )
return SearchSourcesObjectsUpdate(
analysis_objects=object_list,
analysis_sources=document_sources,
log_messages=["Agent 1 Task done"],
)

View File

@@ -1,180 +1,180 @@
# from datetime import datetime
# from datetime import timedelta
# from datetime import timezone
# from typing import cast
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from typing import cast
# from langchain_core.messages import HumanMessage
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from langchain_core.messages import HumanMessage
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from onyx.agents.agent_search.dc_search_analysis.ops import extract_section
# from onyx.agents.agent_search.dc_search_analysis.ops import research
# from onyx.agents.agent_search.dc_search_analysis.states import ObjectSourceInput
# from onyx.agents.agent_search.dc_search_analysis.states import (
# ObjectSourceResearchUpdate,
# )
# from onyx.agents.agent_search.models import GraphConfig
# from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
# trim_prompt_piece,
# )
# from onyx.prompts.agents.dc_prompts import DC_OBJECT_SOURCE_RESEARCH_PROMPT
# from onyx.utils.logger import setup_logger
# from onyx.utils.threadpool_concurrency import run_with_timeout
from onyx.agents.agent_search.dc_search_analysis.ops import extract_section
from onyx.agents.agent_search.dc_search_analysis.ops import research
from onyx.agents.agent_search.dc_search_analysis.states import ObjectSourceInput
from onyx.agents.agent_search.dc_search_analysis.states import (
ObjectSourceResearchUpdate,
)
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
trim_prompt_piece,
)
from onyx.prompts.agents.dc_prompts import DC_OBJECT_SOURCE_RESEARCH_PROMPT
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
# logger = setup_logger()
logger = setup_logger()
# def research_object_source(
# state: ObjectSourceInput,
# config: RunnableConfig,
# writer: StreamWriter = lambda _: None,
# ) -> ObjectSourceResearchUpdate:
# """
# LangGraph node to start the agentic search process.
# """
# datetime.now()
def research_object_source(
state: ObjectSourceInput,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> ObjectSourceResearchUpdate:
"""
LangGraph node to start the agentic search process.
"""
datetime.now()
# graph_config = cast(GraphConfig, config["metadata"]["config"])
# search_tool = graph_config.tooling.search_tool
# question = graph_config.inputs.prompt_builder.raw_user_query
# object, document_source = state.object_source_combination
graph_config = cast(GraphConfig, config["metadata"]["config"])
search_tool = graph_config.tooling.search_tool
question = graph_config.inputs.prompt_builder.raw_user_query
object, document_source = state.object_source_combination
# if search_tool is None or graph_config.inputs.persona is None:
# raise ValueError("Search tool and persona must be provided for DivCon search")
if search_tool is None or graph_config.inputs.persona is None:
raise ValueError("Search tool and persona must be provided for DivCon search")
# try:
# instructions = graph_config.inputs.persona.system_prompt or ""
try:
instructions = graph_config.inputs.persona.system_prompt or ""
# agent_2_instructions = extract_section(
# instructions, "Agent Step 2:", "Agent Step 3:"
# )
# if agent_2_instructions is None:
# raise ValueError("Agent 2 instructions not found")
agent_2_instructions = extract_section(
instructions, "Agent Step 2:", "Agent Step 3:"
)
if agent_2_instructions is None:
raise ValueError("Agent 2 instructions not found")
# agent_2_task = extract_section(
# agent_2_instructions, "Task:", "Independent Research Sources:"
# )
# if agent_2_task is None:
# raise ValueError("Agent 2 task not found")
agent_2_task = extract_section(
agent_2_instructions, "Task:", "Independent Research Sources:"
)
if agent_2_task is None:
raise ValueError("Agent 2 task not found")
# agent_2_time_cutoff = extract_section(
# agent_2_instructions, "Time Cutoff:", "Research Topics:"
# )
agent_2_time_cutoff = extract_section(
agent_2_instructions, "Time Cutoff:", "Research Topics:"
)
# agent_2_research_topics = extract_section(
# agent_2_instructions, "Research Topics:", "Output Objective"
# )
agent_2_research_topics = extract_section(
agent_2_instructions, "Research Topics:", "Output Objective"
)
# agent_2_output_objective = extract_section(
# agent_2_instructions, "Output Objective:"
# )
# if agent_2_output_objective is None:
# raise ValueError("Agent 2 output objective not found")
agent_2_output_objective = extract_section(
agent_2_instructions, "Output Objective:"
)
if agent_2_output_objective is None:
raise ValueError("Agent 2 output objective not found")
# except Exception:
# raise ValueError(
# "Agent 1 instructions not found or not formatted correctly: {e}"
# )
except Exception:
raise ValueError(
"Agent 1 instructions not found or not formatted correctly: {e}"
)
# # Populate prompt
# Populate prompt
# # Retrieve chunks for objects
# Retrieve chunks for objects
# if agent_2_time_cutoff is not None and agent_2_time_cutoff.strip() != "":
# if agent_2_time_cutoff.strip().endswith("d"):
# try:
# days = int(agent_2_time_cutoff.strip()[:-1])
# agent_2_source_start_time = datetime.now(timezone.utc) - timedelta(
# days=days
# )
# except ValueError:
# raise ValueError(
# f"Invalid time cutoff format: {agent_2_time_cutoff}. Expected format: '<number>d'"
# )
# else:
# raise ValueError(
# f"Invalid time cutoff format: {agent_2_time_cutoff}. Expected format: '<number>d'"
# )
# else:
# agent_2_source_start_time = None
if agent_2_time_cutoff is not None and agent_2_time_cutoff.strip() != "":
if agent_2_time_cutoff.strip().endswith("d"):
try:
days = int(agent_2_time_cutoff.strip()[:-1])
agent_2_source_start_time = datetime.now(timezone.utc) - timedelta(
days=days
)
except ValueError:
raise ValueError(
f"Invalid time cutoff format: {agent_2_time_cutoff}. Expected format: '<number>d'"
)
else:
raise ValueError(
f"Invalid time cutoff format: {agent_2_time_cutoff}. Expected format: '<number>d'"
)
else:
agent_2_source_start_time = None
# document_sources = [document_source] if document_source else None
document_sources = [document_source] if document_source else None
# if len(question.strip()) > 0:
# research_area = f"{question} for {object}"
# elif agent_2_research_topics and len(agent_2_research_topics.strip()) > 0:
# research_area = f"{agent_2_research_topics} for {object}"
# else:
# research_area = object
if len(question.strip()) > 0:
research_area = f"{question} for {object}"
elif agent_2_research_topics and len(agent_2_research_topics.strip()) > 0:
research_area = f"{agent_2_research_topics} for {object}"
else:
research_area = object
# retrieved_docs = research(
# question=research_area,
# search_tool=search_tool,
# document_sources=document_sources,
# time_cutoff=agent_2_source_start_time,
# )
retrieved_docs = research(
question=research_area,
search_tool=search_tool,
document_sources=document_sources,
time_cutoff=agent_2_source_start_time,
)
# # Generate document text
# Generate document text
# document_texts_list = []
# for doc_num, doc in enumerate(retrieved_docs):
# chunk_text = "Document " + str(doc_num) + ":\n" + doc.content
# document_texts_list.append(chunk_text)
document_texts_list = []
for doc_num, doc in enumerate(retrieved_docs):
chunk_text = "Document " + str(doc_num) + ":\n" + doc.content
document_texts_list.append(chunk_text)
# document_texts = "\n\n".join(document_texts_list)
document_texts = "\n\n".join(document_texts_list)
# # Built prompt
# Built prompt
# today = datetime.now().strftime("%A, %Y-%m-%d")
today = datetime.now().strftime("%A, %Y-%m-%d")
# dc_object_source_research_prompt = (
# DC_OBJECT_SOURCE_RESEARCH_PROMPT.format(
# today=today,
# question=question,
# task=agent_2_task,
# document_text=document_texts,
# format=agent_2_output_objective,
# )
# .replace("---object---", object)
# .replace("---source---", document_source.value)
# )
dc_object_source_research_prompt = (
DC_OBJECT_SOURCE_RESEARCH_PROMPT.format(
today=today,
question=question,
task=agent_2_task,
document_text=document_texts,
format=agent_2_output_objective,
)
.replace("---object---", object)
.replace("---source---", document_source.value)
)
# # Run LLM
# Run LLM
# msg = [
# HumanMessage(
# content=trim_prompt_piece(
# config=graph_config.tooling.primary_llm.config,
# prompt_piece=dc_object_source_research_prompt,
# reserved_str="",
# ),
# )
# ]
# primary_llm = graph_config.tooling.primary_llm
# # Grader
# try:
# llm_response = run_with_timeout(
# 30,
# primary_llm.invoke_langchain,
# prompt=msg,
# timeout_override=30,
# max_tokens=300,
# )
msg = [
HumanMessage(
content=trim_prompt_piece(
config=graph_config.tooling.primary_llm.config,
prompt_piece=dc_object_source_research_prompt,
reserved_str="",
),
)
]
primary_llm = graph_config.tooling.primary_llm
# Grader
try:
llm_response = run_with_timeout(
30,
primary_llm.invoke_langchain,
prompt=msg,
timeout_override=30,
max_tokens=300,
)
# cleaned_response = str(llm_response.content).replace("```json\n", "")
# cleaned_response = cleaned_response.split("RESEARCH RESULTS:")[1]
# object_research_results = {
# "object": object,
# "source": document_source.value,
# "research_result": cleaned_response,
# }
cleaned_response = str(llm_response.content).replace("```json\n", "")
cleaned_response = cleaned_response.split("RESEARCH RESULTS:")[1]
object_research_results = {
"object": object,
"source": document_source.value,
"research_result": cleaned_response,
}
# except Exception as e:
# raise ValueError(f"Error in research_object_source: {e}")
except Exception as e:
raise ValueError(f"Error in research_object_source: {e}")
# logger.debug("DivCon Step A2 - Object Source Research - completed for an object")
logger.debug("DivCon Step A2 - Object Source Research - completed for an object")
# return ObjectSourceResearchUpdate(
# object_source_research_results=[object_research_results],
# log_messages=["Agent Step 2 done for one object"],
# )
return ObjectSourceResearchUpdate(
object_source_research_results=[object_research_results],
log_messages=["Agent Step 2 done for one object"],
)

View File

@@ -1,48 +1,48 @@
# from collections import defaultdict
# from typing import Dict
# from typing import List
from collections import defaultdict
from typing import Dict
from typing import List
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from onyx.agents.agent_search.dc_search_analysis.states import MainState
# from onyx.agents.agent_search.dc_search_analysis.states import (
# ObjectResearchInformationUpdate,
# )
# from onyx.utils.logger import setup_logger
from onyx.agents.agent_search.dc_search_analysis.states import MainState
from onyx.agents.agent_search.dc_search_analysis.states import (
ObjectResearchInformationUpdate,
)
from onyx.utils.logger import setup_logger
# logger = setup_logger()
logger = setup_logger()
# def structure_research_by_object(
# state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
# ) -> ObjectResearchInformationUpdate:
# """
# LangGraph node to start the agentic search process.
# """
def structure_research_by_object(
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> ObjectResearchInformationUpdate:
"""
LangGraph node to start the agentic search process.
"""
# object_source_research_results = state.object_source_research_results
object_source_research_results = state.object_source_research_results
# object_research_information_results: List[Dict[str, str]] = []
# object_research_information_results_list: Dict[str, List[str]] = defaultdict(list)
object_research_information_results: List[Dict[str, str]] = []
object_research_information_results_list: Dict[str, List[str]] = defaultdict(list)
# for object_source_research in object_source_research_results:
# object = object_source_research["object"]
# source = object_source_research["source"]
# research_result = object_source_research["research_result"]
for object_source_research in object_source_research_results:
object = object_source_research["object"]
source = object_source_research["source"]
research_result = object_source_research["research_result"]
# object_research_information_results_list[object].append(
# f"Source: {source}\n{research_result}"
# )
object_research_information_results_list[object].append(
f"Source: {source}\n{research_result}"
)
# for object, information in object_research_information_results_list.items():
# object_research_information_results.append(
# {"object": object, "information": "\n".join(information)}
# )
for object, information in object_research_information_results_list.items():
object_research_information_results.append(
{"object": object, "information": "\n".join(information)}
)
# logger.debug("DivCon Step A3 - Object Research Information Structuring - completed")
logger.debug("DivCon Step A3 - Object Research Information Structuring - completed")
# return ObjectResearchInformationUpdate(
# object_research_information_results=object_research_information_results,
# log_messages=["A3 - Object Research Information structured"],
# )
return ObjectResearchInformationUpdate(
object_research_information_results=object_research_information_results,
log_messages=["A3 - Object Research Information structured"],
)

View File

@@ -1,103 +1,103 @@
# from typing import cast
from typing import cast
# from langchain_core.messages import HumanMessage
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from langchain_core.messages import HumanMessage
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from onyx.agents.agent_search.dc_search_analysis.ops import extract_section
# from onyx.agents.agent_search.dc_search_analysis.states import ObjectInformationInput
# from onyx.agents.agent_search.dc_search_analysis.states import ObjectResearchUpdate
# from onyx.agents.agent_search.models import GraphConfig
# from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
# trim_prompt_piece,
# )
# from onyx.prompts.agents.dc_prompts import DC_OBJECT_CONSOLIDATION_PROMPT
# from onyx.utils.logger import setup_logger
# from onyx.utils.threadpool_concurrency import run_with_timeout
from onyx.agents.agent_search.dc_search_analysis.ops import extract_section
from onyx.agents.agent_search.dc_search_analysis.states import ObjectInformationInput
from onyx.agents.agent_search.dc_search_analysis.states import ObjectResearchUpdate
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
trim_prompt_piece,
)
from onyx.prompts.agents.dc_prompts import DC_OBJECT_CONSOLIDATION_PROMPT
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
# logger = setup_logger()
logger = setup_logger()
# def consolidate_object_research(
# state: ObjectInformationInput,
# config: RunnableConfig,
# writer: StreamWriter = lambda _: None,
# ) -> ObjectResearchUpdate:
# """
# LangGraph node to start the agentic search process.
# """
# graph_config = cast(GraphConfig, config["metadata"]["config"])
# search_tool = graph_config.tooling.search_tool
# question = graph_config.inputs.prompt_builder.raw_user_query
def consolidate_object_research(
state: ObjectInformationInput,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> ObjectResearchUpdate:
"""
LangGraph node to start the agentic search process.
"""
graph_config = cast(GraphConfig, config["metadata"]["config"])
search_tool = graph_config.tooling.search_tool
question = graph_config.inputs.prompt_builder.raw_user_query
# if search_tool is None or graph_config.inputs.persona is None:
# raise ValueError("Search tool and persona must be provided for DivCon search")
if search_tool is None or graph_config.inputs.persona is None:
raise ValueError("Search tool and persona must be provided for DivCon search")
# instructions = graph_config.inputs.persona.system_prompt or ""
instructions = graph_config.inputs.persona.system_prompt or ""
# agent_4_instructions = extract_section(
# instructions, "Agent Step 4:", "Agent Step 5:"
# )
# if agent_4_instructions is None:
# raise ValueError("Agent 4 instructions not found")
# agent_4_output_objective = extract_section(
# agent_4_instructions, "Output Objective:"
# )
# if agent_4_output_objective is None:
# raise ValueError("Agent 4 output objective not found")
agent_4_instructions = extract_section(
instructions, "Agent Step 4:", "Agent Step 5:"
)
if agent_4_instructions is None:
raise ValueError("Agent 4 instructions not found")
agent_4_output_objective = extract_section(
agent_4_instructions, "Output Objective:"
)
if agent_4_output_objective is None:
raise ValueError("Agent 4 output objective not found")
# object_information = state.object_information
object_information = state.object_information
# object = object_information["object"]
# information = object_information["information"]
object = object_information["object"]
information = object_information["information"]
# # Create a prompt for the object consolidation
# Create a prompt for the object consolidation
# dc_object_consolidation_prompt = DC_OBJECT_CONSOLIDATION_PROMPT.format(
# question=question,
# object=object,
# information=information,
# format=agent_4_output_objective,
# )
dc_object_consolidation_prompt = DC_OBJECT_CONSOLIDATION_PROMPT.format(
question=question,
object=object,
information=information,
format=agent_4_output_objective,
)
# # Run LLM
# Run LLM
# msg = [
# HumanMessage(
# content=trim_prompt_piece(
# config=graph_config.tooling.primary_llm.config,
# prompt_piece=dc_object_consolidation_prompt,
# reserved_str="",
# ),
# )
# ]
# primary_llm = graph_config.tooling.primary_llm
# # Grader
# try:
# llm_response = run_with_timeout(
# 30,
# primary_llm.invoke_langchain,
# prompt=msg,
# timeout_override=30,
# max_tokens=300,
# )
msg = [
HumanMessage(
content=trim_prompt_piece(
config=graph_config.tooling.primary_llm.config,
prompt_piece=dc_object_consolidation_prompt,
reserved_str="",
),
)
]
primary_llm = graph_config.tooling.primary_llm
# Grader
try:
llm_response = run_with_timeout(
30,
primary_llm.invoke_langchain,
prompt=msg,
timeout_override=30,
max_tokens=300,
)
# cleaned_response = str(llm_response.content).replace("```json\n", "")
# consolidated_information = cleaned_response.split("INFORMATION:")[1]
cleaned_response = str(llm_response.content).replace("```json\n", "")
consolidated_information = cleaned_response.split("INFORMATION:")[1]
# except Exception as e:
# raise ValueError(f"Error in consolidate_object_research: {e}")
except Exception as e:
raise ValueError(f"Error in consolidate_object_research: {e}")
# object_research_results = {
# "object": object,
# "research_result": consolidated_information,
# }
object_research_results = {
"object": object,
"research_result": consolidated_information,
}
# logger.debug(
# "DivCon Step A4 - Object Research Consolidation - completed for an object"
# )
logger.debug(
"DivCon Step A4 - Object Research Consolidation - completed for an object"
)
# return ObjectResearchUpdate(
# object_research_results=[object_research_results],
# log_messages=["Agent Source Consilidation done"],
# )
return ObjectResearchUpdate(
object_research_results=[object_research_results],
log_messages=["Agent Source Consilidation done"],
)

View File

@@ -1,127 +1,127 @@
# from typing import cast
from typing import cast
# from langchain_core.messages import HumanMessage
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from langchain_core.messages import HumanMessage
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from onyx.agents.agent_search.dc_search_analysis.ops import extract_section
# from onyx.agents.agent_search.dc_search_analysis.states import MainState
# from onyx.agents.agent_search.dc_search_analysis.states import ResearchUpdate
# from onyx.agents.agent_search.models import GraphConfig
# from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
# trim_prompt_piece,
# )
# from onyx.agents.agent_search.shared_graph_utils.llm import stream_llm_answer
# from onyx.prompts.agents.dc_prompts import DC_FORMATTING_NO_BASE_DATA_PROMPT
# from onyx.prompts.agents.dc_prompts import DC_FORMATTING_WITH_BASE_DATA_PROMPT
# from onyx.utils.logger import setup_logger
# from onyx.utils.threadpool_concurrency import run_with_timeout
from onyx.agents.agent_search.dc_search_analysis.ops import extract_section
from onyx.agents.agent_search.dc_search_analysis.states import MainState
from onyx.agents.agent_search.dc_search_analysis.states import ResearchUpdate
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
trim_prompt_piece,
)
from onyx.agents.agent_search.shared_graph_utils.llm import stream_llm_answer
from onyx.prompts.agents.dc_prompts import DC_FORMATTING_NO_BASE_DATA_PROMPT
from onyx.prompts.agents.dc_prompts import DC_FORMATTING_WITH_BASE_DATA_PROMPT
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
# logger = setup_logger()
logger = setup_logger()
# def consolidate_research(
# state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
# ) -> ResearchUpdate:
# """
# LangGraph node to start the agentic search process.
# """
def consolidate_research(
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> ResearchUpdate:
"""
LangGraph node to start the agentic search process.
"""
# graph_config = cast(GraphConfig, config["metadata"]["config"])
graph_config = cast(GraphConfig, config["metadata"]["config"])
# search_tool = graph_config.tooling.search_tool
search_tool = graph_config.tooling.search_tool
# if search_tool is None or graph_config.inputs.persona is None:
# raise ValueError("Search tool and persona must be provided for DivCon search")
if search_tool is None or graph_config.inputs.persona is None:
raise ValueError("Search tool and persona must be provided for DivCon search")
# # Populate prompt
# instructions = graph_config.inputs.persona.system_prompt or ""
# Populate prompt
instructions = graph_config.inputs.persona.system_prompt or ""
# try:
# agent_5_instructions = extract_section(
# instructions, "Agent Step 5:", "Agent End"
# )
# if agent_5_instructions is None:
# raise ValueError("Agent 5 instructions not found")
# agent_5_base_data = extract_section(instructions, "|Start Data|", "|End Data|")
# agent_5_task = extract_section(
# agent_5_instructions, "Task:", "Independent Research Sources:"
# )
# if agent_5_task is None:
# raise ValueError("Agent 5 task not found")
# agent_5_output_objective = extract_section(
# agent_5_instructions, "Output Objective:"
# )
# if agent_5_output_objective is None:
# raise ValueError("Agent 5 output objective not found")
# except ValueError as e:
# raise ValueError(
# f"Instructions for Agent Step 5 were not properly formatted: {e}"
# )
try:
agent_5_instructions = extract_section(
instructions, "Agent Step 5:", "Agent End"
)
if agent_5_instructions is None:
raise ValueError("Agent 5 instructions not found")
agent_5_base_data = extract_section(instructions, "|Start Data|", "|End Data|")
agent_5_task = extract_section(
agent_5_instructions, "Task:", "Independent Research Sources:"
)
if agent_5_task is None:
raise ValueError("Agent 5 task not found")
agent_5_output_objective = extract_section(
agent_5_instructions, "Output Objective:"
)
if agent_5_output_objective is None:
raise ValueError("Agent 5 output objective not found")
except ValueError as e:
raise ValueError(
f"Instructions for Agent Step 5 were not properly formatted: {e}"
)
# research_result_list = []
research_result_list = []
# if agent_5_task.strip() == "*concatenate*":
# object_research_results = state.object_research_results
if agent_5_task.strip() == "*concatenate*":
object_research_results = state.object_research_results
# for object_research_result in object_research_results:
# object = object_research_result["object"]
# research_result = object_research_result["research_result"]
# research_result_list.append(f"Object: {object}\n\n{research_result}")
for object_research_result in object_research_results:
object = object_research_result["object"]
research_result = object_research_result["research_result"]
research_result_list.append(f"Object: {object}\n\n{research_result}")
# research_results = "\n\n".join(research_result_list)
research_results = "\n\n".join(research_result_list)
# else:
# raise NotImplementedError("Only '*concatenate*' is currently supported")
else:
raise NotImplementedError("Only '*concatenate*' is currently supported")
# # Create a prompt for the object consolidation
# Create a prompt for the object consolidation
# if agent_5_base_data is None:
# dc_formatting_prompt = DC_FORMATTING_NO_BASE_DATA_PROMPT.format(
# text=research_results,
# format=agent_5_output_objective,
# )
# else:
# dc_formatting_prompt = DC_FORMATTING_WITH_BASE_DATA_PROMPT.format(
# base_data=agent_5_base_data,
# text=research_results,
# format=agent_5_output_objective,
# )
if agent_5_base_data is None:
dc_formatting_prompt = DC_FORMATTING_NO_BASE_DATA_PROMPT.format(
text=research_results,
format=agent_5_output_objective,
)
else:
dc_formatting_prompt = DC_FORMATTING_WITH_BASE_DATA_PROMPT.format(
base_data=agent_5_base_data,
text=research_results,
format=agent_5_output_objective,
)
# # Run LLM
# Run LLM
# msg = [
# HumanMessage(
# content=trim_prompt_piece(
# config=graph_config.tooling.primary_llm.config,
# prompt_piece=dc_formatting_prompt,
# reserved_str="",
# ),
# )
# ]
msg = [
HumanMessage(
content=trim_prompt_piece(
config=graph_config.tooling.primary_llm.config,
prompt_piece=dc_formatting_prompt,
reserved_str="",
),
)
]
# try:
# _ = run_with_timeout(
# 60,
# lambda: stream_llm_answer(
# llm=graph_config.tooling.primary_llm,
# prompt=msg,
# event_name="initial_agent_answer",
# writer=writer,
# agent_answer_level=0,
# agent_answer_question_num=0,
# agent_answer_type="agent_level_answer",
# timeout_override=30,
# max_tokens=None,
# ),
# )
try:
_ = run_with_timeout(
60,
lambda: stream_llm_answer(
llm=graph_config.tooling.primary_llm,
prompt=msg,
event_name="initial_agent_answer",
writer=writer,
agent_answer_level=0,
agent_answer_question_num=0,
agent_answer_type="agent_level_answer",
timeout_override=30,
max_tokens=None,
),
)
# except Exception as e:
# raise ValueError(f"Error in consolidate_research: {e}")
except Exception as e:
raise ValueError(f"Error in consolidate_research: {e}")
# logger.debug("DivCon Step A5 - Final Generation - completed")
logger.debug("DivCon Step A5 - Final Generation - completed")
# return ResearchUpdate(
# research_results=research_results,
# log_messages=["Agent Source Consilidation done"],
# )
return ResearchUpdate(
research_results=research_results,
log_messages=["Agent Source Consilidation done"],
)

View File

@@ -1,50 +1,61 @@
# from datetime import datetime
# from typing import cast
from datetime import datetime
from typing import cast
# from onyx.chat.models import LlmDoc
# from onyx.configs.constants import DocumentSource
# from onyx.tools.models import SearchToolOverrideKwargs
# from onyx.tools.tool_implementations.search.search_tool import (
# FINAL_CONTEXT_DOCUMENTS_ID,
# )
# from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.chat.models import LlmDoc
from onyx.configs.constants import DocumentSource
from onyx.context.search.models import InferenceSection
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.tools.models import SearchToolOverrideKwargs
from onyx.tools.tool_implementations.search.search_tool import (
FINAL_CONTEXT_DOCUMENTS_ID,
)
from onyx.tools.tool_implementations.search.search_tool import SearchTool
# def research(
# question: str,
# search_tool: SearchTool,
# document_sources: list[DocumentSource] | None = None,
# time_cutoff: datetime | None = None,
# ) -> list[LlmDoc]:
# # new db session to avoid concurrency issues
def research(
question: str,
search_tool: SearchTool,
document_sources: list[DocumentSource] | None = None,
time_cutoff: datetime | None = None,
) -> list[LlmDoc]:
# new db session to avoid concurrency issues
# retrieved_docs: list[LlmDoc] = []
callback_container: list[list[InferenceSection]] = []
retrieved_docs: list[LlmDoc] = []
# for tool_response in search_tool.run(
# query=question,
# override_kwargs=SearchToolOverrideKwargs(original_query=question),
# ):
# # get retrieved docs to send to the rest of the graph
# if tool_response.id == FINAL_CONTEXT_DOCUMENTS_ID:
# retrieved_docs = cast(list[LlmDoc], tool_response.response)[:10]
# break
# return retrieved_docs
with get_session_with_current_tenant() as db_session:
for tool_response in search_tool.run(
query=question,
override_kwargs=SearchToolOverrideKwargs(
force_no_rerank=False,
alternate_db_session=db_session,
retrieved_sections_callback=callback_container.append,
skip_query_analysis=True,
document_sources=document_sources,
time_cutoff=time_cutoff,
),
):
# get retrieved docs to send to the rest of the graph
if tool_response.id == FINAL_CONTEXT_DOCUMENTS_ID:
retrieved_docs = cast(list[LlmDoc], tool_response.response)[:10]
break
return retrieved_docs
# def extract_section(
# text: str, start_marker: str, end_marker: str | None = None
# ) -> str | None:
# """Extract text between markers, returning None if markers not found"""
# parts = text.split(start_marker)
def extract_section(
text: str, start_marker: str, end_marker: str | None = None
) -> str | None:
"""Extract text between markers, returning None if markers not found"""
parts = text.split(start_marker)
# if len(parts) == 1:
# return None
if len(parts) == 1:
return None
# after_start = parts[1].strip()
after_start = parts[1].strip()
# if not end_marker:
# return after_start
if not end_marker:
return after_start
# extract = after_start.split(end_marker)[0]
extract = after_start.split(end_marker)[0]
# return extract.strip()
return extract.strip()

View File

@@ -1,72 +1,72 @@
# from operator import add
# from typing import Annotated
# from typing import Dict
# from typing import TypedDict
from operator import add
from typing import Annotated
from typing import Dict
from typing import TypedDict
# from pydantic import BaseModel
from pydantic import BaseModel
# from onyx.agents.agent_search.core_state import CoreState
# from onyx.agents.agent_search.orchestration.states import ToolCallUpdate
# from onyx.agents.agent_search.orchestration.states import ToolChoiceInput
# from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
# from onyx.configs.constants import DocumentSource
from onyx.agents.agent_search.core_state import CoreState
from onyx.agents.agent_search.orchestration.states import ToolCallUpdate
from onyx.agents.agent_search.orchestration.states import ToolChoiceInput
from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
from onyx.configs.constants import DocumentSource
# ### States ###
# class LoggerUpdate(BaseModel):
# log_messages: Annotated[list[str], add] = []
### States ###
class LoggerUpdate(BaseModel):
log_messages: Annotated[list[str], add] = []
# class SearchSourcesObjectsUpdate(LoggerUpdate):
# analysis_objects: list[str] = []
# analysis_sources: list[DocumentSource] = []
class SearchSourcesObjectsUpdate(LoggerUpdate):
analysis_objects: list[str] = []
analysis_sources: list[DocumentSource] = []
# class ObjectSourceInput(LoggerUpdate):
# object_source_combination: tuple[str, DocumentSource]
class ObjectSourceInput(LoggerUpdate):
object_source_combination: tuple[str, DocumentSource]
# class ObjectSourceResearchUpdate(LoggerUpdate):
# object_source_research_results: Annotated[list[Dict[str, str]], add] = []
class ObjectSourceResearchUpdate(LoggerUpdate):
object_source_research_results: Annotated[list[Dict[str, str]], add] = []
# class ObjectInformationInput(LoggerUpdate):
# object_information: Dict[str, str]
class ObjectInformationInput(LoggerUpdate):
object_information: Dict[str, str]
# class ObjectResearchInformationUpdate(LoggerUpdate):
# object_research_information_results: Annotated[list[Dict[str, str]], add] = []
class ObjectResearchInformationUpdate(LoggerUpdate):
object_research_information_results: Annotated[list[Dict[str, str]], add] = []
# class ObjectResearchUpdate(LoggerUpdate):
# object_research_results: Annotated[list[Dict[str, str]], add] = []
class ObjectResearchUpdate(LoggerUpdate):
object_research_results: Annotated[list[Dict[str, str]], add] = []
# class ResearchUpdate(LoggerUpdate):
# research_results: str | None = None
class ResearchUpdate(LoggerUpdate):
research_results: str | None = None
# ## Graph Input State
# class MainInput(CoreState):
# pass
## Graph Input State
class MainInput(CoreState):
pass
# ## Graph State
# class MainState(
# # This includes the core state
# MainInput,
# ToolChoiceInput,
# ToolCallUpdate,
# ToolChoiceUpdate,
# SearchSourcesObjectsUpdate,
# ObjectSourceResearchUpdate,
# ObjectResearchInformationUpdate,
# ObjectResearchUpdate,
# ResearchUpdate,
# ):
# pass
## Graph State
class MainState(
# This includes the core state
MainInput,
ToolChoiceInput,
ToolCallUpdate,
ToolChoiceUpdate,
SearchSourcesObjectsUpdate,
ObjectSourceResearchUpdate,
ObjectResearchInformationUpdate,
ObjectResearchUpdate,
ResearchUpdate,
):
pass
# ## Graph Output State - presently not used
# class MainOutput(TypedDict):
# log_messages: list[str]
## Graph Output State - presently not used
class MainOutput(TypedDict):
log_messages: list[str]

View File

@@ -1,36 +1,36 @@
# from pydantic import BaseModel
from pydantic import BaseModel
# class RefinementSubQuestion(BaseModel):
# sub_question: str
# sub_question_id: str
# verified: bool
# answered: bool
# answer: str
class RefinementSubQuestion(BaseModel):
sub_question: str
sub_question_id: str
verified: bool
answered: bool
answer: str
# class AgentTimings(BaseModel):
# base_duration_s: float | None
# refined_duration_s: float | None
# full_duration_s: float | None
class AgentTimings(BaseModel):
base_duration_s: float | None
refined_duration_s: float | None
full_duration_s: float | None
# class AgentBaseMetrics(BaseModel):
# num_verified_documents_total: int | None
# num_verified_documents_core: int | None
# verified_avg_score_core: float | None
# num_verified_documents_base: int | float | None
# verified_avg_score_base: float | None = None
# base_doc_boost_factor: float | None = None
# support_boost_factor: float | None = None
# duration_s: float | None = None
class AgentBaseMetrics(BaseModel):
num_verified_documents_total: int | None
num_verified_documents_core: int | None
verified_avg_score_core: float | None
num_verified_documents_base: int | float | None
verified_avg_score_base: float | None = None
base_doc_boost_factor: float | None = None
support_boost_factor: float | None = None
duration_s: float | None = None
# class AgentRefinedMetrics(BaseModel):
# refined_doc_boost_factor: float | None = None
# refined_question_boost_factor: float | None = None
# duration_s: float | None = None
class AgentRefinedMetrics(BaseModel):
refined_doc_boost_factor: float | None = None
refined_question_boost_factor: float | None = None
duration_s: float | None = None
# class AgentAdditionalMetrics(BaseModel):
# pass
class AgentAdditionalMetrics(BaseModel):
pass

View File

@@ -1,61 +1,61 @@
# from collections.abc import Hashable
from collections.abc import Hashable
# from langgraph.graph import END
# from langgraph.types import Send
from langgraph.graph import END
from langgraph.types import Send
# from onyx.agents.agent_search.dr.enums import DRPath
# from onyx.agents.agent_search.dr.states import MainState
from onyx.agents.agent_search.dr.enums import DRPath
from onyx.agents.agent_search.dr.states import MainState
# def decision_router(state: MainState) -> list[Send | Hashable] | DRPath | str:
# if not state.tools_used:
# raise IndexError("state.tools_used cannot be empty")
def decision_router(state: MainState) -> list[Send | Hashable] | DRPath | str:
if not state.tools_used:
raise IndexError("state.tools_used cannot be empty")
# # next_tool is either a generic tool name or a DRPath string
# next_tool_name = state.tools_used[-1]
# next_tool is either a generic tool name or a DRPath string
next_tool_name = state.tools_used[-1]
# available_tools = state.available_tools
# if not available_tools:
# raise ValueError("No tool is available. This should not happen.")
available_tools = state.available_tools
if not available_tools:
raise ValueError("No tool is available. This should not happen.")
# if next_tool_name in available_tools:
# next_tool_path = available_tools[next_tool_name].path
# elif next_tool_name == DRPath.END.value:
# return END
# elif next_tool_name == DRPath.LOGGER.value:
# return DRPath.LOGGER
# elif next_tool_name == DRPath.CLOSER.value:
# return DRPath.CLOSER
# else:
# return DRPath.ORCHESTRATOR
if next_tool_name in available_tools:
next_tool_path = available_tools[next_tool_name].path
elif next_tool_name == DRPath.END.value:
return END
elif next_tool_name == DRPath.LOGGER.value:
return DRPath.LOGGER
elif next_tool_name == DRPath.CLOSER.value:
return DRPath.CLOSER
else:
return DRPath.ORCHESTRATOR
# # handle invalid paths
# if next_tool_path == DRPath.CLARIFIER:
# raise ValueError("CLARIFIER is not a valid path during iteration")
# handle invalid paths
if next_tool_path == DRPath.CLARIFIER:
raise ValueError("CLARIFIER is not a valid path during iteration")
# # handle tool calls without a query
# if (
# next_tool_path
# in (
# DRPath.INTERNAL_SEARCH,
# DRPath.WEB_SEARCH,
# DRPath.KNOWLEDGE_GRAPH,
# DRPath.IMAGE_GENERATION,
# )
# and len(state.query_list) == 0
# ):
# return DRPath.CLOSER
# handle tool calls without a query
if (
next_tool_path
in (
DRPath.INTERNAL_SEARCH,
DRPath.WEB_SEARCH,
DRPath.KNOWLEDGE_GRAPH,
DRPath.IMAGE_GENERATION,
)
and len(state.query_list) == 0
):
return DRPath.CLOSER
# return next_tool_path
return next_tool_path
# def completeness_router(state: MainState) -> DRPath | str:
# if not state.tools_used:
# raise IndexError("tools_used cannot be empty")
def completeness_router(state: MainState) -> DRPath | str:
if not state.tools_used:
raise IndexError("tools_used cannot be empty")
# # go to closer if path is CLOSER or no queries
# next_path = state.tools_used[-1]
# go to closer if path is CLOSER or no queries
next_path = state.tools_used[-1]
# if next_path == DRPath.ORCHESTRATOR.value:
# return DRPath.ORCHESTRATOR
# return DRPath.LOGGER
if next_path == DRPath.ORCHESTRATOR.value:
return DRPath.ORCHESTRATOR
return DRPath.LOGGER

View File

@@ -1,27 +1,31 @@
# from onyx.agents.agent_search.dr.enums import DRPath
from onyx.agents.agent_search.dr.enums import DRPath
from onyx.agents.agent_search.dr.enums import ResearchType
# MAX_CHAT_HISTORY_MESSAGES = (
# 3 # note: actual count is x2 to account for user and assistant messages
# )
MAX_CHAT_HISTORY_MESSAGES = (
3 # note: actual count is x2 to account for user and assistant messages
)
# MAX_DR_PARALLEL_SEARCH = 4
MAX_DR_PARALLEL_SEARCH = 4
# # TODO: test more, generally not needed/adds unnecessary iterations
# MAX_NUM_CLOSER_SUGGESTIONS = (
# 0 # how many times the closer can send back to the orchestrator
# )
# TODO: test more, generally not needed/adds unnecessary iterations
MAX_NUM_CLOSER_SUGGESTIONS = (
0 # how many times the closer can send back to the orchestrator
)
# CLARIFICATION_REQUEST_PREFIX = "PLEASE CLARIFY:"
# HIGH_LEVEL_PLAN_PREFIX = "The Plan:"
CLARIFICATION_REQUEST_PREFIX = "PLEASE CLARIFY:"
HIGH_LEVEL_PLAN_PREFIX = "The Plan:"
# AVERAGE_TOOL_COSTS: dict[DRPath, float] = {
# DRPath.INTERNAL_SEARCH: 1.0,
# DRPath.KNOWLEDGE_GRAPH: 2.0,
# DRPath.WEB_SEARCH: 1.5,
# DRPath.IMAGE_GENERATION: 3.0,
# DRPath.GENERIC_TOOL: 1.5, # TODO: see todo in OrchestratorTool
# DRPath.CLOSER: 0.0,
# }
AVERAGE_TOOL_COSTS: dict[DRPath, float] = {
DRPath.INTERNAL_SEARCH: 1.0,
DRPath.KNOWLEDGE_GRAPH: 2.0,
DRPath.WEB_SEARCH: 1.5,
DRPath.IMAGE_GENERATION: 3.0,
DRPath.GENERIC_TOOL: 1.5, # TODO: see todo in OrchestratorTool
DRPath.CLOSER: 0.0,
}
# # Default time budget for agentic search (when use_agentic_search is True)
# DR_TIME_BUDGET_DEFAULT = 12.0
DR_TIME_BUDGET_BY_TYPE = {
ResearchType.THOUGHTFUL: 3.0,
ResearchType.DEEP: 12.0,
ResearchType.FAST: 0.5,
}

View File

@@ -1,111 +1,112 @@
# from datetime import datetime
from datetime import datetime
# from onyx.agents.agent_search.dr.enums import DRPath
# from onyx.agents.agent_search.dr.models import DRPromptPurpose
# from onyx.agents.agent_search.dr.models import OrchestratorTool
# from onyx.prompts.dr_prompts import GET_CLARIFICATION_PROMPT
# from onyx.prompts.dr_prompts import KG_TYPES_DESCRIPTIONS
# from onyx.prompts.dr_prompts import ORCHESTRATOR_DEEP_INITIAL_PLAN_PROMPT
# from onyx.prompts.dr_prompts import ORCHESTRATOR_DEEP_ITERATIVE_DECISION_PROMPT
# from onyx.prompts.dr_prompts import ORCHESTRATOR_FAST_ITERATIVE_DECISION_PROMPT
# from onyx.prompts.dr_prompts import ORCHESTRATOR_FAST_ITERATIVE_REASONING_PROMPT
# from onyx.prompts.dr_prompts import ORCHESTRATOR_NEXT_STEP_PURPOSE_PROMPT
# from onyx.prompts.dr_prompts import TOOL_DIFFERENTIATION_HINTS
# from onyx.prompts.dr_prompts import TOOL_QUESTION_HINTS
# from onyx.prompts.prompt_template import PromptTemplate
from onyx.agents.agent_search.dr.enums import DRPath
from onyx.agents.agent_search.dr.enums import ResearchType
from onyx.agents.agent_search.dr.models import DRPromptPurpose
from onyx.agents.agent_search.dr.models import OrchestratorTool
from onyx.prompts.dr_prompts import GET_CLARIFICATION_PROMPT
from onyx.prompts.dr_prompts import KG_TYPES_DESCRIPTIONS
from onyx.prompts.dr_prompts import ORCHESTRATOR_DEEP_INITIAL_PLAN_PROMPT
from onyx.prompts.dr_prompts import ORCHESTRATOR_DEEP_ITERATIVE_DECISION_PROMPT
from onyx.prompts.dr_prompts import ORCHESTRATOR_FAST_ITERATIVE_DECISION_PROMPT
from onyx.prompts.dr_prompts import ORCHESTRATOR_FAST_ITERATIVE_REASONING_PROMPT
from onyx.prompts.dr_prompts import ORCHESTRATOR_NEXT_STEP_PURPOSE_PROMPT
from onyx.prompts.dr_prompts import TOOL_DIFFERENTIATION_HINTS
from onyx.prompts.dr_prompts import TOOL_QUESTION_HINTS
from onyx.prompts.prompt_template import PromptTemplate
# def get_dr_prompt_orchestration_templates(
# purpose: DRPromptPurpose,
# use_agentic_search: bool,
# available_tools: dict[str, OrchestratorTool],
# entity_types_string: str | None = None,
# relationship_types_string: str | None = None,
# reasoning_result: str | None = None,
# tool_calls_string: str | None = None,
# ) -> PromptTemplate:
# available_tools = available_tools or {}
# tool_names = list(available_tools.keys())
# tool_description_str = "\n\n".join(
# f"- {tool_name}: {tool.description}"
# for tool_name, tool in available_tools.items()
# )
# tool_cost_str = "\n".join(
# f"{tool_name}: {tool.cost}" for tool_name, tool in available_tools.items()
# )
def get_dr_prompt_orchestration_templates(
purpose: DRPromptPurpose,
research_type: ResearchType,
available_tools: dict[str, OrchestratorTool],
entity_types_string: str | None = None,
relationship_types_string: str | None = None,
reasoning_result: str | None = None,
tool_calls_string: str | None = None,
) -> PromptTemplate:
available_tools = available_tools or {}
tool_names = list(available_tools.keys())
tool_description_str = "\n\n".join(
f"- {tool_name}: {tool.description}"
for tool_name, tool in available_tools.items()
)
tool_cost_str = "\n".join(
f"{tool_name}: {tool.cost}" for tool_name, tool in available_tools.items()
)
# tool_differentiations: list[str] = [
# TOOL_DIFFERENTIATION_HINTS[(tool_1, tool_2)]
# for tool_1 in available_tools
# for tool_2 in available_tools
# if (tool_1, tool_2) in TOOL_DIFFERENTIATION_HINTS
# ]
# tool_differentiation_hint_string = (
# "\n".join(tool_differentiations) or "(No differentiating hints available)"
# )
# # TODO: add tool deliniation pairs for custom tools as well
tool_differentiations: list[str] = [
TOOL_DIFFERENTIATION_HINTS[(tool_1, tool_2)]
for tool_1 in available_tools
for tool_2 in available_tools
if (tool_1, tool_2) in TOOL_DIFFERENTIATION_HINTS
]
tool_differentiation_hint_string = (
"\n".join(tool_differentiations) or "(No differentiating hints available)"
)
# TODO: add tool deliniation pairs for custom tools as well
# tool_question_hint_string = (
# "\n".join(
# "- " + TOOL_QUESTION_HINTS[tool]
# for tool in available_tools
# if tool in TOOL_QUESTION_HINTS
# )
# or "(No examples available)"
# )
tool_question_hint_string = (
"\n".join(
"- " + TOOL_QUESTION_HINTS[tool]
for tool in available_tools
if tool in TOOL_QUESTION_HINTS
)
or "(No examples available)"
)
# if DRPath.KNOWLEDGE_GRAPH.value in available_tools and (
# entity_types_string or relationship_types_string
# ):
if DRPath.KNOWLEDGE_GRAPH.value in available_tools and (
entity_types_string or relationship_types_string
):
# kg_types_descriptions = KG_TYPES_DESCRIPTIONS.build(
# possible_entities=entity_types_string or "",
# possible_relationships=relationship_types_string or "",
# )
# else:
# kg_types_descriptions = "(The Knowledge Graph is not used.)"
kg_types_descriptions = KG_TYPES_DESCRIPTIONS.build(
possible_entities=entity_types_string or "",
possible_relationships=relationship_types_string or "",
)
else:
kg_types_descriptions = "(The Knowledge Graph is not used.)"
# if purpose == DRPromptPurpose.PLAN:
# if not use_agentic_search:
# raise ValueError("plan generation is only supported for agentic search")
# base_template = ORCHESTRATOR_DEEP_INITIAL_PLAN_PROMPT
if purpose == DRPromptPurpose.PLAN:
if research_type == ResearchType.THOUGHTFUL:
raise ValueError("plan generation is not supported for FAST time budget")
base_template = ORCHESTRATOR_DEEP_INITIAL_PLAN_PROMPT
# elif purpose == DRPromptPurpose.NEXT_STEP_REASONING:
# if not use_agentic_search:
# base_template = ORCHESTRATOR_FAST_ITERATIVE_REASONING_PROMPT
# else:
# raise ValueError(
# "reasoning is not separately required for agentic search"
# )
elif purpose == DRPromptPurpose.NEXT_STEP_REASONING:
if research_type == ResearchType.THOUGHTFUL:
base_template = ORCHESTRATOR_FAST_ITERATIVE_REASONING_PROMPT
else:
raise ValueError(
"reasoning is not separately required for DEEP time budget"
)
# elif purpose == DRPromptPurpose.NEXT_STEP_PURPOSE:
# base_template = ORCHESTRATOR_NEXT_STEP_PURPOSE_PROMPT
elif purpose == DRPromptPurpose.NEXT_STEP_PURPOSE:
base_template = ORCHESTRATOR_NEXT_STEP_PURPOSE_PROMPT
# elif purpose == DRPromptPurpose.NEXT_STEP:
# if not use_agentic_search:
# base_template = ORCHESTRATOR_FAST_ITERATIVE_DECISION_PROMPT
# else:
# base_template = ORCHESTRATOR_DEEP_ITERATIVE_DECISION_PROMPT
elif purpose == DRPromptPurpose.NEXT_STEP:
if research_type == ResearchType.THOUGHTFUL:
base_template = ORCHESTRATOR_FAST_ITERATIVE_DECISION_PROMPT
else:
base_template = ORCHESTRATOR_DEEP_ITERATIVE_DECISION_PROMPT
# elif purpose == DRPromptPurpose.CLARIFICATION:
# if not use_agentic_search:
# raise ValueError("clarification is only supported for agentic search")
# base_template = GET_CLARIFICATION_PROMPT
elif purpose == DRPromptPurpose.CLARIFICATION:
if research_type == ResearchType.THOUGHTFUL:
raise ValueError("clarification is not supported for FAST time budget")
base_template = GET_CLARIFICATION_PROMPT
# else:
# # for mypy, clearly a mypy bug
# raise ValueError(f"Invalid purpose: {purpose}")
else:
# for mypy, clearly a mypy bug
raise ValueError(f"Invalid purpose: {purpose}")
# return base_template.partial_build(
# num_available_tools=str(len(tool_names)),
# available_tools=", ".join(tool_names),
# tool_choice_options=" or ".join(tool_names),
# current_time=datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
# kg_types_descriptions=kg_types_descriptions,
# tool_descriptions=tool_description_str,
# tool_differentiation_hints=tool_differentiation_hint_string,
# tool_question_hints=tool_question_hint_string,
# average_tool_costs=tool_cost_str,
# reasoning_result=reasoning_result or "(No reasoning result provided.)",
# tool_calls_string=tool_calls_string or "(No tool calls provided.)",
# )
return base_template.partial_build(
num_available_tools=str(len(tool_names)),
available_tools=", ".join(tool_names),
tool_choice_options=" or ".join(tool_names),
current_time=datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
kg_types_descriptions=kg_types_descriptions,
tool_descriptions=tool_description_str,
tool_differentiation_hints=tool_differentiation_hint_string,
tool_question_hints=tool_question_hint_string,
average_tool_costs=tool_cost_str,
reasoning_result=reasoning_result or "(No reasoning result provided.)",
tool_calls_string=tool_calls_string or "(No tool calls provided.)",
)

View File

@@ -1,22 +1,32 @@
# from enum import Enum
from enum import Enum
# class ResearchAnswerPurpose(str, Enum):
# """Research answer purpose options for agent search operations"""
class ResearchType(str, Enum):
"""Research type options for agent search operations"""
# ANSWER = "ANSWER"
# CLARIFICATION_REQUEST = "CLARIFICATION_REQUEST"
# BASIC = "BASIC"
LEGACY_AGENTIC = "LEGACY_AGENTIC" # only used for legacy agentic search migrations
THOUGHTFUL = "THOUGHTFUL"
DEEP = "DEEP"
FAST = "FAST"
# class DRPath(str, Enum):
# CLARIFIER = "Clarifier"
# ORCHESTRATOR = "Orchestrator"
# INTERNAL_SEARCH = "Internal Search"
# GENERIC_TOOL = "Generic Tool"
# KNOWLEDGE_GRAPH = "Knowledge Graph Search"
# WEB_SEARCH = "Web Search"
# IMAGE_GENERATION = "Image Generation"
# GENERIC_INTERNAL_TOOL = "Generic Internal Tool"
# CLOSER = "Closer"
# LOGGER = "Logger"
# END = "End"
class ResearchAnswerPurpose(str, Enum):
"""Research answer purpose options for agent search operations"""
ANSWER = "ANSWER"
CLARIFICATION_REQUEST = "CLARIFICATION_REQUEST"
class DRPath(str, Enum):
CLARIFIER = "Clarifier"
ORCHESTRATOR = "Orchestrator"
INTERNAL_SEARCH = "Internal Search"
GENERIC_TOOL = "Generic Tool"
KNOWLEDGE_GRAPH = "Knowledge Graph Search"
WEB_SEARCH = "Web Search"
IMAGE_GENERATION = "Image Generation"
GENERIC_INTERNAL_TOOL = "Generic Internal Tool"
CLOSER = "Closer"
LOGGER = "Logger"
END = "End"

View File

@@ -1,88 +1,88 @@
# from langgraph.graph import END
# from langgraph.graph import START
# from langgraph.graph import StateGraph
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
# from onyx.agents.agent_search.dr.conditional_edges import completeness_router
# from onyx.agents.agent_search.dr.conditional_edges import decision_router
# from onyx.agents.agent_search.dr.enums import DRPath
# from onyx.agents.agent_search.dr.nodes.dr_a0_clarification import clarifier
# from onyx.agents.agent_search.dr.nodes.dr_a1_orchestrator import orchestrator
# from onyx.agents.agent_search.dr.nodes.dr_a2_closer import closer
# from onyx.agents.agent_search.dr.nodes.dr_a3_logger import logging
# from onyx.agents.agent_search.dr.states import MainInput
# from onyx.agents.agent_search.dr.states import MainState
# from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_basic_search_graph_builder import (
# dr_basic_search_graph_builder,
# )
# from onyx.agents.agent_search.dr.sub_agents.custom_tool.dr_custom_tool_graph_builder import (
# dr_custom_tool_graph_builder,
# )
# from onyx.agents.agent_search.dr.sub_agents.generic_internal_tool.dr_generic_internal_tool_graph_builder import (
# dr_generic_internal_tool_graph_builder,
# )
# from onyx.agents.agent_search.dr.sub_agents.image_generation.dr_image_generation_graph_builder import (
# dr_image_generation_graph_builder,
# )
# from onyx.agents.agent_search.dr.sub_agents.kg_search.dr_kg_search_graph_builder import (
# dr_kg_search_graph_builder,
# )
# from onyx.agents.agent_search.dr.sub_agents.web_search.dr_ws_graph_builder import (
# dr_ws_graph_builder,
# )
from onyx.agents.agent_search.dr.conditional_edges import completeness_router
from onyx.agents.agent_search.dr.conditional_edges import decision_router
from onyx.agents.agent_search.dr.enums import DRPath
from onyx.agents.agent_search.dr.nodes.dr_a0_clarification import clarifier
from onyx.agents.agent_search.dr.nodes.dr_a1_orchestrator import orchestrator
from onyx.agents.agent_search.dr.nodes.dr_a2_closer import closer
from onyx.agents.agent_search.dr.nodes.dr_a3_logger import logging
from onyx.agents.agent_search.dr.states import MainInput
from onyx.agents.agent_search.dr.states import MainState
from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_basic_search_graph_builder import (
dr_basic_search_graph_builder,
)
from onyx.agents.agent_search.dr.sub_agents.custom_tool.dr_custom_tool_graph_builder import (
dr_custom_tool_graph_builder,
)
from onyx.agents.agent_search.dr.sub_agents.generic_internal_tool.dr_generic_internal_tool_graph_builder import (
dr_generic_internal_tool_graph_builder,
)
from onyx.agents.agent_search.dr.sub_agents.image_generation.dr_image_generation_graph_builder import (
dr_image_generation_graph_builder,
)
from onyx.agents.agent_search.dr.sub_agents.kg_search.dr_kg_search_graph_builder import (
dr_kg_search_graph_builder,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.dr_ws_graph_builder import (
dr_ws_graph_builder,
)
# # from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_basic_search_2_act import search
# from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_basic_search_2_act import search
# def dr_graph_builder() -> StateGraph:
# """
# LangGraph graph builder for the deep research agent.
# """
def dr_graph_builder() -> StateGraph:
"""
LangGraph graph builder for the deep research agent.
"""
# graph = StateGraph(state_schema=MainState, input=MainInput)
graph = StateGraph(state_schema=MainState, input=MainInput)
# ### Add nodes ###
### Add nodes ###
# graph.add_node(DRPath.CLARIFIER, clarifier)
graph.add_node(DRPath.CLARIFIER, clarifier)
# graph.add_node(DRPath.ORCHESTRATOR, orchestrator)
graph.add_node(DRPath.ORCHESTRATOR, orchestrator)
# basic_search_graph = dr_basic_search_graph_builder().compile()
# graph.add_node(DRPath.INTERNAL_SEARCH, basic_search_graph)
basic_search_graph = dr_basic_search_graph_builder().compile()
graph.add_node(DRPath.INTERNAL_SEARCH, basic_search_graph)
# kg_search_graph = dr_kg_search_graph_builder().compile()
# graph.add_node(DRPath.KNOWLEDGE_GRAPH, kg_search_graph)
kg_search_graph = dr_kg_search_graph_builder().compile()
graph.add_node(DRPath.KNOWLEDGE_GRAPH, kg_search_graph)
# internet_search_graph = dr_ws_graph_builder().compile()
# graph.add_node(DRPath.WEB_SEARCH, internet_search_graph)
internet_search_graph = dr_ws_graph_builder().compile()
graph.add_node(DRPath.WEB_SEARCH, internet_search_graph)
# image_generation_graph = dr_image_generation_graph_builder().compile()
# graph.add_node(DRPath.IMAGE_GENERATION, image_generation_graph)
image_generation_graph = dr_image_generation_graph_builder().compile()
graph.add_node(DRPath.IMAGE_GENERATION, image_generation_graph)
# custom_tool_graph = dr_custom_tool_graph_builder().compile()
# graph.add_node(DRPath.GENERIC_TOOL, custom_tool_graph)
custom_tool_graph = dr_custom_tool_graph_builder().compile()
graph.add_node(DRPath.GENERIC_TOOL, custom_tool_graph)
# generic_internal_tool_graph = dr_generic_internal_tool_graph_builder().compile()
# graph.add_node(DRPath.GENERIC_INTERNAL_TOOL, generic_internal_tool_graph)
generic_internal_tool_graph = dr_generic_internal_tool_graph_builder().compile()
graph.add_node(DRPath.GENERIC_INTERNAL_TOOL, generic_internal_tool_graph)
# graph.add_node(DRPath.CLOSER, closer)
# graph.add_node(DRPath.LOGGER, logging)
graph.add_node(DRPath.CLOSER, closer)
graph.add_node(DRPath.LOGGER, logging)
# ### Add edges ###
### Add edges ###
# graph.add_edge(start_key=START, end_key=DRPath.CLARIFIER)
graph.add_edge(start_key=START, end_key=DRPath.CLARIFIER)
# graph.add_conditional_edges(DRPath.CLARIFIER, decision_router)
graph.add_conditional_edges(DRPath.CLARIFIER, decision_router)
# graph.add_conditional_edges(DRPath.ORCHESTRATOR, decision_router)
graph.add_conditional_edges(DRPath.ORCHESTRATOR, decision_router)
# graph.add_edge(start_key=DRPath.INTERNAL_SEARCH, end_key=DRPath.ORCHESTRATOR)
# graph.add_edge(start_key=DRPath.KNOWLEDGE_GRAPH, end_key=DRPath.ORCHESTRATOR)
# graph.add_edge(start_key=DRPath.WEB_SEARCH, end_key=DRPath.ORCHESTRATOR)
# graph.add_edge(start_key=DRPath.IMAGE_GENERATION, end_key=DRPath.ORCHESTRATOR)
# graph.add_edge(start_key=DRPath.GENERIC_TOOL, end_key=DRPath.ORCHESTRATOR)
# graph.add_edge(start_key=DRPath.GENERIC_INTERNAL_TOOL, end_key=DRPath.ORCHESTRATOR)
graph.add_edge(start_key=DRPath.INTERNAL_SEARCH, end_key=DRPath.ORCHESTRATOR)
graph.add_edge(start_key=DRPath.KNOWLEDGE_GRAPH, end_key=DRPath.ORCHESTRATOR)
graph.add_edge(start_key=DRPath.WEB_SEARCH, end_key=DRPath.ORCHESTRATOR)
graph.add_edge(start_key=DRPath.IMAGE_GENERATION, end_key=DRPath.ORCHESTRATOR)
graph.add_edge(start_key=DRPath.GENERIC_TOOL, end_key=DRPath.ORCHESTRATOR)
graph.add_edge(start_key=DRPath.GENERIC_INTERNAL_TOOL, end_key=DRPath.ORCHESTRATOR)
# graph.add_conditional_edges(DRPath.CLOSER, completeness_router)
# graph.add_edge(start_key=DRPath.LOGGER, end_key=END)
graph.add_conditional_edges(DRPath.CLOSER, completeness_router)
graph.add_edge(start_key=DRPath.LOGGER, end_key=END)
# return graph
return graph

View File

@@ -1,131 +1,131 @@
# from enum import Enum
from enum import Enum
# from pydantic import BaseModel
from pydantic import BaseModel
from pydantic import ConfigDict
# from onyx.agents.agent_search.dr.enums import DRPath
# from onyx.agents.agent_search.dr.sub_agents.image_generation.models import (
# GeneratedImage,
# )
# from onyx.context.search.models import InferenceSection
# from onyx.tools.tool import Tool
from onyx.agents.agent_search.dr.enums import DRPath
from onyx.agents.agent_search.dr.sub_agents.image_generation.models import (
GeneratedImage,
)
from onyx.context.search.models import InferenceSection
from onyx.tools.tool import Tool
# class OrchestratorStep(BaseModel):
# tool: str
# questions: list[str]
class OrchestratorStep(BaseModel):
tool: str
questions: list[str]
# class OrchestratorDecisonsNoPlan(BaseModel):
# reasoning: str
# next_step: OrchestratorStep
class OrchestratorDecisonsNoPlan(BaseModel):
reasoning: str
next_step: OrchestratorStep
# class OrchestrationPlan(BaseModel):
# reasoning: str
# plan: str
class OrchestrationPlan(BaseModel):
reasoning: str
plan: str
# class ClarificationGenerationResponse(BaseModel):
# clarification_needed: bool
# clarification_question: str
class ClarificationGenerationResponse(BaseModel):
clarification_needed: bool
clarification_question: str
# class DecisionResponse(BaseModel):
# reasoning: str
# decision: str
class DecisionResponse(BaseModel):
reasoning: str
decision: str
# class QueryEvaluationResponse(BaseModel):
# reasoning: str
# query_permitted: bool
class QueryEvaluationResponse(BaseModel):
reasoning: str
query_permitted: bool
# class OrchestrationClarificationInfo(BaseModel):
# clarification_question: str
# clarification_response: str | None = None
class OrchestrationClarificationInfo(BaseModel):
clarification_question: str
clarification_response: str | None = None
# class WebSearchAnswer(BaseModel):
# urls_to_open_indices: list[int]
class WebSearchAnswer(BaseModel):
urls_to_open_indices: list[int]
# class SearchAnswer(BaseModel):
# reasoning: str
# answer: str
# claims: list[str] | None = None
class SearchAnswer(BaseModel):
reasoning: str
answer: str
claims: list[str] | None = None
# class TestInfoCompleteResponse(BaseModel):
# reasoning: str
# complete: bool
# gaps: list[str]
class TestInfoCompleteResponse(BaseModel):
reasoning: str
complete: bool
gaps: list[str]
# # TODO: revisit with custom tools implementation in v2
# # each tool should be a class with the attributes below, plus the actual tool implementation
# # this will also allow custom tools to have their own cost
# class OrchestratorTool(BaseModel):
# tool_id: int
# name: str
# llm_path: str # the path for the LLM to refer by
# path: DRPath # the actual path in the graph
# description: str
# metadata: dict[str, str]
# cost: float
# tool_object: Tool | None = None # None for CLOSER
# TODO: revisit with custom tools implementation in v2
# each tool should be a class with the attributes below, plus the actual tool implementation
# this will also allow custom tools to have their own cost
class OrchestratorTool(BaseModel):
tool_id: int
name: str
llm_path: str # the path for the LLM to refer by
path: DRPath # the actual path in the graph
description: str
metadata: dict[str, str]
cost: float
tool_object: Tool | None = None # None for CLOSER
# class Config:
# arbitrary_types_allowed = True
model_config = ConfigDict(arbitrary_types_allowed=True)
# class IterationInstructions(BaseModel):
# iteration_nr: int
# plan: str | None
# reasoning: str
# purpose: str
class IterationInstructions(BaseModel):
iteration_nr: int
plan: str | None
reasoning: str
purpose: str
# class IterationAnswer(BaseModel):
# tool: str
# tool_id: int
# iteration_nr: int
# parallelization_nr: int
# question: str
# reasoning: str | None
# answer: str
# cited_documents: dict[int, InferenceSection]
# background_info: str | None = None
# claims: list[str] | None = None
# additional_data: dict[str, str] | None = None
# response_type: str | None = None
# data: dict | list | str | int | float | bool | None = None
# file_ids: list[str] | None = None
# # TODO: This is not ideal, but we'll can rework the schema
# # for deep research later
# is_web_fetch: bool = False
# # for image generation step-types
# generated_images: list[GeneratedImage] | None = None
# # for multi-query search tools (v2 web search and internal search)
# # TODO: Clean this up to be more flexible to tools
# queries: list[str] | None = None
class IterationAnswer(BaseModel):
tool: str
tool_id: int
iteration_nr: int
parallelization_nr: int
question: str
reasoning: str | None
answer: str
cited_documents: dict[int, InferenceSection]
background_info: str | None = None
claims: list[str] | None = None
additional_data: dict[str, str] | None = None
response_type: str | None = None
data: dict | list | str | int | float | bool | None = None
file_ids: list[str] | None = None
# TODO: This is not ideal, but we'll can rework the schema
# for deep research later
is_web_fetch: bool = False
# for image generation step-types
generated_images: list[GeneratedImage] | None = None
# for multi-query search tools (v2 web search and internal search)
# TODO: Clean this up to be more flexible to tools
queries: list[str] | None = None
# class AggregatedDRContext(BaseModel):
# context: str
# cited_documents: list[InferenceSection]
# is_internet_marker_dict: dict[str, bool]
# global_iteration_responses: list[IterationAnswer]
class AggregatedDRContext(BaseModel):
context: str
cited_documents: list[InferenceSection]
is_internet_marker_dict: dict[str, bool]
global_iteration_responses: list[IterationAnswer]
# class DRPromptPurpose(str, Enum):
# PLAN = "PLAN"
# NEXT_STEP = "NEXT_STEP"
# NEXT_STEP_REASONING = "NEXT_STEP_REASONING"
# NEXT_STEP_PURPOSE = "NEXT_STEP_PURPOSE"
# CLARIFICATION = "CLARIFICATION"
class DRPromptPurpose(str, Enum):
PLAN = "PLAN"
NEXT_STEP = "NEXT_STEP"
NEXT_STEP_REASONING = "NEXT_STEP_REASONING"
NEXT_STEP_PURPOSE = "NEXT_STEP_PURPOSE"
CLARIFICATION = "CLARIFICATION"
# class BaseSearchProcessingResponse(BaseModel):
# specified_source_types: list[str]
# rewritten_query: str
# time_filter: str
class BaseSearchProcessingResponse(BaseModel):
specified_source_types: list[str]
rewritten_query: str
time_filter: str

File diff suppressed because it is too large Load Diff

View File

@@ -1,418 +1,423 @@
# import re
# from datetime import datetime
# from typing import cast
import re
from datetime import datetime
from typing import cast
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
# from sqlalchemy.orm import Session
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from sqlalchemy.orm import Session
# from onyx.agents.agent_search.dr.constants import MAX_CHAT_HISTORY_MESSAGES
# from onyx.agents.agent_search.dr.constants import MAX_NUM_CLOSER_SUGGESTIONS
# from onyx.agents.agent_search.dr.enums import DRPath
# from onyx.agents.agent_search.dr.enums import ResearchAnswerPurpose
# from onyx.agents.agent_search.dr.models import AggregatedDRContext
# from onyx.agents.agent_search.dr.models import TestInfoCompleteResponse
# from onyx.agents.agent_search.dr.states import FinalUpdate
# from onyx.agents.agent_search.dr.states import MainState
# from onyx.agents.agent_search.dr.states import OrchestrationUpdate
# from onyx.agents.agent_search.dr.sub_agents.image_generation.models import (
# GeneratedImageFullResult,
# )
# from onyx.agents.agent_search.dr.utils import aggregate_context
# from onyx.agents.agent_search.dr.utils import convert_inference_sections_to_search_docs
# from onyx.agents.agent_search.dr.utils import get_chat_history_string
# from onyx.agents.agent_search.dr.utils import get_prompt_question
# from onyx.agents.agent_search.dr.utils import parse_plan_to_dict
# from onyx.agents.agent_search.models import GraphConfig
# from onyx.agents.agent_search.shared_graph_utils.llm import invoke_llm_json
# from onyx.agents.agent_search.shared_graph_utils.llm import stream_llm_answer
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
# from onyx.agents.agent_search.utils import create_question_prompt
# from onyx.chat.chat_utils import llm_doc_from_inference_section
# from onyx.configs.agent_configs import TF_DR_TIMEOUT_LONG
# from onyx.context.search.models import InferenceSection
# from onyx.db.chat import create_search_doc_from_inference_section
# from onyx.db.chat import update_db_session_with_messages
# from onyx.db.models import ChatMessage__SearchDoc
# from onyx.db.models import ResearchAgentIteration
# from onyx.db.models import ResearchAgentIterationSubStep
# from onyx.db.models import SearchDoc as DbSearchDoc
# from onyx.llm.utils import check_number_of_tokens
# from onyx.prompts.chat_prompts import PROJECT_INSTRUCTIONS_SEPARATOR
# from onyx.prompts.dr_prompts import FINAL_ANSWER_PROMPT_W_SUB_ANSWERS
# from onyx.prompts.dr_prompts import FINAL_ANSWER_PROMPT_WITHOUT_SUB_ANSWERS
# from onyx.prompts.dr_prompts import TEST_INFO_COMPLETE_PROMPT
# from onyx.server.query_and_chat.streaming_models import CitationDelta
# from onyx.server.query_and_chat.streaming_models import CitationStart
# from onyx.server.query_and_chat.streaming_models import MessageStart
# from onyx.server.query_and_chat.streaming_models import SectionEnd
# from onyx.server.query_and_chat.streaming_models import StreamingType
# from onyx.utils.logger import setup_logger
# from onyx.utils.threadpool_concurrency import run_with_timeout
from onyx.agents.agent_search.dr.constants import MAX_CHAT_HISTORY_MESSAGES
from onyx.agents.agent_search.dr.constants import MAX_NUM_CLOSER_SUGGESTIONS
from onyx.agents.agent_search.dr.enums import DRPath
from onyx.agents.agent_search.dr.enums import ResearchAnswerPurpose
from onyx.agents.agent_search.dr.enums import ResearchType
from onyx.agents.agent_search.dr.models import AggregatedDRContext
from onyx.agents.agent_search.dr.models import TestInfoCompleteResponse
from onyx.agents.agent_search.dr.states import FinalUpdate
from onyx.agents.agent_search.dr.states import MainState
from onyx.agents.agent_search.dr.states import OrchestrationUpdate
from onyx.agents.agent_search.dr.sub_agents.image_generation.models import (
GeneratedImageFullResult,
)
from onyx.agents.agent_search.dr.utils import aggregate_context
from onyx.agents.agent_search.dr.utils import convert_inference_sections_to_search_docs
from onyx.agents.agent_search.dr.utils import get_chat_history_string
from onyx.agents.agent_search.dr.utils import get_prompt_question
from onyx.agents.agent_search.dr.utils import parse_plan_to_dict
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.llm import invoke_llm_json
from onyx.agents.agent_search.shared_graph_utils.llm import stream_llm_answer
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.agents.agent_search.utils import create_question_prompt
from onyx.chat.chat_utils import llm_doc_from_inference_section
from onyx.configs.agent_configs import TF_DR_TIMEOUT_LONG
from onyx.context.search.models import InferenceSection
from onyx.db.chat import create_search_doc_from_inference_section
from onyx.db.chat import update_db_session_with_messages
from onyx.db.models import ChatMessage__SearchDoc
from onyx.db.models import ResearchAgentIteration
from onyx.db.models import ResearchAgentIterationSubStep
from onyx.db.models import SearchDoc as DbSearchDoc
from onyx.llm.utils import check_number_of_tokens
from onyx.prompts.chat_prompts import PROJECT_INSTRUCTIONS_SEPARATOR
from onyx.prompts.dr_prompts import FINAL_ANSWER_PROMPT_W_SUB_ANSWERS
from onyx.prompts.dr_prompts import FINAL_ANSWER_PROMPT_WITHOUT_SUB_ANSWERS
from onyx.prompts.dr_prompts import TEST_INFO_COMPLETE_PROMPT
from onyx.server.query_and_chat.streaming_models import CitationDelta
from onyx.server.query_and_chat.streaming_models import CitationStart
from onyx.server.query_and_chat.streaming_models import MessageStart
from onyx.server.query_and_chat.streaming_models import SectionEnd
from onyx.server.query_and_chat.streaming_models import StreamingType
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
# logger = setup_logger()
logger = setup_logger()
# def extract_citation_numbers(text: str) -> list[int]:
# """
# Extract all citation numbers from text in the format [[<number>]] or [[<number_1>, <number_2>, ...]].
# Returns a list of all unique citation numbers found.
# """
# # Pattern to match [[number]] or [[number1, number2, ...]]
# pattern = r"\[\[(\d+(?:,\s*\d+)*)\]\]"
# matches = re.findall(pattern, text)
def extract_citation_numbers(text: str) -> list[int]:
"""
Extract all citation numbers from text in the format [[<number>]] or [[<number_1>, <number_2>, ...]].
Returns a list of all unique citation numbers found.
"""
# Pattern to match [[number]] or [[number1, number2, ...]]
pattern = r"\[\[(\d+(?:,\s*\d+)*)\]\]"
matches = re.findall(pattern, text)
# cited_numbers = []
# for match in matches:
# # Split by comma and extract all numbers
# numbers = [int(num.strip()) for num in match.split(",")]
# cited_numbers.extend(numbers)
cited_numbers = []
for match in matches:
# Split by comma and extract all numbers
numbers = [int(num.strip()) for num in match.split(",")]
cited_numbers.extend(numbers)
# return list(set(cited_numbers)) # Return unique numbers
return list(set(cited_numbers)) # Return unique numbers
# def replace_citation_with_link(match: re.Match[str], docs: list[DbSearchDoc]) -> str:
# citation_content = match.group(1) # e.g., "3" or "3, 5, 7"
# numbers = [int(num.strip()) for num in citation_content.split(",")]
def replace_citation_with_link(match: re.Match[str], docs: list[DbSearchDoc]) -> str:
citation_content = match.group(1) # e.g., "3" or "3, 5, 7"
numbers = [int(num.strip()) for num in citation_content.split(",")]
# # For multiple citations like [[3, 5, 7]], create separate linked citations
# linked_citations = []
# for num in numbers:
# if num - 1 < len(docs): # Check bounds
# link = docs[num - 1].link or ""
# linked_citations.append(f"[[{num}]]({link})")
# else:
# linked_citations.append(f"[[{num}]]") # No link if out of bounds
# For multiple citations like [[3, 5, 7]], create separate linked citations
linked_citations = []
for num in numbers:
if num - 1 < len(docs): # Check bounds
link = docs[num - 1].link or ""
linked_citations.append(f"[[{num}]]({link})")
else:
linked_citations.append(f"[[{num}]]") # No link if out of bounds
# return "".join(linked_citations)
return "".join(linked_citations)
# def insert_chat_message_search_doc_pair(
# message_id: int, search_doc_ids: list[int], db_session: Session
# ) -> None:
# """
# Insert a pair of message_id and search_doc_id into the chat_message__search_doc table.
def insert_chat_message_search_doc_pair(
message_id: int, search_doc_ids: list[int], db_session: Session
) -> None:
"""
Insert a pair of message_id and search_doc_id into the chat_message__search_doc table.
# Args:
# message_id: The ID of the chat message
# search_doc_id: The ID of the search document
# db_session: The database session
# """
# for search_doc_id in search_doc_ids:
# chat_message_search_doc = ChatMessage__SearchDoc(
# chat_message_id=message_id, search_doc_id=search_doc_id
# )
# db_session.add(chat_message_search_doc)
Args:
message_id: The ID of the chat message
search_doc_id: The ID of the search document
db_session: The database session
"""
for search_doc_id in search_doc_ids:
chat_message_search_doc = ChatMessage__SearchDoc(
chat_message_id=message_id, search_doc_id=search_doc_id
)
db_session.add(chat_message_search_doc)
# def save_iteration(
# state: MainState,
# graph_config: GraphConfig,
# aggregated_context: AggregatedDRContext,
# final_answer: str,
# all_cited_documents: list[InferenceSection],
# is_internet_marker_dict: dict[str, bool],
# ) -> None:
# db_session = graph_config.persistence.db_session
# message_id = graph_config.persistence.message_id
def save_iteration(
state: MainState,
graph_config: GraphConfig,
aggregated_context: AggregatedDRContext,
final_answer: str,
all_cited_documents: list[InferenceSection],
is_internet_marker_dict: dict[str, bool],
) -> None:
db_session = graph_config.persistence.db_session
message_id = graph_config.persistence.message_id
research_type = graph_config.behavior.research_type
db_session = graph_config.persistence.db_session
# # first, insert the search_docs
# search_docs = [
# create_search_doc_from_inference_section(
# inference_section=inference_section,
# is_internet=is_internet_marker_dict.get(
# inference_section.center_chunk.document_id, False
# ), # TODO: revisit
# db_session=db_session,
# commit=False,
# )
# for inference_section in all_cited_documents
# ]
# first, insert the search_docs
search_docs = [
create_search_doc_from_inference_section(
inference_section=inference_section,
is_internet=is_internet_marker_dict.get(
inference_section.center_chunk.document_id, False
), # TODO: revisit
db_session=db_session,
commit=False,
)
for inference_section in all_cited_documents
]
# # then, map_search_docs to message
# insert_chat_message_search_doc_pair(
# message_id, [search_doc.id for search_doc in search_docs], db_session
# )
# then, map_search_docs to message
insert_chat_message_search_doc_pair(
message_id, [search_doc.id for search_doc in search_docs], db_session
)
# # lastly, insert the citations
# citation_dict: dict[int, int] = {}
# cited_doc_nrs = extract_citation_numbers(final_answer)
# for cited_doc_nr in cited_doc_nrs:
# citation_dict[cited_doc_nr] = search_docs[cited_doc_nr - 1].id
# lastly, insert the citations
citation_dict: dict[int, int] = {}
cited_doc_nrs = extract_citation_numbers(final_answer)
for cited_doc_nr in cited_doc_nrs:
citation_dict[cited_doc_nr] = search_docs[cited_doc_nr - 1].id
# # TODO: generate plan as dict in the first place
# plan_of_record = state.plan_of_record.plan if state.plan_of_record else ""
# plan_of_record_dict = parse_plan_to_dict(plan_of_record)
# TODO: generate plan as dict in the first place
plan_of_record = state.plan_of_record.plan if state.plan_of_record else ""
plan_of_record_dict = parse_plan_to_dict(plan_of_record)
# # Update the chat message and its parent message in database
# update_db_session_with_messages(
# db_session=db_session,
# chat_message_id=message_id,
# chat_session_id=graph_config.persistence.chat_session_id,
# is_agentic=graph_config.behavior.use_agentic_search,
# message=final_answer,
# citations=citation_dict,
# research_type=None, # research_type is deprecated
# research_plan=plan_of_record_dict,
# final_documents=search_docs,
# update_parent_message=True,
# research_answer_purpose=ResearchAnswerPurpose.ANSWER,
# )
# Update the chat message and its parent message in database
update_db_session_with_messages(
db_session=db_session,
chat_message_id=message_id,
chat_session_id=graph_config.persistence.chat_session_id,
is_agentic=graph_config.behavior.use_agentic_search,
message=final_answer,
citations=citation_dict,
research_type=research_type,
research_plan=plan_of_record_dict,
final_documents=search_docs,
update_parent_message=True,
research_answer_purpose=ResearchAnswerPurpose.ANSWER,
)
# for iteration_preparation in state.iteration_instructions:
# research_agent_iteration_step = ResearchAgentIteration(
# primary_question_id=message_id,
# reasoning=iteration_preparation.reasoning,
# purpose=iteration_preparation.purpose,
# iteration_nr=iteration_preparation.iteration_nr,
# )
# db_session.add(research_agent_iteration_step)
for iteration_preparation in state.iteration_instructions:
research_agent_iteration_step = ResearchAgentIteration(
primary_question_id=message_id,
reasoning=iteration_preparation.reasoning,
purpose=iteration_preparation.purpose,
iteration_nr=iteration_preparation.iteration_nr,
)
db_session.add(research_agent_iteration_step)
# for iteration_answer in aggregated_context.global_iteration_responses:
for iteration_answer in aggregated_context.global_iteration_responses:
# retrieved_search_docs = convert_inference_sections_to_search_docs(
# list(iteration_answer.cited_documents.values())
# )
retrieved_search_docs = convert_inference_sections_to_search_docs(
list(iteration_answer.cited_documents.values())
)
# # Convert SavedSearchDoc objects to JSON-serializable format
# serialized_search_docs = [doc.model_dump() for doc in retrieved_search_docs]
# Convert SavedSearchDoc objects to JSON-serializable format
serialized_search_docs = [doc.model_dump() for doc in retrieved_search_docs]
# research_agent_iteration_sub_step = ResearchAgentIterationSubStep(
# primary_question_id=message_id,
# iteration_nr=iteration_answer.iteration_nr,
# iteration_sub_step_nr=iteration_answer.parallelization_nr,
# sub_step_instructions=iteration_answer.question,
# sub_step_tool_id=iteration_answer.tool_id,
# sub_answer=iteration_answer.answer,
# reasoning=iteration_answer.reasoning,
# claims=iteration_answer.claims,
# cited_doc_results=serialized_search_docs,
# generated_images=(
# GeneratedImageFullResult(images=iteration_answer.generated_images)
# if iteration_answer.generated_images
# else None
# ),
# additional_data=iteration_answer.additional_data,
# queries=iteration_answer.queries,
# )
# db_session.add(research_agent_iteration_sub_step)
research_agent_iteration_sub_step = ResearchAgentIterationSubStep(
primary_question_id=message_id,
iteration_nr=iteration_answer.iteration_nr,
iteration_sub_step_nr=iteration_answer.parallelization_nr,
sub_step_instructions=iteration_answer.question,
sub_step_tool_id=iteration_answer.tool_id,
sub_answer=iteration_answer.answer,
reasoning=iteration_answer.reasoning,
claims=iteration_answer.claims,
cited_doc_results=serialized_search_docs,
generated_images=(
GeneratedImageFullResult(images=iteration_answer.generated_images)
if iteration_answer.generated_images
else None
),
additional_data=iteration_answer.additional_data,
queries=iteration_answer.queries,
)
db_session.add(research_agent_iteration_sub_step)
# db_session.commit()
db_session.commit()
# def closer(
# state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
# ) -> FinalUpdate | OrchestrationUpdate:
# """
# LangGraph node to close the DR process and finalize the answer.
# """
def closer(
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> FinalUpdate | OrchestrationUpdate:
"""
LangGraph node to close the DR process and finalize the answer.
"""
# node_start_time = datetime.now()
# # TODO: generate final answer using all the previous steps
# # (right now, answers from each step are concatenated onto each other)
# # Also, add missing fields once usage in UI is clear.
node_start_time = datetime.now()
# TODO: generate final answer using all the previous steps
# (right now, answers from each step are concatenated onto each other)
# Also, add missing fields once usage in UI is clear.
# current_step_nr = state.current_step_nr
current_step_nr = state.current_step_nr
# graph_config = cast(GraphConfig, config["metadata"]["config"])
# base_question = state.original_question
# if not base_question:
# raise ValueError("Question is required for closer")
graph_config = cast(GraphConfig, config["metadata"]["config"])
base_question = state.original_question
if not base_question:
raise ValueError("Question is required for closer")
# use_agentic_search = graph_config.behavior.use_agentic_search
research_type = graph_config.behavior.research_type
# assistant_system_prompt: str = state.assistant_system_prompt or ""
# assistant_task_prompt = state.assistant_task_prompt
assistant_system_prompt: str = state.assistant_system_prompt or ""
assistant_task_prompt = state.assistant_task_prompt
# uploaded_context = state.uploaded_test_context or ""
uploaded_context = state.uploaded_test_context or ""
# clarification = state.clarification
# prompt_question = get_prompt_question(base_question, clarification)
clarification = state.clarification
prompt_question = get_prompt_question(base_question, clarification)
# chat_history_string = (
# get_chat_history_string(
# graph_config.inputs.prompt_builder.message_history,
# MAX_CHAT_HISTORY_MESSAGES,
# )
# or "(No chat history yet available)"
# )
chat_history_string = (
get_chat_history_string(
graph_config.inputs.prompt_builder.message_history,
MAX_CHAT_HISTORY_MESSAGES,
)
or "(No chat history yet available)"
)
# aggregated_context_w_docs = aggregate_context(
# state.iteration_responses, include_documents=True
# )
aggregated_context_w_docs = aggregate_context(
state.iteration_responses, include_documents=True
)
# aggregated_context_wo_docs = aggregate_context(
# state.iteration_responses, include_documents=False
# )
aggregated_context_wo_docs = aggregate_context(
state.iteration_responses, include_documents=False
)
# iteration_responses_w_docs_string = aggregated_context_w_docs.context
# iteration_responses_wo_docs_string = aggregated_context_wo_docs.context
# all_cited_documents = aggregated_context_w_docs.cited_documents
iteration_responses_w_docs_string = aggregated_context_w_docs.context
iteration_responses_wo_docs_string = aggregated_context_wo_docs.context
all_cited_documents = aggregated_context_w_docs.cited_documents
# num_closer_suggestions = state.num_closer_suggestions
num_closer_suggestions = state.num_closer_suggestions
# if (
# num_closer_suggestions < MAX_NUM_CLOSER_SUGGESTIONS
# and use_agentic_search
# ):
# test_info_complete_prompt = TEST_INFO_COMPLETE_PROMPT.build(
# base_question=prompt_question,
# questions_answers_claims=iteration_responses_wo_docs_string,
# chat_history_string=chat_history_string,
# high_level_plan=(
# state.plan_of_record.plan
# if state.plan_of_record
# else "No plan available"
# ),
# )
if (
num_closer_suggestions < MAX_NUM_CLOSER_SUGGESTIONS
and research_type == ResearchType.DEEP
):
test_info_complete_prompt = TEST_INFO_COMPLETE_PROMPT.build(
base_question=prompt_question,
questions_answers_claims=iteration_responses_wo_docs_string,
chat_history_string=chat_history_string,
high_level_plan=(
state.plan_of_record.plan
if state.plan_of_record
else "No plan available"
),
)
# test_info_complete_json = invoke_llm_json(
# llm=graph_config.tooling.primary_llm,
# prompt=create_question_prompt(
# assistant_system_prompt,
# test_info_complete_prompt + (assistant_task_prompt or ""),
# ),
# schema=TestInfoCompleteResponse,
# timeout_override=TF_DR_TIMEOUT_LONG,
# # max_tokens=1000,
# )
test_info_complete_json = invoke_llm_json(
llm=graph_config.tooling.primary_llm,
prompt=create_question_prompt(
assistant_system_prompt,
test_info_complete_prompt + (assistant_task_prompt or ""),
),
schema=TestInfoCompleteResponse,
timeout_override=TF_DR_TIMEOUT_LONG,
# max_tokens=1000,
)
# if test_info_complete_json.complete:
# pass
if test_info_complete_json.complete:
pass
# else:
# return OrchestrationUpdate(
# tools_used=[DRPath.ORCHESTRATOR.value],
# query_list=[],
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="main",
# node_name="closer",
# node_start_time=node_start_time,
# )
# ],
# gaps=test_info_complete_json.gaps,
# num_closer_suggestions=num_closer_suggestions + 1,
# )
else:
return OrchestrationUpdate(
tools_used=[DRPath.ORCHESTRATOR.value],
query_list=[],
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="closer",
node_start_time=node_start_time,
)
],
gaps=test_info_complete_json.gaps,
num_closer_suggestions=num_closer_suggestions + 1,
)
# retrieved_search_docs = convert_inference_sections_to_search_docs(
# all_cited_documents
# )
retrieved_search_docs = convert_inference_sections_to_search_docs(
all_cited_documents
)
# write_custom_event(
# current_step_nr,
# MessageStart(
# content="",
# final_documents=retrieved_search_docs,
# ),
# writer,
# )
write_custom_event(
current_step_nr,
MessageStart(
content="",
final_documents=retrieved_search_docs,
),
writer,
)
# if not use_agentic_search:
# final_answer_base_prompt = FINAL_ANSWER_PROMPT_WITHOUT_SUB_ANSWERS
# else:
# final_answer_base_prompt = FINAL_ANSWER_PROMPT_W_SUB_ANSWERS
if research_type in [ResearchType.THOUGHTFUL, ResearchType.FAST]:
final_answer_base_prompt = FINAL_ANSWER_PROMPT_WITHOUT_SUB_ANSWERS
elif research_type == ResearchType.DEEP:
final_answer_base_prompt = FINAL_ANSWER_PROMPT_W_SUB_ANSWERS
else:
raise ValueError(f"Invalid research type: {research_type}")
# estimated_final_answer_prompt_tokens = check_number_of_tokens(
# final_answer_base_prompt.build(
# base_question=prompt_question,
# iteration_responses_string=iteration_responses_w_docs_string,
# chat_history_string=chat_history_string,
# uploaded_context=uploaded_context,
# )
# )
estimated_final_answer_prompt_tokens = check_number_of_tokens(
final_answer_base_prompt.build(
base_question=prompt_question,
iteration_responses_string=iteration_responses_w_docs_string,
chat_history_string=chat_history_string,
uploaded_context=uploaded_context,
)
)
# # for DR, rely only on sub-answers and claims to save tokens if context is too long
# # TODO: consider compression step for Thoughtful mode if context is too long.
# # Should generally not be the case though.
# for DR, rely only on sub-answers and claims to save tokens if context is too long
# TODO: consider compression step for Thoughtful mode if context is too long.
# Should generally not be the case though.
# max_allowed_input_tokens = graph_config.tooling.primary_llm.config.max_input_tokens
max_allowed_input_tokens = graph_config.tooling.primary_llm.config.max_input_tokens
# if (
# estimated_final_answer_prompt_tokens > 0.8 * max_allowed_input_tokens
# and use_agentic_search
# ):
# iteration_responses_string = iteration_responses_wo_docs_string
# else:
# iteration_responses_string = iteration_responses_w_docs_string
if (
estimated_final_answer_prompt_tokens > 0.8 * max_allowed_input_tokens
and research_type == ResearchType.DEEP
):
iteration_responses_string = iteration_responses_wo_docs_string
else:
iteration_responses_string = iteration_responses_w_docs_string
# final_answer_prompt = final_answer_base_prompt.build(
# base_question=prompt_question,
# iteration_responses_string=iteration_responses_string,
# chat_history_string=chat_history_string,
# uploaded_context=uploaded_context,
# )
final_answer_prompt = final_answer_base_prompt.build(
base_question=prompt_question,
iteration_responses_string=iteration_responses_string,
chat_history_string=chat_history_string,
uploaded_context=uploaded_context,
)
# if graph_config.inputs.project_instructions:
# assistant_system_prompt = (
# assistant_system_prompt
# + PROJECT_INSTRUCTIONS_SEPARATOR
# + (graph_config.inputs.project_instructions or "")
# )
if graph_config.inputs.project_instructions:
assistant_system_prompt = (
assistant_system_prompt
+ PROJECT_INSTRUCTIONS_SEPARATOR
+ (graph_config.inputs.project_instructions or "")
)
# all_context_llmdocs = [
# llm_doc_from_inference_section(inference_section)
# for inference_section in all_cited_documents
# ]
all_context_llmdocs = [
llm_doc_from_inference_section(inference_section)
for inference_section in all_cited_documents
]
# try:
# streamed_output, _, citation_infos = run_with_timeout(
# int(3 * TF_DR_TIMEOUT_LONG),
# lambda: stream_llm_answer(
# llm=graph_config.tooling.primary_llm,
# prompt=create_question_prompt(
# assistant_system_prompt,
# final_answer_prompt + (assistant_task_prompt or ""),
# ),
# event_name="basic_response",
# writer=writer,
# agent_answer_level=0,
# agent_answer_question_num=0,
# agent_answer_type="agent_level_answer",
# timeout_override=int(2 * TF_DR_TIMEOUT_LONG),
# answer_piece=StreamingType.MESSAGE_DELTA.value,
# ind=current_step_nr,
# context_docs=all_context_llmdocs,
# replace_citations=True,
# # max_tokens=None,
# ),
# )
try:
streamed_output, _, citation_infos = run_with_timeout(
int(3 * TF_DR_TIMEOUT_LONG),
lambda: stream_llm_answer(
llm=graph_config.tooling.primary_llm,
prompt=create_question_prompt(
assistant_system_prompt,
final_answer_prompt + (assistant_task_prompt or ""),
),
event_name="basic_response",
writer=writer,
agent_answer_level=0,
agent_answer_question_num=0,
agent_answer_type="agent_level_answer",
timeout_override=int(2 * TF_DR_TIMEOUT_LONG),
answer_piece=StreamingType.MESSAGE_DELTA.value,
ind=current_step_nr,
context_docs=all_context_llmdocs,
replace_citations=True,
# max_tokens=None,
),
)
# final_answer = "".join(streamed_output)
# except Exception as e:
# raise ValueError(f"Error in consolidate_research: {e}")
final_answer = "".join(streamed_output)
except Exception as e:
raise ValueError(f"Error in consolidate_research: {e}")
# write_custom_event(current_step_nr, SectionEnd(), writer)
write_custom_event(current_step_nr, SectionEnd(), writer)
# current_step_nr += 1
current_step_nr += 1
# write_custom_event(current_step_nr, CitationStart(), writer)
# write_custom_event(current_step_nr, CitationDelta(citations=citation_infos), writer)
# write_custom_event(current_step_nr, SectionEnd(), writer)
write_custom_event(current_step_nr, CitationStart(), writer)
write_custom_event(current_step_nr, CitationDelta(citations=citation_infos), writer)
write_custom_event(current_step_nr, SectionEnd(), writer)
# current_step_nr += 1
current_step_nr += 1
# # Log the research agent steps
# # save_iteration(
# # state,
# # graph_config,
# # aggregated_context,
# # final_answer,
# # all_cited_documents,
# # is_internet_marker_dict,
# # )
# Log the research agent steps
# save_iteration(
# state,
# graph_config,
# aggregated_context,
# final_answer,
# all_cited_documents,
# is_internet_marker_dict,
# )
# return FinalUpdate(
# final_answer=final_answer,
# all_cited_documents=all_cited_documents,
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="main",
# node_name="closer",
# node_start_time=node_start_time,
# )
# ],
# )
return FinalUpdate(
final_answer=final_answer,
all_cited_documents=all_cited_documents,
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="closer",
node_start_time=node_start_time,
)
],
)

View File

@@ -1,246 +1,248 @@
# import re
# from datetime import datetime
# from typing import cast
import re
from datetime import datetime
from typing import cast
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
# from sqlalchemy.orm import Session
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from sqlalchemy.orm import Session
# from onyx.agents.agent_search.dr.enums import ResearchAnswerPurpose
# from onyx.agents.agent_search.dr.models import AggregatedDRContext
# from onyx.agents.agent_search.dr.states import LoggerUpdate
# from onyx.agents.agent_search.dr.states import MainState
# from onyx.agents.agent_search.dr.sub_agents.image_generation.models import (
# GeneratedImageFullResult,
# )
# from onyx.agents.agent_search.dr.utils import aggregate_context
# from onyx.agents.agent_search.dr.utils import convert_inference_sections_to_search_docs
# from onyx.agents.agent_search.dr.utils import parse_plan_to_dict
# from onyx.agents.agent_search.models import GraphConfig
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
# from onyx.context.search.models import InferenceSection
# from onyx.db.chat import create_search_doc_from_inference_section
# from onyx.db.chat import update_db_session_with_messages
# from onyx.db.models import ChatMessage__SearchDoc
# from onyx.db.models import ResearchAgentIteration
# from onyx.db.models import ResearchAgentIterationSubStep
# from onyx.db.models import SearchDoc as DbSearchDoc
# from onyx.natural_language_processing.utils import get_tokenizer
# from onyx.server.query_and_chat.streaming_models import OverallStop
# from onyx.utils.logger import setup_logger
from onyx.agents.agent_search.dr.enums import ResearchAnswerPurpose
from onyx.agents.agent_search.dr.models import AggregatedDRContext
from onyx.agents.agent_search.dr.states import LoggerUpdate
from onyx.agents.agent_search.dr.states import MainState
from onyx.agents.agent_search.dr.sub_agents.image_generation.models import (
GeneratedImageFullResult,
)
from onyx.agents.agent_search.dr.utils import aggregate_context
from onyx.agents.agent_search.dr.utils import convert_inference_sections_to_search_docs
from onyx.agents.agent_search.dr.utils import parse_plan_to_dict
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.context.search.models import InferenceSection
from onyx.db.chat import create_search_doc_from_inference_section
from onyx.db.chat import update_db_session_with_messages
from onyx.db.models import ChatMessage__SearchDoc
from onyx.db.models import ResearchAgentIteration
from onyx.db.models import ResearchAgentIterationSubStep
from onyx.db.models import SearchDoc as DbSearchDoc
from onyx.natural_language_processing.utils import get_tokenizer
from onyx.server.query_and_chat.streaming_models import OverallStop
from onyx.utils.logger import setup_logger
# logger = setup_logger()
logger = setup_logger()
# def _extract_citation_numbers(text: str) -> list[int]:
# """
# Extract all citation numbers from text in the format [[<number>]] or [[<number_1>, <number_2>, ...]].
# Returns a list of all unique citation numbers found.
# """
# # Pattern to match [[number]] or [[number1, number2, ...]]
# pattern = r"\[\[(\d+(?:,\s*\d+)*)\]\]"
# matches = re.findall(pattern, text)
def _extract_citation_numbers(text: str) -> list[int]:
"""
Extract all citation numbers from text in the format [[<number>]] or [[<number_1>, <number_2>, ...]].
Returns a list of all unique citation numbers found.
"""
# Pattern to match [[number]] or [[number1, number2, ...]]
pattern = r"\[\[(\d+(?:,\s*\d+)*)\]\]"
matches = re.findall(pattern, text)
# cited_numbers = []
# for match in matches:
# # Split by comma and extract all numbers
# numbers = [int(num.strip()) for num in match.split(",")]
# cited_numbers.extend(numbers)
cited_numbers = []
for match in matches:
# Split by comma and extract all numbers
numbers = [int(num.strip()) for num in match.split(",")]
cited_numbers.extend(numbers)
# return list(set(cited_numbers)) # Return unique numbers
return list(set(cited_numbers)) # Return unique numbers
# def replace_citation_with_link(match: re.Match[str], docs: list[DbSearchDoc]) -> str:
# citation_content = match.group(1) # e.g., "3" or "3, 5, 7"
# numbers = [int(num.strip()) for num in citation_content.split(",")]
def replace_citation_with_link(match: re.Match[str], docs: list[DbSearchDoc]) -> str:
citation_content = match.group(1) # e.g., "3" or "3, 5, 7"
numbers = [int(num.strip()) for num in citation_content.split(",")]
# # For multiple citations like [[3, 5, 7]], create separate linked citations
# linked_citations = []
# for num in numbers:
# if num - 1 < len(docs): # Check bounds
# link = docs[num - 1].link or ""
# linked_citations.append(f"[[{num}]]({link})")
# else:
# linked_citations.append(f"[[{num}]]") # No link if out of bounds
# For multiple citations like [[3, 5, 7]], create separate linked citations
linked_citations = []
for num in numbers:
if num - 1 < len(docs): # Check bounds
link = docs[num - 1].link or ""
linked_citations.append(f"[[{num}]]({link})")
else:
linked_citations.append(f"[[{num}]]") # No link if out of bounds
# return "".join(linked_citations)
return "".join(linked_citations)
# def _insert_chat_message_search_doc_pair(
# message_id: int, search_doc_ids: list[int], db_session: Session
# ) -> None:
# """
# Insert a pair of message_id and search_doc_id into the chat_message__search_doc table.
def _insert_chat_message_search_doc_pair(
message_id: int, search_doc_ids: list[int], db_session: Session
) -> None:
"""
Insert a pair of message_id and search_doc_id into the chat_message__search_doc table.
# Args:
# message_id: The ID of the chat message
# search_doc_id: The ID of the search document
# db_session: The database session
# """
# for search_doc_id in search_doc_ids:
# chat_message_search_doc = ChatMessage__SearchDoc(
# chat_message_id=message_id, search_doc_id=search_doc_id
# )
# db_session.add(chat_message_search_doc)
Args:
message_id: The ID of the chat message
search_doc_id: The ID of the search document
db_session: The database session
"""
for search_doc_id in search_doc_ids:
chat_message_search_doc = ChatMessage__SearchDoc(
chat_message_id=message_id, search_doc_id=search_doc_id
)
db_session.add(chat_message_search_doc)
# def save_iteration(
# state: MainState,
# graph_config: GraphConfig,
# aggregated_context: AggregatedDRContext,
# final_answer: str,
# all_cited_documents: list[InferenceSection],
# is_internet_marker_dict: dict[str, bool],
# num_tokens: int,
# ) -> None:
# db_session = graph_config.persistence.db_session
# message_id = graph_config.persistence.message_id
def save_iteration(
state: MainState,
graph_config: GraphConfig,
aggregated_context: AggregatedDRContext,
final_answer: str,
all_cited_documents: list[InferenceSection],
is_internet_marker_dict: dict[str, bool],
num_tokens: int,
) -> None:
db_session = graph_config.persistence.db_session
message_id = graph_config.persistence.message_id
research_type = graph_config.behavior.research_type
db_session = graph_config.persistence.db_session
# # first, insert the search_docs
# search_docs = [
# create_search_doc_from_inference_section(
# inference_section=inference_section,
# is_internet=is_internet_marker_dict.get(
# inference_section.center_chunk.document_id, False
# ), # TODO: revisit
# db_session=db_session,
# commit=False,
# )
# for inference_section in all_cited_documents
# ]
# first, insert the search_docs
search_docs = [
create_search_doc_from_inference_section(
inference_section=inference_section,
is_internet=is_internet_marker_dict.get(
inference_section.center_chunk.document_id, False
), # TODO: revisit
db_session=db_session,
commit=False,
)
for inference_section in all_cited_documents
]
# # then, map_search_docs to message
# _insert_chat_message_search_doc_pair(
# message_id, [search_doc.id for search_doc in search_docs], db_session
# )
# then, map_search_docs to message
_insert_chat_message_search_doc_pair(
message_id, [search_doc.id for search_doc in search_docs], db_session
)
# # lastly, insert the citations
# citation_dict: dict[int, int] = {}
# cited_doc_nrs = _extract_citation_numbers(final_answer)
# if search_docs:
# for cited_doc_nr in cited_doc_nrs:
# citation_dict[cited_doc_nr] = search_docs[cited_doc_nr - 1].id
# lastly, insert the citations
citation_dict: dict[int, int] = {}
cited_doc_nrs = _extract_citation_numbers(final_answer)
if search_docs:
for cited_doc_nr in cited_doc_nrs:
citation_dict[cited_doc_nr] = search_docs[cited_doc_nr - 1].id
# # TODO: generate plan as dict in the first place
# plan_of_record = state.plan_of_record.plan if state.plan_of_record else ""
# plan_of_record_dict = parse_plan_to_dict(plan_of_record)
# TODO: generate plan as dict in the first place
plan_of_record = state.plan_of_record.plan if state.plan_of_record else ""
plan_of_record_dict = parse_plan_to_dict(plan_of_record)
# # Update the chat message and its parent message in database
# update_db_session_with_messages(
# db_session=db_session,
# chat_message_id=message_id,
# chat_session_id=graph_config.persistence.chat_session_id,
# is_agentic=graph_config.behavior.use_agentic_search,
# message=final_answer,
# citations=citation_dict,
# research_type=None, # research_type is deprecated
# research_plan=plan_of_record_dict,
# final_documents=search_docs,
# update_parent_message=True,
# research_answer_purpose=ResearchAnswerPurpose.ANSWER,
# token_count=num_tokens,
# )
# Update the chat message and its parent message in database
update_db_session_with_messages(
db_session=db_session,
chat_message_id=message_id,
chat_session_id=graph_config.persistence.chat_session_id,
is_agentic=graph_config.behavior.use_agentic_search,
message=final_answer,
citations=citation_dict,
research_type=research_type,
research_plan=plan_of_record_dict,
final_documents=search_docs,
update_parent_message=True,
research_answer_purpose=ResearchAnswerPurpose.ANSWER,
token_count=num_tokens,
)
# for iteration_preparation in state.iteration_instructions:
# research_agent_iteration_step = ResearchAgentIteration(
# primary_question_id=message_id,
# reasoning=iteration_preparation.reasoning,
# purpose=iteration_preparation.purpose,
# iteration_nr=iteration_preparation.iteration_nr,
# )
# db_session.add(research_agent_iteration_step)
for iteration_preparation in state.iteration_instructions:
research_agent_iteration_step = ResearchAgentIteration(
primary_question_id=message_id,
reasoning=iteration_preparation.reasoning,
purpose=iteration_preparation.purpose,
iteration_nr=iteration_preparation.iteration_nr,
)
db_session.add(research_agent_iteration_step)
# for iteration_answer in aggregated_context.global_iteration_responses:
for iteration_answer in aggregated_context.global_iteration_responses:
# retrieved_search_docs = convert_inference_sections_to_search_docs(
# list(iteration_answer.cited_documents.values())
# )
retrieved_search_docs = convert_inference_sections_to_search_docs(
list(iteration_answer.cited_documents.values())
)
# # Convert SavedSearchDoc objects to JSON-serializable format
# serialized_search_docs = [doc.model_dump() for doc in retrieved_search_docs]
# Convert SavedSearchDoc objects to JSON-serializable format
serialized_search_docs = [doc.model_dump() for doc in retrieved_search_docs]
# research_agent_iteration_sub_step = ResearchAgentIterationSubStep(
# primary_question_id=message_id,
# iteration_nr=iteration_answer.iteration_nr,
# iteration_sub_step_nr=iteration_answer.parallelization_nr,
# sub_step_instructions=iteration_answer.question,
# sub_step_tool_id=iteration_answer.tool_id,
# sub_answer=iteration_answer.answer,
# reasoning=iteration_answer.reasoning,
# claims=iteration_answer.claims,
# cited_doc_results=serialized_search_docs,
# generated_images=(
# GeneratedImageFullResult(images=iteration_answer.generated_images)
# if iteration_answer.generated_images
# else None
# ),
# additional_data=iteration_answer.additional_data,
# queries=iteration_answer.queries,
# )
# db_session.add(research_agent_iteration_sub_step)
research_agent_iteration_sub_step = ResearchAgentIterationSubStep(
primary_question_id=message_id,
iteration_nr=iteration_answer.iteration_nr,
iteration_sub_step_nr=iteration_answer.parallelization_nr,
sub_step_instructions=iteration_answer.question,
sub_step_tool_id=iteration_answer.tool_id,
sub_answer=iteration_answer.answer,
reasoning=iteration_answer.reasoning,
claims=iteration_answer.claims,
cited_doc_results=serialized_search_docs,
generated_images=(
GeneratedImageFullResult(images=iteration_answer.generated_images)
if iteration_answer.generated_images
else None
),
additional_data=iteration_answer.additional_data,
queries=iteration_answer.queries,
)
db_session.add(research_agent_iteration_sub_step)
# db_session.commit()
db_session.commit()
# def logging(
# state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
# ) -> LoggerUpdate:
# """
# LangGraph node to close the DR process and finalize the answer.
# """
def logging(
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> LoggerUpdate:
"""
LangGraph node to close the DR process and finalize the answer.
"""
# node_start_time = datetime.now()
# # TODO: generate final answer using all the previous steps
# # (right now, answers from each step are concatenated onto each other)
# # Also, add missing fields once usage in UI is clear.
node_start_time = datetime.now()
# TODO: generate final answer using all the previous steps
# (right now, answers from each step are concatenated onto each other)
# Also, add missing fields once usage in UI is clear.
# current_step_nr = state.current_step_nr
current_step_nr = state.current_step_nr
# graph_config = cast(GraphConfig, config["metadata"]["config"])
# base_question = state.original_question
# if not base_question:
# raise ValueError("Question is required for closer")
graph_config = cast(GraphConfig, config["metadata"]["config"])
base_question = state.original_question
if not base_question:
raise ValueError("Question is required for closer")
# aggregated_context = aggregate_context(
# state.iteration_responses, include_documents=True
# )
aggregated_context = aggregate_context(
state.iteration_responses, include_documents=True
)
# all_cited_documents = aggregated_context.cited_documents
all_cited_documents = aggregated_context.cited_documents
# is_internet_marker_dict = aggregated_context.is_internet_marker_dict
is_internet_marker_dict = aggregated_context.is_internet_marker_dict
# final_answer = state.final_answer or ""
# llm_provider = graph_config.tooling.primary_llm.config.model_provider
# llm_model_name = graph_config.tooling.primary_llm.config.model_name
final_answer = state.final_answer or ""
llm_provider = graph_config.tooling.primary_llm.config.model_provider
llm_model_name = graph_config.tooling.primary_llm.config.model_name
# llm_tokenizer = get_tokenizer(
# model_name=llm_model_name,
# provider_type=llm_provider,
# )
# num_tokens = len(llm_tokenizer.encode(final_answer or ""))
llm_tokenizer = get_tokenizer(
model_name=llm_model_name,
provider_type=llm_provider,
)
num_tokens = len(llm_tokenizer.encode(final_answer or ""))
# write_custom_event(current_step_nr, OverallStop(), writer)
write_custom_event(current_step_nr, OverallStop(), writer)
# # Log the research agent steps
# save_iteration(
# state,
# graph_config,
# aggregated_context,
# final_answer,
# all_cited_documents,
# is_internet_marker_dict,
# num_tokens,
# )
# Log the research agent steps
save_iteration(
state,
graph_config,
aggregated_context,
final_answer,
all_cited_documents,
is_internet_marker_dict,
num_tokens,
)
# return LoggerUpdate(
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="main",
# node_name="logger",
# node_start_time=node_start_time,
# )
# ],
# )
return LoggerUpdate(
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="logger",
node_start_time=node_start_time,
)
],
)

View File

@@ -1,131 +1,132 @@
# from collections.abc import Iterator
# from typing import cast
from collections.abc import Iterator
from typing import cast
# from langchain_core.messages import AIMessageChunk
# from langchain_core.messages import BaseMessage
# from langgraph.types import StreamWriter
# from pydantic import BaseModel
from langchain_core.messages import AIMessageChunk
from langchain_core.messages import BaseMessage
from langgraph.types import StreamWriter
from pydantic import BaseModel
# from onyx.chat.chat_utils import saved_search_docs_from_llm_docs
# from onyx.chat.models import AgentAnswerPiece
# from onyx.chat.models import CitationInfo
# from onyx.chat.models import LlmDoc
# from onyx.chat.models import OnyxAnswerPiece
# from onyx.chat.stream_processing.answer_response_handler import AnswerResponseHandler
# from onyx.chat.stream_processing.answer_response_handler import CitationResponseHandler
# from onyx.chat.stream_processing.answer_response_handler import (
# PassThroughAnswerResponseHandler,
# )
# from onyx.chat.stream_processing.utils import map_document_id_order
# from onyx.context.search.models import InferenceSection
# from onyx.server.query_and_chat.streaming_models import CitationDelta
# from onyx.server.query_and_chat.streaming_models import CitationStart
# from onyx.server.query_and_chat.streaming_models import MessageDelta
# from onyx.server.query_and_chat.streaming_models import MessageStart
# from onyx.server.query_and_chat.streaming_models import SectionEnd
# from onyx.utils.logger import setup_logger
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.chat_utils import saved_search_docs_from_llm_docs
from onyx.chat.models import AgentAnswerPiece
from onyx.chat.models import CitationInfo
from onyx.chat.models import LlmDoc
from onyx.chat.models import OnyxAnswerPiece
from onyx.chat.stream_processing.answer_response_handler import AnswerResponseHandler
from onyx.chat.stream_processing.answer_response_handler import CitationResponseHandler
from onyx.chat.stream_processing.answer_response_handler import (
PassThroughAnswerResponseHandler,
)
from onyx.chat.stream_processing.utils import map_document_id_order
from onyx.context.search.models import InferenceSection
from onyx.server.query_and_chat.streaming_models import CitationDelta
from onyx.server.query_and_chat.streaming_models import CitationStart
from onyx.server.query_and_chat.streaming_models import MessageDelta
from onyx.server.query_and_chat.streaming_models import MessageStart
from onyx.server.query_and_chat.streaming_models import SectionEnd
from onyx.utils.logger import setup_logger
# logger = setup_logger()
logger = setup_logger()
# class BasicSearchProcessedStreamResults(BaseModel):
# ai_message_chunk: AIMessageChunk = AIMessageChunk(content="")
# full_answer: str | None = None
# cited_references: list[InferenceSection] = []
# retrieved_documents: list[LlmDoc] = []
class BasicSearchProcessedStreamResults(BaseModel):
ai_message_chunk: AIMessageChunk = AIMessageChunk(content="")
full_answer: str | None = None
cited_references: list[InferenceSection] = []
retrieved_documents: list[LlmDoc] = []
# def process_llm_stream(
# messages: Iterator[BaseMessage],
# should_stream_answer: bool,
# writer: StreamWriter,
# ind: int,
# search_results: list[LlmDoc] | None = None,
# generate_final_answer: bool = False,
# chat_message_id: str | None = None,
# ) -> BasicSearchProcessedStreamResults:
# tool_call_chunk = AIMessageChunk(content="")
def process_llm_stream(
messages: Iterator[BaseMessage],
should_stream_answer: bool,
writer: StreamWriter,
ind: int,
search_results: list[LlmDoc] | None = None,
generate_final_answer: bool = False,
chat_message_id: str | None = None,
) -> BasicSearchProcessedStreamResults:
tool_call_chunk = AIMessageChunk(content="")
# if search_results:
# answer_handler: AnswerResponseHandler = CitationResponseHandler(
# context_docs=search_results,
# doc_id_to_rank_map=map_document_id_order(search_results),
# )
# else:
# answer_handler = PassThroughAnswerResponseHandler()
if search_results:
answer_handler: AnswerResponseHandler = CitationResponseHandler(
context_docs=search_results,
doc_id_to_rank_map=map_document_id_order(search_results),
)
else:
answer_handler = PassThroughAnswerResponseHandler()
# full_answer = ""
# start_final_answer_streaming_set = False
# # Accumulate citation infos if handler emits them
# collected_citation_infos: list[CitationInfo] = []
full_answer = ""
start_final_answer_streaming_set = False
# Accumulate citation infos if handler emits them
collected_citation_infos: list[CitationInfo] = []
# # This stream will be the llm answer if no tool is chosen. When a tool is chosen,
# # the stream will contain AIMessageChunks with tool call information.
# for message in messages:
# This stream will be the llm answer if no tool is chosen. When a tool is chosen,
# the stream will contain AIMessageChunks with tool call information.
for message in messages:
# answer_piece = message.content
# if not isinstance(answer_piece, str):
# # this is only used for logging, so fine to
# # just add the string representation
# answer_piece = str(answer_piece)
# full_answer += answer_piece
answer_piece = message.content
if not isinstance(answer_piece, str):
# this is only used for logging, so fine to
# just add the string representation
answer_piece = str(answer_piece)
full_answer += answer_piece
# if isinstance(message, AIMessageChunk) and (
# message.tool_call_chunks or message.tool_calls
# ):
# tool_call_chunk += message # type: ignore
# elif should_stream_answer:
# for response_part in answer_handler.handle_response_part(message):
if isinstance(message, AIMessageChunk) and (
message.tool_call_chunks or message.tool_calls
):
tool_call_chunk += message # type: ignore
elif should_stream_answer:
for response_part in answer_handler.handle_response_part(message):
# # only stream out answer parts
# if (
# isinstance(response_part, (OnyxAnswerPiece, AgentAnswerPiece))
# and generate_final_answer
# and response_part.answer_piece
# ):
# if chat_message_id is None:
# raise ValueError(
# "chat_message_id is required when generating final answer"
# )
# only stream out answer parts
if (
isinstance(response_part, (OnyxAnswerPiece, AgentAnswerPiece))
and generate_final_answer
and response_part.answer_piece
):
if chat_message_id is None:
raise ValueError(
"chat_message_id is required when generating final answer"
)
# if not start_final_answer_streaming_set:
# # Convert LlmDocs to SavedSearchDocs
# saved_search_docs = saved_search_docs_from_llm_docs(
# search_results
# )
# write_custom_event(
# ind,
# MessageStart(content="", final_documents=saved_search_docs),
# writer,
# )
# start_final_answer_streaming_set = True
if not start_final_answer_streaming_set:
# Convert LlmDocs to SavedSearchDocs
saved_search_docs = saved_search_docs_from_llm_docs(
search_results
)
write_custom_event(
ind,
MessageStart(content="", final_documents=saved_search_docs),
writer,
)
start_final_answer_streaming_set = True
# write_custom_event(
# ind,
# MessageDelta(content=response_part.answer_piece),
# writer,
# )
# # collect citation info objects
# elif isinstance(response_part, CitationInfo):
# collected_citation_infos.append(response_part)
write_custom_event(
ind,
MessageDelta(content=response_part.answer_piece),
writer,
)
# collect citation info objects
elif isinstance(response_part, CitationInfo):
collected_citation_infos.append(response_part)
# if generate_final_answer and start_final_answer_streaming_set:
# # start_final_answer_streaming_set is only set if the answer is verbal and not a tool call
# write_custom_event(
# ind,
# SectionEnd(),
# writer,
# )
if generate_final_answer and start_final_answer_streaming_set:
# start_final_answer_streaming_set is only set if the answer is verbal and not a tool call
write_custom_event(
ind,
SectionEnd(),
writer,
)
# # Emit citations section if any were collected
# if collected_citation_infos:
# write_custom_event(ind, CitationStart(), writer)
# write_custom_event(
# ind, CitationDelta(citations=collected_citation_infos), writer
# )
# write_custom_event(ind, SectionEnd(), writer)
# Emit citations section if any were collected
if collected_citation_infos:
write_custom_event(ind, CitationStart(), writer)
write_custom_event(
ind, CitationDelta(citations=collected_citation_infos), writer
)
write_custom_event(ind, SectionEnd(), writer)
# logger.debug(f"Full answer: {full_answer}")
# return BasicSearchProcessedStreamResults(
# ai_message_chunk=cast(AIMessageChunk, tool_call_chunk), full_answer=full_answer
# )
logger.debug(f"Full answer: {full_answer}")
return BasicSearchProcessedStreamResults(
ai_message_chunk=cast(AIMessageChunk, tool_call_chunk), full_answer=full_answer
)

View File

@@ -1,82 +1,82 @@
# from operator import add
# from typing import Annotated
# from typing import Any
# from typing import TypedDict
from operator import add
from typing import Annotated
from typing import Any
from typing import TypedDict
# from pydantic import BaseModel
from pydantic import BaseModel
# from onyx.agents.agent_search.core_state import CoreState
# from onyx.agents.agent_search.dr.models import IterationAnswer
# from onyx.agents.agent_search.dr.models import IterationInstructions
# from onyx.agents.agent_search.dr.models import OrchestrationClarificationInfo
# from onyx.agents.agent_search.dr.models import OrchestrationPlan
# from onyx.agents.agent_search.dr.models import OrchestratorTool
# from onyx.context.search.models import InferenceSection
# from onyx.db.connector import DocumentSource
from onyx.agents.agent_search.core_state import CoreState
from onyx.agents.agent_search.dr.models import IterationAnswer
from onyx.agents.agent_search.dr.models import IterationInstructions
from onyx.agents.agent_search.dr.models import OrchestrationClarificationInfo
from onyx.agents.agent_search.dr.models import OrchestrationPlan
from onyx.agents.agent_search.dr.models import OrchestratorTool
from onyx.context.search.models import InferenceSection
from onyx.db.connector import DocumentSource
# ### States ###
### States ###
# class LoggerUpdate(BaseModel):
# log_messages: Annotated[list[str], add] = []
class LoggerUpdate(BaseModel):
log_messages: Annotated[list[str], add] = []
# class OrchestrationUpdate(LoggerUpdate):
# tools_used: Annotated[list[str], add] = []
# query_list: list[str] = []
# iteration_nr: int = 0
# current_step_nr: int = 1
# plan_of_record: OrchestrationPlan | None = None # None for Thoughtful
# remaining_time_budget: float = 2.0 # set by default to about 2 searches
# num_closer_suggestions: int = 0 # how many times the closer was suggested
# gaps: list[str] = (
# []
# ) # gaps that may be identified by the closer before being able to answer the question.
# iteration_instructions: Annotated[list[IterationInstructions], add] = []
class OrchestrationUpdate(LoggerUpdate):
tools_used: Annotated[list[str], add] = []
query_list: list[str] = []
iteration_nr: int = 0
current_step_nr: int = 1
plan_of_record: OrchestrationPlan | None = None # None for Thoughtful
remaining_time_budget: float = 2.0 # set by default to about 2 searches
num_closer_suggestions: int = 0 # how many times the closer was suggested
gaps: list[str] = (
[]
) # gaps that may be identified by the closer before being able to answer the question.
iteration_instructions: Annotated[list[IterationInstructions], add] = []
# class OrchestrationSetup(OrchestrationUpdate):
# original_question: str | None = None
# chat_history_string: str | None = None
# clarification: OrchestrationClarificationInfo | None = None
# available_tools: dict[str, OrchestratorTool] | None = None
# num_closer_suggestions: int = 0 # how many times the closer was suggested
class OrchestrationSetup(OrchestrationUpdate):
original_question: str | None = None
chat_history_string: str | None = None
clarification: OrchestrationClarificationInfo | None = None
available_tools: dict[str, OrchestratorTool] | None = None
num_closer_suggestions: int = 0 # how many times the closer was suggested
# active_source_types: list[DocumentSource] | None = None
# active_source_types_descriptions: str | None = None
# assistant_system_prompt: str | None = None
# assistant_task_prompt: str | None = None
# uploaded_test_context: str | None = None
# uploaded_image_context: list[dict[str, Any]] | None = None
active_source_types: list[DocumentSource] | None = None
active_source_types_descriptions: str | None = None
assistant_system_prompt: str | None = None
assistant_task_prompt: str | None = None
uploaded_test_context: str | None = None
uploaded_image_context: list[dict[str, Any]] | None = None
# class AnswerUpdate(LoggerUpdate):
# iteration_responses: Annotated[list[IterationAnswer], add] = []
class AnswerUpdate(LoggerUpdate):
iteration_responses: Annotated[list[IterationAnswer], add] = []
# class FinalUpdate(LoggerUpdate):
# final_answer: str | None = None
# all_cited_documents: list[InferenceSection] = []
class FinalUpdate(LoggerUpdate):
final_answer: str | None = None
all_cited_documents: list[InferenceSection] = []
# ## Graph Input State
# class MainInput(CoreState):
# pass
## Graph Input State
class MainInput(CoreState):
pass
# ## Graph State
# class MainState(
# # This includes the core state
# MainInput,
# OrchestrationSetup,
# AnswerUpdate,
# FinalUpdate,
# ):
# pass
## Graph State
class MainState(
# This includes the core state
MainInput,
OrchestrationSetup,
AnswerUpdate,
FinalUpdate,
):
pass
# ## Graph Output State
# class MainOutput(TypedDict):
# log_messages: list[str]
# final_answer: str | None
# all_cited_documents: list[InferenceSection]
## Graph Output State
class MainOutput(TypedDict):
log_messages: list[str]
final_answer: str | None
all_cited_documents: list[InferenceSection]

View File

@@ -1,47 +1,47 @@
# from datetime import datetime
from datetime import datetime
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from onyx.agents.agent_search.dr.states import LoggerUpdate
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
# from onyx.server.query_and_chat.streaming_models import SearchToolStart
# from onyx.utils.logger import setup_logger
from onyx.agents.agent_search.dr.states import LoggerUpdate
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.server.query_and_chat.streaming_models import SearchToolStart
from onyx.utils.logger import setup_logger
# logger = setup_logger()
logger = setup_logger()
# def basic_search_branch(
# state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
# ) -> LoggerUpdate:
# """
# LangGraph node to perform a standard search as part of the DR process.
# """
def basic_search_branch(
state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> LoggerUpdate:
"""
LangGraph node to perform a standard search as part of the DR process.
"""
# node_start_time = datetime.now()
# iteration_nr = state.iteration_nr
# current_step_nr = state.current_step_nr
node_start_time = datetime.now()
iteration_nr = state.iteration_nr
current_step_nr = state.current_step_nr
# logger.debug(f"Search start for Basic Search {iteration_nr} at {datetime.now()}")
logger.debug(f"Search start for Basic Search {iteration_nr} at {datetime.now()}")
# write_custom_event(
# current_step_nr,
# SearchToolStart(
# is_internet_search=False,
# ),
# writer,
# )
write_custom_event(
current_step_nr,
SearchToolStart(
is_internet_search=False,
),
writer,
)
# return LoggerUpdate(
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="basic_search",
# node_name="branching",
# node_start_time=node_start_time,
# )
# ],
# )
return LoggerUpdate(
log_messages=[
get_langgraph_node_log_string(
graph_component="basic_search",
node_name="branching",
node_start_time=node_start_time,
)
],
)

View File

@@ -1,261 +1,286 @@
# import re
# from datetime import datetime
# from typing import cast
import re
from datetime import datetime
from typing import cast
from uuid import UUID
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from onyx.agents.agent_search.dr.models import BaseSearchProcessingResponse
# from onyx.agents.agent_search.dr.models import IterationAnswer
# from onyx.agents.agent_search.dr.models import SearchAnswer
# from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
# from onyx.agents.agent_search.dr.sub_agents.states import BranchUpdate
# from onyx.agents.agent_search.dr.utils import convert_inference_sections_to_search_docs
# from onyx.agents.agent_search.dr.utils import extract_document_citations
# from onyx.agents.agent_search.kb_search.graph_utils import build_document_context
# from onyx.agents.agent_search.models import GraphConfig
# from onyx.agents.agent_search.shared_graph_utils.llm import invoke_llm_json
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
# from onyx.agents.agent_search.utils import create_question_prompt
# from onyx.chat.models import LlmDoc
# from onyx.configs.agent_configs import TF_DR_TIMEOUT_LONG
# from onyx.configs.agent_configs import TF_DR_TIMEOUT_SHORT
# from onyx.context.search.models import InferenceSection
# from onyx.db.connector import DocumentSource
# from onyx.prompts.dr_prompts import BASE_SEARCH_PROCESSING_PROMPT
# from onyx.prompts.dr_prompts import INTERNAL_SEARCH_PROMPTS
# from onyx.secondary_llm_flows.source_filter import strings_to_document_sources
# from onyx.server.query_and_chat.streaming_models import SearchToolDelta
# from onyx.tools.models import SearchToolOverrideKwargs
# from onyx.tools.tool_implementations.search.search_tool import SearchTool
# from onyx.tools.tool_implementations.search_like_tool_utils import (
# SEARCH_INFERENCE_SECTIONS_ID,
# )
# from onyx.utils.logger import setup_logger
from onyx.agents.agent_search.dr.enums import ResearchType
from onyx.agents.agent_search.dr.models import BaseSearchProcessingResponse
from onyx.agents.agent_search.dr.models import IterationAnswer
from onyx.agents.agent_search.dr.models import SearchAnswer
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
from onyx.agents.agent_search.dr.sub_agents.states import BranchUpdate
from onyx.agents.agent_search.dr.utils import convert_inference_sections_to_search_docs
from onyx.agents.agent_search.dr.utils import extract_document_citations
from onyx.agents.agent_search.kb_search.graph_utils import build_document_context
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.llm import invoke_llm_json
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.agents.agent_search.utils import create_question_prompt
from onyx.chat.models import LlmDoc
from onyx.configs.agent_configs import TF_DR_TIMEOUT_LONG
from onyx.configs.agent_configs import TF_DR_TIMEOUT_SHORT
from onyx.context.search.models import InferenceSection
from onyx.db.connector import DocumentSource
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.prompts.dr_prompts import BASE_SEARCH_PROCESSING_PROMPT
from onyx.prompts.dr_prompts import INTERNAL_SEARCH_PROMPTS
from onyx.secondary_llm_flows.source_filter import strings_to_document_sources
from onyx.server.query_and_chat.streaming_models import SearchToolDelta
from onyx.tools.models import SearchToolOverrideKwargs
from onyx.tools.tool_implementations.search.search_tool import (
SEARCH_RESPONSE_SUMMARY_ID,
)
from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary
from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.utils.logger import setup_logger
# logger = setup_logger()
logger = setup_logger()
# def basic_search(
# state: BranchInput,
# config: RunnableConfig,
# writer: StreamWriter = lambda _: None,
# ) -> BranchUpdate:
# """
# LangGraph node to perform a standard search as part of the DR process.
# """
def basic_search(
state: BranchInput,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> BranchUpdate:
"""
LangGraph node to perform a standard search as part of the DR process.
"""
# node_start_time = datetime.now()
# iteration_nr = state.iteration_nr
# parallelization_nr = state.parallelization_nr
# current_step_nr = state.current_step_nr
# assistant_system_prompt = state.assistant_system_prompt
# assistant_task_prompt = state.assistant_task_prompt
node_start_time = datetime.now()
iteration_nr = state.iteration_nr
parallelization_nr = state.parallelization_nr
current_step_nr = state.current_step_nr
assistant_system_prompt = state.assistant_system_prompt
assistant_task_prompt = state.assistant_task_prompt
# branch_query = state.branch_question
# if not branch_query:
# raise ValueError("branch_query is not set")
branch_query = state.branch_question
if not branch_query:
raise ValueError("branch_query is not set")
# graph_config = cast(GraphConfig, config["metadata"]["config"])
# base_question = graph_config.inputs.prompt_builder.raw_user_query
# use_agentic_search = graph_config.behavior.use_agentic_search
graph_config = cast(GraphConfig, config["metadata"]["config"])
base_question = graph_config.inputs.prompt_builder.raw_user_query
research_type = graph_config.behavior.research_type
# if not state.available_tools:
# raise ValueError("available_tools is not set")
if not state.available_tools:
raise ValueError("available_tools is not set")
# elif len(state.tools_used) == 0:
# raise ValueError("tools_used is empty")
elif len(state.tools_used) == 0:
raise ValueError("tools_used is empty")
# search_tool_info = state.available_tools[state.tools_used[-1]]
# search_tool = cast(SearchTool, search_tool_info.tool_object)
# graph_config.tooling.force_use_tool
search_tool_info = state.available_tools[state.tools_used[-1]]
search_tool = cast(SearchTool, search_tool_info.tool_object)
force_use_tool = graph_config.tooling.force_use_tool
# # sanity check
# if search_tool != graph_config.tooling.search_tool:
# raise ValueError("search_tool does not match the configured search tool")
# sanity check
if search_tool != graph_config.tooling.search_tool:
raise ValueError("search_tool does not match the configured search tool")
# # rewrite query and identify source types
# active_source_types_str = ", ".join(
# [source.value for source in state.active_source_types or []]
# )
# rewrite query and identify source types
active_source_types_str = ", ".join(
[source.value for source in state.active_source_types or []]
)
# base_search_processing_prompt = BASE_SEARCH_PROCESSING_PROMPT.build(
# active_source_types_str=active_source_types_str,
# branch_query=branch_query,
# current_time=datetime.now().strftime("%Y-%m-%d %H:%M"),
# )
base_search_processing_prompt = BASE_SEARCH_PROCESSING_PROMPT.build(
active_source_types_str=active_source_types_str,
branch_query=branch_query,
current_time=datetime.now().strftime("%Y-%m-%d %H:%M"),
)
# try:
# search_processing = invoke_llm_json(
# llm=graph_config.tooling.primary_llm,
# prompt=create_question_prompt(
# assistant_system_prompt, base_search_processing_prompt
# ),
# schema=BaseSearchProcessingResponse,
# timeout_override=TF_DR_TIMEOUT_SHORT,
# # max_tokens=100,
# )
# except Exception as e:
# logger.error(f"Could not process query: {e}")
# raise e
try:
search_processing = invoke_llm_json(
llm=graph_config.tooling.primary_llm,
prompt=create_question_prompt(
assistant_system_prompt, base_search_processing_prompt
),
schema=BaseSearchProcessingResponse,
timeout_override=TF_DR_TIMEOUT_SHORT,
# max_tokens=100,
)
except Exception as e:
logger.error(f"Could not process query: {e}")
raise e
# rewritten_query = search_processing.rewritten_query
rewritten_query = search_processing.rewritten_query
# # give back the query so we can render it in the UI
# write_custom_event(
# current_step_nr,
# SearchToolDelta(
# queries=[rewritten_query],
# documents=[],
# ),
# writer,
# )
# give back the query so we can render it in the UI
write_custom_event(
current_step_nr,
SearchToolDelta(
queries=[rewritten_query],
documents=[],
),
writer,
)
# implied_start_date = search_processing.time_filter
implied_start_date = search_processing.time_filter
# # Validate time_filter format if it exists
# implied_time_filter = None
# if implied_start_date:
# Validate time_filter format if it exists
implied_time_filter = None
if implied_start_date:
# # Check if time_filter is in YYYY-MM-DD format
# date_pattern = r"^\d{4}-\d{2}-\d{2}$"
# if re.match(date_pattern, implied_start_date):
# implied_time_filter = datetime.strptime(implied_start_date, "%Y-%m-%d")
# Check if time_filter is in YYYY-MM-DD format
date_pattern = r"^\d{4}-\d{2}-\d{2}$"
if re.match(date_pattern, implied_start_date):
implied_time_filter = datetime.strptime(implied_start_date, "%Y-%m-%d")
# specified_source_types: list[DocumentSource] | None = (
# strings_to_document_sources(search_processing.specified_source_types)
# if search_processing.specified_source_types
# else None
# )
specified_source_types: list[DocumentSource] | None = (
strings_to_document_sources(search_processing.specified_source_types)
if search_processing.specified_source_types
else None
)
# if specified_source_types is not None and len(specified_source_types) == 0:
# specified_source_types = None
if specified_source_types is not None and len(specified_source_types) == 0:
specified_source_types = None
# logger.debug(
# f"Search start for Standard Search {iteration_nr}.{parallelization_nr} at {datetime.now()}"
# )
logger.debug(
f"Search start for Standard Search {iteration_nr}.{parallelization_nr} at {datetime.now()}"
)
# retrieved_docs: list[InferenceSection] = []
retrieved_docs: list[InferenceSection] = []
callback_container: list[list[InferenceSection]] = []
# for tool_response in search_tool.run(
# query=rewritten_query,
# document_sources=specified_source_types,
# time_filter=implied_time_filter,
# override_kwargs=SearchToolOverrideKwargs(original_query=rewritten_query),
# ):
# # get retrieved docs to send to the rest of the graph
# if tool_response.id == SEARCH_INFERENCE_SECTIONS_ID:
# retrieved_docs = cast(list[InferenceSection], tool_response.response)
user_file_ids: list[UUID] | None = None
project_id: int | None = None
if force_use_tool.override_kwargs and isinstance(
force_use_tool.override_kwargs, SearchToolOverrideKwargs
):
override_kwargs = force_use_tool.override_kwargs
user_file_ids = override_kwargs.user_file_ids
project_id = override_kwargs.project_id
# break
# new db session to avoid concurrency issues
with get_session_with_current_tenant() as search_db_session:
for tool_response in search_tool.run(
query=rewritten_query,
document_sources=specified_source_types,
time_filter=implied_time_filter,
override_kwargs=SearchToolOverrideKwargs(
force_no_rerank=True,
alternate_db_session=search_db_session,
retrieved_sections_callback=callback_container.append,
skip_query_analysis=True,
original_query=rewritten_query,
user_file_ids=user_file_ids,
project_id=project_id,
),
):
# get retrieved docs to send to the rest of the graph
if tool_response.id == SEARCH_RESPONSE_SUMMARY_ID:
response = cast(SearchResponseSummary, tool_response.response)
retrieved_docs = response.top_sections
# # render the retrieved docs in the UI
# write_custom_event(
# current_step_nr,
# SearchToolDelta(
# queries=[],
# documents=convert_inference_sections_to_search_docs(
# retrieved_docs, is_internet=False
# ),
# ),
# writer,
# )
break
# document_texts_list = []
# render the retrieved docs in the UI
write_custom_event(
current_step_nr,
SearchToolDelta(
queries=[],
documents=convert_inference_sections_to_search_docs(
retrieved_docs, is_internet=False
),
),
writer,
)
# for doc_num, retrieved_doc in enumerate(retrieved_docs[:15]):
# if not isinstance(retrieved_doc, (InferenceSection, LlmDoc)):
# raise ValueError(f"Unexpected document type: {type(retrieved_doc)}")
# chunk_text = build_document_context(retrieved_doc, doc_num + 1)
# document_texts_list.append(chunk_text)
document_texts_list = []
# document_texts = "\n\n".join(document_texts_list)
for doc_num, retrieved_doc in enumerate(retrieved_docs[:15]):
if not isinstance(retrieved_doc, (InferenceSection, LlmDoc)):
raise ValueError(f"Unexpected document type: {type(retrieved_doc)}")
chunk_text = build_document_context(retrieved_doc, doc_num + 1)
document_texts_list.append(chunk_text)
# logger.debug(
# f"Search end/LLM start for Standard Search {iteration_nr}.{parallelization_nr} at {datetime.now()}"
# )
document_texts = "\n\n".join(document_texts_list)
# # Built prompt
logger.debug(
f"Search end/LLM start for Standard Search {iteration_nr}.{parallelization_nr} at {datetime.now()}"
)
# if use_agentic_search:
# search_prompt = INTERNAL_SEARCH_PROMPTS[use_agentic_search].build(
# search_query=branch_query,
# base_question=base_question,
# document_text=document_texts,
# )
# Built prompt
# # Run LLM
if research_type == ResearchType.DEEP:
search_prompt = INTERNAL_SEARCH_PROMPTS[research_type].build(
search_query=branch_query,
base_question=base_question,
document_text=document_texts,
)
# # search_answer_json = None
# search_answer_json = invoke_llm_json(
# llm=graph_config.tooling.primary_llm,
# prompt=create_question_prompt(
# assistant_system_prompt, search_prompt + (assistant_task_prompt or "")
# ),
# schema=SearchAnswer,
# timeout_override=TF_DR_TIMEOUT_LONG,
# # max_tokens=1500,
# )
# Run LLM
# logger.debug(
# f"LLM/all done for Standard Search {iteration_nr}.{parallelization_nr} at {datetime.now()}"
# )
# search_answer_json = None
search_answer_json = invoke_llm_json(
llm=graph_config.tooling.primary_llm,
prompt=create_question_prompt(
assistant_system_prompt, search_prompt + (assistant_task_prompt or "")
),
schema=SearchAnswer,
timeout_override=TF_DR_TIMEOUT_LONG,
# max_tokens=1500,
)
# # get cited documents
# answer_string = search_answer_json.answer
# claims = search_answer_json.claims or []
# reasoning = search_answer_json.reasoning
# # answer_string = ""
# # claims = []
logger.debug(
f"LLM/all done for Standard Search {iteration_nr}.{parallelization_nr} at {datetime.now()}"
)
# (
# citation_numbers,
# answer_string,
# claims,
# ) = extract_document_citations(answer_string, claims)
# get cited documents
answer_string = search_answer_json.answer
claims = search_answer_json.claims or []
reasoning = search_answer_json.reasoning
# answer_string = ""
# claims = []
# if citation_numbers and (
# (max(citation_numbers) > len(retrieved_docs)) or min(citation_numbers) < 1
# ):
# raise ValueError("Citation numbers are out of range for retrieved docs.")
(
citation_numbers,
answer_string,
claims,
) = extract_document_citations(answer_string, claims)
# cited_documents = {
# citation_number: retrieved_docs[citation_number - 1]
# for citation_number in citation_numbers
# }
if citation_numbers and (
(max(citation_numbers) > len(retrieved_docs)) or min(citation_numbers) < 1
):
raise ValueError("Citation numbers are out of range for retrieved docs.")
# else:
# answer_string = ""
# claims = []
# cited_documents = {
# doc_num + 1: retrieved_doc
# for doc_num, retrieved_doc in enumerate(retrieved_docs[:15])
# }
# reasoning = ""
cited_documents = {
citation_number: retrieved_docs[citation_number - 1]
for citation_number in citation_numbers
}
# return BranchUpdate(
# branch_iteration_responses=[
# IterationAnswer(
# tool=search_tool_info.llm_path,
# tool_id=search_tool_info.tool_id,
# iteration_nr=iteration_nr,
# parallelization_nr=parallelization_nr,
# question=branch_query,
# answer=answer_string,
# claims=claims,
# cited_documents=cited_documents,
# reasoning=reasoning,
# additional_data=None,
# )
# ],
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="basic_search",
# node_name="searching",
# node_start_time=node_start_time,
# )
# ],
# )
else:
answer_string = ""
claims = []
cited_documents = {
doc_num + 1: retrieved_doc
for doc_num, retrieved_doc in enumerate(retrieved_docs[:15])
}
reasoning = ""
return BranchUpdate(
branch_iteration_responses=[
IterationAnswer(
tool=search_tool_info.llm_path,
tool_id=search_tool_info.tool_id,
iteration_nr=iteration_nr,
parallelization_nr=parallelization_nr,
question=branch_query,
answer=answer_string,
claims=claims,
cited_documents=cited_documents,
reasoning=reasoning,
additional_data=None,
)
],
log_messages=[
get_langgraph_node_log_string(
graph_component="basic_search",
node_name="searching",
node_start_time=node_start_time,
)
],
)

View File

@@ -1,77 +1,77 @@
# from datetime import datetime
from datetime import datetime
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentUpdate
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
# from onyx.context.search.models import SavedSearchDoc
# from onyx.context.search.models import SearchDoc
# from onyx.server.query_and_chat.streaming_models import SectionEnd
# from onyx.utils.logger import setup_logger
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentUpdate
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.context.search.models import SavedSearchDoc
from onyx.context.search.models import SearchDoc
from onyx.server.query_and_chat.streaming_models import SectionEnd
from onyx.utils.logger import setup_logger
# logger = setup_logger()
logger = setup_logger()
# def is_reducer(
# state: SubAgentMainState,
# config: RunnableConfig,
# writer: StreamWriter = lambda _: None,
# ) -> SubAgentUpdate:
# """
# LangGraph node to perform a standard search as part of the DR process.
# """
def is_reducer(
state: SubAgentMainState,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> SubAgentUpdate:
"""
LangGraph node to perform a standard search as part of the DR process.
"""
# node_start_time = datetime.now()
node_start_time = datetime.now()
# branch_updates = state.branch_iteration_responses
# current_iteration = state.iteration_nr
# current_step_nr = state.current_step_nr
branch_updates = state.branch_iteration_responses
current_iteration = state.iteration_nr
current_step_nr = state.current_step_nr
# new_updates = [
# update for update in branch_updates if update.iteration_nr == current_iteration
# ]
new_updates = [
update for update in branch_updates if update.iteration_nr == current_iteration
]
# [update.question for update in new_updates]
# doc_lists = [list(update.cited_documents.values()) for update in new_updates]
[update.question for update in new_updates]
doc_lists = [list(update.cited_documents.values()) for update in new_updates]
# doc_list = []
doc_list = []
# for xs in doc_lists:
# for x in xs:
# doc_list.append(x)
for xs in doc_lists:
for x in xs:
doc_list.append(x)
# # Convert InferenceSections to SavedSearchDocs
# search_docs = SearchDoc.from_chunks_or_sections(doc_list)
# retrieved_saved_search_docs = [
# SavedSearchDoc.from_search_doc(search_doc, db_doc_id=0)
# for search_doc in search_docs
# ]
# Convert InferenceSections to SavedSearchDocs
search_docs = SearchDoc.from_chunks_or_sections(doc_list)
retrieved_saved_search_docs = [
SavedSearchDoc.from_search_doc(search_doc, db_doc_id=0)
for search_doc in search_docs
]
# for retrieved_saved_search_doc in retrieved_saved_search_docs:
# retrieved_saved_search_doc.is_internet = False
for retrieved_saved_search_doc in retrieved_saved_search_docs:
retrieved_saved_search_doc.is_internet = False
# write_custom_event(
# current_step_nr,
# SectionEnd(),
# writer,
# )
write_custom_event(
current_step_nr,
SectionEnd(),
writer,
)
# current_step_nr += 1
current_step_nr += 1
# return SubAgentUpdate(
# iteration_responses=new_updates,
# current_step_nr=current_step_nr,
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="basic_search",
# node_name="consolidation",
# node_start_time=node_start_time,
# )
# ],
# )
return SubAgentUpdate(
iteration_responses=new_updates,
current_step_nr=current_step_nr,
log_messages=[
get_langgraph_node_log_string(
graph_component="basic_search",
node_name="consolidation",
node_start_time=node_start_time,
)
],
)

View File

@@ -1,50 +1,50 @@
# from langgraph.graph import END
# from langgraph.graph import START
# from langgraph.graph import StateGraph
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
# from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_basic_search_1_branch import (
# basic_search_branch,
# )
# from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_basic_search_2_act import (
# basic_search,
# )
# from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_basic_search_3_reduce import (
# is_reducer,
# )
# from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_image_generation_conditional_edges import (
# branching_router,
# )
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
# from onyx.utils.logger import setup_logger
from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_basic_search_1_branch import (
basic_search_branch,
)
from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_basic_search_2_act import (
basic_search,
)
from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_basic_search_3_reduce import (
is_reducer,
)
from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_image_generation_conditional_edges import (
branching_router,
)
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
from onyx.utils.logger import setup_logger
# logger = setup_logger()
logger = setup_logger()
# def dr_basic_search_graph_builder() -> StateGraph:
# """
# LangGraph graph builder for Web Search Sub-Agent
# """
def dr_basic_search_graph_builder() -> StateGraph:
"""
LangGraph graph builder for Web Search Sub-Agent
"""
# graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
# ### Add nodes ###
### Add nodes ###
# graph.add_node("branch", basic_search_branch)
graph.add_node("branch", basic_search_branch)
# graph.add_node("act", basic_search)
graph.add_node("act", basic_search)
# graph.add_node("reducer", is_reducer)
graph.add_node("reducer", is_reducer)
# ### Add edges ###
### Add edges ###
# graph.add_edge(start_key=START, end_key="branch")
graph.add_edge(start_key=START, end_key="branch")
# graph.add_conditional_edges("branch", branching_router)
graph.add_conditional_edges("branch", branching_router)
# graph.add_edge(start_key="act", end_key="reducer")
graph.add_edge(start_key="act", end_key="reducer")
# graph.add_edge(start_key="reducer", end_key=END)
graph.add_edge(start_key="reducer", end_key=END)
# return graph
return graph

View File

@@ -1,30 +1,30 @@
# from collections.abc import Hashable
from collections.abc import Hashable
# from langgraph.types import Send
from langgraph.types import Send
# from onyx.agents.agent_search.dr.constants import MAX_DR_PARALLEL_SEARCH
# from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
from onyx.agents.agent_search.dr.constants import MAX_DR_PARALLEL_SEARCH
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
# def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
# return [
# Send(
# "act",
# BranchInput(
# iteration_nr=state.iteration_nr,
# parallelization_nr=parallelization_nr,
# branch_question=query,
# current_step_nr=state.current_step_nr,
# context="",
# active_source_types=state.active_source_types,
# tools_used=state.tools_used,
# available_tools=state.available_tools,
# assistant_system_prompt=state.assistant_system_prompt,
# assistant_task_prompt=state.assistant_task_prompt,
# ),
# )
# for parallelization_nr, query in enumerate(
# state.query_list[:MAX_DR_PARALLEL_SEARCH]
# )
# ]
def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
return [
Send(
"act",
BranchInput(
iteration_nr=state.iteration_nr,
parallelization_nr=parallelization_nr,
branch_question=query,
current_step_nr=state.current_step_nr,
context="",
active_source_types=state.active_source_types,
tools_used=state.tools_used,
available_tools=state.available_tools,
assistant_system_prompt=state.assistant_system_prompt,
assistant_task_prompt=state.assistant_task_prompt,
),
)
for parallelization_nr, query in enumerate(
state.query_list[:MAX_DR_PARALLEL_SEARCH]
)
]

View File

@@ -1,36 +1,36 @@
# from datetime import datetime
from datetime import datetime
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from onyx.agents.agent_search.dr.states import LoggerUpdate
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.utils.logger import setup_logger
from onyx.agents.agent_search.dr.states import LoggerUpdate
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.utils.logger import setup_logger
# logger = setup_logger()
logger = setup_logger()
# def custom_tool_branch(
# state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
# ) -> LoggerUpdate:
# """
# LangGraph node to perform a generic tool call as part of the DR process.
# """
def custom_tool_branch(
state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> LoggerUpdate:
"""
LangGraph node to perform a generic tool call as part of the DR process.
"""
# node_start_time = datetime.now()
# iteration_nr = state.iteration_nr
node_start_time = datetime.now()
iteration_nr = state.iteration_nr
# logger.debug(f"Search start for Generic Tool {iteration_nr} at {datetime.now()}")
logger.debug(f"Search start for Generic Tool {iteration_nr} at {datetime.now()}")
# return LoggerUpdate(
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="custom_tool",
# node_name="branching",
# node_start_time=node_start_time,
# )
# ],
# )
return LoggerUpdate(
log_messages=[
get_langgraph_node_log_string(
graph_component="custom_tool",
node_name="branching",
node_start_time=node_start_time,
)
],
)

View File

@@ -1,164 +1,169 @@
# import json
# from datetime import datetime
# from typing import cast
import json
from datetime import datetime
from typing import cast
# from langchain_core.messages import AIMessage
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from langchain_core.messages import AIMessage
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
# from onyx.agents.agent_search.dr.sub_agents.states import BranchUpdate
# from onyx.agents.agent_search.dr.sub_agents.states import IterationAnswer
# from onyx.agents.agent_search.models import GraphConfig
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.configs.agent_configs import TF_DR_TIMEOUT_LONG
# from onyx.configs.agent_configs import TF_DR_TIMEOUT_SHORT
# from onyx.prompts.dr_prompts import CUSTOM_TOOL_PREP_PROMPT
# from onyx.prompts.dr_prompts import CUSTOM_TOOL_USE_PROMPT
# from onyx.tools.tool_implementations.custom.custom_tool import CUSTOM_TOOL_RESPONSE_ID
# from onyx.tools.tool_implementations.custom.custom_tool import CustomTool
# from onyx.tools.tool_implementations.custom.custom_tool import CustomToolCallSummary
# from onyx.tools.tool_implementations.mcp.mcp_tool import MCP_TOOL_RESPONSE_ID
# from onyx.utils.logger import setup_logger
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
from onyx.agents.agent_search.dr.sub_agents.states import BranchUpdate
from onyx.agents.agent_search.dr.sub_agents.states import IterationAnswer
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.configs.agent_configs import TF_DR_TIMEOUT_LONG
from onyx.configs.agent_configs import TF_DR_TIMEOUT_SHORT
from onyx.prompts.dr_prompts import CUSTOM_TOOL_PREP_PROMPT
from onyx.prompts.dr_prompts import CUSTOM_TOOL_USE_PROMPT
from onyx.tools.tool_implementations.custom.custom_tool import CUSTOM_TOOL_RESPONSE_ID
from onyx.tools.tool_implementations.custom.custom_tool import CustomTool
from onyx.tools.tool_implementations.custom.custom_tool import CustomToolCallSummary
from onyx.tools.tool_implementations.mcp.mcp_tool import MCP_TOOL_RESPONSE_ID
from onyx.utils.logger import setup_logger
# logger = setup_logger()
logger = setup_logger()
# def custom_tool_act(
# state: BranchInput,
# config: RunnableConfig,
# writer: StreamWriter = lambda _: None,
# ) -> BranchUpdate:
# """
# LangGraph node to perform a generic tool call as part of the DR process.
# """
def custom_tool_act(
state: BranchInput,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> BranchUpdate:
"""
LangGraph node to perform a generic tool call as part of the DR process.
"""
# node_start_time = datetime.now()
# iteration_nr = state.iteration_nr
# parallelization_nr = state.parallelization_nr
node_start_time = datetime.now()
iteration_nr = state.iteration_nr
parallelization_nr = state.parallelization_nr
# if not state.available_tools:
# raise ValueError("available_tools is not set")
if not state.available_tools:
raise ValueError("available_tools is not set")
# custom_tool_info = state.available_tools[state.tools_used[-1]]
# custom_tool_name = custom_tool_info.name
# custom_tool = cast(CustomTool, custom_tool_info.tool_object)
custom_tool_info = state.available_tools[state.tools_used[-1]]
custom_tool_name = custom_tool_info.name
custom_tool = cast(CustomTool, custom_tool_info.tool_object)
# branch_query = state.branch_question
# if not branch_query:
# raise ValueError("branch_query is not set")
branch_query = state.branch_question
if not branch_query:
raise ValueError("branch_query is not set")
# graph_config = cast(GraphConfig, config["metadata"]["config"])
# base_question = graph_config.inputs.prompt_builder.raw_user_query
graph_config = cast(GraphConfig, config["metadata"]["config"])
base_question = graph_config.inputs.prompt_builder.raw_user_query
# logger.debug(
# f"Tool call start for {custom_tool_name} {iteration_nr}.{parallelization_nr} at {datetime.now()}"
# )
logger.debug(
f"Tool call start for {custom_tool_name} {iteration_nr}.{parallelization_nr} at {datetime.now()}"
)
# # get tool call args
# tool_args: dict | None = None
# if graph_config.tooling.using_tool_calling_llm:
# # get tool call args from tool-calling LLM
# tool_use_prompt = CUSTOM_TOOL_PREP_PROMPT.build(
# query=branch_query,
# base_question=base_question,
# tool_description=custom_tool_info.description,
# )
# tool_calling_msg = graph_config.tooling.primary_llm.invoke_langchain(
# tool_use_prompt,
# tools=[custom_tool.tool_definition()],
# tool_choice="required",
# timeout_override=TF_DR_TIMEOUT_LONG,
# )
# get tool call args
tool_args: dict | None = None
if graph_config.tooling.using_tool_calling_llm:
# get tool call args from tool-calling LLM
tool_use_prompt = CUSTOM_TOOL_PREP_PROMPT.build(
query=branch_query,
base_question=base_question,
tool_description=custom_tool_info.description,
)
tool_calling_msg = graph_config.tooling.primary_llm.invoke_langchain(
tool_use_prompt,
tools=[custom_tool.tool_definition()],
tool_choice="required",
timeout_override=TF_DR_TIMEOUT_LONG,
)
# # make sure we got a tool call
# if (
# isinstance(tool_calling_msg, AIMessage)
# and len(tool_calling_msg.tool_calls) == 1
# ):
# tool_args = tool_calling_msg.tool_calls[0]["args"]
# else:
# logger.warning("Tool-calling LLM did not emit a tool call")
# make sure we got a tool call
if (
isinstance(tool_calling_msg, AIMessage)
and len(tool_calling_msg.tool_calls) == 1
):
tool_args = tool_calling_msg.tool_calls[0]["args"]
else:
logger.warning("Tool-calling LLM did not emit a tool call")
# if tool_args is None:
# raise ValueError(
# "Failed to obtain tool arguments from LLM - tool calling is required"
# )
if tool_args is None:
# get tool call args from non-tool-calling LLM or for failed tool-calling LLM
tool_args = custom_tool.get_args_for_non_tool_calling_llm(
query=branch_query,
history=[],
llm=graph_config.tooling.primary_llm,
force_run=True,
)
# # run the tool
# response_summary: CustomToolCallSummary | None = None
# for tool_response in custom_tool.run(**tool_args):
# if tool_response.id in {CUSTOM_TOOL_RESPONSE_ID, MCP_TOOL_RESPONSE_ID}:
# response_summary = cast(CustomToolCallSummary, tool_response.response)
# break
if tool_args is None:
raise ValueError("Failed to obtain tool arguments from LLM")
# if not response_summary:
# raise ValueError("Custom tool did not return a valid response summary")
# run the tool
response_summary: CustomToolCallSummary | None = None
for tool_response in custom_tool.run(**tool_args):
if tool_response.id in {CUSTOM_TOOL_RESPONSE_ID, MCP_TOOL_RESPONSE_ID}:
response_summary = cast(CustomToolCallSummary, tool_response.response)
break
# # summarise tool result
# if not response_summary.response_type:
# raise ValueError("Response type is not returned.")
if not response_summary:
raise ValueError("Custom tool did not return a valid response summary")
# if response_summary.response_type == "json":
# tool_result_str = json.dumps(response_summary.tool_result, ensure_ascii=False)
# elif response_summary.response_type in {"image", "csv"}:
# tool_result_str = f"{response_summary.response_type} files: {response_summary.tool_result.file_ids}"
# else:
# tool_result_str = str(response_summary.tool_result)
# summarise tool result
if not response_summary.response_type:
raise ValueError("Response type is not returned.")
# tool_str = (
# f"Tool used: {custom_tool_name}\n"
# f"Description: {custom_tool_info.description}\n"
# f"Result: {tool_result_str}"
# )
if response_summary.response_type == "json":
tool_result_str = json.dumps(response_summary.tool_result, ensure_ascii=False)
elif response_summary.response_type in {"image", "csv"}:
tool_result_str = f"{response_summary.response_type} files: {response_summary.tool_result.file_ids}"
else:
tool_result_str = str(response_summary.tool_result)
# tool_summary_prompt = CUSTOM_TOOL_USE_PROMPT.build(
# query=branch_query, base_question=base_question, tool_response=tool_str
# )
# answer_string = str(
# graph_config.tooling.primary_llm.invoke(
# tool_summary_prompt, timeout_override=TF_DR_TIMEOUT_SHORT
# ).content
# ).strip()
tool_str = (
f"Tool used: {custom_tool_name}\n"
f"Description: {custom_tool_info.description}\n"
f"Result: {tool_result_str}"
)
# tool_summary_prompt = CUSTOM_TOOL_USE_PROMPT.build(
# query=branch_query, base_question=base_question, tool_response=tool_str
# )
# answer_string = str(
# graph_config.tooling.primary_llm.invoke_langchain(
# tool_summary_prompt, timeout_override=TF_DR_TIMEOUT_SHORT
# ).content
# ).strip()
tool_summary_prompt = CUSTOM_TOOL_USE_PROMPT.build(
query=branch_query, base_question=base_question, tool_response=tool_str
)
answer_string = str(
graph_config.tooling.primary_llm.invoke_langchain(
tool_summary_prompt, timeout_override=TF_DR_TIMEOUT_SHORT
).content
).strip()
# logger.debug(
# f"Tool call end for {custom_tool_name} {iteration_nr}.{parallelization_nr} at {datetime.now()}"
# )
# get file_ids:
file_ids = None
if response_summary.response_type in {"image", "csv"} and hasattr(
response_summary.tool_result, "file_ids"
):
file_ids = response_summary.tool_result.file_ids
# return BranchUpdate(
# branch_iteration_responses=[
# IterationAnswer(
# tool=custom_tool_name,
# tool_id=custom_tool_info.tool_id,
# iteration_nr=iteration_nr,
# parallelization_nr=parallelization_nr,
# question=branch_query,
# answer=answer_string,
# claims=[],
# cited_documents={},
# reasoning="",
# additional_data=None,
# response_type=response_summary.response_type,
# data=response_summary.tool_result,
# file_ids=file_ids,
# )
# ],
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="custom_tool",
# node_name="tool_calling",
# node_start_time=node_start_time,
# )
# ],
# )
logger.debug(
f"Tool call end for {custom_tool_name} {iteration_nr}.{parallelization_nr} at {datetime.now()}"
)
return BranchUpdate(
branch_iteration_responses=[
IterationAnswer(
tool=custom_tool_name,
tool_id=custom_tool_info.tool_id,
iteration_nr=iteration_nr,
parallelization_nr=parallelization_nr,
question=branch_query,
answer=answer_string,
claims=[],
cited_documents={},
reasoning="",
additional_data=None,
response_type=response_summary.response_type,
data=response_summary.tool_result,
file_ids=file_ids,
)
],
log_messages=[
get_langgraph_node_log_string(
graph_component="custom_tool",
node_name="tool_calling",
node_start_time=node_start_time,
)
],
)

View File

@@ -1,82 +1,82 @@
# from datetime import datetime
from datetime import datetime
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentUpdate
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
# from onyx.server.query_and_chat.streaming_models import CustomToolDelta
# from onyx.server.query_and_chat.streaming_models import CustomToolStart
# from onyx.server.query_and_chat.streaming_models import SectionEnd
# from onyx.utils.logger import setup_logger
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentUpdate
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.server.query_and_chat.streaming_models import CustomToolDelta
from onyx.server.query_and_chat.streaming_models import CustomToolStart
from onyx.server.query_and_chat.streaming_models import SectionEnd
from onyx.utils.logger import setup_logger
# logger = setup_logger()
logger = setup_logger()
# def custom_tool_reducer(
# state: SubAgentMainState,
# config: RunnableConfig,
# writer: StreamWriter = lambda _: None,
# ) -> SubAgentUpdate:
# """
# LangGraph node to perform a generic tool call as part of the DR process.
# """
def custom_tool_reducer(
state: SubAgentMainState,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> SubAgentUpdate:
"""
LangGraph node to perform a generic tool call as part of the DR process.
"""
# node_start_time = datetime.now()
node_start_time = datetime.now()
# current_step_nr = state.current_step_nr
current_step_nr = state.current_step_nr
# branch_updates = state.branch_iteration_responses
# current_iteration = state.iteration_nr
branch_updates = state.branch_iteration_responses
current_iteration = state.iteration_nr
# new_updates = [
# update for update in branch_updates if update.iteration_nr == current_iteration
# ]
new_updates = [
update for update in branch_updates if update.iteration_nr == current_iteration
]
# for new_update in new_updates:
for new_update in new_updates:
# if not new_update.response_type:
# raise ValueError("Response type is not returned.")
if not new_update.response_type:
raise ValueError("Response type is not returned.")
# write_custom_event(
# current_step_nr,
# CustomToolStart(
# tool_name=new_update.tool,
# ),
# writer,
# )
write_custom_event(
current_step_nr,
CustomToolStart(
tool_name=new_update.tool,
),
writer,
)
# write_custom_event(
# current_step_nr,
# CustomToolDelta(
# tool_name=new_update.tool,
# response_type=new_update.response_type,
# data=new_update.data,
# file_ids=new_update.file_ids,
# ),
# writer,
# )
write_custom_event(
current_step_nr,
CustomToolDelta(
tool_name=new_update.tool,
response_type=new_update.response_type,
data=new_update.data,
file_ids=new_update.file_ids,
),
writer,
)
# write_custom_event(
# current_step_nr,
# SectionEnd(),
# writer,
# )
write_custom_event(
current_step_nr,
SectionEnd(),
writer,
)
# current_step_nr += 1
current_step_nr += 1
# return SubAgentUpdate(
# iteration_responses=new_updates,
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="custom_tool",
# node_name="consolidation",
# node_start_time=node_start_time,
# )
# ],
# )
return SubAgentUpdate(
iteration_responses=new_updates,
log_messages=[
get_langgraph_node_log_string(
graph_component="custom_tool",
node_name="consolidation",
node_start_time=node_start_time,
)
],
)

View File

@@ -1,28 +1,28 @@
# from collections.abc import Hashable
from collections.abc import Hashable
# from langgraph.types import Send
from langgraph.types import Send
# from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
# from onyx.agents.agent_search.dr.sub_agents.states import (
# SubAgentInput,
# )
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
from onyx.agents.agent_search.dr.sub_agents.states import (
SubAgentInput,
)
# def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
# return [
# Send(
# "act",
# BranchInput(
# iteration_nr=state.iteration_nr,
# parallelization_nr=parallelization_nr,
# branch_question=query,
# context="",
# active_source_types=state.active_source_types,
# tools_used=state.tools_used,
# available_tools=state.available_tools,
# ),
# )
# for parallelization_nr, query in enumerate(
# state.query_list[:1] # no parallel call for now
# )
# ]
def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
return [
Send(
"act",
BranchInput(
iteration_nr=state.iteration_nr,
parallelization_nr=parallelization_nr,
branch_question=query,
context="",
active_source_types=state.active_source_types,
tools_used=state.tools_used,
available_tools=state.available_tools,
),
)
for parallelization_nr, query in enumerate(
state.query_list[:1] # no parallel call for now
)
]

View File

@@ -1,50 +1,50 @@
# from langgraph.graph import END
# from langgraph.graph import START
# from langgraph.graph import StateGraph
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
# from onyx.agents.agent_search.dr.sub_agents.custom_tool.dr_custom_tool_1_branch import (
# custom_tool_branch,
# )
# from onyx.agents.agent_search.dr.sub_agents.custom_tool.dr_custom_tool_2_act import (
# custom_tool_act,
# )
# from onyx.agents.agent_search.dr.sub_agents.custom_tool.dr_custom_tool_3_reduce import (
# custom_tool_reducer,
# )
# from onyx.agents.agent_search.dr.sub_agents.custom_tool.dr_custom_tool_conditional_edges import (
# branching_router,
# )
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
# from onyx.utils.logger import setup_logger
from onyx.agents.agent_search.dr.sub_agents.custom_tool.dr_custom_tool_1_branch import (
custom_tool_branch,
)
from onyx.agents.agent_search.dr.sub_agents.custom_tool.dr_custom_tool_2_act import (
custom_tool_act,
)
from onyx.agents.agent_search.dr.sub_agents.custom_tool.dr_custom_tool_3_reduce import (
custom_tool_reducer,
)
from onyx.agents.agent_search.dr.sub_agents.custom_tool.dr_custom_tool_conditional_edges import (
branching_router,
)
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
from onyx.utils.logger import setup_logger
# logger = setup_logger()
logger = setup_logger()
# def dr_custom_tool_graph_builder() -> StateGraph:
# """
# LangGraph graph builder for Generic Tool Sub-Agent
# """
def dr_custom_tool_graph_builder() -> StateGraph:
"""
LangGraph graph builder for Generic Tool Sub-Agent
"""
# graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
# ### Add nodes ###
### Add nodes ###
# graph.add_node("branch", custom_tool_branch)
graph.add_node("branch", custom_tool_branch)
# graph.add_node("act", custom_tool_act)
graph.add_node("act", custom_tool_act)
# graph.add_node("reducer", custom_tool_reducer)
graph.add_node("reducer", custom_tool_reducer)
# ### Add edges ###
### Add edges ###
# graph.add_edge(start_key=START, end_key="branch")
graph.add_edge(start_key=START, end_key="branch")
# graph.add_conditional_edges("branch", branching_router)
graph.add_conditional_edges("branch", branching_router)
# graph.add_edge(start_key="act", end_key="reducer")
graph.add_edge(start_key="act", end_key="reducer")
# graph.add_edge(start_key="reducer", end_key=END)
graph.add_edge(start_key="reducer", end_key=END)
# return graph
return graph

View File

@@ -1,36 +1,36 @@
# from datetime import datetime
from datetime import datetime
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from onyx.agents.agent_search.dr.states import LoggerUpdate
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.utils.logger import setup_logger
from onyx.agents.agent_search.dr.states import LoggerUpdate
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.utils.logger import setup_logger
# logger = setup_logger()
logger = setup_logger()
# def generic_internal_tool_branch(
# state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
# ) -> LoggerUpdate:
# """
# LangGraph node to perform a generic tool call as part of the DR process.
# """
def generic_internal_tool_branch(
state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> LoggerUpdate:
"""
LangGraph node to perform a generic tool call as part of the DR process.
"""
# node_start_time = datetime.now()
# iteration_nr = state.iteration_nr
node_start_time = datetime.now()
iteration_nr = state.iteration_nr
# logger.debug(f"Search start for Generic Tool {iteration_nr} at {datetime.now()}")
logger.debug(f"Search start for Generic Tool {iteration_nr} at {datetime.now()}")
# return LoggerUpdate(
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="generic_internal_tool",
# node_name="branching",
# node_start_time=node_start_time,
# )
# ],
# )
return LoggerUpdate(
log_messages=[
get_langgraph_node_log_string(
graph_component="generic_internal_tool",
node_name="branching",
node_start_time=node_start_time,
)
],
)

View File

@@ -1,147 +1,149 @@
# import json
# from datetime import datetime
# from typing import cast
import json
from datetime import datetime
from typing import cast
# from langchain_core.messages import AIMessage
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from langchain_core.messages import AIMessage
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
# from onyx.agents.agent_search.dr.sub_agents.states import BranchUpdate
# from onyx.agents.agent_search.dr.sub_agents.states import IterationAnswer
# from onyx.agents.agent_search.models import GraphConfig
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.configs.agent_configs import TF_DR_TIMEOUT_SHORT
# from onyx.prompts.dr_prompts import CUSTOM_TOOL_PREP_PROMPT
# from onyx.prompts.dr_prompts import CUSTOM_TOOL_USE_PROMPT
# from onyx.prompts.dr_prompts import OKTA_TOOL_USE_SPECIAL_PROMPT
# from onyx.utils.logger import setup_logger
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
from onyx.agents.agent_search.dr.sub_agents.states import BranchUpdate
from onyx.agents.agent_search.dr.sub_agents.states import IterationAnswer
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.configs.agent_configs import TF_DR_TIMEOUT_SHORT
from onyx.prompts.dr_prompts import CUSTOM_TOOL_PREP_PROMPT
from onyx.prompts.dr_prompts import CUSTOM_TOOL_USE_PROMPT
from onyx.prompts.dr_prompts import OKTA_TOOL_USE_SPECIAL_PROMPT
from onyx.utils.logger import setup_logger
# logger = setup_logger()
logger = setup_logger()
# def generic_internal_tool_act(
# state: BranchInput,
# config: RunnableConfig,
# writer: StreamWriter = lambda _: None,
# ) -> BranchUpdate:
# """
# LangGraph node to perform a generic tool call as part of the DR process.
# """
def generic_internal_tool_act(
state: BranchInput,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> BranchUpdate:
"""
LangGraph node to perform a generic tool call as part of the DR process.
"""
# node_start_time = datetime.now()
# iteration_nr = state.iteration_nr
# parallelization_nr = state.parallelization_nr
node_start_time = datetime.now()
iteration_nr = state.iteration_nr
parallelization_nr = state.parallelization_nr
# if not state.available_tools:
# raise ValueError("available_tools is not set")
if not state.available_tools:
raise ValueError("available_tools is not set")
# generic_internal_tool_info = state.available_tools[state.tools_used[-1]]
# generic_internal_tool_name = generic_internal_tool_info.llm_path
# generic_internal_tool = generic_internal_tool_info.tool_object
generic_internal_tool_info = state.available_tools[state.tools_used[-1]]
generic_internal_tool_name = generic_internal_tool_info.llm_path
generic_internal_tool = generic_internal_tool_info.tool_object
# if generic_internal_tool is None:
# raise ValueError("generic_internal_tool is not set")
if generic_internal_tool is None:
raise ValueError("generic_internal_tool is not set")
# branch_query = state.branch_question
# if not branch_query:
# raise ValueError("branch_query is not set")
branch_query = state.branch_question
if not branch_query:
raise ValueError("branch_query is not set")
# graph_config = cast(GraphConfig, config["metadata"]["config"])
# base_question = graph_config.inputs.prompt_builder.raw_user_query
graph_config = cast(GraphConfig, config["metadata"]["config"])
base_question = graph_config.inputs.prompt_builder.raw_user_query
# logger.debug(
# f"Tool call start for {generic_internal_tool_name} {iteration_nr}.{parallelization_nr} at {datetime.now()}"
# )
logger.debug(
f"Tool call start for {generic_internal_tool_name} {iteration_nr}.{parallelization_nr} at {datetime.now()}"
)
# # get tool call args
# tool_args: dict | None = None
# if graph_config.tooling.using_tool_calling_llm:
# # get tool call args from tool-calling LLM
# tool_use_prompt = CUSTOM_TOOL_PREP_PROMPT.build(
# query=branch_query,
# base_question=base_question,
# tool_description=generic_internal_tool_info.description,
# )
# tool_calling_msg = graph_config.tooling.primary_llm.invoke_langchain(
# tool_use_prompt,
# tools=[generic_internal_tool.tool_definition()],
# tool_choice="required",
# timeout_override=TF_DR_TIMEOUT_SHORT,
# )
# get tool call args
tool_args: dict | None = None
if graph_config.tooling.using_tool_calling_llm:
# get tool call args from tool-calling LLM
tool_use_prompt = CUSTOM_TOOL_PREP_PROMPT.build(
query=branch_query,
base_question=base_question,
tool_description=generic_internal_tool_info.description,
)
tool_calling_msg = graph_config.tooling.primary_llm.invoke_langchain(
tool_use_prompt,
tools=[generic_internal_tool.tool_definition()],
tool_choice="required",
timeout_override=TF_DR_TIMEOUT_SHORT,
)
# # make sure we got a tool call
# if (
# isinstance(tool_calling_msg, AIMessage)
# and len(tool_calling_msg.tool_calls) == 1
# ):
# tool_args = tool_calling_msg.tool_calls[0]["args"]
# else:
# logger.warning("Tool-calling LLM did not emit a tool call")
# make sure we got a tool call
if (
isinstance(tool_calling_msg, AIMessage)
and len(tool_calling_msg.tool_calls) == 1
):
tool_args = tool_calling_msg.tool_calls[0]["args"]
else:
logger.warning("Tool-calling LLM did not emit a tool call")
# if tool_args is None:
# raise ValueError(
# "Failed to obtain tool arguments from LLM - tool calling is required"
# )
if tool_args is None:
# get tool call args from non-tool-calling LLM or for failed tool-calling LLM
tool_args = generic_internal_tool.get_args_for_non_tool_calling_llm(
query=branch_query,
history=[],
llm=graph_config.tooling.primary_llm,
force_run=True,
)
# # run the tool
# tool_responses = list(generic_internal_tool.run(**tool_args))
# final_data = generic_internal_tool.final_result(*tool_responses)
# tool_result_str = json.dumps(final_data, ensure_ascii=False)
if tool_args is None:
raise ValueError("Failed to obtain tool arguments from LLM")
# tool_str = (
# f"Tool used: {generic_internal_tool.display_name}\n"
# f"Description: {generic_internal_tool_info.description}\n"
# f"Result: {tool_result_str}"
# )
# run the tool
tool_responses = list(generic_internal_tool.run(**tool_args))
final_data = generic_internal_tool.final_result(*tool_responses)
tool_result_str = json.dumps(final_data, ensure_ascii=False)
# if generic_internal_tool.display_name == "Okta Profile":
# tool_prompt = OKTA_TOOL_USE_SPECIAL_PROMPT
# else:
# tool_prompt = CUSTOM_TOOL_USE_PROMPT
tool_str = (
f"Tool used: {generic_internal_tool.display_name}\n"
f"Description: {generic_internal_tool_info.description}\n"
f"Result: {tool_result_str}"
)
# tool_summary_prompt = tool_prompt.build(
# query=branch_query, base_question=base_question, tool_response=tool_str
# )
# answer_string = str(
# graph_config.tooling.primary_llm.invoke(
# tool_summary_prompt, timeout_override=TF_DR_TIMEOUT_SHORT
# ).content
# ).strip()
if generic_internal_tool.display_name == "Okta Profile":
tool_prompt = OKTA_TOOL_USE_SPECIAL_PROMPT
else:
tool_prompt = CUSTOM_TOOL_USE_PROMPT
# tool_summary_prompt = tool_prompt.build(
# query=branch_query, base_question=base_question, tool_response=tool_str
# )
# answer_string = str(
# graph_config.tooling.primary_llm.invoke_langchain(
# tool_summary_prompt, timeout_override=TF_DR_TIMEOUT_SHORT
# ).content
# ).strip()
tool_summary_prompt = tool_prompt.build(
query=branch_query, base_question=base_question, tool_response=tool_str
)
answer_string = str(
graph_config.tooling.primary_llm.invoke_langchain(
tool_summary_prompt, timeout_override=TF_DR_TIMEOUT_SHORT
).content
).strip()
# return BranchUpdate(
# branch_iteration_responses=[
# IterationAnswer(
# tool=generic_internal_tool.name,
# tool_id=generic_internal_tool_info.tool_id,
# iteration_nr=iteration_nr,
# parallelization_nr=parallelization_nr,
# question=branch_query,
# answer=answer_string,
# claims=[],
# cited_documents={},
# reasoning="",
# additional_data=None,
# response_type="text", # TODO: convert all response types to enums
# data=answer_string,
# )
# ],
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="custom_tool",
# node_name="tool_calling",
# node_start_time=node_start_time,
# )
# ],
# )
logger.debug(
f"Tool call end for {generic_internal_tool_name} {iteration_nr}.{parallelization_nr} at {datetime.now()}"
)
return BranchUpdate(
branch_iteration_responses=[
IterationAnswer(
tool=generic_internal_tool.llm_name,
tool_id=generic_internal_tool_info.tool_id,
iteration_nr=iteration_nr,
parallelization_nr=parallelization_nr,
question=branch_query,
answer=answer_string,
claims=[],
cited_documents={},
reasoning="",
additional_data=None,
response_type="text", # TODO: convert all response types to enums
data=answer_string,
)
],
log_messages=[
get_langgraph_node_log_string(
graph_component="custom_tool",
node_name="tool_calling",
node_start_time=node_start_time,
)
],
)

View File

@@ -1,82 +1,82 @@
# from datetime import datetime
from datetime import datetime
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentUpdate
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
# from onyx.server.query_and_chat.streaming_models import CustomToolDelta
# from onyx.server.query_and_chat.streaming_models import CustomToolStart
# from onyx.server.query_and_chat.streaming_models import SectionEnd
# from onyx.utils.logger import setup_logger
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentUpdate
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.server.query_and_chat.streaming_models import CustomToolDelta
from onyx.server.query_and_chat.streaming_models import CustomToolStart
from onyx.server.query_and_chat.streaming_models import SectionEnd
from onyx.utils.logger import setup_logger
# logger = setup_logger()
logger = setup_logger()
# def generic_internal_tool_reducer(
# state: SubAgentMainState,
# config: RunnableConfig,
# writer: StreamWriter = lambda _: None,
# ) -> SubAgentUpdate:
# """
# LangGraph node to perform a generic tool call as part of the DR process.
# """
def generic_internal_tool_reducer(
state: SubAgentMainState,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> SubAgentUpdate:
"""
LangGraph node to perform a generic tool call as part of the DR process.
"""
# node_start_time = datetime.now()
node_start_time = datetime.now()
# current_step_nr = state.current_step_nr
current_step_nr = state.current_step_nr
# branch_updates = state.branch_iteration_responses
# current_iteration = state.iteration_nr
branch_updates = state.branch_iteration_responses
current_iteration = state.iteration_nr
# new_updates = [
# update for update in branch_updates if update.iteration_nr == current_iteration
# ]
new_updates = [
update for update in branch_updates if update.iteration_nr == current_iteration
]
# for new_update in new_updates:
for new_update in new_updates:
# if not new_update.response_type:
# raise ValueError("Response type is not returned.")
if not new_update.response_type:
raise ValueError("Response type is not returned.")
# write_custom_event(
# current_step_nr,
# CustomToolStart(
# tool_name=new_update.tool,
# ),
# writer,
# )
write_custom_event(
current_step_nr,
CustomToolStart(
tool_name=new_update.tool,
),
writer,
)
# write_custom_event(
# current_step_nr,
# CustomToolDelta(
# tool_name=new_update.tool,
# response_type=new_update.response_type,
# data=new_update.data,
# file_ids=[],
# ),
# writer,
# )
write_custom_event(
current_step_nr,
CustomToolDelta(
tool_name=new_update.tool,
response_type=new_update.response_type,
data=new_update.data,
file_ids=[],
),
writer,
)
# write_custom_event(
# current_step_nr,
# SectionEnd(),
# writer,
# )
write_custom_event(
current_step_nr,
SectionEnd(),
writer,
)
# current_step_nr += 1
current_step_nr += 1
# return SubAgentUpdate(
# iteration_responses=new_updates,
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="custom_tool",
# node_name="consolidation",
# node_start_time=node_start_time,
# )
# ],
# )
return SubAgentUpdate(
iteration_responses=new_updates,
log_messages=[
get_langgraph_node_log_string(
graph_component="custom_tool",
node_name="consolidation",
node_start_time=node_start_time,
)
],
)

View File

@@ -1,28 +1,28 @@
# from collections.abc import Hashable
from collections.abc import Hashable
# from langgraph.types import Send
from langgraph.types import Send
# from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
# from onyx.agents.agent_search.dr.sub_agents.states import (
# SubAgentInput,
# )
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
from onyx.agents.agent_search.dr.sub_agents.states import (
SubAgentInput,
)
# def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
# return [
# Send(
# "act",
# BranchInput(
# iteration_nr=state.iteration_nr,
# parallelization_nr=parallelization_nr,
# branch_question=query,
# context="",
# active_source_types=state.active_source_types,
# tools_used=state.tools_used,
# available_tools=state.available_tools,
# ),
# )
# for parallelization_nr, query in enumerate(
# state.query_list[:1] # no parallel call for now
# )
# ]
def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
return [
Send(
"act",
BranchInput(
iteration_nr=state.iteration_nr,
parallelization_nr=parallelization_nr,
branch_question=query,
context="",
active_source_types=state.active_source_types,
tools_used=state.tools_used,
available_tools=state.available_tools,
),
)
for parallelization_nr, query in enumerate(
state.query_list[:1] # no parallel call for now
)
]

View File

@@ -1,50 +1,50 @@
# from langgraph.graph import END
# from langgraph.graph import START
# from langgraph.graph import StateGraph
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
# from onyx.agents.agent_search.dr.sub_agents.generic_internal_tool.dr_generic_internal_tool_1_branch import (
# generic_internal_tool_branch,
# )
# from onyx.agents.agent_search.dr.sub_agents.generic_internal_tool.dr_generic_internal_tool_2_act import (
# generic_internal_tool_act,
# )
# from onyx.agents.agent_search.dr.sub_agents.generic_internal_tool.dr_generic_internal_tool_3_reduce import (
# generic_internal_tool_reducer,
# )
# from onyx.agents.agent_search.dr.sub_agents.generic_internal_tool.dr_generic_internal_tool_conditional_edges import (
# branching_router,
# )
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
# from onyx.utils.logger import setup_logger
from onyx.agents.agent_search.dr.sub_agents.generic_internal_tool.dr_generic_internal_tool_1_branch import (
generic_internal_tool_branch,
)
from onyx.agents.agent_search.dr.sub_agents.generic_internal_tool.dr_generic_internal_tool_2_act import (
generic_internal_tool_act,
)
from onyx.agents.agent_search.dr.sub_agents.generic_internal_tool.dr_generic_internal_tool_3_reduce import (
generic_internal_tool_reducer,
)
from onyx.agents.agent_search.dr.sub_agents.generic_internal_tool.dr_generic_internal_tool_conditional_edges import (
branching_router,
)
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
from onyx.utils.logger import setup_logger
# logger = setup_logger()
logger = setup_logger()
# def dr_generic_internal_tool_graph_builder() -> StateGraph:
# """
# LangGraph graph builder for Generic Tool Sub-Agent
# """
def dr_generic_internal_tool_graph_builder() -> StateGraph:
"""
LangGraph graph builder for Generic Tool Sub-Agent
"""
# graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
# ### Add nodes ###
### Add nodes ###
# graph.add_node("branch", generic_internal_tool_branch)
graph.add_node("branch", generic_internal_tool_branch)
# graph.add_node("act", generic_internal_tool_act)
graph.add_node("act", generic_internal_tool_act)
# graph.add_node("reducer", generic_internal_tool_reducer)
graph.add_node("reducer", generic_internal_tool_reducer)
# ### Add edges ###
### Add edges ###
# graph.add_edge(start_key=START, end_key="branch")
graph.add_edge(start_key=START, end_key="branch")
# graph.add_conditional_edges("branch", branching_router)
graph.add_conditional_edges("branch", branching_router)
# graph.add_edge(start_key="act", end_key="reducer")
graph.add_edge(start_key="act", end_key="reducer")
# graph.add_edge(start_key="reducer", end_key=END)
graph.add_edge(start_key="reducer", end_key=END)
# return graph
return graph

View File

@@ -1,45 +1,45 @@
# from datetime import datetime
from datetime import datetime
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from onyx.agents.agent_search.dr.states import LoggerUpdate
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
# from onyx.server.query_and_chat.streaming_models import ImageGenerationToolStart
# from onyx.utils.logger import setup_logger
from onyx.agents.agent_search.dr.states import LoggerUpdate
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.server.query_and_chat.streaming_models import ImageGenerationToolStart
from onyx.utils.logger import setup_logger
# logger = setup_logger()
logger = setup_logger()
# def image_generation_branch(
# state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
# ) -> LoggerUpdate:
# """
# LangGraph node to perform a image generation as part of the DR process.
# """
def image_generation_branch(
state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> LoggerUpdate:
"""
LangGraph node to perform a image generation as part of the DR process.
"""
# node_start_time = datetime.now()
# iteration_nr = state.iteration_nr
node_start_time = datetime.now()
iteration_nr = state.iteration_nr
# logger.debug(f"Image generation start {iteration_nr} at {datetime.now()}")
logger.debug(f"Image generation start {iteration_nr} at {datetime.now()}")
# # tell frontend that we are starting the image generation tool
# write_custom_event(
# state.current_step_nr,
# ImageGenerationToolStart(),
# writer,
# )
# tell frontend that we are starting the image generation tool
write_custom_event(
state.current_step_nr,
ImageGenerationToolStart(),
writer,
)
# return LoggerUpdate(
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="image_generation",
# node_name="branching",
# node_start_time=node_start_time,
# )
# ],
# )
return LoggerUpdate(
log_messages=[
get_langgraph_node_log_string(
graph_component="image_generation",
node_name="branching",
node_start_time=node_start_time,
)
],
)

View File

@@ -1,187 +1,189 @@
# import json
# from datetime import datetime
# from typing import cast
import json
from datetime import datetime
from typing import cast
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from onyx.agents.agent_search.dr.models import GeneratedImage
# from onyx.agents.agent_search.dr.models import IterationAnswer
# from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
# from onyx.agents.agent_search.dr.sub_agents.states import BranchUpdate
# from onyx.agents.agent_search.models import GraphConfig
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
# from onyx.file_store.utils import build_frontend_file_url
# from onyx.file_store.utils import save_files
# from onyx.server.query_and_chat.streaming_models import ImageGenerationToolHeartbeat
# from onyx.tools.tool_implementations.images.image_generation_tool import (
# IMAGE_GENERATION_HEARTBEAT_ID,
# )
# from onyx.tools.tool_implementations.images.image_generation_tool import (
# IMAGE_GENERATION_RESPONSE_ID,
# )
# from onyx.tools.tool_implementations.images.image_generation_tool import (
# ImageGenerationResponse,
# )
# from onyx.tools.tool_implementations.images.image_generation_tool import (
# ImageGenerationTool,
# )
# from onyx.tools.tool_implementations.images.image_generation_tool import ImageShape
# from onyx.utils.logger import setup_logger
from onyx.agents.agent_search.dr.models import GeneratedImage
from onyx.agents.agent_search.dr.models import IterationAnswer
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
from onyx.agents.agent_search.dr.sub_agents.states import BranchUpdate
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.file_store.utils import build_frontend_file_url
from onyx.file_store.utils import save_files
from onyx.server.query_and_chat.streaming_models import ImageGenerationToolHeartbeat
from onyx.tools.tool_implementations.images.image_generation_tool import (
IMAGE_GENERATION_HEARTBEAT_ID,
)
from onyx.tools.tool_implementations.images.image_generation_tool import (
IMAGE_GENERATION_RESPONSE_ID,
)
from onyx.tools.tool_implementations.images.image_generation_tool import (
ImageGenerationResponse,
)
from onyx.tools.tool_implementations.images.image_generation_tool import (
ImageGenerationTool,
)
from onyx.tools.tool_implementations.images.image_generation_tool import ImageShape
from onyx.utils.logger import setup_logger
# logger = setup_logger()
logger = setup_logger()
# def image_generation(
# state: BranchInput,
# config: RunnableConfig,
# writer: StreamWriter = lambda _: None,
# ) -> BranchUpdate:
# """
# LangGraph node to perform a standard search as part of the DR process.
# """
def image_generation(
state: BranchInput,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> BranchUpdate:
"""
LangGraph node to perform a standard search as part of the DR process.
"""
# node_start_time = datetime.now()
# iteration_nr = state.iteration_nr
# parallelization_nr = state.parallelization_nr
# state.assistant_system_prompt
# state.assistant_task_prompt
node_start_time = datetime.now()
iteration_nr = state.iteration_nr
parallelization_nr = state.parallelization_nr
state.assistant_system_prompt
state.assistant_task_prompt
# branch_query = state.branch_question
# if not branch_query:
# raise ValueError("branch_query is not set")
branch_query = state.branch_question
if not branch_query:
raise ValueError("branch_query is not set")
# graph_config = cast(GraphConfig, config["metadata"]["config"])
graph_config = cast(GraphConfig, config["metadata"]["config"])
graph_config.inputs.prompt_builder.raw_user_query
graph_config.behavior.research_type
# if not state.available_tools:
# raise ValueError("available_tools is not set")
if not state.available_tools:
raise ValueError("available_tools is not set")
# image_tool_info = state.available_tools[state.tools_used[-1]]
# image_tool = cast(ImageGenerationTool, image_tool_info.tool_object)
image_tool_info = state.available_tools[state.tools_used[-1]]
image_tool = cast(ImageGenerationTool, image_tool_info.tool_object)
# image_prompt = branch_query
# requested_shape: ImageShape | None = None
image_prompt = branch_query
requested_shape: ImageShape | None = None
# try:
# parsed_query = json.loads(branch_query)
# except json.JSONDecodeError:
# parsed_query = None
try:
parsed_query = json.loads(branch_query)
except json.JSONDecodeError:
parsed_query = None
# if isinstance(parsed_query, dict):
# prompt_from_llm = parsed_query.get("prompt")
# if isinstance(prompt_from_llm, str) and prompt_from_llm.strip():
# image_prompt = prompt_from_llm.strip()
if isinstance(parsed_query, dict):
prompt_from_llm = parsed_query.get("prompt")
if isinstance(prompt_from_llm, str) and prompt_from_llm.strip():
image_prompt = prompt_from_llm.strip()
# raw_shape = parsed_query.get("shape")
# if isinstance(raw_shape, str):
# try:
# requested_shape = ImageShape(raw_shape)
# except ValueError:
# logger.warning(
# "Received unsupported image shape '%s' from LLM. Falling back to square.",
# raw_shape,
# )
raw_shape = parsed_query.get("shape")
if isinstance(raw_shape, str):
try:
requested_shape = ImageShape(raw_shape)
except ValueError:
logger.warning(
"Received unsupported image shape '%s' from LLM. Falling back to square.",
raw_shape,
)
# logger.debug(
# f"Image generation start for {iteration_nr}.{parallelization_nr} at {datetime.now()}"
# )
logger.debug(
f"Image generation start for {iteration_nr}.{parallelization_nr} at {datetime.now()}"
)
# # Generate images using the image generation tool
# image_generation_responses: list[ImageGenerationResponse] = []
# Generate images using the image generation tool
image_generation_responses: list[ImageGenerationResponse] = []
# if requested_shape is not None:
# tool_iterator = image_tool.run(
# prompt=image_prompt,
# shape=requested_shape.value,
# )
# else:
# tool_iterator = image_tool.run(prompt=image_prompt)
if requested_shape is not None:
tool_iterator = image_tool.run(
prompt=image_prompt,
shape=requested_shape.value,
)
else:
tool_iterator = image_tool.run(prompt=image_prompt)
# for tool_response in tool_iterator:
# if tool_response.id == IMAGE_GENERATION_HEARTBEAT_ID:
# # Stream heartbeat to frontend
# write_custom_event(
# state.current_step_nr,
# ImageGenerationToolHeartbeat(),
# writer,
# )
# elif tool_response.id == IMAGE_GENERATION_RESPONSE_ID:
# response = cast(list[ImageGenerationResponse], tool_response.response)
# image_generation_responses = response
# break
for tool_response in tool_iterator:
if tool_response.id == IMAGE_GENERATION_HEARTBEAT_ID:
# Stream heartbeat to frontend
write_custom_event(
state.current_step_nr,
ImageGenerationToolHeartbeat(),
writer,
)
elif tool_response.id == IMAGE_GENERATION_RESPONSE_ID:
response = cast(list[ImageGenerationResponse], tool_response.response)
image_generation_responses = response
break
# # save images to file store
# file_ids = save_files(
# urls=[],
# base64_files=[img.image_data for img in image_generation_responses],
# )
# save images to file store
file_ids = save_files(
urls=[],
base64_files=[img.image_data for img in image_generation_responses],
)
# final_generated_images = [
# GeneratedImage(
# file_id=file_id,
# url=build_frontend_file_url(file_id),
# revised_prompt=img.revised_prompt,
# shape=(requested_shape or ImageShape.SQUARE).value,
# )
# for file_id, img in zip(file_ids, image_generation_responses)
# ]
final_generated_images = [
GeneratedImage(
file_id=file_id,
url=build_frontend_file_url(file_id),
revised_prompt=img.revised_prompt,
shape=(requested_shape or ImageShape.SQUARE).value,
)
for file_id, img in zip(file_ids, image_generation_responses)
]
# logger.debug(
# f"Image generation complete for {iteration_nr}.{parallelization_nr} at {datetime.now()}"
# )
logger.debug(
f"Image generation complete for {iteration_nr}.{parallelization_nr} at {datetime.now()}"
)
# # Create answer string describing the generated images
# if final_generated_images:
# image_descriptions = []
# for i, img in enumerate(final_generated_images, 1):
# if img.shape and img.shape != ImageShape.SQUARE.value:
# image_descriptions.append(
# f"Image {i}: {img.revised_prompt} (shape: {img.shape})"
# )
# else:
# image_descriptions.append(f"Image {i}: {img.revised_prompt}")
# Create answer string describing the generated images
if final_generated_images:
image_descriptions = []
for i, img in enumerate(final_generated_images, 1):
if img.shape and img.shape != ImageShape.SQUARE.value:
image_descriptions.append(
f"Image {i}: {img.revised_prompt} (shape: {img.shape})"
)
else:
image_descriptions.append(f"Image {i}: {img.revised_prompt}")
# answer_string = (
# f"Generated {len(final_generated_images)} image(s) based on the request: {image_prompt}\n\n"
# + "\n".join(image_descriptions)
# )
# if requested_shape:
# reasoning = (
# "Used image generation tool to create "
# f"{len(final_generated_images)} image(s) in {requested_shape.value} orientation."
# )
# else:
# reasoning = (
# "Used image generation tool to create "
# f"{len(final_generated_images)} image(s) based on the user's request."
# )
# else:
# answer_string = f"Failed to generate images for request: {image_prompt}"
# reasoning = "Image generation tool did not return any results."
answer_string = (
f"Generated {len(final_generated_images)} image(s) based on the request: {image_prompt}\n\n"
+ "\n".join(image_descriptions)
)
if requested_shape:
reasoning = (
"Used image generation tool to create "
f"{len(final_generated_images)} image(s) in {requested_shape.value} orientation."
)
else:
reasoning = (
"Used image generation tool to create "
f"{len(final_generated_images)} image(s) based on the user's request."
)
else:
answer_string = f"Failed to generate images for request: {image_prompt}"
reasoning = "Image generation tool did not return any results."
# return BranchUpdate(
# branch_iteration_responses=[
# IterationAnswer(
# tool=image_tool_info.llm_path,
# tool_id=image_tool_info.tool_id,
# iteration_nr=iteration_nr,
# parallelization_nr=parallelization_nr,
# question=branch_query,
# answer=answer_string,
# claims=[],
# cited_documents={},
# reasoning=reasoning,
# generated_images=final_generated_images,
# )
# ],
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="image_generation",
# node_name="generating",
# node_start_time=node_start_time,
# )
# ],
# )
return BranchUpdate(
branch_iteration_responses=[
IterationAnswer(
tool=image_tool_info.llm_path,
tool_id=image_tool_info.tool_id,
iteration_nr=iteration_nr,
parallelization_nr=parallelization_nr,
question=branch_query,
answer=answer_string,
claims=[],
cited_documents={},
reasoning=reasoning,
generated_images=final_generated_images,
)
],
log_messages=[
get_langgraph_node_log_string(
graph_component="image_generation",
node_name="generating",
node_start_time=node_start_time,
)
],
)

View File

@@ -1,71 +1,71 @@
# from datetime import datetime
from datetime import datetime
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from onyx.agents.agent_search.dr.models import GeneratedImage
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentUpdate
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
# from onyx.server.query_and_chat.streaming_models import ImageGenerationFinal
# from onyx.server.query_and_chat.streaming_models import SectionEnd
# from onyx.utils.logger import setup_logger
from onyx.agents.agent_search.dr.models import GeneratedImage
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentUpdate
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.server.query_and_chat.streaming_models import ImageGenerationToolDelta
from onyx.server.query_and_chat.streaming_models import SectionEnd
from onyx.utils.logger import setup_logger
# logger = setup_logger()
logger = setup_logger()
# def is_reducer(
# state: SubAgentMainState,
# config: RunnableConfig,
# writer: StreamWriter = lambda _: None,
# ) -> SubAgentUpdate:
# """
# LangGraph node to perform a standard search as part of the DR process.
# """
def is_reducer(
state: SubAgentMainState,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> SubAgentUpdate:
"""
LangGraph node to perform a standard search as part of the DR process.
"""
# node_start_time = datetime.now()
node_start_time = datetime.now()
# branch_updates = state.branch_iteration_responses
# current_iteration = state.iteration_nr
# current_step_nr = state.current_step_nr
branch_updates = state.branch_iteration_responses
current_iteration = state.iteration_nr
current_step_nr = state.current_step_nr
# new_updates = [
# update for update in branch_updates if update.iteration_nr == current_iteration
# ]
# generated_images: list[GeneratedImage] = []
# for update in new_updates:
# if update.generated_images:
# generated_images.extend(update.generated_images)
new_updates = [
update for update in branch_updates if update.iteration_nr == current_iteration
]
generated_images: list[GeneratedImage] = []
for update in new_updates:
if update.generated_images:
generated_images.extend(update.generated_images)
# # Write the results to the stream
# write_custom_event(
# current_step_nr,
# ImageGenerationFinal(
# images=generated_images,
# ),
# writer,
# )
# Write the results to the stream
write_custom_event(
current_step_nr,
ImageGenerationToolDelta(
images=generated_images,
),
writer,
)
# write_custom_event(
# current_step_nr,
# SectionEnd(),
# writer,
# )
write_custom_event(
current_step_nr,
SectionEnd(),
writer,
)
# current_step_nr += 1
current_step_nr += 1
# return SubAgentUpdate(
# iteration_responses=new_updates,
# current_step_nr=current_step_nr,
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="image_generation",
# node_name="consolidation",
# node_start_time=node_start_time,
# )
# ],
# )
return SubAgentUpdate(
iteration_responses=new_updates,
current_step_nr=current_step_nr,
log_messages=[
get_langgraph_node_log_string(
graph_component="image_generation",
node_name="consolidation",
node_start_time=node_start_time,
)
],
)

View File

@@ -1,29 +1,29 @@
# from collections.abc import Hashable
from collections.abc import Hashable
# from langgraph.types import Send
from langgraph.types import Send
# from onyx.agents.agent_search.dr.constants import MAX_DR_PARALLEL_SEARCH
# from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
from onyx.agents.agent_search.dr.constants import MAX_DR_PARALLEL_SEARCH
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
# def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
# return [
# Send(
# "act",
# BranchInput(
# iteration_nr=state.iteration_nr,
# parallelization_nr=parallelization_nr,
# branch_question=query,
# context="",
# active_source_types=state.active_source_types,
# tools_used=state.tools_used,
# available_tools=state.available_tools,
# assistant_system_prompt=state.assistant_system_prompt,
# assistant_task_prompt=state.assistant_task_prompt,
# ),
# )
# for parallelization_nr, query in enumerate(
# state.query_list[:MAX_DR_PARALLEL_SEARCH]
# )
# ]
def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
return [
Send(
"act",
BranchInput(
iteration_nr=state.iteration_nr,
parallelization_nr=parallelization_nr,
branch_question=query,
context="",
active_source_types=state.active_source_types,
tools_used=state.tools_used,
available_tools=state.available_tools,
assistant_system_prompt=state.assistant_system_prompt,
assistant_task_prompt=state.assistant_task_prompt,
),
)
for parallelization_nr, query in enumerate(
state.query_list[:MAX_DR_PARALLEL_SEARCH]
)
]

View File

@@ -1,50 +1,50 @@
# from langgraph.graph import END
# from langgraph.graph import START
# from langgraph.graph import StateGraph
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
# from onyx.agents.agent_search.dr.sub_agents.image_generation.dr_image_generation_1_branch import (
# image_generation_branch,
# )
# from onyx.agents.agent_search.dr.sub_agents.image_generation.dr_image_generation_2_act import (
# image_generation,
# )
# from onyx.agents.agent_search.dr.sub_agents.image_generation.dr_image_generation_3_reduce import (
# is_reducer,
# )
# from onyx.agents.agent_search.dr.sub_agents.image_generation.dr_image_generation_conditional_edges import (
# branching_router,
# )
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
# from onyx.utils.logger import setup_logger
from onyx.agents.agent_search.dr.sub_agents.image_generation.dr_image_generation_1_branch import (
image_generation_branch,
)
from onyx.agents.agent_search.dr.sub_agents.image_generation.dr_image_generation_2_act import (
image_generation,
)
from onyx.agents.agent_search.dr.sub_agents.image_generation.dr_image_generation_3_reduce import (
is_reducer,
)
from onyx.agents.agent_search.dr.sub_agents.image_generation.dr_image_generation_conditional_edges import (
branching_router,
)
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
from onyx.utils.logger import setup_logger
# logger = setup_logger()
logger = setup_logger()
# def dr_image_generation_graph_builder() -> StateGraph:
# """
# LangGraph graph builder for Image Generation Sub-Agent
# """
def dr_image_generation_graph_builder() -> StateGraph:
"""
LangGraph graph builder for Image Generation Sub-Agent
"""
# graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
# ### Add nodes ###
### Add nodes ###
# graph.add_node("branch", image_generation_branch)
graph.add_node("branch", image_generation_branch)
# graph.add_node("act", image_generation)
graph.add_node("act", image_generation)
# graph.add_node("reducer", is_reducer)
graph.add_node("reducer", is_reducer)
# ### Add edges ###
### Add edges ###
# graph.add_edge(start_key=START, end_key="branch")
graph.add_edge(start_key=START, end_key="branch")
# graph.add_conditional_edges("branch", branching_router)
graph.add_conditional_edges("branch", branching_router)
# graph.add_edge(start_key="act", end_key="reducer")
graph.add_edge(start_key="act", end_key="reducer")
# graph.add_edge(start_key="reducer", end_key=END)
graph.add_edge(start_key="reducer", end_key=END)
# return graph
return graph

View File

@@ -1,13 +1,13 @@
# from pydantic import BaseModel
from pydantic import BaseModel
# class GeneratedImage(BaseModel):
# file_id: str
# url: str
# revised_prompt: str
# shape: str | None = None
class GeneratedImage(BaseModel):
file_id: str
url: str
revised_prompt: str
shape: str | None = None
# # Needed for PydanticType
# class GeneratedImageFullResult(BaseModel):
# images: list[GeneratedImage]
# Needed for PydanticType
class GeneratedImageFullResult(BaseModel):
images: list[GeneratedImage]

View File

@@ -1,36 +1,36 @@
# from datetime import datetime
from datetime import datetime
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from onyx.agents.agent_search.dr.states import LoggerUpdate
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.utils.logger import setup_logger
from onyx.agents.agent_search.dr.states import LoggerUpdate
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.utils.logger import setup_logger
# logger = setup_logger()
logger = setup_logger()
# def kg_search_branch(
# state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
# ) -> LoggerUpdate:
# """
# LangGraph node to perform a KG search as part of the DR process.
# """
def kg_search_branch(
state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> LoggerUpdate:
"""
LangGraph node to perform a KG search as part of the DR process.
"""
# node_start_time = datetime.now()
# iteration_nr = state.iteration_nr
node_start_time = datetime.now()
iteration_nr = state.iteration_nr
# logger.debug(f"Search start for KG Search {iteration_nr} at {datetime.now()}")
logger.debug(f"Search start for KG Search {iteration_nr} at {datetime.now()}")
# return LoggerUpdate(
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="kg_search",
# node_name="branching",
# node_start_time=node_start_time,
# )
# ],
# )
return LoggerUpdate(
log_messages=[
get_langgraph_node_log_string(
graph_component="kg_search",
node_name="branching",
node_start_time=node_start_time,
)
],
)

View File

@@ -1,97 +1,97 @@
# from datetime import datetime
from datetime import datetime
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from onyx.agents.agent_search.dr.models import IterationAnswer
# from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
# from onyx.agents.agent_search.dr.sub_agents.states import BranchUpdate
# from onyx.agents.agent_search.dr.utils import extract_document_citations
# from onyx.agents.agent_search.kb_search.graph_builder import kb_graph_builder
# from onyx.agents.agent_search.kb_search.states import MainInput as KbMainInput
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.context.search.models import InferenceSection
# from onyx.utils.logger import setup_logger
from onyx.agents.agent_search.dr.models import IterationAnswer
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
from onyx.agents.agent_search.dr.sub_agents.states import BranchUpdate
from onyx.agents.agent_search.dr.utils import extract_document_citations
from onyx.agents.agent_search.kb_search.graph_builder import kb_graph_builder
from onyx.agents.agent_search.kb_search.states import MainInput as KbMainInput
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.context.search.models import InferenceSection
from onyx.utils.logger import setup_logger
# logger = setup_logger()
logger = setup_logger()
# def kg_search(
# state: BranchInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
# ) -> BranchUpdate:
# """
# LangGraph node to perform a KG search as part of the DR process.
# """
def kg_search(
state: BranchInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> BranchUpdate:
"""
LangGraph node to perform a KG search as part of the DR process.
"""
# node_start_time = datetime.now()
# iteration_nr = state.iteration_nr
# state.current_step_nr
# parallelization_nr = state.parallelization_nr
node_start_time = datetime.now()
iteration_nr = state.iteration_nr
state.current_step_nr
parallelization_nr = state.parallelization_nr
# search_query = state.branch_question
# if not search_query:
# raise ValueError("search_query is not set")
search_query = state.branch_question
if not search_query:
raise ValueError("search_query is not set")
# logger.debug(
# f"Search start for KG Search {iteration_nr}.{parallelization_nr} at {datetime.now()}"
# )
logger.debug(
f"Search start for KG Search {iteration_nr}.{parallelization_nr} at {datetime.now()}"
)
# if not state.available_tools:
# raise ValueError("available_tools is not set")
if not state.available_tools:
raise ValueError("available_tools is not set")
# kg_tool_info = state.available_tools[state.tools_used[-1]]
kg_tool_info = state.available_tools[state.tools_used[-1]]
# kb_graph = kb_graph_builder().compile()
kb_graph = kb_graph_builder().compile()
# kb_results = kb_graph.invoke(
# input=KbMainInput(question=search_query, individual_flow=False),
# config=config,
# )
kb_results = kb_graph.invoke(
input=KbMainInput(question=search_query, individual_flow=False),
config=config,
)
# # get cited documents
# answer_string = kb_results.get("final_answer") or "No answer provided"
# claims: list[str] = []
# retrieved_docs: list[InferenceSection] = kb_results.get("retrieved_documents", [])
# get cited documents
answer_string = kb_results.get("final_answer") or "No answer provided"
claims: list[str] = []
retrieved_docs: list[InferenceSection] = kb_results.get("retrieved_documents", [])
# (
# citation_numbers,
# answer_string,
# claims,
# ) = extract_document_citations(answer_string, claims)
(
citation_numbers,
answer_string,
claims,
) = extract_document_citations(answer_string, claims)
# # if citation is empty, the answer must have come from the KG rather than a doc
# # in that case, simply cite the docs returned by the KG
# if not citation_numbers:
# citation_numbers = [i + 1 for i in range(len(retrieved_docs))]
# if citation is empty, the answer must have come from the KG rather than a doc
# in that case, simply cite the docs returned by the KG
if not citation_numbers:
citation_numbers = [i + 1 for i in range(len(retrieved_docs))]
# cited_documents = {
# citation_number: retrieved_docs[citation_number - 1]
# for citation_number in citation_numbers
# if citation_number <= len(retrieved_docs)
# }
cited_documents = {
citation_number: retrieved_docs[citation_number - 1]
for citation_number in citation_numbers
if citation_number <= len(retrieved_docs)
}
# return BranchUpdate(
# branch_iteration_responses=[
# IterationAnswer(
# tool=kg_tool_info.llm_path,
# tool_id=kg_tool_info.tool_id,
# iteration_nr=iteration_nr,
# parallelization_nr=parallelization_nr,
# question=search_query,
# answer=answer_string,
# claims=claims,
# cited_documents=cited_documents,
# reasoning=None,
# additional_data=None,
# )
# ],
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="kg_search",
# node_name="searching",
# node_start_time=node_start_time,
# )
# ],
# )
return BranchUpdate(
branch_iteration_responses=[
IterationAnswer(
tool=kg_tool_info.llm_path,
tool_id=kg_tool_info.tool_id,
iteration_nr=iteration_nr,
parallelization_nr=parallelization_nr,
question=search_query,
answer=answer_string,
claims=claims,
cited_documents=cited_documents,
reasoning=None,
additional_data=None,
)
],
log_messages=[
get_langgraph_node_log_string(
graph_component="kg_search",
node_name="searching",
node_start_time=node_start_time,
)
],
)

View File

@@ -1,121 +1,121 @@
# from datetime import datetime
from datetime import datetime
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentUpdate
# from onyx.agents.agent_search.dr.utils import convert_inference_sections_to_search_docs
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
# from onyx.server.query_and_chat.streaming_models import ReasoningDelta
# from onyx.server.query_and_chat.streaming_models import ReasoningStart
# from onyx.server.query_and_chat.streaming_models import SearchToolDelta
# from onyx.server.query_and_chat.streaming_models import SearchToolStart
# from onyx.server.query_and_chat.streaming_models import SectionEnd
# from onyx.utils.logger import setup_logger
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentUpdate
from onyx.agents.agent_search.dr.utils import convert_inference_sections_to_search_docs
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.server.query_and_chat.streaming_models import ReasoningDelta
from onyx.server.query_and_chat.streaming_models import ReasoningStart
from onyx.server.query_and_chat.streaming_models import SearchToolDelta
from onyx.server.query_and_chat.streaming_models import SearchToolStart
from onyx.server.query_and_chat.streaming_models import SectionEnd
from onyx.utils.logger import setup_logger
# logger = setup_logger()
logger = setup_logger()
# _MAX_KG_STEAMED_ANSWER_LENGTH = 1000 # num characters
_MAX_KG_STEAMED_ANSWER_LENGTH = 1000 # num characters
# def kg_search_reducer(
# state: SubAgentMainState,
# config: RunnableConfig,
# writer: StreamWriter = lambda _: None,
# ) -> SubAgentUpdate:
# """
# LangGraph node to perform a KG search as part of the DR process.
# """
def kg_search_reducer(
state: SubAgentMainState,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> SubAgentUpdate:
"""
LangGraph node to perform a KG search as part of the DR process.
"""
# node_start_time = datetime.now()
node_start_time = datetime.now()
# branch_updates = state.branch_iteration_responses
# current_iteration = state.iteration_nr
# current_step_nr = state.current_step_nr
branch_updates = state.branch_iteration_responses
current_iteration = state.iteration_nr
current_step_nr = state.current_step_nr
# new_updates = [
# update for update in branch_updates if update.iteration_nr == current_iteration
# ]
new_updates = [
update for update in branch_updates if update.iteration_nr == current_iteration
]
# queries = [update.question for update in new_updates]
# doc_lists = [list(update.cited_documents.values()) for update in new_updates]
queries = [update.question for update in new_updates]
doc_lists = [list(update.cited_documents.values()) for update in new_updates]
# doc_list = []
doc_list = []
# for xs in doc_lists:
# for x in xs:
# doc_list.append(x)
for xs in doc_lists:
for x in xs:
doc_list.append(x)
# retrieved_search_docs = convert_inference_sections_to_search_docs(doc_list)
# kg_answer = (
# "The Knowledge Graph Answer:\n\n" + new_updates[0].answer
# if len(queries) == 1
# else None
# )
retrieved_search_docs = convert_inference_sections_to_search_docs(doc_list)
kg_answer = (
"The Knowledge Graph Answer:\n\n" + new_updates[0].answer
if len(queries) == 1
else None
)
# if len(retrieved_search_docs) > 0:
# write_custom_event(
# current_step_nr,
# SearchToolStart(
# is_internet_search=False,
# ),
# writer,
# )
# write_custom_event(
# current_step_nr,
# SearchToolDelta(
# queries=queries,
# documents=retrieved_search_docs,
# ),
# writer,
# )
# write_custom_event(
# current_step_nr,
# SectionEnd(),
# writer,
# )
if len(retrieved_search_docs) > 0:
write_custom_event(
current_step_nr,
SearchToolStart(
is_internet_search=False,
),
writer,
)
write_custom_event(
current_step_nr,
SearchToolDelta(
queries=queries,
documents=retrieved_search_docs,
),
writer,
)
write_custom_event(
current_step_nr,
SectionEnd(),
writer,
)
# current_step_nr += 1
current_step_nr += 1
# if kg_answer is not None:
if kg_answer is not None:
# kg_display_answer = (
# f"{kg_answer[:_MAX_KG_STEAMED_ANSWER_LENGTH]}..."
# if len(kg_answer) > _MAX_KG_STEAMED_ANSWER_LENGTH
# else kg_answer
# )
kg_display_answer = (
f"{kg_answer[:_MAX_KG_STEAMED_ANSWER_LENGTH]}..."
if len(kg_answer) > _MAX_KG_STEAMED_ANSWER_LENGTH
else kg_answer
)
# write_custom_event(
# current_step_nr,
# ReasoningStart(),
# writer,
# )
# write_custom_event(
# current_step_nr,
# ReasoningDelta(reasoning=kg_display_answer),
# writer,
# )
# write_custom_event(
# current_step_nr,
# SectionEnd(),
# writer,
# )
write_custom_event(
current_step_nr,
ReasoningStart(),
writer,
)
write_custom_event(
current_step_nr,
ReasoningDelta(reasoning=kg_display_answer),
writer,
)
write_custom_event(
current_step_nr,
SectionEnd(),
writer,
)
# current_step_nr += 1
current_step_nr += 1
# return SubAgentUpdate(
# iteration_responses=new_updates,
# current_step_nr=current_step_nr,
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="kg_search",
# node_name="consolidation",
# node_start_time=node_start_time,
# )
# ],
# )
return SubAgentUpdate(
iteration_responses=new_updates,
current_step_nr=current_step_nr,
log_messages=[
get_langgraph_node_log_string(
graph_component="kg_search",
node_name="consolidation",
node_start_time=node_start_time,
)
],
)

View File

@@ -1,27 +1,27 @@
# from collections.abc import Hashable
from collections.abc import Hashable
# from langgraph.types import Send
from langgraph.types import Send
# from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
# def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
# return [
# Send(
# "act",
# BranchInput(
# iteration_nr=state.iteration_nr,
# parallelization_nr=parallelization_nr,
# branch_question=query,
# context="",
# tools_used=state.tools_used,
# available_tools=state.available_tools,
# assistant_system_prompt=state.assistant_system_prompt,
# assistant_task_prompt=state.assistant_task_prompt,
# ),
# )
# for parallelization_nr, query in enumerate(
# state.query_list[:1] # no parallel search for now
# )
# ]
def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
return [
Send(
"act",
BranchInput(
iteration_nr=state.iteration_nr,
parallelization_nr=parallelization_nr,
branch_question=query,
context="",
tools_used=state.tools_used,
available_tools=state.available_tools,
assistant_system_prompt=state.assistant_system_prompt,
assistant_task_prompt=state.assistant_task_prompt,
),
)
for parallelization_nr, query in enumerate(
state.query_list[:1] # no parallel search for now
)
]

View File

@@ -1,50 +1,50 @@
# from langgraph.graph import END
# from langgraph.graph import START
# from langgraph.graph import StateGraph
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
# from onyx.agents.agent_search.dr.sub_agents.kg_search.dr_kg_search_1_branch import (
# kg_search_branch,
# )
# from onyx.agents.agent_search.dr.sub_agents.kg_search.dr_kg_search_2_act import (
# kg_search,
# )
# from onyx.agents.agent_search.dr.sub_agents.kg_search.dr_kg_search_3_reduce import (
# kg_search_reducer,
# )
# from onyx.agents.agent_search.dr.sub_agents.kg_search.dr_kg_search_conditional_edges import (
# branching_router,
# )
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
# from onyx.utils.logger import setup_logger
from onyx.agents.agent_search.dr.sub_agents.kg_search.dr_kg_search_1_branch import (
kg_search_branch,
)
from onyx.agents.agent_search.dr.sub_agents.kg_search.dr_kg_search_2_act import (
kg_search,
)
from onyx.agents.agent_search.dr.sub_agents.kg_search.dr_kg_search_3_reduce import (
kg_search_reducer,
)
from onyx.agents.agent_search.dr.sub_agents.kg_search.dr_kg_search_conditional_edges import (
branching_router,
)
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
from onyx.utils.logger import setup_logger
# logger = setup_logger()
logger = setup_logger()
# def dr_kg_search_graph_builder() -> StateGraph:
# """
# LangGraph graph builder for KG Search Sub-Agent
# """
def dr_kg_search_graph_builder() -> StateGraph:
"""
LangGraph graph builder for KG Search Sub-Agent
"""
# graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
# ### Add nodes ###
### Add nodes ###
# graph.add_node("branch", kg_search_branch)
graph.add_node("branch", kg_search_branch)
# graph.add_node("act", kg_search)
graph.add_node("act", kg_search)
# graph.add_node("reducer", kg_search_reducer)
graph.add_node("reducer", kg_search_reducer)
# ### Add edges ###
### Add edges ###
# graph.add_edge(start_key=START, end_key="branch")
graph.add_edge(start_key=START, end_key="branch")
# graph.add_conditional_edges("branch", branching_router)
graph.add_conditional_edges("branch", branching_router)
# graph.add_edge(start_key="act", end_key="reducer")
graph.add_edge(start_key="act", end_key="reducer")
# graph.add_edge(start_key="reducer", end_key=END)
graph.add_edge(start_key="reducer", end_key=END)
# return graph
return graph

View File

@@ -1,46 +1,46 @@
# from operator import add
# from typing import Annotated
from operator import add
from typing import Annotated
# from onyx.agents.agent_search.dr.models import IterationAnswer
# from onyx.agents.agent_search.dr.models import OrchestratorTool
# from onyx.agents.agent_search.dr.states import LoggerUpdate
# from onyx.db.connector import DocumentSource
from onyx.agents.agent_search.dr.models import IterationAnswer
from onyx.agents.agent_search.dr.models import OrchestratorTool
from onyx.agents.agent_search.dr.states import LoggerUpdate
from onyx.db.connector import DocumentSource
# class SubAgentUpdate(LoggerUpdate):
# iteration_responses: Annotated[list[IterationAnswer], add] = []
# current_step_nr: int = 1
class SubAgentUpdate(LoggerUpdate):
iteration_responses: Annotated[list[IterationAnswer], add] = []
current_step_nr: int = 1
# class BranchUpdate(LoggerUpdate):
# branch_iteration_responses: Annotated[list[IterationAnswer], add] = []
class BranchUpdate(LoggerUpdate):
branch_iteration_responses: Annotated[list[IterationAnswer], add] = []
# class SubAgentInput(LoggerUpdate):
# iteration_nr: int = 0
# current_step_nr: int = 1
# query_list: list[str] = []
# context: str | None = None
# active_source_types: list[DocumentSource] | None = None
# tools_used: Annotated[list[str], add] = []
# available_tools: dict[str, OrchestratorTool] | None = None
# assistant_system_prompt: str | None = None
# assistant_task_prompt: str | None = None
class SubAgentInput(LoggerUpdate):
iteration_nr: int = 0
current_step_nr: int = 1
query_list: list[str] = []
context: str | None = None
active_source_types: list[DocumentSource] | None = None
tools_used: Annotated[list[str], add] = []
available_tools: dict[str, OrchestratorTool] | None = None
assistant_system_prompt: str | None = None
assistant_task_prompt: str | None = None
# class SubAgentMainState(
# # This includes the core state
# SubAgentInput,
# SubAgentUpdate,
# BranchUpdate,
# ):
# pass
class SubAgentMainState(
# This includes the core state
SubAgentInput,
SubAgentUpdate,
BranchUpdate,
):
pass
# class BranchInput(SubAgentInput):
# parallelization_nr: int = 0
# branch_question: str
class BranchInput(SubAgentInput):
parallelization_nr: int = 0
branch_question: str
# class CustomToolBranchInput(LoggerUpdate):
# tool_info: OrchestratorTool
class CustomToolBranchInput(LoggerUpdate):
tool_info: OrchestratorTool

View File

@@ -1,71 +1,70 @@
# from collections.abc import Sequence
from collections.abc import Sequence
# from exa_py import Exa
# from exa_py.api import HighlightsContentsOptions
from exa_py import Exa
from exa_py.api import HighlightsContentsOptions
# from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
# WebContent,
# )
# from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
# WebSearchProvider,
# )
# from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
# WebSearchResult,
# )
# from onyx.configs.chat_configs import EXA_API_KEY
# from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
# from onyx.utils.retry_wrapper import retry_builder
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
WebContent,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
WebSearchProvider,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
WebSearchResult,
)
from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
from onyx.utils.retry_wrapper import retry_builder
# class ExaClient(WebSearchProvider):
# def __init__(self, api_key: str | None = EXA_API_KEY) -> None:
# self.exa = Exa(api_key=api_key)
class ExaClient(WebSearchProvider):
def __init__(self, api_key: str) -> None:
self.exa = Exa(api_key=api_key)
# @retry_builder(tries=3, delay=1, backoff=2)
# def search(self, query: str) -> list[WebSearchResult]:
# response = self.exa.search_and_contents(
# query,
# type="auto",
# highlights=HighlightsContentsOptions(
# num_sentences=2,
# highlights_per_url=1,
# ),
# num_results=10,
# )
@retry_builder(tries=3, delay=1, backoff=2)
def search(self, query: str) -> list[WebSearchResult]:
response = self.exa.search_and_contents(
query,
type="auto",
highlights=HighlightsContentsOptions(
num_sentences=2,
highlights_per_url=1,
),
num_results=10,
)
# return [
# WebSearchResult(
# title=result.title or "",
# link=result.url,
# snippet=result.highlights[0] if result.highlights else "",
# author=result.author,
# published_date=(
# time_str_to_utc(result.published_date)
# if result.published_date
# else None
# ),
# )
# for result in response.results
# ]
return [
WebSearchResult(
title=result.title or "",
link=result.url,
snippet=result.highlights[0] if result.highlights else "",
author=result.author,
published_date=(
time_str_to_utc(result.published_date)
if result.published_date
else None
),
)
for result in response.results
]
# @retry_builder(tries=3, delay=1, backoff=2)
# def contents(self, urls: Sequence[str]) -> list[WebContent]:
# response = self.exa.get_contents(
# urls=list(urls),
# text=True,
# livecrawl="preferred",
# )
@retry_builder(tries=3, delay=1, backoff=2)
def contents(self, urls: Sequence[str]) -> list[WebContent]:
response = self.exa.get_contents(
urls=list(urls),
text=True,
livecrawl="preferred",
)
# return [
# WebContent(
# title=result.title or "",
# link=result.url,
# full_content=result.text or "",
# published_date=(
# time_str_to_utc(result.published_date)
# if result.published_date
# else None
# ),
# )
# for result in response.results
# ]
return [
WebContent(
title=result.title or "",
link=result.url,
full_content=result.text or "",
published_date=(
time_str_to_utc(result.published_date)
if result.published_date
else None
),
)
for result in response.results
]

View File

@@ -8,9 +8,11 @@ from typing import Any
import requests
from onyx.agents.agent_search.dr.sub_agents.web_search.models import WebContent
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
WebContentProvider,
)
from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
from onyx.tools.tool_implementations.open_url.models import WebContent
from onyx.tools.tool_implementations.open_url.models import WebContentProvider
from onyx.utils.logger import setup_logger
from onyx.utils.retry_wrapper import retry_builder

View File

@@ -1,14 +1,16 @@
from __future__ import annotations
from collections.abc import Sequence
from datetime import datetime
from typing import Any
import requests
from onyx.tools.tool_implementations.web_search.models import (
from onyx.agents.agent_search.dr.sub_agents.web_search.models import WebContent
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
WebSearchProvider,
)
from onyx.tools.tool_implementations.web_search.models import WebSearchResult
from onyx.agents.agent_search.dr.sub_agents.web_search.models import WebSearchResult
from onyx.utils.logger import setup_logger
from onyx.utils.retry_wrapper import retry_builder
@@ -119,3 +121,18 @@ class GooglePSEClient(WebSearchProvider):
)
return results
def contents(self, urls: Sequence[str]) -> list[WebContent]:
logger.warning(
"Google PSE does not support content fetching; returning empty results."
)
return [
WebContent(
title="",
link=url,
full_content="",
published_date=None,
scrape_successful=False,
)
for url in urls
]

View File

@@ -4,14 +4,14 @@ from collections.abc import Sequence
import requests
from onyx.file_processing.html_utils import ParsedHTML
from onyx.file_processing.html_utils import web_html_cleanup
from onyx.tools.tool_implementations.open_url.models import (
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
WebContent,
)
from onyx.tools.tool_implementations.open_url.models import (
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
WebContentProvider,
)
from onyx.file_processing.html_utils import ParsedHTML
from onyx.file_processing.html_utils import web_html_cleanup
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -20,7 +20,7 @@ DEFAULT_TIMEOUT_SECONDS = 15
DEFAULT_USER_AGENT = "OnyxWebCrawler/1.0 (+https://www.onyx.app)"
class OnyxWebCrawler(WebContentProvider):
class OnyxWebCrawlerClient(WebContentProvider):
"""
Lightweight built-in crawler that fetches HTML directly and extracts readable text.
Acts as the default content provider when no external crawler (e.g. Firecrawl) is

View File

@@ -1,148 +1,159 @@
# import json
# from collections.abc import Sequence
# from concurrent.futures import ThreadPoolExecutor
import json
from collections.abc import Sequence
from concurrent.futures import ThreadPoolExecutor
# import requests
import requests
# from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
# WebContent,
# )
# from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
# WebSearchProvider,
# )
# from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
# WebSearchResult,
# )
# from onyx.configs.chat_configs import SERPER_API_KEY
# from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
# from onyx.utils.retry_wrapper import retry_builder
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
WebContent,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
WebSearchProvider,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
WebSearchResult,
)
from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
from onyx.utils.retry_wrapper import retry_builder
# SERPER_SEARCH_URL = "https://google.serper.dev/search"
# SERPER_CONTENTS_URL = "https://scrape.serper.dev"
SERPER_SEARCH_URL = "https://google.serper.dev/search"
SERPER_CONTENTS_URL = "https://scrape.serper.dev"
# class SerperClient(WebSearchProvider):
# def __init__(self, api_key: str | None = SERPER_API_KEY) -> None:
# self.headers = {
# "X-API-KEY": api_key,
# "Content-Type": "application/json",
# }
class SerperClient(WebSearchProvider):
def __init__(self, api_key: str) -> None:
self.headers = {
"X-API-KEY": api_key,
"Content-Type": "application/json",
}
# @retry_builder(tries=3, delay=1, backoff=2)
# def search(self, query: str) -> list[WebSearchResult]:
# payload = {
# "q": query,
# }
@retry_builder(tries=3, delay=1, backoff=2)
def search(self, query: str) -> list[WebSearchResult]:
payload = {
"q": query,
}
# response = requests.post(
# SERPER_SEARCH_URL,
# headers=self.headers,
# data=json.dumps(payload),
# )
response = requests.post(
SERPER_SEARCH_URL,
headers=self.headers,
data=json.dumps(payload),
)
# response.raise_for_status()
try:
response.raise_for_status()
except Exception:
# Avoid leaking API keys/URLs
raise ValueError(
"Serper search failed. Check credentials or quota."
) from None
# results = response.json()
# organic_results = results["organic"]
results = response.json()
organic_results = results["organic"]
# return [
# WebSearchResult(
# title=result["title"],
# link=result["link"],
# snippet=result["snippet"],
# author=None,
# published_date=None,
# )
# for result in organic_results
# ]
return [
WebSearchResult(
title=result["title"],
link=result["link"],
snippet=result["snippet"],
author=None,
published_date=None,
)
for result in organic_results
]
# def contents(self, urls: Sequence[str]) -> list[WebContent]:
# if not urls:
# return []
def contents(self, urls: Sequence[str]) -> list[WebContent]:
if not urls:
return []
# # Serper can responds with 500s regularly. We want to retry,
# # but in the event of failure, return an unsuccesful scrape.
# def safe_get_webpage_content(url: str) -> WebContent:
# try:
# return self._get_webpage_content(url)
# except Exception:
# return WebContent(
# title="",
# link=url,
# full_content="",
# published_date=None,
# scrape_successful=False,
# )
# Serper can responds with 500s regularly. We want to retry,
# but in the event of failure, return an unsuccesful scrape.
def safe_get_webpage_content(url: str) -> WebContent:
try:
return self._get_webpage_content(url)
except Exception:
return WebContent(
title="",
link=url,
full_content="",
published_date=None,
scrape_successful=False,
)
# with ThreadPoolExecutor(max_workers=min(8, len(urls))) as e:
# return list(e.map(safe_get_webpage_content, urls))
with ThreadPoolExecutor(max_workers=min(8, len(urls))) as e:
return list(e.map(safe_get_webpage_content, urls))
# @retry_builder(tries=3, delay=1, backoff=2)
# def _get_webpage_content(self, url: str) -> WebContent:
# payload = {
# "url": url,
# }
@retry_builder(tries=3, delay=1, backoff=2)
def _get_webpage_content(self, url: str) -> WebContent:
payload = {
"url": url,
}
# response = requests.post(
# SERPER_CONTENTS_URL,
# headers=self.headers,
# data=json.dumps(payload),
# )
response = requests.post(
SERPER_CONTENTS_URL,
headers=self.headers,
data=json.dumps(payload),
)
# # 400 returned when serper cannot scrape
# if response.status_code == 400:
# return WebContent(
# title="",
# link=url,
# full_content="",
# published_date=None,
# scrape_successful=False,
# )
# 400 returned when serper cannot scrape
if response.status_code == 400:
return WebContent(
title="",
link=url,
full_content="",
published_date=None,
scrape_successful=False,
)
# response.raise_for_status()
try:
response.raise_for_status()
except Exception:
# Avoid leaking API keys/URLs
raise ValueError(
"Serper content fetch failed. Check credentials."
) from None
# response_json = response.json()
response_json = response.json()
# # Response only guarantees text
# text = response_json["text"]
# Response only guarantees text
text = response_json["text"]
# # metadata & jsonld is not guaranteed to be present
# metadata = response_json.get("metadata", {})
# jsonld = response_json.get("jsonld", {})
# metadata & jsonld is not guaranteed to be present
metadata = response_json.get("metadata", {})
jsonld = response_json.get("jsonld", {})
# title = extract_title_from_metadata(metadata)
title = extract_title_from_metadata(metadata)
# # Serper does not provide a reliable mechanism to extract the url
# response_url = url
# published_date_str = extract_published_date_from_jsonld(jsonld)
# published_date = None
# Serper does not provide a reliable mechanism to extract the url
response_url = url
published_date_str = extract_published_date_from_jsonld(jsonld)
published_date = None
# if published_date_str:
# try:
# published_date = time_str_to_utc(published_date_str)
# except Exception:
# published_date = None
if published_date_str:
try:
published_date = time_str_to_utc(published_date_str)
except Exception:
published_date = None
# return WebContent(
# title=title or "",
# link=response_url,
# full_content=text or "",
# published_date=published_date,
# )
return WebContent(
title=title or "",
link=response_url,
full_content=text or "",
published_date=published_date,
)
# def extract_title_from_metadata(metadata: dict[str, str]) -> str | None:
# keys = ["title", "og:title"]
# return extract_value_from_dict(metadata, keys)
def extract_title_from_metadata(metadata: dict[str, str]) -> str | None:
keys = ["title", "og:title"]
return extract_value_from_dict(metadata, keys)
# def extract_published_date_from_jsonld(jsonld: dict[str, str]) -> str | None:
# keys = ["dateModified"]
# return extract_value_from_dict(jsonld, keys)
def extract_published_date_from_jsonld(jsonld: dict[str, str]) -> str | None:
keys = ["dateModified"]
return extract_value_from_dict(jsonld, keys)
# def extract_value_from_dict(data: dict[str, str], keys: list[str]) -> str | None:
# for key in keys:
# if key in data:
# return data[key]
# return None
def extract_value_from_dict(data: dict[str, str], keys: list[str]) -> str | None:
for key in keys:
if key in data:
return data[key]
return None

View File

@@ -1,47 +1,47 @@
# from datetime import datetime
from datetime import datetime
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from onyx.agents.agent_search.dr.states import LoggerUpdate
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
# from onyx.server.query_and_chat.streaming_models import SearchToolStart
# from onyx.utils.logger import setup_logger
from onyx.agents.agent_search.dr.states import LoggerUpdate
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.server.query_and_chat.streaming_models import SearchToolStart
from onyx.utils.logger import setup_logger
# logger = setup_logger()
logger = setup_logger()
# def is_branch(
# state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
# ) -> LoggerUpdate:
# """
# LangGraph node to perform a web search as part of the DR process.
# """
def is_branch(
state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> LoggerUpdate:
"""
LangGraph node to perform a web search as part of the DR process.
"""
# node_start_time = datetime.now()
# iteration_nr = state.iteration_nr
# current_step_nr = state.current_step_nr
node_start_time = datetime.now()
iteration_nr = state.iteration_nr
current_step_nr = state.current_step_nr
# logger.debug(f"Search start for Web Search {iteration_nr} at {datetime.now()}")
logger.debug(f"Search start for Web Search {iteration_nr} at {datetime.now()}")
# write_custom_event(
# current_step_nr,
# SearchToolStart(
# is_internet_search=True,
# ),
# writer,
# )
write_custom_event(
current_step_nr,
SearchToolStart(
is_internet_search=True,
),
writer,
)
# return LoggerUpdate(
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="internet_search",
# node_name="branching",
# node_start_time=node_start_time,
# )
# ],
# )
return LoggerUpdate(
log_messages=[
get_langgraph_node_log_string(
graph_component="internet_search",
node_name="branching",
node_start_time=node_start_time,
)
],
)

View File

@@ -1,128 +1,137 @@
# from datetime import datetime
# from typing import cast
from datetime import datetime
from typing import cast
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
# from langsmith import traceable
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from langsmith import traceable
# from onyx.agents.agent_search.dr.models import WebSearchAnswer
# from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
# WebSearchResult,
# )
# from onyx.agents.agent_search.dr.sub_agents.web_search.providers import (
# get_default_provider,
# )
# from onyx.agents.agent_search.dr.sub_agents.web_search.states import (
# InternetSearchInput,
# )
# from onyx.agents.agent_search.dr.sub_agents.web_search.states import (
# InternetSearchUpdate,
# )
# from onyx.agents.agent_search.models import GraphConfig
# from onyx.agents.agent_search.shared_graph_utils.llm import invoke_llm_json
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
# from onyx.agents.agent_search.utils import create_question_prompt
# from onyx.configs.agent_configs import TF_DR_TIMEOUT_SHORT
# from onyx.prompts.dr_prompts import WEB_SEARCH_URL_SELECTION_PROMPT
# from onyx.server.query_and_chat.streaming_models import SearchToolDelta
# from onyx.utils.logger import setup_logger
from onyx.agents.agent_search.dr.models import WebSearchAnswer
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
WebSearchResult,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.providers import (
get_default_provider,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.states import (
InternetSearchInput,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.states import (
InternetSearchUpdate,
)
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.llm import invoke_llm_json
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.agents.agent_search.utils import create_question_prompt
from onyx.configs.agent_configs import TF_DR_TIMEOUT_SHORT
from onyx.prompts.dr_prompts import WEB_SEARCH_URL_SELECTION_PROMPT
from onyx.server.query_and_chat.streaming_models import SearchToolDelta
from onyx.utils.logger import setup_logger
# logger = setup_logger()
logger = setup_logger()
# def web_search(
# state: InternetSearchInput,
# config: RunnableConfig,
# writer: StreamWriter = lambda _: None,
# ) -> InternetSearchUpdate:
# """
# LangGraph node to perform internet search and decide which URLs to fetch.
# """
def web_search(
state: InternetSearchInput,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> InternetSearchUpdate:
"""
LangGraph node to perform internet search and decide which URLs to fetch.
"""
# node_start_time = datetime.now()
# current_step_nr = state.current_step_nr
node_start_time = datetime.now()
current_step_nr = state.current_step_nr
# if not current_step_nr:
# raise ValueError("Current step number is not set. This should not happen.")
if not current_step_nr:
raise ValueError("Current step number is not set. This should not happen.")
# assistant_system_prompt = state.assistant_system_prompt
# assistant_task_prompt = state.assistant_task_prompt
assistant_system_prompt = state.assistant_system_prompt
assistant_task_prompt = state.assistant_task_prompt
# if not state.available_tools:
# raise ValueError("available_tools is not set")
# search_query = state.branch_question
if not state.available_tools:
raise ValueError("available_tools is not set")
search_query = state.branch_question
# write_custom_event(
# current_step_nr,
# SearchToolDelta(
# queries=[search_query],
# documents=[],
# ),
# writer,
# )
write_custom_event(
current_step_nr,
SearchToolDelta(
queries=[search_query],
documents=[],
),
writer,
)
# graph_config = cast(GraphConfig, config["metadata"]["config"])
# base_question = graph_config.inputs.prompt_builder.raw_user_query
graph_config = cast(GraphConfig, config["metadata"]["config"])
base_question = graph_config.inputs.prompt_builder.raw_user_query
# if graph_config.inputs.persona is None:
# raise ValueError("persona is not set")
if graph_config.inputs.persona is None:
raise ValueError("persona is not set")
# provider = get_default_provider()
# if not provider:
# raise ValueError("No internet search provider found")
provider = get_default_provider()
if not provider:
raise ValueError("No internet search provider found")
# @traceable(name="Search Provider API Call")
# def _search(search_query: str) -> list[WebSearchResult]:
# search_results: list[WebSearchResult] = []
# try:
# search_results = list(provider.search(search_query))
# except Exception as e:
# logger.error(f"Error performing search: {e}")
# return search_results
# Log which provider type is being used
provider_type = type(provider).__name__
logger.info(
f"Performing web search with {provider_type} for query: '{search_query}'"
)
# search_results: list[WebSearchResult] = _search(search_query)
# search_results_text = "\n\n".join(
# [
# f"{i}. {result.title}\n URL: {result.link}\n"
# + (f" Author: {result.author}\n" if result.author else "")
# + (
# f" Date: {result.published_date.strftime('%Y-%m-%d')}\n"
# if result.published_date
# else ""
# )
# + (f" Snippet: {result.snippet}\n" if result.snippet else "")
# for i, result in enumerate(search_results)
# ]
# )
# agent_decision_prompt = WEB_SEARCH_URL_SELECTION_PROMPT.build(
# search_query=search_query,
# base_question=base_question,
# search_results_text=search_results_text,
# )
# agent_decision = invoke_llm_json(
# llm=graph_config.tooling.fast_llm,
# prompt=create_question_prompt(
# assistant_system_prompt,
# agent_decision_prompt + (assistant_task_prompt or ""),
# ),
# schema=WebSearchAnswer,
# timeout_override=TF_DR_TIMEOUT_SHORT,
# )
# results_to_open = [
# (search_query, search_results[i])
# for i in agent_decision.urls_to_open_indices
# if i < len(search_results) and i >= 0
# ]
# return InternetSearchUpdate(
# results_to_open=results_to_open,
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="internet_search",
# node_name="searching",
# node_start_time=node_start_time,
# )
# ],
# )
@traceable(name="Search Provider API Call")
def _search(search_query: str) -> list[WebSearchResult]:
search_results: list[WebSearchResult] = []
try:
search_results = list(provider.search(search_query))
logger.info(
f"Search returned {len(search_results)} results using {provider_type}"
)
except Exception as e:
logger.error(f"Error performing search with {provider_type}: {e}")
return search_results
search_results: list[WebSearchResult] = _search(search_query)
search_results_text = "\n\n".join(
[
f"{i}. {result.title}\n URL: {result.link}\n"
+ (f" Author: {result.author}\n" if result.author else "")
+ (
f" Date: {result.published_date.strftime('%Y-%m-%d')}\n"
if result.published_date
else ""
)
+ (f" Snippet: {result.snippet}\n" if result.snippet else "")
for i, result in enumerate(search_results)
]
)
agent_decision_prompt = WEB_SEARCH_URL_SELECTION_PROMPT.build(
search_query=search_query,
base_question=base_question,
search_results_text=search_results_text,
)
agent_decision = invoke_llm_json(
llm=graph_config.tooling.fast_llm,
prompt=create_question_prompt(
assistant_system_prompt,
agent_decision_prompt + (assistant_task_prompt or ""),
),
schema=WebSearchAnswer,
timeout_override=TF_DR_TIMEOUT_SHORT,
)
results_to_open = [
(search_query, search_results[i])
for i in agent_decision.urls_to_open_indices
if i < len(search_results) and i >= 0
]
return InternetSearchUpdate(
results_to_open=results_to_open,
log_messages=[
get_langgraph_node_log_string(
graph_component="internet_search",
node_name="searching",
node_start_time=node_start_time,
)
],
)

View File

@@ -1,52 +1,52 @@
# from collections import defaultdict
from collections import defaultdict
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
# WebSearchResult,
# )
# from onyx.agents.agent_search.dr.sub_agents.web_search.states import (
# InternetSearchInput,
# )
# from onyx.agents.agent_search.dr.sub_agents.web_search.utils import (
# dummy_inference_section_from_internet_search_result,
# )
# from onyx.agents.agent_search.dr.utils import convert_inference_sections_to_search_docs
# from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
# from onyx.server.query_and_chat.streaming_models import SearchToolDelta
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
WebSearchResult,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.states import (
InternetSearchInput,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.utils import (
dummy_inference_section_from_internet_search_result,
)
from onyx.agents.agent_search.dr.utils import convert_inference_sections_to_search_docs
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.server.query_and_chat.streaming_models import SearchToolDelta
# def dedup_urls(
# state: InternetSearchInput,
# config: RunnableConfig,
# writer: StreamWriter = lambda _: None,
# ) -> InternetSearchInput:
# branch_questions_to_urls: dict[str, list[str]] = defaultdict(list)
# unique_results_by_link: dict[str, WebSearchResult] = {}
# for query, result in state.results_to_open:
# branch_questions_to_urls[query].append(result.link)
# if result.link not in unique_results_by_link:
# unique_results_by_link[result.link] = result
def dedup_urls(
state: InternetSearchInput,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> InternetSearchInput:
branch_questions_to_urls: dict[str, list[str]] = defaultdict(list)
unique_results_by_link: dict[str, WebSearchResult] = {}
for query, result in state.results_to_open:
branch_questions_to_urls[query].append(result.link)
if result.link not in unique_results_by_link:
unique_results_by_link[result.link] = result
# unique_results = list(unique_results_by_link.values())
# dummy_docs_inference_sections = [
# dummy_inference_section_from_internet_search_result(doc)
# for doc in unique_results
# ]
unique_results = list(unique_results_by_link.values())
dummy_docs_inference_sections = [
dummy_inference_section_from_internet_search_result(doc)
for doc in unique_results
]
# write_custom_event(
# state.current_step_nr,
# SearchToolDelta(
# queries=[],
# documents=convert_inference_sections_to_search_docs(
# dummy_docs_inference_sections, is_internet=True
# ),
# ),
# writer,
# )
write_custom_event(
state.current_step_nr,
SearchToolDelta(
queries=[],
documents=convert_inference_sections_to_search_docs(
dummy_docs_inference_sections, is_internet=True
),
),
writer,
)
# return InternetSearchInput(
# results_to_open=[],
# branch_questions_to_urls=branch_questions_to_urls,
# )
return InternetSearchInput(
results_to_open=[],
branch_questions_to_urls=branch_questions_to_urls,
)

View File

@@ -1,69 +1,69 @@
# from datetime import datetime
# from typing import cast
from datetime import datetime
from typing import cast
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from onyx.agents.agent_search.dr.sub_agents.web_search.providers import (
# get_default_provider,
# )
# from onyx.agents.agent_search.dr.sub_agents.web_search.states import FetchInput
# from onyx.agents.agent_search.dr.sub_agents.web_search.states import FetchUpdate
# from onyx.agents.agent_search.dr.sub_agents.web_search.utils import (
# inference_section_from_internet_page_scrape,
# )
# from onyx.agents.agent_search.models import GraphConfig
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.context.search.models import InferenceSection
# from onyx.utils.logger import setup_logger
from onyx.agents.agent_search.dr.sub_agents.web_search.providers import (
get_default_content_provider,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.states import FetchInput
from onyx.agents.agent_search.dr.sub_agents.web_search.states import FetchUpdate
from onyx.agents.agent_search.dr.sub_agents.web_search.utils import (
dummy_inference_section_from_internet_content,
)
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.context.search.models import InferenceSection
from onyx.utils.logger import setup_logger
# logger = setup_logger()
logger = setup_logger()
# def web_fetch(
# state: FetchInput,
# config: RunnableConfig,
# writer: StreamWriter = lambda _: None,
# ) -> FetchUpdate:
# """
# LangGraph node to fetch content from URLs and process the results.
# """
def web_fetch(
state: FetchInput,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> FetchUpdate:
"""
LangGraph node to fetch content from URLs and process the results.
"""
# node_start_time = datetime.now()
node_start_time = datetime.now()
# if not state.available_tools:
# raise ValueError("available_tools is not set")
if not state.available_tools:
raise ValueError("available_tools is not set")
# graph_config = cast(GraphConfig, config["metadata"]["config"])
graph_config = cast(GraphConfig, config["metadata"]["config"])
# if graph_config.inputs.persona is None:
# raise ValueError("persona is not set")
if graph_config.inputs.persona is None:
raise ValueError("persona is not set")
# provider = get_default_provider()
# if provider is None:
# raise ValueError("No web search provider found")
provider = get_default_content_provider()
if provider is None:
raise ValueError("No web content provider found")
# retrieved_docs: list[InferenceSection] = []
# try:
# retrieved_docs = [
# inference_section_from_internet_page_scrape(result)
# for result in provider.contents(state.urls_to_open)
# ]
# except Exception as e:
# logger.error(f"Error fetching URLs: {e}")
retrieved_docs: list[InferenceSection] = []
try:
retrieved_docs = [
dummy_inference_section_from_internet_content(result)
for result in provider.contents(state.urls_to_open)
]
except Exception as e:
logger.exception(e)
# if not retrieved_docs:
# logger.warning("No content retrieved from URLs")
if not retrieved_docs:
logger.warning("No content retrieved from URLs")
# return FetchUpdate(
# raw_documents=retrieved_docs,
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="internet_search",
# node_name="fetching",
# node_start_time=node_start_time,
# )
# ],
# )
return FetchUpdate(
raw_documents=retrieved_docs,
log_messages=[
get_langgraph_node_log_string(
graph_component="internet_search",
node_name="fetching",
node_start_time=node_start_time,
)
],
)

View File

@@ -1,19 +1,19 @@
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from onyx.agents.agent_search.dr.sub_agents.web_search.states import FetchInput
# from onyx.agents.agent_search.dr.sub_agents.web_search.states import (
# InternetSearchInput,
# )
from onyx.agents.agent_search.dr.sub_agents.web_search.states import FetchInput
from onyx.agents.agent_search.dr.sub_agents.web_search.states import (
InternetSearchInput,
)
# def collect_raw_docs(
# state: FetchInput,
# config: RunnableConfig,
# writer: StreamWriter = lambda _: None,
# ) -> InternetSearchInput:
# raw_documents = state.raw_documents
def collect_raw_docs(
state: FetchInput,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> InternetSearchInput:
raw_documents = state.raw_documents
# return InternetSearchInput(
# raw_documents=raw_documents,
# )
return InternetSearchInput(
raw_documents=raw_documents,
)

View File

@@ -1,132 +1,133 @@
# from datetime import datetime
# from typing import cast
from datetime import datetime
from typing import cast
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from onyx.agents.agent_search.dr.models import IterationAnswer
# from onyx.agents.agent_search.dr.models import SearchAnswer
# from onyx.agents.agent_search.dr.sub_agents.states import BranchUpdate
# from onyx.agents.agent_search.dr.sub_agents.web_search.states import SummarizeInput
# from onyx.agents.agent_search.dr.utils import extract_document_citations
# from onyx.agents.agent_search.kb_search.graph_utils import build_document_context
# from onyx.agents.agent_search.models import GraphConfig
# from onyx.agents.agent_search.shared_graph_utils.llm import invoke_llm_json
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.agents.agent_search.utils import create_question_prompt
# from onyx.configs.agent_configs import TF_DR_TIMEOUT_SHORT
# from onyx.context.search.models import InferenceSection
# from onyx.prompts.dr_prompts import INTERNAL_SEARCH_PROMPTS
# from onyx.utils.logger import setup_logger
# from onyx.utils.url import normalize_url
from onyx.agents.agent_search.dr.enums import ResearchType
from onyx.agents.agent_search.dr.models import IterationAnswer
from onyx.agents.agent_search.dr.models import SearchAnswer
from onyx.agents.agent_search.dr.sub_agents.states import BranchUpdate
from onyx.agents.agent_search.dr.sub_agents.web_search.states import SummarizeInput
from onyx.agents.agent_search.dr.utils import extract_document_citations
from onyx.agents.agent_search.kb_search.graph_utils import build_document_context
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.llm import invoke_llm_json
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.utils import create_question_prompt
from onyx.configs.agent_configs import TF_DR_TIMEOUT_SHORT
from onyx.context.search.models import InferenceSection
from onyx.prompts.dr_prompts import INTERNAL_SEARCH_PROMPTS
from onyx.utils.logger import setup_logger
from onyx.utils.url import normalize_url
# logger = setup_logger()
logger = setup_logger()
# def is_summarize(
# state: SummarizeInput,
# config: RunnableConfig,
# writer: StreamWriter = lambda _: None,
# ) -> BranchUpdate:
# """
# LangGraph node to perform a internet search as part of the DR process.
# """
def is_summarize(
state: SummarizeInput,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> BranchUpdate:
"""
LangGraph node to perform a internet search as part of the DR process.
"""
# node_start_time = datetime.now()
node_start_time = datetime.now()
# # build branch iterations from fetch inputs
# # Normalize URLs to handle mismatches from query parameters (e.g., ?activeTab=explore)
# url_to_raw_document: dict[str, InferenceSection] = {}
# for raw_document in state.raw_documents:
# normalized_url = normalize_url(raw_document.center_chunk.semantic_identifier)
# url_to_raw_document[normalized_url] = raw_document
# build branch iterations from fetch inputs
# Normalize URLs to handle mismatches from query parameters (e.g., ?activeTab=explore)
url_to_raw_document: dict[str, InferenceSection] = {}
for raw_document in state.raw_documents:
normalized_url = normalize_url(raw_document.center_chunk.semantic_identifier)
url_to_raw_document[normalized_url] = raw_document
# # Normalize the URLs from branch_questions_to_urls as well
# urls = [
# normalize_url(url)
# for url in state.branch_questions_to_urls[state.branch_question]
# ]
# current_iteration = state.iteration_nr
# graph_config = cast(GraphConfig, config["metadata"]["config"])
# use_agentic_search = graph_config.behavior.use_agentic_search
# if not state.available_tools:
# raise ValueError("available_tools is not set")
# is_tool_info = state.available_tools[state.tools_used[-1]]
# Normalize the URLs from branch_questions_to_urls as well
urls = [
normalize_url(url)
for url in state.branch_questions_to_urls[state.branch_question]
]
current_iteration = state.iteration_nr
graph_config = cast(GraphConfig, config["metadata"]["config"])
research_type = graph_config.behavior.research_type
if not state.available_tools:
raise ValueError("available_tools is not set")
is_tool_info = state.available_tools[state.tools_used[-1]]
# if use_agentic_search:
# cited_raw_documents = [url_to_raw_document[url] for url in urls]
# document_texts = _create_document_texts(cited_raw_documents)
# search_prompt = INTERNAL_SEARCH_PROMPTS[use_agentic_search].build(
# search_query=state.branch_question,
# base_question=graph_config.inputs.prompt_builder.raw_user_query,
# document_text=document_texts,
# )
# assistant_system_prompt = state.assistant_system_prompt
# assistant_task_prompt = state.assistant_task_prompt
# search_answer_json = invoke_llm_json(
# llm=graph_config.tooling.primary_llm,
# prompt=create_question_prompt(
# assistant_system_prompt, search_prompt + (assistant_task_prompt or "")
# ),
# schema=SearchAnswer,
# timeout_override=TF_DR_TIMEOUT_SHORT,
# )
# answer_string = search_answer_json.answer
# claims = search_answer_json.claims or []
# reasoning = search_answer_json.reasoning or ""
# (
# citation_numbers,
# answer_string,
# claims,
# ) = extract_document_citations(answer_string, claims)
# cited_documents = {
# citation_number: cited_raw_documents[citation_number - 1]
# for citation_number in citation_numbers
# }
if research_type == ResearchType.DEEP:
cited_raw_documents = [url_to_raw_document[url] for url in urls]
document_texts = _create_document_texts(cited_raw_documents)
search_prompt = INTERNAL_SEARCH_PROMPTS[research_type].build(
search_query=state.branch_question,
base_question=graph_config.inputs.prompt_builder.raw_user_query,
document_text=document_texts,
)
assistant_system_prompt = state.assistant_system_prompt
assistant_task_prompt = state.assistant_task_prompt
search_answer_json = invoke_llm_json(
llm=graph_config.tooling.primary_llm,
prompt=create_question_prompt(
assistant_system_prompt, search_prompt + (assistant_task_prompt or "")
),
schema=SearchAnswer,
timeout_override=TF_DR_TIMEOUT_SHORT,
)
answer_string = search_answer_json.answer
claims = search_answer_json.claims or []
reasoning = search_answer_json.reasoning or ""
(
citation_numbers,
answer_string,
claims,
) = extract_document_citations(answer_string, claims)
cited_documents = {
citation_number: cited_raw_documents[citation_number - 1]
for citation_number in citation_numbers
}
# else:
# answer_string = ""
# reasoning = ""
# claims = []
# cited_raw_documents = [url_to_raw_document[url] for url in urls]
# cited_documents = {
# doc_num + 1: retrieved_doc
# for doc_num, retrieved_doc in enumerate(cited_raw_documents)
# }
else:
answer_string = ""
reasoning = ""
claims = []
cited_raw_documents = [url_to_raw_document[url] for url in urls]
cited_documents = {
doc_num + 1: retrieved_doc
for doc_num, retrieved_doc in enumerate(cited_raw_documents)
}
# return BranchUpdate(
# branch_iteration_responses=[
# IterationAnswer(
# tool=is_tool_info.llm_path,
# tool_id=is_tool_info.tool_id,
# iteration_nr=current_iteration,
# parallelization_nr=0,
# question=state.branch_question,
# answer=answer_string,
# claims=claims,
# cited_documents=cited_documents,
# reasoning=reasoning,
# additional_data=None,
# )
# ],
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="internet_search",
# node_name="summarizing",
# node_start_time=node_start_time,
# )
# ],
# )
return BranchUpdate(
branch_iteration_responses=[
IterationAnswer(
tool=is_tool_info.llm_path,
tool_id=is_tool_info.tool_id,
iteration_nr=current_iteration,
parallelization_nr=0,
question=state.branch_question,
answer=answer_string,
claims=claims,
cited_documents=cited_documents,
reasoning=reasoning,
additional_data=None,
)
],
log_messages=[
get_langgraph_node_log_string(
graph_component="internet_search",
node_name="summarizing",
node_start_time=node_start_time,
)
],
)
# def _create_document_texts(raw_documents: list[InferenceSection]) -> str:
# document_texts_list = []
# for doc_num, retrieved_doc in enumerate(raw_documents):
# if not isinstance(retrieved_doc, InferenceSection):
# raise ValueError(f"Unexpected document type: {type(retrieved_doc)}")
# chunk_text = build_document_context(retrieved_doc, doc_num + 1)
# document_texts_list.append(chunk_text)
# return "\n\n".join(document_texts_list)
def _create_document_texts(raw_documents: list[InferenceSection]) -> str:
document_texts_list = []
for doc_num, retrieved_doc in enumerate(raw_documents):
if not isinstance(retrieved_doc, InferenceSection):
raise ValueError(f"Unexpected document type: {type(retrieved_doc)}")
chunk_text = build_document_context(retrieved_doc, doc_num + 1)
document_texts_list.append(chunk_text)
return "\n\n".join(document_texts_list)

View File

@@ -1,56 +1,56 @@
# from datetime import datetime
from datetime import datetime
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentUpdate
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
# from onyx.server.query_and_chat.streaming_models import SectionEnd
# from onyx.utils.logger import setup_logger
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentUpdate
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.server.query_and_chat.streaming_models import SectionEnd
from onyx.utils.logger import setup_logger
# logger = setup_logger()
logger = setup_logger()
# def is_reducer(
# state: SubAgentMainState,
# config: RunnableConfig,
# writer: StreamWriter = lambda _: None,
# ) -> SubAgentUpdate:
# """
# LangGraph node to perform a internet search as part of the DR process.
# """
def is_reducer(
state: SubAgentMainState,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> SubAgentUpdate:
"""
LangGraph node to perform a internet search as part of the DR process.
"""
# node_start_time = datetime.now()
node_start_time = datetime.now()
# branch_updates = state.branch_iteration_responses
# current_iteration = state.iteration_nr
# current_step_nr = state.current_step_nr
branch_updates = state.branch_iteration_responses
current_iteration = state.iteration_nr
current_step_nr = state.current_step_nr
# new_updates = [
# update for update in branch_updates if update.iteration_nr == current_iteration
# ]
new_updates = [
update for update in branch_updates if update.iteration_nr == current_iteration
]
# write_custom_event(
# current_step_nr,
# SectionEnd(),
# writer,
# )
write_custom_event(
current_step_nr,
SectionEnd(),
writer,
)
# current_step_nr += 1
current_step_nr += 1
# return SubAgentUpdate(
# iteration_responses=new_updates,
# current_step_nr=current_step_nr,
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="internet_search",
# node_name="consolidation",
# node_start_time=node_start_time,
# )
# ],
# )
return SubAgentUpdate(
iteration_responses=new_updates,
current_step_nr=current_step_nr,
log_messages=[
get_langgraph_node_log_string(
graph_component="internet_search",
node_name="consolidation",
node_start_time=node_start_time,
)
],
)

View File

@@ -1,79 +1,79 @@
# from collections.abc import Hashable
from collections.abc import Hashable
# from langgraph.types import Send
from langgraph.types import Send
# from onyx.agents.agent_search.dr.constants import MAX_DR_PARALLEL_SEARCH
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
# from onyx.agents.agent_search.dr.sub_agents.web_search.states import FetchInput
# from onyx.agents.agent_search.dr.sub_agents.web_search.states import (
# InternetSearchInput,
# )
# from onyx.agents.agent_search.dr.sub_agents.web_search.states import SummarizeInput
from onyx.agents.agent_search.dr.constants import MAX_DR_PARALLEL_SEARCH
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
from onyx.agents.agent_search.dr.sub_agents.web_search.states import FetchInput
from onyx.agents.agent_search.dr.sub_agents.web_search.states import (
InternetSearchInput,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.states import SummarizeInput
# def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
# return [
# Send(
# "search",
# InternetSearchInput(
# iteration_nr=state.iteration_nr,
# current_step_nr=state.current_step_nr,
# parallelization_nr=parallelization_nr,
# query_list=[query],
# branch_question=query,
# context="",
# tools_used=state.tools_used,
# available_tools=state.available_tools,
# assistant_system_prompt=state.assistant_system_prompt,
# assistant_task_prompt=state.assistant_task_prompt,
# results_to_open=[],
# ),
# )
# for parallelization_nr, query in enumerate(
# state.query_list[:MAX_DR_PARALLEL_SEARCH]
# )
# ]
def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
return [
Send(
"search",
InternetSearchInput(
iteration_nr=state.iteration_nr,
current_step_nr=state.current_step_nr,
parallelization_nr=parallelization_nr,
query_list=[query],
branch_question=query,
context="",
tools_used=state.tools_used,
available_tools=state.available_tools,
assistant_system_prompt=state.assistant_system_prompt,
assistant_task_prompt=state.assistant_task_prompt,
results_to_open=[],
),
)
for parallelization_nr, query in enumerate(
state.query_list[:MAX_DR_PARALLEL_SEARCH]
)
]
# def fetch_router(state: InternetSearchInput) -> list[Send | Hashable]:
# branch_questions_to_urls = state.branch_questions_to_urls
# return [
# Send(
# "fetch",
# FetchInput(
# iteration_nr=state.iteration_nr,
# urls_to_open=[url],
# tools_used=state.tools_used,
# available_tools=state.available_tools,
# assistant_system_prompt=state.assistant_system_prompt,
# assistant_task_prompt=state.assistant_task_prompt,
# current_step_nr=state.current_step_nr,
# branch_questions_to_urls=branch_questions_to_urls,
# raw_documents=state.raw_documents,
# ),
# )
# for url in set(
# url for urls in branch_questions_to_urls.values() for url in urls
# )
# ]
def fetch_router(state: InternetSearchInput) -> list[Send | Hashable]:
branch_questions_to_urls = state.branch_questions_to_urls
return [
Send(
"fetch",
FetchInput(
iteration_nr=state.iteration_nr,
urls_to_open=[url],
tools_used=state.tools_used,
available_tools=state.available_tools,
assistant_system_prompt=state.assistant_system_prompt,
assistant_task_prompt=state.assistant_task_prompt,
current_step_nr=state.current_step_nr,
branch_questions_to_urls=branch_questions_to_urls,
raw_documents=state.raw_documents,
),
)
for url in set(
url for urls in branch_questions_to_urls.values() for url in urls
)
]
# def summarize_router(state: InternetSearchInput) -> list[Send | Hashable]:
# branch_questions_to_urls = state.branch_questions_to_urls
# return [
# Send(
# "summarize",
# SummarizeInput(
# iteration_nr=state.iteration_nr,
# raw_documents=state.raw_documents,
# branch_questions_to_urls=branch_questions_to_urls,
# branch_question=branch_question,
# tools_used=state.tools_used,
# available_tools=state.available_tools,
# assistant_system_prompt=state.assistant_system_prompt,
# assistant_task_prompt=state.assistant_task_prompt,
# current_step_nr=state.current_step_nr,
# ),
# )
# for branch_question in branch_questions_to_urls.keys()
# ]
def summarize_router(state: InternetSearchInput) -> list[Send | Hashable]:
branch_questions_to_urls = state.branch_questions_to_urls
return [
Send(
"summarize",
SummarizeInput(
iteration_nr=state.iteration_nr,
raw_documents=state.raw_documents,
branch_questions_to_urls=branch_questions_to_urls,
branch_question=branch_question,
tools_used=state.tools_used,
available_tools=state.available_tools,
assistant_system_prompt=state.assistant_system_prompt,
assistant_task_prompt=state.assistant_task_prompt,
current_step_nr=state.current_step_nr,
),
)
for branch_question in branch_questions_to_urls.keys()
]

View File

@@ -1,84 +1,84 @@
# from langgraph.graph import END
# from langgraph.graph import START
# from langgraph.graph import StateGraph
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
# from onyx.agents.agent_search.dr.sub_agents.web_search.dr_ws_1_branch import (
# is_branch,
# )
# from onyx.agents.agent_search.dr.sub_agents.web_search.dr_ws_2_search import (
# web_search,
# )
# from onyx.agents.agent_search.dr.sub_agents.web_search.dr_ws_3_dedup_urls import (
# dedup_urls,
# )
# from onyx.agents.agent_search.dr.sub_agents.web_search.dr_ws_4_fetch import (
# web_fetch,
# )
# from onyx.agents.agent_search.dr.sub_agents.web_search.dr_ws_5_collect_raw_docs import (
# collect_raw_docs,
# )
# from onyx.agents.agent_search.dr.sub_agents.web_search.dr_ws_6_summarize import (
# is_summarize,
# )
# from onyx.agents.agent_search.dr.sub_agents.web_search.dr_ws_7_reduce import (
# is_reducer,
# )
# from onyx.agents.agent_search.dr.sub_agents.web_search.dr_ws_conditional_edges import (
# branching_router,
# )
# from onyx.agents.agent_search.dr.sub_agents.web_search.dr_ws_conditional_edges import (
# fetch_router,
# )
# from onyx.agents.agent_search.dr.sub_agents.web_search.dr_ws_conditional_edges import (
# summarize_router,
# )
# from onyx.utils.logger import setup_logger
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
from onyx.agents.agent_search.dr.sub_agents.web_search.dr_ws_1_branch import (
is_branch,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.dr_ws_2_search import (
web_search,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.dr_ws_3_dedup_urls import (
dedup_urls,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.dr_ws_4_fetch import (
web_fetch,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.dr_ws_5_collect_raw_docs import (
collect_raw_docs,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.dr_ws_6_summarize import (
is_summarize,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.dr_ws_7_reduce import (
is_reducer,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.dr_ws_conditional_edges import (
branching_router,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.dr_ws_conditional_edges import (
fetch_router,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.dr_ws_conditional_edges import (
summarize_router,
)
from onyx.utils.logger import setup_logger
# logger = setup_logger()
logger = setup_logger()
# def dr_ws_graph_builder() -> StateGraph:
# """
# LangGraph graph builder for Internet Search Sub-Agent
# """
def dr_ws_graph_builder() -> StateGraph:
"""
LangGraph graph builder for Internet Search Sub-Agent
"""
# graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
# ### Add nodes ###
### Add nodes ###
# graph.add_node("branch", is_branch)
graph.add_node("branch", is_branch)
# graph.add_node("search", web_search)
graph.add_node("search", web_search)
# graph.add_node("dedup_urls", dedup_urls)
graph.add_node("dedup_urls", dedup_urls)
# graph.add_node("fetch", web_fetch)
graph.add_node("fetch", web_fetch)
# graph.add_node("collect_raw_docs", collect_raw_docs)
graph.add_node("collect_raw_docs", collect_raw_docs)
# graph.add_node("summarize", is_summarize)
graph.add_node("summarize", is_summarize)
# graph.add_node("reducer", is_reducer)
graph.add_node("reducer", is_reducer)
# ### Add edges ###
### Add edges ###
# graph.add_edge(start_key=START, end_key="branch")
graph.add_edge(start_key=START, end_key="branch")
# graph.add_conditional_edges("branch", branching_router)
graph.add_conditional_edges("branch", branching_router)
# graph.add_edge(start_key="search", end_key="dedup_urls")
graph.add_edge(start_key="search", end_key="dedup_urls")
# graph.add_conditional_edges("dedup_urls", fetch_router)
graph.add_conditional_edges("dedup_urls", fetch_router)
# graph.add_edge(start_key="fetch", end_key="collect_raw_docs")
graph.add_edge(start_key="fetch", end_key="collect_raw_docs")
# graph.add_conditional_edges("collect_raw_docs", summarize_router)
graph.add_conditional_edges("collect_raw_docs", summarize_router)
# graph.add_edge(start_key="summarize", end_key="reducer")
graph.add_edge(start_key="summarize", end_key="reducer")
# graph.add_edge(start_key="reducer", end_key=END)
graph.add_edge(start_key="reducer", end_key=END)
# return graph
return graph

View File

@@ -1,53 +1,47 @@
# from abc import ABC
# from abc import abstractmethod
# from collections.abc import Sequence
# from datetime import datetime
# from enum import Enum
from abc import ABC
from abc import abstractmethod
from collections.abc import Sequence
from datetime import datetime
# from pydantic import BaseModel
# from pydantic import field_validator
from pydantic import BaseModel
from pydantic import field_validator
# from onyx.utils.url import normalize_url
from onyx.utils.url import normalize_url
# class ProviderType(Enum):
# """Enum for internet search provider types"""
class WebSearchResult(BaseModel):
title: str
link: str
snippet: str | None = None
author: str | None = None
published_date: datetime | None = None
# GOOGLE = "google"
# EXA = "exa"
@field_validator("link")
@classmethod
def normalize_link(cls, v: str) -> str:
return normalize_url(v)
# class WebSearchResult(BaseModel):
# title: str
# link: str
# snippet: str | None = None
# author: str | None = None
# published_date: datetime | None = None
class WebContent(BaseModel):
title: str
link: str
full_content: str
published_date: datetime | None = None
scrape_successful: bool = True
# @field_validator("link")
# @classmethod
# def normalize_link(cls, v: str) -> str:
# return normalize_url(v)
@field_validator("link")
@classmethod
def normalize_link(cls, v: str) -> str:
return normalize_url(v)
# class WebContent(BaseModel):
# title: str
# link: str
# full_content: str
# published_date: datetime | None = None
# scrape_successful: bool = True
# @field_validator("link")
# @classmethod
# def normalize_link(cls, v: str) -> str:
# return normalize_url(v)
class WebContentProvider(ABC):
@abstractmethod
def contents(self, urls: Sequence[str]) -> list[WebContent]:
pass
# class WebSearchProvider(ABC):
# @abstractmethod
# def search(self, query: str) -> Sequence[WebSearchResult]:
# pass
# @abstractmethod
# def contents(self, urls: Sequence[str]) -> list[WebContent]:
# pass
class WebSearchProvider(WebContentProvider):
@abstractmethod
def search(self, query: str) -> Sequence[WebSearchResult]:
pass

View File

@@ -1,19 +1,199 @@
# from onyx.agents.agent_search.dr.sub_agents.web_search.clients.exa_client import (
# ExaClient,
# )
# from onyx.agents.agent_search.dr.sub_agents.web_search.clients.serper_client import (
# SerperClient,
# )
# from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
# WebSearchProvider,
# )
# from onyx.configs.chat_configs import EXA_API_KEY
# from onyx.configs.chat_configs import SERPER_API_KEY
from typing import Any
from onyx.agents.agent_search.dr.sub_agents.web_search.clients.exa_client import (
ExaClient,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.clients.firecrawl_client import (
FIRECRAWL_SCRAPE_URL,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.clients.firecrawl_client import (
FirecrawlClient,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.clients.google_pse_client import (
GooglePSEClient,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.clients.onyx_web_crawler_client import (
OnyxWebCrawlerClient,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.clients.serper_client import (
SerperClient,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
WebContentProvider,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
WebSearchProvider,
)
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.web_search import fetch_active_web_content_provider
from onyx.db.web_search import fetch_active_web_search_provider
from onyx.utils.logger import setup_logger
from shared_configs.enums import WebContentProviderType
from shared_configs.enums import WebSearchProviderType
logger = setup_logger()
# def get_default_provider() -> WebSearchProvider | None:
# if EXA_API_KEY:
# return ExaClient()
# if SERPER_API_KEY:
# return SerperClient()
# return None
def build_search_provider_from_config(
*,
provider_type: WebSearchProviderType,
api_key: str | None,
config: dict[str, str] | None,
provider_name: str = "web_search_provider",
) -> WebSearchProvider | None:
provider_type_value = provider_type.value
try:
provider_type_enum = WebSearchProviderType(provider_type_value)
except ValueError:
logger.error(
f"Unknown web search provider type '{provider_type_value}'. "
"Skipping provider initialization."
)
return None
# All web search providers require an API key
if not api_key:
raise ValueError(
f"Web search provider '{provider_name}' is missing an API key."
)
assert api_key is not None
config = config or {}
if provider_type_enum == WebSearchProviderType.EXA:
return ExaClient(api_key=api_key)
if provider_type_enum == WebSearchProviderType.SERPER:
return SerperClient(api_key=api_key)
if provider_type_enum == WebSearchProviderType.GOOGLE_PSE:
search_engine_id = (
config.get("search_engine_id")
or config.get("cx")
or config.get("search_engine")
)
if not search_engine_id:
raise ValueError(
"Google PSE provider requires a search engine id (cx) in addition to the API key."
)
assert search_engine_id is not None
try:
num_results = int(config.get("num_results", 10))
except (TypeError, ValueError):
raise ValueError(
"Invalid value for Google PSE 'num_results'; expected integer."
)
try:
timeout_seconds = int(config.get("timeout_seconds", 10))
except (TypeError, ValueError):
raise ValueError(
"Invalid value for Google PSE 'timeout_seconds'; expected integer."
)
return GooglePSEClient(
api_key=api_key,
search_engine_id=search_engine_id,
num_results=num_results,
timeout_seconds=timeout_seconds,
)
logger.error(
f"Unhandled web search provider type '{provider_type_value}'. "
"Skipping provider initialization."
)
return None
def _build_search_provider(provider_model: Any) -> WebSearchProvider | None:
return build_search_provider_from_config(
provider_type=WebSearchProviderType(provider_model.provider_type),
api_key=provider_model.api_key,
config=provider_model.config or {},
provider_name=provider_model.name,
)
def build_content_provider_from_config(
*,
provider_type: WebContentProviderType,
api_key: str | None,
config: dict[str, str] | None,
provider_name: str = "web_content_provider",
) -> WebContentProvider | None:
provider_type_value = provider_type.value
try:
provider_type_enum = WebContentProviderType(provider_type_value)
except ValueError:
logger.error(
f"Unknown web content provider type '{provider_type_value}'. "
"Skipping provider initialization."
)
return None
if provider_type_enum == WebContentProviderType.ONYX_WEB_CRAWLER:
config = config or {}
timeout_value = config.get("timeout_seconds", 15)
try:
timeout_seconds = int(timeout_value)
except (TypeError, ValueError):
raise ValueError(
"Invalid value for Onyx Web Crawler 'timeout_seconds'; expected integer."
)
return OnyxWebCrawlerClient(timeout_seconds=timeout_seconds)
if provider_type_enum == WebContentProviderType.FIRECRAWL:
if not api_key:
raise ValueError("Firecrawl content provider requires an API key.")
assert api_key is not None
config = config or {}
timeout_seconds_str = config.get("timeout_seconds")
if timeout_seconds_str is None:
timeout_seconds = 10
else:
try:
timeout_seconds = int(timeout_seconds_str)
except (TypeError, ValueError):
raise ValueError(
"Invalid value for Firecrawl 'timeout_seconds'; expected integer."
)
return FirecrawlClient(
api_key=api_key,
base_url=config.get("base_url") or FIRECRAWL_SCRAPE_URL,
timeout_seconds=timeout_seconds,
)
logger.error(
f"Unhandled web content provider type '{provider_type_value}'. "
"Skipping provider initialization."
)
return None
def _build_content_provider(provider_model: Any) -> WebContentProvider | None:
return build_content_provider_from_config(
provider_type=WebContentProviderType(provider_model.provider_type),
api_key=provider_model.api_key,
config=provider_model.config or {},
provider_name=provider_model.name,
)
def get_default_provider() -> WebSearchProvider | None:
with get_session_with_current_tenant() as db_session:
provider_model = fetch_active_web_search_provider(db_session)
if provider_model is None:
return None
return _build_search_provider(provider_model)
def get_default_content_provider() -> WebContentProvider | None:
with get_session_with_current_tenant() as db_session:
provider_model = fetch_active_web_content_provider(db_session)
if provider_model:
provider = _build_content_provider(provider_model)
if provider:
return provider
# Fall back to built-in Onyx crawler when nothing is configured.
try:
return OnyxWebCrawlerClient()
except Exception as exc: # pragma: no cover - defensive
logger.error(f"Failed to initialize default Onyx crawler: {exc}")
return None

View File

@@ -1,37 +1,37 @@
# from operator import add
# from typing import Annotated
from operator import add
from typing import Annotated
# from onyx.agents.agent_search.dr.states import LoggerUpdate
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
# from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
# WebSearchResult,
# )
# from onyx.context.search.models import InferenceSection
from onyx.agents.agent_search.dr.states import LoggerUpdate
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
WebSearchResult,
)
from onyx.context.search.models import InferenceSection
# class InternetSearchInput(SubAgentInput):
# results_to_open: Annotated[list[tuple[str, WebSearchResult]], add] = []
# parallelization_nr: int = 0
# branch_question: Annotated[str, lambda x, y: y] = ""
# branch_questions_to_urls: Annotated[dict[str, list[str]], lambda x, y: y] = {}
# raw_documents: Annotated[list[InferenceSection], add] = []
class InternetSearchInput(SubAgentInput):
results_to_open: Annotated[list[tuple[str, WebSearchResult]], add] = []
parallelization_nr: int = 0
branch_question: Annotated[str, lambda x, y: y] = ""
branch_questions_to_urls: Annotated[dict[str, list[str]], lambda x, y: y] = {}
raw_documents: Annotated[list[InferenceSection], add] = []
# class InternetSearchUpdate(LoggerUpdate):
# results_to_open: Annotated[list[tuple[str, WebSearchResult]], add] = []
class InternetSearchUpdate(LoggerUpdate):
results_to_open: Annotated[list[tuple[str, WebSearchResult]], add] = []
# class FetchInput(SubAgentInput):
# urls_to_open: Annotated[list[str], add] = []
# branch_questions_to_urls: dict[str, list[str]]
# raw_documents: Annotated[list[InferenceSection], add] = []
class FetchInput(SubAgentInput):
urls_to_open: Annotated[list[str], add] = []
branch_questions_to_urls: dict[str, list[str]]
raw_documents: Annotated[list[InferenceSection], add] = []
# class FetchUpdate(LoggerUpdate):
# raw_documents: Annotated[list[InferenceSection], add] = []
class FetchUpdate(LoggerUpdate):
raw_documents: Annotated[list[InferenceSection], add] = []
# class SummarizeInput(SubAgentInput):
# raw_documents: Annotated[list[InferenceSection], add] = []
# branch_questions_to_urls: dict[str, list[str]]
# branch_question: str
class SummarizeInput(SubAgentInput):
raw_documents: Annotated[list[InferenceSection], add] = []
branch_questions_to_urls: dict[str, list[str]]
branch_question: str

View File

@@ -1,99 +1,99 @@
# from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
# WebContent,
# )
# from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
# WebSearchResult,
# )
# from onyx.chat.models import DOCUMENT_CITATION_NUMBER_EMPTY_VALUE
# from onyx.chat.models import LlmDoc
# from onyx.configs.constants import DocumentSource
# from onyx.context.search.models import InferenceChunk
# from onyx.context.search.models import InferenceSection
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
WebContent,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
WebSearchResult,
)
from onyx.chat.models import DOCUMENT_CITATION_NUMBER_EMPTY_VALUE
from onyx.chat.models import LlmDoc
from onyx.configs.constants import DocumentSource
from onyx.context.search.models import InferenceChunk
from onyx.context.search.models import InferenceSection
# def truncate_search_result_content(content: str, max_chars: int = 10000) -> str:
# """Truncate search result content to a maximum number of characters"""
# if len(content) <= max_chars:
# return content
# return content[:max_chars] + "..."
def truncate_search_result_content(content: str, max_chars: int = 10000) -> str:
"""Truncate search result content to a maximum number of characters"""
if len(content) <= max_chars:
return content
return content[:max_chars] + "..."
# def inference_section_from_internet_page_scrape(
# result: WebContent,
# ) -> InferenceSection:
# truncated_content = truncate_search_result_content(result.full_content)
# return InferenceSection(
# center_chunk=InferenceChunk(
# chunk_id=0,
# blurb=result.title,
# content=truncated_content,
# source_links={0: result.link},
# section_continuation=False,
# document_id="INTERNET_SEARCH_DOC_" + result.link,
# source_type=DocumentSource.WEB,
# semantic_identifier=result.link,
# title=result.title,
# boost=1,
# recency_bias=1.0,
# score=1.0,
# hidden=(not result.scrape_successful),
# metadata={},
# match_highlights=[],
# doc_summary=truncated_content,
# chunk_context=truncated_content,
# updated_at=result.published_date,
# image_file_id=None,
# ),
# chunks=[],
# combined_content=truncated_content,
# )
def dummy_inference_section_from_internet_content(
result: WebContent,
) -> InferenceSection:
truncated_content = truncate_search_result_content(result.full_content)
return InferenceSection(
center_chunk=InferenceChunk(
chunk_id=0,
blurb=result.title,
content=truncated_content,
source_links={0: result.link},
section_continuation=False,
document_id="INTERNET_SEARCH_DOC_" + result.link,
source_type=DocumentSource.WEB,
semantic_identifier=result.link,
title=result.title,
boost=1,
recency_bias=1.0,
score=1.0,
hidden=(not result.scrape_successful),
metadata={},
match_highlights=[],
doc_summary=truncated_content,
chunk_context=truncated_content,
updated_at=result.published_date,
image_file_id=None,
),
chunks=[],
combined_content=truncated_content,
)
# def dummy_inference_section_from_internet_search_result(
# result: WebSearchResult,
# ) -> InferenceSection:
# return InferenceSection(
# center_chunk=InferenceChunk(
# chunk_id=0,
# blurb=result.title,
# content="",
# source_links={0: result.link},
# section_continuation=False,
# document_id="INTERNET_SEARCH_DOC_" + result.link,
# source_type=DocumentSource.WEB,
# semantic_identifier=result.link,
# title=result.title,
# boost=1,
# recency_bias=1.0,
# score=1.0,
# hidden=False,
# metadata={},
# match_highlights=[],
# doc_summary="",
# chunk_context="",
# updated_at=result.published_date,
# image_file_id=None,
# ),
# chunks=[],
# combined_content="",
# )
def dummy_inference_section_from_internet_search_result(
result: WebSearchResult,
) -> InferenceSection:
return InferenceSection(
center_chunk=InferenceChunk(
chunk_id=0,
blurb=result.title,
content="",
source_links={0: result.link},
section_continuation=False,
document_id="INTERNET_SEARCH_DOC_" + result.link,
source_type=DocumentSource.WEB,
semantic_identifier=result.link,
title=result.title,
boost=1,
recency_bias=1.0,
score=1.0,
hidden=False,
metadata={},
match_highlights=[],
doc_summary="",
chunk_context="",
updated_at=result.published_date,
image_file_id=None,
),
chunks=[],
combined_content="",
)
# def llm_doc_from_web_content(web_content: WebContent) -> LlmDoc:
# """Create an LlmDoc from WebContent with the INTERNET_SEARCH_DOC_ prefix"""
# return LlmDoc(
# # TODO: Is this what we want to do for document_id? We're kind of overloading it since it
# # should ideally correspond to a document in the database. But I guess if you're calling this
# # function you know it won't be in the database.
# document_id="INTERNET_SEARCH_DOC_" + web_content.link,
# content=truncate_search_result_content(web_content.full_content),
# blurb=web_content.link,
# semantic_identifier=web_content.link,
# source_type=DocumentSource.WEB,
# metadata={},
# link=web_content.link,
# document_citation_number=DOCUMENT_CITATION_NUMBER_EMPTY_VALUE,
# updated_at=web_content.published_date,
# source_links={},
# match_highlights=[],
# )
def llm_doc_from_web_content(web_content: WebContent) -> LlmDoc:
"""Create an LlmDoc from WebContent with the INTERNET_SEARCH_DOC_ prefix"""
return LlmDoc(
# TODO: Is this what we want to do for document_id? We're kind of overloading it since it
# should ideally correspond to a document in the database. But I guess if you're calling this
# function you know it won't be in the database.
document_id="INTERNET_SEARCH_DOC_" + web_content.link,
content=truncate_search_result_content(web_content.full_content),
blurb=web_content.link,
semantic_identifier=web_content.link,
source_type=DocumentSource.WEB,
metadata={},
link=web_content.link,
document_citation_number=DOCUMENT_CITATION_NUMBER_EMPTY_VALUE,
updated_at=web_content.published_date,
source_links={},
match_highlights=[],
)

View File

@@ -1,277 +1,277 @@
# import copy
# import re
import copy
import re
# from langchain.schema.messages import BaseMessage
# from langchain.schema.messages import HumanMessage
from langchain.schema.messages import BaseMessage
from langchain.schema.messages import HumanMessage
# from onyx.agents.agent_search.dr.models import AggregatedDRContext
# from onyx.agents.agent_search.dr.models import IterationAnswer
# from onyx.agents.agent_search.dr.models import OrchestrationClarificationInfo
# from onyx.agents.agent_search.kb_search.graph_utils import build_document_context
# from onyx.agents.agent_search.shared_graph_utils.operators import (
# dedup_inference_section_list,
# )
# from onyx.context.search.models import InferenceSection
# from onyx.context.search.models import SavedSearchDoc
# from onyx.context.search.models import SearchDoc
# from onyx.tools.tool_implementations.web_search.web_search_tool import (
# WebSearchTool,
# )
from onyx.agents.agent_search.dr.models import AggregatedDRContext
from onyx.agents.agent_search.dr.models import IterationAnswer
from onyx.agents.agent_search.dr.models import OrchestrationClarificationInfo
from onyx.agents.agent_search.kb_search.graph_utils import build_document_context
from onyx.agents.agent_search.shared_graph_utils.operators import (
dedup_inference_section_list,
)
from onyx.context.search.models import InferenceSection
from onyx.context.search.models import SavedSearchDoc
from onyx.context.search.models import SearchDoc
from onyx.tools.tool_implementations.web_search.web_search_tool import (
WebSearchTool,
)
# CITATION_PREFIX = "CITE:"
CITATION_PREFIX = "CITE:"
# def extract_document_citations(
# answer: str, claims: list[str]
# ) -> tuple[list[int], str, list[str]]:
# """
# Finds all citations of the form [1], [2, 3], etc. and returns the list of cited indices,
# as well as the answer and claims with the citations replaced with [<CITATION_PREFIX>1],
# etc., to help with citation deduplication later on.
# """
# citations: set[int] = set()
def extract_document_citations(
answer: str, claims: list[str]
) -> tuple[list[int], str, list[str]]:
"""
Finds all citations of the form [1], [2, 3], etc. and returns the list of cited indices,
as well as the answer and claims with the citations replaced with [<CITATION_PREFIX>1],
etc., to help with citation deduplication later on.
"""
citations: set[int] = set()
# # Pattern to match both single citations [1] and multiple citations [1, 2, 3]
# # This regex matches:
# # - \[(\d+)\] for single citations like [1]
# # - \[(\d+(?:,\s*\d+)*)\] for multiple citations like [1, 2, 3]
# pattern = re.compile(r"\[(\d+(?:,\s*\d+)*)\]")
# Pattern to match both single citations [1] and multiple citations [1, 2, 3]
# This regex matches:
# - \[(\d+)\] for single citations like [1]
# - \[(\d+(?:,\s*\d+)*)\] for multiple citations like [1, 2, 3]
pattern = re.compile(r"\[(\d+(?:,\s*\d+)*)\]")
# def _extract_and_replace(match: re.Match[str]) -> str:
# numbers = [int(num) for num in match.group(1).split(",")]
# citations.update(numbers)
# return "".join(f"[{CITATION_PREFIX}{num}]" for num in numbers)
def _extract_and_replace(match: re.Match[str]) -> str:
numbers = [int(num) for num in match.group(1).split(",")]
citations.update(numbers)
return "".join(f"[{CITATION_PREFIX}{num}]" for num in numbers)
# new_answer = pattern.sub(_extract_and_replace, answer)
# new_claims = [pattern.sub(_extract_and_replace, claim) for claim in claims]
new_answer = pattern.sub(_extract_and_replace, answer)
new_claims = [pattern.sub(_extract_and_replace, claim) for claim in claims]
# return list(citations), new_answer, new_claims
return list(citations), new_answer, new_claims
# def aggregate_context(
# iteration_responses: list[IterationAnswer], include_documents: bool = True
# ) -> AggregatedDRContext:
# """
# Converts the iteration response into a single string with unified citations.
# For example,
# it 1: the answer is x [3][4]. {3: doc_abc, 4: doc_xyz}
# it 2: blah blah [1, 3]. {1: doc_xyz, 3: doc_pqr}
# Output:
# it 1: the answer is x [1][2].
# it 2: blah blah [2][3]
# [1]: doc_xyz
# [2]: doc_abc
# [3]: doc_pqr
# """
# # dedupe and merge inference section contents
# unrolled_inference_sections: list[InferenceSection] = []
# is_internet_marker_dict: dict[str, bool] = {}
# for iteration_response in sorted(
# iteration_responses,
# key=lambda x: (x.iteration_nr, x.parallelization_nr),
# ):
def aggregate_context(
iteration_responses: list[IterationAnswer], include_documents: bool = True
) -> AggregatedDRContext:
"""
Converts the iteration response into a single string with unified citations.
For example,
it 1: the answer is x [3][4]. {3: doc_abc, 4: doc_xyz}
it 2: blah blah [1, 3]. {1: doc_xyz, 3: doc_pqr}
Output:
it 1: the answer is x [1][2].
it 2: blah blah [2][3]
[1]: doc_xyz
[2]: doc_abc
[3]: doc_pqr
"""
# dedupe and merge inference section contents
unrolled_inference_sections: list[InferenceSection] = []
is_internet_marker_dict: dict[str, bool] = {}
for iteration_response in sorted(
iteration_responses,
key=lambda x: (x.iteration_nr, x.parallelization_nr),
):
# iteration_tool = iteration_response.tool
# is_internet = iteration_tool == WebSearchTool._NAME
iteration_tool = iteration_response.tool
is_internet = iteration_tool == WebSearchTool._NAME
# for cited_doc in iteration_response.cited_documents.values():
# unrolled_inference_sections.append(cited_doc)
# if cited_doc.center_chunk.document_id not in is_internet_marker_dict:
# is_internet_marker_dict[cited_doc.center_chunk.document_id] = (
# is_internet
# )
# cited_doc.center_chunk.score = None # None means maintain order
for cited_doc in iteration_response.cited_documents.values():
unrolled_inference_sections.append(cited_doc)
if cited_doc.center_chunk.document_id not in is_internet_marker_dict:
is_internet_marker_dict[cited_doc.center_chunk.document_id] = (
is_internet
)
cited_doc.center_chunk.score = None # None means maintain order
# global_documents = dedup_inference_section_list(unrolled_inference_sections)
global_documents = dedup_inference_section_list(unrolled_inference_sections)
# global_citations = {
# doc.center_chunk.document_id: i for i, doc in enumerate(global_documents, 1)
# }
global_citations = {
doc.center_chunk.document_id: i for i, doc in enumerate(global_documents, 1)
}
# # build output string
# output_strings: list[str] = []
# global_iteration_responses: list[IterationAnswer] = []
# build output string
output_strings: list[str] = []
global_iteration_responses: list[IterationAnswer] = []
# for iteration_response in sorted(
# iteration_responses,
# key=lambda x: (x.iteration_nr, x.parallelization_nr),
# ):
# # add basic iteration info
# output_strings.append(
# f"Iteration: {iteration_response.iteration_nr}, "
# f"Question {iteration_response.parallelization_nr}"
# )
# output_strings.append(f"Tool: {iteration_response.tool}")
# output_strings.append(f"Question: {iteration_response.question}")
for iteration_response in sorted(
iteration_responses,
key=lambda x: (x.iteration_nr, x.parallelization_nr),
):
# add basic iteration info
output_strings.append(
f"Iteration: {iteration_response.iteration_nr}, "
f"Question {iteration_response.parallelization_nr}"
)
output_strings.append(f"Tool: {iteration_response.tool}")
output_strings.append(f"Question: {iteration_response.question}")
# # get answer and claims with global citations
# answer_str = iteration_response.answer
# claims = iteration_response.claims or []
# get answer and claims with global citations
answer_str = iteration_response.answer
claims = iteration_response.claims or []
# iteration_citations: list[int] = []
# for local_number, cited_doc in iteration_response.cited_documents.items():
# global_number = global_citations[cited_doc.center_chunk.document_id]
# # translate local citations to global citations
# answer_str = answer_str.replace(
# f"[{CITATION_PREFIX}{local_number}]", f"[{global_number}]"
# )
# claims = [
# claim.replace(
# f"[{CITATION_PREFIX}{local_number}]", f"[{global_number}]"
# )
# for claim in claims
# ]
# iteration_citations.append(global_number)
iteration_citations: list[int] = []
for local_number, cited_doc in iteration_response.cited_documents.items():
global_number = global_citations[cited_doc.center_chunk.document_id]
# translate local citations to global citations
answer_str = answer_str.replace(
f"[{CITATION_PREFIX}{local_number}]", f"[{global_number}]"
)
claims = [
claim.replace(
f"[{CITATION_PREFIX}{local_number}]", f"[{global_number}]"
)
for claim in claims
]
iteration_citations.append(global_number)
# # add answer, claims, and citation info
# if answer_str:
# output_strings.append(f"Answer: {answer_str}")
# if claims:
# output_strings.append(
# "Claims: " + "".join(f"\n - {claim}" for claim in claims or [])
# or "No claims provided"
# )
# if not answer_str and not claims:
# output_strings.append(
# "Retrieved documents: "
# + (
# "".join(
# f"[{global_number}]"
# for global_number in sorted(iteration_citations)
# )
# or "No documents retrieved"
# )
# )
# output_strings.append("\n---\n")
# add answer, claims, and citation info
if answer_str:
output_strings.append(f"Answer: {answer_str}")
if claims:
output_strings.append(
"Claims: " + "".join(f"\n - {claim}" for claim in claims or [])
or "No claims provided"
)
if not answer_str and not claims:
output_strings.append(
"Retrieved documents: "
+ (
"".join(
f"[{global_number}]"
for global_number in sorted(iteration_citations)
)
or "No documents retrieved"
)
)
output_strings.append("\n---\n")
# # save global iteration response
# iteration_response_copy = iteration_response.model_copy()
# iteration_response_copy.answer = answer_str
# iteration_response_copy.claims = claims
# iteration_response_copy.cited_documents = {
# global_citations[doc.center_chunk.document_id]: doc
# for doc in iteration_response.cited_documents.values()
# }
# global_iteration_responses.append(iteration_response_copy)
# save global iteration response
iteration_response_copy = iteration_response.model_copy()
iteration_response_copy.answer = answer_str
iteration_response_copy.claims = claims
iteration_response_copy.cited_documents = {
global_citations[doc.center_chunk.document_id]: doc
for doc in iteration_response.cited_documents.values()
}
global_iteration_responses.append(iteration_response_copy)
# # add document contents if requested
# if include_documents:
# if global_documents:
# output_strings.append("Cited document contents:")
# for doc in global_documents:
# output_strings.append(
# build_document_context(
# doc, global_citations[doc.center_chunk.document_id]
# )
# )
# output_strings.append("\n---\n")
# add document contents if requested
if include_documents:
if global_documents:
output_strings.append("Cited document contents:")
for doc in global_documents:
output_strings.append(
build_document_context(
doc, global_citations[doc.center_chunk.document_id]
)
)
output_strings.append("\n---\n")
# return AggregatedDRContext(
# context="\n".join(output_strings),
# cited_documents=global_documents,
# is_internet_marker_dict=is_internet_marker_dict,
# global_iteration_responses=global_iteration_responses,
# )
return AggregatedDRContext(
context="\n".join(output_strings),
cited_documents=global_documents,
is_internet_marker_dict=is_internet_marker_dict,
global_iteration_responses=global_iteration_responses,
)
# def get_chat_history_string(chat_history: list[BaseMessage], max_messages: int) -> str:
# """
# Get the chat history (up to max_messages) as a string.
# """
# # get past max_messages USER, ASSISTANT message pairs
def get_chat_history_string(chat_history: list[BaseMessage], max_messages: int) -> str:
"""
Get the chat history (up to max_messages) as a string.
"""
# get past max_messages USER, ASSISTANT message pairs
# past_messages = chat_history[-max_messages * 2 :]
# filtered_past_messages = copy.deepcopy(past_messages)
past_messages = chat_history[-max_messages * 2 :]
filtered_past_messages = copy.deepcopy(past_messages)
# for past_message_number, past_message in enumerate(past_messages):
for past_message_number, past_message in enumerate(past_messages):
# if isinstance(past_message.content, list):
# removal_indices = []
# for content_piece_number, content_piece in enumerate(past_message.content):
# if (
# isinstance(content_piece, dict)
# and content_piece.get("type") != "text"
# ):
# removal_indices.append(content_piece_number)
if isinstance(past_message.content, list):
removal_indices = []
for content_piece_number, content_piece in enumerate(past_message.content):
if (
isinstance(content_piece, dict)
and content_piece.get("type") != "text"
):
removal_indices.append(content_piece_number)
# # Only rebuild the content list if there are items to remove
# if removal_indices:
# filtered_past_messages[past_message_number].content = [
# content_piece
# for content_piece_number, content_piece in enumerate(
# past_message.content
# )
# if content_piece_number not in removal_indices
# ]
# Only rebuild the content list if there are items to remove
if removal_indices:
filtered_past_messages[past_message_number].content = [
content_piece
for content_piece_number, content_piece in enumerate(
past_message.content
)
if content_piece_number not in removal_indices
]
# else:
# continue
else:
continue
# return (
# "...\n" if len(chat_history) > len(filtered_past_messages) else ""
# ) + "\n".join(
# ("user" if isinstance(msg, HumanMessage) else "you")
# + f": {str(msg.content).strip()}"
# for msg in filtered_past_messages
# )
return (
"...\n" if len(chat_history) > len(filtered_past_messages) else ""
) + "\n".join(
("user" if isinstance(msg, HumanMessage) else "you")
+ f": {str(msg.content).strip()}"
for msg in filtered_past_messages
)
# def get_prompt_question(
# question: str, clarification: OrchestrationClarificationInfo | None
# ) -> str:
# if clarification:
# clarification_question = clarification.clarification_question
# clarification_response = clarification.clarification_response
# return (
# f"Initial User Question: {question}\n"
# f"(Clarification Question: {clarification_question}\n"
# f"User Response: {clarification_response})"
# )
def get_prompt_question(
question: str, clarification: OrchestrationClarificationInfo | None
) -> str:
if clarification:
clarification_question = clarification.clarification_question
clarification_response = clarification.clarification_response
return (
f"Initial User Question: {question}\n"
f"(Clarification Question: {clarification_question}\n"
f"User Response: {clarification_response})"
)
# return question
return question
# def create_tool_call_string(tool_name: str, query_list: list[str]) -> str:
# """
# Create a string representation of the tool call.
# """
# questions_str = "\n - ".join(query_list)
# return f"Tool: {tool_name}\n\nQuestions:\n{questions_str}"
def create_tool_call_string(tool_name: str, query_list: list[str]) -> str:
"""
Create a string representation of the tool call.
"""
questions_str = "\n - ".join(query_list)
return f"Tool: {tool_name}\n\nQuestions:\n{questions_str}"
# def parse_plan_to_dict(plan_text: str) -> dict[str, str]:
# # Convert plan string to numbered dict format
# if not plan_text:
# return {}
def parse_plan_to_dict(plan_text: str) -> dict[str, str]:
# Convert plan string to numbered dict format
if not plan_text:
return {}
# # Split by numbered items (1., 2., 3., etc. or 1), 2), 3), etc.)
# parts = re.split(r"(\d+[.)])", plan_text)
# plan_dict = {}
# Split by numbered items (1., 2., 3., etc. or 1), 2), 3), etc.)
parts = re.split(r"(\d+[.)])", plan_text)
plan_dict = {}
# for i in range(
# 1, len(parts), 2
# ): # Skip empty first part, then take number and text pairs
# if i + 1 < len(parts):
# number = parts[i].rstrip(".)") # Remove the dot or parenthesis
# text = parts[i + 1].strip()
# if text: # Only add if there's actual content
# plan_dict[number] = text
for i in range(
1, len(parts), 2
): # Skip empty first part, then take number and text pairs
if i + 1 < len(parts):
number = parts[i].rstrip(".)") # Remove the dot or parenthesis
text = parts[i + 1].strip()
if text: # Only add if there's actual content
plan_dict[number] = text
# return plan_dict
return plan_dict
# def convert_inference_sections_to_search_docs(
# inference_sections: list[InferenceSection],
# is_internet: bool = False,
# ) -> list[SavedSearchDoc]:
# # Convert InferenceSections to SavedSearchDocs
# search_docs = SearchDoc.from_chunks_or_sections(inference_sections)
# for search_doc in search_docs:
# search_doc.is_internet = is_internet
def convert_inference_sections_to_search_docs(
inference_sections: list[InferenceSection],
is_internet: bool = False,
) -> list[SavedSearchDoc]:
# Convert InferenceSections to SavedSearchDocs
search_docs = SearchDoc.from_chunks_or_sections(inference_sections)
for search_doc in search_docs:
search_doc.is_internet = is_internet
# retrieved_saved_search_docs = [
# SavedSearchDoc.from_search_doc(search_doc, db_doc_id=0)
# for search_doc in search_docs
# ]
# return retrieved_saved_search_docs
retrieved_saved_search_docs = [
SavedSearchDoc.from_search_doc(search_doc, db_doc_id=0)
for search_doc in search_docs
]
return retrieved_saved_search_docs

View File

@@ -1,87 +1,87 @@
# from collections.abc import Hashable
# from datetime import datetime
# from enum import Enum
from collections.abc import Hashable
from datetime import datetime
from enum import Enum
# from langgraph.types import Send
from langgraph.types import Send
# from onyx.agents.agent_search.kb_search.states import KGAnswerStrategy
# from onyx.agents.agent_search.kb_search.states import KGSearchType
# from onyx.agents.agent_search.kb_search.states import KGSourceDivisionType
# from onyx.agents.agent_search.kb_search.states import MainState
# from onyx.agents.agent_search.kb_search.states import ResearchObjectInput
# from onyx.configs.kg_configs import KG_MAX_DECOMPOSITION_SEGMENTS
# from onyx.utils.logger import setup_logger
from onyx.agents.agent_search.kb_search.states import KGAnswerStrategy
from onyx.agents.agent_search.kb_search.states import KGSearchType
from onyx.agents.agent_search.kb_search.states import KGSourceDivisionType
from onyx.agents.agent_search.kb_search.states import MainState
from onyx.agents.agent_search.kb_search.states import ResearchObjectInput
from onyx.configs.kg_configs import KG_MAX_DECOMPOSITION_SEGMENTS
from onyx.utils.logger import setup_logger
# logger = setup_logger()
logger = setup_logger()
# class KGAnalysisPath(str, Enum):
# PROCESS_KG_ONLY_ANSWERS = "process_kg_only_answers"
# CONSTRUCT_DEEP_SEARCH_FILTERS = "construct_deep_search_filters"
class KGAnalysisPath(str, Enum):
PROCESS_KG_ONLY_ANSWERS = "process_kg_only_answers"
CONSTRUCT_DEEP_SEARCH_FILTERS = "construct_deep_search_filters"
# def simple_vs_search(
# state: MainState,
# ) -> str:
def simple_vs_search(
state: MainState,
) -> str:
# identified_strategy = state.updated_strategy or state.strategy
identified_strategy = state.updated_strategy or state.strategy
# if (
# identified_strategy == KGAnswerStrategy.DEEP
# or state.search_type == KGSearchType.SEARCH
# ):
# return KGAnalysisPath.CONSTRUCT_DEEP_SEARCH_FILTERS.value
# else:
# return KGAnalysisPath.PROCESS_KG_ONLY_ANSWERS.value
if (
identified_strategy == KGAnswerStrategy.DEEP
or state.search_type == KGSearchType.SEARCH
):
return KGAnalysisPath.CONSTRUCT_DEEP_SEARCH_FILTERS.value
else:
return KGAnalysisPath.PROCESS_KG_ONLY_ANSWERS.value
# def research_individual_object(
# state: MainState,
# ) -> list[Send | Hashable] | str:
# edge_start_time = datetime.now()
def research_individual_object(
state: MainState,
) -> list[Send | Hashable] | str:
edge_start_time = datetime.now()
# assert state.div_con_entities is not None
# assert state.broken_down_question is not None
# assert state.vespa_filter_results is not None
assert state.div_con_entities is not None
assert state.broken_down_question is not None
assert state.vespa_filter_results is not None
# if (
# state.search_type == KGSearchType.SQL
# and state.strategy == KGAnswerStrategy.DEEP
# ):
if (
state.search_type == KGSearchType.SQL
and state.strategy == KGAnswerStrategy.DEEP
):
# if state.source_filters and state.source_division:
# segments = state.source_filters
# segment_type = KGSourceDivisionType.SOURCE.value
# else:
# segments = state.div_con_entities
# segment_type = KGSourceDivisionType.ENTITY.value
if state.source_filters and state.source_division:
segments = state.source_filters
segment_type = KGSourceDivisionType.SOURCE.value
else:
segments = state.div_con_entities
segment_type = KGSourceDivisionType.ENTITY.value
# if segments and (len(segments) > KG_MAX_DECOMPOSITION_SEGMENTS):
# logger.debug(f"Too many sources ({len(segments)}), usingfiltered search")
# return "filtered_search"
if segments and (len(segments) > KG_MAX_DECOMPOSITION_SEGMENTS):
logger.debug(f"Too many sources ({len(segments)}), usingfiltered search")
return "filtered_search"
# return [
# Send(
# "process_individual_deep_search",
# ResearchObjectInput(
# research_nr=research_nr + 1,
# entity=entity,
# broken_down_question=state.broken_down_question,
# vespa_filter_results=state.vespa_filter_results,
# source_division=state.source_division,
# source_entity_filters=state.source_filters,
# segment_type=segment_type,
# log_messages=[
# f"{edge_start_time} -- Main Edge - Parallelize Initial Sub-question Answering"
# ],
# step_results=[],
# ),
# )
# for research_nr, entity in enumerate(segments)
# ]
# elif state.search_type == KGSearchType.SEARCH:
# return "filtered_search"
# else:
# raise ValueError(
# f"Invalid combination of search type: {state.search_type} and strategy: {state.strategy}"
# )
return [
Send(
"process_individual_deep_search",
ResearchObjectInput(
research_nr=research_nr + 1,
entity=entity,
broken_down_question=state.broken_down_question,
vespa_filter_results=state.vespa_filter_results,
source_division=state.source_division,
source_entity_filters=state.source_filters,
segment_type=segment_type,
log_messages=[
f"{edge_start_time} -- Main Edge - Parallelize Initial Sub-question Answering"
],
step_results=[],
),
)
for research_nr, entity in enumerate(segments)
]
elif state.search_type == KGSearchType.SEARCH:
return "filtered_search"
else:
raise ValueError(
f"Invalid combination of search type: {state.search_type} and strategy: {state.strategy}"
)

View File

@@ -1,143 +1,143 @@
# from langgraph.graph import END
# from langgraph.graph import START
# from langgraph.graph import StateGraph
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
# from onyx.agents.agent_search.kb_search.conditional_edges import (
# research_individual_object,
# )
# from onyx.agents.agent_search.kb_search.conditional_edges import simple_vs_search
# from onyx.agents.agent_search.kb_search.nodes.a1_extract_ert import extract_ert
# from onyx.agents.agent_search.kb_search.nodes.a2_analyze import analyze
# from onyx.agents.agent_search.kb_search.nodes.a3_generate_simple_sql import (
# generate_simple_sql,
# )
# from onyx.agents.agent_search.kb_search.nodes.b1_construct_deep_search_filters import (
# construct_deep_search_filters,
# )
# from onyx.agents.agent_search.kb_search.nodes.b2p_process_individual_deep_search import (
# process_individual_deep_search,
# )
# from onyx.agents.agent_search.kb_search.nodes.b2s_filtered_search import filtered_search
# from onyx.agents.agent_search.kb_search.nodes.b3_consolidate_individual_deep_search import (
# consolidate_individual_deep_search,
# )
# from onyx.agents.agent_search.kb_search.nodes.c1_process_kg_only_answers import (
# process_kg_only_answers,
# )
# from onyx.agents.agent_search.kb_search.nodes.d1_generate_answer import generate_answer
# from onyx.agents.agent_search.kb_search.nodes.d2_logging_node import log_data
# from onyx.agents.agent_search.kb_search.states import MainInput
# from onyx.agents.agent_search.kb_search.states import MainState
# from onyx.utils.logger import setup_logger
from onyx.agents.agent_search.kb_search.conditional_edges import (
research_individual_object,
)
from onyx.agents.agent_search.kb_search.conditional_edges import simple_vs_search
from onyx.agents.agent_search.kb_search.nodes.a1_extract_ert import extract_ert
from onyx.agents.agent_search.kb_search.nodes.a2_analyze import analyze
from onyx.agents.agent_search.kb_search.nodes.a3_generate_simple_sql import (
generate_simple_sql,
)
from onyx.agents.agent_search.kb_search.nodes.b1_construct_deep_search_filters import (
construct_deep_search_filters,
)
from onyx.agents.agent_search.kb_search.nodes.b2p_process_individual_deep_search import (
process_individual_deep_search,
)
from onyx.agents.agent_search.kb_search.nodes.b2s_filtered_search import filtered_search
from onyx.agents.agent_search.kb_search.nodes.b3_consolidate_individual_deep_search import (
consolidate_individual_deep_search,
)
from onyx.agents.agent_search.kb_search.nodes.c1_process_kg_only_answers import (
process_kg_only_answers,
)
from onyx.agents.agent_search.kb_search.nodes.d1_generate_answer import generate_answer
from onyx.agents.agent_search.kb_search.nodes.d2_logging_node import log_data
from onyx.agents.agent_search.kb_search.states import MainInput
from onyx.agents.agent_search.kb_search.states import MainState
from onyx.utils.logger import setup_logger
# logger = setup_logger()
logger = setup_logger()
# def kb_graph_builder() -> StateGraph:
# """
# LangGraph graph builder for the knowledge graph search process.
# """
def kb_graph_builder() -> StateGraph:
"""
LangGraph graph builder for the knowledge graph search process.
"""
# graph = StateGraph(
# state_schema=MainState,
# input=MainInput,
# )
graph = StateGraph(
state_schema=MainState,
input=MainInput,
)
# ### Add nodes ###
### Add nodes ###
# graph.add_node(
# "extract_ert",
# extract_ert,
# )
graph.add_node(
"extract_ert",
extract_ert,
)
# graph.add_node(
# "generate_simple_sql",
# generate_simple_sql,
# )
graph.add_node(
"generate_simple_sql",
generate_simple_sql,
)
# graph.add_node(
# "filtered_search",
# filtered_search,
# )
graph.add_node(
"filtered_search",
filtered_search,
)
# graph.add_node(
# "analyze",
# analyze,
# )
graph.add_node(
"analyze",
analyze,
)
# graph.add_node(
# "generate_answer",
# generate_answer,
# )
graph.add_node(
"generate_answer",
generate_answer,
)
# graph.add_node(
# "log_data",
# log_data,
# )
graph.add_node(
"log_data",
log_data,
)
# graph.add_node(
# "construct_deep_search_filters",
# construct_deep_search_filters,
# )
graph.add_node(
"construct_deep_search_filters",
construct_deep_search_filters,
)
# graph.add_node(
# "process_individual_deep_search",
# process_individual_deep_search,
# )
graph.add_node(
"process_individual_deep_search",
process_individual_deep_search,
)
# graph.add_node(
# "consolidate_individual_deep_search",
# consolidate_individual_deep_search,
# )
graph.add_node(
"consolidate_individual_deep_search",
consolidate_individual_deep_search,
)
# graph.add_node("process_kg_only_answers", process_kg_only_answers)
graph.add_node("process_kg_only_answers", process_kg_only_answers)
# ### Add edges ###
### Add edges ###
# graph.add_edge(start_key=START, end_key="extract_ert")
graph.add_edge(start_key=START, end_key="extract_ert")
# graph.add_edge(
# start_key="extract_ert",
# end_key="analyze",
# )
graph.add_edge(
start_key="extract_ert",
end_key="analyze",
)
# graph.add_edge(
# start_key="analyze",
# end_key="generate_simple_sql",
# )
graph.add_edge(
start_key="analyze",
end_key="generate_simple_sql",
)
# graph.add_conditional_edges("generate_simple_sql", simple_vs_search)
graph.add_conditional_edges("generate_simple_sql", simple_vs_search)
# graph.add_edge(start_key="process_kg_only_answers", end_key="generate_answer")
graph.add_edge(start_key="process_kg_only_answers", end_key="generate_answer")
# graph.add_conditional_edges(
# source="construct_deep_search_filters",
# path=research_individual_object,
# path_map=["process_individual_deep_search", "filtered_search"],
# )
graph.add_conditional_edges(
source="construct_deep_search_filters",
path=research_individual_object,
path_map=["process_individual_deep_search", "filtered_search"],
)
# graph.add_edge(
# start_key="process_individual_deep_search",
# end_key="consolidate_individual_deep_search",
# )
graph.add_edge(
start_key="process_individual_deep_search",
end_key="consolidate_individual_deep_search",
)
# graph.add_edge(
# start_key="consolidate_individual_deep_search", end_key="generate_answer"
# )
graph.add_edge(
start_key="consolidate_individual_deep_search", end_key="generate_answer"
)
# graph.add_edge(
# start_key="filtered_search",
# end_key="generate_answer",
# )
graph.add_edge(
start_key="filtered_search",
end_key="generate_answer",
)
# graph.add_edge(
# start_key="generate_answer",
# end_key="log_data",
# )
graph.add_edge(
start_key="generate_answer",
end_key="log_data",
)
# graph.add_edge(
# start_key="log_data",
# end_key=END,
# )
graph.add_edge(
start_key="log_data",
end_key=END,
)
# return graph
return graph

View File

@@ -1,244 +1,244 @@
# import re
import re
# from onyx.agents.agent_search.kb_search.models import KGEntityDocInfo
# from onyx.agents.agent_search.kb_search.models import KGExpandedGraphObjects
# from onyx.agents.agent_search.kb_search.states import SubQuestionAnswerResults
# from onyx.agents.agent_search.kb_search.step_definitions import (
# KG_SEARCH_STEP_DESCRIPTIONS,
# )
# from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkRetrievalStats
# from onyx.chat.models import LlmDoc
# from onyx.context.search.models import InferenceChunk
# from onyx.context.search.models import InferenceSection
# from onyx.db.document import get_kg_doc_info_for_entity_name
# from onyx.db.engine.sql_engine import get_session_with_current_tenant
# from onyx.db.entities import get_document_id_for_entity
# from onyx.db.entities import get_entity_name
# from onyx.db.entity_type import get_entity_types
# from onyx.kg.utils.formatting_utils import make_entity_id
# from onyx.kg.utils.formatting_utils import split_relationship_id
# from onyx.utils.logger import setup_logger
from onyx.agents.agent_search.kb_search.models import KGEntityDocInfo
from onyx.agents.agent_search.kb_search.models import KGExpandedGraphObjects
from onyx.agents.agent_search.kb_search.states import SubQuestionAnswerResults
from onyx.agents.agent_search.kb_search.step_definitions import (
KG_SEARCH_STEP_DESCRIPTIONS,
)
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkRetrievalStats
from onyx.chat.models import LlmDoc
from onyx.context.search.models import InferenceChunk
from onyx.context.search.models import InferenceSection
from onyx.db.document import get_kg_doc_info_for_entity_name
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.entities import get_document_id_for_entity
from onyx.db.entities import get_entity_name
from onyx.db.entity_type import get_entity_types
from onyx.kg.utils.formatting_utils import make_entity_id
from onyx.kg.utils.formatting_utils import split_relationship_id
from onyx.utils.logger import setup_logger
# logger = setup_logger()
logger = setup_logger()
# def _check_entities_disconnected(
# current_entities: list[str], current_relationships: list[str]
# ) -> bool:
# """
# Check if all entities in current_entities are disconnected via the given relationships.
# Relationships are in the format: source_entity__relationship_name__target_entity
def _check_entities_disconnected(
current_entities: list[str], current_relationships: list[str]
) -> bool:
"""
Check if all entities in current_entities are disconnected via the given relationships.
Relationships are in the format: source_entity__relationship_name__target_entity
# Args:
# current_entities: List of entity IDs to check connectivity for
# current_relationships: List of relationships in format source__relationship__target
Args:
current_entities: List of entity IDs to check connectivity for
current_relationships: List of relationships in format source__relationship__target
# Returns:
# bool: True if all entities are disconnected, False otherwise
# """
# if not current_entities:
# return True
Returns:
bool: True if all entities are disconnected, False otherwise
"""
if not current_entities:
return True
# # Create a graph representation using adjacency list
# graph: dict[str, set[str]] = {entity: set() for entity in current_entities}
# Create a graph representation using adjacency list
graph: dict[str, set[str]] = {entity: set() for entity in current_entities}
# # Build the graph from relationships
# for relationship in current_relationships:
# try:
# source, _, target = split_relationship_id(relationship)
# if source in graph and target in graph:
# graph[source].add(target)
# # Add reverse edge to capture that we do also have a relationship in the other direction,
# # albeit not quite the same one.
# graph[target].add(source)
# except ValueError:
# raise ValueError(f"Invalid relationship format: {relationship}")
# Build the graph from relationships
for relationship in current_relationships:
try:
source, _, target = split_relationship_id(relationship)
if source in graph and target in graph:
graph[source].add(target)
# Add reverse edge to capture that we do also have a relationship in the other direction,
# albeit not quite the same one.
graph[target].add(source)
except ValueError:
raise ValueError(f"Invalid relationship format: {relationship}")
# # Use BFS to check if all entities are connected
# visited: set[str] = set()
# start_entity = current_entities[0]
# Use BFS to check if all entities are connected
visited: set[str] = set()
start_entity = current_entities[0]
# def _bfs(start: str) -> None:
# queue = [start]
# visited.add(start)
# while queue:
# current = queue.pop(0)
# for neighbor in graph[current]:
# if neighbor not in visited:
# visited.add(neighbor)
# queue.append(neighbor)
def _bfs(start: str) -> None:
queue = [start]
visited.add(start)
while queue:
current = queue.pop(0)
for neighbor in graph[current]:
if neighbor not in visited:
visited.add(neighbor)
queue.append(neighbor)
# # Start BFS from the first entity
# _bfs(start_entity)
# Start BFS from the first entity
_bfs(start_entity)
# logger.debug(f"Number of visited entities: {len(visited)}")
logger.debug(f"Number of visited entities: {len(visited)}")
# # Check if all current_entities are in visited
# return not all(entity in visited for entity in current_entities)
# Check if all current_entities are in visited
return not all(entity in visited for entity in current_entities)
# def create_minimal_connected_query_graph(
# entities: list[str], relationships: list[str], max_depth: int = 1
# ) -> KGExpandedGraphObjects:
# """
# TODO: Implement this. For now we'll trust the SQL generation to do the right thing.
# Return the original entities and relationships.
# """
# return KGExpandedGraphObjects(entities=entities, relationships=relationships)
def create_minimal_connected_query_graph(
entities: list[str], relationships: list[str], max_depth: int = 1
) -> KGExpandedGraphObjects:
"""
TODO: Implement this. For now we'll trust the SQL generation to do the right thing.
Return the original entities and relationships.
"""
return KGExpandedGraphObjects(entities=entities, relationships=relationships)
# def get_doc_information_for_entity(entity_id_name: str) -> KGEntityDocInfo:
# """
# Get document information for an entity, including its semantic name and document details.
# """
# if "::" not in entity_id_name:
# return KGEntityDocInfo(
# doc_id=None,
# doc_semantic_id=None,
# doc_link=None,
# semantic_entity_name=entity_id_name,
# semantic_linked_entity_name=entity_id_name,
# )
def get_doc_information_for_entity(entity_id_name: str) -> KGEntityDocInfo:
"""
Get document information for an entity, including its semantic name and document details.
"""
if "::" not in entity_id_name:
return KGEntityDocInfo(
doc_id=None,
doc_semantic_id=None,
doc_link=None,
semantic_entity_name=entity_id_name,
semantic_linked_entity_name=entity_id_name,
)
# entity_type, entity_name = map(str.strip, entity_id_name.split("::", 1))
entity_type, entity_name = map(str.strip, entity_id_name.split("::", 1))
# with get_session_with_current_tenant() as db_session:
# entity_document_id = get_document_id_for_entity(db_session, entity_id_name)
# if entity_document_id:
# return get_kg_doc_info_for_entity_name(
# db_session, entity_document_id, entity_type
# )
# else:
# entity_actual_name = get_entity_name(db_session, entity_id_name)
with get_session_with_current_tenant() as db_session:
entity_document_id = get_document_id_for_entity(db_session, entity_id_name)
if entity_document_id:
return get_kg_doc_info_for_entity_name(
db_session, entity_document_id, entity_type
)
else:
entity_actual_name = get_entity_name(db_session, entity_id_name)
# return KGEntityDocInfo(
# doc_id=None,
# doc_semantic_id=None,
# doc_link=None,
# semantic_entity_name=f"{entity_type} {entity_actual_name or entity_id_name}",
# semantic_linked_entity_name=f"{entity_type} {entity_actual_name or entity_id_name}",
# )
return KGEntityDocInfo(
doc_id=None,
doc_semantic_id=None,
doc_link=None,
semantic_entity_name=f"{entity_type} {entity_actual_name or entity_id_name}",
semantic_linked_entity_name=f"{entity_type} {entity_actual_name or entity_id_name}",
)
# def rename_entities_in_answer(answer: str) -> str:
# """
# Process entity references in the answer string by:
# 1. Extracting all strings matching <str>:<str> or <str>: <str> patterns
# 2. Looking up these references in the entity table
# 3. Replacing valid references with their corresponding values
def rename_entities_in_answer(answer: str) -> str:
"""
Process entity references in the answer string by:
1. Extracting all strings matching <str>:<str> or <str>: <str> patterns
2. Looking up these references in the entity table
3. Replacing valid references with their corresponding values
# Args:
# answer: The input string containing potential entity references
Args:
answer: The input string containing potential entity references
# Returns:
# str: The processed string with entity references replaced
# """
# logger.debug(f"Input answer: {answer}")
Returns:
str: The processed string with entity references replaced
"""
logger.debug(f"Input answer: {answer}")
# # Clean up any spaces around ::
# answer = re.sub(r"::\s+", "::", answer)
# Clean up any spaces around ::
answer = re.sub(r"::\s+", "::", answer)
# # Pattern to match entity_type::entity_name, with optional quotes
# pattern = r"(?:')?([a-zA-Z0-9-]+)::([a-zA-Z0-9]+)(?:')?"
# Pattern to match entity_type::entity_name, with optional quotes
pattern = r"(?:')?([a-zA-Z0-9-]+)::([a-zA-Z0-9]+)(?:')?"
# matches = list(re.finditer(pattern, answer))
# logger.debug(f"Found {len(matches)} matches")
matches = list(re.finditer(pattern, answer))
logger.debug(f"Found {len(matches)} matches")
# # get active entity types
# with get_session_with_current_tenant() as db_session:
# active_entity_types = [
# x.id_name for x in get_entity_types(db_session, active=True)
# ]
# logger.debug(f"Active entity types: {active_entity_types}")
# get active entity types
with get_session_with_current_tenant() as db_session:
active_entity_types = [
x.id_name for x in get_entity_types(db_session, active=True)
]
logger.debug(f"Active entity types: {active_entity_types}")
# # Create dictionary for processed references
# processed_refs = {}
# Create dictionary for processed references
processed_refs = {}
# for match in matches:
# entity_type = match.group(1).upper().strip()
# entity_name = match.group(2).strip()
# potential_entity_id_name = make_entity_id(entity_type, entity_name)
for match in matches:
entity_type = match.group(1).upper().strip()
entity_name = match.group(2).strip()
potential_entity_id_name = make_entity_id(entity_type, entity_name)
# if entity_type not in active_entity_types:
# continue
if entity_type not in active_entity_types:
continue
# replacement_candidate = get_doc_information_for_entity(potential_entity_id_name)
replacement_candidate = get_doc_information_for_entity(potential_entity_id_name)
# if replacement_candidate.doc_id:
# # Store both the original match and the entity_id_name for replacement
# processed_refs[match.group(0)] = (
# replacement_candidate.semantic_linked_entity_name
# )
# else:
# processed_refs[match.group(0)] = replacement_candidate.semantic_entity_name
if replacement_candidate.doc_id:
# Store both the original match and the entity_id_name for replacement
processed_refs[match.group(0)] = (
replacement_candidate.semantic_linked_entity_name
)
else:
processed_refs[match.group(0)] = replacement_candidate.semantic_entity_name
# # Replace all references in the answer
# for ref, replacement in processed_refs.items():
# answer = answer.replace(ref, replacement)
# logger.debug(f"Replaced {ref} with {replacement}")
# Replace all references in the answer
for ref, replacement in processed_refs.items():
answer = answer.replace(ref, replacement)
logger.debug(f"Replaced {ref} with {replacement}")
# return answer
return answer
# def build_document_context(
# document: InferenceSection | LlmDoc, document_number: int
# ) -> str:
# """
# Build a context string for a document.
# """
def build_document_context(
document: InferenceSection | LlmDoc, document_number: int
) -> str:
"""
Build a context string for a document.
"""
# metadata_list: list[str] = []
# document_content: str | None = None
# info_source: InferenceChunk | LlmDoc | None = None
# info_content: str | None = None
metadata_list: list[str] = []
document_content: str | None = None
info_source: InferenceChunk | LlmDoc | None = None
info_content: str | None = None
# if isinstance(document, InferenceSection):
# info_source = document.center_chunk
# info_content = document.combined_content
# elif isinstance(document, LlmDoc):
# info_source = document
# info_content = document.content
if isinstance(document, InferenceSection):
info_source = document.center_chunk
info_content = document.combined_content
elif isinstance(document, LlmDoc):
info_source = document
info_content = document.content
# for key, value in info_source.metadata.items():
# metadata_list.append(f" - {key}: {value}")
for key, value in info_source.metadata.items():
metadata_list.append(f" - {key}: {value}")
# if metadata_list:
# metadata_str = "- Document Metadata:\n" + "\n".join(metadata_list)
# else:
# metadata_str = ""
if metadata_list:
metadata_str = "- Document Metadata:\n" + "\n".join(metadata_list)
else:
metadata_str = ""
# # Construct document header with number and semantic identifier
# doc_header = f"Document {str(document_number)}: {info_source.semantic_identifier}"
# Construct document header with number and semantic identifier
doc_header = f"Document {str(document_number)}: {info_source.semantic_identifier}"
# # Combine all parts with proper spacing
# document_content = f"{doc_header}\n\n{metadata_str}\n\n{info_content}"
# Combine all parts with proper spacing
document_content = f"{doc_header}\n\n{metadata_str}\n\n{info_content}"
# return document_content
return document_content
# def get_near_empty_step_results(
# step_number: int,
# step_answer: str,
# verified_reranked_documents: list[InferenceSection] = [],
# ) -> SubQuestionAnswerResults:
# """
# Get near-empty step results from a list of step results.
# """
# return SubQuestionAnswerResults(
# question=KG_SEARCH_STEP_DESCRIPTIONS[step_number].description,
# question_id="0_" + str(step_number),
# answer=step_answer,
# verified_high_quality=True,
# sub_query_retrieval_results=[],
# verified_reranked_documents=verified_reranked_documents,
# context_documents=[],
# cited_documents=[],
# sub_question_retrieval_stats=AgentChunkRetrievalStats(
# verified_count=None,
# verified_avg_scores=None,
# rejected_count=None,
# rejected_avg_scores=None,
# verified_doc_chunk_ids=[],
# dismissed_doc_chunk_ids=[],
# ),
# )
def get_near_empty_step_results(
step_number: int,
step_answer: str,
verified_reranked_documents: list[InferenceSection] = [],
) -> SubQuestionAnswerResults:
"""
Get near-empty step results from a list of step results.
"""
return SubQuestionAnswerResults(
question=KG_SEARCH_STEP_DESCRIPTIONS[step_number].description,
question_id="0_" + str(step_number),
answer=step_answer,
verified_high_quality=True,
sub_query_retrieval_results=[],
verified_reranked_documents=verified_reranked_documents,
context_documents=[],
cited_documents=[],
sub_question_retrieval_stats=AgentChunkRetrievalStats(
verified_count=None,
verified_avg_scores=None,
rejected_count=None,
rejected_avg_scores=None,
verified_doc_chunk_ids=[],
dismissed_doc_chunk_ids=[],
),
)

View File

@@ -1,55 +1,55 @@
# from pydantic import BaseModel
from pydantic import BaseModel
# from onyx.agents.agent_search.kb_search.states import KGAnswerFormat
# from onyx.agents.agent_search.kb_search.states import KGAnswerStrategy
# from onyx.agents.agent_search.kb_search.states import KGRelationshipDetection
# from onyx.agents.agent_search.kb_search.states import KGSearchType
# from onyx.agents.agent_search.kb_search.states import YesNoEnum
from onyx.agents.agent_search.kb_search.states import KGAnswerFormat
from onyx.agents.agent_search.kb_search.states import KGAnswerStrategy
from onyx.agents.agent_search.kb_search.states import KGRelationshipDetection
from onyx.agents.agent_search.kb_search.states import KGSearchType
from onyx.agents.agent_search.kb_search.states import YesNoEnum
# class KGQuestionEntityExtractionResult(BaseModel):
# entities: list[str]
# time_filter: str | None
class KGQuestionEntityExtractionResult(BaseModel):
entities: list[str]
time_filter: str | None
# class KGViewNames(BaseModel):
# allowed_docs_view_name: str
# kg_relationships_view_name: str
# kg_entity_view_name: str
class KGViewNames(BaseModel):
allowed_docs_view_name: str
kg_relationships_view_name: str
kg_entity_view_name: str
# class KGAnswerApproach(BaseModel):
# search_type: KGSearchType
# search_strategy: KGAnswerStrategy
# relationship_detection: KGRelationshipDetection
# format: KGAnswerFormat
# broken_down_question: str | None = None
# divide_and_conquer: YesNoEnum | None = None
class KGAnswerApproach(BaseModel):
search_type: KGSearchType
search_strategy: KGAnswerStrategy
relationship_detection: KGRelationshipDetection
format: KGAnswerFormat
broken_down_question: str | None = None
divide_and_conquer: YesNoEnum | None = None
# class KGQuestionRelationshipExtractionResult(BaseModel):
# relationships: list[str]
class KGQuestionRelationshipExtractionResult(BaseModel):
relationships: list[str]
# class KGQuestionExtractionResult(BaseModel):
# entities: list[str]
# relationships: list[str]
# time_filter: str | None
class KGQuestionExtractionResult(BaseModel):
entities: list[str]
relationships: list[str]
time_filter: str | None
# class KGExpandedGraphObjects(BaseModel):
# entities: list[str]
# relationships: list[str]
class KGExpandedGraphObjects(BaseModel):
entities: list[str]
relationships: list[str]
# class KGSteps(BaseModel):
# description: str
# activities: list[str]
class KGSteps(BaseModel):
description: str
activities: list[str]
# class KGEntityDocInfo(BaseModel):
# doc_id: str | None
# doc_semantic_id: str | None
# doc_link: str | None
# semantic_entity_name: str
# semantic_linked_entity_name: str
class KGEntityDocInfo(BaseModel):
doc_id: str | None
doc_semantic_id: str | None
doc_link: str | None
semantic_entity_name: str
semantic_linked_entity_name: str

View File

@@ -1,255 +1,255 @@
# from datetime import datetime
# from typing import cast
from datetime import datetime
from typing import cast
# from langchain_core.messages import HumanMessage
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
# from pydantic import ValidationError
from langchain_core.messages import HumanMessage
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from pydantic import ValidationError
# from onyx.agents.agent_search.kb_search.graph_utils import get_near_empty_step_results
# from onyx.agents.agent_search.kb_search.models import KGQuestionEntityExtractionResult
# from onyx.agents.agent_search.kb_search.models import (
# KGQuestionRelationshipExtractionResult,
# )
# from onyx.agents.agent_search.kb_search.states import EntityRelationshipExtractionUpdate
# from onyx.agents.agent_search.kb_search.states import MainState
# from onyx.agents.agent_search.models import GraphConfig
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.configs.kg_configs import KG_ENTITY_EXTRACTION_TIMEOUT
# from onyx.configs.kg_configs import KG_RELATIONSHIP_EXTRACTION_TIMEOUT
# from onyx.db.engine.sql_engine import get_session_with_current_tenant
# from onyx.db.kg_temp_view import create_views
# from onyx.db.kg_temp_view import get_user_view_names
# from onyx.db.relationships import get_allowed_relationship_type_pairs
# from onyx.kg.utils.extraction_utils import get_entity_types_str
# from onyx.kg.utils.extraction_utils import get_relationship_types_str
# from onyx.prompts.kg_prompts import QUERY_ENTITY_EXTRACTION_PROMPT
# from onyx.prompts.kg_prompts import QUERY_RELATIONSHIP_EXTRACTION_PROMPT
# from onyx.utils.logger import setup_logger
# from onyx.utils.threadpool_concurrency import run_with_timeout
# from shared_configs.contextvars import get_current_tenant_id
from onyx.agents.agent_search.kb_search.graph_utils import get_near_empty_step_results
from onyx.agents.agent_search.kb_search.models import KGQuestionEntityExtractionResult
from onyx.agents.agent_search.kb_search.models import (
KGQuestionRelationshipExtractionResult,
)
from onyx.agents.agent_search.kb_search.states import EntityRelationshipExtractionUpdate
from onyx.agents.agent_search.kb_search.states import MainState
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.configs.kg_configs import KG_ENTITY_EXTRACTION_TIMEOUT
from onyx.configs.kg_configs import KG_RELATIONSHIP_EXTRACTION_TIMEOUT
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.kg_temp_view import create_views
from onyx.db.kg_temp_view import get_user_view_names
from onyx.db.relationships import get_allowed_relationship_type_pairs
from onyx.kg.utils.extraction_utils import get_entity_types_str
from onyx.kg.utils.extraction_utils import get_relationship_types_str
from onyx.prompts.kg_prompts import QUERY_ENTITY_EXTRACTION_PROMPT
from onyx.prompts.kg_prompts import QUERY_RELATIONSHIP_EXTRACTION_PROMPT
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
from shared_configs.contextvars import get_current_tenant_id
# logger = setup_logger()
logger = setup_logger()
# def extract_ert(
# state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
# ) -> EntityRelationshipExtractionUpdate:
# """
# LangGraph node to start the agentic search process.
# """
def extract_ert(
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> EntityRelationshipExtractionUpdate:
"""
LangGraph node to start the agentic search process.
"""
# # recheck KG enablement at outset KG graph
# recheck KG enablement at outset KG graph
# if not config["metadata"]["config"].behavior.kg_config_settings.KG_ENABLED:
# logger.error("KG approach is not enabled, the KG agent flow cannot run.")
# raise ValueError("KG approach is not enabled, the KG agent flow cannot run.")
if not config["metadata"]["config"].behavior.kg_config_settings.KG_ENABLED:
logger.error("KG approach is not enabled, the KG agent flow cannot run.")
raise ValueError("KG approach is not enabled, the KG agent flow cannot run.")
# _KG_STEP_NR = 1
_KG_STEP_NR = 1
# node_start_time = datetime.now()
node_start_time = datetime.now()
# graph_config = cast(GraphConfig, config["metadata"]["config"])
graph_config = cast(GraphConfig, config["metadata"]["config"])
# if graph_config.tooling.search_tool is None:
# raise ValueError("Search tool is not set")
# elif graph_config.tooling.search_tool.user is None:
# raise ValueError("User is not set")
# else:
# user_email = graph_config.tooling.search_tool.user.email
# user_name = user_email.split("@")[0] or "unknown"
if graph_config.tooling.search_tool is None:
raise ValueError("Search tool is not set")
elif graph_config.tooling.search_tool.user is None:
raise ValueError("User is not set")
else:
user_email = graph_config.tooling.search_tool.user.email
user_name = user_email.split("@")[0] or "unknown"
# # first four lines duplicates from generate_initial_answer
# question = state.question
# today_date = datetime.now().strftime("%A, %Y-%m-%d")
# first four lines duplicates from generate_initial_answer
question = state.question
today_date = datetime.now().strftime("%A, %Y-%m-%d")
# all_entity_types = get_entity_types_str(active=True)
# all_relationship_types = get_relationship_types_str(active=True)
all_entity_types = get_entity_types_str(active=True)
all_relationship_types = get_relationship_types_str(active=True)
# # Create temporary views. TODO: move into parallel step, if ultimately materialized
# tenant_id = get_current_tenant_id()
# kg_views = get_user_view_names(user_email, tenant_id)
# with get_session_with_current_tenant() as db_session:
# create_views(
# db_session,
# tenant_id=tenant_id,
# user_email=user_email,
# allowed_docs_view_name=kg_views.allowed_docs_view_name,
# kg_relationships_view_name=kg_views.kg_relationships_view_name,
# kg_entity_view_name=kg_views.kg_entity_view_name,
# )
# Create temporary views. TODO: move into parallel step, if ultimately materialized
tenant_id = get_current_tenant_id()
kg_views = get_user_view_names(user_email, tenant_id)
with get_session_with_current_tenant() as db_session:
create_views(
db_session,
tenant_id=tenant_id,
user_email=user_email,
allowed_docs_view_name=kg_views.allowed_docs_view_name,
kg_relationships_view_name=kg_views.kg_relationships_view_name,
kg_entity_view_name=kg_views.kg_entity_view_name,
)
# ### get the entities, terms, and filters
### get the entities, terms, and filters
# query_extraction_pre_prompt = QUERY_ENTITY_EXTRACTION_PROMPT.format(
# entity_types=all_entity_types,
# relationship_types=all_relationship_types,
# )
query_extraction_pre_prompt = QUERY_ENTITY_EXTRACTION_PROMPT.format(
entity_types=all_entity_types,
relationship_types=all_relationship_types,
)
# query_extraction_prompt = (
# query_extraction_pre_prompt.replace("---content---", question)
# .replace("---today_date---", today_date)
# .replace("---user_name---", f"EMPLOYEE:{user_name}")
# .replace("{{", "{")
# .replace("}}", "}")
# )
query_extraction_prompt = (
query_extraction_pre_prompt.replace("---content---", question)
.replace("---today_date---", today_date)
.replace("---user_name---", f"EMPLOYEE:{user_name}")
.replace("{{", "{")
.replace("}}", "}")
)
# msg = [
# HumanMessage(
# content=query_extraction_prompt,
# )
# ]
# primary_llm = graph_config.tooling.primary_llm
# # Grader
# try:
# llm_response = run_with_timeout(
# KG_ENTITY_EXTRACTION_TIMEOUT,
# primary_llm.invoke_langchain,
# prompt=msg,
# timeout_override=15,
# max_tokens=300,
# )
msg = [
HumanMessage(
content=query_extraction_prompt,
)
]
primary_llm = graph_config.tooling.primary_llm
# Grader
try:
llm_response = run_with_timeout(
KG_ENTITY_EXTRACTION_TIMEOUT,
primary_llm.invoke_langchain,
prompt=msg,
timeout_override=15,
max_tokens=300,
)
# cleaned_response = (
# str(llm_response.content)
# .replace("{{", "{")
# .replace("}}", "}")
# .replace("```json\n", "")
# .replace("\n```", "")
# .replace("\n", "")
# )
# first_bracket = cleaned_response.find("{")
# last_bracket = cleaned_response.rfind("}")
# cleaned_response = cleaned_response[first_bracket : last_bracket + 1]
cleaned_response = (
str(llm_response.content)
.replace("{{", "{")
.replace("}}", "}")
.replace("```json\n", "")
.replace("\n```", "")
.replace("\n", "")
)
first_bracket = cleaned_response.find("{")
last_bracket = cleaned_response.rfind("}")
cleaned_response = cleaned_response[first_bracket : last_bracket + 1]
# entity_extraction_result = KGQuestionEntityExtractionResult.model_validate_json(
# cleaned_response
# )
# except ValidationError:
# logger.error("Failed to parse LLM response as JSON in Entity Extraction")
# entity_extraction_result = KGQuestionEntityExtractionResult(
# entities=[], time_filter=""
# )
# except Exception as e:
# logger.error(f"Error in extract_ert: {e}")
# entity_extraction_result = KGQuestionEntityExtractionResult(
# entities=[], time_filter=""
# )
entity_extraction_result = KGQuestionEntityExtractionResult.model_validate_json(
cleaned_response
)
except ValidationError:
logger.error("Failed to parse LLM response as JSON in Entity Extraction")
entity_extraction_result = KGQuestionEntityExtractionResult(
entities=[], time_filter=""
)
except Exception as e:
logger.error(f"Error in extract_ert: {e}")
entity_extraction_result = KGQuestionEntityExtractionResult(
entities=[], time_filter=""
)
# # remove the attribute filters from the entities to for the purpose of the relationship
# entities_no_attributes = [
# entity.split("--")[0] for entity in entity_extraction_result.entities
# ]
# ert_entities_string = f"Entities: {entities_no_attributes}\n"
# remove the attribute filters from the entities to for the purpose of the relationship
entities_no_attributes = [
entity.split("--")[0] for entity in entity_extraction_result.entities
]
ert_entities_string = f"Entities: {entities_no_attributes}\n"
# ### get the relationships
### get the relationships
# # find the relationship types that match the extracted entity types
# find the relationship types that match the extracted entity types
# with get_session_with_current_tenant() as db_session:
# allowed_relationship_pairs = get_allowed_relationship_type_pairs(
# db_session, entity_extraction_result.entities
# )
with get_session_with_current_tenant() as db_session:
allowed_relationship_pairs = get_allowed_relationship_type_pairs(
db_session, entity_extraction_result.entities
)
# query_relationship_extraction_prompt = (
# QUERY_RELATIONSHIP_EXTRACTION_PROMPT.replace("---question---", question)
# .replace("---today_date---", today_date)
# .replace(
# "---relationship_type_options---",
# " - " + "\n - ".join(allowed_relationship_pairs),
# )
# .replace("---identified_entities---", ert_entities_string)
# .replace("---entity_types---", all_entity_types)
# .replace("{{", "{")
# .replace("}}", "}")
# )
query_relationship_extraction_prompt = (
QUERY_RELATIONSHIP_EXTRACTION_PROMPT.replace("---question---", question)
.replace("---today_date---", today_date)
.replace(
"---relationship_type_options---",
" - " + "\n - ".join(allowed_relationship_pairs),
)
.replace("---identified_entities---", ert_entities_string)
.replace("---entity_types---", all_entity_types)
.replace("{{", "{")
.replace("}}", "}")
)
# msg = [
# HumanMessage(
# content=query_relationship_extraction_prompt,
# )
# ]
# primary_llm = graph_config.tooling.primary_llm
# # Grader
# try:
# llm_response = run_with_timeout(
# KG_RELATIONSHIP_EXTRACTION_TIMEOUT,
# primary_llm.invoke_langchain,
# prompt=msg,
# timeout_override=15,
# max_tokens=300,
# )
msg = [
HumanMessage(
content=query_relationship_extraction_prompt,
)
]
primary_llm = graph_config.tooling.primary_llm
# Grader
try:
llm_response = run_with_timeout(
KG_RELATIONSHIP_EXTRACTION_TIMEOUT,
primary_llm.invoke_langchain,
prompt=msg,
timeout_override=15,
max_tokens=300,
)
# cleaned_response = (
# str(llm_response.content)
# .replace("{{", "{")
# .replace("}}", "}")
# .replace("```json\n", "")
# .replace("\n```", "")
# .replace("\n", "")
# )
# first_bracket = cleaned_response.find("{")
# last_bracket = cleaned_response.rfind("}")
# cleaned_response = cleaned_response[first_bracket : last_bracket + 1]
# cleaned_response = cleaned_response.replace("{{", '{"')
# cleaned_response = cleaned_response.replace("}}", '"}')
cleaned_response = (
str(llm_response.content)
.replace("{{", "{")
.replace("}}", "}")
.replace("```json\n", "")
.replace("\n```", "")
.replace("\n", "")
)
first_bracket = cleaned_response.find("{")
last_bracket = cleaned_response.rfind("}")
cleaned_response = cleaned_response[first_bracket : last_bracket + 1]
cleaned_response = cleaned_response.replace("{{", '{"')
cleaned_response = cleaned_response.replace("}}", '"}')
# try:
# relationship_extraction_result = (
# KGQuestionRelationshipExtractionResult.model_validate_json(
# cleaned_response
# )
# )
# except ValidationError:
# logger.error(
# "Failed to parse LLM response as JSON in Relationship Extraction"
# )
# relationship_extraction_result = KGQuestionRelationshipExtractionResult(
# relationships=[],
# )
# except Exception as e:
# logger.error(f"Error in extract_ert: {e}")
# relationship_extraction_result = KGQuestionRelationshipExtractionResult(
# relationships=[],
# )
try:
relationship_extraction_result = (
KGQuestionRelationshipExtractionResult.model_validate_json(
cleaned_response
)
)
except ValidationError:
logger.error(
"Failed to parse LLM response as JSON in Relationship Extraction"
)
relationship_extraction_result = KGQuestionRelationshipExtractionResult(
relationships=[],
)
except Exception as e:
logger.error(f"Error in extract_ert: {e}")
relationship_extraction_result = KGQuestionRelationshipExtractionResult(
relationships=[],
)
# ## STEP 1
# # Stream answer pieces out to the UI for Step 1
## STEP 1
# Stream answer pieces out to the UI for Step 1
# extracted_entity_string = " \n ".join(
# [x.split("--")[0] for x in entity_extraction_result.entities]
# )
# extracted_relationship_string = " \n ".join(
# relationship_extraction_result.relationships
# )
extracted_entity_string = " \n ".join(
[x.split("--")[0] for x in entity_extraction_result.entities]
)
extracted_relationship_string = " \n ".join(
relationship_extraction_result.relationships
)
# step_answer = f"""Entities and relationships have been extracted from query - \n \
# Entities: {extracted_entity_string} - \n Relationships: {extracted_relationship_string}"""
step_answer = f"""Entities and relationships have been extracted from query - \n \
Entities: {extracted_entity_string} - \n Relationships: {extracted_relationship_string}"""
# return EntityRelationshipExtractionUpdate(
# entities_types_str=all_entity_types,
# relationship_types_str=all_relationship_types,
# extracted_entities_w_attributes=entity_extraction_result.entities,
# extracted_entities_no_attributes=entities_no_attributes,
# extracted_relationships=relationship_extraction_result.relationships,
# time_filter=entity_extraction_result.time_filter,
# kg_doc_temp_view_name=kg_views.allowed_docs_view_name,
# kg_rel_temp_view_name=kg_views.kg_relationships_view_name,
# kg_entity_temp_view_name=kg_views.kg_entity_view_name,
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="main",
# node_name="extract entities terms",
# node_start_time=node_start_time,
# )
# ],
# step_results=[
# get_near_empty_step_results(
# step_number=_KG_STEP_NR,
# step_answer=step_answer,
# verified_reranked_documents=[],
# )
# ],
# )
return EntityRelationshipExtractionUpdate(
entities_types_str=all_entity_types,
relationship_types_str=all_relationship_types,
extracted_entities_w_attributes=entity_extraction_result.entities,
extracted_entities_no_attributes=entities_no_attributes,
extracted_relationships=relationship_extraction_result.relationships,
time_filter=entity_extraction_result.time_filter,
kg_doc_temp_view_name=kg_views.allowed_docs_view_name,
kg_rel_temp_view_name=kg_views.kg_relationships_view_name,
kg_entity_temp_view_name=kg_views.kg_entity_view_name,
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="extract entities terms",
node_start_time=node_start_time,
)
],
step_results=[
get_near_empty_step_results(
step_number=_KG_STEP_NR,
step_answer=step_answer,
verified_reranked_documents=[],
)
],
)

View File

@@ -1,312 +1,312 @@
# from datetime import datetime
# from typing import cast
from datetime import datetime
from typing import cast
# from langchain_core.messages import HumanMessage
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from langchain_core.messages import HumanMessage
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from onyx.agents.agent_search.kb_search.graph_utils import (
# create_minimal_connected_query_graph,
# )
# from onyx.agents.agent_search.kb_search.graph_utils import get_near_empty_step_results
# from onyx.agents.agent_search.kb_search.models import KGAnswerApproach
# from onyx.agents.agent_search.kb_search.states import AnalysisUpdate
# from onyx.agents.agent_search.kb_search.states import KGAnswerFormat
# from onyx.agents.agent_search.kb_search.states import KGAnswerStrategy
# from onyx.agents.agent_search.kb_search.states import KGRelationshipDetection
# from onyx.agents.agent_search.kb_search.states import KGSearchType
# from onyx.agents.agent_search.kb_search.states import MainState
# from onyx.agents.agent_search.kb_search.states import YesNoEnum
# from onyx.agents.agent_search.models import GraphConfig
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.configs.kg_configs import KG_STRATEGY_GENERATION_TIMEOUT
# from onyx.db.engine.sql_engine import get_session_with_current_tenant
# from onyx.db.entities import get_document_id_for_entity
# from onyx.kg.clustering.normalizations import normalize_entities
# from onyx.kg.clustering.normalizations import normalize_relationships
# from onyx.kg.utils.formatting_utils import split_relationship_id
# from onyx.prompts.kg_prompts import STRATEGY_GENERATION_PROMPT
# from onyx.utils.logger import setup_logger
# from onyx.utils.threadpool_concurrency import run_with_timeout
from onyx.agents.agent_search.kb_search.graph_utils import (
create_minimal_connected_query_graph,
)
from onyx.agents.agent_search.kb_search.graph_utils import get_near_empty_step_results
from onyx.agents.agent_search.kb_search.models import KGAnswerApproach
from onyx.agents.agent_search.kb_search.states import AnalysisUpdate
from onyx.agents.agent_search.kb_search.states import KGAnswerFormat
from onyx.agents.agent_search.kb_search.states import KGAnswerStrategy
from onyx.agents.agent_search.kb_search.states import KGRelationshipDetection
from onyx.agents.agent_search.kb_search.states import KGSearchType
from onyx.agents.agent_search.kb_search.states import MainState
from onyx.agents.agent_search.kb_search.states import YesNoEnum
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.configs.kg_configs import KG_STRATEGY_GENERATION_TIMEOUT
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.entities import get_document_id_for_entity
from onyx.kg.clustering.normalizations import normalize_entities
from onyx.kg.clustering.normalizations import normalize_relationships
from onyx.kg.utils.formatting_utils import split_relationship_id
from onyx.prompts.kg_prompts import STRATEGY_GENERATION_PROMPT
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
# logger = setup_logger()
logger = setup_logger()
# def _articulate_normalizations(
# entity_normalization_map: dict[str, str],
# relationship_normalization_map: dict[str, str],
# ) -> str:
def _articulate_normalizations(
entity_normalization_map: dict[str, str],
relationship_normalization_map: dict[str, str],
) -> str:
# remark_list: list[str] = []
remark_list: list[str] = []
# if entity_normalization_map:
# remark_list.append("\n Entities:")
# for extracted_entity, normalized_entity in entity_normalization_map.items():
# remark_list.append(f" - {extracted_entity} -> {normalized_entity}")
if entity_normalization_map:
remark_list.append("\n Entities:")
for extracted_entity, normalized_entity in entity_normalization_map.items():
remark_list.append(f" - {extracted_entity} -> {normalized_entity}")
# if relationship_normalization_map:
# remark_list.append(" \n Relationships:")
# for (
# extracted_relationship,
# normalized_relationship,
# ) in relationship_normalization_map.items():
# remark_list.append(
# f" - {extracted_relationship} -> {normalized_relationship}"
# )
if relationship_normalization_map:
remark_list.append(" \n Relationships:")
for (
extracted_relationship,
normalized_relationship,
) in relationship_normalization_map.items():
remark_list.append(
f" - {extracted_relationship} -> {normalized_relationship}"
)
# return " \n ".join(remark_list)
return " \n ".join(remark_list)
# def _get_fully_connected_entities(
# entities: list[str], relationships: list[str]
# ) -> list[str]:
# """
# Analyze the connectedness of the entities and relationships.
# """
# # Build a dictionary to track connections for each entity
# entity_connections: dict[str, set[str]] = {entity: set() for entity in entities}
def _get_fully_connected_entities(
entities: list[str], relationships: list[str]
) -> list[str]:
"""
Analyze the connectedness of the entities and relationships.
"""
# Build a dictionary to track connections for each entity
entity_connections: dict[str, set[str]] = {entity: set() for entity in entities}
# # Parse relationships to build connection graph
# for relationship in relationships:
# # Split relationship into parts. Test for proper formatting just in case.
# # Should never be an error though at this point.
# parts = split_relationship_id(relationship)
# if len(parts) != 3:
# raise ValueError(f"Invalid relationship: {relationship}")
# Parse relationships to build connection graph
for relationship in relationships:
# Split relationship into parts. Test for proper formatting just in case.
# Should never be an error though at this point.
parts = split_relationship_id(relationship)
if len(parts) != 3:
raise ValueError(f"Invalid relationship: {relationship}")
# entity1 = parts[0]
# entity2 = parts[2]
entity1 = parts[0]
entity2 = parts[2]
# # Add bidirectional connections
# if entity1 in entity_connections:
# entity_connections[entity1].add(entity2)
# if entity2 in entity_connections:
# entity_connections[entity2].add(entity1)
# Add bidirectional connections
if entity1 in entity_connections:
entity_connections[entity1].add(entity2)
if entity2 in entity_connections:
entity_connections[entity2].add(entity1)
# # Find entities connected to all others
# fully_connected_entities = []
# all_entities = set(entities)
# Find entities connected to all others
fully_connected_entities = []
all_entities = set(entities)
# for entity, connections in entity_connections.items():
# # Check if this entity is connected to all other entities
# if connections == all_entities - {entity}:
# fully_connected_entities.append(entity)
for entity, connections in entity_connections.items():
# Check if this entity is connected to all other entities
if connections == all_entities - {entity}:
fully_connected_entities.append(entity)
# return fully_connected_entities
return fully_connected_entities
# def _check_for_single_doc(
# normalized_entities: list[str],
# raw_entities: list[str],
# normalized_relationship_strings: list[str],
# raw_relationships: list[str],
# normalized_time_filter: str | None,
# ) -> str | None:
# """
# Check if the query is for a single document, like 'Summarize ticket ENG-2243K'.
# None is returned if the query is not for a single document.
# """
# if (
# len(normalized_entities) == 1
# and len(raw_entities) == 1
# and len(normalized_relationship_strings) == 0
# and len(raw_relationships) == 0
# and normalized_time_filter is None
# ):
# with get_session_with_current_tenant() as db_session:
# single_doc_id = get_document_id_for_entity(
# db_session, normalized_entities[0]
# )
# else:
# single_doc_id = None
# return single_doc_id
def _check_for_single_doc(
normalized_entities: list[str],
raw_entities: list[str],
normalized_relationship_strings: list[str],
raw_relationships: list[str],
normalized_time_filter: str | None,
) -> str | None:
"""
Check if the query is for a single document, like 'Summarize ticket ENG-2243K'.
None is returned if the query is not for a single document.
"""
if (
len(normalized_entities) == 1
and len(raw_entities) == 1
and len(normalized_relationship_strings) == 0
and len(raw_relationships) == 0
and normalized_time_filter is None
):
with get_session_with_current_tenant() as db_session:
single_doc_id = get_document_id_for_entity(
db_session, normalized_entities[0]
)
else:
single_doc_id = None
return single_doc_id
# def analyze(
# state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
# ) -> AnalysisUpdate:
# """
# LangGraph node to start the agentic search process.
# """
def analyze(
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> AnalysisUpdate:
"""
LangGraph node to start the agentic search process.
"""
# _KG_STEP_NR = 2
_KG_STEP_NR = 2
# node_start_time = datetime.now()
node_start_time = datetime.now()
# graph_config = cast(GraphConfig, config["metadata"]["config"])
# question = state.question
# entities = (
# state.extracted_entities_no_attributes
# ) # attribute knowledge is not required for this step
# relationships = state.extracted_relationships
# time_filter = state.time_filter
graph_config = cast(GraphConfig, config["metadata"]["config"])
question = state.question
entities = (
state.extracted_entities_no_attributes
) # attribute knowledge is not required for this step
relationships = state.extracted_relationships
time_filter = state.time_filter
# ## STEP 2 - stream out goals
## STEP 2 - stream out goals
# # Continue with node
# Continue with node
# normalized_entities = normalize_entities(
# entities,
# state.extracted_entities_w_attributes,
# allowed_docs_temp_view_name=state.kg_doc_temp_view_name,
# )
normalized_entities = normalize_entities(
entities,
state.extracted_entities_w_attributes,
allowed_docs_temp_view_name=state.kg_doc_temp_view_name,
)
# normalized_relationships = normalize_relationships(
# relationships, normalized_entities.entity_normalization_map
# )
# normalized_time_filter = time_filter
normalized_relationships = normalize_relationships(
relationships, normalized_entities.entity_normalization_map
)
normalized_time_filter = time_filter
# # If single-doc inquiry, send to single-doc processing directly
# If single-doc inquiry, send to single-doc processing directly
# single_doc_id = _check_for_single_doc(
# normalized_entities=normalized_entities.entities,
# raw_entities=entities,
# normalized_relationship_strings=normalized_relationships.relationships,
# raw_relationships=relationships,
# normalized_time_filter=normalized_time_filter,
# )
single_doc_id = _check_for_single_doc(
normalized_entities=normalized_entities.entities,
raw_entities=entities,
normalized_relationship_strings=normalized_relationships.relationships,
raw_relationships=relationships,
normalized_time_filter=normalized_time_filter,
)
# # Expand the entities and relationships to make sure that entities are connected
# Expand the entities and relationships to make sure that entities are connected
# graph_expansion = create_minimal_connected_query_graph(
# normalized_entities.entities,
# normalized_relationships.relationships,
# max_depth=2,
# )
graph_expansion = create_minimal_connected_query_graph(
normalized_entities.entities,
normalized_relationships.relationships,
max_depth=2,
)
# query_graph_entities = graph_expansion.entities
# query_graph_relationships = graph_expansion.relationships
query_graph_entities = graph_expansion.entities
query_graph_relationships = graph_expansion.relationships
# # Evaluate whether a search needs to be done after identifying all entities and relationships
# Evaluate whether a search needs to be done after identifying all entities and relationships
# strategy_generation_prompt = (
# STRATEGY_GENERATION_PROMPT.replace(
# "---entities---", "\n".join(query_graph_entities)
# )
# .replace("---relationships---", "\n".join(query_graph_relationships))
# .replace("---possible_entities---", state.entities_types_str)
# .replace("---possible_relationships---", state.relationship_types_str)
# .replace("---question---", question)
# )
strategy_generation_prompt = (
STRATEGY_GENERATION_PROMPT.replace(
"---entities---", "\n".join(query_graph_entities)
)
.replace("---relationships---", "\n".join(query_graph_relationships))
.replace("---possible_entities---", state.entities_types_str)
.replace("---possible_relationships---", state.relationship_types_str)
.replace("---question---", question)
)
# msg = [
# HumanMessage(
# content=strategy_generation_prompt,
# )
# ]
# primary_llm = graph_config.tooling.primary_llm
# # Grader
# try:
# llm_response = run_with_timeout(
# KG_STRATEGY_GENERATION_TIMEOUT,
# # fast_llm.invoke,
# primary_llm.invoke_langchain,
# prompt=msg,
# timeout_override=5,
# max_tokens=100,
# )
msg = [
HumanMessage(
content=strategy_generation_prompt,
)
]
primary_llm = graph_config.tooling.primary_llm
# Grader
try:
llm_response = run_with_timeout(
KG_STRATEGY_GENERATION_TIMEOUT,
# fast_llm.invoke,
primary_llm.invoke_langchain,
prompt=msg,
timeout_override=5,
max_tokens=100,
)
# cleaned_response = (
# str(llm_response.content)
# .replace("```json\n", "")
# .replace("\n```", "")
# .replace("\n", "")
# )
# first_bracket = cleaned_response.find("{")
# last_bracket = cleaned_response.rfind("}")
# cleaned_response = cleaned_response[first_bracket : last_bracket + 1]
cleaned_response = (
str(llm_response.content)
.replace("```json\n", "")
.replace("\n```", "")
.replace("\n", "")
)
first_bracket = cleaned_response.find("{")
last_bracket = cleaned_response.rfind("}")
cleaned_response = cleaned_response[first_bracket : last_bracket + 1]
# try:
# approach_extraction_result = KGAnswerApproach.model_validate_json(
# cleaned_response
# )
# search_type = approach_extraction_result.search_type
# search_strategy = approach_extraction_result.search_strategy
# relationship_detection = (
# approach_extraction_result.relationship_detection.value
# )
# output_format = approach_extraction_result.format
# broken_down_question = approach_extraction_result.broken_down_question
# divide_and_conquer = approach_extraction_result.divide_and_conquer
# except ValueError:
# logger.error(
# "Failed to parse LLM response as JSON in Entity-Term Extraction"
# )
# search_type = KGSearchType.SEARCH
# search_strategy = KGAnswerStrategy.DEEP
# relationship_detection = KGRelationshipDetection.RELATIONSHIPS.value
# output_format = KGAnswerFormat.TEXT
# broken_down_question = None
# divide_and_conquer = YesNoEnum.NO
# if search_strategy is None or output_format is None:
# raise ValueError(f"Invalid strategy: {cleaned_response}")
try:
approach_extraction_result = KGAnswerApproach.model_validate_json(
cleaned_response
)
search_type = approach_extraction_result.search_type
search_strategy = approach_extraction_result.search_strategy
relationship_detection = (
approach_extraction_result.relationship_detection.value
)
output_format = approach_extraction_result.format
broken_down_question = approach_extraction_result.broken_down_question
divide_and_conquer = approach_extraction_result.divide_and_conquer
except ValueError:
logger.error(
"Failed to parse LLM response as JSON in Entity-Term Extraction"
)
search_type = KGSearchType.SEARCH
search_strategy = KGAnswerStrategy.DEEP
relationship_detection = KGRelationshipDetection.RELATIONSHIPS.value
output_format = KGAnswerFormat.TEXT
broken_down_question = None
divide_and_conquer = YesNoEnum.NO
if search_strategy is None or output_format is None:
raise ValueError(f"Invalid strategy: {cleaned_response}")
# except Exception as e:
# logger.error(f"Error in strategy generation: {e}")
# raise e
except Exception as e:
logger.error(f"Error in strategy generation: {e}")
raise e
# # Stream out relevant results
# Stream out relevant results
# if single_doc_id:
# search_strategy = (
# KGAnswerStrategy.DEEP
# ) # if a single doc is identified, we will want to look at the details.
if single_doc_id:
search_strategy = (
KGAnswerStrategy.DEEP
) # if a single doc is identified, we will want to look at the details.
# step_answer = f"Strategy and format have been extracted from query. Strategy: {search_strategy.value}, \
# Format: {output_format.value}, Broken down question: {broken_down_question}"
step_answer = f"Strategy and format have been extracted from query. Strategy: {search_strategy.value}, \
Format: {output_format.value}, Broken down question: {broken_down_question}"
# extraction_detected_relationships = len(query_graph_relationships) > 0
# if extraction_detected_relationships:
# query_type = KGRelationshipDetection.RELATIONSHIPS.value
extraction_detected_relationships = len(query_graph_relationships) > 0
if extraction_detected_relationships:
query_type = KGRelationshipDetection.RELATIONSHIPS.value
# if extraction_detected_relationships:
# logger.warning(
# "Fyi - Extraction detected relationships: "
# f"{extraction_detected_relationships}, "
# f"but relationship detection: {relationship_detection}"
# )
# else:
# query_type = KGRelationshipDetection.NO_RELATIONSHIPS.value
if extraction_detected_relationships:
logger.warning(
"Fyi - Extraction detected relationships: "
f"{extraction_detected_relationships}, "
f"but relationship detection: {relationship_detection}"
)
else:
query_type = KGRelationshipDetection.NO_RELATIONSHIPS.value
# # End node
# End node
# return AnalysisUpdate(
# normalized_core_entities=normalized_entities.entities,
# normalized_core_relationships=normalized_relationships.relationships,
# entity_normalization_map=normalized_entities.entity_normalization_map,
# relationship_normalization_map=normalized_relationships.relationship_normalization_map,
# query_graph_entities_no_attributes=query_graph_entities,
# query_graph_entities_w_attributes=normalized_entities.entities_w_attributes,
# query_graph_relationships=query_graph_relationships,
# normalized_terms=[], # TODO: remove fully later
# normalized_time_filter=normalized_time_filter,
# strategy=search_strategy,
# broken_down_question=broken_down_question,
# output_format=output_format,
# divide_and_conquer=divide_and_conquer,
# single_doc_id=single_doc_id,
# search_type=search_type,
# query_type=query_type,
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="main",
# node_name="analyze",
# node_start_time=node_start_time,
# )
# ],
# step_results=[
# get_near_empty_step_results(
# step_number=_KG_STEP_NR,
# step_answer=step_answer,
# verified_reranked_documents=[],
# )
# ],
# remarks=[
# _articulate_normalizations(
# entity_normalization_map=normalized_entities.entity_normalization_map,
# relationship_normalization_map=normalized_relationships.relationship_normalization_map,
# )
# ],
# )
return AnalysisUpdate(
normalized_core_entities=normalized_entities.entities,
normalized_core_relationships=normalized_relationships.relationships,
entity_normalization_map=normalized_entities.entity_normalization_map,
relationship_normalization_map=normalized_relationships.relationship_normalization_map,
query_graph_entities_no_attributes=query_graph_entities,
query_graph_entities_w_attributes=normalized_entities.entities_w_attributes,
query_graph_relationships=query_graph_relationships,
normalized_terms=[], # TODO: remove fully later
normalized_time_filter=normalized_time_filter,
strategy=search_strategy,
broken_down_question=broken_down_question,
output_format=output_format,
divide_and_conquer=divide_and_conquer,
single_doc_id=single_doc_id,
search_type=search_type,
query_type=query_type,
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="analyze",
node_start_time=node_start_time,
)
],
step_results=[
get_near_empty_step_results(
step_number=_KG_STEP_NR,
step_answer=step_answer,
verified_reranked_documents=[],
)
],
remarks=[
_articulate_normalizations(
entity_normalization_map=normalized_entities.entity_normalization_map,
relationship_normalization_map=normalized_relationships.relationship_normalization_map,
)
],
)

View File

@@ -1,185 +1,185 @@
# from datetime import datetime
# from typing import cast
from datetime import datetime
from typing import cast
# from langchain_core.messages import HumanMessage
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from langchain_core.messages import HumanMessage
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from onyx.agents.agent_search.kb_search.states import DeepSearchFilterUpdate
# from onyx.agents.agent_search.kb_search.states import KGFilterConstructionResults
# from onyx.agents.agent_search.kb_search.states import MainState
# from onyx.agents.agent_search.models import GraphConfig
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.configs.kg_configs import KG_FILTER_CONSTRUCTION_TIMEOUT
# from onyx.db.engine.sql_engine import get_session_with_current_tenant
# from onyx.db.entity_type import get_entity_types_with_grounded_source_name
# from onyx.kg.utils.formatting_utils import make_entity_id
# from onyx.prompts.kg_prompts import SEARCH_FILTER_CONSTRUCTION_PROMPT
# from onyx.utils.logger import setup_logger
# from onyx.utils.threadpool_concurrency import run_with_timeout
from onyx.agents.agent_search.kb_search.states import DeepSearchFilterUpdate
from onyx.agents.agent_search.kb_search.states import KGFilterConstructionResults
from onyx.agents.agent_search.kb_search.states import MainState
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.configs.kg_configs import KG_FILTER_CONSTRUCTION_TIMEOUT
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.entity_type import get_entity_types_with_grounded_source_name
from onyx.kg.utils.formatting_utils import make_entity_id
from onyx.prompts.kg_prompts import SEARCH_FILTER_CONSTRUCTION_PROMPT
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
# logger = setup_logger()
logger = setup_logger()
# def construct_deep_search_filters(
# state: MainState, config: RunnableConfig, writer: StreamWriter
# ) -> DeepSearchFilterUpdate:
# """
# LangGraph node to start the agentic search process.
# """
def construct_deep_search_filters(
state: MainState, config: RunnableConfig, writer: StreamWriter
) -> DeepSearchFilterUpdate:
"""
LangGraph node to start the agentic search process.
"""
# node_start_time = datetime.now()
node_start_time = datetime.now()
# graph_config = cast(GraphConfig, config["metadata"]["config"])
# question = state.question
graph_config = cast(GraphConfig, config["metadata"]["config"])
question = state.question
# entities_types_str = state.entities_types_str
# entities = state.query_graph_entities_no_attributes
# relationships = state.query_graph_relationships
# simple_sql_query = state.sql_query
# simple_sql_results = state.sql_query_results
# source_document_results = state.source_document_results
# if simple_sql_results:
# simple_sql_results_str = "\n".join([str(x) for x in simple_sql_results])
# else:
# simple_sql_results_str = "(no SQL results generated)"
# if source_document_results:
# source_document_results_str = "\n".join(
# [str(x) for x in source_document_results]
# )
# else:
# source_document_results_str = "(no source document results generated)"
entities_types_str = state.entities_types_str
entities = state.query_graph_entities_no_attributes
relationships = state.query_graph_relationships
simple_sql_query = state.sql_query
simple_sql_results = state.sql_query_results
source_document_results = state.source_document_results
if simple_sql_results:
simple_sql_results_str = "\n".join([str(x) for x in simple_sql_results])
else:
simple_sql_results_str = "(no SQL results generated)"
if source_document_results:
source_document_results_str = "\n".join(
[str(x) for x in source_document_results]
)
else:
source_document_results_str = "(no source document results generated)"
# search_filter_construction_prompt = (
# SEARCH_FILTER_CONSTRUCTION_PROMPT.replace(
# "---entity_type_descriptions---",
# entities_types_str,
# )
# .replace(
# "---entity_filters---",
# "\n".join(entities),
# )
# .replace(
# "---relationship_filters---",
# "\n".join(relationships),
# )
# .replace(
# "---sql_query---",
# simple_sql_query or "(no SQL generated)",
# )
# .replace(
# "---sql_results---",
# simple_sql_results_str or "(no SQL results generated)",
# )
# .replace(
# "---source_document_results---",
# source_document_results_str or "(no source document results generated)",
# )
# .replace(
# "---question---",
# question,
# )
# )
search_filter_construction_prompt = (
SEARCH_FILTER_CONSTRUCTION_PROMPT.replace(
"---entity_type_descriptions---",
entities_types_str,
)
.replace(
"---entity_filters---",
"\n".join(entities),
)
.replace(
"---relationship_filters---",
"\n".join(relationships),
)
.replace(
"---sql_query---",
simple_sql_query or "(no SQL generated)",
)
.replace(
"---sql_results---",
simple_sql_results_str or "(no SQL results generated)",
)
.replace(
"---source_document_results---",
source_document_results_str or "(no source document results generated)",
)
.replace(
"---question---",
question,
)
)
# msg = [
# HumanMessage(
# content=search_filter_construction_prompt,
# )
# ]
# llm = graph_config.tooling.primary_llm
# # Grader
# try:
# llm_response = run_with_timeout(
# KG_FILTER_CONSTRUCTION_TIMEOUT,
# llm.invoke_langchain,
# prompt=msg,
# timeout_override=15,
# max_tokens=1400,
# )
msg = [
HumanMessage(
content=search_filter_construction_prompt,
)
]
llm = graph_config.tooling.primary_llm
# Grader
try:
llm_response = run_with_timeout(
KG_FILTER_CONSTRUCTION_TIMEOUT,
llm.invoke_langchain,
prompt=msg,
timeout_override=15,
max_tokens=1400,
)
# cleaned_response = (
# str(llm_response.content)
# .replace("```json\n", "")
# .replace("\n```", "")
# .replace("\n", "")
# )
# first_bracket = cleaned_response.find("{")
# last_bracket = cleaned_response.rfind("}")
cleaned_response = (
str(llm_response.content)
.replace("```json\n", "")
.replace("\n```", "")
.replace("\n", "")
)
first_bracket = cleaned_response.find("{")
last_bracket = cleaned_response.rfind("}")
# if last_bracket == -1 or first_bracket == -1:
# raise ValueError("No valid JSON found in LLM response - no brackets found")
# cleaned_response = cleaned_response[first_bracket : last_bracket + 1]
# cleaned_response = cleaned_response.replace("{{", '{"')
# cleaned_response = cleaned_response.replace("}}", '"}')
if last_bracket == -1 or first_bracket == -1:
raise ValueError("No valid JSON found in LLM response - no brackets found")
cleaned_response = cleaned_response[first_bracket : last_bracket + 1]
cleaned_response = cleaned_response.replace("{{", '{"')
cleaned_response = cleaned_response.replace("}}", '"}')
# try:
try:
# filter_results = KGFilterConstructionResults.model_validate_json(
# cleaned_response
# )
# except ValueError:
# logger.error(
# "Failed to parse LLM response as JSON in Entity-Term Extraction"
# )
# filter_results = KGFilterConstructionResults(
# global_entity_filters=[],
# global_relationship_filters=[],
# local_entity_filters=[],
# source_document_filters=[],
# structure=[],
# )
filter_results = KGFilterConstructionResults.model_validate_json(
cleaned_response
)
except ValueError:
logger.error(
"Failed to parse LLM response as JSON in Entity-Term Extraction"
)
filter_results = KGFilterConstructionResults(
global_entity_filters=[],
global_relationship_filters=[],
local_entity_filters=[],
source_document_filters=[],
structure=[],
)
# except Exception as e:
# logger.error(f"Error in construct_deep_search_filters: {e}")
# filter_results = KGFilterConstructionResults(
# global_entity_filters=[],
# global_relationship_filters=[],
# local_entity_filters=[],
# source_document_filters=[],
# structure=[],
# )
except Exception as e:
logger.error(f"Error in construct_deep_search_filters: {e}")
filter_results = KGFilterConstructionResults(
global_entity_filters=[],
global_relationship_filters=[],
local_entity_filters=[],
source_document_filters=[],
structure=[],
)
# div_con_structure = filter_results.structure
div_con_structure = filter_results.structure
# logger.info(f"div_con_structure: {div_con_structure}")
logger.info(f"div_con_structure: {div_con_structure}")
# with get_session_with_current_tenant() as db_session:
# double_grounded_entity_types = get_entity_types_with_grounded_source_name(
# db_session
# )
with get_session_with_current_tenant() as db_session:
double_grounded_entity_types = get_entity_types_with_grounded_source_name(
db_session
)
# source_division = False
source_division = False
# if div_con_structure:
# for entity_type in double_grounded_entity_types:
# # entity_type is guaranteed to have grounded_source_name
# if (
# cast(str, entity_type.grounded_source_name).lower()
# in div_con_structure[0].lower()
# ):
# source_division = True
# break
if div_con_structure:
for entity_type in double_grounded_entity_types:
# entity_type is guaranteed to have grounded_source_name
if (
cast(str, entity_type.grounded_source_name).lower()
in div_con_structure[0].lower()
):
source_division = True
break
# return DeepSearchFilterUpdate(
# vespa_filter_results=filter_results,
# div_con_entities=div_con_structure,
# source_division=source_division,
# global_entity_filters=[
# make_entity_id(global_filter, "*")
# for global_filter in filter_results.global_entity_filters
# ],
# global_relationship_filters=filter_results.global_relationship_filters,
# local_entity_filters=filter_results.local_entity_filters,
# source_filters=filter_results.source_document_filters,
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="main",
# node_name="construct deep search filters",
# node_start_time=node_start_time,
# )
# ],
# step_results=[],
# )
return DeepSearchFilterUpdate(
vespa_filter_results=filter_results,
div_con_entities=div_con_structure,
source_division=source_division,
global_entity_filters=[
make_entity_id(global_filter, "*")
for global_filter in filter_results.global_entity_filters
],
global_relationship_filters=filter_results.global_relationship_filters,
local_entity_filters=filter_results.local_entity_filters,
source_filters=filter_results.source_document_filters,
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="construct deep search filters",
node_start_time=node_start_time,
)
],
step_results=[],
)

View File

@@ -1,168 +1,168 @@
# import copy
# from datetime import datetime
# from typing import cast
import copy
from datetime import datetime
from typing import cast
# from langchain_core.messages import HumanMessage
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from langchain_core.messages import HumanMessage
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from onyx.agents.agent_search.kb_search.graph_utils import build_document_context
# from onyx.agents.agent_search.kb_search.ops import research
# from onyx.agents.agent_search.kb_search.states import KGSourceDivisionType
# from onyx.agents.agent_search.kb_search.states import ResearchObjectInput
# from onyx.agents.agent_search.kb_search.states import ResearchObjectUpdate
# from onyx.agents.agent_search.models import GraphConfig
# from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
# trim_prompt_piece,
# )
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.chat.models import LlmDoc
# from onyx.configs.kg_configs import KG_MAX_SEARCH_DOCUMENTS
# from onyx.configs.kg_configs import KG_OBJECT_SOURCE_RESEARCH_TIMEOUT
# from onyx.context.search.models import InferenceSection
# from onyx.kg.utils.formatting_utils import split_entity_id
# from onyx.prompts.kg_prompts import KG_OBJECT_SOURCE_RESEARCH_PROMPT
# from onyx.utils.logger import setup_logger
# from onyx.utils.threadpool_concurrency import run_with_timeout
from onyx.agents.agent_search.kb_search.graph_utils import build_document_context
from onyx.agents.agent_search.kb_search.ops import research
from onyx.agents.agent_search.kb_search.states import KGSourceDivisionType
from onyx.agents.agent_search.kb_search.states import ResearchObjectInput
from onyx.agents.agent_search.kb_search.states import ResearchObjectUpdate
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
trim_prompt_piece,
)
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.chat.models import LlmDoc
from onyx.configs.kg_configs import KG_MAX_SEARCH_DOCUMENTS
from onyx.configs.kg_configs import KG_OBJECT_SOURCE_RESEARCH_TIMEOUT
from onyx.context.search.models import InferenceSection
from onyx.kg.utils.formatting_utils import split_entity_id
from onyx.prompts.kg_prompts import KG_OBJECT_SOURCE_RESEARCH_PROMPT
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
# logger = setup_logger()
logger = setup_logger()
# def process_individual_deep_search(
# state: ResearchObjectInput,
# config: RunnableConfig,
# writer: StreamWriter = lambda _: None,
# ) -> ResearchObjectUpdate:
# """
# LangGraph node to start the agentic search process.
# """
def process_individual_deep_search(
state: ResearchObjectInput,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> ResearchObjectUpdate:
"""
LangGraph node to start the agentic search process.
"""
# node_start_time = datetime.now()
node_start_time = datetime.now()
# graph_config = cast(GraphConfig, config["metadata"]["config"])
# search_tool = graph_config.tooling.search_tool
# question = state.broken_down_question
# segment_type = state.segment_type
graph_config = cast(GraphConfig, config["metadata"]["config"])
search_tool = graph_config.tooling.search_tool
question = state.broken_down_question
segment_type = state.segment_type
# object = state.entity.replace("::", ":: ").lower()
object = state.entity.replace("::", ":: ").lower()
# if not search_tool:
# raise ValueError("search_tool is not provided")
if not search_tool:
raise ValueError("search_tool is not provided")
# state.research_nr
state.research_nr
# if segment_type == KGSourceDivisionType.ENTITY.value:
if segment_type == KGSourceDivisionType.ENTITY.value:
# object_id = split_entity_id(object)[1].strip()
# extended_question = f"{question} in regards to {object}"
# source_filters = state.source_entity_filters
object_id = split_entity_id(object)[1].strip()
extended_question = f"{question} in regards to {object}"
source_filters = state.source_entity_filters
# # TODO: this does not really occur in V1. But needs to be changed for V2
# raw_kg_entity_filters = copy.deepcopy(
# list(
# set((state.vespa_filter_results.global_entity_filters + [state.entity]))
# )
# )
# TODO: this does not really occur in V1. But needs to be changed for V2
raw_kg_entity_filters = copy.deepcopy(
list(
set((state.vespa_filter_results.global_entity_filters + [state.entity]))
)
)
# kg_entity_filters = []
# for raw_kg_entity_filter in raw_kg_entity_filters:
# if "::" not in raw_kg_entity_filter:
# raw_kg_entity_filter += "::*"
# kg_entity_filters.append(raw_kg_entity_filter)
kg_entity_filters = []
for raw_kg_entity_filter in raw_kg_entity_filters:
if "::" not in raw_kg_entity_filter:
raw_kg_entity_filter += "::*"
kg_entity_filters.append(raw_kg_entity_filter)
# kg_relationship_filters = copy.deepcopy(
# state.vespa_filter_results.global_relationship_filters
# )
kg_relationship_filters = copy.deepcopy(
state.vespa_filter_results.global_relationship_filters
)
# logger.debug("Research for object: " + object)
# logger.debug(f"kg_entity_filters: {kg_entity_filters}")
# logger.debug(f"kg_relationship_filters: {kg_relationship_filters}")
logger.debug("Research for object: " + object)
logger.debug(f"kg_entity_filters: {kg_entity_filters}")
logger.debug(f"kg_relationship_filters: {kg_relationship_filters}")
# else:
# # if we came through the entity view route, in KG V1 the state entity
# # is the document to search for. No need to set other filters then.
# object_id = state.entity # source doc in this case
# extended_question = f"{question}"
# source_filters = [object_id]
else:
# if we came through the entity view route, in KG V1 the state entity
# is the document to search for. No need to set other filters then.
object_id = state.entity # source doc in this case
extended_question = f"{question}"
source_filters = [object_id]
# kg_entity_filters = None
# kg_relationship_filters = None
kg_entity_filters = None
kg_relationship_filters = None
# if source_filters and (len(source_filters) > KG_MAX_SEARCH_DOCUMENTS):
# logger.debug(
# f"Too many sources ({len(source_filters)}), setting to None and effectively filtered search"
# )
# source_filters = None
if source_filters and (len(source_filters) > KG_MAX_SEARCH_DOCUMENTS):
logger.debug(
f"Too many sources ({len(source_filters)}), setting to None and effectively filtered search"
)
source_filters = None
# retrieved_docs = research(
# question=extended_question,
# kg_entities=kg_entity_filters,
# kg_relationships=kg_relationship_filters,
# kg_sources=source_filters,
# search_tool=search_tool,
# )
retrieved_docs = research(
question=extended_question,
kg_entities=kg_entity_filters,
kg_relationships=kg_relationship_filters,
kg_sources=source_filters,
search_tool=search_tool,
)
# document_texts_list = []
document_texts_list = []
# for doc_num, retrieved_doc in enumerate(retrieved_docs):
# if not isinstance(retrieved_doc, (InferenceSection, LlmDoc)):
# raise ValueError(f"Unexpected document type: {type(retrieved_doc)}")
# chunk_text = build_document_context(retrieved_doc, doc_num + 1)
# document_texts_list.append(chunk_text)
for doc_num, retrieved_doc in enumerate(retrieved_docs):
if not isinstance(retrieved_doc, (InferenceSection, LlmDoc)):
raise ValueError(f"Unexpected document type: {type(retrieved_doc)}")
chunk_text = build_document_context(retrieved_doc, doc_num + 1)
document_texts_list.append(chunk_text)
# document_texts = "\n\n".join(document_texts_list)
document_texts = "\n\n".join(document_texts_list)
# # Built prompt
# Built prompt
# kg_object_source_research_prompt = KG_OBJECT_SOURCE_RESEARCH_PROMPT.format(
# question=extended_question,
# document_text=document_texts,
# )
kg_object_source_research_prompt = KG_OBJECT_SOURCE_RESEARCH_PROMPT.format(
question=extended_question,
document_text=document_texts,
)
# # Run LLM
# Run LLM
# msg = [
# HumanMessage(
# content=trim_prompt_piece(
# config=graph_config.tooling.primary_llm.config,
# prompt_piece=kg_object_source_research_prompt,
# reserved_str="",
# ),
# )
# ]
# primary_llm = graph_config.tooling.primary_llm
# # Grader
# try:
# llm_response = run_with_timeout(
# KG_OBJECT_SOURCE_RESEARCH_TIMEOUT,
# primary_llm.invoke_langchain,
# prompt=msg,
# timeout_override=KG_OBJECT_SOURCE_RESEARCH_TIMEOUT,
# max_tokens=300,
# )
msg = [
HumanMessage(
content=trim_prompt_piece(
config=graph_config.tooling.primary_llm.config,
prompt_piece=kg_object_source_research_prompt,
reserved_str="",
),
)
]
primary_llm = graph_config.tooling.primary_llm
# Grader
try:
llm_response = run_with_timeout(
KG_OBJECT_SOURCE_RESEARCH_TIMEOUT,
primary_llm.invoke_langchain,
prompt=msg,
timeout_override=KG_OBJECT_SOURCE_RESEARCH_TIMEOUT,
max_tokens=300,
)
# object_research_results = str(llm_response.content).replace("```json\n", "")
object_research_results = str(llm_response.content).replace("```json\n", "")
# except Exception as e:
# raise ValueError(f"Error in research_object_source: {e}")
except Exception as e:
raise ValueError(f"Error in research_object_source: {e}")
# logger.debug("DivCon Step A2 - Object Source Research - completed for an object")
logger.debug("DivCon Step A2 - Object Source Research - completed for an object")
# return ResearchObjectUpdate(
# research_object_results=[
# {
# "object": object.replace("::", ":: ").capitalize(),
# "results": object_research_results,
# }
# ],
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="main",
# node_name="process individual deep search",
# node_start_time=node_start_time,
# )
# ],
# step_results=[],
# )
return ResearchObjectUpdate(
research_object_results=[
{
"object": object.replace("::", ":: ").capitalize(),
"results": object_research_results,
}
],
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="process individual deep search",
node_start_time=node_start_time,
)
],
step_results=[],
)

View File

@@ -1,166 +1,166 @@
# from datetime import datetime
# from typing import cast
from datetime import datetime
from typing import cast
# from langchain_core.messages import HumanMessage
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from langchain_core.messages import HumanMessage
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from onyx.agents.agent_search.kb_search.graph_utils import build_document_context
# from onyx.agents.agent_search.kb_search.graph_utils import get_near_empty_step_results
# from onyx.agents.agent_search.kb_search.ops import research
# from onyx.agents.agent_search.kb_search.states import ConsolidatedResearchUpdate
# from onyx.agents.agent_search.kb_search.states import MainState
# from onyx.agents.agent_search.models import GraphConfig
# from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
# trim_prompt_piece,
# )
# from onyx.agents.agent_search.shared_graph_utils.calculations import (
# get_answer_generation_documents,
# )
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.configs.kg_configs import KG_FILTERED_SEARCH_TIMEOUT
# from onyx.configs.kg_configs import KG_RESEARCH_NUM_RETRIEVED_DOCS
# from onyx.context.search.models import InferenceSection
# from onyx.prompts.kg_prompts import KG_SEARCH_PROMPT
# from onyx.utils.logger import setup_logger
# from onyx.utils.threadpool_concurrency import run_with_timeout
from onyx.agents.agent_search.kb_search.graph_utils import build_document_context
from onyx.agents.agent_search.kb_search.graph_utils import get_near_empty_step_results
from onyx.agents.agent_search.kb_search.ops import research
from onyx.agents.agent_search.kb_search.states import ConsolidatedResearchUpdate
from onyx.agents.agent_search.kb_search.states import MainState
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
trim_prompt_piece,
)
from onyx.agents.agent_search.shared_graph_utils.calculations import (
get_answer_generation_documents,
)
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.configs.kg_configs import KG_FILTERED_SEARCH_TIMEOUT
from onyx.configs.kg_configs import KG_RESEARCH_NUM_RETRIEVED_DOCS
from onyx.context.search.models import InferenceSection
from onyx.prompts.kg_prompts import KG_SEARCH_PROMPT
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
# logger = setup_logger()
logger = setup_logger()
# def filtered_search(
# state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
# ) -> ConsolidatedResearchUpdate:
# """
# LangGraph node to do a filtered search.
# """
# _KG_STEP_NR = 4
def filtered_search(
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> ConsolidatedResearchUpdate:
"""
LangGraph node to do a filtered search.
"""
_KG_STEP_NR = 4
# node_start_time = datetime.now()
node_start_time = datetime.now()
# graph_config = cast(GraphConfig, config["metadata"]["config"])
# search_tool = graph_config.tooling.search_tool
# question = state.question
graph_config = cast(GraphConfig, config["metadata"]["config"])
search_tool = graph_config.tooling.search_tool
question = state.question
# if not search_tool:
# raise ValueError("search_tool is not provided")
if not search_tool:
raise ValueError("search_tool is not provided")
# if not state.vespa_filter_results:
# raise ValueError("vespa_filter_results is not provided")
# raw_kg_entity_filters = list(
# set((state.vespa_filter_results.global_entity_filters))
# )
if not state.vespa_filter_results:
raise ValueError("vespa_filter_results is not provided")
raw_kg_entity_filters = list(
set((state.vespa_filter_results.global_entity_filters))
)
# kg_entity_filters = []
# for raw_kg_entity_filter in raw_kg_entity_filters:
# if "::" not in raw_kg_entity_filter:
# raw_kg_entity_filter += "::*"
# kg_entity_filters.append(raw_kg_entity_filter)
kg_entity_filters = []
for raw_kg_entity_filter in raw_kg_entity_filters:
if "::" not in raw_kg_entity_filter:
raw_kg_entity_filter += "::*"
kg_entity_filters.append(raw_kg_entity_filter)
# kg_relationship_filters = state.vespa_filter_results.global_relationship_filters
kg_relationship_filters = state.vespa_filter_results.global_relationship_filters
# logger.debug("Starting filtered search")
# logger.debug(f"kg_entity_filters: {kg_entity_filters}")
# logger.debug(f"kg_relationship_filters: {kg_relationship_filters}")
logger.debug("Starting filtered search")
logger.debug(f"kg_entity_filters: {kg_entity_filters}")
logger.debug(f"kg_relationship_filters: {kg_relationship_filters}")
# retrieved_docs = cast(
# list[InferenceSection],
# research(
# question=question,
# kg_entities=kg_entity_filters,
# kg_relationships=kg_relationship_filters,
# kg_sources=None,
# search_tool=search_tool,
# inference_sections_only=True,
# ),
# )
retrieved_docs = cast(
list[InferenceSection],
research(
question=question,
kg_entities=kg_entity_filters,
kg_relationships=kg_relationship_filters,
kg_sources=None,
search_tool=search_tool,
inference_sections_only=True,
),
)
# source_link_dict = {
# num + 1: doc.center_chunk.source_links[0]
# for num, doc in enumerate(retrieved_docs)
# if doc.center_chunk.source_links
# }
source_link_dict = {
num + 1: doc.center_chunk.source_links[0]
for num, doc in enumerate(retrieved_docs)
if doc.center_chunk.source_links
}
# answer_generation_documents = get_answer_generation_documents(
# relevant_docs=retrieved_docs,
# context_documents=retrieved_docs,
# original_question_docs=retrieved_docs,
# max_docs=KG_RESEARCH_NUM_RETRIEVED_DOCS,
# )
answer_generation_documents = get_answer_generation_documents(
relevant_docs=retrieved_docs,
context_documents=retrieved_docs,
original_question_docs=retrieved_docs,
max_docs=KG_RESEARCH_NUM_RETRIEVED_DOCS,
)
# document_texts_list = []
document_texts_list = []
# for doc_num, retrieved_doc in enumerate(
# answer_generation_documents.context_documents
# ):
# chunk_text = build_document_context(retrieved_doc, doc_num + 1)
# document_texts_list.append(chunk_text)
for doc_num, retrieved_doc in enumerate(
answer_generation_documents.context_documents
):
chunk_text = build_document_context(retrieved_doc, doc_num + 1)
document_texts_list.append(chunk_text)
# document_texts = "\n\n".join(document_texts_list)
document_texts = "\n\n".join(document_texts_list)
# # Built prompt
# Built prompt
# datetime.now().strftime("%A, %Y-%m-%d")
datetime.now().strftime("%A, %Y-%m-%d")
# kg_object_source_research_prompt = KG_SEARCH_PROMPT.format(
# question=question,
# document_text=document_texts,
# )
kg_object_source_research_prompt = KG_SEARCH_PROMPT.format(
question=question,
document_text=document_texts,
)
# # Run LLM
# Run LLM
# msg = [
# HumanMessage(
# content=trim_prompt_piece(
# config=graph_config.tooling.primary_llm.config,
# prompt_piece=kg_object_source_research_prompt,
# reserved_str="",
# ),
# )
# ]
# primary_llm = graph_config.tooling.primary_llm
# llm = primary_llm
# # Grader
# try:
# llm_response = run_with_timeout(
# KG_FILTERED_SEARCH_TIMEOUT,
# llm.invoke_langchain,
# prompt=msg,
# timeout_override=30,
# max_tokens=300,
# )
msg = [
HumanMessage(
content=trim_prompt_piece(
config=graph_config.tooling.primary_llm.config,
prompt_piece=kg_object_source_research_prompt,
reserved_str="",
),
)
]
primary_llm = graph_config.tooling.primary_llm
llm = primary_llm
# Grader
try:
llm_response = run_with_timeout(
KG_FILTERED_SEARCH_TIMEOUT,
llm.invoke_langchain,
prompt=msg,
timeout_override=30,
max_tokens=300,
)
# filtered_search_answer = str(llm_response.content).replace("```json\n", "")
filtered_search_answer = str(llm_response.content).replace("```json\n", "")
# # TODO: make sure the citations look correct. Currently, they do not.
# for source_link_num, source_link in source_link_dict.items():
# if f"[{source_link_num}]" in filtered_search_answer:
# filtered_search_answer = filtered_search_answer.replace(
# f"[{source_link_num}]", f"[{source_link_num}]({source_link})"
# )
# TODO: make sure the citations look correct. Currently, they do not.
for source_link_num, source_link in source_link_dict.items():
if f"[{source_link_num}]" in filtered_search_answer:
filtered_search_answer = filtered_search_answer.replace(
f"[{source_link_num}]", f"[{source_link_num}]({source_link})"
)
# except Exception as e:
# raise ValueError(f"Error in filtered_search: {e}")
except Exception as e:
raise ValueError(f"Error in filtered_search: {e}")
# step_answer = "Filtered search is complete."
step_answer = "Filtered search is complete."
# return ConsolidatedResearchUpdate(
# consolidated_research_object_results_str=filtered_search_answer,
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="main",
# node_name="filtered search",
# node_start_time=node_start_time,
# )
# ],
# step_results=[
# get_near_empty_step_results(
# step_number=_KG_STEP_NR,
# step_answer=step_answer,
# verified_reranked_documents=retrieved_docs,
# )
# ],
# )
return ConsolidatedResearchUpdate(
consolidated_research_object_results_str=filtered_search_answer,
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="filtered search",
node_start_time=node_start_time,
)
],
step_results=[
get_near_empty_step_results(
step_number=_KG_STEP_NR,
step_answer=step_answer,
verified_reranked_documents=retrieved_docs,
)
],
)

View File

@@ -1,54 +1,54 @@
# from datetime import datetime
from datetime import datetime
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from onyx.agents.agent_search.kb_search.graph_utils import get_near_empty_step_results
# from onyx.agents.agent_search.kb_search.graph_utils import rename_entities_in_answer
# from onyx.agents.agent_search.kb_search.states import ConsolidatedResearchUpdate
# from onyx.agents.agent_search.kb_search.states import MainState
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.utils.logger import setup_logger
from onyx.agents.agent_search.kb_search.graph_utils import get_near_empty_step_results
from onyx.agents.agent_search.kb_search.graph_utils import rename_entities_in_answer
from onyx.agents.agent_search.kb_search.states import ConsolidatedResearchUpdate
from onyx.agents.agent_search.kb_search.states import MainState
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.utils.logger import setup_logger
# logger = setup_logger()
logger = setup_logger()
# def consolidate_individual_deep_search(
# state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
# ) -> ConsolidatedResearchUpdate:
# """
# LangGraph node to start the agentic search process.
# """
def consolidate_individual_deep_search(
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> ConsolidatedResearchUpdate:
"""
LangGraph node to start the agentic search process.
"""
# _KG_STEP_NR = 4
# node_start_time = datetime.now()
_KG_STEP_NR = 4
node_start_time = datetime.now()
# research_object_results = state.research_object_results
research_object_results = state.research_object_results
# consolidated_research_object_results_str = "\n".join(
# [f"{x['object']}: {x['results']}" for x in research_object_results]
# )
consolidated_research_object_results_str = "\n".join(
[f"{x['object']}: {x['results']}" for x in research_object_results]
)
# consolidated_research_object_results_str = rename_entities_in_answer(
# consolidated_research_object_results_str
# )
consolidated_research_object_results_str = rename_entities_in_answer(
consolidated_research_object_results_str
)
# step_answer = "All research is complete. Consolidating results..."
step_answer = "All research is complete. Consolidating results..."
# return ConsolidatedResearchUpdate(
# consolidated_research_object_results_str=consolidated_research_object_results_str,
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="main",
# node_name="consolidate individual deep search",
# node_start_time=node_start_time,
# )
# ],
# step_results=[
# get_near_empty_step_results(
# step_number=_KG_STEP_NR, step_answer=step_answer
# )
# ],
# )
return ConsolidatedResearchUpdate(
consolidated_research_object_results_str=consolidated_research_object_results_str,
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="consolidate individual deep search",
node_start_time=node_start_time,
)
],
step_results=[
get_near_empty_step_results(
step_number=_KG_STEP_NR, step_answer=step_answer
)
],
)

View File

@@ -1,96 +1,96 @@
# from datetime import datetime
from datetime import datetime
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from onyx.agents.agent_search.kb_search.graph_utils import get_near_empty_step_results
# from onyx.agents.agent_search.kb_search.states import MainState
# from onyx.agents.agent_search.kb_search.states import ResultsDataUpdate
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.db.document import get_base_llm_doc_information
# from onyx.db.engine.sql_engine import get_session_with_current_tenant
# from onyx.utils.logger import setup_logger
from onyx.agents.agent_search.kb_search.graph_utils import get_near_empty_step_results
from onyx.agents.agent_search.kb_search.states import MainState
from onyx.agents.agent_search.kb_search.states import ResultsDataUpdate
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.db.document import get_base_llm_doc_information
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.utils.logger import setup_logger
# logger = setup_logger()
logger = setup_logger()
# def _get_formated_source_reference_results(
# source_document_results: list[str] | None,
# ) -> str | None:
# """
# Generate reference results from the query results data string.
# """
def _get_formated_source_reference_results(
source_document_results: list[str] | None,
) -> str | None:
"""
Generate reference results from the query results data string.
"""
# if source_document_results is None:
# return None
if source_document_results is None:
return None
# # get all entities that correspond to an Onyx document
# document_ids = source_document_results
# get all entities that correspond to an Onyx document
document_ids = source_document_results
# with get_session_with_current_tenant() as session:
# llm_doc_information_results = get_base_llm_doc_information(
# session, document_ids
# )
with get_session_with_current_tenant() as session:
llm_doc_information_results = get_base_llm_doc_information(
session, document_ids
)
# if len(llm_doc_information_results) == 0:
# return ""
if len(llm_doc_information_results) == 0:
return ""
# return (
# f"\n \n Here are {len(llm_doc_information_results)} supporting documents or examples: \n \n "
# + " \n \n ".join(llm_doc_information_results)
# )
return (
f"\n \n Here are {len(llm_doc_information_results)} supporting documents or examples: \n \n "
+ " \n \n ".join(llm_doc_information_results)
)
# def process_kg_only_answers(
# state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
# ) -> ResultsDataUpdate:
# """
# LangGraph node to start the agentic search process.
# """
def process_kg_only_answers(
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> ResultsDataUpdate:
"""
LangGraph node to start the agentic search process.
"""
# _KG_STEP_NR = 4
_KG_STEP_NR = 4
# node_start_time = datetime.now()
node_start_time = datetime.now()
# query_results = state.sql_query_results
# source_document_results = state.source_document_results
query_results = state.sql_query_results
source_document_results = state.source_document_results
# if query_results:
# query_results_data_str = "\n".join(
# str(query_result).replace("::", ":: ").capitalize()
# for query_result in query_results
# )
# else:
# logger.warning("No query results were found")
# query_results_data_str = "(No query results were found)"
if query_results:
query_results_data_str = "\n".join(
str(query_result).replace("::", ":: ").capitalize()
for query_result in query_results
)
else:
logger.warning("No query results were found")
query_results_data_str = "(No query results were found)"
# source_reference_result_str = _get_formated_source_reference_results(
# source_document_results
# )
source_reference_result_str = _get_formated_source_reference_results(
source_document_results
)
# ## STEP 4 - same components as Step 1
## STEP 4 - same components as Step 1
# step_answer = (
# "No further research is needed, the answer is derived from the knowledge graph."
# )
step_answer = (
"No further research is needed, the answer is derived from the knowledge graph."
)
# return ResultsDataUpdate(
# query_results_data_str=query_results_data_str,
# individualized_query_results_data_str="",
# reference_results_str=source_reference_result_str,
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="main",
# node_name="kg query results data processing",
# node_start_time=node_start_time,
# )
# ],
# step_results=[
# get_near_empty_step_results(
# step_number=_KG_STEP_NR, step_answer=step_answer
# )
# ],
# )
return ResultsDataUpdate(
query_results_data_str=query_results_data_str,
individualized_query_results_data_str="",
reference_results_str=source_reference_result_str,
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="kg query results data processing",
node_start_time=node_start_time,
)
],
step_results=[
get_near_empty_step_results(
step_number=_KG_STEP_NR, step_answer=step_answer
)
],
)

View File

@@ -1,187 +1,213 @@
# from datetime import datetime
# from typing import cast
from datetime import datetime
from typing import cast
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from onyx.access.access import get_acl_for_user
# from onyx.agents.agent_search.kb_search.graph_utils import rename_entities_in_answer
# from onyx.agents.agent_search.kb_search.ops import research
# from onyx.agents.agent_search.kb_search.states import FinalAnswerUpdate
# from onyx.agents.agent_search.kb_search.states import MainState
# from onyx.agents.agent_search.models import GraphConfig
# from onyx.agents.agent_search.shared_graph_utils.calculations import (
# get_answer_generation_documents,
# )
# from onyx.agents.agent_search.shared_graph_utils.llm import get_answer_from_llm
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.configs.kg_configs import KG_RESEARCH_NUM_RETRIEVED_DOCS
# from onyx.configs.kg_configs import KG_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION
# from onyx.context.search.models import InferenceSection
# from onyx.db.engine.sql_engine import get_session_with_current_tenant
# from onyx.prompts.kg_prompts import OUTPUT_FORMAT_NO_EXAMPLES_PROMPT
# from onyx.prompts.kg_prompts import OUTPUT_FORMAT_NO_OVERALL_ANSWER_PROMPT
# from onyx.utils.logger import setup_logger
from onyx.access.access import get_acl_for_user
from onyx.agents.agent_search.kb_search.graph_utils import rename_entities_in_answer
from onyx.agents.agent_search.kb_search.ops import research
from onyx.agents.agent_search.kb_search.states import FinalAnswerUpdate
from onyx.agents.agent_search.kb_search.states import MainState
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.calculations import (
get_answer_generation_documents,
)
from onyx.agents.agent_search.shared_graph_utils.llm import get_answer_from_llm
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import relevance_from_docs
from onyx.configs.kg_configs import KG_RESEARCH_NUM_RETRIEVED_DOCS
from onyx.configs.kg_configs import KG_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION
from onyx.context.search.enums import SearchType
from onyx.context.search.models import InferenceSection
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.prompts.kg_prompts import OUTPUT_FORMAT_NO_EXAMPLES_PROMPT
from onyx.prompts.kg_prompts import OUTPUT_FORMAT_NO_OVERALL_ANSWER_PROMPT
from onyx.tools.tool_implementations.search.search_tool import IndexFilters
from onyx.tools.tool_implementations.search.search_tool import SearchQueryInfo
from onyx.tools.tool_implementations.search.search_tool import yield_search_responses
from onyx.utils.logger import setup_logger
# logger = setup_logger()
logger = setup_logger()
# def generate_answer(
# state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
# ) -> FinalAnswerUpdate:
# """
# LangGraph node to start the agentic search process.
# """
def generate_answer(
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> FinalAnswerUpdate:
"""
LangGraph node to start the agentic search process.
"""
# node_start_time = datetime.now()
node_start_time = datetime.now()
# graph_config = cast(GraphConfig, config["metadata"]["config"])
# question = state.question
graph_config = cast(GraphConfig, config["metadata"]["config"])
question = state.question
# final_answer: str | None = None
final_answer: str | None = None
# user = (
# graph_config.tooling.search_tool.user
# if graph_config.tooling.search_tool
# else None
# )
user = (
graph_config.tooling.search_tool.user
if graph_config.tooling.search_tool
else None
)
# if not user:
# raise ValueError("User is not set")
if not user:
raise ValueError("User is not set")
# search_tool = graph_config.tooling.search_tool
# if search_tool is None:
# raise ValueError("Search tool is not set")
search_tool = graph_config.tooling.search_tool
if search_tool is None:
raise ValueError("Search tool is not set")
# # Close out previous streams of steps
# Close out previous streams of steps
# # DECLARE STEPS DONE
# DECLARE STEPS DONE
# ## MAIN ANSWER
## MAIN ANSWER
# # identify whether documents have already been retrieved
# identify whether documents have already been retrieved
# retrieved_docs: list[InferenceSection] = []
# for step_result in state.step_results:
# retrieved_docs += step_result.verified_reranked_documents
retrieved_docs: list[InferenceSection] = []
for step_result in state.step_results:
retrieved_docs += step_result.verified_reranked_documents
# # if still needed, get a search done and send the results to the UI
# if still needed, get a search done and send the results to the UI
# if not retrieved_docs and state.source_document_results:
# assert graph_config.tooling.search_tool is not None
# retrieved_docs = cast(
# list[InferenceSection],
# research(
# question=question,
# kg_entities=[],
# kg_relationships=[],
# kg_sources=state.source_document_results[
# :KG_RESEARCH_NUM_RETRIEVED_DOCS
# ],
# search_tool=graph_config.tooling.search_tool,
# kg_chunk_id_zero_only=True,
# inference_sections_only=True,
# ),
# )
if not retrieved_docs and state.source_document_results:
assert graph_config.tooling.search_tool is not None
retrieved_docs = cast(
list[InferenceSection],
research(
question=question,
kg_entities=[],
kg_relationships=[],
kg_sources=state.source_document_results[
:KG_RESEARCH_NUM_RETRIEVED_DOCS
],
search_tool=graph_config.tooling.search_tool,
kg_chunk_id_zero_only=True,
inference_sections_only=True,
),
)
# answer_generation_documents = get_answer_generation_documents(
# relevant_docs=retrieved_docs,
# context_documents=retrieved_docs,
# original_question_docs=retrieved_docs,
# max_docs=KG_RESEARCH_NUM_RETRIEVED_DOCS,
# )
answer_generation_documents = get_answer_generation_documents(
relevant_docs=retrieved_docs,
context_documents=retrieved_docs,
original_question_docs=retrieved_docs,
max_docs=KG_RESEARCH_NUM_RETRIEVED_DOCS,
)
# assert graph_config.tooling.search_tool is not None
relevance_list = relevance_from_docs(
answer_generation_documents.streaming_documents
)
# with get_session_with_current_tenant() as graph_db_session:
# list(get_acl_for_user(user, graph_db_session))
assert graph_config.tooling.search_tool is not None
# # continue with the answer generation
with get_session_with_current_tenant() as graph_db_session:
user_acl = list(get_acl_for_user(user, graph_db_session))
# output_format = (
# state.output_format.value
# if state.output_format
# else "<you be the judge how to best present the data>"
# )
for tool_response in yield_search_responses(
query=question,
get_retrieved_sections=lambda: answer_generation_documents.context_documents,
get_final_context_sections=lambda: answer_generation_documents.context_documents,
search_query_info=SearchQueryInfo(
predicted_search=SearchType.KEYWORD,
# acl here is empty, because the searach alrady happened and
# we are streaming out the results.
final_filters=IndexFilters(access_control_list=user_acl),
recency_bias_multiplier=1.0,
),
get_section_relevance=lambda: relevance_list,
search_tool=graph_config.tooling.search_tool,
):
# original document streaming
pass
# # if deep path was taken:
# continue with the answer generation
# consolidated_research_object_results_str = (
# state.consolidated_research_object_results_str
# )
# # reference_results_str = (
# # state.reference_results_str
# # ) # will not be part of LLM. Manually added to the answer
output_format = (
state.output_format.value
if state.output_format
else "<you be the judge how to best present the data>"
)
# # if simple path was taken:
# introductory_answer = state.query_results_data_str # from simple answer path only
# if consolidated_research_object_results_str:
# research_results = consolidated_research_object_results_str
# else:
# research_results = ""
# if deep path was taken:
# if introductory_answer:
# output_format_prompt = (
# OUTPUT_FORMAT_NO_EXAMPLES_PROMPT.replace("---question---", question)
# .replace(
# "---introductory_answer---",
# rename_entities_in_answer(introductory_answer),
# )
# .replace("---output_format---", str(output_format) if output_format else "")
# )
# elif research_results and consolidated_research_object_results_str:
# output_format_prompt = (
# OUTPUT_FORMAT_NO_EXAMPLES_PROMPT.replace("---question---", question)
# .replace(
# "---introductory_answer---",
# rename_entities_in_answer(consolidated_research_object_results_str),
# )
# .replace("---output_format---", str(output_format) if output_format else "")
# )
# elif research_results and not consolidated_research_object_results_str:
# output_format_prompt = (
# OUTPUT_FORMAT_NO_OVERALL_ANSWER_PROMPT.replace("---question---", question)
# .replace("---output_format---", str(output_format) if output_format else "")
# .replace(
# "---research_results---", rename_entities_in_answer(research_results)
# )
# )
# elif consolidated_research_object_results_str:
# output_format_prompt = (
# OUTPUT_FORMAT_NO_EXAMPLES_PROMPT.replace("---question---", question)
# .replace("---output_format---", str(output_format) if output_format else "")
# .replace(
# "---research_results---", rename_entities_in_answer(research_results)
# )
# )
# else:
# raise ValueError("No research results or introductory answer provided")
consolidated_research_object_results_str = (
state.consolidated_research_object_results_str
)
# reference_results_str = (
# state.reference_results_str
# ) # will not be part of LLM. Manually added to the answer
# try:
# if simple path was taken:
introductory_answer = state.query_results_data_str # from simple answer path only
if consolidated_research_object_results_str:
research_results = consolidated_research_object_results_str
else:
research_results = ""
# final_answer = get_answer_from_llm(
# llm=graph_config.tooling.primary_llm,
# prompt=output_format_prompt,
# stream=False,
# json_string_flag=False,
# timeout_override=KG_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION,
# )
if introductory_answer:
output_format_prompt = (
OUTPUT_FORMAT_NO_EXAMPLES_PROMPT.replace("---question---", question)
.replace(
"---introductory_answer---",
rename_entities_in_answer(introductory_answer),
)
.replace("---output_format---", str(output_format) if output_format else "")
)
elif research_results and consolidated_research_object_results_str:
output_format_prompt = (
OUTPUT_FORMAT_NO_EXAMPLES_PROMPT.replace("---question---", question)
.replace(
"---introductory_answer---",
rename_entities_in_answer(consolidated_research_object_results_str),
)
.replace("---output_format---", str(output_format) if output_format else "")
)
elif research_results and not consolidated_research_object_results_str:
output_format_prompt = (
OUTPUT_FORMAT_NO_OVERALL_ANSWER_PROMPT.replace("---question---", question)
.replace("---output_format---", str(output_format) if output_format else "")
.replace(
"---research_results---", rename_entities_in_answer(research_results)
)
)
elif consolidated_research_object_results_str:
output_format_prompt = (
OUTPUT_FORMAT_NO_EXAMPLES_PROMPT.replace("---question---", question)
.replace("---output_format---", str(output_format) if output_format else "")
.replace(
"---research_results---", rename_entities_in_answer(research_results)
)
)
else:
raise ValueError("No research results or introductory answer provided")
# except Exception as e:
# raise ValueError(f"Could not generate the answer. Error {e}")
try:
# return FinalAnswerUpdate(
# final_answer=final_answer,
# retrieved_documents=answer_generation_documents.context_documents,
# step_results=[],
# remarks=[],
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="main",
# node_name="query completed",
# node_start_time=node_start_time,
# )
# ],
# )
final_answer = get_answer_from_llm(
llm=graph_config.tooling.primary_llm,
prompt=output_format_prompt,
stream=False,
json_string_flag=False,
timeout_override=KG_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION,
)
except Exception as e:
raise ValueError(f"Could not generate the answer. Error {e}")
return FinalAnswerUpdate(
final_answer=final_answer,
retrieved_documents=answer_generation_documents.context_documents,
step_results=[],
remarks=[],
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="query completed",
node_start_time=node_start_time,
)
],
)

View File

@@ -1,60 +1,60 @@
# from datetime import datetime
# from typing import cast
from datetime import datetime
from typing import cast
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from onyx.agents.agent_search.kb_search.states import MainOutput
# from onyx.agents.agent_search.kb_search.states import MainState
# from onyx.agents.agent_search.models import GraphConfig
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.db.chat import log_agent_sub_question_results
# from onyx.utils.logger import setup_logger
from onyx.agents.agent_search.kb_search.states import MainOutput
from onyx.agents.agent_search.kb_search.states import MainState
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.db.chat import log_agent_sub_question_results
from onyx.utils.logger import setup_logger
# logger = setup_logger()
logger = setup_logger()
# def log_data(
# state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
# ) -> MainOutput:
# """
# LangGraph node to start the agentic search process.
# """
def log_data(
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> MainOutput:
"""
LangGraph node to start the agentic search process.
"""
# node_start_time = datetime.now()
node_start_time = datetime.now()
# graph_config = cast(GraphConfig, config["metadata"]["config"])
graph_config = cast(GraphConfig, config["metadata"]["config"])
# search_tool = graph_config.tooling.search_tool
# if search_tool is None:
# raise ValueError("Search tool is not set")
search_tool = graph_config.tooling.search_tool
if search_tool is None:
raise ValueError("Search tool is not set")
# # commit original db_session
# commit original db_session
# query_db_session = graph_config.persistence.db_session
# query_db_session.commit()
query_db_session = graph_config.persistence.db_session
query_db_session.commit()
# chat_session_id = graph_config.persistence.chat_session_id
# primary_message_id = graph_config.persistence.message_id
# sub_question_answer_results = state.step_results
chat_session_id = graph_config.persistence.chat_session_id
primary_message_id = graph_config.persistence.message_id
sub_question_answer_results = state.step_results
# log_agent_sub_question_results(
# db_session=query_db_session,
# chat_session_id=chat_session_id,
# primary_message_id=primary_message_id,
# sub_question_answer_results=sub_question_answer_results,
# )
log_agent_sub_question_results(
db_session=query_db_session,
chat_session_id=chat_session_id,
primary_message_id=primary_message_id,
sub_question_answer_results=sub_question_answer_results,
)
# return MainOutput(
# final_answer=state.final_answer,
# retrieved_documents=state.retrieved_documents,
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="main",
# node_name="query completed",
# node_start_time=node_start_time,
# )
# ],
# )
return MainOutput(
final_answer=state.final_answer,
retrieved_documents=state.retrieved_documents,
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="query completed",
node_start_time=node_start_time,
)
],
)

View File

@@ -1,47 +1,65 @@
# from datetime import datetime
# from typing import cast
from datetime import datetime
from typing import cast
# from onyx.chat.models import LlmDoc
# from onyx.configs.constants import DocumentSource
# from onyx.configs.kg_configs import KG_RESEARCH_NUM_RETRIEVED_DOCS
# from onyx.context.search.models import InferenceSection
# from onyx.tools.models import SearchToolOverrideKwargs
# from onyx.tools.tool_implementations.search.search_tool import (
# FINAL_CONTEXT_DOCUMENTS_ID,
# )
# from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.chat.models import LlmDoc
from onyx.configs.constants import DocumentSource
from onyx.configs.kg_configs import KG_RESEARCH_NUM_RETRIEVED_DOCS
from onyx.context.search.models import InferenceSection
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.tools.models import SearchToolOverrideKwargs
from onyx.tools.tool_implementations.search.search_tool import (
FINAL_CONTEXT_DOCUMENTS_ID,
)
from onyx.tools.tool_implementations.search.search_tool import SearchTool
# def research(
# question: str,
# search_tool: SearchTool,
# document_sources: list[DocumentSource] | None = None,
# time_cutoff: datetime | None = None,
# kg_entities: list[str] | None = None,
# kg_relationships: list[str] | None = None,
# kg_terms: list[str] | None = None,
# kg_sources: list[str] | None = None,
# kg_chunk_id_zero_only: bool = False,
# inference_sections_only: bool = False,
# ) -> list[LlmDoc] | list[InferenceSection]:
# # new db session to avoid concurrency issues
def research(
question: str,
search_tool: SearchTool,
document_sources: list[DocumentSource] | None = None,
time_cutoff: datetime | None = None,
kg_entities: list[str] | None = None,
kg_relationships: list[str] | None = None,
kg_terms: list[str] | None = None,
kg_sources: list[str] | None = None,
kg_chunk_id_zero_only: bool = False,
inference_sections_only: bool = False,
) -> list[LlmDoc] | list[InferenceSection]:
# new db session to avoid concurrency issues
# retrieved_docs: list[LlmDoc] | list[InferenceSection] = []
callback_container: list[list[InferenceSection]] = []
retrieved_docs: list[LlmDoc] | list[InferenceSection] = []
# for tool_response in search_tool.run(
# query=question,
# override_kwargs=SearchToolOverrideKwargs(original_query=question),
# ):
# if inference_sections_only and tool_response.id == "search_response_summary":
# retrieved_docs = tool_response.response.top_sections[
# :KG_RESEARCH_NUM_RETRIEVED_DOCS
# ]
# retrieved_docs = cast(list[InferenceSection], retrieved_docs)
# break
# # get retrieved docs to send to the rest of the graph
# elif tool_response.id == FINAL_CONTEXT_DOCUMENTS_ID:
# retrieved_docs = cast(list[LlmDoc], tool_response.response)[
# :KG_RESEARCH_NUM_RETRIEVED_DOCS
# ]
# break
# return retrieved_docs
with get_session_with_current_tenant() as db_session:
for tool_response in search_tool.run(
query=question,
override_kwargs=SearchToolOverrideKwargs(
force_no_rerank=False,
alternate_db_session=db_session,
retrieved_sections_callback=callback_container.append,
skip_query_analysis=True,
document_sources=document_sources,
time_cutoff=time_cutoff,
kg_entities=kg_entities,
kg_relationships=kg_relationships,
kg_terms=kg_terms,
kg_sources=kg_sources,
kg_chunk_id_zero_only=kg_chunk_id_zero_only,
),
):
if (
inference_sections_only
and tool_response.id == "search_response_summary"
):
retrieved_docs = tool_response.response.top_sections[
:KG_RESEARCH_NUM_RETRIEVED_DOCS
]
retrieved_docs = cast(list[InferenceSection], retrieved_docs)
break
# get retrieved docs to send to the rest of the graph
elif tool_response.id == FINAL_CONTEXT_DOCUMENTS_ID:
retrieved_docs = cast(list[LlmDoc], tool_response.response)[
:KG_RESEARCH_NUM_RETRIEVED_DOCS
]
break
return retrieved_docs

View File

@@ -1,191 +1,191 @@
# from enum import Enum
# from operator import add
# from typing import Annotated
# from typing import Any
# from typing import Dict
# from typing import TypedDict
from enum import Enum
from operator import add
from typing import Annotated
from typing import Any
from typing import Dict
from typing import TypedDict
# from pydantic import BaseModel
from pydantic import BaseModel
# from onyx.agents.agent_search.core_state import CoreState
# from onyx.agents.agent_search.orchestration.states import ToolCallUpdate
# from onyx.agents.agent_search.orchestration.states import ToolChoiceInput
# from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
# from onyx.agents.agent_search.shared_graph_utils.models import QueryRetrievalResult
# from onyx.agents.agent_search.shared_graph_utils.models import SubQuestionAnswerResults
# from onyx.context.search.models import InferenceSection
from onyx.agents.agent_search.core_state import CoreState
from onyx.agents.agent_search.orchestration.states import ToolCallUpdate
from onyx.agents.agent_search.orchestration.states import ToolChoiceInput
from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
from onyx.agents.agent_search.shared_graph_utils.models import QueryRetrievalResult
from onyx.agents.agent_search.shared_graph_utils.models import SubQuestionAnswerResults
from onyx.context.search.models import InferenceSection
# ### States ###
### States ###
# class StepResults(BaseModel):
# question: str
# question_id: str
# answer: str
# sub_query_retrieval_results: list[QueryRetrievalResult]
# verified_reranked_documents: list[InferenceSection]
# context_documents: list[InferenceSection]
# cited_documents: list[InferenceSection]
class StepResults(BaseModel):
question: str
question_id: str
answer: str
sub_query_retrieval_results: list[QueryRetrievalResult]
verified_reranked_documents: list[InferenceSection]
context_documents: list[InferenceSection]
cited_documents: list[InferenceSection]
# class LoggerUpdate(BaseModel):
# log_messages: Annotated[list[str], add] = []
# step_results: Annotated[list[SubQuestionAnswerResults], add]
# remarks: Annotated[list[str], add] = []
class LoggerUpdate(BaseModel):
log_messages: Annotated[list[str], add] = []
step_results: Annotated[list[SubQuestionAnswerResults], add]
remarks: Annotated[list[str], add] = []
# class KGFilterConstructionResults(BaseModel):
# global_entity_filters: list[str]
# global_relationship_filters: list[str]
# local_entity_filters: list[list[str]]
# source_document_filters: list[str]
# structure: list[str]
class KGFilterConstructionResults(BaseModel):
global_entity_filters: list[str]
global_relationship_filters: list[str]
local_entity_filters: list[list[str]]
source_document_filters: list[str]
structure: list[str]
# class KGSearchType(Enum):
# SEARCH = "SEARCH"
# SQL = "SQL"
class KGSearchType(Enum):
SEARCH = "SEARCH"
SQL = "SQL"
# class KGAnswerStrategy(Enum):
# DEEP = "DEEP"
# SIMPLE = "SIMPLE"
class KGAnswerStrategy(Enum):
DEEP = "DEEP"
SIMPLE = "SIMPLE"
# class KGSourceDivisionType(Enum):
# SOURCE = "SOURCE"
# ENTITY = "ENTITY"
class KGSourceDivisionType(Enum):
SOURCE = "SOURCE"
ENTITY = "ENTITY"
# class KGRelationshipDetection(Enum):
# RELATIONSHIPS = "RELATIONSHIPS"
# NO_RELATIONSHIPS = "NO_RELATIONSHIPS"
class KGRelationshipDetection(Enum):
RELATIONSHIPS = "RELATIONSHIPS"
NO_RELATIONSHIPS = "NO_RELATIONSHIPS"
# class KGAnswerFormat(Enum):
# LIST = "LIST"
# TEXT = "TEXT"
class KGAnswerFormat(Enum):
LIST = "LIST"
TEXT = "TEXT"
# class YesNoEnum(str, Enum):
# YES = "yes"
# NO = "no"
class YesNoEnum(str, Enum):
YES = "yes"
NO = "no"
# class AnalysisUpdate(LoggerUpdate):
# normalized_core_entities: list[str] = []
# normalized_core_relationships: list[str] = []
# entity_normalization_map: dict[str, str] = {}
# relationship_normalization_map: dict[str, str] = {}
# query_graph_entities_no_attributes: list[str] = []
# query_graph_entities_w_attributes: list[str] = []
# query_graph_relationships: list[str] = []
# normalized_terms: list[str] = []
# normalized_time_filter: str | None = None
# strategy: KGAnswerStrategy | None = None
# output_format: KGAnswerFormat | None = None
# broken_down_question: str | None = None
# divide_and_conquer: YesNoEnum | None = None
# single_doc_id: str | None = None
# search_type: KGSearchType | None = None
# query_type: str | None = None
class AnalysisUpdate(LoggerUpdate):
normalized_core_entities: list[str] = []
normalized_core_relationships: list[str] = []
entity_normalization_map: dict[str, str] = {}
relationship_normalization_map: dict[str, str] = {}
query_graph_entities_no_attributes: list[str] = []
query_graph_entities_w_attributes: list[str] = []
query_graph_relationships: list[str] = []
normalized_terms: list[str] = []
normalized_time_filter: str | None = None
strategy: KGAnswerStrategy | None = None
output_format: KGAnswerFormat | None = None
broken_down_question: str | None = None
divide_and_conquer: YesNoEnum | None = None
single_doc_id: str | None = None
search_type: KGSearchType | None = None
query_type: str | None = None
# class SQLSimpleGenerationUpdate(LoggerUpdate):
# sql_query: str | None = None
# sql_query_results: list[Dict[Any, Any]] | None = None
# individualized_sql_query: str | None = None
# individualized_query_results: list[Dict[Any, Any]] | None = None
# source_documents_sql: str | None = None
# source_document_results: list[str] | None = None
# updated_strategy: KGAnswerStrategy | None = None
class SQLSimpleGenerationUpdate(LoggerUpdate):
sql_query: str | None = None
sql_query_results: list[Dict[Any, Any]] | None = None
individualized_sql_query: str | None = None
individualized_query_results: list[Dict[Any, Any]] | None = None
source_documents_sql: str | None = None
source_document_results: list[str] | None = None
updated_strategy: KGAnswerStrategy | None = None
# class ConsolidatedResearchUpdate(LoggerUpdate):
# consolidated_research_object_results_str: str | None = None
class ConsolidatedResearchUpdate(LoggerUpdate):
consolidated_research_object_results_str: str | None = None
# class DeepSearchFilterUpdate(LoggerUpdate):
# vespa_filter_results: KGFilterConstructionResults | None = None
# div_con_entities: list[str] | None = None
# source_division: bool | None = None
# global_entity_filters: list[str] | None = None
# global_relationship_filters: list[str] | None = None
# local_entity_filters: list[list[str]] | None = None
# source_filters: list[str] | None = None
class DeepSearchFilterUpdate(LoggerUpdate):
vespa_filter_results: KGFilterConstructionResults | None = None
div_con_entities: list[str] | None = None
source_division: bool | None = None
global_entity_filters: list[str] | None = None
global_relationship_filters: list[str] | None = None
local_entity_filters: list[list[str]] | None = None
source_filters: list[str] | None = None
# class ResearchObjectOutput(LoggerUpdate):
# research_object_results: Annotated[list[dict[str, Any]], add] = []
class ResearchObjectOutput(LoggerUpdate):
research_object_results: Annotated[list[dict[str, Any]], add] = []
# class EntityRelationshipExtractionUpdate(LoggerUpdate):
# entities_types_str: str = ""
# relationship_types_str: str = ""
# extracted_entities_w_attributes: list[str] = []
# extracted_entities_no_attributes: list[str] = []
# extracted_relationships: list[str] = []
# time_filter: str | None = None
# kg_doc_temp_view_name: str | None = None
# kg_rel_temp_view_name: str | None = None
# kg_entity_temp_view_name: str | None = None
class EntityRelationshipExtractionUpdate(LoggerUpdate):
entities_types_str: str = ""
relationship_types_str: str = ""
extracted_entities_w_attributes: list[str] = []
extracted_entities_no_attributes: list[str] = []
extracted_relationships: list[str] = []
time_filter: str | None = None
kg_doc_temp_view_name: str | None = None
kg_rel_temp_view_name: str | None = None
kg_entity_temp_view_name: str | None = None
# class ResultsDataUpdate(LoggerUpdate):
# query_results_data_str: str | None = None
# individualized_query_results_data_str: str | None = None
# reference_results_str: str | None = None
class ResultsDataUpdate(LoggerUpdate):
query_results_data_str: str | None = None
individualized_query_results_data_str: str | None = None
reference_results_str: str | None = None
# class ResearchObjectUpdate(LoggerUpdate):
# research_object_results: Annotated[list[dict[str, Any]], add] = []
class ResearchObjectUpdate(LoggerUpdate):
research_object_results: Annotated[list[dict[str, Any]], add] = []
# ## Graph Input State
# class MainInput(CoreState):
# question: str
# individual_flow: bool = True # used for UI display purposes
## Graph Input State
class MainInput(CoreState):
question: str
individual_flow: bool = True # used for UI display purposes
# class FinalAnswerUpdate(LoggerUpdate):
# final_answer: str | None = None
# retrieved_documents: list[InferenceSection] | None = None
class FinalAnswerUpdate(LoggerUpdate):
final_answer: str | None = None
retrieved_documents: list[InferenceSection] | None = None
# ## Graph State
# class MainState(
# # This includes the core state
# MainInput,
# ToolChoiceInput,
# ToolCallUpdate,
# ToolChoiceUpdate,
# EntityRelationshipExtractionUpdate,
# AnalysisUpdate,
# SQLSimpleGenerationUpdate,
# ResultsDataUpdate,
# ResearchObjectOutput,
# DeepSearchFilterUpdate,
# ResearchObjectUpdate,
# ConsolidatedResearchUpdate,
# FinalAnswerUpdate,
# ):
# pass
## Graph State
class MainState(
# This includes the core state
MainInput,
ToolChoiceInput,
ToolCallUpdate,
ToolChoiceUpdate,
EntityRelationshipExtractionUpdate,
AnalysisUpdate,
SQLSimpleGenerationUpdate,
ResultsDataUpdate,
ResearchObjectOutput,
DeepSearchFilterUpdate,
ResearchObjectUpdate,
ConsolidatedResearchUpdate,
FinalAnswerUpdate,
):
pass
# ## Graph Output State - presently not used
# class MainOutput(TypedDict):
# log_messages: list[str]
# final_answer: str | None
# retrieved_documents: list[InferenceSection] | None
## Graph Output State - presently not used
class MainOutput(TypedDict):
log_messages: list[str]
final_answer: str | None
retrieved_documents: list[InferenceSection] | None
# class ResearchObjectInput(LoggerUpdate):
# research_nr: int
# entity: str
# broken_down_question: str
# vespa_filter_results: KGFilterConstructionResults
# source_division: bool | None
# source_entity_filters: list[str] | None
# segment_type: str
# individual_flow: bool = True # used for UI display purposes
class ResearchObjectInput(LoggerUpdate):
research_nr: int
entity: str
broken_down_question: str
vespa_filter_results: KGFilterConstructionResults
source_division: bool | None
source_entity_filters: list[str] | None
segment_type: str
individual_flow: bool = True # used for UI display purposes

View File

@@ -1,33 +1,33 @@
# from onyx.agents.agent_search.kb_search.models import KGSteps
from onyx.agents.agent_search.kb_search.models import KGSteps
# KG_SEARCH_STEP_DESCRIPTIONS: dict[int, KGSteps] = {
# 1: KGSteps(
# description="Analyzing the question...",
# activities=[
# "Entities in Query",
# "Relationships in Query",
# "Terms in Query",
# "Time Filters",
# ],
# ),
# 2: KGSteps(
# description="Planning the response approach...",
# activities=["Query Execution Strategy", "Answer Format"],
# ),
# 3: KGSteps(
# description="Querying the Knowledge Graph...",
# activities=[
# "Knowledge Graph Query",
# "Knowledge Graph Query Results",
# "Query for Source Documents",
# "Source Documents",
# ],
# ),
# 4: KGSteps(
# description="Conducting further research on source documents...", activities=[]
# ),
# }
KG_SEARCH_STEP_DESCRIPTIONS: dict[int, KGSteps] = {
1: KGSteps(
description="Analyzing the question...",
activities=[
"Entities in Query",
"Relationships in Query",
"Terms in Query",
"Time Filters",
],
),
2: KGSteps(
description="Planning the response approach...",
activities=["Query Execution Strategy", "Answer Format"],
),
3: KGSteps(
description="Querying the Knowledge Graph...",
activities=[
"Knowledge Graph Query",
"Knowledge Graph Query Results",
"Query for Source Documents",
"Source Documents",
],
),
4: KGSteps(
description="Conducting further research on source documents...", activities=[]
),
}
# BASIC_SEARCH_STEP_DESCRIPTIONS: dict[int, KGSteps] = {
# 1: KGSteps(description="Conducting a standard search...", activities=[]),
# }
BASIC_SEARCH_STEP_DESCRIPTIONS: dict[int, KGSteps] = {
1: KGSteps(description="Conducting a standard search...", activities=[]),
}

View File

@@ -1,89 +1,88 @@
# from uuid import UUID
from uuid import UUID
# from pydantic import BaseModel
# from sqlalchemy.orm import Session
from pydantic import BaseModel
from pydantic import ConfigDict
from sqlalchemy.orm import Session
# from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
# from onyx.context.search.models import RerankingDetails
# from onyx.db.models import Persona
# from onyx.file_store.utils import InMemoryChatFile
# from onyx.kg.models import KGConfigSettings
# from onyx.llm.interfaces import LLM
# from onyx.tools.force import ForceUseTool
# from onyx.tools.tool import Tool
# from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.agents.agent_search.dr.enums import ResearchType
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.context.search.models import RerankingDetails
from onyx.db.models import Persona
from onyx.file_store.utils import InMemoryChatFile
from onyx.kg.models import KGConfigSettings
from onyx.llm.interfaces import LLM
from onyx.tools.force import ForceUseTool
from onyx.tools.tool import Tool
from onyx.tools.tool_implementations.search.search_tool import SearchTool
# class GraphInputs(BaseModel):
# """Input data required for the graph execution"""
class GraphInputs(BaseModel):
"""Input data required for the graph execution"""
# persona: Persona | None = None
# rerank_settings: RerankingDetails | None = None
# prompt_builder: AnswerPromptBuilder
# files: list[InMemoryChatFile] | None = None
# structured_response_format: dict | None = None
# project_instructions: str | None = None
persona: Persona | None = None
rerank_settings: RerankingDetails | None = None
prompt_builder: AnswerPromptBuilder
files: list[InMemoryChatFile] | None = None
structured_response_format: dict | None = None
project_instructions: str | None = None
# class Config:
# arbitrary_types_allowed = True
model_config = ConfigDict(arbitrary_types_allowed=True)
# class GraphTooling(BaseModel):
# """Tools and LLMs available to the graph"""
class GraphTooling(BaseModel):
"""Tools and LLMs available to the graph"""
# primary_llm: LLM
# fast_llm: LLM
# search_tool: SearchTool | None = None
# tools: list[Tool]
# # Whether to force use of a tool, or to
# # force tool args IF the tool is used
# force_use_tool: ForceUseTool
# using_tool_calling_llm: bool = False
primary_llm: LLM
fast_llm: LLM
search_tool: SearchTool | None = None
tools: list[Tool]
# Whether to force use of a tool, or to
# force tool args IF the tool is used
force_use_tool: ForceUseTool
using_tool_calling_llm: bool = False
# class Config:
# arbitrary_types_allowed = True
model_config = ConfigDict(arbitrary_types_allowed=True)
# class GraphPersistence(BaseModel):
# """Configuration for data persistence"""
class GraphPersistence(BaseModel):
"""Configuration for data persistence"""
# chat_session_id: UUID
# # The message ID of the to-be-created first agent message
# # in response to the user message that triggered the Pro Search
# message_id: int
chat_session_id: UUID
# The message ID of the to-be-created first agent message
# in response to the user message that triggered the Pro Search
message_id: int
# # The database session the user and initial agent
# # message were flushed to; only needed for agentic search
# db_session: Session
# The database session the user and initial agent
# message were flushed to; only needed for agentic search
db_session: Session
# class Config:
# arbitrary_types_allowed = True
model_config = ConfigDict(arbitrary_types_allowed=True)
# class GraphSearchConfig(BaseModel):
# """Configuration controlling search behavior"""
class GraphSearchConfig(BaseModel):
"""Configuration controlling search behavior"""
# use_agentic_search: bool = False
# # Whether to perform initial search to inform decomposition
# perform_initial_search_decomposition: bool = True
use_agentic_search: bool = False
# Whether to perform initial search to inform decomposition
perform_initial_search_decomposition: bool = True
# # Whether to allow creation of refinement questions (and entity extraction, etc.)
# allow_refinement: bool = True
# skip_gen_ai_answer_generation: bool = False
# allow_agent_reranking: bool = False
# kg_config_settings: KGConfigSettings = KGConfigSettings()
# Whether to allow creation of refinement questions (and entity extraction, etc.)
allow_refinement: bool = True
skip_gen_ai_answer_generation: bool = False
allow_agent_reranking: bool = False
kg_config_settings: KGConfigSettings = KGConfigSettings()
research_type: ResearchType = ResearchType.THOUGHTFUL
# class GraphConfig(BaseModel):
# """
# Main container for data needed for Langgraph execution
# """
class GraphConfig(BaseModel):
"""
Main container for data needed for Langgraph execution
"""
# inputs: GraphInputs
# tooling: GraphTooling
# behavior: GraphSearchConfig
# # Only needed for agentic search
# persistence: GraphPersistence
inputs: GraphInputs
tooling: GraphTooling
behavior: GraphSearchConfig
# Only needed for agentic search
persistence: GraphPersistence
# class Config:
# arbitrary_types_allowed = True
model_config = ConfigDict(arbitrary_types_allowed=True)

View File

@@ -1,50 +1,50 @@
# from pydantic import BaseModel
from pydantic import BaseModel
from pydantic import ConfigDict
# from onyx.chat.prompt_builder.schemas import PromptSnapshot
# from onyx.tools.message import ToolCallSummary
# from onyx.tools.models import SearchToolOverrideKwargs
# from onyx.tools.models import ToolCallFinalResult
# from onyx.tools.models import ToolCallKickoff
# from onyx.tools.models import ToolResponse
# from onyx.tools.tool import Tool
from onyx.chat.prompt_builder.schemas import PromptSnapshot
from onyx.tools.message import ToolCallSummary
from onyx.tools.models import SearchToolOverrideKwargs
from onyx.tools.models import ToolCallFinalResult
from onyx.tools.models import ToolCallKickoff
from onyx.tools.models import ToolResponse
from onyx.tools.tool import Tool
# # TODO: adapt the tool choice/tool call to allow for parallel tool calls by
# # creating a subgraph that can be invoked in parallel via Send/Command APIs
# class ToolChoiceInput(BaseModel):
# should_stream_answer: bool = True
# # default to the prompt builder from the config, but
# # allow overrides for arbitrary tool calls
# prompt_snapshot: PromptSnapshot | None = None
# TODO: adapt the tool choice/tool call to allow for parallel tool calls by
# creating a subgraph that can be invoked in parallel via Send/Command APIs
class ToolChoiceInput(BaseModel):
should_stream_answer: bool = True
# default to the prompt builder from the config, but
# allow overrides for arbitrary tool calls
prompt_snapshot: PromptSnapshot | None = None
# # names of tools to use for tool calling. Filters the tools available in the config
# tools: list[str] = []
# names of tools to use for tool calling. Filters the tools available in the config
tools: list[str] = []
# class ToolCallOutput(BaseModel):
# tool_call_summary: ToolCallSummary
# tool_call_kickoff: ToolCallKickoff
# tool_call_responses: list[ToolResponse]
# tool_call_final_result: ToolCallFinalResult
class ToolCallOutput(BaseModel):
tool_call_summary: ToolCallSummary
tool_call_kickoff: ToolCallKickoff
tool_call_responses: list[ToolResponse]
tool_call_final_result: ToolCallFinalResult
# class ToolCallUpdate(BaseModel):
# tool_call_output: ToolCallOutput | None = None
class ToolCallUpdate(BaseModel):
tool_call_output: ToolCallOutput | None = None
# class ToolChoice(BaseModel):
# tool: Tool
# tool_args: dict
# id: str | None
# search_tool_override_kwargs: SearchToolOverrideKwargs = SearchToolOverrideKwargs()
class ToolChoice(BaseModel):
tool: Tool
tool_args: dict
id: str | None
search_tool_override_kwargs: SearchToolOverrideKwargs = SearchToolOverrideKwargs()
# class Config:
# arbitrary_types_allowed = True
model_config = ConfigDict(arbitrary_types_allowed=True)
# class ToolChoiceUpdate(BaseModel):
# tool_choice: ToolChoice | None = None
class ToolChoiceUpdate(BaseModel):
tool_choice: ToolChoice | None = None
# class ToolChoiceState(ToolChoiceUpdate, ToolChoiceInput):
# pass
class ToolChoiceState(ToolChoiceUpdate, ToolChoiceInput):
pass

View File

@@ -1,93 +1,93 @@
# from collections.abc import Iterable
# from typing import cast
from collections.abc import Iterable
from typing import cast
# from langchain_core.runnables.schema import CustomStreamEvent
# from langchain_core.runnables.schema import StreamEvent
# from langfuse.langchain import CallbackHandler
# from langgraph.graph.state import CompiledStateGraph
from langchain_core.runnables.schema import CustomStreamEvent
from langchain_core.runnables.schema import StreamEvent
from langfuse.langchain import CallbackHandler
from langgraph.graph.state import CompiledStateGraph
# from onyx.agents.agent_search.dc_search_analysis.graph_builder import (
# divide_and_conquer_graph_builder,
# )
# from onyx.agents.agent_search.dc_search_analysis.states import MainInput as DCMainInput
# from onyx.agents.agent_search.dr.graph_builder import dr_graph_builder
# from onyx.agents.agent_search.dr.states import MainInput as DRMainInput
# from onyx.agents.agent_search.kb_search.graph_builder import kb_graph_builder
# from onyx.agents.agent_search.kb_search.states import MainInput as KBMainInput
# from onyx.agents.agent_search.models import GraphConfig
# from onyx.chat.models import AnswerStream
# from onyx.configs.app_configs import LANGFUSE_PUBLIC_KEY
# from onyx.configs.app_configs import LANGFUSE_SECRET_KEY
# from onyx.server.query_and_chat.streaming_models import Packet
# from onyx.utils.logger import setup_logger
from onyx.agents.agent_search.dc_search_analysis.graph_builder import (
divide_and_conquer_graph_builder,
)
from onyx.agents.agent_search.dc_search_analysis.states import MainInput as DCMainInput
from onyx.agents.agent_search.dr.graph_builder import dr_graph_builder
from onyx.agents.agent_search.dr.states import MainInput as DRMainInput
from onyx.agents.agent_search.kb_search.graph_builder import kb_graph_builder
from onyx.agents.agent_search.kb_search.states import MainInput as KBMainInput
from onyx.agents.agent_search.models import GraphConfig
from onyx.chat.models import AnswerStream
from onyx.configs.app_configs import LANGFUSE_PUBLIC_KEY
from onyx.configs.app_configs import LANGFUSE_SECRET_KEY
from onyx.server.query_and_chat.streaming_models import Packet
from onyx.utils.logger import setup_logger
# logger = setup_logger()
# GraphInput = DCMainInput | KBMainInput | DRMainInput
logger = setup_logger()
GraphInput = DCMainInput | KBMainInput | DRMainInput
# def manage_sync_streaming(
# compiled_graph: CompiledStateGraph,
# config: GraphConfig,
# graph_input: GraphInput,
# ) -> Iterable[StreamEvent]:
# message_id = config.persistence.message_id if config.persistence else None
# callbacks: list[CallbackHandler] = []
# if LANGFUSE_SECRET_KEY and LANGFUSE_PUBLIC_KEY:
# callbacks.append(CallbackHandler())
# for event in compiled_graph.stream(
# stream_mode="custom",
# input=graph_input,
# config={
# "metadata": {"config": config, "thread_id": str(message_id)},
# "callbacks": callbacks, # type: ignore
# },
# ):
# yield cast(CustomStreamEvent, event)
def manage_sync_streaming(
compiled_graph: CompiledStateGraph,
config: GraphConfig,
graph_input: GraphInput,
) -> Iterable[StreamEvent]:
message_id = config.persistence.message_id if config.persistence else None
callbacks: list[CallbackHandler] = []
if LANGFUSE_SECRET_KEY and LANGFUSE_PUBLIC_KEY:
callbacks.append(CallbackHandler())
for event in compiled_graph.stream(
stream_mode="custom",
input=graph_input,
config={
"metadata": {"config": config, "thread_id": str(message_id)},
"callbacks": callbacks, # type: ignore
},
):
yield cast(CustomStreamEvent, event)
# def run_graph(
# compiled_graph: CompiledStateGraph,
# config: GraphConfig,
# input: GraphInput,
# ) -> AnswerStream:
def run_graph(
compiled_graph: CompiledStateGraph,
config: GraphConfig,
input: GraphInput,
) -> AnswerStream:
# for event in manage_sync_streaming(
# compiled_graph=compiled_graph, config=config, graph_input=input
# ):
for event in manage_sync_streaming(
compiled_graph=compiled_graph, config=config, graph_input=input
):
# yield cast(Packet, event["data"])
yield cast(Packet, event["data"])
# def run_kb_graph(
# config: GraphConfig,
# ) -> AnswerStream:
# graph = kb_graph_builder()
# compiled_graph = graph.compile()
# input = KBMainInput(
# log_messages=[], question=config.inputs.prompt_builder.raw_user_query
# )
def run_kb_graph(
config: GraphConfig,
) -> AnswerStream:
graph = kb_graph_builder()
compiled_graph = graph.compile()
input = KBMainInput(
log_messages=[], question=config.inputs.prompt_builder.raw_user_query
)
# yield from run_graph(compiled_graph, config, input)
yield from run_graph(compiled_graph, config, input)
# def run_dr_graph(
# config: GraphConfig,
# ) -> AnswerStream:
# graph = dr_graph_builder()
# compiled_graph = graph.compile()
# input = DRMainInput(log_messages=[])
def run_dr_graph(
config: GraphConfig,
) -> AnswerStream:
graph = dr_graph_builder()
compiled_graph = graph.compile()
input = DRMainInput(log_messages=[])
# yield from run_graph(compiled_graph, config, input)
yield from run_graph(compiled_graph, config, input)
# def run_dc_graph(
# config: GraphConfig,
# ) -> AnswerStream:
# graph = divide_and_conquer_graph_builder()
# compiled_graph = graph.compile()
# input = DCMainInput(log_messages=[])
# config.inputs.prompt_builder.raw_user_query = (
# config.inputs.prompt_builder.raw_user_query.strip()
# )
# return run_graph(compiled_graph, config, input)
def run_dc_graph(
config: GraphConfig,
) -> AnswerStream:
graph = divide_and_conquer_graph_builder()
compiled_graph = graph.compile()
input = DCMainInput(log_messages=[])
config.inputs.prompt_builder.raw_user_query = (
config.inputs.prompt_builder.raw_user_query.strip()
)
return run_graph(compiled_graph, config, input)

View File

@@ -1,176 +1,176 @@
# from langchain.schema import AIMessage
# from langchain.schema import HumanMessage
# from langchain.schema import SystemMessage
# from langchain_core.messages.tool import ToolMessage
from langchain.schema import AIMessage
from langchain.schema import HumanMessage
from langchain.schema import SystemMessage
from langchain_core.messages.tool import ToolMessage
# from onyx.agents.agent_search.models import GraphConfig
# from onyx.agents.agent_search.shared_graph_utils.models import (
# AgentPromptEnrichmentComponents,
# )
# from onyx.agents.agent_search.shared_graph_utils.utils import format_docs
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_persona_agent_prompt_expressions,
# )
# from onyx.agents.agent_search.shared_graph_utils.utils import remove_document_citations
# from onyx.agents.agent_search.shared_graph_utils.utils import summarize_history
# from onyx.configs.agent_configs import AGENT_MAX_STATIC_HISTORY_WORD_LENGTH
# from onyx.configs.constants import MessageType
# from onyx.context.search.models import InferenceSection
# from onyx.llm.interfaces import LLMConfig
# from onyx.natural_language_processing.utils import get_tokenizer
# from onyx.natural_language_processing.utils import tokenizer_trim_content
# from onyx.prompts.agent_search import HISTORY_FRAMING_PROMPT
# from onyx.prompts.agent_search import SUB_QUESTION_RAG_PROMPT
# from onyx.prompts.prompt_utils import build_date_time_string
# from onyx.utils.logger import setup_logger
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.models import (
AgentPromptEnrichmentComponents,
)
from onyx.agents.agent_search.shared_graph_utils.utils import format_docs
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_persona_agent_prompt_expressions,
)
from onyx.agents.agent_search.shared_graph_utils.utils import remove_document_citations
from onyx.agents.agent_search.shared_graph_utils.utils import summarize_history
from onyx.configs.agent_configs import AGENT_MAX_STATIC_HISTORY_WORD_LENGTH
from onyx.configs.constants import MessageType
from onyx.context.search.models import InferenceSection
from onyx.llm.interfaces import LLMConfig
from onyx.natural_language_processing.utils import get_tokenizer
from onyx.natural_language_processing.utils import tokenizer_trim_content
from onyx.prompts.agent_search import HISTORY_FRAMING_PROMPT
from onyx.prompts.agent_search import SUB_QUESTION_RAG_PROMPT
from onyx.prompts.prompt_utils import build_date_time_string
from onyx.utils.logger import setup_logger
# logger = setup_logger()
logger = setup_logger()
# def build_sub_question_answer_prompt(
# question: str,
# original_question: str,
# docs: list[InferenceSection],
# persona_specification: str,
# config: LLMConfig,
# ) -> list[SystemMessage | HumanMessage | AIMessage | ToolMessage]:
# system_message = SystemMessage(
# content=persona_specification,
# )
def build_sub_question_answer_prompt(
question: str,
original_question: str,
docs: list[InferenceSection],
persona_specification: str,
config: LLMConfig,
) -> list[SystemMessage | HumanMessage | AIMessage | ToolMessage]:
system_message = SystemMessage(
content=persona_specification,
)
# date_str = build_date_time_string()
date_str = build_date_time_string()
# docs_str = format_docs(docs)
docs_str = format_docs(docs)
# docs_str = trim_prompt_piece(
# config=config,
# prompt_piece=docs_str,
# reserved_str=SUB_QUESTION_RAG_PROMPT + question + original_question + date_str,
# )
# human_message = HumanMessage(
# content=SUB_QUESTION_RAG_PROMPT.format(
# question=question,
# original_question=original_question,
# context=docs_str,
# date_prompt=date_str,
# )
# )
docs_str = trim_prompt_piece(
config=config,
prompt_piece=docs_str,
reserved_str=SUB_QUESTION_RAG_PROMPT + question + original_question + date_str,
)
human_message = HumanMessage(
content=SUB_QUESTION_RAG_PROMPT.format(
question=question,
original_question=original_question,
context=docs_str,
date_prompt=date_str,
)
)
# return [system_message, human_message]
return [system_message, human_message]
# def trim_prompt_piece(config: LLMConfig, prompt_piece: str, reserved_str: str) -> str:
# # no need to trim if a conservative estimate of one token
# # per character is already less than the max tokens
# if len(prompt_piece) + len(reserved_str) < config.max_input_tokens:
# return prompt_piece
def trim_prompt_piece(config: LLMConfig, prompt_piece: str, reserved_str: str) -> str:
# no need to trim if a conservative estimate of one token
# per character is already less than the max tokens
if len(prompt_piece) + len(reserved_str) < config.max_input_tokens:
return prompt_piece
# llm_tokenizer = get_tokenizer(
# provider_type=config.model_provider,
# model_name=config.model_name,
# )
llm_tokenizer = get_tokenizer(
provider_type=config.model_provider,
model_name=config.model_name,
)
# # slightly conservative trimming
# return tokenizer_trim_content(
# content=prompt_piece,
# desired_length=config.max_input_tokens
# - len(llm_tokenizer.encode(reserved_str)),
# tokenizer=llm_tokenizer,
# )
# slightly conservative trimming
return tokenizer_trim_content(
content=prompt_piece,
desired_length=config.max_input_tokens
- len(llm_tokenizer.encode(reserved_str)),
tokenizer=llm_tokenizer,
)
# def build_history_prompt(config: GraphConfig, question: str) -> str:
# prompt_builder = config.inputs.prompt_builder
# persona_base = get_persona_agent_prompt_expressions(
# config.inputs.persona, db_session=config.persistence.db_session
# ).base_prompt
def build_history_prompt(config: GraphConfig, question: str) -> str:
prompt_builder = config.inputs.prompt_builder
persona_base = get_persona_agent_prompt_expressions(
config.inputs.persona, db_session=config.persistence.db_session
).base_prompt
# if prompt_builder is None:
# return ""
if prompt_builder is None:
return ""
# if prompt_builder.single_message_history is not None:
# history = prompt_builder.single_message_history
# else:
# history_components = []
# previous_message_type = None
# for message in prompt_builder.raw_message_history:
# if message.message_type == MessageType.USER:
# history_components.append(f"User: {message.message}\n")
# previous_message_type = MessageType.USER
# elif message.message_type == MessageType.ASSISTANT:
# # Previously there could be multiple assistant messages in a row
# # Now this is handled at the message history construction
# assert previous_message_type is not MessageType.ASSISTANT
# history_components.append(f"You/Agent: {message.message}\n")
# previous_message_type = MessageType.ASSISTANT
# else:
# # Other message types are not included here, currently there should be no other message types
# logger.error(
# f"Unhandled message type: {message.message_type} with message: {message.message}"
# )
# continue
if prompt_builder.single_message_history is not None:
history = prompt_builder.single_message_history
else:
history_components = []
previous_message_type = None
for message in prompt_builder.raw_message_history:
if message.message_type == MessageType.USER:
history_components.append(f"User: {message.message}\n")
previous_message_type = MessageType.USER
elif message.message_type == MessageType.ASSISTANT:
# Previously there could be multiple assistant messages in a row
# Now this is handled at the message history construction
assert previous_message_type is not MessageType.ASSISTANT
history_components.append(f"You/Agent: {message.message}\n")
previous_message_type = MessageType.ASSISTANT
else:
# Other message types are not included here, currently there should be no other message types
logger.error(
f"Unhandled message type: {message.message_type} with message: {message.message}"
)
continue
# history = "\n".join(history_components)
# history = remove_document_citations(history)
# if len(history.split()) > AGENT_MAX_STATIC_HISTORY_WORD_LENGTH:
# history = summarize_history(
# history=history,
# question=question,
# persona_specification=persona_base,
# llm=config.tooling.fast_llm,
# )
history = "\n".join(history_components)
history = remove_document_citations(history)
if len(history.split()) > AGENT_MAX_STATIC_HISTORY_WORD_LENGTH:
history = summarize_history(
history=history,
question=question,
persona_specification=persona_base,
llm=config.tooling.fast_llm,
)
# return HISTORY_FRAMING_PROMPT.format(history=history) if history else ""
return HISTORY_FRAMING_PROMPT.format(history=history) if history else ""
# def get_prompt_enrichment_components(
# config: GraphConfig,
# ) -> AgentPromptEnrichmentComponents:
# persona_prompts = get_persona_agent_prompt_expressions(
# config.inputs.persona, db_session=config.persistence.db_session
# )
def get_prompt_enrichment_components(
config: GraphConfig,
) -> AgentPromptEnrichmentComponents:
persona_prompts = get_persona_agent_prompt_expressions(
config.inputs.persona, db_session=config.persistence.db_session
)
# history = build_history_prompt(config, config.inputs.prompt_builder.raw_user_query)
history = build_history_prompt(config, config.inputs.prompt_builder.raw_user_query)
# date_str = build_date_time_string()
date_str = build_date_time_string()
# return AgentPromptEnrichmentComponents(
# persona_prompts=persona_prompts,
# history=history,
# date_str=date_str,
# )
return AgentPromptEnrichmentComponents(
persona_prompts=persona_prompts,
history=history,
date_str=date_str,
)
# def binary_string_test(text: str, positive_value: str = "yes") -> bool:
# """
# Tests if a string contains a positive value (case-insensitive).
def binary_string_test(text: str, positive_value: str = "yes") -> bool:
"""
Tests if a string contains a positive value (case-insensitive).
# Args:
# text: The string to test
# positive_value: The value to look for (defaults to "yes")
Args:
text: The string to test
positive_value: The value to look for (defaults to "yes")
# Returns:
# True if the positive value is found in the text
# """
# return positive_value.lower() in text.lower()
Returns:
True if the positive value is found in the text
"""
return positive_value.lower() in text.lower()
# def binary_string_test_after_answer_separator(
# text: str, positive_value: str = "yes", separator: str = "Answer:"
# ) -> bool:
# """
# Tests if a string contains a positive value (case-insensitive).
def binary_string_test_after_answer_separator(
text: str, positive_value: str = "yes", separator: str = "Answer:"
) -> bool:
"""
Tests if a string contains a positive value (case-insensitive).
# Args:
# text: The string to test
# positive_value: The value to look for (defaults to "yes")
Args:
text: The string to test
positive_value: The value to look for (defaults to "yes")
# Returns:
# True if the positive value is found in the text
# """
Returns:
True if the positive value is found in the text
"""
# if separator not in text:
# return False
# relevant_text = text.split(f"{separator}")[-1]
if separator not in text:
return False
relevant_text = text.split(f"{separator}")[-1]
# return binary_string_test(relevant_text, positive_value)
return binary_string_test(relevant_text, positive_value)

View File

@@ -1,205 +1,205 @@
# import numpy as np
import numpy as np
# from onyx.agents.agent_search.shared_graph_utils.models import AnswerGenerationDocuments
# from onyx.agents.agent_search.shared_graph_utils.models import RetrievalFitScoreMetrics
# from onyx.agents.agent_search.shared_graph_utils.models import RetrievalFitStats
# from onyx.agents.agent_search.shared_graph_utils.operators import (
# dedup_inference_section_list,
# )
# from onyx.chat.models import SectionRelevancePiece
# from onyx.context.search.models import InferenceSection
# from onyx.utils.logger import setup_logger
from onyx.agents.agent_search.shared_graph_utils.models import AnswerGenerationDocuments
from onyx.agents.agent_search.shared_graph_utils.models import RetrievalFitScoreMetrics
from onyx.agents.agent_search.shared_graph_utils.models import RetrievalFitStats
from onyx.agents.agent_search.shared_graph_utils.operators import (
dedup_inference_section_list,
)
from onyx.chat.models import SectionRelevancePiece
from onyx.context.search.models import InferenceSection
from onyx.utils.logger import setup_logger
# logger = setup_logger()
logger = setup_logger()
# def unique_chunk_id(doc: InferenceSection) -> str:
# return f"{doc.center_chunk.document_id}_{doc.center_chunk.chunk_id}"
def unique_chunk_id(doc: InferenceSection) -> str:
return f"{doc.center_chunk.document_id}_{doc.center_chunk.chunk_id}"
# def calculate_rank_shift(list1: list, list2: list, top_n: int = 20) -> float:
# shift = 0
# for rank_first, doc_id in enumerate(list1[:top_n], 1):
# try:
# rank_second = list2.index(doc_id) + 1
# except ValueError:
# rank_second = len(list2) # Document not found in second list
def calculate_rank_shift(list1: list, list2: list, top_n: int = 20) -> float:
shift = 0
for rank_first, doc_id in enumerate(list1[:top_n], 1):
try:
rank_second = list2.index(doc_id) + 1
except ValueError:
rank_second = len(list2) # Document not found in second list
# shift += np.abs(rank_first - rank_second) / np.log(1 + rank_first * rank_second)
shift += np.abs(rank_first - rank_second) / np.log(1 + rank_first * rank_second)
# return shift / top_n
return shift / top_n
# def get_fit_scores(
# pre_reranked_results: list[InferenceSection],
# post_reranked_results: list[InferenceSection] | list[SectionRelevancePiece],
# ) -> RetrievalFitStats | None:
# """
# Calculate retrieval metrics for search purposes
# """
def get_fit_scores(
pre_reranked_results: list[InferenceSection],
post_reranked_results: list[InferenceSection] | list[SectionRelevancePiece],
) -> RetrievalFitStats | None:
"""
Calculate retrieval metrics for search purposes
"""
# if len(pre_reranked_results) == 0 or len(post_reranked_results) == 0:
# return None
if len(pre_reranked_results) == 0 or len(post_reranked_results) == 0:
return None
# ranked_sections = {
# "initial": pre_reranked_results,
# "reranked": post_reranked_results,
# }
ranked_sections = {
"initial": pre_reranked_results,
"reranked": post_reranked_results,
}
# fit_eval: RetrievalFitStats = RetrievalFitStats(
# fit_score_lift=0,
# rerank_effect=0,
# fit_scores={
# "initial": RetrievalFitScoreMetrics(scores={}, chunk_ids=[]),
# "reranked": RetrievalFitScoreMetrics(scores={}, chunk_ids=[]),
# },
# )
fit_eval: RetrievalFitStats = RetrievalFitStats(
fit_score_lift=0,
rerank_effect=0,
fit_scores={
"initial": RetrievalFitScoreMetrics(scores={}, chunk_ids=[]),
"reranked": RetrievalFitScoreMetrics(scores={}, chunk_ids=[]),
},
)
# for rank_type, docs in ranked_sections.items():
# logger.debug(f"rank_type: {rank_type}")
for rank_type, docs in ranked_sections.items():
logger.debug(f"rank_type: {rank_type}")
# for i in [1, 5, 10]:
# fit_eval.fit_scores[rank_type].scores[str(i)] = (
# sum(
# [
# float(doc.center_chunk.score)
# for doc in docs[:i]
# if isinstance(doc, InferenceSection)
# and doc.center_chunk.score is not None
# ]
# )
# / i
# )
for i in [1, 5, 10]:
fit_eval.fit_scores[rank_type].scores[str(i)] = (
sum(
[
float(doc.center_chunk.score)
for doc in docs[:i]
if isinstance(doc, InferenceSection)
and doc.center_chunk.score is not None
]
)
/ i
)
# fit_eval.fit_scores[rank_type].scores["fit_score"] = (
# 1
# / 3
# * (
# fit_eval.fit_scores[rank_type].scores["1"]
# + fit_eval.fit_scores[rank_type].scores["5"]
# + fit_eval.fit_scores[rank_type].scores["10"]
# )
# )
fit_eval.fit_scores[rank_type].scores["fit_score"] = (
1
/ 3
* (
fit_eval.fit_scores[rank_type].scores["1"]
+ fit_eval.fit_scores[rank_type].scores["5"]
+ fit_eval.fit_scores[rank_type].scores["10"]
)
)
# fit_eval.fit_scores[rank_type].scores["fit_score"] = fit_eval.fit_scores[
# rank_type
# ].scores["1"]
fit_eval.fit_scores[rank_type].scores["fit_score"] = fit_eval.fit_scores[
rank_type
].scores["1"]
# fit_eval.fit_scores[rank_type].chunk_ids = [
# unique_chunk_id(doc) for doc in docs if isinstance(doc, InferenceSection)
# ]
fit_eval.fit_scores[rank_type].chunk_ids = [
unique_chunk_id(doc) for doc in docs if isinstance(doc, InferenceSection)
]
# fit_eval.fit_score_lift = (
# fit_eval.fit_scores["reranked"].scores["fit_score"]
# / fit_eval.fit_scores["initial"].scores["fit_score"]
# )
fit_eval.fit_score_lift = (
fit_eval.fit_scores["reranked"].scores["fit_score"]
/ fit_eval.fit_scores["initial"].scores["fit_score"]
)
# fit_eval.rerank_effect = calculate_rank_shift(
# fit_eval.fit_scores["initial"].chunk_ids,
# fit_eval.fit_scores["reranked"].chunk_ids,
# )
fit_eval.rerank_effect = calculate_rank_shift(
fit_eval.fit_scores["initial"].chunk_ids,
fit_eval.fit_scores["reranked"].chunk_ids,
)
# return fit_eval
return fit_eval
# def get_answer_generation_documents(
# relevant_docs: list[InferenceSection],
# context_documents: list[InferenceSection],
# original_question_docs: list[InferenceSection],
# max_docs: int,
# ) -> AnswerGenerationDocuments:
# """
# Create a deduplicated list of documents to stream, prioritizing relevant docs.
def get_answer_generation_documents(
relevant_docs: list[InferenceSection],
context_documents: list[InferenceSection],
original_question_docs: list[InferenceSection],
max_docs: int,
) -> AnswerGenerationDocuments:
"""
Create a deduplicated list of documents to stream, prioritizing relevant docs.
# Args:
# relevant_docs: Primary documents to include
# context_documents: Additional context documents to append
# original_question_docs: Original question documents to append
# max_docs: Maximum number of documents to return
Args:
relevant_docs: Primary documents to include
context_documents: Additional context documents to append
original_question_docs: Original question documents to append
max_docs: Maximum number of documents to return
# Returns:
# List of deduplicated documents, limited to max_docs
# """
# # get relevant_doc ids
# relevant_doc_ids = [doc.center_chunk.document_id for doc in relevant_docs]
Returns:
List of deduplicated documents, limited to max_docs
"""
# get relevant_doc ids
relevant_doc_ids = [doc.center_chunk.document_id for doc in relevant_docs]
# # Start with relevant docs or fallback to original question docs
# streaming_documents = relevant_docs.copy()
# Start with relevant docs or fallback to original question docs
streaming_documents = relevant_docs.copy()
# # Use a set for O(1) lookups of document IDs
# seen_doc_ids = {doc.center_chunk.document_id for doc in streaming_documents}
# Use a set for O(1) lookups of document IDs
seen_doc_ids = {doc.center_chunk.document_id for doc in streaming_documents}
# # Combine additional documents to check in one iteration
# additional_docs = context_documents + original_question_docs
# for doc_idx, doc in enumerate(additional_docs):
# doc_id = doc.center_chunk.document_id
# if doc_id not in seen_doc_ids:
# streaming_documents.append(doc)
# seen_doc_ids.add(doc_id)
# Combine additional documents to check in one iteration
additional_docs = context_documents + original_question_docs
for doc_idx, doc in enumerate(additional_docs):
doc_id = doc.center_chunk.document_id
if doc_id not in seen_doc_ids:
streaming_documents.append(doc)
seen_doc_ids.add(doc_id)
# streaming_documents = dedup_inference_section_list(streaming_documents)
streaming_documents = dedup_inference_section_list(streaming_documents)
# relevant_streaming_docs = [
# doc
# for doc in streaming_documents
# if doc.center_chunk.document_id in relevant_doc_ids
# ]
# relevant_streaming_docs = dedup_sort_inference_section_list(relevant_streaming_docs)
relevant_streaming_docs = [
doc
for doc in streaming_documents
if doc.center_chunk.document_id in relevant_doc_ids
]
relevant_streaming_docs = dedup_sort_inference_section_list(relevant_streaming_docs)
# additional_streaming_docs = [
# doc
# for doc in streaming_documents
# if doc.center_chunk.document_id not in relevant_doc_ids
# ]
# additional_streaming_docs = dedup_sort_inference_section_list(
# additional_streaming_docs
# )
additional_streaming_docs = [
doc
for doc in streaming_documents
if doc.center_chunk.document_id not in relevant_doc_ids
]
additional_streaming_docs = dedup_sort_inference_section_list(
additional_streaming_docs
)
# for doc in additional_streaming_docs:
# if doc.center_chunk.score:
# doc.center_chunk.score += -2.0
# else:
# doc.center_chunk.score = -2.0
for doc in additional_streaming_docs:
if doc.center_chunk.score:
doc.center_chunk.score += -2.0
else:
doc.center_chunk.score = -2.0
# sorted_streaming_documents = relevant_streaming_docs + additional_streaming_docs
sorted_streaming_documents = relevant_streaming_docs + additional_streaming_docs
# return AnswerGenerationDocuments(
# streaming_documents=sorted_streaming_documents[:max_docs],
# context_documents=relevant_streaming_docs[:max_docs],
# )
return AnswerGenerationDocuments(
streaming_documents=sorted_streaming_documents[:max_docs],
context_documents=relevant_streaming_docs[:max_docs],
)
# def dedup_sort_inference_section_list(
# sections: list[InferenceSection],
# ) -> list[InferenceSection]:
# """Deduplicates InferenceSections by document_id and sorts by score.
def dedup_sort_inference_section_list(
sections: list[InferenceSection],
) -> list[InferenceSection]:
"""Deduplicates InferenceSections by document_id and sorts by score.
# Args:
# sections: List of InferenceSections to deduplicate and sort
Args:
sections: List of InferenceSections to deduplicate and sort
# Returns:
# Deduplicated list of InferenceSections sorted by score in descending order
# """
# # dedupe/merge with existing framework
# sections = dedup_inference_section_list(sections)
Returns:
Deduplicated list of InferenceSections sorted by score in descending order
"""
# dedupe/merge with existing framework
sections = dedup_inference_section_list(sections)
# # Use dict to deduplicate by document_id, keeping highest scored version
# unique_sections: dict[str, InferenceSection] = {}
# for section in sections:
# doc_id = section.center_chunk.document_id
# if doc_id not in unique_sections:
# unique_sections[doc_id] = section
# continue
# Use dict to deduplicate by document_id, keeping highest scored version
unique_sections: dict[str, InferenceSection] = {}
for section in sections:
doc_id = section.center_chunk.document_id
if doc_id not in unique_sections:
unique_sections[doc_id] = section
continue
# # Keep version with higher score
# existing_score = unique_sections[doc_id].center_chunk.score or 0
# new_score = section.center_chunk.score or 0
# if new_score > existing_score:
# unique_sections[doc_id] = section
# Keep version with higher score
existing_score = unique_sections[doc_id].center_chunk.score or 0
new_score = section.center_chunk.score or 0
if new_score > existing_score:
unique_sections[doc_id] = section
# # Sort by score in descending order, handling None scores
# sorted_sections = sorted(
# unique_sections.values(), key=lambda x: x.center_chunk.score or 0, reverse=True
# )
# Sort by score in descending order, handling None scores
sorted_sections = sorted(
unique_sections.values(), key=lambda x: x.center_chunk.score or 0, reverse=True
)
# return sorted_sections
return sorted_sections

Some files were not shown because too many files have changed in this diff Show More