Compare commits

...

3 Commits
v2.9.4 ... oops

Author SHA1 Message Date
Yuhong Sun
8a38fdf8a5 ok 2025-11-29 22:46:22 -08:00
Vega
9155d4aa21 [FIX] Fix citation document mismatch and standardize citation format (#6484) 2025-11-29 22:43:03 -08:00
Yuhong Sun
b20591611a Single Commit Rebased 2025-11-29 21:22:58 -08:00
280 changed files with 27579 additions and 26755 deletions

2
.gitignore vendored
View File

@@ -49,5 +49,7 @@ CLAUDE.md
# Local .terraform.lock.hcl file
.terraform.lock.hcl
node_modules
# MCP configs
.playwright-mcp

View File

@@ -0,0 +1,104 @@
"""add_open_url_tool
Revision ID: 4f8a2b3c1d9e
Revises: a852cbe15577
Create Date: 2025-11-24 12:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "4f8a2b3c1d9e"
down_revision = "a852cbe15577"
branch_labels = None
depends_on = None
OPEN_URL_TOOL = {
"name": "OpenURLTool",
"display_name": "Open URL",
"description": (
"The Open URL Action allows the agent to fetch and read contents of web pages."
),
"in_code_tool_id": "OpenURLTool",
"enabled": True,
}
def upgrade() -> None:
conn = op.get_bind()
# Check if tool already exists
existing = conn.execute(
sa.text("SELECT id FROM tool WHERE in_code_tool_id = :in_code_tool_id"),
{"in_code_tool_id": OPEN_URL_TOOL["in_code_tool_id"]},
).fetchone()
if existing:
tool_id = existing[0]
# Update existing tool
conn.execute(
sa.text(
"""
UPDATE tool
SET name = :name,
display_name = :display_name,
description = :description
WHERE in_code_tool_id = :in_code_tool_id
"""
),
OPEN_URL_TOOL,
)
else:
# Insert new tool
conn.execute(
sa.text(
"""
INSERT INTO tool (name, display_name, description, in_code_tool_id, enabled)
VALUES (:name, :display_name, :description, :in_code_tool_id, :enabled)
"""
),
OPEN_URL_TOOL,
)
# Get the newly inserted tool's id
result = conn.execute(
sa.text("SELECT id FROM tool WHERE in_code_tool_id = :in_code_tool_id"),
{"in_code_tool_id": OPEN_URL_TOOL["in_code_tool_id"]},
).fetchone()
tool_id = result[0] # type: ignore
# Associate the tool with all existing personas
# Get all persona IDs
persona_ids = conn.execute(sa.text("SELECT id FROM persona")).fetchall()
for (persona_id,) in persona_ids:
# Check if association already exists
exists = conn.execute(
sa.text(
"""
SELECT 1 FROM persona__tool
WHERE persona_id = :persona_id AND tool_id = :tool_id
"""
),
{"persona_id": persona_id, "tool_id": tool_id},
).fetchone()
if not exists:
conn.execute(
sa.text(
"""
INSERT INTO persona__tool (persona_id, tool_id)
VALUES (:persona_id, :tool_id)
"""
),
{"persona_id": persona_id, "tool_id": tool_id},
)
def downgrade() -> None:
# We don't remove the tool on downgrade since it's fine to have it around.
# If we upgrade again, it will be a no-op.
pass

View File

@@ -0,0 +1,572 @@
"""New Chat History
Revision ID: a852cbe15577
Revises: 6436661d5b65
Create Date: 2025-11-08 15:16:37.781308
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "a852cbe15577"
down_revision = "6436661d5b65"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Drop research agent tables (if they exist)
op.execute("DROP TABLE IF EXISTS research_agent_iteration_sub_step CASCADE")
op.execute("DROP TABLE IF EXISTS research_agent_iteration CASCADE")
# Drop agent sub query and sub question tables (if they exist)
op.execute("DROP TABLE IF EXISTS agent__sub_query__search_doc CASCADE")
op.execute("DROP TABLE IF EXISTS agent__sub_query CASCADE")
op.execute("DROP TABLE IF EXISTS agent__sub_question CASCADE")
# Update ChatMessage table
# Rename parent_message to parent_message_id and make it a foreign key (if not already done)
conn = op.get_bind()
result = conn.execute(
sa.text(
"""
SELECT column_name FROM information_schema.columns
WHERE table_name = 'chat_message' AND column_name = 'parent_message'
"""
)
)
if result.fetchone():
op.alter_column(
"chat_message", "parent_message", new_column_name="parent_message_id"
)
op.create_foreign_key(
"fk_chat_message_parent_message_id",
"chat_message",
"chat_message",
["parent_message_id"],
["id"],
)
# Rename latest_child_message to latest_child_message_id and make it a foreign key (if not already done)
result = conn.execute(
sa.text(
"""
SELECT column_name FROM information_schema.columns
WHERE table_name = 'chat_message' AND column_name = 'latest_child_message'
"""
)
)
if result.fetchone():
op.alter_column(
"chat_message",
"latest_child_message",
new_column_name="latest_child_message_id",
)
op.create_foreign_key(
"fk_chat_message_latest_child_message_id",
"chat_message",
"chat_message",
["latest_child_message_id"],
["id"],
)
# Add reasoning_tokens column (if not exists)
result = conn.execute(
sa.text(
"""
SELECT column_name FROM information_schema.columns
WHERE table_name = 'chat_message' AND column_name = 'reasoning_tokens'
"""
)
)
if not result.fetchone():
op.add_column(
"chat_message", sa.Column("reasoning_tokens", sa.Text(), nullable=True)
)
# Drop columns no longer needed (if they exist)
for col in [
"rephrased_query",
"alternate_assistant_id",
"overridden_model",
"is_agentic",
"refined_answer_improvement",
"research_type",
"research_plan",
"research_answer_purpose",
]:
result = conn.execute(
sa.text(
f"""
SELECT column_name FROM information_schema.columns
WHERE table_name = 'chat_message' AND column_name = '{col}'
"""
)
)
if result.fetchone():
op.drop_column("chat_message", col)
# Update ToolCall table
# Add chat_session_id column (if not exists)
result = conn.execute(
sa.text(
"""
SELECT column_name FROM information_schema.columns
WHERE table_name = 'tool_call' AND column_name = 'chat_session_id'
"""
)
)
if not result.fetchone():
op.add_column(
"tool_call",
sa.Column("chat_session_id", postgresql.UUID(as_uuid=True), nullable=False),
)
op.create_foreign_key(
"fk_tool_call_chat_session_id",
"tool_call",
"chat_session",
["chat_session_id"],
["id"],
)
# Rename message_id to parent_chat_message_id and make nullable (if not already done)
result = conn.execute(
sa.text(
"""
SELECT column_name FROM information_schema.columns
WHERE table_name = 'tool_call' AND column_name = 'message_id'
"""
)
)
if result.fetchone():
op.alter_column(
"tool_call",
"message_id",
new_column_name="parent_chat_message_id",
nullable=True,
)
# Add parent_tool_call_id (if not exists)
result = conn.execute(
sa.text(
"""
SELECT column_name FROM information_schema.columns
WHERE table_name = 'tool_call' AND column_name = 'parent_tool_call_id'
"""
)
)
if not result.fetchone():
op.add_column(
"tool_call", sa.Column("parent_tool_call_id", sa.Integer(), nullable=True)
)
op.create_foreign_key(
"fk_tool_call_parent_tool_call_id",
"tool_call",
"tool_call",
["parent_tool_call_id"],
["id"],
)
op.drop_constraint("uq_tool_call_message_id", "tool_call", type_="unique")
# Add turn_number, tool_id (if not exists)
for col_name in ["turn_number", "tool_id"]:
result = conn.execute(
sa.text(
f"""
SELECT column_name FROM information_schema.columns
WHERE table_name = 'tool_call' AND column_name = '{col_name}'
"""
)
)
if not result.fetchone():
op.add_column(
"tool_call",
sa.Column(col_name, sa.Integer(), nullable=False, server_default="0"),
)
# Add tool_call_id as String (if not exists)
result = conn.execute(
sa.text(
"""
SELECT column_name FROM information_schema.columns
WHERE table_name = 'tool_call' AND column_name = 'tool_call_id'
"""
)
)
if not result.fetchone():
op.add_column(
"tool_call",
sa.Column("tool_call_id", sa.String(), nullable=False, server_default=""),
)
# Add reasoning_tokens (if not exists)
result = conn.execute(
sa.text(
"""
SELECT column_name FROM information_schema.columns
WHERE table_name = 'tool_call' AND column_name = 'reasoning_tokens'
"""
)
)
if not result.fetchone():
op.add_column(
"tool_call", sa.Column("reasoning_tokens", sa.Text(), nullable=True)
)
# Rename tool_arguments to tool_call_arguments (if not already done)
result = conn.execute(
sa.text(
"""
SELECT column_name FROM information_schema.columns
WHERE table_name = 'tool_call' AND column_name = 'tool_arguments'
"""
)
)
if result.fetchone():
op.alter_column(
"tool_call", "tool_arguments", new_column_name="tool_call_arguments"
)
# Rename tool_result to tool_call_response and change type from JSONB to Text (if not already done)
result = conn.execute(
sa.text(
"""
SELECT column_name, data_type FROM information_schema.columns
WHERE table_name = 'tool_call' AND column_name = 'tool_result'
"""
)
)
tool_result_row = result.fetchone()
if tool_result_row:
op.alter_column(
"tool_call", "tool_result", new_column_name="tool_call_response"
)
# Change type from JSONB to Text
op.execute(
sa.text(
"""
ALTER TABLE tool_call
ALTER COLUMN tool_call_response TYPE TEXT
USING tool_call_response::text
"""
)
)
else:
# Check if tool_call_response already exists and is JSONB, then convert to Text
result = conn.execute(
sa.text(
"""
SELECT data_type FROM information_schema.columns
WHERE table_name = 'tool_call' AND column_name = 'tool_call_response'
"""
)
)
tool_call_response_row = result.fetchone()
if tool_call_response_row and tool_call_response_row[0] == "jsonb":
op.execute(
sa.text(
"""
ALTER TABLE tool_call
ALTER COLUMN tool_call_response TYPE TEXT
USING tool_call_response::text
"""
)
)
# Add tool_call_tokens (if not exists)
result = conn.execute(
sa.text(
"""
SELECT column_name FROM information_schema.columns
WHERE table_name = 'tool_call' AND column_name = 'tool_call_tokens'
"""
)
)
if not result.fetchone():
op.add_column(
"tool_call",
sa.Column(
"tool_call_tokens", sa.Integer(), nullable=False, server_default="0"
),
)
# Add generated_images column for image generation tool replay (if not exists)
result = conn.execute(
sa.text(
"""
SELECT column_name FROM information_schema.columns
WHERE table_name = 'tool_call' AND column_name = 'generated_images'
"""
)
)
if not result.fetchone():
op.add_column(
"tool_call",
sa.Column("generated_images", postgresql.JSONB(), nullable=True),
)
# Drop tool_name column (if exists)
result = conn.execute(
sa.text(
"""
SELECT column_name FROM information_schema.columns
WHERE table_name = 'tool_call' AND column_name = 'tool_name'
"""
)
)
if result.fetchone():
op.drop_column("tool_call", "tool_name")
# Create tool_call__search_doc association table (if not exists)
result = conn.execute(
sa.text(
"""
SELECT table_name FROM information_schema.tables
WHERE table_name = 'tool_call__search_doc'
"""
)
)
if not result.fetchone():
op.create_table(
"tool_call__search_doc",
sa.Column("tool_call_id", sa.Integer(), nullable=False),
sa.Column("search_doc_id", sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(
["tool_call_id"], ["tool_call.id"], ondelete="CASCADE"
),
sa.ForeignKeyConstraint(
["search_doc_id"], ["search_doc.id"], ondelete="CASCADE"
),
sa.PrimaryKeyConstraint("tool_call_id", "search_doc_id"),
)
# Add replace_base_system_prompt to persona table (if not exists)
result = conn.execute(
sa.text(
"""
SELECT column_name FROM information_schema.columns
WHERE table_name = 'persona' AND column_name = 'replace_base_system_prompt'
"""
)
)
if not result.fetchone():
op.add_column(
"persona",
sa.Column(
"replace_base_system_prompt",
sa.Boolean(),
nullable=False,
server_default="false",
),
)
def downgrade() -> None:
# Reverse persona changes
op.drop_column("persona", "replace_base_system_prompt")
# Drop tool_call__search_doc association table
op.execute("DROP TABLE IF EXISTS tool_call__search_doc CASCADE")
# Reverse ToolCall changes
op.add_column("tool_call", sa.Column("tool_name", sa.String(), nullable=False))
op.drop_column("tool_call", "tool_id")
op.drop_column("tool_call", "tool_call_tokens")
op.drop_column("tool_call", "generated_images")
# Change tool_call_response back to JSONB before renaming
op.execute(
sa.text(
"""
ALTER TABLE tool_call
ALTER COLUMN tool_call_response TYPE JSONB
USING tool_call_response::jsonb
"""
)
)
op.alter_column("tool_call", "tool_call_response", new_column_name="tool_result")
op.alter_column(
"tool_call", "tool_call_arguments", new_column_name="tool_arguments"
)
op.drop_column("tool_call", "reasoning_tokens")
op.drop_column("tool_call", "tool_call_id")
op.drop_column("tool_call", "turn_number")
op.drop_constraint(
"fk_tool_call_parent_tool_call_id", "tool_call", type_="foreignkey"
)
op.drop_column("tool_call", "parent_tool_call_id")
op.alter_column(
"tool_call",
"parent_chat_message_id",
new_column_name="message_id",
nullable=False,
)
op.drop_constraint("fk_tool_call_chat_session_id", "tool_call", type_="foreignkey")
op.drop_column("tool_call", "chat_session_id")
op.add_column(
"chat_message",
sa.Column(
"research_answer_purpose",
sa.Enum("INTRO", "DEEP_DIVE", name="researchanswerpurpose"),
nullable=True,
),
)
op.add_column(
"chat_message", sa.Column("research_plan", postgresql.JSONB(), nullable=True)
)
op.add_column(
"chat_message",
sa.Column(
"research_type",
sa.Enum("SIMPLE", "DEEP", name="researchtype"),
nullable=True,
),
)
op.add_column(
"chat_message",
sa.Column("refined_answer_improvement", sa.Boolean(), nullable=True),
)
op.add_column(
"chat_message",
sa.Column("is_agentic", sa.Boolean(), nullable=False, server_default="false"),
)
op.add_column(
"chat_message", sa.Column("overridden_model", sa.String(), nullable=True)
)
op.add_column(
"chat_message", sa.Column("alternate_assistant_id", sa.Integer(), nullable=True)
)
op.add_column(
"chat_message", sa.Column("rephrased_query", sa.Text(), nullable=True)
)
op.drop_column("chat_message", "reasoning_tokens")
op.drop_constraint(
"fk_chat_message_latest_child_message_id", "chat_message", type_="foreignkey"
)
op.alter_column(
"chat_message",
"latest_child_message_id",
new_column_name="latest_child_message",
)
op.drop_constraint(
"fk_chat_message_parent_message_id", "chat_message", type_="foreignkey"
)
op.alter_column(
"chat_message", "parent_message_id", new_column_name="parent_message"
)
# Recreate agent sub question and sub query tables
op.create_table(
"agent__sub_question",
sa.Column("id", sa.Integer(), primary_key=True),
sa.Column("primary_question_id", sa.Integer(), nullable=False),
sa.Column("chat_session_id", postgresql.UUID(as_uuid=True), nullable=False),
sa.Column("sub_question", sa.Text(), nullable=False),
sa.Column("level", sa.Integer(), nullable=False),
sa.Column("level_question_num", sa.Integer(), nullable=False),
sa.Column(
"time_created",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column("sub_answer", sa.Text(), nullable=False),
sa.Column("sub_question_doc_results", postgresql.JSONB(), nullable=False),
sa.ForeignKeyConstraint(
["primary_question_id"], ["chat_message.id"], ondelete="CASCADE"
),
sa.ForeignKeyConstraint(["chat_session_id"], ["chat_session.id"]),
sa.PrimaryKeyConstraint("id"),
)
op.create_table(
"agent__sub_query",
sa.Column("id", sa.Integer(), primary_key=True),
sa.Column("parent_question_id", sa.Integer(), nullable=False),
sa.Column("chat_session_id", postgresql.UUID(as_uuid=True), nullable=False),
sa.Column("sub_query", sa.Text(), nullable=False),
sa.Column(
"time_created",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.ForeignKeyConstraint(
["parent_question_id"], ["agent__sub_question.id"], ondelete="CASCADE"
),
sa.ForeignKeyConstraint(["chat_session_id"], ["chat_session.id"]),
sa.PrimaryKeyConstraint("id"),
)
op.create_table(
"agent__sub_query__search_doc",
sa.Column("sub_query_id", sa.Integer(), nullable=False),
sa.Column("search_doc_id", sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(
["sub_query_id"], ["agent__sub_query.id"], ondelete="CASCADE"
),
sa.ForeignKeyConstraint(["search_doc_id"], ["search_doc.id"]),
sa.PrimaryKeyConstraint("sub_query_id", "search_doc_id"),
)
# Recreate research agent tables
op.create_table(
"research_agent_iteration",
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
sa.Column("primary_question_id", sa.Integer(), nullable=False),
sa.Column("iteration_nr", sa.Integer(), nullable=False),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column("purpose", sa.String(), nullable=True),
sa.Column("reasoning", sa.String(), nullable=True),
sa.ForeignKeyConstraint(
["primary_question_id"], ["chat_message.id"], ondelete="CASCADE"
),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint(
"primary_question_id",
"iteration_nr",
name="_research_agent_iteration_unique_constraint",
),
)
op.create_table(
"research_agent_iteration_sub_step",
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
sa.Column("primary_question_id", sa.Integer(), nullable=False),
sa.Column("iteration_nr", sa.Integer(), nullable=False),
sa.Column("iteration_sub_step_nr", sa.Integer(), nullable=False),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column("sub_step_instructions", sa.String(), nullable=True),
sa.Column("sub_step_tool_id", sa.Integer(), nullable=True),
sa.Column("reasoning", sa.String(), nullable=True),
sa.Column("sub_answer", sa.String(), nullable=True),
sa.Column("cited_doc_results", postgresql.JSONB(), nullable=False),
sa.Column("claims", postgresql.JSONB(), nullable=True),
sa.Column("is_web_fetch", sa.Boolean(), nullable=True),
sa.Column("queries", postgresql.JSONB(), nullable=True),
sa.Column("generated_images", postgresql.JSONB(), nullable=True),
sa.Column("additional_data", postgresql.JSONB(), nullable=True),
sa.ForeignKeyConstraint(
["primary_question_id", "iteration_nr"],
[
"research_agent_iteration.primary_question_id",
"research_agent_iteration.iteration_nr",
],
ondelete="CASCADE",
),
sa.ForeignKeyConstraint(["sub_step_tool_id"], ["tool.id"], ondelete="SET NULL"),
sa.PrimaryKeyConstraint("id"),
)

View File

@@ -199,10 +199,7 @@ def fetch_persona_message_analytics(
ChatMessage.chat_session_id == ChatSession.id,
)
.where(
or_(
ChatMessage.alternate_assistant_id == persona_id,
ChatSession.persona_id == persona_id,
),
ChatSession.persona_id == persona_id,
ChatMessage.time_sent >= start,
ChatMessage.time_sent <= end,
ChatMessage.message_type == MessageType.ASSISTANT,
@@ -231,10 +228,7 @@ def fetch_persona_unique_users(
ChatMessage.chat_session_id == ChatSession.id,
)
.where(
or_(
ChatMessage.alternate_assistant_id == persona_id,
ChatSession.persona_id == persona_id,
),
ChatSession.persona_id == persona_id,
ChatMessage.time_sent >= start,
ChatMessage.time_sent <= end,
ChatMessage.message_type == MessageType.ASSISTANT,
@@ -265,10 +259,7 @@ def fetch_assistant_message_analytics(
ChatMessage.chat_session_id == ChatSession.id,
)
.where(
or_(
ChatMessage.alternate_assistant_id == assistant_id,
ChatSession.persona_id == assistant_id,
),
ChatSession.persona_id == assistant_id,
ChatMessage.time_sent >= start,
ChatMessage.time_sent <= end,
ChatMessage.message_type == MessageType.ASSISTANT,
@@ -299,10 +290,7 @@ def fetch_assistant_unique_users(
ChatMessage.chat_session_id == ChatSession.id,
)
.where(
or_(
ChatMessage.alternate_assistant_id == assistant_id,
ChatSession.persona_id == assistant_id,
),
ChatSession.persona_id == assistant_id,
ChatMessage.time_sent >= start,
ChatMessage.time_sent <= end,
ChatMessage.message_type == MessageType.ASSISTANT,
@@ -332,10 +320,7 @@ def fetch_assistant_unique_users_total(
ChatMessage.chat_session_id == ChatSession.id,
)
.where(
or_(
ChatMessage.alternate_assistant_id == assistant_id,
ChatSession.persona_id == assistant_id,
),
ChatSession.persona_id == assistant_id,
ChatMessage.time_sent >= start,
ChatMessage.time_sent <= end,
ChatMessage.message_type == MessageType.ASSISTANT,

View File

@@ -55,18 +55,7 @@ def get_empty_chat_messages_entries__paginated(
# Get assistant name (from session persona, or alternate if specified)
assistant_name = None
if message.alternate_assistant_id:
# If there's an alternate assistant, we need to fetch it
from onyx.db.models import Persona
alternate_persona = (
db_session.query(Persona)
.filter(Persona.id == message.alternate_assistant_id)
.first()
)
if alternate_persona:
assistant_name = alternate_persona.name
elif chat_session.persona:
if chat_session.persona:
assistant_name = chat_session.persona.name
message_skeletons.append(

View File

@@ -9,7 +9,7 @@ from ee.onyx.server.query_and_chat.models import (
)
from onyx.auth.users import current_user
from onyx.chat.chat_utils import combine_message_thread
from onyx.chat.chat_utils import create_chat_chain
from onyx.chat.chat_utils import create_chat_history_chain
from onyx.chat.models import ChatBasicResponse
from onyx.chat.process_message import gather_stream
from onyx.chat.process_message import stream_chat_message_objects
@@ -69,7 +69,7 @@ def handle_simplified_chat_message(
chat_session_id = chat_message_req.chat_session_id
try:
parent_message, _ = create_chat_chain(
parent_message, _ = create_chat_history_chain(
chat_session_id=chat_session_id, db_session=db_session
)
except Exception:

View File

@@ -8,10 +8,29 @@ from pydantic import model_validator
from onyx.chat.models import ThreadMessage
from onyx.configs.constants import DocumentSource
from onyx.context.search.models import BaseFilters
from onyx.context.search.models import BasicChunkRequest
from onyx.context.search.models import ChunkContext
from onyx.context.search.models import InferenceChunk
from onyx.context.search.models import RetrievalDetails
from onyx.server.manage.models import StandardAnswer
from onyx.server.query_and_chat.streaming_models import SubQuestionIdentifier
class StandardAnswerRequest(BaseModel):
message: str
slack_bot_categories: list[str]
class StandardAnswerResponse(BaseModel):
standard_answers: list[StandardAnswer] = Field(default_factory=list)
class DocumentSearchRequest(BasicChunkRequest):
user_selected_filters: BaseFilters | None = None
class DocumentSearchResponse(BaseModel):
top_documents: list[InferenceChunk]
class BasicCreateChatMessageRequest(ChunkContext):
@@ -71,17 +90,17 @@ class SimpleDoc(BaseModel):
metadata: dict | None
class AgentSubQuestion(SubQuestionIdentifier):
class AgentSubQuestion(BaseModel):
sub_question: str
document_ids: list[str]
class AgentAnswer(SubQuestionIdentifier):
class AgentAnswer(BaseModel):
answer: str
answer_type: Literal["agent_sub_answer", "agent_level_answer"]
class AgentSubQuery(SubQuestionIdentifier):
class AgentSubQuery(BaseModel):
sub_query: str
query_id: int
@@ -127,12 +146,3 @@ class AgentSubQuery(SubQuestionIdentifier):
sorted(level_question_dict.items(), key=lambda x: (x is None, x))
)
return sorted_dict
class StandardAnswerRequest(BaseModel):
message: str
slack_bot_categories: list[str]
class StandardAnswerResponse(BaseModel):
standard_answers: list[StandardAnswer] = Field(default_factory=list)

View File

@@ -24,7 +24,7 @@ from onyx.auth.users import current_admin_user
from onyx.auth.users import get_display_email
from onyx.background.celery.versioned_apps.client import app as client_app
from onyx.background.task_utils import construct_query_history_report_name
from onyx.chat.chat_utils import create_chat_chain
from onyx.chat.chat_utils import create_chat_history_chain
from onyx.configs.app_configs import ONYX_QUERY_HISTORY_TYPE
from onyx.configs.constants import FileOrigin
from onyx.configs.constants import FileType
@@ -123,10 +123,9 @@ def snapshot_from_chat_session(
) -> ChatSessionSnapshot | None:
try:
# Older chats may not have the right structure
last_message, messages = create_chat_chain(
messages = create_chat_history_chain(
chat_session_id=chat_session.id, db_session=db_session
)
messages.append(last_message)
except RuntimeError:
return None

View File

@@ -1,365 +1,309 @@
import json
from collections.abc import Callable
from collections.abc import Iterator
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Any
# import json
# from collections.abc import Callable
# from collections.abc import Iterator
# from collections.abc import Sequence
# from dataclasses import dataclass
# from typing import Any
import onyx.tracing.framework._error_tracing as _error_tracing
from onyx.agents.agent_framework.models import RunItemStreamEvent
from onyx.agents.agent_framework.models import StreamEvent
from onyx.agents.agent_framework.models import ToolCallOutputStreamItem
from onyx.agents.agent_framework.models import ToolCallStreamItem
from onyx.llm.interfaces import LanguageModelInput
from onyx.llm.interfaces import LLM
from onyx.llm.interfaces import ToolChoiceOptions
from onyx.llm.message_types import ChatCompletionMessage
from onyx.llm.message_types import ToolCall
from onyx.llm.model_response import ModelResponseStream
from onyx.tools.tool import RunContextWrapper
from onyx.tools.tool import Tool
from onyx.tracing.framework.create import agent_span
from onyx.tracing.framework.create import function_span
from onyx.tracing.framework.create import generation_span
from onyx.tracing.framework.spans import SpanError
# from onyx.agents.agent_framework.models import RunItemStreamEvent
# from onyx.agents.agent_framework.models import StreamEvent
# from onyx.agents.agent_framework.models import ToolCallStreamItem
# from onyx.llm.interfaces import LanguageModelInput
# from onyx.llm.interfaces import LLM
# from onyx.llm.interfaces import ToolChoiceOptions
# from onyx.llm.message_types import ChatCompletionMessage
# from onyx.llm.message_types import ToolCall
# from onyx.llm.model_response import ModelResponseStream
# from onyx.tools.tool import Tool
# from onyx.tracing.framework.create import agent_span
# from onyx.tracing.framework.create import generation_span
@dataclass
class QueryResult:
stream: Iterator[StreamEvent]
new_messages_stateful: list[ChatCompletionMessage]
# @dataclass
# class QueryResult:
# stream: Iterator[StreamEvent]
# new_messages_stateful: list[ChatCompletionMessage]
def _serialize_tool_output(output: Any) -> str:
if isinstance(output, str):
return output
try:
return json.dumps(output)
except TypeError:
return str(output)
# def _serialize_tool_output(output: Any) -> str:
# if isinstance(output, str):
# return output
# try:
# return json.dumps(output)
# except TypeError:
# return str(output)
def _parse_tool_calls_from_message_content(
content: str,
) -> list[dict[str, Any]]:
"""Parse JSON content that represents tool call instructions."""
try:
parsed_content = json.loads(content)
except json.JSONDecodeError:
return []
# def _parse_tool_calls_from_message_content(
# content: str,
# ) -> list[dict[str, Any]]:
# """Parse JSON content that represents tool call instructions."""
# try:
# parsed_content = json.loads(content)
# except json.JSONDecodeError:
# return []
if isinstance(parsed_content, dict):
candidates = [parsed_content]
elif isinstance(parsed_content, list):
candidates = [item for item in parsed_content if isinstance(item, dict)]
else:
return []
# if isinstance(parsed_content, dict):
# candidates = [parsed_content]
# elif isinstance(parsed_content, list):
# candidates = [item for item in parsed_content if isinstance(item, dict)]
# else:
# return []
tool_calls: list[dict[str, Any]] = []
# tool_calls: list[dict[str, Any]] = []
for candidate in candidates:
name = candidate.get("name")
arguments = candidate.get("arguments")
# for candidate in candidates:
# name = candidate.get("name")
# arguments = candidate.get("arguments")
if not isinstance(name, str) or arguments is None:
continue
# if not isinstance(name, str) or arguments is None:
# continue
if not isinstance(arguments, dict):
continue
# if not isinstance(arguments, dict):
# continue
call_id = candidate.get("id")
arguments_str = json.dumps(arguments)
tool_calls.append(
{
"id": call_id,
"name": name,
"arguments": arguments_str,
}
)
# call_id = candidate.get("id")
# arguments_str = json.dumps(arguments)
# tool_calls.append(
# {
# "id": call_id,
# "name": name,
# "arguments": arguments_str,
# }
# )
return tool_calls
# return tool_calls
def _try_convert_content_to_tool_calls_for_non_tool_calling_llms(
tool_calls_in_progress: dict[int, dict[str, Any]],
content_parts: list[str],
structured_response_format: dict | None,
next_synthetic_tool_call_id: Callable[[], str],
) -> None:
"""Populate tool_calls_in_progress when a non-tool-calling LLM returns JSON content describing tool calls."""
if tool_calls_in_progress or not content_parts or structured_response_format:
return
# def _try_convert_content_to_tool_calls_for_non_tool_calling_llms(
# tool_calls_in_progress: dict[int, dict[str, Any]],
# content_parts: list[str],
# structured_response_format: dict | None,
# next_synthetic_tool_call_id: Callable[[], str],
# ) -> None:
# """Populate tool_calls_in_progress when a non-tool-calling LLM returns JSON content describing tool calls."""
# if tool_calls_in_progress or not content_parts or structured_response_format:
# return
tool_calls_from_content = _parse_tool_calls_from_message_content(
"".join(content_parts)
)
# tool_calls_from_content = _parse_tool_calls_from_message_content(
# "".join(content_parts)
# )
if not tool_calls_from_content:
return
# if not tool_calls_from_content:
# return
content_parts.clear()
# content_parts.clear()
for index, tool_call_data in enumerate(tool_calls_from_content):
call_id = tool_call_data["id"] or next_synthetic_tool_call_id()
tool_calls_in_progress[index] = {
"id": call_id,
"name": tool_call_data["name"],
"arguments": tool_call_data["arguments"],
}
# for index, tool_call_data in enumerate(tool_calls_from_content):
# call_id = tool_call_data["id"] or next_synthetic_tool_call_id()
# tool_calls_in_progress[index] = {
# "id": call_id,
# "name": tool_call_data["name"],
# "arguments": tool_call_data["arguments"],
# }
def _update_tool_call_with_delta(
tool_calls_in_progress: dict[int, dict[str, Any]],
tool_call_delta: Any,
) -> None:
index = tool_call_delta.index
# def _update_tool_call_with_delta(
# tool_calls_in_progress: dict[int, dict[str, Any]],
# tool_call_delta: Any,
# ) -> None:
# index = tool_call_delta.index
if index not in tool_calls_in_progress:
tool_calls_in_progress[index] = {
"id": None,
"name": None,
"arguments": "",
}
# if index not in tool_calls_in_progress:
# tool_calls_in_progress[index] = {
# "id": None,
# "name": None,
# "arguments": "",
# }
if tool_call_delta.id:
tool_calls_in_progress[index]["id"] = tool_call_delta.id
# if tool_call_delta.id:
# tool_calls_in_progress[index]["id"] = tool_call_delta.id
if tool_call_delta.function:
if tool_call_delta.function.name:
tool_calls_in_progress[index]["name"] = tool_call_delta.function.name
# if tool_call_delta.function:
# if tool_call_delta.function.name:
# tool_calls_in_progress[index]["name"] = tool_call_delta.function.name
if tool_call_delta.function.arguments:
tool_calls_in_progress[index][
"arguments"
] += tool_call_delta.function.arguments
# if tool_call_delta.function.arguments:
# tool_calls_in_progress[index][
# "arguments"
# ] += tool_call_delta.function.arguments
def query(
llm_with_default_settings: LLM,
messages: LanguageModelInput,
tools: Sequence[Tool],
context: Any,
tool_choice: ToolChoiceOptions | None = None,
structured_response_format: dict | None = None,
) -> QueryResult:
tool_definitions = [tool.tool_definition() for tool in tools]
tools_by_name = {tool.name: tool for tool in tools}
# def query(
# llm_with_default_settings: LLM,
# messages: LanguageModelInput,
# tools: Sequence[Tool],
# context: Any,
# tool_choice: ToolChoiceOptions | None = None,
# structured_response_format: dict | None = None,
# ) -> QueryResult:
# tool_definitions = [tool.tool_definition() for tool in tools]
# tools_by_name = {tool.name: tool for tool in tools}
new_messages_stateful: list[ChatCompletionMessage] = []
# new_messages_stateful: list[ChatCompletionMessage] = []
current_span = agent_span(
name="agent_framework_query",
output_type="dict" if structured_response_format else "str",
)
current_span.start(mark_as_current=True)
current_span.span_data.tools = [t.name for t in tools]
# current_span = agent_span(
# name="agent_framework_query",
# output_type="dict" if structured_response_format else "str",
# )
# current_span.start(mark_as_current=True)
# current_span.span_data.tools = [t.name for t in tools]
def stream_generator() -> Iterator[StreamEvent]:
message_started = False
reasoning_started = False
# def stream_generator() -> Iterator[StreamEvent]:
# message_started = False
# reasoning_started = False
tool_calls_in_progress: dict[int, dict[str, Any]] = {}
# tool_calls_in_progress: dict[int, dict[str, Any]] = {}
content_parts: list[str] = []
# content_parts: list[str] = []
synthetic_tool_call_counter = 0
# synthetic_tool_call_counter = 0
def _next_synthetic_tool_call_id() -> str:
nonlocal synthetic_tool_call_counter
call_id = f"synthetic_tool_call_{synthetic_tool_call_counter}"
synthetic_tool_call_counter += 1
return call_id
# def _next_synthetic_tool_call_id() -> str:
# nonlocal synthetic_tool_call_counter
# call_id = f"synthetic_tool_call_{synthetic_tool_call_counter}"
# synthetic_tool_call_counter += 1
# return call_id
with generation_span( # type: ignore[misc]
model=llm_with_default_settings.config.model_name,
model_config={
"base_url": str(llm_with_default_settings.config.api_base or ""),
"model_impl": "litellm",
},
) as span_generation:
# Only set input if messages is a sequence (not a string)
# ChatCompletionMessage TypedDicts are compatible with Mapping[str, Any] at runtime
if isinstance(messages, Sequence) and not isinstance(messages, str):
# Convert ChatCompletionMessage sequence to Sequence[Mapping[str, Any]]
span_generation.span_data.input = [dict(msg) for msg in messages] # type: ignore[assignment]
for chunk in llm_with_default_settings.stream(
prompt=messages,
tools=tool_definitions,
tool_choice=tool_choice,
structured_response_format=structured_response_format,
):
assert isinstance(chunk, ModelResponseStream)
usage = getattr(chunk, "usage", None)
if usage:
span_generation.span_data.usage = {
"input_tokens": usage.prompt_tokens,
"output_tokens": usage.completion_tokens,
"cache_read_input_tokens": usage.cache_read_input_tokens,
"cache_creation_input_tokens": usage.cache_creation_input_tokens,
}
# with generation_span( # type: ignore[misc]
# model=llm_with_default_settings.config.model_name,
# model_config={
# "base_url": str(llm_with_default_settings.config.api_base or ""),
# "model_impl": "litellm",
# },
# ) as span_generation:
# # Only set input if messages is a sequence (not a string)
# # ChatCompletionMessage TypedDicts are compatible with Mapping[str, Any] at runtime
# if isinstance(messages, Sequence) and not isinstance(messages, str):
# # Convert ChatCompletionMessage sequence to Sequence[Mapping[str, Any]]
# span_generation.span_data.input = [dict(msg) for msg in messages] # type: ignore[assignment]
# for chunk in llm_with_default_settings.stream(
# prompt=messages,
# tools=tool_definitions,
# tool_choice=tool_choice,
# structured_response_format=structured_response_format,
# ):
# assert isinstance(chunk, ModelResponseStream)
# usage = getattr(chunk, "usage", None)
# if usage:
# span_generation.span_data.usage = {
# "input_tokens": usage.prompt_tokens,
# "output_tokens": usage.completion_tokens,
# "cache_read_input_tokens": usage.cache_read_input_tokens,
# "cache_creation_input_tokens": usage.cache_creation_input_tokens,
# }
delta = chunk.choice.delta
finish_reason = chunk.choice.finish_reason
# delta = chunk.choice.delta
# finish_reason = chunk.choice.finish_reason
if delta.reasoning_content:
if not reasoning_started:
yield RunItemStreamEvent(type="reasoning_start")
reasoning_started = True
# if delta.reasoning_content:
# if not reasoning_started:
# yield RunItemStreamEvent(type="reasoning_start")
# reasoning_started = True
if delta.content:
if reasoning_started:
yield RunItemStreamEvent(type="reasoning_done")
reasoning_started = False
content_parts.append(delta.content)
if not message_started:
yield RunItemStreamEvent(type="message_start")
message_started = True
# if delta.content:
# if reasoning_started:
# yield RunItemStreamEvent(type="reasoning_done")
# reasoning_started = False
# content_parts.append(delta.content)
# if not message_started:
# yield RunItemStreamEvent(type="message_start")
# message_started = True
if delta.tool_calls:
if reasoning_started:
yield RunItemStreamEvent(type="reasoning_done")
reasoning_started = False
if message_started:
yield RunItemStreamEvent(type="message_done")
message_started = False
# if delta.tool_calls:
# if reasoning_started:
# yield RunItemStreamEvent(type="reasoning_done")
# reasoning_started = False
# if message_started:
# yield RunItemStreamEvent(type="message_done")
# message_started = False
for tool_call_delta in delta.tool_calls:
_update_tool_call_with_delta(
tool_calls_in_progress, tool_call_delta
)
# for tool_call_delta in delta.tool_calls:
# _update_tool_call_with_delta(
# tool_calls_in_progress, tool_call_delta
# )
yield chunk
# yield chunk
if not finish_reason:
continue
# if not finish_reason:
# continue
if reasoning_started:
yield RunItemStreamEvent(type="reasoning_done")
reasoning_started = False
if message_started:
yield RunItemStreamEvent(type="message_done")
message_started = False
# if reasoning_started:
# yield RunItemStreamEvent(type="reasoning_done")
# reasoning_started = False
# if message_started:
# yield RunItemStreamEvent(type="message_done")
# message_started = False
if tool_choice != "none":
_try_convert_content_to_tool_calls_for_non_tool_calling_llms(
tool_calls_in_progress,
content_parts,
structured_response_format,
_next_synthetic_tool_call_id,
)
# if tool_choice != "none":
# _try_convert_content_to_tool_calls_for_non_tool_calling_llms(
# tool_calls_in_progress,
# content_parts,
# structured_response_format,
# _next_synthetic_tool_call_id,
# )
if content_parts:
new_messages_stateful.append(
{
"role": "assistant",
"content": "".join(content_parts),
}
)
span_generation.span_data.output = new_messages_stateful
# if content_parts:
# new_messages_stateful.append(
# {
# "role": "assistant",
# "content": "".join(content_parts),
# }
# )
# span_generation.span_data.output = new_messages_stateful
# Execute tool calls outside of the stream loop and generation_span
if tool_calls_in_progress:
sorted_tool_calls = sorted(tool_calls_in_progress.items())
# # Execute tool calls outside of the stream loop and generation_span
# if tool_calls_in_progress:
# sorted_tool_calls = sorted(tool_calls_in_progress.items())
# Build tool calls for the message and execute tools
assistant_tool_calls: list[ToolCall] = []
tool_outputs: dict[str, str] = {}
# # Build tool calls for the message and execute tools
# assistant_tool_calls: list[ToolCall] = []
for _, tool_call_data in sorted_tool_calls:
call_id = tool_call_data["id"]
name = tool_call_data["name"]
arguments_str = tool_call_data["arguments"]
# for _, tool_call_data in sorted_tool_calls:
# call_id = tool_call_data["id"]
# name = tool_call_data["name"]
# arguments_str = tool_call_data["arguments"]
if call_id is None or name is None:
continue
# if call_id is None or name is None:
# continue
assistant_tool_calls.append(
{
"id": call_id,
"type": "function",
"function": {
"name": name,
"arguments": arguments_str,
},
}
)
# assistant_tool_calls.append(
# {
# "id": call_id,
# "type": "function",
# "function": {
# "name": name,
# "arguments": arguments_str,
# },
# }
# )
yield RunItemStreamEvent(
type="tool_call",
details=ToolCallStreamItem(
call_id=call_id,
name=name,
arguments=arguments_str,
),
)
# yield RunItemStreamEvent(
# type="tool_call",
# details=ToolCallStreamItem(
# call_id=call_id,
# name=name,
# arguments=arguments_str,
# ),
# )
if name in tools_by_name:
tool = tools_by_name[name]
arguments = json.loads(arguments_str)
# if name in tools_by_name:
# tools_by_name[name]
# json.loads(arguments_str)
run_context = RunContextWrapper(context=context)
# run_context = RunContextWrapper(context=context)
# TODO: Instead of executing sequentially, execute in parallel
# In practice, it's not a must right now since we don't use parallel
# tool calls, so kicking the can down the road for now.
with function_span(tool.name) as span_fn:
span_fn.span_data.input = arguments
try:
output = tool.run_v2(run_context, **arguments)
tool_outputs[call_id] = _serialize_tool_output(output)
span_fn.span_data.output = output
except Exception as e:
_error_tracing.attach_error_to_current_span(
SpanError(
message="Error running tool",
data={"tool_name": tool.name, "error": str(e)},
)
)
# Treat the error as the tool output so the framework can continue
error_output = f"Error: {str(e)}"
tool_outputs[call_id] = error_output
output = error_output
# TODO: Instead of executing sequentially, execute in parallel
# In practice, it's not a must right now since we don't use parallel
# tool calls, so kicking the can down the road for now.
yield RunItemStreamEvent(
type="tool_call_output",
details=ToolCallOutputStreamItem(
call_id=call_id,
output=output,
),
)
else:
not_found_output = f"Tool {name} not found"
tool_outputs[call_id] = _serialize_tool_output(not_found_output)
yield RunItemStreamEvent(
type="tool_call_output",
details=ToolCallOutputStreamItem(
call_id=call_id,
output=not_found_output,
),
)
# TODO broken for now, no need for a run_v2
# output = tool.run_v2(run_context, **arguments)
new_messages_stateful.append(
{
"role": "assistant",
"content": None,
"tool_calls": assistant_tool_calls,
}
)
for _, tool_call_data in sorted_tool_calls:
call_id = tool_call_data["id"]
if call_id in tool_outputs:
new_messages_stateful.append(
{
"role": "tool",
"content": tool_outputs[call_id],
"tool_call_id": call_id,
}
)
current_span.finish(reset_current=True)
return QueryResult(
stream=stream_generator(),
new_messages_stateful=new_messages_stateful,
)
# yield RunItemStreamEvent(
# type="tool_call_output",
# details=ToolCallOutputStreamItem(
# call_id=call_id,
# output=output,
# ),
# )

View File

@@ -1,21 +1,21 @@
from operator import add
from typing import Annotated
# from operator import add
# from typing import Annotated
from pydantic import BaseModel
# from pydantic import BaseModel
class CoreState(BaseModel):
"""
This is the core state that is shared across all subgraphs.
"""
# class CoreState(BaseModel):
# """
# This is the core state that is shared across all subgraphs.
# """
log_messages: Annotated[list[str], add] = []
current_step_nr: int = 1
# log_messages: Annotated[list[str], add] = []
# current_step_nr: int = 1
class SubgraphCoreState(BaseModel):
"""
This is the core state that is shared across all subgraphs.
"""
# class SubgraphCoreState(BaseModel):
# """
# This is the core state that is shared across all subgraphs.
# """
log_messages: Annotated[list[str], add] = []
# log_messages: Annotated[list[str], add] = []

View File

@@ -1,62 +1,62 @@
from collections.abc import Hashable
from typing import cast
# from collections.abc import Hashable
# from typing import cast
from langchain_core.runnables.config import RunnableConfig
from langgraph.types import Send
# from langchain_core.runnables.config import RunnableConfig
# from langgraph.types import Send
from onyx.agents.agent_search.dc_search_analysis.states import ObjectInformationInput
from onyx.agents.agent_search.dc_search_analysis.states import (
ObjectResearchInformationUpdate,
)
from onyx.agents.agent_search.dc_search_analysis.states import ObjectSourceInput
from onyx.agents.agent_search.dc_search_analysis.states import (
SearchSourcesObjectsUpdate,
)
from onyx.agents.agent_search.models import GraphConfig
# from onyx.agents.agent_search.dc_search_analysis.states import ObjectInformationInput
# from onyx.agents.agent_search.dc_search_analysis.states import (
# ObjectResearchInformationUpdate,
# )
# from onyx.agents.agent_search.dc_search_analysis.states import ObjectSourceInput
# from onyx.agents.agent_search.dc_search_analysis.states import (
# SearchSourcesObjectsUpdate,
# )
# from onyx.agents.agent_search.models import GraphConfig
def parallel_object_source_research_edge(
state: SearchSourcesObjectsUpdate, config: RunnableConfig
) -> list[Send | Hashable]:
"""
LangGraph edge to parallelize the research for an individual object and source
"""
# def parallel_object_source_research_edge(
# state: SearchSourcesObjectsUpdate, config: RunnableConfig
# ) -> list[Send | Hashable]:
# """
# LangGraph edge to parallelize the research for an individual object and source
# """
search_objects = state.analysis_objects
search_sources = state.analysis_sources
# search_objects = state.analysis_objects
# search_sources = state.analysis_sources
object_source_combinations = [
(object, source) for object in search_objects for source in search_sources
]
# object_source_combinations = [
# (object, source) for object in search_objects for source in search_sources
# ]
return [
Send(
"research_object_source",
ObjectSourceInput(
object_source_combination=object_source_combination,
log_messages=[],
),
)
for object_source_combination in object_source_combinations
]
# return [
# Send(
# "research_object_source",
# ObjectSourceInput(
# object_source_combination=object_source_combination,
# log_messages=[],
# ),
# )
# for object_source_combination in object_source_combinations
# ]
def parallel_object_research_consolidation_edge(
state: ObjectResearchInformationUpdate, config: RunnableConfig
) -> list[Send | Hashable]:
"""
LangGraph edge to parallelize the research for an individual object and source
"""
cast(GraphConfig, config["metadata"]["config"])
object_research_information_results = state.object_research_information_results
# def parallel_object_research_consolidation_edge(
# state: ObjectResearchInformationUpdate, config: RunnableConfig
# ) -> list[Send | Hashable]:
# """
# LangGraph edge to parallelize the research for an individual object and source
# """
# cast(GraphConfig, config["metadata"]["config"])
# object_research_information_results = state.object_research_information_results
return [
Send(
"consolidate_object_research",
ObjectInformationInput(
object_information=object_information,
log_messages=[],
),
)
for object_information in object_research_information_results
]
# return [
# Send(
# "consolidate_object_research",
# ObjectInformationInput(
# object_information=object_information,
# log_messages=[],
# ),
# )
# for object_information in object_research_information_results
# ]

View File

@@ -1,103 +1,103 @@
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
# from langgraph.graph import END
# from langgraph.graph import START
# from langgraph.graph import StateGraph
from onyx.agents.agent_search.dc_search_analysis.edges import (
parallel_object_research_consolidation_edge,
)
from onyx.agents.agent_search.dc_search_analysis.edges import (
parallel_object_source_research_edge,
)
from onyx.agents.agent_search.dc_search_analysis.nodes.a1_search_objects import (
search_objects,
)
from onyx.agents.agent_search.dc_search_analysis.nodes.a2_research_object_source import (
research_object_source,
)
from onyx.agents.agent_search.dc_search_analysis.nodes.a3_structure_research_by_object import (
structure_research_by_object,
)
from onyx.agents.agent_search.dc_search_analysis.nodes.a4_consolidate_object_research import (
consolidate_object_research,
)
from onyx.agents.agent_search.dc_search_analysis.nodes.a5_consolidate_research import (
consolidate_research,
)
from onyx.agents.agent_search.dc_search_analysis.states import MainInput
from onyx.agents.agent_search.dc_search_analysis.states import MainState
from onyx.utils.logger import setup_logger
# from onyx.agents.agent_search.dc_search_analysis.edges import (
# parallel_object_research_consolidation_edge,
# )
# from onyx.agents.agent_search.dc_search_analysis.edges import (
# parallel_object_source_research_edge,
# )
# from onyx.agents.agent_search.dc_search_analysis.nodes.a1_search_objects import (
# search_objects,
# )
# from onyx.agents.agent_search.dc_search_analysis.nodes.a2_research_object_source import (
# research_object_source,
# )
# from onyx.agents.agent_search.dc_search_analysis.nodes.a3_structure_research_by_object import (
# structure_research_by_object,
# )
# from onyx.agents.agent_search.dc_search_analysis.nodes.a4_consolidate_object_research import (
# consolidate_object_research,
# )
# from onyx.agents.agent_search.dc_search_analysis.nodes.a5_consolidate_research import (
# consolidate_research,
# )
# from onyx.agents.agent_search.dc_search_analysis.states import MainInput
# from onyx.agents.agent_search.dc_search_analysis.states import MainState
# from onyx.utils.logger import setup_logger
logger = setup_logger()
# logger = setup_logger()
test_mode = False
# test_mode = False
def divide_and_conquer_graph_builder(test_mode: bool = False) -> StateGraph:
"""
LangGraph graph builder for the knowledge graph search process.
"""
# def divide_and_conquer_graph_builder(test_mode: bool = False) -> StateGraph:
# """
# LangGraph graph builder for the knowledge graph search process.
# """
graph = StateGraph(
state_schema=MainState,
input=MainInput,
)
# graph = StateGraph(
# state_schema=MainState,
# input=MainInput,
# )
### Add nodes ###
# ### Add nodes ###
graph.add_node(
"search_objects",
search_objects,
)
# graph.add_node(
# "search_objects",
# search_objects,
# )
graph.add_node(
"structure_research_by_source",
structure_research_by_object,
)
# graph.add_node(
# "structure_research_by_source",
# structure_research_by_object,
# )
graph.add_node(
"research_object_source",
research_object_source,
)
# graph.add_node(
# "research_object_source",
# research_object_source,
# )
graph.add_node(
"consolidate_object_research",
consolidate_object_research,
)
# graph.add_node(
# "consolidate_object_research",
# consolidate_object_research,
# )
graph.add_node(
"consolidate_research",
consolidate_research,
)
# graph.add_node(
# "consolidate_research",
# consolidate_research,
# )
### Add edges ###
# ### Add edges ###
graph.add_edge(start_key=START, end_key="search_objects")
# graph.add_edge(start_key=START, end_key="search_objects")
graph.add_conditional_edges(
source="search_objects",
path=parallel_object_source_research_edge,
path_map=["research_object_source"],
)
# graph.add_conditional_edges(
# source="search_objects",
# path=parallel_object_source_research_edge,
# path_map=["research_object_source"],
# )
graph.add_edge(
start_key="research_object_source",
end_key="structure_research_by_source",
)
# graph.add_edge(
# start_key="research_object_source",
# end_key="structure_research_by_source",
# )
graph.add_conditional_edges(
source="structure_research_by_source",
path=parallel_object_research_consolidation_edge,
path_map=["consolidate_object_research"],
)
# graph.add_conditional_edges(
# source="structure_research_by_source",
# path=parallel_object_research_consolidation_edge,
# path_map=["consolidate_object_research"],
# )
graph.add_edge(
start_key="consolidate_object_research",
end_key="consolidate_research",
)
# graph.add_edge(
# start_key="consolidate_object_research",
# end_key="consolidate_research",
# )
graph.add_edge(
start_key="consolidate_research",
end_key=END,
)
# graph.add_edge(
# start_key="consolidate_research",
# end_key=END,
# )
return graph
# return graph

View File

@@ -1,146 +1,146 @@
from typing import cast
# from typing import cast
from langchain_core.messages import HumanMessage
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from langchain_core.messages import HumanMessage
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from onyx.agents.agent_search.dc_search_analysis.ops import extract_section
from onyx.agents.agent_search.dc_search_analysis.ops import research
from onyx.agents.agent_search.dc_search_analysis.states import MainState
from onyx.agents.agent_search.dc_search_analysis.states import (
SearchSourcesObjectsUpdate,
)
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
trim_prompt_piece,
)
from onyx.prompts.agents.dc_prompts import DC_OBJECT_NO_BASE_DATA_EXTRACTION_PROMPT
from onyx.prompts.agents.dc_prompts import DC_OBJECT_SEPARATOR
from onyx.prompts.agents.dc_prompts import DC_OBJECT_WITH_BASE_DATA_EXTRACTION_PROMPT
from onyx.secondary_llm_flows.source_filter import strings_to_document_sources
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
# from onyx.agents.agent_search.dc_search_analysis.ops import extract_section
# from onyx.agents.agent_search.dc_search_analysis.ops import research
# from onyx.agents.agent_search.dc_search_analysis.states import MainState
# from onyx.agents.agent_search.dc_search_analysis.states import (
# SearchSourcesObjectsUpdate,
# )
# from onyx.agents.agent_search.models import GraphConfig
# from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
# trim_prompt_piece,
# )
# from onyx.prompts.agents.dc_prompts import DC_OBJECT_NO_BASE_DATA_EXTRACTION_PROMPT
# from onyx.prompts.agents.dc_prompts import DC_OBJECT_SEPARATOR
# from onyx.prompts.agents.dc_prompts import DC_OBJECT_WITH_BASE_DATA_EXTRACTION_PROMPT
# from onyx.secondary_llm_flows.source_filter import strings_to_document_sources
# from onyx.utils.logger import setup_logger
# from onyx.utils.threadpool_concurrency import run_with_timeout
logger = setup_logger()
# logger = setup_logger()
def search_objects(
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> SearchSourcesObjectsUpdate:
"""
LangGraph node to start the agentic search process.
"""
# def search_objects(
# state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
# ) -> SearchSourcesObjectsUpdate:
# """
# LangGraph node to start the agentic search process.
# """
graph_config = cast(GraphConfig, config["metadata"]["config"])
question = graph_config.inputs.prompt_builder.raw_user_query
search_tool = graph_config.tooling.search_tool
# graph_config = cast(GraphConfig, config["metadata"]["config"])
# question = graph_config.inputs.prompt_builder.raw_user_query
# search_tool = graph_config.tooling.search_tool
if search_tool is None or graph_config.inputs.persona is None:
raise ValueError("Search tool and persona must be provided for DivCon search")
# if search_tool is None or graph_config.inputs.persona is None:
# raise ValueError("Search tool and persona must be provided for DivCon search")
try:
instructions = graph_config.inputs.persona.system_prompt or ""
# try:
# instructions = graph_config.inputs.persona.system_prompt or ""
agent_1_instructions = extract_section(
instructions, "Agent Step 1:", "Agent Step 2:"
)
if agent_1_instructions is None:
raise ValueError("Agent 1 instructions not found")
# agent_1_instructions = extract_section(
# instructions, "Agent Step 1:", "Agent Step 2:"
# )
# if agent_1_instructions is None:
# raise ValueError("Agent 1 instructions not found")
agent_1_base_data = extract_section(instructions, "|Start Data|", "|End Data|")
# agent_1_base_data = extract_section(instructions, "|Start Data|", "|End Data|")
agent_1_task = extract_section(
agent_1_instructions, "Task:", "Independent Research Sources:"
)
if agent_1_task is None:
raise ValueError("Agent 1 task not found")
# agent_1_task = extract_section(
# agent_1_instructions, "Task:", "Independent Research Sources:"
# )
# if agent_1_task is None:
# raise ValueError("Agent 1 task not found")
agent_1_independent_sources_str = extract_section(
agent_1_instructions, "Independent Research Sources:", "Output Objective:"
)
if agent_1_independent_sources_str is None:
raise ValueError("Agent 1 Independent Research Sources not found")
# agent_1_independent_sources_str = extract_section(
# agent_1_instructions, "Independent Research Sources:", "Output Objective:"
# )
# if agent_1_independent_sources_str is None:
# raise ValueError("Agent 1 Independent Research Sources not found")
document_sources = strings_to_document_sources(
[
x.strip().lower()
for x in agent_1_independent_sources_str.split(DC_OBJECT_SEPARATOR)
]
)
# document_sources = strings_to_document_sources(
# [
# x.strip().lower()
# for x in agent_1_independent_sources_str.split(DC_OBJECT_SEPARATOR)
# ]
# )
agent_1_output_objective = extract_section(
agent_1_instructions, "Output Objective:"
)
if agent_1_output_objective is None:
raise ValueError("Agent 1 output objective not found")
# agent_1_output_objective = extract_section(
# agent_1_instructions, "Output Objective:"
# )
# if agent_1_output_objective is None:
# raise ValueError("Agent 1 output objective not found")
except Exception as e:
raise ValueError(
f"Agent 1 instructions not found or not formatted correctly: {e}"
)
# except Exception as e:
# raise ValueError(
# f"Agent 1 instructions not found or not formatted correctly: {e}"
# )
# Extract objects
# # Extract objects
if agent_1_base_data is None:
# Retrieve chunks for objects
# if agent_1_base_data is None:
# # Retrieve chunks for objects
retrieved_docs = research(question, search_tool)[:10]
# retrieved_docs = research(question, search_tool)[:10]
document_texts_list = []
for doc_num, doc in enumerate(retrieved_docs):
chunk_text = "Document " + str(doc_num) + ":\n" + doc.content
document_texts_list.append(chunk_text)
# document_texts_list = []
# for doc_num, doc in enumerate(retrieved_docs):
# chunk_text = "Document " + str(doc_num) + ":\n" + doc.content
# document_texts_list.append(chunk_text)
document_texts = "\n\n".join(document_texts_list)
# document_texts = "\n\n".join(document_texts_list)
dc_object_extraction_prompt = DC_OBJECT_NO_BASE_DATA_EXTRACTION_PROMPT.format(
question=question,
task=agent_1_task,
document_text=document_texts,
objects_of_interest=agent_1_output_objective,
)
else:
dc_object_extraction_prompt = DC_OBJECT_WITH_BASE_DATA_EXTRACTION_PROMPT.format(
question=question,
task=agent_1_task,
base_data=agent_1_base_data,
objects_of_interest=agent_1_output_objective,
)
# dc_object_extraction_prompt = DC_OBJECT_NO_BASE_DATA_EXTRACTION_PROMPT.format(
# question=question,
# task=agent_1_task,
# document_text=document_texts,
# objects_of_interest=agent_1_output_objective,
# )
# else:
# dc_object_extraction_prompt = DC_OBJECT_WITH_BASE_DATA_EXTRACTION_PROMPT.format(
# question=question,
# task=agent_1_task,
# base_data=agent_1_base_data,
# objects_of_interest=agent_1_output_objective,
# )
msg = [
HumanMessage(
content=trim_prompt_piece(
config=graph_config.tooling.primary_llm.config,
prompt_piece=dc_object_extraction_prompt,
reserved_str="",
),
)
]
primary_llm = graph_config.tooling.primary_llm
# Grader
try:
llm_response = run_with_timeout(
30,
primary_llm.invoke_langchain,
prompt=msg,
timeout_override=30,
max_tokens=300,
)
# msg = [
# HumanMessage(
# content=trim_prompt_piece(
# config=graph_config.tooling.primary_llm.config,
# prompt_piece=dc_object_extraction_prompt,
# reserved_str="",
# ),
# )
# ]
# primary_llm = graph_config.tooling.primary_llm
# # Grader
# try:
# llm_response = run_with_timeout(
# 30,
# primary_llm.invoke_langchain,
# prompt=msg,
# timeout_override=30,
# max_tokens=300,
# )
cleaned_response = (
str(llm_response.content)
.replace("```json\n", "")
.replace("\n```", "")
.replace("\n", "")
)
cleaned_response = cleaned_response.split("OBJECTS:")[1]
object_list = [x.strip() for x in cleaned_response.split(";")]
# cleaned_response = (
# str(llm_response.content)
# .replace("```json\n", "")
# .replace("\n```", "")
# .replace("\n", "")
# )
# cleaned_response = cleaned_response.split("OBJECTS:")[1]
# object_list = [x.strip() for x in cleaned_response.split(";")]
except Exception as e:
raise ValueError(f"Error in search_objects: {e}")
# except Exception as e:
# raise ValueError(f"Error in search_objects: {e}")
return SearchSourcesObjectsUpdate(
analysis_objects=object_list,
analysis_sources=document_sources,
log_messages=["Agent 1 Task done"],
)
# return SearchSourcesObjectsUpdate(
# analysis_objects=object_list,
# analysis_sources=document_sources,
# log_messages=["Agent 1 Task done"],
# )

View File

@@ -1,180 +1,180 @@
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from typing import cast
# from datetime import datetime
# from datetime import timedelta
# from datetime import timezone
# from typing import cast
from langchain_core.messages import HumanMessage
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from langchain_core.messages import HumanMessage
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from onyx.agents.agent_search.dc_search_analysis.ops import extract_section
from onyx.agents.agent_search.dc_search_analysis.ops import research
from onyx.agents.agent_search.dc_search_analysis.states import ObjectSourceInput
from onyx.agents.agent_search.dc_search_analysis.states import (
ObjectSourceResearchUpdate,
)
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
trim_prompt_piece,
)
from onyx.prompts.agents.dc_prompts import DC_OBJECT_SOURCE_RESEARCH_PROMPT
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
# from onyx.agents.agent_search.dc_search_analysis.ops import extract_section
# from onyx.agents.agent_search.dc_search_analysis.ops import research
# from onyx.agents.agent_search.dc_search_analysis.states import ObjectSourceInput
# from onyx.agents.agent_search.dc_search_analysis.states import (
# ObjectSourceResearchUpdate,
# )
# from onyx.agents.agent_search.models import GraphConfig
# from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
# trim_prompt_piece,
# )
# from onyx.prompts.agents.dc_prompts import DC_OBJECT_SOURCE_RESEARCH_PROMPT
# from onyx.utils.logger import setup_logger
# from onyx.utils.threadpool_concurrency import run_with_timeout
logger = setup_logger()
# logger = setup_logger()
def research_object_source(
state: ObjectSourceInput,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> ObjectSourceResearchUpdate:
"""
LangGraph node to start the agentic search process.
"""
datetime.now()
# def research_object_source(
# state: ObjectSourceInput,
# config: RunnableConfig,
# writer: StreamWriter = lambda _: None,
# ) -> ObjectSourceResearchUpdate:
# """
# LangGraph node to start the agentic search process.
# """
# datetime.now()
graph_config = cast(GraphConfig, config["metadata"]["config"])
search_tool = graph_config.tooling.search_tool
question = graph_config.inputs.prompt_builder.raw_user_query
object, document_source = state.object_source_combination
# graph_config = cast(GraphConfig, config["metadata"]["config"])
# search_tool = graph_config.tooling.search_tool
# question = graph_config.inputs.prompt_builder.raw_user_query
# object, document_source = state.object_source_combination
if search_tool is None or graph_config.inputs.persona is None:
raise ValueError("Search tool and persona must be provided for DivCon search")
# if search_tool is None or graph_config.inputs.persona is None:
# raise ValueError("Search tool and persona must be provided for DivCon search")
try:
instructions = graph_config.inputs.persona.system_prompt or ""
# try:
# instructions = graph_config.inputs.persona.system_prompt or ""
agent_2_instructions = extract_section(
instructions, "Agent Step 2:", "Agent Step 3:"
)
if agent_2_instructions is None:
raise ValueError("Agent 2 instructions not found")
# agent_2_instructions = extract_section(
# instructions, "Agent Step 2:", "Agent Step 3:"
# )
# if agent_2_instructions is None:
# raise ValueError("Agent 2 instructions not found")
agent_2_task = extract_section(
agent_2_instructions, "Task:", "Independent Research Sources:"
)
if agent_2_task is None:
raise ValueError("Agent 2 task not found")
# agent_2_task = extract_section(
# agent_2_instructions, "Task:", "Independent Research Sources:"
# )
# if agent_2_task is None:
# raise ValueError("Agent 2 task not found")
agent_2_time_cutoff = extract_section(
agent_2_instructions, "Time Cutoff:", "Research Topics:"
)
# agent_2_time_cutoff = extract_section(
# agent_2_instructions, "Time Cutoff:", "Research Topics:"
# )
agent_2_research_topics = extract_section(
agent_2_instructions, "Research Topics:", "Output Objective"
)
# agent_2_research_topics = extract_section(
# agent_2_instructions, "Research Topics:", "Output Objective"
# )
agent_2_output_objective = extract_section(
agent_2_instructions, "Output Objective:"
)
if agent_2_output_objective is None:
raise ValueError("Agent 2 output objective not found")
# agent_2_output_objective = extract_section(
# agent_2_instructions, "Output Objective:"
# )
# if agent_2_output_objective is None:
# raise ValueError("Agent 2 output objective not found")
except Exception:
raise ValueError(
"Agent 1 instructions not found or not formatted correctly: {e}"
)
# except Exception:
# raise ValueError(
# "Agent 1 instructions not found or not formatted correctly: {e}"
# )
# Populate prompt
# # Populate prompt
# Retrieve chunks for objects
# # Retrieve chunks for objects
if agent_2_time_cutoff is not None and agent_2_time_cutoff.strip() != "":
if agent_2_time_cutoff.strip().endswith("d"):
try:
days = int(agent_2_time_cutoff.strip()[:-1])
agent_2_source_start_time = datetime.now(timezone.utc) - timedelta(
days=days
)
except ValueError:
raise ValueError(
f"Invalid time cutoff format: {agent_2_time_cutoff}. Expected format: '<number>d'"
)
else:
raise ValueError(
f"Invalid time cutoff format: {agent_2_time_cutoff}. Expected format: '<number>d'"
)
else:
agent_2_source_start_time = None
# if agent_2_time_cutoff is not None and agent_2_time_cutoff.strip() != "":
# if agent_2_time_cutoff.strip().endswith("d"):
# try:
# days = int(agent_2_time_cutoff.strip()[:-1])
# agent_2_source_start_time = datetime.now(timezone.utc) - timedelta(
# days=days
# )
# except ValueError:
# raise ValueError(
# f"Invalid time cutoff format: {agent_2_time_cutoff}. Expected format: '<number>d'"
# )
# else:
# raise ValueError(
# f"Invalid time cutoff format: {agent_2_time_cutoff}. Expected format: '<number>d'"
# )
# else:
# agent_2_source_start_time = None
document_sources = [document_source] if document_source else None
# document_sources = [document_source] if document_source else None
if len(question.strip()) > 0:
research_area = f"{question} for {object}"
elif agent_2_research_topics and len(agent_2_research_topics.strip()) > 0:
research_area = f"{agent_2_research_topics} for {object}"
else:
research_area = object
# if len(question.strip()) > 0:
# research_area = f"{question} for {object}"
# elif agent_2_research_topics and len(agent_2_research_topics.strip()) > 0:
# research_area = f"{agent_2_research_topics} for {object}"
# else:
# research_area = object
retrieved_docs = research(
question=research_area,
search_tool=search_tool,
document_sources=document_sources,
time_cutoff=agent_2_source_start_time,
)
# retrieved_docs = research(
# question=research_area,
# search_tool=search_tool,
# document_sources=document_sources,
# time_cutoff=agent_2_source_start_time,
# )
# Generate document text
# # Generate document text
document_texts_list = []
for doc_num, doc in enumerate(retrieved_docs):
chunk_text = "Document " + str(doc_num) + ":\n" + doc.content
document_texts_list.append(chunk_text)
# document_texts_list = []
# for doc_num, doc in enumerate(retrieved_docs):
# chunk_text = "Document " + str(doc_num) + ":\n" + doc.content
# document_texts_list.append(chunk_text)
document_texts = "\n\n".join(document_texts_list)
# document_texts = "\n\n".join(document_texts_list)
# Built prompt
# # Built prompt
today = datetime.now().strftime("%A, %Y-%m-%d")
# today = datetime.now().strftime("%A, %Y-%m-%d")
dc_object_source_research_prompt = (
DC_OBJECT_SOURCE_RESEARCH_PROMPT.format(
today=today,
question=question,
task=agent_2_task,
document_text=document_texts,
format=agent_2_output_objective,
)
.replace("---object---", object)
.replace("---source---", document_source.value)
)
# dc_object_source_research_prompt = (
# DC_OBJECT_SOURCE_RESEARCH_PROMPT.format(
# today=today,
# question=question,
# task=agent_2_task,
# document_text=document_texts,
# format=agent_2_output_objective,
# )
# .replace("---object---", object)
# .replace("---source---", document_source.value)
# )
# Run LLM
# # Run LLM
msg = [
HumanMessage(
content=trim_prompt_piece(
config=graph_config.tooling.primary_llm.config,
prompt_piece=dc_object_source_research_prompt,
reserved_str="",
),
)
]
primary_llm = graph_config.tooling.primary_llm
# Grader
try:
llm_response = run_with_timeout(
30,
primary_llm.invoke_langchain,
prompt=msg,
timeout_override=30,
max_tokens=300,
)
# msg = [
# HumanMessage(
# content=trim_prompt_piece(
# config=graph_config.tooling.primary_llm.config,
# prompt_piece=dc_object_source_research_prompt,
# reserved_str="",
# ),
# )
# ]
# primary_llm = graph_config.tooling.primary_llm
# # Grader
# try:
# llm_response = run_with_timeout(
# 30,
# primary_llm.invoke_langchain,
# prompt=msg,
# timeout_override=30,
# max_tokens=300,
# )
cleaned_response = str(llm_response.content).replace("```json\n", "")
cleaned_response = cleaned_response.split("RESEARCH RESULTS:")[1]
object_research_results = {
"object": object,
"source": document_source.value,
"research_result": cleaned_response,
}
# cleaned_response = str(llm_response.content).replace("```json\n", "")
# cleaned_response = cleaned_response.split("RESEARCH RESULTS:")[1]
# object_research_results = {
# "object": object,
# "source": document_source.value,
# "research_result": cleaned_response,
# }
except Exception as e:
raise ValueError(f"Error in research_object_source: {e}")
# except Exception as e:
# raise ValueError(f"Error in research_object_source: {e}")
logger.debug("DivCon Step A2 - Object Source Research - completed for an object")
# logger.debug("DivCon Step A2 - Object Source Research - completed for an object")
return ObjectSourceResearchUpdate(
object_source_research_results=[object_research_results],
log_messages=["Agent Step 2 done for one object"],
)
# return ObjectSourceResearchUpdate(
# object_source_research_results=[object_research_results],
# log_messages=["Agent Step 2 done for one object"],
# )

View File

@@ -1,48 +1,48 @@
from collections import defaultdict
from typing import Dict
from typing import List
# from collections import defaultdict
# from typing import Dict
# from typing import List
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from onyx.agents.agent_search.dc_search_analysis.states import MainState
from onyx.agents.agent_search.dc_search_analysis.states import (
ObjectResearchInformationUpdate,
)
from onyx.utils.logger import setup_logger
# from onyx.agents.agent_search.dc_search_analysis.states import MainState
# from onyx.agents.agent_search.dc_search_analysis.states import (
# ObjectResearchInformationUpdate,
# )
# from onyx.utils.logger import setup_logger
logger = setup_logger()
# logger = setup_logger()
def structure_research_by_object(
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> ObjectResearchInformationUpdate:
"""
LangGraph node to start the agentic search process.
"""
# def structure_research_by_object(
# state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
# ) -> ObjectResearchInformationUpdate:
# """
# LangGraph node to start the agentic search process.
# """
object_source_research_results = state.object_source_research_results
# object_source_research_results = state.object_source_research_results
object_research_information_results: List[Dict[str, str]] = []
object_research_information_results_list: Dict[str, List[str]] = defaultdict(list)
# object_research_information_results: List[Dict[str, str]] = []
# object_research_information_results_list: Dict[str, List[str]] = defaultdict(list)
for object_source_research in object_source_research_results:
object = object_source_research["object"]
source = object_source_research["source"]
research_result = object_source_research["research_result"]
# for object_source_research in object_source_research_results:
# object = object_source_research["object"]
# source = object_source_research["source"]
# research_result = object_source_research["research_result"]
object_research_information_results_list[object].append(
f"Source: {source}\n{research_result}"
)
# object_research_information_results_list[object].append(
# f"Source: {source}\n{research_result}"
# )
for object, information in object_research_information_results_list.items():
object_research_information_results.append(
{"object": object, "information": "\n".join(information)}
)
# for object, information in object_research_information_results_list.items():
# object_research_information_results.append(
# {"object": object, "information": "\n".join(information)}
# )
logger.debug("DivCon Step A3 - Object Research Information Structuring - completed")
# logger.debug("DivCon Step A3 - Object Research Information Structuring - completed")
return ObjectResearchInformationUpdate(
object_research_information_results=object_research_information_results,
log_messages=["A3 - Object Research Information structured"],
)
# return ObjectResearchInformationUpdate(
# object_research_information_results=object_research_information_results,
# log_messages=["A3 - Object Research Information structured"],
# )

View File

@@ -1,103 +1,103 @@
from typing import cast
# from typing import cast
from langchain_core.messages import HumanMessage
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from langchain_core.messages import HumanMessage
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from onyx.agents.agent_search.dc_search_analysis.ops import extract_section
from onyx.agents.agent_search.dc_search_analysis.states import ObjectInformationInput
from onyx.agents.agent_search.dc_search_analysis.states import ObjectResearchUpdate
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
trim_prompt_piece,
)
from onyx.prompts.agents.dc_prompts import DC_OBJECT_CONSOLIDATION_PROMPT
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
# from onyx.agents.agent_search.dc_search_analysis.ops import extract_section
# from onyx.agents.agent_search.dc_search_analysis.states import ObjectInformationInput
# from onyx.agents.agent_search.dc_search_analysis.states import ObjectResearchUpdate
# from onyx.agents.agent_search.models import GraphConfig
# from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
# trim_prompt_piece,
# )
# from onyx.prompts.agents.dc_prompts import DC_OBJECT_CONSOLIDATION_PROMPT
# from onyx.utils.logger import setup_logger
# from onyx.utils.threadpool_concurrency import run_with_timeout
logger = setup_logger()
# logger = setup_logger()
def consolidate_object_research(
state: ObjectInformationInput,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> ObjectResearchUpdate:
"""
LangGraph node to start the agentic search process.
"""
graph_config = cast(GraphConfig, config["metadata"]["config"])
search_tool = graph_config.tooling.search_tool
question = graph_config.inputs.prompt_builder.raw_user_query
# def consolidate_object_research(
# state: ObjectInformationInput,
# config: RunnableConfig,
# writer: StreamWriter = lambda _: None,
# ) -> ObjectResearchUpdate:
# """
# LangGraph node to start the agentic search process.
# """
# graph_config = cast(GraphConfig, config["metadata"]["config"])
# search_tool = graph_config.tooling.search_tool
# question = graph_config.inputs.prompt_builder.raw_user_query
if search_tool is None or graph_config.inputs.persona is None:
raise ValueError("Search tool and persona must be provided for DivCon search")
# if search_tool is None or graph_config.inputs.persona is None:
# raise ValueError("Search tool and persona must be provided for DivCon search")
instructions = graph_config.inputs.persona.system_prompt or ""
# instructions = graph_config.inputs.persona.system_prompt or ""
agent_4_instructions = extract_section(
instructions, "Agent Step 4:", "Agent Step 5:"
)
if agent_4_instructions is None:
raise ValueError("Agent 4 instructions not found")
agent_4_output_objective = extract_section(
agent_4_instructions, "Output Objective:"
)
if agent_4_output_objective is None:
raise ValueError("Agent 4 output objective not found")
# agent_4_instructions = extract_section(
# instructions, "Agent Step 4:", "Agent Step 5:"
# )
# if agent_4_instructions is None:
# raise ValueError("Agent 4 instructions not found")
# agent_4_output_objective = extract_section(
# agent_4_instructions, "Output Objective:"
# )
# if agent_4_output_objective is None:
# raise ValueError("Agent 4 output objective not found")
object_information = state.object_information
# object_information = state.object_information
object = object_information["object"]
information = object_information["information"]
# object = object_information["object"]
# information = object_information["information"]
# Create a prompt for the object consolidation
# # Create a prompt for the object consolidation
dc_object_consolidation_prompt = DC_OBJECT_CONSOLIDATION_PROMPT.format(
question=question,
object=object,
information=information,
format=agent_4_output_objective,
)
# dc_object_consolidation_prompt = DC_OBJECT_CONSOLIDATION_PROMPT.format(
# question=question,
# object=object,
# information=information,
# format=agent_4_output_objective,
# )
# Run LLM
# # Run LLM
msg = [
HumanMessage(
content=trim_prompt_piece(
config=graph_config.tooling.primary_llm.config,
prompt_piece=dc_object_consolidation_prompt,
reserved_str="",
),
)
]
primary_llm = graph_config.tooling.primary_llm
# Grader
try:
llm_response = run_with_timeout(
30,
primary_llm.invoke_langchain,
prompt=msg,
timeout_override=30,
max_tokens=300,
)
# msg = [
# HumanMessage(
# content=trim_prompt_piece(
# config=graph_config.tooling.primary_llm.config,
# prompt_piece=dc_object_consolidation_prompt,
# reserved_str="",
# ),
# )
# ]
# primary_llm = graph_config.tooling.primary_llm
# # Grader
# try:
# llm_response = run_with_timeout(
# 30,
# primary_llm.invoke_langchain,
# prompt=msg,
# timeout_override=30,
# max_tokens=300,
# )
cleaned_response = str(llm_response.content).replace("```json\n", "")
consolidated_information = cleaned_response.split("INFORMATION:")[1]
# cleaned_response = str(llm_response.content).replace("```json\n", "")
# consolidated_information = cleaned_response.split("INFORMATION:")[1]
except Exception as e:
raise ValueError(f"Error in consolidate_object_research: {e}")
# except Exception as e:
# raise ValueError(f"Error in consolidate_object_research: {e}")
object_research_results = {
"object": object,
"research_result": consolidated_information,
}
# object_research_results = {
# "object": object,
# "research_result": consolidated_information,
# }
logger.debug(
"DivCon Step A4 - Object Research Consolidation - completed for an object"
)
# logger.debug(
# "DivCon Step A4 - Object Research Consolidation - completed for an object"
# )
return ObjectResearchUpdate(
object_research_results=[object_research_results],
log_messages=["Agent Source Consilidation done"],
)
# return ObjectResearchUpdate(
# object_research_results=[object_research_results],
# log_messages=["Agent Source Consilidation done"],
# )

View File

@@ -1,127 +1,127 @@
from typing import cast
# from typing import cast
from langchain_core.messages import HumanMessage
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from langchain_core.messages import HumanMessage
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from onyx.agents.agent_search.dc_search_analysis.ops import extract_section
from onyx.agents.agent_search.dc_search_analysis.states import MainState
from onyx.agents.agent_search.dc_search_analysis.states import ResearchUpdate
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
trim_prompt_piece,
)
from onyx.agents.agent_search.shared_graph_utils.llm import stream_llm_answer
from onyx.prompts.agents.dc_prompts import DC_FORMATTING_NO_BASE_DATA_PROMPT
from onyx.prompts.agents.dc_prompts import DC_FORMATTING_WITH_BASE_DATA_PROMPT
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
# from onyx.agents.agent_search.dc_search_analysis.ops import extract_section
# from onyx.agents.agent_search.dc_search_analysis.states import MainState
# from onyx.agents.agent_search.dc_search_analysis.states import ResearchUpdate
# from onyx.agents.agent_search.models import GraphConfig
# from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
# trim_prompt_piece,
# )
# from onyx.agents.agent_search.shared_graph_utils.llm import stream_llm_answer
# from onyx.prompts.agents.dc_prompts import DC_FORMATTING_NO_BASE_DATA_PROMPT
# from onyx.prompts.agents.dc_prompts import DC_FORMATTING_WITH_BASE_DATA_PROMPT
# from onyx.utils.logger import setup_logger
# from onyx.utils.threadpool_concurrency import run_with_timeout
logger = setup_logger()
# logger = setup_logger()
def consolidate_research(
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> ResearchUpdate:
"""
LangGraph node to start the agentic search process.
"""
# def consolidate_research(
# state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
# ) -> ResearchUpdate:
# """
# LangGraph node to start the agentic search process.
# """
graph_config = cast(GraphConfig, config["metadata"]["config"])
# graph_config = cast(GraphConfig, config["metadata"]["config"])
search_tool = graph_config.tooling.search_tool
# search_tool = graph_config.tooling.search_tool
if search_tool is None or graph_config.inputs.persona is None:
raise ValueError("Search tool and persona must be provided for DivCon search")
# if search_tool is None or graph_config.inputs.persona is None:
# raise ValueError("Search tool and persona must be provided for DivCon search")
# Populate prompt
instructions = graph_config.inputs.persona.system_prompt or ""
# # Populate prompt
# instructions = graph_config.inputs.persona.system_prompt or ""
try:
agent_5_instructions = extract_section(
instructions, "Agent Step 5:", "Agent End"
)
if agent_5_instructions is None:
raise ValueError("Agent 5 instructions not found")
agent_5_base_data = extract_section(instructions, "|Start Data|", "|End Data|")
agent_5_task = extract_section(
agent_5_instructions, "Task:", "Independent Research Sources:"
)
if agent_5_task is None:
raise ValueError("Agent 5 task not found")
agent_5_output_objective = extract_section(
agent_5_instructions, "Output Objective:"
)
if agent_5_output_objective is None:
raise ValueError("Agent 5 output objective not found")
except ValueError as e:
raise ValueError(
f"Instructions for Agent Step 5 were not properly formatted: {e}"
)
# try:
# agent_5_instructions = extract_section(
# instructions, "Agent Step 5:", "Agent End"
# )
# if agent_5_instructions is None:
# raise ValueError("Agent 5 instructions not found")
# agent_5_base_data = extract_section(instructions, "|Start Data|", "|End Data|")
# agent_5_task = extract_section(
# agent_5_instructions, "Task:", "Independent Research Sources:"
# )
# if agent_5_task is None:
# raise ValueError("Agent 5 task not found")
# agent_5_output_objective = extract_section(
# agent_5_instructions, "Output Objective:"
# )
# if agent_5_output_objective is None:
# raise ValueError("Agent 5 output objective not found")
# except ValueError as e:
# raise ValueError(
# f"Instructions for Agent Step 5 were not properly formatted: {e}"
# )
research_result_list = []
# research_result_list = []
if agent_5_task.strip() == "*concatenate*":
object_research_results = state.object_research_results
# if agent_5_task.strip() == "*concatenate*":
# object_research_results = state.object_research_results
for object_research_result in object_research_results:
object = object_research_result["object"]
research_result = object_research_result["research_result"]
research_result_list.append(f"Object: {object}\n\n{research_result}")
# for object_research_result in object_research_results:
# object = object_research_result["object"]
# research_result = object_research_result["research_result"]
# research_result_list.append(f"Object: {object}\n\n{research_result}")
research_results = "\n\n".join(research_result_list)
# research_results = "\n\n".join(research_result_list)
else:
raise NotImplementedError("Only '*concatenate*' is currently supported")
# else:
# raise NotImplementedError("Only '*concatenate*' is currently supported")
# Create a prompt for the object consolidation
# # Create a prompt for the object consolidation
if agent_5_base_data is None:
dc_formatting_prompt = DC_FORMATTING_NO_BASE_DATA_PROMPT.format(
text=research_results,
format=agent_5_output_objective,
)
else:
dc_formatting_prompt = DC_FORMATTING_WITH_BASE_DATA_PROMPT.format(
base_data=agent_5_base_data,
text=research_results,
format=agent_5_output_objective,
)
# if agent_5_base_data is None:
# dc_formatting_prompt = DC_FORMATTING_NO_BASE_DATA_PROMPT.format(
# text=research_results,
# format=agent_5_output_objective,
# )
# else:
# dc_formatting_prompt = DC_FORMATTING_WITH_BASE_DATA_PROMPT.format(
# base_data=agent_5_base_data,
# text=research_results,
# format=agent_5_output_objective,
# )
# Run LLM
# # Run LLM
msg = [
HumanMessage(
content=trim_prompt_piece(
config=graph_config.tooling.primary_llm.config,
prompt_piece=dc_formatting_prompt,
reserved_str="",
),
)
]
# msg = [
# HumanMessage(
# content=trim_prompt_piece(
# config=graph_config.tooling.primary_llm.config,
# prompt_piece=dc_formatting_prompt,
# reserved_str="",
# ),
# )
# ]
try:
_ = run_with_timeout(
60,
lambda: stream_llm_answer(
llm=graph_config.tooling.primary_llm,
prompt=msg,
event_name="initial_agent_answer",
writer=writer,
agent_answer_level=0,
agent_answer_question_num=0,
agent_answer_type="agent_level_answer",
timeout_override=30,
max_tokens=None,
),
)
# try:
# _ = run_with_timeout(
# 60,
# lambda: stream_llm_answer(
# llm=graph_config.tooling.primary_llm,
# prompt=msg,
# event_name="initial_agent_answer",
# writer=writer,
# agent_answer_level=0,
# agent_answer_question_num=0,
# agent_answer_type="agent_level_answer",
# timeout_override=30,
# max_tokens=None,
# ),
# )
except Exception as e:
raise ValueError(f"Error in consolidate_research: {e}")
# except Exception as e:
# raise ValueError(f"Error in consolidate_research: {e}")
logger.debug("DivCon Step A5 - Final Generation - completed")
# logger.debug("DivCon Step A5 - Final Generation - completed")
return ResearchUpdate(
research_results=research_results,
log_messages=["Agent Source Consilidation done"],
)
# return ResearchUpdate(
# research_results=research_results,
# log_messages=["Agent Source Consilidation done"],
# )

View File

@@ -1,61 +1,50 @@
from datetime import datetime
from typing import cast
# from datetime import datetime
# from typing import cast
from onyx.chat.models import LlmDoc
from onyx.configs.constants import DocumentSource
from onyx.context.search.models import InferenceSection
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.tools.models import SearchToolOverrideKwargs
from onyx.tools.tool_implementations.search.search_tool import (
FINAL_CONTEXT_DOCUMENTS_ID,
)
from onyx.tools.tool_implementations.search.search_tool import SearchTool
# from onyx.chat.models import LlmDoc
# from onyx.configs.constants import DocumentSource
# from onyx.tools.models import SearchToolOverrideKwargs
# from onyx.tools.tool_implementations.search.search_tool import (
# FINAL_CONTEXT_DOCUMENTS_ID,
# )
# from onyx.tools.tool_implementations.search.search_tool import SearchTool
def research(
question: str,
search_tool: SearchTool,
document_sources: list[DocumentSource] | None = None,
time_cutoff: datetime | None = None,
) -> list[LlmDoc]:
# new db session to avoid concurrency issues
# def research(
# question: str,
# search_tool: SearchTool,
# document_sources: list[DocumentSource] | None = None,
# time_cutoff: datetime | None = None,
# ) -> list[LlmDoc]:
# # new db session to avoid concurrency issues
callback_container: list[list[InferenceSection]] = []
retrieved_docs: list[LlmDoc] = []
# retrieved_docs: list[LlmDoc] = []
with get_session_with_current_tenant() as db_session:
for tool_response in search_tool.run(
query=question,
override_kwargs=SearchToolOverrideKwargs(
force_no_rerank=False,
alternate_db_session=db_session,
retrieved_sections_callback=callback_container.append,
skip_query_analysis=True,
document_sources=document_sources,
time_cutoff=time_cutoff,
),
):
# get retrieved docs to send to the rest of the graph
if tool_response.id == FINAL_CONTEXT_DOCUMENTS_ID:
retrieved_docs = cast(list[LlmDoc], tool_response.response)[:10]
break
return retrieved_docs
# for tool_response in search_tool.run(
# query=question,
# override_kwargs=SearchToolOverrideKwargs(original_query=question),
# ):
# # get retrieved docs to send to the rest of the graph
# if tool_response.id == FINAL_CONTEXT_DOCUMENTS_ID:
# retrieved_docs = cast(list[LlmDoc], tool_response.response)[:10]
# break
# return retrieved_docs
def extract_section(
text: str, start_marker: str, end_marker: str | None = None
) -> str | None:
"""Extract text between markers, returning None if markers not found"""
parts = text.split(start_marker)
# def extract_section(
# text: str, start_marker: str, end_marker: str | None = None
# ) -> str | None:
# """Extract text between markers, returning None if markers not found"""
# parts = text.split(start_marker)
if len(parts) == 1:
return None
# if len(parts) == 1:
# return None
after_start = parts[1].strip()
# after_start = parts[1].strip()
if not end_marker:
return after_start
# if not end_marker:
# return after_start
extract = after_start.split(end_marker)[0]
# extract = after_start.split(end_marker)[0]
return extract.strip()
# return extract.strip()

View File

@@ -1,72 +1,72 @@
from operator import add
from typing import Annotated
from typing import Dict
from typing import TypedDict
# from operator import add
# from typing import Annotated
# from typing import Dict
# from typing import TypedDict
from pydantic import BaseModel
# from pydantic import BaseModel
from onyx.agents.agent_search.core_state import CoreState
from onyx.agents.agent_search.orchestration.states import ToolCallUpdate
from onyx.agents.agent_search.orchestration.states import ToolChoiceInput
from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
from onyx.configs.constants import DocumentSource
# from onyx.agents.agent_search.core_state import CoreState
# from onyx.agents.agent_search.orchestration.states import ToolCallUpdate
# from onyx.agents.agent_search.orchestration.states import ToolChoiceInput
# from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
# from onyx.configs.constants import DocumentSource
### States ###
class LoggerUpdate(BaseModel):
log_messages: Annotated[list[str], add] = []
# ### States ###
# class LoggerUpdate(BaseModel):
# log_messages: Annotated[list[str], add] = []
class SearchSourcesObjectsUpdate(LoggerUpdate):
analysis_objects: list[str] = []
analysis_sources: list[DocumentSource] = []
# class SearchSourcesObjectsUpdate(LoggerUpdate):
# analysis_objects: list[str] = []
# analysis_sources: list[DocumentSource] = []
class ObjectSourceInput(LoggerUpdate):
object_source_combination: tuple[str, DocumentSource]
# class ObjectSourceInput(LoggerUpdate):
# object_source_combination: tuple[str, DocumentSource]
class ObjectSourceResearchUpdate(LoggerUpdate):
object_source_research_results: Annotated[list[Dict[str, str]], add] = []
# class ObjectSourceResearchUpdate(LoggerUpdate):
# object_source_research_results: Annotated[list[Dict[str, str]], add] = []
class ObjectInformationInput(LoggerUpdate):
object_information: Dict[str, str]
# class ObjectInformationInput(LoggerUpdate):
# object_information: Dict[str, str]
class ObjectResearchInformationUpdate(LoggerUpdate):
object_research_information_results: Annotated[list[Dict[str, str]], add] = []
# class ObjectResearchInformationUpdate(LoggerUpdate):
# object_research_information_results: Annotated[list[Dict[str, str]], add] = []
class ObjectResearchUpdate(LoggerUpdate):
object_research_results: Annotated[list[Dict[str, str]], add] = []
# class ObjectResearchUpdate(LoggerUpdate):
# object_research_results: Annotated[list[Dict[str, str]], add] = []
class ResearchUpdate(LoggerUpdate):
research_results: str | None = None
# class ResearchUpdate(LoggerUpdate):
# research_results: str | None = None
## Graph Input State
class MainInput(CoreState):
pass
# ## Graph Input State
# class MainInput(CoreState):
# pass
## Graph State
class MainState(
# This includes the core state
MainInput,
ToolChoiceInput,
ToolCallUpdate,
ToolChoiceUpdate,
SearchSourcesObjectsUpdate,
ObjectSourceResearchUpdate,
ObjectResearchInformationUpdate,
ObjectResearchUpdate,
ResearchUpdate,
):
pass
# ## Graph State
# class MainState(
# # This includes the core state
# MainInput,
# ToolChoiceInput,
# ToolCallUpdate,
# ToolChoiceUpdate,
# SearchSourcesObjectsUpdate,
# ObjectSourceResearchUpdate,
# ObjectResearchInformationUpdate,
# ObjectResearchUpdate,
# ResearchUpdate,
# ):
# pass
## Graph Output State - presently not used
class MainOutput(TypedDict):
log_messages: list[str]
# ## Graph Output State - presently not used
# class MainOutput(TypedDict):
# log_messages: list[str]

View File

@@ -1,36 +1,36 @@
from pydantic import BaseModel
# from pydantic import BaseModel
class RefinementSubQuestion(BaseModel):
sub_question: str
sub_question_id: str
verified: bool
answered: bool
answer: str
# class RefinementSubQuestion(BaseModel):
# sub_question: str
# sub_question_id: str
# verified: bool
# answered: bool
# answer: str
class AgentTimings(BaseModel):
base_duration_s: float | None
refined_duration_s: float | None
full_duration_s: float | None
# class AgentTimings(BaseModel):
# base_duration_s: float | None
# refined_duration_s: float | None
# full_duration_s: float | None
class AgentBaseMetrics(BaseModel):
num_verified_documents_total: int | None
num_verified_documents_core: int | None
verified_avg_score_core: float | None
num_verified_documents_base: int | float | None
verified_avg_score_base: float | None = None
base_doc_boost_factor: float | None = None
support_boost_factor: float | None = None
duration_s: float | None = None
# class AgentBaseMetrics(BaseModel):
# num_verified_documents_total: int | None
# num_verified_documents_core: int | None
# verified_avg_score_core: float | None
# num_verified_documents_base: int | float | None
# verified_avg_score_base: float | None = None
# base_doc_boost_factor: float | None = None
# support_boost_factor: float | None = None
# duration_s: float | None = None
class AgentRefinedMetrics(BaseModel):
refined_doc_boost_factor: float | None = None
refined_question_boost_factor: float | None = None
duration_s: float | None = None
# class AgentRefinedMetrics(BaseModel):
# refined_doc_boost_factor: float | None = None
# refined_question_boost_factor: float | None = None
# duration_s: float | None = None
class AgentAdditionalMetrics(BaseModel):
pass
# class AgentAdditionalMetrics(BaseModel):
# pass

View File

@@ -1,61 +1,61 @@
from collections.abc import Hashable
# from collections.abc import Hashable
from langgraph.graph import END
from langgraph.types import Send
# from langgraph.graph import END
# from langgraph.types import Send
from onyx.agents.agent_search.dr.enums import DRPath
from onyx.agents.agent_search.dr.states import MainState
# from onyx.agents.agent_search.dr.enums import DRPath
# from onyx.agents.agent_search.dr.states import MainState
def decision_router(state: MainState) -> list[Send | Hashable] | DRPath | str:
if not state.tools_used:
raise IndexError("state.tools_used cannot be empty")
# def decision_router(state: MainState) -> list[Send | Hashable] | DRPath | str:
# if not state.tools_used:
# raise IndexError("state.tools_used cannot be empty")
# next_tool is either a generic tool name or a DRPath string
next_tool_name = state.tools_used[-1]
# # next_tool is either a generic tool name or a DRPath string
# next_tool_name = state.tools_used[-1]
available_tools = state.available_tools
if not available_tools:
raise ValueError("No tool is available. This should not happen.")
# available_tools = state.available_tools
# if not available_tools:
# raise ValueError("No tool is available. This should not happen.")
if next_tool_name in available_tools:
next_tool_path = available_tools[next_tool_name].path
elif next_tool_name == DRPath.END.value:
return END
elif next_tool_name == DRPath.LOGGER.value:
return DRPath.LOGGER
elif next_tool_name == DRPath.CLOSER.value:
return DRPath.CLOSER
else:
return DRPath.ORCHESTRATOR
# if next_tool_name in available_tools:
# next_tool_path = available_tools[next_tool_name].path
# elif next_tool_name == DRPath.END.value:
# return END
# elif next_tool_name == DRPath.LOGGER.value:
# return DRPath.LOGGER
# elif next_tool_name == DRPath.CLOSER.value:
# return DRPath.CLOSER
# else:
# return DRPath.ORCHESTRATOR
# handle invalid paths
if next_tool_path == DRPath.CLARIFIER:
raise ValueError("CLARIFIER is not a valid path during iteration")
# # handle invalid paths
# if next_tool_path == DRPath.CLARIFIER:
# raise ValueError("CLARIFIER is not a valid path during iteration")
# handle tool calls without a query
if (
next_tool_path
in (
DRPath.INTERNAL_SEARCH,
DRPath.WEB_SEARCH,
DRPath.KNOWLEDGE_GRAPH,
DRPath.IMAGE_GENERATION,
)
and len(state.query_list) == 0
):
return DRPath.CLOSER
# # handle tool calls without a query
# if (
# next_tool_path
# in (
# DRPath.INTERNAL_SEARCH,
# DRPath.WEB_SEARCH,
# DRPath.KNOWLEDGE_GRAPH,
# DRPath.IMAGE_GENERATION,
# )
# and len(state.query_list) == 0
# ):
# return DRPath.CLOSER
return next_tool_path
# return next_tool_path
def completeness_router(state: MainState) -> DRPath | str:
if not state.tools_used:
raise IndexError("tools_used cannot be empty")
# def completeness_router(state: MainState) -> DRPath | str:
# if not state.tools_used:
# raise IndexError("tools_used cannot be empty")
# go to closer if path is CLOSER or no queries
next_path = state.tools_used[-1]
# # go to closer if path is CLOSER or no queries
# next_path = state.tools_used[-1]
if next_path == DRPath.ORCHESTRATOR.value:
return DRPath.ORCHESTRATOR
return DRPath.LOGGER
# if next_path == DRPath.ORCHESTRATOR.value:
# return DRPath.ORCHESTRATOR
# return DRPath.LOGGER

View File

@@ -1,31 +1,27 @@
from onyx.agents.agent_search.dr.enums import DRPath
from onyx.agents.agent_search.dr.enums import ResearchType
# from onyx.agents.agent_search.dr.enums import DRPath
MAX_CHAT_HISTORY_MESSAGES = (
3 # note: actual count is x2 to account for user and assistant messages
)
# MAX_CHAT_HISTORY_MESSAGES = (
# 3 # note: actual count is x2 to account for user and assistant messages
# )
MAX_DR_PARALLEL_SEARCH = 4
# MAX_DR_PARALLEL_SEARCH = 4
# TODO: test more, generally not needed/adds unnecessary iterations
MAX_NUM_CLOSER_SUGGESTIONS = (
0 # how many times the closer can send back to the orchestrator
)
# # TODO: test more, generally not needed/adds unnecessary iterations
# MAX_NUM_CLOSER_SUGGESTIONS = (
# 0 # how many times the closer can send back to the orchestrator
# )
CLARIFICATION_REQUEST_PREFIX = "PLEASE CLARIFY:"
HIGH_LEVEL_PLAN_PREFIX = "The Plan:"
# CLARIFICATION_REQUEST_PREFIX = "PLEASE CLARIFY:"
# HIGH_LEVEL_PLAN_PREFIX = "The Plan:"
AVERAGE_TOOL_COSTS: dict[DRPath, float] = {
DRPath.INTERNAL_SEARCH: 1.0,
DRPath.KNOWLEDGE_GRAPH: 2.0,
DRPath.WEB_SEARCH: 1.5,
DRPath.IMAGE_GENERATION: 3.0,
DRPath.GENERIC_TOOL: 1.5, # TODO: see todo in OrchestratorTool
DRPath.CLOSER: 0.0,
}
# AVERAGE_TOOL_COSTS: dict[DRPath, float] = {
# DRPath.INTERNAL_SEARCH: 1.0,
# DRPath.KNOWLEDGE_GRAPH: 2.0,
# DRPath.WEB_SEARCH: 1.5,
# DRPath.IMAGE_GENERATION: 3.0,
# DRPath.GENERIC_TOOL: 1.5, # TODO: see todo in OrchestratorTool
# DRPath.CLOSER: 0.0,
# }
DR_TIME_BUDGET_BY_TYPE = {
ResearchType.THOUGHTFUL: 3.0,
ResearchType.DEEP: 12.0,
ResearchType.FAST: 0.5,
}
# # Default time budget for agentic search (when use_agentic_search is True)
# DR_TIME_BUDGET_DEFAULT = 12.0

View File

@@ -1,112 +1,111 @@
from datetime import datetime
# from datetime import datetime
from onyx.agents.agent_search.dr.enums import DRPath
from onyx.agents.agent_search.dr.enums import ResearchType
from onyx.agents.agent_search.dr.models import DRPromptPurpose
from onyx.agents.agent_search.dr.models import OrchestratorTool
from onyx.prompts.dr_prompts import GET_CLARIFICATION_PROMPT
from onyx.prompts.dr_prompts import KG_TYPES_DESCRIPTIONS
from onyx.prompts.dr_prompts import ORCHESTRATOR_DEEP_INITIAL_PLAN_PROMPT
from onyx.prompts.dr_prompts import ORCHESTRATOR_DEEP_ITERATIVE_DECISION_PROMPT
from onyx.prompts.dr_prompts import ORCHESTRATOR_FAST_ITERATIVE_DECISION_PROMPT
from onyx.prompts.dr_prompts import ORCHESTRATOR_FAST_ITERATIVE_REASONING_PROMPT
from onyx.prompts.dr_prompts import ORCHESTRATOR_NEXT_STEP_PURPOSE_PROMPT
from onyx.prompts.dr_prompts import TOOL_DIFFERENTIATION_HINTS
from onyx.prompts.dr_prompts import TOOL_QUESTION_HINTS
from onyx.prompts.prompt_template import PromptTemplate
# from onyx.agents.agent_search.dr.enums import DRPath
# from onyx.agents.agent_search.dr.models import DRPromptPurpose
# from onyx.agents.agent_search.dr.models import OrchestratorTool
# from onyx.prompts.dr_prompts import GET_CLARIFICATION_PROMPT
# from onyx.prompts.dr_prompts import KG_TYPES_DESCRIPTIONS
# from onyx.prompts.dr_prompts import ORCHESTRATOR_DEEP_INITIAL_PLAN_PROMPT
# from onyx.prompts.dr_prompts import ORCHESTRATOR_DEEP_ITERATIVE_DECISION_PROMPT
# from onyx.prompts.dr_prompts import ORCHESTRATOR_FAST_ITERATIVE_DECISION_PROMPT
# from onyx.prompts.dr_prompts import ORCHESTRATOR_FAST_ITERATIVE_REASONING_PROMPT
# from onyx.prompts.dr_prompts import ORCHESTRATOR_NEXT_STEP_PURPOSE_PROMPT
# from onyx.prompts.dr_prompts import TOOL_DIFFERENTIATION_HINTS
# from onyx.prompts.dr_prompts import TOOL_QUESTION_HINTS
# from onyx.prompts.prompt_template import PromptTemplate
def get_dr_prompt_orchestration_templates(
purpose: DRPromptPurpose,
research_type: ResearchType,
available_tools: dict[str, OrchestratorTool],
entity_types_string: str | None = None,
relationship_types_string: str | None = None,
reasoning_result: str | None = None,
tool_calls_string: str | None = None,
) -> PromptTemplate:
available_tools = available_tools or {}
tool_names = list(available_tools.keys())
tool_description_str = "\n\n".join(
f"- {tool_name}: {tool.description}"
for tool_name, tool in available_tools.items()
)
tool_cost_str = "\n".join(
f"{tool_name}: {tool.cost}" for tool_name, tool in available_tools.items()
)
# def get_dr_prompt_orchestration_templates(
# purpose: DRPromptPurpose,
# use_agentic_search: bool,
# available_tools: dict[str, OrchestratorTool],
# entity_types_string: str | None = None,
# relationship_types_string: str | None = None,
# reasoning_result: str | None = None,
# tool_calls_string: str | None = None,
# ) -> PromptTemplate:
# available_tools = available_tools or {}
# tool_names = list(available_tools.keys())
# tool_description_str = "\n\n".join(
# f"- {tool_name}: {tool.description}"
# for tool_name, tool in available_tools.items()
# )
# tool_cost_str = "\n".join(
# f"{tool_name}: {tool.cost}" for tool_name, tool in available_tools.items()
# )
tool_differentiations: list[str] = [
TOOL_DIFFERENTIATION_HINTS[(tool_1, tool_2)]
for tool_1 in available_tools
for tool_2 in available_tools
if (tool_1, tool_2) in TOOL_DIFFERENTIATION_HINTS
]
tool_differentiation_hint_string = (
"\n".join(tool_differentiations) or "(No differentiating hints available)"
)
# TODO: add tool deliniation pairs for custom tools as well
# tool_differentiations: list[str] = [
# TOOL_DIFFERENTIATION_HINTS[(tool_1, tool_2)]
# for tool_1 in available_tools
# for tool_2 in available_tools
# if (tool_1, tool_2) in TOOL_DIFFERENTIATION_HINTS
# ]
# tool_differentiation_hint_string = (
# "\n".join(tool_differentiations) or "(No differentiating hints available)"
# )
# # TODO: add tool deliniation pairs for custom tools as well
tool_question_hint_string = (
"\n".join(
"- " + TOOL_QUESTION_HINTS[tool]
for tool in available_tools
if tool in TOOL_QUESTION_HINTS
)
or "(No examples available)"
)
# tool_question_hint_string = (
# "\n".join(
# "- " + TOOL_QUESTION_HINTS[tool]
# for tool in available_tools
# if tool in TOOL_QUESTION_HINTS
# )
# or "(No examples available)"
# )
if DRPath.KNOWLEDGE_GRAPH.value in available_tools and (
entity_types_string or relationship_types_string
):
# if DRPath.KNOWLEDGE_GRAPH.value in available_tools and (
# entity_types_string or relationship_types_string
# ):
kg_types_descriptions = KG_TYPES_DESCRIPTIONS.build(
possible_entities=entity_types_string or "",
possible_relationships=relationship_types_string or "",
)
else:
kg_types_descriptions = "(The Knowledge Graph is not used.)"
# kg_types_descriptions = KG_TYPES_DESCRIPTIONS.build(
# possible_entities=entity_types_string or "",
# possible_relationships=relationship_types_string or "",
# )
# else:
# kg_types_descriptions = "(The Knowledge Graph is not used.)"
if purpose == DRPromptPurpose.PLAN:
if research_type == ResearchType.THOUGHTFUL:
raise ValueError("plan generation is not supported for FAST time budget")
base_template = ORCHESTRATOR_DEEP_INITIAL_PLAN_PROMPT
# if purpose == DRPromptPurpose.PLAN:
# if not use_agentic_search:
# raise ValueError("plan generation is only supported for agentic search")
# base_template = ORCHESTRATOR_DEEP_INITIAL_PLAN_PROMPT
elif purpose == DRPromptPurpose.NEXT_STEP_REASONING:
if research_type == ResearchType.THOUGHTFUL:
base_template = ORCHESTRATOR_FAST_ITERATIVE_REASONING_PROMPT
else:
raise ValueError(
"reasoning is not separately required for DEEP time budget"
)
# elif purpose == DRPromptPurpose.NEXT_STEP_REASONING:
# if not use_agentic_search:
# base_template = ORCHESTRATOR_FAST_ITERATIVE_REASONING_PROMPT
# else:
# raise ValueError(
# "reasoning is not separately required for agentic search"
# )
elif purpose == DRPromptPurpose.NEXT_STEP_PURPOSE:
base_template = ORCHESTRATOR_NEXT_STEP_PURPOSE_PROMPT
# elif purpose == DRPromptPurpose.NEXT_STEP_PURPOSE:
# base_template = ORCHESTRATOR_NEXT_STEP_PURPOSE_PROMPT
elif purpose == DRPromptPurpose.NEXT_STEP:
if research_type == ResearchType.THOUGHTFUL:
base_template = ORCHESTRATOR_FAST_ITERATIVE_DECISION_PROMPT
else:
base_template = ORCHESTRATOR_DEEP_ITERATIVE_DECISION_PROMPT
# elif purpose == DRPromptPurpose.NEXT_STEP:
# if not use_agentic_search:
# base_template = ORCHESTRATOR_FAST_ITERATIVE_DECISION_PROMPT
# else:
# base_template = ORCHESTRATOR_DEEP_ITERATIVE_DECISION_PROMPT
elif purpose == DRPromptPurpose.CLARIFICATION:
if research_type == ResearchType.THOUGHTFUL:
raise ValueError("clarification is not supported for FAST time budget")
base_template = GET_CLARIFICATION_PROMPT
# elif purpose == DRPromptPurpose.CLARIFICATION:
# if not use_agentic_search:
# raise ValueError("clarification is only supported for agentic search")
# base_template = GET_CLARIFICATION_PROMPT
else:
# for mypy, clearly a mypy bug
raise ValueError(f"Invalid purpose: {purpose}")
# else:
# # for mypy, clearly a mypy bug
# raise ValueError(f"Invalid purpose: {purpose}")
return base_template.partial_build(
num_available_tools=str(len(tool_names)),
available_tools=", ".join(tool_names),
tool_choice_options=" or ".join(tool_names),
current_time=datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
kg_types_descriptions=kg_types_descriptions,
tool_descriptions=tool_description_str,
tool_differentiation_hints=tool_differentiation_hint_string,
tool_question_hints=tool_question_hint_string,
average_tool_costs=tool_cost_str,
reasoning_result=reasoning_result or "(No reasoning result provided.)",
tool_calls_string=tool_calls_string or "(No tool calls provided.)",
)
# return base_template.partial_build(
# num_available_tools=str(len(tool_names)),
# available_tools=", ".join(tool_names),
# tool_choice_options=" or ".join(tool_names),
# current_time=datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
# kg_types_descriptions=kg_types_descriptions,
# tool_descriptions=tool_description_str,
# tool_differentiation_hints=tool_differentiation_hint_string,
# tool_question_hints=tool_question_hint_string,
# average_tool_costs=tool_cost_str,
# reasoning_result=reasoning_result or "(No reasoning result provided.)",
# tool_calls_string=tool_calls_string or "(No tool calls provided.)",
# )

View File

@@ -1,32 +1,22 @@
from enum import Enum
# from enum import Enum
class ResearchType(str, Enum):
"""Research type options for agent search operations"""
# class ResearchAnswerPurpose(str, Enum):
# """Research answer purpose options for agent search operations"""
# BASIC = "BASIC"
LEGACY_AGENTIC = "LEGACY_AGENTIC" # only used for legacy agentic search migrations
THOUGHTFUL = "THOUGHTFUL"
DEEP = "DEEP"
FAST = "FAST"
# ANSWER = "ANSWER"
# CLARIFICATION_REQUEST = "CLARIFICATION_REQUEST"
class ResearchAnswerPurpose(str, Enum):
"""Research answer purpose options for agent search operations"""
ANSWER = "ANSWER"
CLARIFICATION_REQUEST = "CLARIFICATION_REQUEST"
class DRPath(str, Enum):
CLARIFIER = "Clarifier"
ORCHESTRATOR = "Orchestrator"
INTERNAL_SEARCH = "Internal Search"
GENERIC_TOOL = "Generic Tool"
KNOWLEDGE_GRAPH = "Knowledge Graph Search"
WEB_SEARCH = "Web Search"
IMAGE_GENERATION = "Image Generation"
GENERIC_INTERNAL_TOOL = "Generic Internal Tool"
CLOSER = "Closer"
LOGGER = "Logger"
END = "End"
# class DRPath(str, Enum):
# CLARIFIER = "Clarifier"
# ORCHESTRATOR = "Orchestrator"
# INTERNAL_SEARCH = "Internal Search"
# GENERIC_TOOL = "Generic Tool"
# KNOWLEDGE_GRAPH = "Knowledge Graph Search"
# WEB_SEARCH = "Web Search"
# IMAGE_GENERATION = "Image Generation"
# GENERIC_INTERNAL_TOOL = "Generic Internal Tool"
# CLOSER = "Closer"
# LOGGER = "Logger"
# END = "End"

View File

@@ -1,88 +1,88 @@
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
# from langgraph.graph import END
# from langgraph.graph import START
# from langgraph.graph import StateGraph
from onyx.agents.agent_search.dr.conditional_edges import completeness_router
from onyx.agents.agent_search.dr.conditional_edges import decision_router
from onyx.agents.agent_search.dr.enums import DRPath
from onyx.agents.agent_search.dr.nodes.dr_a0_clarification import clarifier
from onyx.agents.agent_search.dr.nodes.dr_a1_orchestrator import orchestrator
from onyx.agents.agent_search.dr.nodes.dr_a2_closer import closer
from onyx.agents.agent_search.dr.nodes.dr_a3_logger import logging
from onyx.agents.agent_search.dr.states import MainInput
from onyx.agents.agent_search.dr.states import MainState
from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_basic_search_graph_builder import (
dr_basic_search_graph_builder,
)
from onyx.agents.agent_search.dr.sub_agents.custom_tool.dr_custom_tool_graph_builder import (
dr_custom_tool_graph_builder,
)
from onyx.agents.agent_search.dr.sub_agents.generic_internal_tool.dr_generic_internal_tool_graph_builder import (
dr_generic_internal_tool_graph_builder,
)
from onyx.agents.agent_search.dr.sub_agents.image_generation.dr_image_generation_graph_builder import (
dr_image_generation_graph_builder,
)
from onyx.agents.agent_search.dr.sub_agents.kg_search.dr_kg_search_graph_builder import (
dr_kg_search_graph_builder,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.dr_ws_graph_builder import (
dr_ws_graph_builder,
)
# from onyx.agents.agent_search.dr.conditional_edges import completeness_router
# from onyx.agents.agent_search.dr.conditional_edges import decision_router
# from onyx.agents.agent_search.dr.enums import DRPath
# from onyx.agents.agent_search.dr.nodes.dr_a0_clarification import clarifier
# from onyx.agents.agent_search.dr.nodes.dr_a1_orchestrator import orchestrator
# from onyx.agents.agent_search.dr.nodes.dr_a2_closer import closer
# from onyx.agents.agent_search.dr.nodes.dr_a3_logger import logging
# from onyx.agents.agent_search.dr.states import MainInput
# from onyx.agents.agent_search.dr.states import MainState
# from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_basic_search_graph_builder import (
# dr_basic_search_graph_builder,
# )
# from onyx.agents.agent_search.dr.sub_agents.custom_tool.dr_custom_tool_graph_builder import (
# dr_custom_tool_graph_builder,
# )
# from onyx.agents.agent_search.dr.sub_agents.generic_internal_tool.dr_generic_internal_tool_graph_builder import (
# dr_generic_internal_tool_graph_builder,
# )
# from onyx.agents.agent_search.dr.sub_agents.image_generation.dr_image_generation_graph_builder import (
# dr_image_generation_graph_builder,
# )
# from onyx.agents.agent_search.dr.sub_agents.kg_search.dr_kg_search_graph_builder import (
# dr_kg_search_graph_builder,
# )
# from onyx.agents.agent_search.dr.sub_agents.web_search.dr_ws_graph_builder import (
# dr_ws_graph_builder,
# )
# from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_basic_search_2_act import search
# # from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_basic_search_2_act import search
def dr_graph_builder() -> StateGraph:
"""
LangGraph graph builder for the deep research agent.
"""
# def dr_graph_builder() -> StateGraph:
# """
# LangGraph graph builder for the deep research agent.
# """
graph = StateGraph(state_schema=MainState, input=MainInput)
# graph = StateGraph(state_schema=MainState, input=MainInput)
### Add nodes ###
# ### Add nodes ###
graph.add_node(DRPath.CLARIFIER, clarifier)
# graph.add_node(DRPath.CLARIFIER, clarifier)
graph.add_node(DRPath.ORCHESTRATOR, orchestrator)
# graph.add_node(DRPath.ORCHESTRATOR, orchestrator)
basic_search_graph = dr_basic_search_graph_builder().compile()
graph.add_node(DRPath.INTERNAL_SEARCH, basic_search_graph)
# basic_search_graph = dr_basic_search_graph_builder().compile()
# graph.add_node(DRPath.INTERNAL_SEARCH, basic_search_graph)
kg_search_graph = dr_kg_search_graph_builder().compile()
graph.add_node(DRPath.KNOWLEDGE_GRAPH, kg_search_graph)
# kg_search_graph = dr_kg_search_graph_builder().compile()
# graph.add_node(DRPath.KNOWLEDGE_GRAPH, kg_search_graph)
internet_search_graph = dr_ws_graph_builder().compile()
graph.add_node(DRPath.WEB_SEARCH, internet_search_graph)
# internet_search_graph = dr_ws_graph_builder().compile()
# graph.add_node(DRPath.WEB_SEARCH, internet_search_graph)
image_generation_graph = dr_image_generation_graph_builder().compile()
graph.add_node(DRPath.IMAGE_GENERATION, image_generation_graph)
# image_generation_graph = dr_image_generation_graph_builder().compile()
# graph.add_node(DRPath.IMAGE_GENERATION, image_generation_graph)
custom_tool_graph = dr_custom_tool_graph_builder().compile()
graph.add_node(DRPath.GENERIC_TOOL, custom_tool_graph)
# custom_tool_graph = dr_custom_tool_graph_builder().compile()
# graph.add_node(DRPath.GENERIC_TOOL, custom_tool_graph)
generic_internal_tool_graph = dr_generic_internal_tool_graph_builder().compile()
graph.add_node(DRPath.GENERIC_INTERNAL_TOOL, generic_internal_tool_graph)
# generic_internal_tool_graph = dr_generic_internal_tool_graph_builder().compile()
# graph.add_node(DRPath.GENERIC_INTERNAL_TOOL, generic_internal_tool_graph)
graph.add_node(DRPath.CLOSER, closer)
graph.add_node(DRPath.LOGGER, logging)
# graph.add_node(DRPath.CLOSER, closer)
# graph.add_node(DRPath.LOGGER, logging)
### Add edges ###
# ### Add edges ###
graph.add_edge(start_key=START, end_key=DRPath.CLARIFIER)
# graph.add_edge(start_key=START, end_key=DRPath.CLARIFIER)
graph.add_conditional_edges(DRPath.CLARIFIER, decision_router)
# graph.add_conditional_edges(DRPath.CLARIFIER, decision_router)
graph.add_conditional_edges(DRPath.ORCHESTRATOR, decision_router)
# graph.add_conditional_edges(DRPath.ORCHESTRATOR, decision_router)
graph.add_edge(start_key=DRPath.INTERNAL_SEARCH, end_key=DRPath.ORCHESTRATOR)
graph.add_edge(start_key=DRPath.KNOWLEDGE_GRAPH, end_key=DRPath.ORCHESTRATOR)
graph.add_edge(start_key=DRPath.WEB_SEARCH, end_key=DRPath.ORCHESTRATOR)
graph.add_edge(start_key=DRPath.IMAGE_GENERATION, end_key=DRPath.ORCHESTRATOR)
graph.add_edge(start_key=DRPath.GENERIC_TOOL, end_key=DRPath.ORCHESTRATOR)
graph.add_edge(start_key=DRPath.GENERIC_INTERNAL_TOOL, end_key=DRPath.ORCHESTRATOR)
# graph.add_edge(start_key=DRPath.INTERNAL_SEARCH, end_key=DRPath.ORCHESTRATOR)
# graph.add_edge(start_key=DRPath.KNOWLEDGE_GRAPH, end_key=DRPath.ORCHESTRATOR)
# graph.add_edge(start_key=DRPath.WEB_SEARCH, end_key=DRPath.ORCHESTRATOR)
# graph.add_edge(start_key=DRPath.IMAGE_GENERATION, end_key=DRPath.ORCHESTRATOR)
# graph.add_edge(start_key=DRPath.GENERIC_TOOL, end_key=DRPath.ORCHESTRATOR)
# graph.add_edge(start_key=DRPath.GENERIC_INTERNAL_TOOL, end_key=DRPath.ORCHESTRATOR)
graph.add_conditional_edges(DRPath.CLOSER, completeness_router)
graph.add_edge(start_key=DRPath.LOGGER, end_key=END)
# graph.add_conditional_edges(DRPath.CLOSER, completeness_router)
# graph.add_edge(start_key=DRPath.LOGGER, end_key=END)
return graph
# return graph

View File

@@ -1,131 +1,131 @@
from enum import Enum
# from enum import Enum
from pydantic import BaseModel
from pydantic import ConfigDict
# from pydantic import BaseModel
from onyx.agents.agent_search.dr.enums import DRPath
from onyx.agents.agent_search.dr.sub_agents.image_generation.models import (
GeneratedImage,
)
from onyx.context.search.models import InferenceSection
from onyx.tools.tool import Tool
# from onyx.agents.agent_search.dr.enums import DRPath
# from onyx.agents.agent_search.dr.sub_agents.image_generation.models import (
# GeneratedImage,
# )
# from onyx.context.search.models import InferenceSection
# from onyx.tools.tool import Tool
class OrchestratorStep(BaseModel):
tool: str
questions: list[str]
# class OrchestratorStep(BaseModel):
# tool: str
# questions: list[str]
class OrchestratorDecisonsNoPlan(BaseModel):
reasoning: str
next_step: OrchestratorStep
# class OrchestratorDecisonsNoPlan(BaseModel):
# reasoning: str
# next_step: OrchestratorStep
class OrchestrationPlan(BaseModel):
reasoning: str
plan: str
# class OrchestrationPlan(BaseModel):
# reasoning: str
# plan: str
class ClarificationGenerationResponse(BaseModel):
clarification_needed: bool
clarification_question: str
# class ClarificationGenerationResponse(BaseModel):
# clarification_needed: bool
# clarification_question: str
class DecisionResponse(BaseModel):
reasoning: str
decision: str
# class DecisionResponse(BaseModel):
# reasoning: str
# decision: str
class QueryEvaluationResponse(BaseModel):
reasoning: str
query_permitted: bool
# class QueryEvaluationResponse(BaseModel):
# reasoning: str
# query_permitted: bool
class OrchestrationClarificationInfo(BaseModel):
clarification_question: str
clarification_response: str | None = None
# class OrchestrationClarificationInfo(BaseModel):
# clarification_question: str
# clarification_response: str | None = None
class WebSearchAnswer(BaseModel):
urls_to_open_indices: list[int]
# class WebSearchAnswer(BaseModel):
# urls_to_open_indices: list[int]
class SearchAnswer(BaseModel):
reasoning: str
answer: str
claims: list[str] | None = None
# class SearchAnswer(BaseModel):
# reasoning: str
# answer: str
# claims: list[str] | None = None
class TestInfoCompleteResponse(BaseModel):
reasoning: str
complete: bool
gaps: list[str]
# class TestInfoCompleteResponse(BaseModel):
# reasoning: str
# complete: bool
# gaps: list[str]
# TODO: revisit with custom tools implementation in v2
# each tool should be a class with the attributes below, plus the actual tool implementation
# this will also allow custom tools to have their own cost
class OrchestratorTool(BaseModel):
tool_id: int
name: str
llm_path: str # the path for the LLM to refer by
path: DRPath # the actual path in the graph
description: str
metadata: dict[str, str]
cost: float
tool_object: Tool | None = None # None for CLOSER
# # TODO: revisit with custom tools implementation in v2
# # each tool should be a class with the attributes below, plus the actual tool implementation
# # this will also allow custom tools to have their own cost
# class OrchestratorTool(BaseModel):
# tool_id: int
# name: str
# llm_path: str # the path for the LLM to refer by
# path: DRPath # the actual path in the graph
# description: str
# metadata: dict[str, str]
# cost: float
# tool_object: Tool | None = None # None for CLOSER
model_config = ConfigDict(arbitrary_types_allowed=True)
# class Config:
# arbitrary_types_allowed = True
class IterationInstructions(BaseModel):
iteration_nr: int
plan: str | None
reasoning: str
purpose: str
# class IterationInstructions(BaseModel):
# iteration_nr: int
# plan: str | None
# reasoning: str
# purpose: str
class IterationAnswer(BaseModel):
tool: str
tool_id: int
iteration_nr: int
parallelization_nr: int
question: str
reasoning: str | None
answer: str
cited_documents: dict[int, InferenceSection]
background_info: str | None = None
claims: list[str] | None = None
additional_data: dict[str, str] | None = None
response_type: str | None = None
data: dict | list | str | int | float | bool | None = None
file_ids: list[str] | None = None
# TODO: This is not ideal, but we'll can rework the schema
# for deep research later
is_web_fetch: bool = False
# for image generation step-types
generated_images: list[GeneratedImage] | None = None
# for multi-query search tools (v2 web search and internal search)
# TODO: Clean this up to be more flexible to tools
queries: list[str] | None = None
# class IterationAnswer(BaseModel):
# tool: str
# tool_id: int
# iteration_nr: int
# parallelization_nr: int
# question: str
# reasoning: str | None
# answer: str
# cited_documents: dict[int, InferenceSection]
# background_info: str | None = None
# claims: list[str] | None = None
# additional_data: dict[str, str] | None = None
# response_type: str | None = None
# data: dict | list | str | int | float | bool | None = None
# file_ids: list[str] | None = None
# # TODO: This is not ideal, but we'll can rework the schema
# # for deep research later
# is_web_fetch: bool = False
# # for image generation step-types
# generated_images: list[GeneratedImage] | None = None
# # for multi-query search tools (v2 web search and internal search)
# # TODO: Clean this up to be more flexible to tools
# queries: list[str] | None = None
class AggregatedDRContext(BaseModel):
context: str
cited_documents: list[InferenceSection]
is_internet_marker_dict: dict[str, bool]
global_iteration_responses: list[IterationAnswer]
# class AggregatedDRContext(BaseModel):
# context: str
# cited_documents: list[InferenceSection]
# is_internet_marker_dict: dict[str, bool]
# global_iteration_responses: list[IterationAnswer]
class DRPromptPurpose(str, Enum):
PLAN = "PLAN"
NEXT_STEP = "NEXT_STEP"
NEXT_STEP_REASONING = "NEXT_STEP_REASONING"
NEXT_STEP_PURPOSE = "NEXT_STEP_PURPOSE"
CLARIFICATION = "CLARIFICATION"
# class DRPromptPurpose(str, Enum):
# PLAN = "PLAN"
# NEXT_STEP = "NEXT_STEP"
# NEXT_STEP_REASONING = "NEXT_STEP_REASONING"
# NEXT_STEP_PURPOSE = "NEXT_STEP_PURPOSE"
# CLARIFICATION = "CLARIFICATION"
class BaseSearchProcessingResponse(BaseModel):
specified_source_types: list[str]
rewritten_query: str
time_filter: str
# class BaseSearchProcessingResponse(BaseModel):
# specified_source_types: list[str]
# rewritten_query: str
# time_filter: str

File diff suppressed because it is too large Load Diff

View File

@@ -1,423 +1,418 @@
import re
from datetime import datetime
from typing import cast
# import re
# from datetime import datetime
# from typing import cast
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from sqlalchemy.orm import Session
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
# from sqlalchemy.orm import Session
from onyx.agents.agent_search.dr.constants import MAX_CHAT_HISTORY_MESSAGES
from onyx.agents.agent_search.dr.constants import MAX_NUM_CLOSER_SUGGESTIONS
from onyx.agents.agent_search.dr.enums import DRPath
from onyx.agents.agent_search.dr.enums import ResearchAnswerPurpose
from onyx.agents.agent_search.dr.enums import ResearchType
from onyx.agents.agent_search.dr.models import AggregatedDRContext
from onyx.agents.agent_search.dr.models import TestInfoCompleteResponse
from onyx.agents.agent_search.dr.states import FinalUpdate
from onyx.agents.agent_search.dr.states import MainState
from onyx.agents.agent_search.dr.states import OrchestrationUpdate
from onyx.agents.agent_search.dr.sub_agents.image_generation.models import (
GeneratedImageFullResult,
)
from onyx.agents.agent_search.dr.utils import aggregate_context
from onyx.agents.agent_search.dr.utils import convert_inference_sections_to_search_docs
from onyx.agents.agent_search.dr.utils import get_chat_history_string
from onyx.agents.agent_search.dr.utils import get_prompt_question
from onyx.agents.agent_search.dr.utils import parse_plan_to_dict
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.llm import invoke_llm_json
from onyx.agents.agent_search.shared_graph_utils.llm import stream_llm_answer
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.agents.agent_search.utils import create_question_prompt
from onyx.chat.chat_utils import llm_doc_from_inference_section
from onyx.configs.agent_configs import TF_DR_TIMEOUT_LONG
from onyx.context.search.models import InferenceSection
from onyx.db.chat import create_search_doc_from_inference_section
from onyx.db.chat import update_db_session_with_messages
from onyx.db.models import ChatMessage__SearchDoc
from onyx.db.models import ResearchAgentIteration
from onyx.db.models import ResearchAgentIterationSubStep
from onyx.db.models import SearchDoc as DbSearchDoc
from onyx.llm.utils import check_number_of_tokens
from onyx.prompts.chat_prompts import PROJECT_INSTRUCTIONS_SEPARATOR
from onyx.prompts.dr_prompts import FINAL_ANSWER_PROMPT_W_SUB_ANSWERS
from onyx.prompts.dr_prompts import FINAL_ANSWER_PROMPT_WITHOUT_SUB_ANSWERS
from onyx.prompts.dr_prompts import TEST_INFO_COMPLETE_PROMPT
from onyx.server.query_and_chat.streaming_models import CitationDelta
from onyx.server.query_and_chat.streaming_models import CitationStart
from onyx.server.query_and_chat.streaming_models import MessageStart
from onyx.server.query_and_chat.streaming_models import SectionEnd
from onyx.server.query_and_chat.streaming_models import StreamingType
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
# from onyx.agents.agent_search.dr.constants import MAX_CHAT_HISTORY_MESSAGES
# from onyx.agents.agent_search.dr.constants import MAX_NUM_CLOSER_SUGGESTIONS
# from onyx.agents.agent_search.dr.enums import DRPath
# from onyx.agents.agent_search.dr.enums import ResearchAnswerPurpose
# from onyx.agents.agent_search.dr.models import AggregatedDRContext
# from onyx.agents.agent_search.dr.models import TestInfoCompleteResponse
# from onyx.agents.agent_search.dr.states import FinalUpdate
# from onyx.agents.agent_search.dr.states import MainState
# from onyx.agents.agent_search.dr.states import OrchestrationUpdate
# from onyx.agents.agent_search.dr.sub_agents.image_generation.models import (
# GeneratedImageFullResult,
# )
# from onyx.agents.agent_search.dr.utils import aggregate_context
# from onyx.agents.agent_search.dr.utils import convert_inference_sections_to_search_docs
# from onyx.agents.agent_search.dr.utils import get_chat_history_string
# from onyx.agents.agent_search.dr.utils import get_prompt_question
# from onyx.agents.agent_search.dr.utils import parse_plan_to_dict
# from onyx.agents.agent_search.models import GraphConfig
# from onyx.agents.agent_search.shared_graph_utils.llm import invoke_llm_json
# from onyx.agents.agent_search.shared_graph_utils.llm import stream_llm_answer
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
# from onyx.agents.agent_search.utils import create_question_prompt
# from onyx.chat.chat_utils import llm_doc_from_inference_section
# from onyx.configs.agent_configs import TF_DR_TIMEOUT_LONG
# from onyx.context.search.models import InferenceSection
# from onyx.db.chat import create_search_doc_from_inference_section
# from onyx.db.chat import update_db_session_with_messages
# from onyx.db.models import ChatMessage__SearchDoc
# from onyx.db.models import ResearchAgentIteration
# from onyx.db.models import ResearchAgentIterationSubStep
# from onyx.db.models import SearchDoc as DbSearchDoc
# from onyx.llm.utils import check_number_of_tokens
# from onyx.prompts.chat_prompts import PROJECT_INSTRUCTIONS_SEPARATOR
# from onyx.prompts.dr_prompts import FINAL_ANSWER_PROMPT_W_SUB_ANSWERS
# from onyx.prompts.dr_prompts import FINAL_ANSWER_PROMPT_WITHOUT_SUB_ANSWERS
# from onyx.prompts.dr_prompts import TEST_INFO_COMPLETE_PROMPT
# from onyx.server.query_and_chat.streaming_models import CitationDelta
# from onyx.server.query_and_chat.streaming_models import CitationStart
# from onyx.server.query_and_chat.streaming_models import MessageStart
# from onyx.server.query_and_chat.streaming_models import SectionEnd
# from onyx.server.query_and_chat.streaming_models import StreamingType
# from onyx.utils.logger import setup_logger
# from onyx.utils.threadpool_concurrency import run_with_timeout
logger = setup_logger()
# logger = setup_logger()
def extract_citation_numbers(text: str) -> list[int]:
"""
Extract all citation numbers from text in the format [[<number>]] or [[<number_1>, <number_2>, ...]].
Returns a list of all unique citation numbers found.
"""
# Pattern to match [[number]] or [[number1, number2, ...]]
pattern = r"\[\[(\d+(?:,\s*\d+)*)\]\]"
matches = re.findall(pattern, text)
# def extract_citation_numbers(text: str) -> list[int]:
# """
# Extract all citation numbers from text in the format [[<number>]] or [[<number_1>, <number_2>, ...]].
# Returns a list of all unique citation numbers found.
# """
# # Pattern to match [[number]] or [[number1, number2, ...]]
# pattern = r"\[\[(\d+(?:,\s*\d+)*)\]\]"
# matches = re.findall(pattern, text)
cited_numbers = []
for match in matches:
# Split by comma and extract all numbers
numbers = [int(num.strip()) for num in match.split(",")]
cited_numbers.extend(numbers)
# cited_numbers = []
# for match in matches:
# # Split by comma and extract all numbers
# numbers = [int(num.strip()) for num in match.split(",")]
# cited_numbers.extend(numbers)
return list(set(cited_numbers)) # Return unique numbers
# return list(set(cited_numbers)) # Return unique numbers
def replace_citation_with_link(match: re.Match[str], docs: list[DbSearchDoc]) -> str:
citation_content = match.group(1) # e.g., "3" or "3, 5, 7"
numbers = [int(num.strip()) for num in citation_content.split(",")]
# def replace_citation_with_link(match: re.Match[str], docs: list[DbSearchDoc]) -> str:
# citation_content = match.group(1) # e.g., "3" or "3, 5, 7"
# numbers = [int(num.strip()) for num in citation_content.split(",")]
# For multiple citations like [[3, 5, 7]], create separate linked citations
linked_citations = []
for num in numbers:
if num - 1 < len(docs): # Check bounds
link = docs[num - 1].link or ""
linked_citations.append(f"[[{num}]]({link})")
else:
linked_citations.append(f"[[{num}]]") # No link if out of bounds
# # For multiple citations like [[3, 5, 7]], create separate linked citations
# linked_citations = []
# for num in numbers:
# if num - 1 < len(docs): # Check bounds
# link = docs[num - 1].link or ""
# linked_citations.append(f"[[{num}]]({link})")
# else:
# linked_citations.append(f"[[{num}]]") # No link if out of bounds
return "".join(linked_citations)
# return "".join(linked_citations)
def insert_chat_message_search_doc_pair(
message_id: int, search_doc_ids: list[int], db_session: Session
) -> None:
"""
Insert a pair of message_id and search_doc_id into the chat_message__search_doc table.
# def insert_chat_message_search_doc_pair(
# message_id: int, search_doc_ids: list[int], db_session: Session
# ) -> None:
# """
# Insert a pair of message_id and search_doc_id into the chat_message__search_doc table.
Args:
message_id: The ID of the chat message
search_doc_id: The ID of the search document
db_session: The database session
"""
for search_doc_id in search_doc_ids:
chat_message_search_doc = ChatMessage__SearchDoc(
chat_message_id=message_id, search_doc_id=search_doc_id
)
db_session.add(chat_message_search_doc)
# Args:
# message_id: The ID of the chat message
# search_doc_id: The ID of the search document
# db_session: The database session
# """
# for search_doc_id in search_doc_ids:
# chat_message_search_doc = ChatMessage__SearchDoc(
# chat_message_id=message_id, search_doc_id=search_doc_id
# )
# db_session.add(chat_message_search_doc)
def save_iteration(
state: MainState,
graph_config: GraphConfig,
aggregated_context: AggregatedDRContext,
final_answer: str,
all_cited_documents: list[InferenceSection],
is_internet_marker_dict: dict[str, bool],
) -> None:
db_session = graph_config.persistence.db_session
message_id = graph_config.persistence.message_id
research_type = graph_config.behavior.research_type
db_session = graph_config.persistence.db_session
# def save_iteration(
# state: MainState,
# graph_config: GraphConfig,
# aggregated_context: AggregatedDRContext,
# final_answer: str,
# all_cited_documents: list[InferenceSection],
# is_internet_marker_dict: dict[str, bool],
# ) -> None:
# db_session = graph_config.persistence.db_session
# message_id = graph_config.persistence.message_id
# first, insert the search_docs
search_docs = [
create_search_doc_from_inference_section(
inference_section=inference_section,
is_internet=is_internet_marker_dict.get(
inference_section.center_chunk.document_id, False
), # TODO: revisit
db_session=db_session,
commit=False,
)
for inference_section in all_cited_documents
]
# # first, insert the search_docs
# search_docs = [
# create_search_doc_from_inference_section(
# inference_section=inference_section,
# is_internet=is_internet_marker_dict.get(
# inference_section.center_chunk.document_id, False
# ), # TODO: revisit
# db_session=db_session,
# commit=False,
# )
# for inference_section in all_cited_documents
# ]
# then, map_search_docs to message
insert_chat_message_search_doc_pair(
message_id, [search_doc.id for search_doc in search_docs], db_session
)
# # then, map_search_docs to message
# insert_chat_message_search_doc_pair(
# message_id, [search_doc.id for search_doc in search_docs], db_session
# )
# lastly, insert the citations
citation_dict: dict[int, int] = {}
cited_doc_nrs = extract_citation_numbers(final_answer)
for cited_doc_nr in cited_doc_nrs:
citation_dict[cited_doc_nr] = search_docs[cited_doc_nr - 1].id
# # lastly, insert the citations
# citation_dict: dict[int, int] = {}
# cited_doc_nrs = extract_citation_numbers(final_answer)
# for cited_doc_nr in cited_doc_nrs:
# citation_dict[cited_doc_nr] = search_docs[cited_doc_nr - 1].id
# TODO: generate plan as dict in the first place
plan_of_record = state.plan_of_record.plan if state.plan_of_record else ""
plan_of_record_dict = parse_plan_to_dict(plan_of_record)
# # TODO: generate plan as dict in the first place
# plan_of_record = state.plan_of_record.plan if state.plan_of_record else ""
# plan_of_record_dict = parse_plan_to_dict(plan_of_record)
# Update the chat message and its parent message in database
update_db_session_with_messages(
db_session=db_session,
chat_message_id=message_id,
chat_session_id=graph_config.persistence.chat_session_id,
is_agentic=graph_config.behavior.use_agentic_search,
message=final_answer,
citations=citation_dict,
research_type=research_type,
research_plan=plan_of_record_dict,
final_documents=search_docs,
update_parent_message=True,
research_answer_purpose=ResearchAnswerPurpose.ANSWER,
)
# # Update the chat message and its parent message in database
# update_db_session_with_messages(
# db_session=db_session,
# chat_message_id=message_id,
# chat_session_id=graph_config.persistence.chat_session_id,
# is_agentic=graph_config.behavior.use_agentic_search,
# message=final_answer,
# citations=citation_dict,
# research_type=None, # research_type is deprecated
# research_plan=plan_of_record_dict,
# final_documents=search_docs,
# update_parent_message=True,
# research_answer_purpose=ResearchAnswerPurpose.ANSWER,
# )
for iteration_preparation in state.iteration_instructions:
research_agent_iteration_step = ResearchAgentIteration(
primary_question_id=message_id,
reasoning=iteration_preparation.reasoning,
purpose=iteration_preparation.purpose,
iteration_nr=iteration_preparation.iteration_nr,
)
db_session.add(research_agent_iteration_step)
# for iteration_preparation in state.iteration_instructions:
# research_agent_iteration_step = ResearchAgentIteration(
# primary_question_id=message_id,
# reasoning=iteration_preparation.reasoning,
# purpose=iteration_preparation.purpose,
# iteration_nr=iteration_preparation.iteration_nr,
# )
# db_session.add(research_agent_iteration_step)
for iteration_answer in aggregated_context.global_iteration_responses:
# for iteration_answer in aggregated_context.global_iteration_responses:
retrieved_search_docs = convert_inference_sections_to_search_docs(
list(iteration_answer.cited_documents.values())
)
# retrieved_search_docs = convert_inference_sections_to_search_docs(
# list(iteration_answer.cited_documents.values())
# )
# Convert SavedSearchDoc objects to JSON-serializable format
serialized_search_docs = [doc.model_dump() for doc in retrieved_search_docs]
# # Convert SavedSearchDoc objects to JSON-serializable format
# serialized_search_docs = [doc.model_dump() for doc in retrieved_search_docs]
research_agent_iteration_sub_step = ResearchAgentIterationSubStep(
primary_question_id=message_id,
iteration_nr=iteration_answer.iteration_nr,
iteration_sub_step_nr=iteration_answer.parallelization_nr,
sub_step_instructions=iteration_answer.question,
sub_step_tool_id=iteration_answer.tool_id,
sub_answer=iteration_answer.answer,
reasoning=iteration_answer.reasoning,
claims=iteration_answer.claims,
cited_doc_results=serialized_search_docs,
generated_images=(
GeneratedImageFullResult(images=iteration_answer.generated_images)
if iteration_answer.generated_images
else None
),
additional_data=iteration_answer.additional_data,
queries=iteration_answer.queries,
)
db_session.add(research_agent_iteration_sub_step)
# research_agent_iteration_sub_step = ResearchAgentIterationSubStep(
# primary_question_id=message_id,
# iteration_nr=iteration_answer.iteration_nr,
# iteration_sub_step_nr=iteration_answer.parallelization_nr,
# sub_step_instructions=iteration_answer.question,
# sub_step_tool_id=iteration_answer.tool_id,
# sub_answer=iteration_answer.answer,
# reasoning=iteration_answer.reasoning,
# claims=iteration_answer.claims,
# cited_doc_results=serialized_search_docs,
# generated_images=(
# GeneratedImageFullResult(images=iteration_answer.generated_images)
# if iteration_answer.generated_images
# else None
# ),
# additional_data=iteration_answer.additional_data,
# queries=iteration_answer.queries,
# )
# db_session.add(research_agent_iteration_sub_step)
db_session.commit()
# db_session.commit()
def closer(
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> FinalUpdate | OrchestrationUpdate:
"""
LangGraph node to close the DR process and finalize the answer.
"""
# def closer(
# state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
# ) -> FinalUpdate | OrchestrationUpdate:
# """
# LangGraph node to close the DR process and finalize the answer.
# """
node_start_time = datetime.now()
# TODO: generate final answer using all the previous steps
# (right now, answers from each step are concatenated onto each other)
# Also, add missing fields once usage in UI is clear.
# node_start_time = datetime.now()
# # TODO: generate final answer using all the previous steps
# # (right now, answers from each step are concatenated onto each other)
# # Also, add missing fields once usage in UI is clear.
current_step_nr = state.current_step_nr
# current_step_nr = state.current_step_nr
graph_config = cast(GraphConfig, config["metadata"]["config"])
base_question = state.original_question
if not base_question:
raise ValueError("Question is required for closer")
# graph_config = cast(GraphConfig, config["metadata"]["config"])
# base_question = state.original_question
# if not base_question:
# raise ValueError("Question is required for closer")
research_type = graph_config.behavior.research_type
# use_agentic_search = graph_config.behavior.use_agentic_search
assistant_system_prompt: str = state.assistant_system_prompt or ""
assistant_task_prompt = state.assistant_task_prompt
# assistant_system_prompt: str = state.assistant_system_prompt or ""
# assistant_task_prompt = state.assistant_task_prompt
uploaded_context = state.uploaded_test_context or ""
# uploaded_context = state.uploaded_test_context or ""
clarification = state.clarification
prompt_question = get_prompt_question(base_question, clarification)
# clarification = state.clarification
# prompt_question = get_prompt_question(base_question, clarification)
chat_history_string = (
get_chat_history_string(
graph_config.inputs.prompt_builder.message_history,
MAX_CHAT_HISTORY_MESSAGES,
)
or "(No chat history yet available)"
)
# chat_history_string = (
# get_chat_history_string(
# graph_config.inputs.prompt_builder.message_history,
# MAX_CHAT_HISTORY_MESSAGES,
# )
# or "(No chat history yet available)"
# )
aggregated_context_w_docs = aggregate_context(
state.iteration_responses, include_documents=True
)
# aggregated_context_w_docs = aggregate_context(
# state.iteration_responses, include_documents=True
# )
aggregated_context_wo_docs = aggregate_context(
state.iteration_responses, include_documents=False
)
# aggregated_context_wo_docs = aggregate_context(
# state.iteration_responses, include_documents=False
# )
iteration_responses_w_docs_string = aggregated_context_w_docs.context
iteration_responses_wo_docs_string = aggregated_context_wo_docs.context
all_cited_documents = aggregated_context_w_docs.cited_documents
# iteration_responses_w_docs_string = aggregated_context_w_docs.context
# iteration_responses_wo_docs_string = aggregated_context_wo_docs.context
# all_cited_documents = aggregated_context_w_docs.cited_documents
num_closer_suggestions = state.num_closer_suggestions
# num_closer_suggestions = state.num_closer_suggestions
if (
num_closer_suggestions < MAX_NUM_CLOSER_SUGGESTIONS
and research_type == ResearchType.DEEP
):
test_info_complete_prompt = TEST_INFO_COMPLETE_PROMPT.build(
base_question=prompt_question,
questions_answers_claims=iteration_responses_wo_docs_string,
chat_history_string=chat_history_string,
high_level_plan=(
state.plan_of_record.plan
if state.plan_of_record
else "No plan available"
),
)
# if (
# num_closer_suggestions < MAX_NUM_CLOSER_SUGGESTIONS
# and use_agentic_search
# ):
# test_info_complete_prompt = TEST_INFO_COMPLETE_PROMPT.build(
# base_question=prompt_question,
# questions_answers_claims=iteration_responses_wo_docs_string,
# chat_history_string=chat_history_string,
# high_level_plan=(
# state.plan_of_record.plan
# if state.plan_of_record
# else "No plan available"
# ),
# )
test_info_complete_json = invoke_llm_json(
llm=graph_config.tooling.primary_llm,
prompt=create_question_prompt(
assistant_system_prompt,
test_info_complete_prompt + (assistant_task_prompt or ""),
),
schema=TestInfoCompleteResponse,
timeout_override=TF_DR_TIMEOUT_LONG,
# max_tokens=1000,
)
# test_info_complete_json = invoke_llm_json(
# llm=graph_config.tooling.primary_llm,
# prompt=create_question_prompt(
# assistant_system_prompt,
# test_info_complete_prompt + (assistant_task_prompt or ""),
# ),
# schema=TestInfoCompleteResponse,
# timeout_override=TF_DR_TIMEOUT_LONG,
# # max_tokens=1000,
# )
if test_info_complete_json.complete:
pass
# if test_info_complete_json.complete:
# pass
else:
return OrchestrationUpdate(
tools_used=[DRPath.ORCHESTRATOR.value],
query_list=[],
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="closer",
node_start_time=node_start_time,
)
],
gaps=test_info_complete_json.gaps,
num_closer_suggestions=num_closer_suggestions + 1,
)
# else:
# return OrchestrationUpdate(
# tools_used=[DRPath.ORCHESTRATOR.value],
# query_list=[],
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="main",
# node_name="closer",
# node_start_time=node_start_time,
# )
# ],
# gaps=test_info_complete_json.gaps,
# num_closer_suggestions=num_closer_suggestions + 1,
# )
retrieved_search_docs = convert_inference_sections_to_search_docs(
all_cited_documents
)
# retrieved_search_docs = convert_inference_sections_to_search_docs(
# all_cited_documents
# )
write_custom_event(
current_step_nr,
MessageStart(
content="",
final_documents=retrieved_search_docs,
),
writer,
)
# write_custom_event(
# current_step_nr,
# MessageStart(
# content="",
# final_documents=retrieved_search_docs,
# ),
# writer,
# )
if research_type in [ResearchType.THOUGHTFUL, ResearchType.FAST]:
final_answer_base_prompt = FINAL_ANSWER_PROMPT_WITHOUT_SUB_ANSWERS
elif research_type == ResearchType.DEEP:
final_answer_base_prompt = FINAL_ANSWER_PROMPT_W_SUB_ANSWERS
else:
raise ValueError(f"Invalid research type: {research_type}")
# if not use_agentic_search:
# final_answer_base_prompt = FINAL_ANSWER_PROMPT_WITHOUT_SUB_ANSWERS
# else:
# final_answer_base_prompt = FINAL_ANSWER_PROMPT_W_SUB_ANSWERS
estimated_final_answer_prompt_tokens = check_number_of_tokens(
final_answer_base_prompt.build(
base_question=prompt_question,
iteration_responses_string=iteration_responses_w_docs_string,
chat_history_string=chat_history_string,
uploaded_context=uploaded_context,
)
)
# estimated_final_answer_prompt_tokens = check_number_of_tokens(
# final_answer_base_prompt.build(
# base_question=prompt_question,
# iteration_responses_string=iteration_responses_w_docs_string,
# chat_history_string=chat_history_string,
# uploaded_context=uploaded_context,
# )
# )
# for DR, rely only on sub-answers and claims to save tokens if context is too long
# TODO: consider compression step for Thoughtful mode if context is too long.
# Should generally not be the case though.
# # for DR, rely only on sub-answers and claims to save tokens if context is too long
# # TODO: consider compression step for Thoughtful mode if context is too long.
# # Should generally not be the case though.
max_allowed_input_tokens = graph_config.tooling.primary_llm.config.max_input_tokens
# max_allowed_input_tokens = graph_config.tooling.primary_llm.config.max_input_tokens
if (
estimated_final_answer_prompt_tokens > 0.8 * max_allowed_input_tokens
and research_type == ResearchType.DEEP
):
iteration_responses_string = iteration_responses_wo_docs_string
else:
iteration_responses_string = iteration_responses_w_docs_string
# if (
# estimated_final_answer_prompt_tokens > 0.8 * max_allowed_input_tokens
# and use_agentic_search
# ):
# iteration_responses_string = iteration_responses_wo_docs_string
# else:
# iteration_responses_string = iteration_responses_w_docs_string
final_answer_prompt = final_answer_base_prompt.build(
base_question=prompt_question,
iteration_responses_string=iteration_responses_string,
chat_history_string=chat_history_string,
uploaded_context=uploaded_context,
)
# final_answer_prompt = final_answer_base_prompt.build(
# base_question=prompt_question,
# iteration_responses_string=iteration_responses_string,
# chat_history_string=chat_history_string,
# uploaded_context=uploaded_context,
# )
if graph_config.inputs.project_instructions:
assistant_system_prompt = (
assistant_system_prompt
+ PROJECT_INSTRUCTIONS_SEPARATOR
+ (graph_config.inputs.project_instructions or "")
)
# if graph_config.inputs.project_instructions:
# assistant_system_prompt = (
# assistant_system_prompt
# + PROJECT_INSTRUCTIONS_SEPARATOR
# + (graph_config.inputs.project_instructions or "")
# )
all_context_llmdocs = [
llm_doc_from_inference_section(inference_section)
for inference_section in all_cited_documents
]
# all_context_llmdocs = [
# llm_doc_from_inference_section(inference_section)
# for inference_section in all_cited_documents
# ]
try:
streamed_output, _, citation_infos = run_with_timeout(
int(3 * TF_DR_TIMEOUT_LONG),
lambda: stream_llm_answer(
llm=graph_config.tooling.primary_llm,
prompt=create_question_prompt(
assistant_system_prompt,
final_answer_prompt + (assistant_task_prompt or ""),
),
event_name="basic_response",
writer=writer,
agent_answer_level=0,
agent_answer_question_num=0,
agent_answer_type="agent_level_answer",
timeout_override=int(2 * TF_DR_TIMEOUT_LONG),
answer_piece=StreamingType.MESSAGE_DELTA.value,
ind=current_step_nr,
context_docs=all_context_llmdocs,
replace_citations=True,
# max_tokens=None,
),
)
# try:
# streamed_output, _, citation_infos = run_with_timeout(
# int(3 * TF_DR_TIMEOUT_LONG),
# lambda: stream_llm_answer(
# llm=graph_config.tooling.primary_llm,
# prompt=create_question_prompt(
# assistant_system_prompt,
# final_answer_prompt + (assistant_task_prompt or ""),
# ),
# event_name="basic_response",
# writer=writer,
# agent_answer_level=0,
# agent_answer_question_num=0,
# agent_answer_type="agent_level_answer",
# timeout_override=int(2 * TF_DR_TIMEOUT_LONG),
# answer_piece=StreamingType.MESSAGE_DELTA.value,
# ind=current_step_nr,
# context_docs=all_context_llmdocs,
# replace_citations=True,
# # max_tokens=None,
# ),
# )
final_answer = "".join(streamed_output)
except Exception as e:
raise ValueError(f"Error in consolidate_research: {e}")
# final_answer = "".join(streamed_output)
# except Exception as e:
# raise ValueError(f"Error in consolidate_research: {e}")
write_custom_event(current_step_nr, SectionEnd(), writer)
# write_custom_event(current_step_nr, SectionEnd(), writer)
current_step_nr += 1
# current_step_nr += 1
write_custom_event(current_step_nr, CitationStart(), writer)
write_custom_event(current_step_nr, CitationDelta(citations=citation_infos), writer)
write_custom_event(current_step_nr, SectionEnd(), writer)
# write_custom_event(current_step_nr, CitationStart(), writer)
# write_custom_event(current_step_nr, CitationDelta(citations=citation_infos), writer)
# write_custom_event(current_step_nr, SectionEnd(), writer)
current_step_nr += 1
# current_step_nr += 1
# Log the research agent steps
# save_iteration(
# state,
# graph_config,
# aggregated_context,
# final_answer,
# all_cited_documents,
# is_internet_marker_dict,
# )
# # Log the research agent steps
# # save_iteration(
# # state,
# # graph_config,
# # aggregated_context,
# # final_answer,
# # all_cited_documents,
# # is_internet_marker_dict,
# # )
return FinalUpdate(
final_answer=final_answer,
all_cited_documents=all_cited_documents,
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="closer",
node_start_time=node_start_time,
)
],
)
# return FinalUpdate(
# final_answer=final_answer,
# all_cited_documents=all_cited_documents,
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="main",
# node_name="closer",
# node_start_time=node_start_time,
# )
# ],
# )

View File

@@ -1,248 +1,246 @@
import re
from datetime import datetime
from typing import cast
# import re
# from datetime import datetime
# from typing import cast
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from sqlalchemy.orm import Session
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
# from sqlalchemy.orm import Session
from onyx.agents.agent_search.dr.enums import ResearchAnswerPurpose
from onyx.agents.agent_search.dr.models import AggregatedDRContext
from onyx.agents.agent_search.dr.states import LoggerUpdate
from onyx.agents.agent_search.dr.states import MainState
from onyx.agents.agent_search.dr.sub_agents.image_generation.models import (
GeneratedImageFullResult,
)
from onyx.agents.agent_search.dr.utils import aggregate_context
from onyx.agents.agent_search.dr.utils import convert_inference_sections_to_search_docs
from onyx.agents.agent_search.dr.utils import parse_plan_to_dict
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.context.search.models import InferenceSection
from onyx.db.chat import create_search_doc_from_inference_section
from onyx.db.chat import update_db_session_with_messages
from onyx.db.models import ChatMessage__SearchDoc
from onyx.db.models import ResearchAgentIteration
from onyx.db.models import ResearchAgentIterationSubStep
from onyx.db.models import SearchDoc as DbSearchDoc
from onyx.natural_language_processing.utils import get_tokenizer
from onyx.server.query_and_chat.streaming_models import OverallStop
from onyx.utils.logger import setup_logger
# from onyx.agents.agent_search.dr.enums import ResearchAnswerPurpose
# from onyx.agents.agent_search.dr.models import AggregatedDRContext
# from onyx.agents.agent_search.dr.states import LoggerUpdate
# from onyx.agents.agent_search.dr.states import MainState
# from onyx.agents.agent_search.dr.sub_agents.image_generation.models import (
# GeneratedImageFullResult,
# )
# from onyx.agents.agent_search.dr.utils import aggregate_context
# from onyx.agents.agent_search.dr.utils import convert_inference_sections_to_search_docs
# from onyx.agents.agent_search.dr.utils import parse_plan_to_dict
# from onyx.agents.agent_search.models import GraphConfig
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
# from onyx.context.search.models import InferenceSection
# from onyx.db.chat import create_search_doc_from_inference_section
# from onyx.db.chat import update_db_session_with_messages
# from onyx.db.models import ChatMessage__SearchDoc
# from onyx.db.models import ResearchAgentIteration
# from onyx.db.models import ResearchAgentIterationSubStep
# from onyx.db.models import SearchDoc as DbSearchDoc
# from onyx.natural_language_processing.utils import get_tokenizer
# from onyx.server.query_and_chat.streaming_models import OverallStop
# from onyx.utils.logger import setup_logger
logger = setup_logger()
# logger = setup_logger()
def _extract_citation_numbers(text: str) -> list[int]:
"""
Extract all citation numbers from text in the format [[<number>]] or [[<number_1>, <number_2>, ...]].
Returns a list of all unique citation numbers found.
"""
# Pattern to match [[number]] or [[number1, number2, ...]]
pattern = r"\[\[(\d+(?:,\s*\d+)*)\]\]"
matches = re.findall(pattern, text)
# def _extract_citation_numbers(text: str) -> list[int]:
# """
# Extract all citation numbers from text in the format [[<number>]] or [[<number_1>, <number_2>, ...]].
# Returns a list of all unique citation numbers found.
# """
# # Pattern to match [[number]] or [[number1, number2, ...]]
# pattern = r"\[\[(\d+(?:,\s*\d+)*)\]\]"
# matches = re.findall(pattern, text)
cited_numbers = []
for match in matches:
# Split by comma and extract all numbers
numbers = [int(num.strip()) for num in match.split(",")]
cited_numbers.extend(numbers)
# cited_numbers = []
# for match in matches:
# # Split by comma and extract all numbers
# numbers = [int(num.strip()) for num in match.split(",")]
# cited_numbers.extend(numbers)
return list(set(cited_numbers)) # Return unique numbers
# return list(set(cited_numbers)) # Return unique numbers
def replace_citation_with_link(match: re.Match[str], docs: list[DbSearchDoc]) -> str:
citation_content = match.group(1) # e.g., "3" or "3, 5, 7"
numbers = [int(num.strip()) for num in citation_content.split(",")]
# def replace_citation_with_link(match: re.Match[str], docs: list[DbSearchDoc]) -> str:
# citation_content = match.group(1) # e.g., "3" or "3, 5, 7"
# numbers = [int(num.strip()) for num in citation_content.split(",")]
# For multiple citations like [[3, 5, 7]], create separate linked citations
linked_citations = []
for num in numbers:
if num - 1 < len(docs): # Check bounds
link = docs[num - 1].link or ""
linked_citations.append(f"[[{num}]]({link})")
else:
linked_citations.append(f"[[{num}]]") # No link if out of bounds
# # For multiple citations like [[3, 5, 7]], create separate linked citations
# linked_citations = []
# for num in numbers:
# if num - 1 < len(docs): # Check bounds
# link = docs[num - 1].link or ""
# linked_citations.append(f"[[{num}]]({link})")
# else:
# linked_citations.append(f"[[{num}]]") # No link if out of bounds
return "".join(linked_citations)
# return "".join(linked_citations)
def _insert_chat_message_search_doc_pair(
message_id: int, search_doc_ids: list[int], db_session: Session
) -> None:
"""
Insert a pair of message_id and search_doc_id into the chat_message__search_doc table.
# def _insert_chat_message_search_doc_pair(
# message_id: int, search_doc_ids: list[int], db_session: Session
# ) -> None:
# """
# Insert a pair of message_id and search_doc_id into the chat_message__search_doc table.
Args:
message_id: The ID of the chat message
search_doc_id: The ID of the search document
db_session: The database session
"""
for search_doc_id in search_doc_ids:
chat_message_search_doc = ChatMessage__SearchDoc(
chat_message_id=message_id, search_doc_id=search_doc_id
)
db_session.add(chat_message_search_doc)
# Args:
# message_id: The ID of the chat message
# search_doc_id: The ID of the search document
# db_session: The database session
# """
# for search_doc_id in search_doc_ids:
# chat_message_search_doc = ChatMessage__SearchDoc(
# chat_message_id=message_id, search_doc_id=search_doc_id
# )
# db_session.add(chat_message_search_doc)
def save_iteration(
state: MainState,
graph_config: GraphConfig,
aggregated_context: AggregatedDRContext,
final_answer: str,
all_cited_documents: list[InferenceSection],
is_internet_marker_dict: dict[str, bool],
num_tokens: int,
) -> None:
db_session = graph_config.persistence.db_session
message_id = graph_config.persistence.message_id
research_type = graph_config.behavior.research_type
db_session = graph_config.persistence.db_session
# def save_iteration(
# state: MainState,
# graph_config: GraphConfig,
# aggregated_context: AggregatedDRContext,
# final_answer: str,
# all_cited_documents: list[InferenceSection],
# is_internet_marker_dict: dict[str, bool],
# num_tokens: int,
# ) -> None:
# db_session = graph_config.persistence.db_session
# message_id = graph_config.persistence.message_id
# first, insert the search_docs
search_docs = [
create_search_doc_from_inference_section(
inference_section=inference_section,
is_internet=is_internet_marker_dict.get(
inference_section.center_chunk.document_id, False
), # TODO: revisit
db_session=db_session,
commit=False,
)
for inference_section in all_cited_documents
]
# # first, insert the search_docs
# search_docs = [
# create_search_doc_from_inference_section(
# inference_section=inference_section,
# is_internet=is_internet_marker_dict.get(
# inference_section.center_chunk.document_id, False
# ), # TODO: revisit
# db_session=db_session,
# commit=False,
# )
# for inference_section in all_cited_documents
# ]
# then, map_search_docs to message
_insert_chat_message_search_doc_pair(
message_id, [search_doc.id for search_doc in search_docs], db_session
)
# # then, map_search_docs to message
# _insert_chat_message_search_doc_pair(
# message_id, [search_doc.id for search_doc in search_docs], db_session
# )
# lastly, insert the citations
citation_dict: dict[int, int] = {}
cited_doc_nrs = _extract_citation_numbers(final_answer)
if search_docs:
for cited_doc_nr in cited_doc_nrs:
citation_dict[cited_doc_nr] = search_docs[cited_doc_nr - 1].id
# # lastly, insert the citations
# citation_dict: dict[int, int] = {}
# cited_doc_nrs = _extract_citation_numbers(final_answer)
# if search_docs:
# for cited_doc_nr in cited_doc_nrs:
# citation_dict[cited_doc_nr] = search_docs[cited_doc_nr - 1].id
# TODO: generate plan as dict in the first place
plan_of_record = state.plan_of_record.plan if state.plan_of_record else ""
plan_of_record_dict = parse_plan_to_dict(plan_of_record)
# # TODO: generate plan as dict in the first place
# plan_of_record = state.plan_of_record.plan if state.plan_of_record else ""
# plan_of_record_dict = parse_plan_to_dict(plan_of_record)
# Update the chat message and its parent message in database
update_db_session_with_messages(
db_session=db_session,
chat_message_id=message_id,
chat_session_id=graph_config.persistence.chat_session_id,
is_agentic=graph_config.behavior.use_agentic_search,
message=final_answer,
citations=citation_dict,
research_type=research_type,
research_plan=plan_of_record_dict,
final_documents=search_docs,
update_parent_message=True,
research_answer_purpose=ResearchAnswerPurpose.ANSWER,
token_count=num_tokens,
)
# # Update the chat message and its parent message in database
# update_db_session_with_messages(
# db_session=db_session,
# chat_message_id=message_id,
# chat_session_id=graph_config.persistence.chat_session_id,
# is_agentic=graph_config.behavior.use_agentic_search,
# message=final_answer,
# citations=citation_dict,
# research_type=None, # research_type is deprecated
# research_plan=plan_of_record_dict,
# final_documents=search_docs,
# update_parent_message=True,
# research_answer_purpose=ResearchAnswerPurpose.ANSWER,
# token_count=num_tokens,
# )
for iteration_preparation in state.iteration_instructions:
research_agent_iteration_step = ResearchAgentIteration(
primary_question_id=message_id,
reasoning=iteration_preparation.reasoning,
purpose=iteration_preparation.purpose,
iteration_nr=iteration_preparation.iteration_nr,
)
db_session.add(research_agent_iteration_step)
# for iteration_preparation in state.iteration_instructions:
# research_agent_iteration_step = ResearchAgentIteration(
# primary_question_id=message_id,
# reasoning=iteration_preparation.reasoning,
# purpose=iteration_preparation.purpose,
# iteration_nr=iteration_preparation.iteration_nr,
# )
# db_session.add(research_agent_iteration_step)
for iteration_answer in aggregated_context.global_iteration_responses:
# for iteration_answer in aggregated_context.global_iteration_responses:
retrieved_search_docs = convert_inference_sections_to_search_docs(
list(iteration_answer.cited_documents.values())
)
# retrieved_search_docs = convert_inference_sections_to_search_docs(
# list(iteration_answer.cited_documents.values())
# )
# Convert SavedSearchDoc objects to JSON-serializable format
serialized_search_docs = [doc.model_dump() for doc in retrieved_search_docs]
# # Convert SavedSearchDoc objects to JSON-serializable format
# serialized_search_docs = [doc.model_dump() for doc in retrieved_search_docs]
research_agent_iteration_sub_step = ResearchAgentIterationSubStep(
primary_question_id=message_id,
iteration_nr=iteration_answer.iteration_nr,
iteration_sub_step_nr=iteration_answer.parallelization_nr,
sub_step_instructions=iteration_answer.question,
sub_step_tool_id=iteration_answer.tool_id,
sub_answer=iteration_answer.answer,
reasoning=iteration_answer.reasoning,
claims=iteration_answer.claims,
cited_doc_results=serialized_search_docs,
generated_images=(
GeneratedImageFullResult(images=iteration_answer.generated_images)
if iteration_answer.generated_images
else None
),
additional_data=iteration_answer.additional_data,
queries=iteration_answer.queries,
)
db_session.add(research_agent_iteration_sub_step)
# research_agent_iteration_sub_step = ResearchAgentIterationSubStep(
# primary_question_id=message_id,
# iteration_nr=iteration_answer.iteration_nr,
# iteration_sub_step_nr=iteration_answer.parallelization_nr,
# sub_step_instructions=iteration_answer.question,
# sub_step_tool_id=iteration_answer.tool_id,
# sub_answer=iteration_answer.answer,
# reasoning=iteration_answer.reasoning,
# claims=iteration_answer.claims,
# cited_doc_results=serialized_search_docs,
# generated_images=(
# GeneratedImageFullResult(images=iteration_answer.generated_images)
# if iteration_answer.generated_images
# else None
# ),
# additional_data=iteration_answer.additional_data,
# queries=iteration_answer.queries,
# )
# db_session.add(research_agent_iteration_sub_step)
db_session.commit()
# db_session.commit()
def logging(
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> LoggerUpdate:
"""
LangGraph node to close the DR process and finalize the answer.
"""
# def logging(
# state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
# ) -> LoggerUpdate:
# """
# LangGraph node to close the DR process and finalize the answer.
# """
node_start_time = datetime.now()
# TODO: generate final answer using all the previous steps
# (right now, answers from each step are concatenated onto each other)
# Also, add missing fields once usage in UI is clear.
# node_start_time = datetime.now()
# # TODO: generate final answer using all the previous steps
# # (right now, answers from each step are concatenated onto each other)
# # Also, add missing fields once usage in UI is clear.
current_step_nr = state.current_step_nr
# current_step_nr = state.current_step_nr
graph_config = cast(GraphConfig, config["metadata"]["config"])
base_question = state.original_question
if not base_question:
raise ValueError("Question is required for closer")
# graph_config = cast(GraphConfig, config["metadata"]["config"])
# base_question = state.original_question
# if not base_question:
# raise ValueError("Question is required for closer")
aggregated_context = aggregate_context(
state.iteration_responses, include_documents=True
)
# aggregated_context = aggregate_context(
# state.iteration_responses, include_documents=True
# )
all_cited_documents = aggregated_context.cited_documents
# all_cited_documents = aggregated_context.cited_documents
is_internet_marker_dict = aggregated_context.is_internet_marker_dict
# is_internet_marker_dict = aggregated_context.is_internet_marker_dict
final_answer = state.final_answer or ""
llm_provider = graph_config.tooling.primary_llm.config.model_provider
llm_model_name = graph_config.tooling.primary_llm.config.model_name
# final_answer = state.final_answer or ""
# llm_provider = graph_config.tooling.primary_llm.config.model_provider
# llm_model_name = graph_config.tooling.primary_llm.config.model_name
llm_tokenizer = get_tokenizer(
model_name=llm_model_name,
provider_type=llm_provider,
)
num_tokens = len(llm_tokenizer.encode(final_answer or ""))
# llm_tokenizer = get_tokenizer(
# model_name=llm_model_name,
# provider_type=llm_provider,
# )
# num_tokens = len(llm_tokenizer.encode(final_answer or ""))
write_custom_event(current_step_nr, OverallStop(), writer)
# write_custom_event(current_step_nr, OverallStop(), writer)
# Log the research agent steps
save_iteration(
state,
graph_config,
aggregated_context,
final_answer,
all_cited_documents,
is_internet_marker_dict,
num_tokens,
)
# # Log the research agent steps
# save_iteration(
# state,
# graph_config,
# aggregated_context,
# final_answer,
# all_cited_documents,
# is_internet_marker_dict,
# num_tokens,
# )
return LoggerUpdate(
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="logger",
node_start_time=node_start_time,
)
],
)
# return LoggerUpdate(
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="main",
# node_name="logger",
# node_start_time=node_start_time,
# )
# ],
# )

View File

@@ -1,132 +1,131 @@
from collections.abc import Iterator
from typing import cast
# from collections.abc import Iterator
# from typing import cast
from langchain_core.messages import AIMessageChunk
from langchain_core.messages import BaseMessage
from langgraph.types import StreamWriter
from pydantic import BaseModel
# from langchain_core.messages import AIMessageChunk
# from langchain_core.messages import BaseMessage
# from langgraph.types import StreamWriter
# from pydantic import BaseModel
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.chat_utils import saved_search_docs_from_llm_docs
from onyx.chat.models import AgentAnswerPiece
from onyx.chat.models import CitationInfo
from onyx.chat.models import LlmDoc
from onyx.chat.models import OnyxAnswerPiece
from onyx.chat.stream_processing.answer_response_handler import AnswerResponseHandler
from onyx.chat.stream_processing.answer_response_handler import CitationResponseHandler
from onyx.chat.stream_processing.answer_response_handler import (
PassThroughAnswerResponseHandler,
)
from onyx.chat.stream_processing.utils import map_document_id_order
from onyx.context.search.models import InferenceSection
from onyx.server.query_and_chat.streaming_models import CitationDelta
from onyx.server.query_and_chat.streaming_models import CitationStart
from onyx.server.query_and_chat.streaming_models import MessageDelta
from onyx.server.query_and_chat.streaming_models import MessageStart
from onyx.server.query_and_chat.streaming_models import SectionEnd
from onyx.utils.logger import setup_logger
# from onyx.chat.chat_utils import saved_search_docs_from_llm_docs
# from onyx.chat.models import AgentAnswerPiece
# from onyx.chat.models import CitationInfo
# from onyx.chat.models import LlmDoc
# from onyx.chat.models import OnyxAnswerPiece
# from onyx.chat.stream_processing.answer_response_handler import AnswerResponseHandler
# from onyx.chat.stream_processing.answer_response_handler import CitationResponseHandler
# from onyx.chat.stream_processing.answer_response_handler import (
# PassThroughAnswerResponseHandler,
# )
# from onyx.chat.stream_processing.utils import map_document_id_order
# from onyx.context.search.models import InferenceSection
# from onyx.server.query_and_chat.streaming_models import CitationDelta
# from onyx.server.query_and_chat.streaming_models import CitationStart
# from onyx.server.query_and_chat.streaming_models import MessageDelta
# from onyx.server.query_and_chat.streaming_models import MessageStart
# from onyx.server.query_and_chat.streaming_models import SectionEnd
# from onyx.utils.logger import setup_logger
logger = setup_logger()
# logger = setup_logger()
class BasicSearchProcessedStreamResults(BaseModel):
ai_message_chunk: AIMessageChunk = AIMessageChunk(content="")
full_answer: str | None = None
cited_references: list[InferenceSection] = []
retrieved_documents: list[LlmDoc] = []
# class BasicSearchProcessedStreamResults(BaseModel):
# ai_message_chunk: AIMessageChunk = AIMessageChunk(content="")
# full_answer: str | None = None
# cited_references: list[InferenceSection] = []
# retrieved_documents: list[LlmDoc] = []
def process_llm_stream(
messages: Iterator[BaseMessage],
should_stream_answer: bool,
writer: StreamWriter,
ind: int,
search_results: list[LlmDoc] | None = None,
generate_final_answer: bool = False,
chat_message_id: str | None = None,
) -> BasicSearchProcessedStreamResults:
tool_call_chunk = AIMessageChunk(content="")
# def process_llm_stream(
# messages: Iterator[BaseMessage],
# should_stream_answer: bool,
# writer: StreamWriter,
# ind: int,
# search_results: list[LlmDoc] | None = None,
# generate_final_answer: bool = False,
# chat_message_id: str | None = None,
# ) -> BasicSearchProcessedStreamResults:
# tool_call_chunk = AIMessageChunk(content="")
if search_results:
answer_handler: AnswerResponseHandler = CitationResponseHandler(
context_docs=search_results,
doc_id_to_rank_map=map_document_id_order(search_results),
)
else:
answer_handler = PassThroughAnswerResponseHandler()
# if search_results:
# answer_handler: AnswerResponseHandler = CitationResponseHandler(
# context_docs=search_results,
# doc_id_to_rank_map=map_document_id_order(search_results),
# )
# else:
# answer_handler = PassThroughAnswerResponseHandler()
full_answer = ""
start_final_answer_streaming_set = False
# Accumulate citation infos if handler emits them
collected_citation_infos: list[CitationInfo] = []
# full_answer = ""
# start_final_answer_streaming_set = False
# # Accumulate citation infos if handler emits them
# collected_citation_infos: list[CitationInfo] = []
# This stream will be the llm answer if no tool is chosen. When a tool is chosen,
# the stream will contain AIMessageChunks with tool call information.
for message in messages:
# # This stream will be the llm answer if no tool is chosen. When a tool is chosen,
# # the stream will contain AIMessageChunks with tool call information.
# for message in messages:
answer_piece = message.content
if not isinstance(answer_piece, str):
# this is only used for logging, so fine to
# just add the string representation
answer_piece = str(answer_piece)
full_answer += answer_piece
# answer_piece = message.content
# if not isinstance(answer_piece, str):
# # this is only used for logging, so fine to
# # just add the string representation
# answer_piece = str(answer_piece)
# full_answer += answer_piece
if isinstance(message, AIMessageChunk) and (
message.tool_call_chunks or message.tool_calls
):
tool_call_chunk += message # type: ignore
elif should_stream_answer:
for response_part in answer_handler.handle_response_part(message):
# if isinstance(message, AIMessageChunk) and (
# message.tool_call_chunks or message.tool_calls
# ):
# tool_call_chunk += message # type: ignore
# elif should_stream_answer:
# for response_part in answer_handler.handle_response_part(message):
# only stream out answer parts
if (
isinstance(response_part, (OnyxAnswerPiece, AgentAnswerPiece))
and generate_final_answer
and response_part.answer_piece
):
if chat_message_id is None:
raise ValueError(
"chat_message_id is required when generating final answer"
)
# # only stream out answer parts
# if (
# isinstance(response_part, (OnyxAnswerPiece, AgentAnswerPiece))
# and generate_final_answer
# and response_part.answer_piece
# ):
# if chat_message_id is None:
# raise ValueError(
# "chat_message_id is required when generating final answer"
# )
if not start_final_answer_streaming_set:
# Convert LlmDocs to SavedSearchDocs
saved_search_docs = saved_search_docs_from_llm_docs(
search_results
)
write_custom_event(
ind,
MessageStart(content="", final_documents=saved_search_docs),
writer,
)
start_final_answer_streaming_set = True
# if not start_final_answer_streaming_set:
# # Convert LlmDocs to SavedSearchDocs
# saved_search_docs = saved_search_docs_from_llm_docs(
# search_results
# )
# write_custom_event(
# ind,
# MessageStart(content="", final_documents=saved_search_docs),
# writer,
# )
# start_final_answer_streaming_set = True
write_custom_event(
ind,
MessageDelta(content=response_part.answer_piece),
writer,
)
# collect citation info objects
elif isinstance(response_part, CitationInfo):
collected_citation_infos.append(response_part)
# write_custom_event(
# ind,
# MessageDelta(content=response_part.answer_piece),
# writer,
# )
# # collect citation info objects
# elif isinstance(response_part, CitationInfo):
# collected_citation_infos.append(response_part)
if generate_final_answer and start_final_answer_streaming_set:
# start_final_answer_streaming_set is only set if the answer is verbal and not a tool call
write_custom_event(
ind,
SectionEnd(),
writer,
)
# if generate_final_answer and start_final_answer_streaming_set:
# # start_final_answer_streaming_set is only set if the answer is verbal and not a tool call
# write_custom_event(
# ind,
# SectionEnd(),
# writer,
# )
# Emit citations section if any were collected
if collected_citation_infos:
write_custom_event(ind, CitationStart(), writer)
write_custom_event(
ind, CitationDelta(citations=collected_citation_infos), writer
)
write_custom_event(ind, SectionEnd(), writer)
# # Emit citations section if any were collected
# if collected_citation_infos:
# write_custom_event(ind, CitationStart(), writer)
# write_custom_event(
# ind, CitationDelta(citations=collected_citation_infos), writer
# )
# write_custom_event(ind, SectionEnd(), writer)
logger.debug(f"Full answer: {full_answer}")
return BasicSearchProcessedStreamResults(
ai_message_chunk=cast(AIMessageChunk, tool_call_chunk), full_answer=full_answer
)
# logger.debug(f"Full answer: {full_answer}")
# return BasicSearchProcessedStreamResults(
# ai_message_chunk=cast(AIMessageChunk, tool_call_chunk), full_answer=full_answer
# )

View File

@@ -1,82 +1,82 @@
from operator import add
from typing import Annotated
from typing import Any
from typing import TypedDict
# from operator import add
# from typing import Annotated
# from typing import Any
# from typing import TypedDict
from pydantic import BaseModel
# from pydantic import BaseModel
from onyx.agents.agent_search.core_state import CoreState
from onyx.agents.agent_search.dr.models import IterationAnswer
from onyx.agents.agent_search.dr.models import IterationInstructions
from onyx.agents.agent_search.dr.models import OrchestrationClarificationInfo
from onyx.agents.agent_search.dr.models import OrchestrationPlan
from onyx.agents.agent_search.dr.models import OrchestratorTool
from onyx.context.search.models import InferenceSection
from onyx.db.connector import DocumentSource
# from onyx.agents.agent_search.core_state import CoreState
# from onyx.agents.agent_search.dr.models import IterationAnswer
# from onyx.agents.agent_search.dr.models import IterationInstructions
# from onyx.agents.agent_search.dr.models import OrchestrationClarificationInfo
# from onyx.agents.agent_search.dr.models import OrchestrationPlan
# from onyx.agents.agent_search.dr.models import OrchestratorTool
# from onyx.context.search.models import InferenceSection
# from onyx.db.connector import DocumentSource
### States ###
# ### States ###
class LoggerUpdate(BaseModel):
log_messages: Annotated[list[str], add] = []
# class LoggerUpdate(BaseModel):
# log_messages: Annotated[list[str], add] = []
class OrchestrationUpdate(LoggerUpdate):
tools_used: Annotated[list[str], add] = []
query_list: list[str] = []
iteration_nr: int = 0
current_step_nr: int = 1
plan_of_record: OrchestrationPlan | None = None # None for Thoughtful
remaining_time_budget: float = 2.0 # set by default to about 2 searches
num_closer_suggestions: int = 0 # how many times the closer was suggested
gaps: list[str] = (
[]
) # gaps that may be identified by the closer before being able to answer the question.
iteration_instructions: Annotated[list[IterationInstructions], add] = []
# class OrchestrationUpdate(LoggerUpdate):
# tools_used: Annotated[list[str], add] = []
# query_list: list[str] = []
# iteration_nr: int = 0
# current_step_nr: int = 1
# plan_of_record: OrchestrationPlan | None = None # None for Thoughtful
# remaining_time_budget: float = 2.0 # set by default to about 2 searches
# num_closer_suggestions: int = 0 # how many times the closer was suggested
# gaps: list[str] = (
# []
# ) # gaps that may be identified by the closer before being able to answer the question.
# iteration_instructions: Annotated[list[IterationInstructions], add] = []
class OrchestrationSetup(OrchestrationUpdate):
original_question: str | None = None
chat_history_string: str | None = None
clarification: OrchestrationClarificationInfo | None = None
available_tools: dict[str, OrchestratorTool] | None = None
num_closer_suggestions: int = 0 # how many times the closer was suggested
# class OrchestrationSetup(OrchestrationUpdate):
# original_question: str | None = None
# chat_history_string: str | None = None
# clarification: OrchestrationClarificationInfo | None = None
# available_tools: dict[str, OrchestratorTool] | None = None
# num_closer_suggestions: int = 0 # how many times the closer was suggested
active_source_types: list[DocumentSource] | None = None
active_source_types_descriptions: str | None = None
assistant_system_prompt: str | None = None
assistant_task_prompt: str | None = None
uploaded_test_context: str | None = None
uploaded_image_context: list[dict[str, Any]] | None = None
# active_source_types: list[DocumentSource] | None = None
# active_source_types_descriptions: str | None = None
# assistant_system_prompt: str | None = None
# assistant_task_prompt: str | None = None
# uploaded_test_context: str | None = None
# uploaded_image_context: list[dict[str, Any]] | None = None
class AnswerUpdate(LoggerUpdate):
iteration_responses: Annotated[list[IterationAnswer], add] = []
# class AnswerUpdate(LoggerUpdate):
# iteration_responses: Annotated[list[IterationAnswer], add] = []
class FinalUpdate(LoggerUpdate):
final_answer: str | None = None
all_cited_documents: list[InferenceSection] = []
# class FinalUpdate(LoggerUpdate):
# final_answer: str | None = None
# all_cited_documents: list[InferenceSection] = []
## Graph Input State
class MainInput(CoreState):
pass
# ## Graph Input State
# class MainInput(CoreState):
# pass
## Graph State
class MainState(
# This includes the core state
MainInput,
OrchestrationSetup,
AnswerUpdate,
FinalUpdate,
):
pass
# ## Graph State
# class MainState(
# # This includes the core state
# MainInput,
# OrchestrationSetup,
# AnswerUpdate,
# FinalUpdate,
# ):
# pass
## Graph Output State
class MainOutput(TypedDict):
log_messages: list[str]
final_answer: str | None
all_cited_documents: list[InferenceSection]
# ## Graph Output State
# class MainOutput(TypedDict):
# log_messages: list[str]
# final_answer: str | None
# all_cited_documents: list[InferenceSection]

View File

@@ -1,47 +1,47 @@
from datetime import datetime
# from datetime import datetime
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from onyx.agents.agent_search.dr.states import LoggerUpdate
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.server.query_and_chat.streaming_models import SearchToolStart
from onyx.utils.logger import setup_logger
# from onyx.agents.agent_search.dr.states import LoggerUpdate
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
# from onyx.server.query_and_chat.streaming_models import SearchToolStart
# from onyx.utils.logger import setup_logger
logger = setup_logger()
# logger = setup_logger()
def basic_search_branch(
state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> LoggerUpdate:
"""
LangGraph node to perform a standard search as part of the DR process.
"""
# def basic_search_branch(
# state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
# ) -> LoggerUpdate:
# """
# LangGraph node to perform a standard search as part of the DR process.
# """
node_start_time = datetime.now()
iteration_nr = state.iteration_nr
current_step_nr = state.current_step_nr
# node_start_time = datetime.now()
# iteration_nr = state.iteration_nr
# current_step_nr = state.current_step_nr
logger.debug(f"Search start for Basic Search {iteration_nr} at {datetime.now()}")
# logger.debug(f"Search start for Basic Search {iteration_nr} at {datetime.now()}")
write_custom_event(
current_step_nr,
SearchToolStart(
is_internet_search=False,
),
writer,
)
# write_custom_event(
# current_step_nr,
# SearchToolStart(
# is_internet_search=False,
# ),
# writer,
# )
return LoggerUpdate(
log_messages=[
get_langgraph_node_log_string(
graph_component="basic_search",
node_name="branching",
node_start_time=node_start_time,
)
],
)
# return LoggerUpdate(
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="basic_search",
# node_name="branching",
# node_start_time=node_start_time,
# )
# ],
# )

View File

@@ -1,286 +1,261 @@
import re
from datetime import datetime
from typing import cast
from uuid import UUID
# import re
# from datetime import datetime
# from typing import cast
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from onyx.agents.agent_search.dr.enums import ResearchType
from onyx.agents.agent_search.dr.models import BaseSearchProcessingResponse
from onyx.agents.agent_search.dr.models import IterationAnswer
from onyx.agents.agent_search.dr.models import SearchAnswer
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
from onyx.agents.agent_search.dr.sub_agents.states import BranchUpdate
from onyx.agents.agent_search.dr.utils import convert_inference_sections_to_search_docs
from onyx.agents.agent_search.dr.utils import extract_document_citations
from onyx.agents.agent_search.kb_search.graph_utils import build_document_context
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.llm import invoke_llm_json
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.agents.agent_search.utils import create_question_prompt
from onyx.chat.models import LlmDoc
from onyx.configs.agent_configs import TF_DR_TIMEOUT_LONG
from onyx.configs.agent_configs import TF_DR_TIMEOUT_SHORT
from onyx.context.search.models import InferenceSection
from onyx.db.connector import DocumentSource
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.prompts.dr_prompts import BASE_SEARCH_PROCESSING_PROMPT
from onyx.prompts.dr_prompts import INTERNAL_SEARCH_PROMPTS
from onyx.secondary_llm_flows.source_filter import strings_to_document_sources
from onyx.server.query_and_chat.streaming_models import SearchToolDelta
from onyx.tools.models import SearchToolOverrideKwargs
from onyx.tools.tool_implementations.search.search_tool import (
SEARCH_RESPONSE_SUMMARY_ID,
)
from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary
from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.utils.logger import setup_logger
# from onyx.agents.agent_search.dr.models import BaseSearchProcessingResponse
# from onyx.agents.agent_search.dr.models import IterationAnswer
# from onyx.agents.agent_search.dr.models import SearchAnswer
# from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
# from onyx.agents.agent_search.dr.sub_agents.states import BranchUpdate
# from onyx.agents.agent_search.dr.utils import convert_inference_sections_to_search_docs
# from onyx.agents.agent_search.dr.utils import extract_document_citations
# from onyx.agents.agent_search.kb_search.graph_utils import build_document_context
# from onyx.agents.agent_search.models import GraphConfig
# from onyx.agents.agent_search.shared_graph_utils.llm import invoke_llm_json
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
# from onyx.agents.agent_search.utils import create_question_prompt
# from onyx.chat.models import LlmDoc
# from onyx.configs.agent_configs import TF_DR_TIMEOUT_LONG
# from onyx.configs.agent_configs import TF_DR_TIMEOUT_SHORT
# from onyx.context.search.models import InferenceSection
# from onyx.db.connector import DocumentSource
# from onyx.prompts.dr_prompts import BASE_SEARCH_PROCESSING_PROMPT
# from onyx.prompts.dr_prompts import INTERNAL_SEARCH_PROMPTS
# from onyx.secondary_llm_flows.source_filter import strings_to_document_sources
# from onyx.server.query_and_chat.streaming_models import SearchToolDelta
# from onyx.tools.models import SearchToolOverrideKwargs
# from onyx.tools.tool_implementations.search.search_tool import SearchTool
# from onyx.tools.tool_implementations.search_like_tool_utils import (
# SEARCH_INFERENCE_SECTIONS_ID,
# )
# from onyx.utils.logger import setup_logger
logger = setup_logger()
# logger = setup_logger()
def basic_search(
state: BranchInput,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> BranchUpdate:
"""
LangGraph node to perform a standard search as part of the DR process.
"""
# def basic_search(
# state: BranchInput,
# config: RunnableConfig,
# writer: StreamWriter = lambda _: None,
# ) -> BranchUpdate:
# """
# LangGraph node to perform a standard search as part of the DR process.
# """
node_start_time = datetime.now()
iteration_nr = state.iteration_nr
parallelization_nr = state.parallelization_nr
current_step_nr = state.current_step_nr
assistant_system_prompt = state.assistant_system_prompt
assistant_task_prompt = state.assistant_task_prompt
# node_start_time = datetime.now()
# iteration_nr = state.iteration_nr
# parallelization_nr = state.parallelization_nr
# current_step_nr = state.current_step_nr
# assistant_system_prompt = state.assistant_system_prompt
# assistant_task_prompt = state.assistant_task_prompt
branch_query = state.branch_question
if not branch_query:
raise ValueError("branch_query is not set")
# branch_query = state.branch_question
# if not branch_query:
# raise ValueError("branch_query is not set")
graph_config = cast(GraphConfig, config["metadata"]["config"])
base_question = graph_config.inputs.prompt_builder.raw_user_query
research_type = graph_config.behavior.research_type
# graph_config = cast(GraphConfig, config["metadata"]["config"])
# base_question = graph_config.inputs.prompt_builder.raw_user_query
# use_agentic_search = graph_config.behavior.use_agentic_search
if not state.available_tools:
raise ValueError("available_tools is not set")
# if not state.available_tools:
# raise ValueError("available_tools is not set")
elif len(state.tools_used) == 0:
raise ValueError("tools_used is empty")
# elif len(state.tools_used) == 0:
# raise ValueError("tools_used is empty")
search_tool_info = state.available_tools[state.tools_used[-1]]
search_tool = cast(SearchTool, search_tool_info.tool_object)
force_use_tool = graph_config.tooling.force_use_tool
# search_tool_info = state.available_tools[state.tools_used[-1]]
# search_tool = cast(SearchTool, search_tool_info.tool_object)
# graph_config.tooling.force_use_tool
# sanity check
if search_tool != graph_config.tooling.search_tool:
raise ValueError("search_tool does not match the configured search tool")
# # sanity check
# if search_tool != graph_config.tooling.search_tool:
# raise ValueError("search_tool does not match the configured search tool")
# rewrite query and identify source types
active_source_types_str = ", ".join(
[source.value for source in state.active_source_types or []]
)
# # rewrite query and identify source types
# active_source_types_str = ", ".join(
# [source.value for source in state.active_source_types or []]
# )
base_search_processing_prompt = BASE_SEARCH_PROCESSING_PROMPT.build(
active_source_types_str=active_source_types_str,
branch_query=branch_query,
current_time=datetime.now().strftime("%Y-%m-%d %H:%M"),
)
# base_search_processing_prompt = BASE_SEARCH_PROCESSING_PROMPT.build(
# active_source_types_str=active_source_types_str,
# branch_query=branch_query,
# current_time=datetime.now().strftime("%Y-%m-%d %H:%M"),
# )
try:
search_processing = invoke_llm_json(
llm=graph_config.tooling.primary_llm,
prompt=create_question_prompt(
assistant_system_prompt, base_search_processing_prompt
),
schema=BaseSearchProcessingResponse,
timeout_override=TF_DR_TIMEOUT_SHORT,
# max_tokens=100,
)
except Exception as e:
logger.error(f"Could not process query: {e}")
raise e
# try:
# search_processing = invoke_llm_json(
# llm=graph_config.tooling.primary_llm,
# prompt=create_question_prompt(
# assistant_system_prompt, base_search_processing_prompt
# ),
# schema=BaseSearchProcessingResponse,
# timeout_override=TF_DR_TIMEOUT_SHORT,
# # max_tokens=100,
# )
# except Exception as e:
# logger.error(f"Could not process query: {e}")
# raise e
rewritten_query = search_processing.rewritten_query
# rewritten_query = search_processing.rewritten_query
# give back the query so we can render it in the UI
write_custom_event(
current_step_nr,
SearchToolDelta(
queries=[rewritten_query],
documents=[],
),
writer,
)
# # give back the query so we can render it in the UI
# write_custom_event(
# current_step_nr,
# SearchToolDelta(
# queries=[rewritten_query],
# documents=[],
# ),
# writer,
# )
implied_start_date = search_processing.time_filter
# implied_start_date = search_processing.time_filter
# Validate time_filter format if it exists
implied_time_filter = None
if implied_start_date:
# # Validate time_filter format if it exists
# implied_time_filter = None
# if implied_start_date:
# Check if time_filter is in YYYY-MM-DD format
date_pattern = r"^\d{4}-\d{2}-\d{2}$"
if re.match(date_pattern, implied_start_date):
implied_time_filter = datetime.strptime(implied_start_date, "%Y-%m-%d")
# # Check if time_filter is in YYYY-MM-DD format
# date_pattern = r"^\d{4}-\d{2}-\d{2}$"
# if re.match(date_pattern, implied_start_date):
# implied_time_filter = datetime.strptime(implied_start_date, "%Y-%m-%d")
specified_source_types: list[DocumentSource] | None = (
strings_to_document_sources(search_processing.specified_source_types)
if search_processing.specified_source_types
else None
)
# specified_source_types: list[DocumentSource] | None = (
# strings_to_document_sources(search_processing.specified_source_types)
# if search_processing.specified_source_types
# else None
# )
if specified_source_types is not None and len(specified_source_types) == 0:
specified_source_types = None
# if specified_source_types is not None and len(specified_source_types) == 0:
# specified_source_types = None
logger.debug(
f"Search start for Standard Search {iteration_nr}.{parallelization_nr} at {datetime.now()}"
)
# logger.debug(
# f"Search start for Standard Search {iteration_nr}.{parallelization_nr} at {datetime.now()}"
# )
retrieved_docs: list[InferenceSection] = []
callback_container: list[list[InferenceSection]] = []
# retrieved_docs: list[InferenceSection] = []
user_file_ids: list[UUID] | None = None
project_id: int | None = None
if force_use_tool.override_kwargs and isinstance(
force_use_tool.override_kwargs, SearchToolOverrideKwargs
):
override_kwargs = force_use_tool.override_kwargs
user_file_ids = override_kwargs.user_file_ids
project_id = override_kwargs.project_id
# for tool_response in search_tool.run(
# query=rewritten_query,
# document_sources=specified_source_types,
# time_filter=implied_time_filter,
# override_kwargs=SearchToolOverrideKwargs(original_query=rewritten_query),
# ):
# # get retrieved docs to send to the rest of the graph
# if tool_response.id == SEARCH_INFERENCE_SECTIONS_ID:
# retrieved_docs = cast(list[InferenceSection], tool_response.response)
# new db session to avoid concurrency issues
with get_session_with_current_tenant() as search_db_session:
for tool_response in search_tool.run(
query=rewritten_query,
document_sources=specified_source_types,
time_filter=implied_time_filter,
override_kwargs=SearchToolOverrideKwargs(
force_no_rerank=True,
alternate_db_session=search_db_session,
retrieved_sections_callback=callback_container.append,
skip_query_analysis=True,
original_query=rewritten_query,
user_file_ids=user_file_ids,
project_id=project_id,
),
):
# get retrieved docs to send to the rest of the graph
if tool_response.id == SEARCH_RESPONSE_SUMMARY_ID:
response = cast(SearchResponseSummary, tool_response.response)
retrieved_docs = response.top_sections
# break
break
# # render the retrieved docs in the UI
# write_custom_event(
# current_step_nr,
# SearchToolDelta(
# queries=[],
# documents=convert_inference_sections_to_search_docs(
# retrieved_docs, is_internet=False
# ),
# ),
# writer,
# )
# render the retrieved docs in the UI
write_custom_event(
current_step_nr,
SearchToolDelta(
queries=[],
documents=convert_inference_sections_to_search_docs(
retrieved_docs, is_internet=False
),
),
writer,
)
# document_texts_list = []
document_texts_list = []
# for doc_num, retrieved_doc in enumerate(retrieved_docs[:15]):
# if not isinstance(retrieved_doc, (InferenceSection, LlmDoc)):
# raise ValueError(f"Unexpected document type: {type(retrieved_doc)}")
# chunk_text = build_document_context(retrieved_doc, doc_num + 1)
# document_texts_list.append(chunk_text)
for doc_num, retrieved_doc in enumerate(retrieved_docs[:15]):
if not isinstance(retrieved_doc, (InferenceSection, LlmDoc)):
raise ValueError(f"Unexpected document type: {type(retrieved_doc)}")
chunk_text = build_document_context(retrieved_doc, doc_num + 1)
document_texts_list.append(chunk_text)
# document_texts = "\n\n".join(document_texts_list)
document_texts = "\n\n".join(document_texts_list)
# logger.debug(
# f"Search end/LLM start for Standard Search {iteration_nr}.{parallelization_nr} at {datetime.now()}"
# )
logger.debug(
f"Search end/LLM start for Standard Search {iteration_nr}.{parallelization_nr} at {datetime.now()}"
)
# # Built prompt
# Built prompt
# if use_agentic_search:
# search_prompt = INTERNAL_SEARCH_PROMPTS[use_agentic_search].build(
# search_query=branch_query,
# base_question=base_question,
# document_text=document_texts,
# )
if research_type == ResearchType.DEEP:
search_prompt = INTERNAL_SEARCH_PROMPTS[research_type].build(
search_query=branch_query,
base_question=base_question,
document_text=document_texts,
)
# # Run LLM
# Run LLM
# # search_answer_json = None
# search_answer_json = invoke_llm_json(
# llm=graph_config.tooling.primary_llm,
# prompt=create_question_prompt(
# assistant_system_prompt, search_prompt + (assistant_task_prompt or "")
# ),
# schema=SearchAnswer,
# timeout_override=TF_DR_TIMEOUT_LONG,
# # max_tokens=1500,
# )
# search_answer_json = None
search_answer_json = invoke_llm_json(
llm=graph_config.tooling.primary_llm,
prompt=create_question_prompt(
assistant_system_prompt, search_prompt + (assistant_task_prompt or "")
),
schema=SearchAnswer,
timeout_override=TF_DR_TIMEOUT_LONG,
# max_tokens=1500,
)
# logger.debug(
# f"LLM/all done for Standard Search {iteration_nr}.{parallelization_nr} at {datetime.now()}"
# )
logger.debug(
f"LLM/all done for Standard Search {iteration_nr}.{parallelization_nr} at {datetime.now()}"
)
# # get cited documents
# answer_string = search_answer_json.answer
# claims = search_answer_json.claims or []
# reasoning = search_answer_json.reasoning
# # answer_string = ""
# # claims = []
# get cited documents
answer_string = search_answer_json.answer
claims = search_answer_json.claims or []
reasoning = search_answer_json.reasoning
# answer_string = ""
# claims = []
# (
# citation_numbers,
# answer_string,
# claims,
# ) = extract_document_citations(answer_string, claims)
(
citation_numbers,
answer_string,
claims,
) = extract_document_citations(answer_string, claims)
# if citation_numbers and (
# (max(citation_numbers) > len(retrieved_docs)) or min(citation_numbers) < 1
# ):
# raise ValueError("Citation numbers are out of range for retrieved docs.")
if citation_numbers and (
(max(citation_numbers) > len(retrieved_docs)) or min(citation_numbers) < 1
):
raise ValueError("Citation numbers are out of range for retrieved docs.")
# cited_documents = {
# citation_number: retrieved_docs[citation_number - 1]
# for citation_number in citation_numbers
# }
cited_documents = {
citation_number: retrieved_docs[citation_number - 1]
for citation_number in citation_numbers
}
# else:
# answer_string = ""
# claims = []
# cited_documents = {
# doc_num + 1: retrieved_doc
# for doc_num, retrieved_doc in enumerate(retrieved_docs[:15])
# }
# reasoning = ""
else:
answer_string = ""
claims = []
cited_documents = {
doc_num + 1: retrieved_doc
for doc_num, retrieved_doc in enumerate(retrieved_docs[:15])
}
reasoning = ""
return BranchUpdate(
branch_iteration_responses=[
IterationAnswer(
tool=search_tool_info.llm_path,
tool_id=search_tool_info.tool_id,
iteration_nr=iteration_nr,
parallelization_nr=parallelization_nr,
question=branch_query,
answer=answer_string,
claims=claims,
cited_documents=cited_documents,
reasoning=reasoning,
additional_data=None,
)
],
log_messages=[
get_langgraph_node_log_string(
graph_component="basic_search",
node_name="searching",
node_start_time=node_start_time,
)
],
)
# return BranchUpdate(
# branch_iteration_responses=[
# IterationAnswer(
# tool=search_tool_info.llm_path,
# tool_id=search_tool_info.tool_id,
# iteration_nr=iteration_nr,
# parallelization_nr=parallelization_nr,
# question=branch_query,
# answer=answer_string,
# claims=claims,
# cited_documents=cited_documents,
# reasoning=reasoning,
# additional_data=None,
# )
# ],
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="basic_search",
# node_name="searching",
# node_start_time=node_start_time,
# )
# ],
# )

View File

@@ -1,77 +1,77 @@
from datetime import datetime
# from datetime import datetime
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentUpdate
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.context.search.models import SavedSearchDoc
from onyx.context.search.models import SearchDoc
from onyx.server.query_and_chat.streaming_models import SectionEnd
from onyx.utils.logger import setup_logger
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentUpdate
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
# from onyx.context.search.models import SavedSearchDoc
# from onyx.context.search.models import SearchDoc
# from onyx.server.query_and_chat.streaming_models import SectionEnd
# from onyx.utils.logger import setup_logger
logger = setup_logger()
# logger = setup_logger()
def is_reducer(
state: SubAgentMainState,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> SubAgentUpdate:
"""
LangGraph node to perform a standard search as part of the DR process.
"""
# def is_reducer(
# state: SubAgentMainState,
# config: RunnableConfig,
# writer: StreamWriter = lambda _: None,
# ) -> SubAgentUpdate:
# """
# LangGraph node to perform a standard search as part of the DR process.
# """
node_start_time = datetime.now()
# node_start_time = datetime.now()
branch_updates = state.branch_iteration_responses
current_iteration = state.iteration_nr
current_step_nr = state.current_step_nr
# branch_updates = state.branch_iteration_responses
# current_iteration = state.iteration_nr
# current_step_nr = state.current_step_nr
new_updates = [
update for update in branch_updates if update.iteration_nr == current_iteration
]
# new_updates = [
# update for update in branch_updates if update.iteration_nr == current_iteration
# ]
[update.question for update in new_updates]
doc_lists = [list(update.cited_documents.values()) for update in new_updates]
# [update.question for update in new_updates]
# doc_lists = [list(update.cited_documents.values()) for update in new_updates]
doc_list = []
# doc_list = []
for xs in doc_lists:
for x in xs:
doc_list.append(x)
# for xs in doc_lists:
# for x in xs:
# doc_list.append(x)
# Convert InferenceSections to SavedSearchDocs
search_docs = SearchDoc.from_chunks_or_sections(doc_list)
retrieved_saved_search_docs = [
SavedSearchDoc.from_search_doc(search_doc, db_doc_id=0)
for search_doc in search_docs
]
# # Convert InferenceSections to SavedSearchDocs
# search_docs = SearchDoc.from_chunks_or_sections(doc_list)
# retrieved_saved_search_docs = [
# SavedSearchDoc.from_search_doc(search_doc, db_doc_id=0)
# for search_doc in search_docs
# ]
for retrieved_saved_search_doc in retrieved_saved_search_docs:
retrieved_saved_search_doc.is_internet = False
# for retrieved_saved_search_doc in retrieved_saved_search_docs:
# retrieved_saved_search_doc.is_internet = False
write_custom_event(
current_step_nr,
SectionEnd(),
writer,
)
# write_custom_event(
# current_step_nr,
# SectionEnd(),
# writer,
# )
current_step_nr += 1
# current_step_nr += 1
return SubAgentUpdate(
iteration_responses=new_updates,
current_step_nr=current_step_nr,
log_messages=[
get_langgraph_node_log_string(
graph_component="basic_search",
node_name="consolidation",
node_start_time=node_start_time,
)
],
)
# return SubAgentUpdate(
# iteration_responses=new_updates,
# current_step_nr=current_step_nr,
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="basic_search",
# node_name="consolidation",
# node_start_time=node_start_time,
# )
# ],
# )

View File

@@ -1,50 +1,50 @@
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
# from langgraph.graph import END
# from langgraph.graph import START
# from langgraph.graph import StateGraph
from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_basic_search_1_branch import (
basic_search_branch,
)
from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_basic_search_2_act import (
basic_search,
)
from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_basic_search_3_reduce import (
is_reducer,
)
from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_image_generation_conditional_edges import (
branching_router,
)
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
from onyx.utils.logger import setup_logger
# from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_basic_search_1_branch import (
# basic_search_branch,
# )
# from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_basic_search_2_act import (
# basic_search,
# )
# from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_basic_search_3_reduce import (
# is_reducer,
# )
# from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_image_generation_conditional_edges import (
# branching_router,
# )
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
# from onyx.utils.logger import setup_logger
logger = setup_logger()
# logger = setup_logger()
def dr_basic_search_graph_builder() -> StateGraph:
"""
LangGraph graph builder for Web Search Sub-Agent
"""
# def dr_basic_search_graph_builder() -> StateGraph:
# """
# LangGraph graph builder for Web Search Sub-Agent
# """
graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
# graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
### Add nodes ###
# ### Add nodes ###
graph.add_node("branch", basic_search_branch)
# graph.add_node("branch", basic_search_branch)
graph.add_node("act", basic_search)
# graph.add_node("act", basic_search)
graph.add_node("reducer", is_reducer)
# graph.add_node("reducer", is_reducer)
### Add edges ###
# ### Add edges ###
graph.add_edge(start_key=START, end_key="branch")
# graph.add_edge(start_key=START, end_key="branch")
graph.add_conditional_edges("branch", branching_router)
# graph.add_conditional_edges("branch", branching_router)
graph.add_edge(start_key="act", end_key="reducer")
# graph.add_edge(start_key="act", end_key="reducer")
graph.add_edge(start_key="reducer", end_key=END)
# graph.add_edge(start_key="reducer", end_key=END)
return graph
# return graph

View File

@@ -1,30 +1,30 @@
from collections.abc import Hashable
# from collections.abc import Hashable
from langgraph.types import Send
# from langgraph.types import Send
from onyx.agents.agent_search.dr.constants import MAX_DR_PARALLEL_SEARCH
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
# from onyx.agents.agent_search.dr.constants import MAX_DR_PARALLEL_SEARCH
# from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
return [
Send(
"act",
BranchInput(
iteration_nr=state.iteration_nr,
parallelization_nr=parallelization_nr,
branch_question=query,
current_step_nr=state.current_step_nr,
context="",
active_source_types=state.active_source_types,
tools_used=state.tools_used,
available_tools=state.available_tools,
assistant_system_prompt=state.assistant_system_prompt,
assistant_task_prompt=state.assistant_task_prompt,
),
)
for parallelization_nr, query in enumerate(
state.query_list[:MAX_DR_PARALLEL_SEARCH]
)
]
# def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
# return [
# Send(
# "act",
# BranchInput(
# iteration_nr=state.iteration_nr,
# parallelization_nr=parallelization_nr,
# branch_question=query,
# current_step_nr=state.current_step_nr,
# context="",
# active_source_types=state.active_source_types,
# tools_used=state.tools_used,
# available_tools=state.available_tools,
# assistant_system_prompt=state.assistant_system_prompt,
# assistant_task_prompt=state.assistant_task_prompt,
# ),
# )
# for parallelization_nr, query in enumerate(
# state.query_list[:MAX_DR_PARALLEL_SEARCH]
# )
# ]

View File

@@ -1,36 +1,36 @@
from datetime import datetime
# from datetime import datetime
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from onyx.agents.agent_search.dr.states import LoggerUpdate
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.utils.logger import setup_logger
# from onyx.agents.agent_search.dr.states import LoggerUpdate
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.utils.logger import setup_logger
logger = setup_logger()
# logger = setup_logger()
def custom_tool_branch(
state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> LoggerUpdate:
"""
LangGraph node to perform a generic tool call as part of the DR process.
"""
# def custom_tool_branch(
# state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
# ) -> LoggerUpdate:
# """
# LangGraph node to perform a generic tool call as part of the DR process.
# """
node_start_time = datetime.now()
iteration_nr = state.iteration_nr
# node_start_time = datetime.now()
# iteration_nr = state.iteration_nr
logger.debug(f"Search start for Generic Tool {iteration_nr} at {datetime.now()}")
# logger.debug(f"Search start for Generic Tool {iteration_nr} at {datetime.now()}")
return LoggerUpdate(
log_messages=[
get_langgraph_node_log_string(
graph_component="custom_tool",
node_name="branching",
node_start_time=node_start_time,
)
],
)
# return LoggerUpdate(
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="custom_tool",
# node_name="branching",
# node_start_time=node_start_time,
# )
# ],
# )

View File

@@ -1,169 +1,164 @@
import json
from datetime import datetime
from typing import cast
# import json
# from datetime import datetime
# from typing import cast
from langchain_core.messages import AIMessage
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from langchain_core.messages import AIMessage
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
from onyx.agents.agent_search.dr.sub_agents.states import BranchUpdate
from onyx.agents.agent_search.dr.sub_agents.states import IterationAnswer
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.configs.agent_configs import TF_DR_TIMEOUT_LONG
from onyx.configs.agent_configs import TF_DR_TIMEOUT_SHORT
from onyx.prompts.dr_prompts import CUSTOM_TOOL_PREP_PROMPT
from onyx.prompts.dr_prompts import CUSTOM_TOOL_USE_PROMPT
from onyx.tools.tool_implementations.custom.custom_tool import CUSTOM_TOOL_RESPONSE_ID
from onyx.tools.tool_implementations.custom.custom_tool import CustomTool
from onyx.tools.tool_implementations.custom.custom_tool import CustomToolCallSummary
from onyx.tools.tool_implementations.mcp.mcp_tool import MCP_TOOL_RESPONSE_ID
from onyx.utils.logger import setup_logger
# from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
# from onyx.agents.agent_search.dr.sub_agents.states import BranchUpdate
# from onyx.agents.agent_search.dr.sub_agents.states import IterationAnswer
# from onyx.agents.agent_search.models import GraphConfig
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.configs.agent_configs import TF_DR_TIMEOUT_LONG
# from onyx.configs.agent_configs import TF_DR_TIMEOUT_SHORT
# from onyx.prompts.dr_prompts import CUSTOM_TOOL_PREP_PROMPT
# from onyx.prompts.dr_prompts import CUSTOM_TOOL_USE_PROMPT
# from onyx.tools.tool_implementations.custom.custom_tool import CUSTOM_TOOL_RESPONSE_ID
# from onyx.tools.tool_implementations.custom.custom_tool import CustomTool
# from onyx.tools.tool_implementations.custom.custom_tool import CustomToolCallSummary
# from onyx.tools.tool_implementations.mcp.mcp_tool import MCP_TOOL_RESPONSE_ID
# from onyx.utils.logger import setup_logger
logger = setup_logger()
# logger = setup_logger()
def custom_tool_act(
state: BranchInput,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> BranchUpdate:
"""
LangGraph node to perform a generic tool call as part of the DR process.
"""
# def custom_tool_act(
# state: BranchInput,
# config: RunnableConfig,
# writer: StreamWriter = lambda _: None,
# ) -> BranchUpdate:
# """
# LangGraph node to perform a generic tool call as part of the DR process.
# """
node_start_time = datetime.now()
iteration_nr = state.iteration_nr
parallelization_nr = state.parallelization_nr
# node_start_time = datetime.now()
# iteration_nr = state.iteration_nr
# parallelization_nr = state.parallelization_nr
if not state.available_tools:
raise ValueError("available_tools is not set")
# if not state.available_tools:
# raise ValueError("available_tools is not set")
custom_tool_info = state.available_tools[state.tools_used[-1]]
custom_tool_name = custom_tool_info.name
custom_tool = cast(CustomTool, custom_tool_info.tool_object)
# custom_tool_info = state.available_tools[state.tools_used[-1]]
# custom_tool_name = custom_tool_info.name
# custom_tool = cast(CustomTool, custom_tool_info.tool_object)
branch_query = state.branch_question
if not branch_query:
raise ValueError("branch_query is not set")
# branch_query = state.branch_question
# if not branch_query:
# raise ValueError("branch_query is not set")
graph_config = cast(GraphConfig, config["metadata"]["config"])
base_question = graph_config.inputs.prompt_builder.raw_user_query
# graph_config = cast(GraphConfig, config["metadata"]["config"])
# base_question = graph_config.inputs.prompt_builder.raw_user_query
logger.debug(
f"Tool call start for {custom_tool_name} {iteration_nr}.{parallelization_nr} at {datetime.now()}"
)
# logger.debug(
# f"Tool call start for {custom_tool_name} {iteration_nr}.{parallelization_nr} at {datetime.now()}"
# )
# get tool call args
tool_args: dict | None = None
if graph_config.tooling.using_tool_calling_llm:
# get tool call args from tool-calling LLM
tool_use_prompt = CUSTOM_TOOL_PREP_PROMPT.build(
query=branch_query,
base_question=base_question,
tool_description=custom_tool_info.description,
)
tool_calling_msg = graph_config.tooling.primary_llm.invoke_langchain(
tool_use_prompt,
tools=[custom_tool.tool_definition()],
tool_choice="required",
timeout_override=TF_DR_TIMEOUT_LONG,
)
# # get tool call args
# tool_args: dict | None = None
# if graph_config.tooling.using_tool_calling_llm:
# # get tool call args from tool-calling LLM
# tool_use_prompt = CUSTOM_TOOL_PREP_PROMPT.build(
# query=branch_query,
# base_question=base_question,
# tool_description=custom_tool_info.description,
# )
# tool_calling_msg = graph_config.tooling.primary_llm.invoke_langchain(
# tool_use_prompt,
# tools=[custom_tool.tool_definition()],
# tool_choice="required",
# timeout_override=TF_DR_TIMEOUT_LONG,
# )
# make sure we got a tool call
if (
isinstance(tool_calling_msg, AIMessage)
and len(tool_calling_msg.tool_calls) == 1
):
tool_args = tool_calling_msg.tool_calls[0]["args"]
else:
logger.warning("Tool-calling LLM did not emit a tool call")
# # make sure we got a tool call
# if (
# isinstance(tool_calling_msg, AIMessage)
# and len(tool_calling_msg.tool_calls) == 1
# ):
# tool_args = tool_calling_msg.tool_calls[0]["args"]
# else:
# logger.warning("Tool-calling LLM did not emit a tool call")
if tool_args is None:
# get tool call args from non-tool-calling LLM or for failed tool-calling LLM
tool_args = custom_tool.get_args_for_non_tool_calling_llm(
query=branch_query,
history=[],
llm=graph_config.tooling.primary_llm,
force_run=True,
)
# if tool_args is None:
# raise ValueError(
# "Failed to obtain tool arguments from LLM - tool calling is required"
# )
if tool_args is None:
raise ValueError("Failed to obtain tool arguments from LLM")
# # run the tool
# response_summary: CustomToolCallSummary | None = None
# for tool_response in custom_tool.run(**tool_args):
# if tool_response.id in {CUSTOM_TOOL_RESPONSE_ID, MCP_TOOL_RESPONSE_ID}:
# response_summary = cast(CustomToolCallSummary, tool_response.response)
# break
# run the tool
response_summary: CustomToolCallSummary | None = None
for tool_response in custom_tool.run(**tool_args):
if tool_response.id in {CUSTOM_TOOL_RESPONSE_ID, MCP_TOOL_RESPONSE_ID}:
response_summary = cast(CustomToolCallSummary, tool_response.response)
break
# if not response_summary:
# raise ValueError("Custom tool did not return a valid response summary")
if not response_summary:
raise ValueError("Custom tool did not return a valid response summary")
# # summarise tool result
# if not response_summary.response_type:
# raise ValueError("Response type is not returned.")
# summarise tool result
if not response_summary.response_type:
raise ValueError("Response type is not returned.")
# if response_summary.response_type == "json":
# tool_result_str = json.dumps(response_summary.tool_result, ensure_ascii=False)
# elif response_summary.response_type in {"image", "csv"}:
# tool_result_str = f"{response_summary.response_type} files: {response_summary.tool_result.file_ids}"
# else:
# tool_result_str = str(response_summary.tool_result)
if response_summary.response_type == "json":
tool_result_str = json.dumps(response_summary.tool_result, ensure_ascii=False)
elif response_summary.response_type in {"image", "csv"}:
tool_result_str = f"{response_summary.response_type} files: {response_summary.tool_result.file_ids}"
else:
tool_result_str = str(response_summary.tool_result)
# tool_str = (
# f"Tool used: {custom_tool_name}\n"
# f"Description: {custom_tool_info.description}\n"
# f"Result: {tool_result_str}"
# )
tool_str = (
f"Tool used: {custom_tool_name}\n"
f"Description: {custom_tool_info.description}\n"
f"Result: {tool_result_str}"
)
# tool_summary_prompt = CUSTOM_TOOL_USE_PROMPT.build(
# query=branch_query, base_question=base_question, tool_response=tool_str
# )
# answer_string = str(
# graph_config.tooling.primary_llm.invoke(
# tool_summary_prompt, timeout_override=TF_DR_TIMEOUT_SHORT
# ).content
# ).strip()
tool_summary_prompt = CUSTOM_TOOL_USE_PROMPT.build(
query=branch_query, base_question=base_question, tool_response=tool_str
)
answer_string = str(
graph_config.tooling.primary_llm.invoke_langchain(
tool_summary_prompt, timeout_override=TF_DR_TIMEOUT_SHORT
).content
).strip()
# tool_summary_prompt = CUSTOM_TOOL_USE_PROMPT.build(
# query=branch_query, base_question=base_question, tool_response=tool_str
# )
# answer_string = str(
# graph_config.tooling.primary_llm.invoke_langchain(
# tool_summary_prompt, timeout_override=TF_DR_TIMEOUT_SHORT
# ).content
# ).strip()
# get file_ids:
file_ids = None
if response_summary.response_type in {"image", "csv"} and hasattr(
response_summary.tool_result, "file_ids"
):
file_ids = response_summary.tool_result.file_ids
# logger.debug(
# f"Tool call end for {custom_tool_name} {iteration_nr}.{parallelization_nr} at {datetime.now()}"
# )
logger.debug(
f"Tool call end for {custom_tool_name} {iteration_nr}.{parallelization_nr} at {datetime.now()}"
)
return BranchUpdate(
branch_iteration_responses=[
IterationAnswer(
tool=custom_tool_name,
tool_id=custom_tool_info.tool_id,
iteration_nr=iteration_nr,
parallelization_nr=parallelization_nr,
question=branch_query,
answer=answer_string,
claims=[],
cited_documents={},
reasoning="",
additional_data=None,
response_type=response_summary.response_type,
data=response_summary.tool_result,
file_ids=file_ids,
)
],
log_messages=[
get_langgraph_node_log_string(
graph_component="custom_tool",
node_name="tool_calling",
node_start_time=node_start_time,
)
],
)
# return BranchUpdate(
# branch_iteration_responses=[
# IterationAnswer(
# tool=custom_tool_name,
# tool_id=custom_tool_info.tool_id,
# iteration_nr=iteration_nr,
# parallelization_nr=parallelization_nr,
# question=branch_query,
# answer=answer_string,
# claims=[],
# cited_documents={},
# reasoning="",
# additional_data=None,
# response_type=response_summary.response_type,
# data=response_summary.tool_result,
# file_ids=file_ids,
# )
# ],
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="custom_tool",
# node_name="tool_calling",
# node_start_time=node_start_time,
# )
# ],
# )

View File

@@ -1,82 +1,82 @@
from datetime import datetime
# from datetime import datetime
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentUpdate
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.server.query_and_chat.streaming_models import CustomToolDelta
from onyx.server.query_and_chat.streaming_models import CustomToolStart
from onyx.server.query_and_chat.streaming_models import SectionEnd
from onyx.utils.logger import setup_logger
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentUpdate
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
# from onyx.server.query_and_chat.streaming_models import CustomToolDelta
# from onyx.server.query_and_chat.streaming_models import CustomToolStart
# from onyx.server.query_and_chat.streaming_models import SectionEnd
# from onyx.utils.logger import setup_logger
logger = setup_logger()
# logger = setup_logger()
def custom_tool_reducer(
state: SubAgentMainState,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> SubAgentUpdate:
"""
LangGraph node to perform a generic tool call as part of the DR process.
"""
# def custom_tool_reducer(
# state: SubAgentMainState,
# config: RunnableConfig,
# writer: StreamWriter = lambda _: None,
# ) -> SubAgentUpdate:
# """
# LangGraph node to perform a generic tool call as part of the DR process.
# """
node_start_time = datetime.now()
# node_start_time = datetime.now()
current_step_nr = state.current_step_nr
# current_step_nr = state.current_step_nr
branch_updates = state.branch_iteration_responses
current_iteration = state.iteration_nr
# branch_updates = state.branch_iteration_responses
# current_iteration = state.iteration_nr
new_updates = [
update for update in branch_updates if update.iteration_nr == current_iteration
]
# new_updates = [
# update for update in branch_updates if update.iteration_nr == current_iteration
# ]
for new_update in new_updates:
# for new_update in new_updates:
if not new_update.response_type:
raise ValueError("Response type is not returned.")
# if not new_update.response_type:
# raise ValueError("Response type is not returned.")
write_custom_event(
current_step_nr,
CustomToolStart(
tool_name=new_update.tool,
),
writer,
)
# write_custom_event(
# current_step_nr,
# CustomToolStart(
# tool_name=new_update.tool,
# ),
# writer,
# )
write_custom_event(
current_step_nr,
CustomToolDelta(
tool_name=new_update.tool,
response_type=new_update.response_type,
data=new_update.data,
file_ids=new_update.file_ids,
),
writer,
)
# write_custom_event(
# current_step_nr,
# CustomToolDelta(
# tool_name=new_update.tool,
# response_type=new_update.response_type,
# data=new_update.data,
# file_ids=new_update.file_ids,
# ),
# writer,
# )
write_custom_event(
current_step_nr,
SectionEnd(),
writer,
)
# write_custom_event(
# current_step_nr,
# SectionEnd(),
# writer,
# )
current_step_nr += 1
# current_step_nr += 1
return SubAgentUpdate(
iteration_responses=new_updates,
log_messages=[
get_langgraph_node_log_string(
graph_component="custom_tool",
node_name="consolidation",
node_start_time=node_start_time,
)
],
)
# return SubAgentUpdate(
# iteration_responses=new_updates,
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="custom_tool",
# node_name="consolidation",
# node_start_time=node_start_time,
# )
# ],
# )

View File

@@ -1,28 +1,28 @@
from collections.abc import Hashable
# from collections.abc import Hashable
from langgraph.types import Send
# from langgraph.types import Send
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
from onyx.agents.agent_search.dr.sub_agents.states import (
SubAgentInput,
)
# from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
# from onyx.agents.agent_search.dr.sub_agents.states import (
# SubAgentInput,
# )
def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
return [
Send(
"act",
BranchInput(
iteration_nr=state.iteration_nr,
parallelization_nr=parallelization_nr,
branch_question=query,
context="",
active_source_types=state.active_source_types,
tools_used=state.tools_used,
available_tools=state.available_tools,
),
)
for parallelization_nr, query in enumerate(
state.query_list[:1] # no parallel call for now
)
]
# def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
# return [
# Send(
# "act",
# BranchInput(
# iteration_nr=state.iteration_nr,
# parallelization_nr=parallelization_nr,
# branch_question=query,
# context="",
# active_source_types=state.active_source_types,
# tools_used=state.tools_used,
# available_tools=state.available_tools,
# ),
# )
# for parallelization_nr, query in enumerate(
# state.query_list[:1] # no parallel call for now
# )
# ]

View File

@@ -1,50 +1,50 @@
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
# from langgraph.graph import END
# from langgraph.graph import START
# from langgraph.graph import StateGraph
from onyx.agents.agent_search.dr.sub_agents.custom_tool.dr_custom_tool_1_branch import (
custom_tool_branch,
)
from onyx.agents.agent_search.dr.sub_agents.custom_tool.dr_custom_tool_2_act import (
custom_tool_act,
)
from onyx.agents.agent_search.dr.sub_agents.custom_tool.dr_custom_tool_3_reduce import (
custom_tool_reducer,
)
from onyx.agents.agent_search.dr.sub_agents.custom_tool.dr_custom_tool_conditional_edges import (
branching_router,
)
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
from onyx.utils.logger import setup_logger
# from onyx.agents.agent_search.dr.sub_agents.custom_tool.dr_custom_tool_1_branch import (
# custom_tool_branch,
# )
# from onyx.agents.agent_search.dr.sub_agents.custom_tool.dr_custom_tool_2_act import (
# custom_tool_act,
# )
# from onyx.agents.agent_search.dr.sub_agents.custom_tool.dr_custom_tool_3_reduce import (
# custom_tool_reducer,
# )
# from onyx.agents.agent_search.dr.sub_agents.custom_tool.dr_custom_tool_conditional_edges import (
# branching_router,
# )
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
# from onyx.utils.logger import setup_logger
logger = setup_logger()
# logger = setup_logger()
def dr_custom_tool_graph_builder() -> StateGraph:
"""
LangGraph graph builder for Generic Tool Sub-Agent
"""
# def dr_custom_tool_graph_builder() -> StateGraph:
# """
# LangGraph graph builder for Generic Tool Sub-Agent
# """
graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
# graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
### Add nodes ###
# ### Add nodes ###
graph.add_node("branch", custom_tool_branch)
# graph.add_node("branch", custom_tool_branch)
graph.add_node("act", custom_tool_act)
# graph.add_node("act", custom_tool_act)
graph.add_node("reducer", custom_tool_reducer)
# graph.add_node("reducer", custom_tool_reducer)
### Add edges ###
# ### Add edges ###
graph.add_edge(start_key=START, end_key="branch")
# graph.add_edge(start_key=START, end_key="branch")
graph.add_conditional_edges("branch", branching_router)
# graph.add_conditional_edges("branch", branching_router)
graph.add_edge(start_key="act", end_key="reducer")
# graph.add_edge(start_key="act", end_key="reducer")
graph.add_edge(start_key="reducer", end_key=END)
# graph.add_edge(start_key="reducer", end_key=END)
return graph
# return graph

View File

@@ -1,36 +1,36 @@
from datetime import datetime
# from datetime import datetime
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from onyx.agents.agent_search.dr.states import LoggerUpdate
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.utils.logger import setup_logger
# from onyx.agents.agent_search.dr.states import LoggerUpdate
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.utils.logger import setup_logger
logger = setup_logger()
# logger = setup_logger()
def generic_internal_tool_branch(
state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> LoggerUpdate:
"""
LangGraph node to perform a generic tool call as part of the DR process.
"""
# def generic_internal_tool_branch(
# state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
# ) -> LoggerUpdate:
# """
# LangGraph node to perform a generic tool call as part of the DR process.
# """
node_start_time = datetime.now()
iteration_nr = state.iteration_nr
# node_start_time = datetime.now()
# iteration_nr = state.iteration_nr
logger.debug(f"Search start for Generic Tool {iteration_nr} at {datetime.now()}")
# logger.debug(f"Search start for Generic Tool {iteration_nr} at {datetime.now()}")
return LoggerUpdate(
log_messages=[
get_langgraph_node_log_string(
graph_component="generic_internal_tool",
node_name="branching",
node_start_time=node_start_time,
)
],
)
# return LoggerUpdate(
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="generic_internal_tool",
# node_name="branching",
# node_start_time=node_start_time,
# )
# ],
# )

View File

@@ -1,149 +1,147 @@
import json
from datetime import datetime
from typing import cast
# import json
# from datetime import datetime
# from typing import cast
from langchain_core.messages import AIMessage
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from langchain_core.messages import AIMessage
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
from onyx.agents.agent_search.dr.sub_agents.states import BranchUpdate
from onyx.agents.agent_search.dr.sub_agents.states import IterationAnswer
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.configs.agent_configs import TF_DR_TIMEOUT_SHORT
from onyx.prompts.dr_prompts import CUSTOM_TOOL_PREP_PROMPT
from onyx.prompts.dr_prompts import CUSTOM_TOOL_USE_PROMPT
from onyx.prompts.dr_prompts import OKTA_TOOL_USE_SPECIAL_PROMPT
from onyx.utils.logger import setup_logger
# from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
# from onyx.agents.agent_search.dr.sub_agents.states import BranchUpdate
# from onyx.agents.agent_search.dr.sub_agents.states import IterationAnswer
# from onyx.agents.agent_search.models import GraphConfig
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.configs.agent_configs import TF_DR_TIMEOUT_SHORT
# from onyx.prompts.dr_prompts import CUSTOM_TOOL_PREP_PROMPT
# from onyx.prompts.dr_prompts import CUSTOM_TOOL_USE_PROMPT
# from onyx.prompts.dr_prompts import OKTA_TOOL_USE_SPECIAL_PROMPT
# from onyx.utils.logger import setup_logger
logger = setup_logger()
# logger = setup_logger()
def generic_internal_tool_act(
state: BranchInput,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> BranchUpdate:
"""
LangGraph node to perform a generic tool call as part of the DR process.
"""
# def generic_internal_tool_act(
# state: BranchInput,
# config: RunnableConfig,
# writer: StreamWriter = lambda _: None,
# ) -> BranchUpdate:
# """
# LangGraph node to perform a generic tool call as part of the DR process.
# """
node_start_time = datetime.now()
iteration_nr = state.iteration_nr
parallelization_nr = state.parallelization_nr
# node_start_time = datetime.now()
# iteration_nr = state.iteration_nr
# parallelization_nr = state.parallelization_nr
if not state.available_tools:
raise ValueError("available_tools is not set")
# if not state.available_tools:
# raise ValueError("available_tools is not set")
generic_internal_tool_info = state.available_tools[state.tools_used[-1]]
generic_internal_tool_name = generic_internal_tool_info.llm_path
generic_internal_tool = generic_internal_tool_info.tool_object
# generic_internal_tool_info = state.available_tools[state.tools_used[-1]]
# generic_internal_tool_name = generic_internal_tool_info.llm_path
# generic_internal_tool = generic_internal_tool_info.tool_object
if generic_internal_tool is None:
raise ValueError("generic_internal_tool is not set")
# if generic_internal_tool is None:
# raise ValueError("generic_internal_tool is not set")
branch_query = state.branch_question
if not branch_query:
raise ValueError("branch_query is not set")
# branch_query = state.branch_question
# if not branch_query:
# raise ValueError("branch_query is not set")
graph_config = cast(GraphConfig, config["metadata"]["config"])
base_question = graph_config.inputs.prompt_builder.raw_user_query
# graph_config = cast(GraphConfig, config["metadata"]["config"])
# base_question = graph_config.inputs.prompt_builder.raw_user_query
logger.debug(
f"Tool call start for {generic_internal_tool_name} {iteration_nr}.{parallelization_nr} at {datetime.now()}"
)
# logger.debug(
# f"Tool call start for {generic_internal_tool_name} {iteration_nr}.{parallelization_nr} at {datetime.now()}"
# )
# get tool call args
tool_args: dict | None = None
if graph_config.tooling.using_tool_calling_llm:
# get tool call args from tool-calling LLM
tool_use_prompt = CUSTOM_TOOL_PREP_PROMPT.build(
query=branch_query,
base_question=base_question,
tool_description=generic_internal_tool_info.description,
)
tool_calling_msg = graph_config.tooling.primary_llm.invoke_langchain(
tool_use_prompt,
tools=[generic_internal_tool.tool_definition()],
tool_choice="required",
timeout_override=TF_DR_TIMEOUT_SHORT,
)
# # get tool call args
# tool_args: dict | None = None
# if graph_config.tooling.using_tool_calling_llm:
# # get tool call args from tool-calling LLM
# tool_use_prompt = CUSTOM_TOOL_PREP_PROMPT.build(
# query=branch_query,
# base_question=base_question,
# tool_description=generic_internal_tool_info.description,
# )
# tool_calling_msg = graph_config.tooling.primary_llm.invoke_langchain(
# tool_use_prompt,
# tools=[generic_internal_tool.tool_definition()],
# tool_choice="required",
# timeout_override=TF_DR_TIMEOUT_SHORT,
# )
# make sure we got a tool call
if (
isinstance(tool_calling_msg, AIMessage)
and len(tool_calling_msg.tool_calls) == 1
):
tool_args = tool_calling_msg.tool_calls[0]["args"]
else:
logger.warning("Tool-calling LLM did not emit a tool call")
# # make sure we got a tool call
# if (
# isinstance(tool_calling_msg, AIMessage)
# and len(tool_calling_msg.tool_calls) == 1
# ):
# tool_args = tool_calling_msg.tool_calls[0]["args"]
# else:
# logger.warning("Tool-calling LLM did not emit a tool call")
if tool_args is None:
# get tool call args from non-tool-calling LLM or for failed tool-calling LLM
tool_args = generic_internal_tool.get_args_for_non_tool_calling_llm(
query=branch_query,
history=[],
llm=graph_config.tooling.primary_llm,
force_run=True,
)
# if tool_args is None:
# raise ValueError(
# "Failed to obtain tool arguments from LLM - tool calling is required"
# )
if tool_args is None:
raise ValueError("Failed to obtain tool arguments from LLM")
# # run the tool
# tool_responses = list(generic_internal_tool.run(**tool_args))
# final_data = generic_internal_tool.final_result(*tool_responses)
# tool_result_str = json.dumps(final_data, ensure_ascii=False)
# run the tool
tool_responses = list(generic_internal_tool.run(**tool_args))
final_data = generic_internal_tool.final_result(*tool_responses)
tool_result_str = json.dumps(final_data, ensure_ascii=False)
# tool_str = (
# f"Tool used: {generic_internal_tool.display_name}\n"
# f"Description: {generic_internal_tool_info.description}\n"
# f"Result: {tool_result_str}"
# )
tool_str = (
f"Tool used: {generic_internal_tool.display_name}\n"
f"Description: {generic_internal_tool_info.description}\n"
f"Result: {tool_result_str}"
)
# if generic_internal_tool.display_name == "Okta Profile":
# tool_prompt = OKTA_TOOL_USE_SPECIAL_PROMPT
# else:
# tool_prompt = CUSTOM_TOOL_USE_PROMPT
if generic_internal_tool.display_name == "Okta Profile":
tool_prompt = OKTA_TOOL_USE_SPECIAL_PROMPT
else:
tool_prompt = CUSTOM_TOOL_USE_PROMPT
# tool_summary_prompt = tool_prompt.build(
# query=branch_query, base_question=base_question, tool_response=tool_str
# )
# answer_string = str(
# graph_config.tooling.primary_llm.invoke(
# tool_summary_prompt, timeout_override=TF_DR_TIMEOUT_SHORT
# ).content
# ).strip()
tool_summary_prompt = tool_prompt.build(
query=branch_query, base_question=base_question, tool_response=tool_str
)
answer_string = str(
graph_config.tooling.primary_llm.invoke_langchain(
tool_summary_prompt, timeout_override=TF_DR_TIMEOUT_SHORT
).content
).strip()
# tool_summary_prompt = tool_prompt.build(
# query=branch_query, base_question=base_question, tool_response=tool_str
# )
# answer_string = str(
# graph_config.tooling.primary_llm.invoke_langchain(
# tool_summary_prompt, timeout_override=TF_DR_TIMEOUT_SHORT
# ).content
# ).strip()
logger.debug(
f"Tool call end for {generic_internal_tool_name} {iteration_nr}.{parallelization_nr} at {datetime.now()}"
)
return BranchUpdate(
branch_iteration_responses=[
IterationAnswer(
tool=generic_internal_tool.llm_name,
tool_id=generic_internal_tool_info.tool_id,
iteration_nr=iteration_nr,
parallelization_nr=parallelization_nr,
question=branch_query,
answer=answer_string,
claims=[],
cited_documents={},
reasoning="",
additional_data=None,
response_type="text", # TODO: convert all response types to enums
data=answer_string,
)
],
log_messages=[
get_langgraph_node_log_string(
graph_component="custom_tool",
node_name="tool_calling",
node_start_time=node_start_time,
)
],
)
# return BranchUpdate(
# branch_iteration_responses=[
# IterationAnswer(
# tool=generic_internal_tool.name,
# tool_id=generic_internal_tool_info.tool_id,
# iteration_nr=iteration_nr,
# parallelization_nr=parallelization_nr,
# question=branch_query,
# answer=answer_string,
# claims=[],
# cited_documents={},
# reasoning="",
# additional_data=None,
# response_type="text", # TODO: convert all response types to enums
# data=answer_string,
# )
# ],
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="custom_tool",
# node_name="tool_calling",
# node_start_time=node_start_time,
# )
# ],
# )

View File

@@ -1,82 +1,82 @@
from datetime import datetime
# from datetime import datetime
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentUpdate
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.server.query_and_chat.streaming_models import CustomToolDelta
from onyx.server.query_and_chat.streaming_models import CustomToolStart
from onyx.server.query_and_chat.streaming_models import SectionEnd
from onyx.utils.logger import setup_logger
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentUpdate
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
# from onyx.server.query_and_chat.streaming_models import CustomToolDelta
# from onyx.server.query_and_chat.streaming_models import CustomToolStart
# from onyx.server.query_and_chat.streaming_models import SectionEnd
# from onyx.utils.logger import setup_logger
logger = setup_logger()
# logger = setup_logger()
def generic_internal_tool_reducer(
state: SubAgentMainState,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> SubAgentUpdate:
"""
LangGraph node to perform a generic tool call as part of the DR process.
"""
# def generic_internal_tool_reducer(
# state: SubAgentMainState,
# config: RunnableConfig,
# writer: StreamWriter = lambda _: None,
# ) -> SubAgentUpdate:
# """
# LangGraph node to perform a generic tool call as part of the DR process.
# """
node_start_time = datetime.now()
# node_start_time = datetime.now()
current_step_nr = state.current_step_nr
# current_step_nr = state.current_step_nr
branch_updates = state.branch_iteration_responses
current_iteration = state.iteration_nr
# branch_updates = state.branch_iteration_responses
# current_iteration = state.iteration_nr
new_updates = [
update for update in branch_updates if update.iteration_nr == current_iteration
]
# new_updates = [
# update for update in branch_updates if update.iteration_nr == current_iteration
# ]
for new_update in new_updates:
# for new_update in new_updates:
if not new_update.response_type:
raise ValueError("Response type is not returned.")
# if not new_update.response_type:
# raise ValueError("Response type is not returned.")
write_custom_event(
current_step_nr,
CustomToolStart(
tool_name=new_update.tool,
),
writer,
)
# write_custom_event(
# current_step_nr,
# CustomToolStart(
# tool_name=new_update.tool,
# ),
# writer,
# )
write_custom_event(
current_step_nr,
CustomToolDelta(
tool_name=new_update.tool,
response_type=new_update.response_type,
data=new_update.data,
file_ids=[],
),
writer,
)
# write_custom_event(
# current_step_nr,
# CustomToolDelta(
# tool_name=new_update.tool,
# response_type=new_update.response_type,
# data=new_update.data,
# file_ids=[],
# ),
# writer,
# )
write_custom_event(
current_step_nr,
SectionEnd(),
writer,
)
# write_custom_event(
# current_step_nr,
# SectionEnd(),
# writer,
# )
current_step_nr += 1
# current_step_nr += 1
return SubAgentUpdate(
iteration_responses=new_updates,
log_messages=[
get_langgraph_node_log_string(
graph_component="custom_tool",
node_name="consolidation",
node_start_time=node_start_time,
)
],
)
# return SubAgentUpdate(
# iteration_responses=new_updates,
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="custom_tool",
# node_name="consolidation",
# node_start_time=node_start_time,
# )
# ],
# )

View File

@@ -1,28 +1,28 @@
from collections.abc import Hashable
# from collections.abc import Hashable
from langgraph.types import Send
# from langgraph.types import Send
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
from onyx.agents.agent_search.dr.sub_agents.states import (
SubAgentInput,
)
# from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
# from onyx.agents.agent_search.dr.sub_agents.states import (
# SubAgentInput,
# )
def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
return [
Send(
"act",
BranchInput(
iteration_nr=state.iteration_nr,
parallelization_nr=parallelization_nr,
branch_question=query,
context="",
active_source_types=state.active_source_types,
tools_used=state.tools_used,
available_tools=state.available_tools,
),
)
for parallelization_nr, query in enumerate(
state.query_list[:1] # no parallel call for now
)
]
# def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
# return [
# Send(
# "act",
# BranchInput(
# iteration_nr=state.iteration_nr,
# parallelization_nr=parallelization_nr,
# branch_question=query,
# context="",
# active_source_types=state.active_source_types,
# tools_used=state.tools_used,
# available_tools=state.available_tools,
# ),
# )
# for parallelization_nr, query in enumerate(
# state.query_list[:1] # no parallel call for now
# )
# ]

View File

@@ -1,50 +1,50 @@
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
# from langgraph.graph import END
# from langgraph.graph import START
# from langgraph.graph import StateGraph
from onyx.agents.agent_search.dr.sub_agents.generic_internal_tool.dr_generic_internal_tool_1_branch import (
generic_internal_tool_branch,
)
from onyx.agents.agent_search.dr.sub_agents.generic_internal_tool.dr_generic_internal_tool_2_act import (
generic_internal_tool_act,
)
from onyx.agents.agent_search.dr.sub_agents.generic_internal_tool.dr_generic_internal_tool_3_reduce import (
generic_internal_tool_reducer,
)
from onyx.agents.agent_search.dr.sub_agents.generic_internal_tool.dr_generic_internal_tool_conditional_edges import (
branching_router,
)
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
from onyx.utils.logger import setup_logger
# from onyx.agents.agent_search.dr.sub_agents.generic_internal_tool.dr_generic_internal_tool_1_branch import (
# generic_internal_tool_branch,
# )
# from onyx.agents.agent_search.dr.sub_agents.generic_internal_tool.dr_generic_internal_tool_2_act import (
# generic_internal_tool_act,
# )
# from onyx.agents.agent_search.dr.sub_agents.generic_internal_tool.dr_generic_internal_tool_3_reduce import (
# generic_internal_tool_reducer,
# )
# from onyx.agents.agent_search.dr.sub_agents.generic_internal_tool.dr_generic_internal_tool_conditional_edges import (
# branching_router,
# )
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
# from onyx.utils.logger import setup_logger
logger = setup_logger()
# logger = setup_logger()
def dr_generic_internal_tool_graph_builder() -> StateGraph:
"""
LangGraph graph builder for Generic Tool Sub-Agent
"""
# def dr_generic_internal_tool_graph_builder() -> StateGraph:
# """
# LangGraph graph builder for Generic Tool Sub-Agent
# """
graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
# graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
### Add nodes ###
# ### Add nodes ###
graph.add_node("branch", generic_internal_tool_branch)
# graph.add_node("branch", generic_internal_tool_branch)
graph.add_node("act", generic_internal_tool_act)
# graph.add_node("act", generic_internal_tool_act)
graph.add_node("reducer", generic_internal_tool_reducer)
# graph.add_node("reducer", generic_internal_tool_reducer)
### Add edges ###
# ### Add edges ###
graph.add_edge(start_key=START, end_key="branch")
# graph.add_edge(start_key=START, end_key="branch")
graph.add_conditional_edges("branch", branching_router)
# graph.add_conditional_edges("branch", branching_router)
graph.add_edge(start_key="act", end_key="reducer")
# graph.add_edge(start_key="act", end_key="reducer")
graph.add_edge(start_key="reducer", end_key=END)
# graph.add_edge(start_key="reducer", end_key=END)
return graph
# return graph

View File

@@ -1,45 +1,45 @@
from datetime import datetime
# from datetime import datetime
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from onyx.agents.agent_search.dr.states import LoggerUpdate
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.server.query_and_chat.streaming_models import ImageGenerationToolStart
from onyx.utils.logger import setup_logger
# from onyx.agents.agent_search.dr.states import LoggerUpdate
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
# from onyx.server.query_and_chat.streaming_models import ImageGenerationToolStart
# from onyx.utils.logger import setup_logger
logger = setup_logger()
# logger = setup_logger()
def image_generation_branch(
state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> LoggerUpdate:
"""
LangGraph node to perform a image generation as part of the DR process.
"""
# def image_generation_branch(
# state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
# ) -> LoggerUpdate:
# """
# LangGraph node to perform a image generation as part of the DR process.
# """
node_start_time = datetime.now()
iteration_nr = state.iteration_nr
# node_start_time = datetime.now()
# iteration_nr = state.iteration_nr
logger.debug(f"Image generation start {iteration_nr} at {datetime.now()}")
# logger.debug(f"Image generation start {iteration_nr} at {datetime.now()}")
# tell frontend that we are starting the image generation tool
write_custom_event(
state.current_step_nr,
ImageGenerationToolStart(),
writer,
)
# # tell frontend that we are starting the image generation tool
# write_custom_event(
# state.current_step_nr,
# ImageGenerationToolStart(),
# writer,
# )
return LoggerUpdate(
log_messages=[
get_langgraph_node_log_string(
graph_component="image_generation",
node_name="branching",
node_start_time=node_start_time,
)
],
)
# return LoggerUpdate(
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="image_generation",
# node_name="branching",
# node_start_time=node_start_time,
# )
# ],
# )

View File

@@ -1,189 +1,187 @@
import json
from datetime import datetime
from typing import cast
# import json
# from datetime import datetime
# from typing import cast
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from onyx.agents.agent_search.dr.models import GeneratedImage
from onyx.agents.agent_search.dr.models import IterationAnswer
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
from onyx.agents.agent_search.dr.sub_agents.states import BranchUpdate
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.file_store.utils import build_frontend_file_url
from onyx.file_store.utils import save_files
from onyx.server.query_and_chat.streaming_models import ImageGenerationToolHeartbeat
from onyx.tools.tool_implementations.images.image_generation_tool import (
IMAGE_GENERATION_HEARTBEAT_ID,
)
from onyx.tools.tool_implementations.images.image_generation_tool import (
IMAGE_GENERATION_RESPONSE_ID,
)
from onyx.tools.tool_implementations.images.image_generation_tool import (
ImageGenerationResponse,
)
from onyx.tools.tool_implementations.images.image_generation_tool import (
ImageGenerationTool,
)
from onyx.tools.tool_implementations.images.image_generation_tool import ImageShape
from onyx.utils.logger import setup_logger
# from onyx.agents.agent_search.dr.models import GeneratedImage
# from onyx.agents.agent_search.dr.models import IterationAnswer
# from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
# from onyx.agents.agent_search.dr.sub_agents.states import BranchUpdate
# from onyx.agents.agent_search.models import GraphConfig
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
# from onyx.file_store.utils import build_frontend_file_url
# from onyx.file_store.utils import save_files
# from onyx.server.query_and_chat.streaming_models import ImageGenerationToolHeartbeat
# from onyx.tools.tool_implementations.images.image_generation_tool import (
# IMAGE_GENERATION_HEARTBEAT_ID,
# )
# from onyx.tools.tool_implementations.images.image_generation_tool import (
# IMAGE_GENERATION_RESPONSE_ID,
# )
# from onyx.tools.tool_implementations.images.image_generation_tool import (
# ImageGenerationResponse,
# )
# from onyx.tools.tool_implementations.images.image_generation_tool import (
# ImageGenerationTool,
# )
# from onyx.tools.tool_implementations.images.image_generation_tool import ImageShape
# from onyx.utils.logger import setup_logger
logger = setup_logger()
# logger = setup_logger()
def image_generation(
state: BranchInput,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> BranchUpdate:
"""
LangGraph node to perform a standard search as part of the DR process.
"""
# def image_generation(
# state: BranchInput,
# config: RunnableConfig,
# writer: StreamWriter = lambda _: None,
# ) -> BranchUpdate:
# """
# LangGraph node to perform a standard search as part of the DR process.
# """
node_start_time = datetime.now()
iteration_nr = state.iteration_nr
parallelization_nr = state.parallelization_nr
state.assistant_system_prompt
state.assistant_task_prompt
# node_start_time = datetime.now()
# iteration_nr = state.iteration_nr
# parallelization_nr = state.parallelization_nr
# state.assistant_system_prompt
# state.assistant_task_prompt
branch_query = state.branch_question
if not branch_query:
raise ValueError("branch_query is not set")
# branch_query = state.branch_question
# if not branch_query:
# raise ValueError("branch_query is not set")
graph_config = cast(GraphConfig, config["metadata"]["config"])
graph_config.inputs.prompt_builder.raw_user_query
graph_config.behavior.research_type
# graph_config = cast(GraphConfig, config["metadata"]["config"])
if not state.available_tools:
raise ValueError("available_tools is not set")
# if not state.available_tools:
# raise ValueError("available_tools is not set")
image_tool_info = state.available_tools[state.tools_used[-1]]
image_tool = cast(ImageGenerationTool, image_tool_info.tool_object)
# image_tool_info = state.available_tools[state.tools_used[-1]]
# image_tool = cast(ImageGenerationTool, image_tool_info.tool_object)
image_prompt = branch_query
requested_shape: ImageShape | None = None
# image_prompt = branch_query
# requested_shape: ImageShape | None = None
try:
parsed_query = json.loads(branch_query)
except json.JSONDecodeError:
parsed_query = None
# try:
# parsed_query = json.loads(branch_query)
# except json.JSONDecodeError:
# parsed_query = None
if isinstance(parsed_query, dict):
prompt_from_llm = parsed_query.get("prompt")
if isinstance(prompt_from_llm, str) and prompt_from_llm.strip():
image_prompt = prompt_from_llm.strip()
# if isinstance(parsed_query, dict):
# prompt_from_llm = parsed_query.get("prompt")
# if isinstance(prompt_from_llm, str) and prompt_from_llm.strip():
# image_prompt = prompt_from_llm.strip()
raw_shape = parsed_query.get("shape")
if isinstance(raw_shape, str):
try:
requested_shape = ImageShape(raw_shape)
except ValueError:
logger.warning(
"Received unsupported image shape '%s' from LLM. Falling back to square.",
raw_shape,
)
# raw_shape = parsed_query.get("shape")
# if isinstance(raw_shape, str):
# try:
# requested_shape = ImageShape(raw_shape)
# except ValueError:
# logger.warning(
# "Received unsupported image shape '%s' from LLM. Falling back to square.",
# raw_shape,
# )
logger.debug(
f"Image generation start for {iteration_nr}.{parallelization_nr} at {datetime.now()}"
)
# logger.debug(
# f"Image generation start for {iteration_nr}.{parallelization_nr} at {datetime.now()}"
# )
# Generate images using the image generation tool
image_generation_responses: list[ImageGenerationResponse] = []
# # Generate images using the image generation tool
# image_generation_responses: list[ImageGenerationResponse] = []
if requested_shape is not None:
tool_iterator = image_tool.run(
prompt=image_prompt,
shape=requested_shape.value,
)
else:
tool_iterator = image_tool.run(prompt=image_prompt)
# if requested_shape is not None:
# tool_iterator = image_tool.run(
# prompt=image_prompt,
# shape=requested_shape.value,
# )
# else:
# tool_iterator = image_tool.run(prompt=image_prompt)
for tool_response in tool_iterator:
if tool_response.id == IMAGE_GENERATION_HEARTBEAT_ID:
# Stream heartbeat to frontend
write_custom_event(
state.current_step_nr,
ImageGenerationToolHeartbeat(),
writer,
)
elif tool_response.id == IMAGE_GENERATION_RESPONSE_ID:
response = cast(list[ImageGenerationResponse], tool_response.response)
image_generation_responses = response
break
# for tool_response in tool_iterator:
# if tool_response.id == IMAGE_GENERATION_HEARTBEAT_ID:
# # Stream heartbeat to frontend
# write_custom_event(
# state.current_step_nr,
# ImageGenerationToolHeartbeat(),
# writer,
# )
# elif tool_response.id == IMAGE_GENERATION_RESPONSE_ID:
# response = cast(list[ImageGenerationResponse], tool_response.response)
# image_generation_responses = response
# break
# save images to file store
file_ids = save_files(
urls=[],
base64_files=[img.image_data for img in image_generation_responses],
)
# # save images to file store
# file_ids = save_files(
# urls=[],
# base64_files=[img.image_data for img in image_generation_responses],
# )
final_generated_images = [
GeneratedImage(
file_id=file_id,
url=build_frontend_file_url(file_id),
revised_prompt=img.revised_prompt,
shape=(requested_shape or ImageShape.SQUARE).value,
)
for file_id, img in zip(file_ids, image_generation_responses)
]
# final_generated_images = [
# GeneratedImage(
# file_id=file_id,
# url=build_frontend_file_url(file_id),
# revised_prompt=img.revised_prompt,
# shape=(requested_shape or ImageShape.SQUARE).value,
# )
# for file_id, img in zip(file_ids, image_generation_responses)
# ]
logger.debug(
f"Image generation complete for {iteration_nr}.{parallelization_nr} at {datetime.now()}"
)
# logger.debug(
# f"Image generation complete for {iteration_nr}.{parallelization_nr} at {datetime.now()}"
# )
# Create answer string describing the generated images
if final_generated_images:
image_descriptions = []
for i, img in enumerate(final_generated_images, 1):
if img.shape and img.shape != ImageShape.SQUARE.value:
image_descriptions.append(
f"Image {i}: {img.revised_prompt} (shape: {img.shape})"
)
else:
image_descriptions.append(f"Image {i}: {img.revised_prompt}")
# # Create answer string describing the generated images
# if final_generated_images:
# image_descriptions = []
# for i, img in enumerate(final_generated_images, 1):
# if img.shape and img.shape != ImageShape.SQUARE.value:
# image_descriptions.append(
# f"Image {i}: {img.revised_prompt} (shape: {img.shape})"
# )
# else:
# image_descriptions.append(f"Image {i}: {img.revised_prompt}")
answer_string = (
f"Generated {len(final_generated_images)} image(s) based on the request: {image_prompt}\n\n"
+ "\n".join(image_descriptions)
)
if requested_shape:
reasoning = (
"Used image generation tool to create "
f"{len(final_generated_images)} image(s) in {requested_shape.value} orientation."
)
else:
reasoning = (
"Used image generation tool to create "
f"{len(final_generated_images)} image(s) based on the user's request."
)
else:
answer_string = f"Failed to generate images for request: {image_prompt}"
reasoning = "Image generation tool did not return any results."
# answer_string = (
# f"Generated {len(final_generated_images)} image(s) based on the request: {image_prompt}\n\n"
# + "\n".join(image_descriptions)
# )
# if requested_shape:
# reasoning = (
# "Used image generation tool to create "
# f"{len(final_generated_images)} image(s) in {requested_shape.value} orientation."
# )
# else:
# reasoning = (
# "Used image generation tool to create "
# f"{len(final_generated_images)} image(s) based on the user's request."
# )
# else:
# answer_string = f"Failed to generate images for request: {image_prompt}"
# reasoning = "Image generation tool did not return any results."
return BranchUpdate(
branch_iteration_responses=[
IterationAnswer(
tool=image_tool_info.llm_path,
tool_id=image_tool_info.tool_id,
iteration_nr=iteration_nr,
parallelization_nr=parallelization_nr,
question=branch_query,
answer=answer_string,
claims=[],
cited_documents={},
reasoning=reasoning,
generated_images=final_generated_images,
)
],
log_messages=[
get_langgraph_node_log_string(
graph_component="image_generation",
node_name="generating",
node_start_time=node_start_time,
)
],
)
# return BranchUpdate(
# branch_iteration_responses=[
# IterationAnswer(
# tool=image_tool_info.llm_path,
# tool_id=image_tool_info.tool_id,
# iteration_nr=iteration_nr,
# parallelization_nr=parallelization_nr,
# question=branch_query,
# answer=answer_string,
# claims=[],
# cited_documents={},
# reasoning=reasoning,
# generated_images=final_generated_images,
# )
# ],
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="image_generation",
# node_name="generating",
# node_start_time=node_start_time,
# )
# ],
# )

View File

@@ -1,71 +1,71 @@
from datetime import datetime
# from datetime import datetime
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from onyx.agents.agent_search.dr.models import GeneratedImage
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentUpdate
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.server.query_and_chat.streaming_models import ImageGenerationToolDelta
from onyx.server.query_and_chat.streaming_models import SectionEnd
from onyx.utils.logger import setup_logger
# from onyx.agents.agent_search.dr.models import GeneratedImage
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentUpdate
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
# from onyx.server.query_and_chat.streaming_models import ImageGenerationFinal
# from onyx.server.query_and_chat.streaming_models import SectionEnd
# from onyx.utils.logger import setup_logger
logger = setup_logger()
# logger = setup_logger()
def is_reducer(
state: SubAgentMainState,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> SubAgentUpdate:
"""
LangGraph node to perform a standard search as part of the DR process.
"""
# def is_reducer(
# state: SubAgentMainState,
# config: RunnableConfig,
# writer: StreamWriter = lambda _: None,
# ) -> SubAgentUpdate:
# """
# LangGraph node to perform a standard search as part of the DR process.
# """
node_start_time = datetime.now()
# node_start_time = datetime.now()
branch_updates = state.branch_iteration_responses
current_iteration = state.iteration_nr
current_step_nr = state.current_step_nr
# branch_updates = state.branch_iteration_responses
# current_iteration = state.iteration_nr
# current_step_nr = state.current_step_nr
new_updates = [
update for update in branch_updates if update.iteration_nr == current_iteration
]
generated_images: list[GeneratedImage] = []
for update in new_updates:
if update.generated_images:
generated_images.extend(update.generated_images)
# new_updates = [
# update for update in branch_updates if update.iteration_nr == current_iteration
# ]
# generated_images: list[GeneratedImage] = []
# for update in new_updates:
# if update.generated_images:
# generated_images.extend(update.generated_images)
# Write the results to the stream
write_custom_event(
current_step_nr,
ImageGenerationToolDelta(
images=generated_images,
),
writer,
)
# # Write the results to the stream
# write_custom_event(
# current_step_nr,
# ImageGenerationFinal(
# images=generated_images,
# ),
# writer,
# )
write_custom_event(
current_step_nr,
SectionEnd(),
writer,
)
# write_custom_event(
# current_step_nr,
# SectionEnd(),
# writer,
# )
current_step_nr += 1
# current_step_nr += 1
return SubAgentUpdate(
iteration_responses=new_updates,
current_step_nr=current_step_nr,
log_messages=[
get_langgraph_node_log_string(
graph_component="image_generation",
node_name="consolidation",
node_start_time=node_start_time,
)
],
)
# return SubAgentUpdate(
# iteration_responses=new_updates,
# current_step_nr=current_step_nr,
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="image_generation",
# node_name="consolidation",
# node_start_time=node_start_time,
# )
# ],
# )

View File

@@ -1,29 +1,29 @@
from collections.abc import Hashable
# from collections.abc import Hashable
from langgraph.types import Send
# from langgraph.types import Send
from onyx.agents.agent_search.dr.constants import MAX_DR_PARALLEL_SEARCH
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
# from onyx.agents.agent_search.dr.constants import MAX_DR_PARALLEL_SEARCH
# from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
return [
Send(
"act",
BranchInput(
iteration_nr=state.iteration_nr,
parallelization_nr=parallelization_nr,
branch_question=query,
context="",
active_source_types=state.active_source_types,
tools_used=state.tools_used,
available_tools=state.available_tools,
assistant_system_prompt=state.assistant_system_prompt,
assistant_task_prompt=state.assistant_task_prompt,
),
)
for parallelization_nr, query in enumerate(
state.query_list[:MAX_DR_PARALLEL_SEARCH]
)
]
# def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
# return [
# Send(
# "act",
# BranchInput(
# iteration_nr=state.iteration_nr,
# parallelization_nr=parallelization_nr,
# branch_question=query,
# context="",
# active_source_types=state.active_source_types,
# tools_used=state.tools_used,
# available_tools=state.available_tools,
# assistant_system_prompt=state.assistant_system_prompt,
# assistant_task_prompt=state.assistant_task_prompt,
# ),
# )
# for parallelization_nr, query in enumerate(
# state.query_list[:MAX_DR_PARALLEL_SEARCH]
# )
# ]

View File

@@ -1,50 +1,50 @@
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
# from langgraph.graph import END
# from langgraph.graph import START
# from langgraph.graph import StateGraph
from onyx.agents.agent_search.dr.sub_agents.image_generation.dr_image_generation_1_branch import (
image_generation_branch,
)
from onyx.agents.agent_search.dr.sub_agents.image_generation.dr_image_generation_2_act import (
image_generation,
)
from onyx.agents.agent_search.dr.sub_agents.image_generation.dr_image_generation_3_reduce import (
is_reducer,
)
from onyx.agents.agent_search.dr.sub_agents.image_generation.dr_image_generation_conditional_edges import (
branching_router,
)
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
from onyx.utils.logger import setup_logger
# from onyx.agents.agent_search.dr.sub_agents.image_generation.dr_image_generation_1_branch import (
# image_generation_branch,
# )
# from onyx.agents.agent_search.dr.sub_agents.image_generation.dr_image_generation_2_act import (
# image_generation,
# )
# from onyx.agents.agent_search.dr.sub_agents.image_generation.dr_image_generation_3_reduce import (
# is_reducer,
# )
# from onyx.agents.agent_search.dr.sub_agents.image_generation.dr_image_generation_conditional_edges import (
# branching_router,
# )
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
# from onyx.utils.logger import setup_logger
logger = setup_logger()
# logger = setup_logger()
def dr_image_generation_graph_builder() -> StateGraph:
"""
LangGraph graph builder for Image Generation Sub-Agent
"""
# def dr_image_generation_graph_builder() -> StateGraph:
# """
# LangGraph graph builder for Image Generation Sub-Agent
# """
graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
# graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
### Add nodes ###
# ### Add nodes ###
graph.add_node("branch", image_generation_branch)
# graph.add_node("branch", image_generation_branch)
graph.add_node("act", image_generation)
# graph.add_node("act", image_generation)
graph.add_node("reducer", is_reducer)
# graph.add_node("reducer", is_reducer)
### Add edges ###
# ### Add edges ###
graph.add_edge(start_key=START, end_key="branch")
# graph.add_edge(start_key=START, end_key="branch")
graph.add_conditional_edges("branch", branching_router)
# graph.add_conditional_edges("branch", branching_router)
graph.add_edge(start_key="act", end_key="reducer")
# graph.add_edge(start_key="act", end_key="reducer")
graph.add_edge(start_key="reducer", end_key=END)
# graph.add_edge(start_key="reducer", end_key=END)
return graph
# return graph

View File

@@ -1,13 +1,13 @@
from pydantic import BaseModel
# from pydantic import BaseModel
class GeneratedImage(BaseModel):
file_id: str
url: str
revised_prompt: str
shape: str | None = None
# class GeneratedImage(BaseModel):
# file_id: str
# url: str
# revised_prompt: str
# shape: str | None = None
# Needed for PydanticType
class GeneratedImageFullResult(BaseModel):
images: list[GeneratedImage]
# # Needed for PydanticType
# class GeneratedImageFullResult(BaseModel):
# images: list[GeneratedImage]

View File

@@ -1,36 +1,36 @@
from datetime import datetime
# from datetime import datetime
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from onyx.agents.agent_search.dr.states import LoggerUpdate
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.utils.logger import setup_logger
# from onyx.agents.agent_search.dr.states import LoggerUpdate
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.utils.logger import setup_logger
logger = setup_logger()
# logger = setup_logger()
def kg_search_branch(
state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> LoggerUpdate:
"""
LangGraph node to perform a KG search as part of the DR process.
"""
# def kg_search_branch(
# state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
# ) -> LoggerUpdate:
# """
# LangGraph node to perform a KG search as part of the DR process.
# """
node_start_time = datetime.now()
iteration_nr = state.iteration_nr
# node_start_time = datetime.now()
# iteration_nr = state.iteration_nr
logger.debug(f"Search start for KG Search {iteration_nr} at {datetime.now()}")
# logger.debug(f"Search start for KG Search {iteration_nr} at {datetime.now()}")
return LoggerUpdate(
log_messages=[
get_langgraph_node_log_string(
graph_component="kg_search",
node_name="branching",
node_start_time=node_start_time,
)
],
)
# return LoggerUpdate(
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="kg_search",
# node_name="branching",
# node_start_time=node_start_time,
# )
# ],
# )

View File

@@ -1,97 +1,97 @@
from datetime import datetime
# from datetime import datetime
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from onyx.agents.agent_search.dr.models import IterationAnswer
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
from onyx.agents.agent_search.dr.sub_agents.states import BranchUpdate
from onyx.agents.agent_search.dr.utils import extract_document_citations
from onyx.agents.agent_search.kb_search.graph_builder import kb_graph_builder
from onyx.agents.agent_search.kb_search.states import MainInput as KbMainInput
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.context.search.models import InferenceSection
from onyx.utils.logger import setup_logger
# from onyx.agents.agent_search.dr.models import IterationAnswer
# from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
# from onyx.agents.agent_search.dr.sub_agents.states import BranchUpdate
# from onyx.agents.agent_search.dr.utils import extract_document_citations
# from onyx.agents.agent_search.kb_search.graph_builder import kb_graph_builder
# from onyx.agents.agent_search.kb_search.states import MainInput as KbMainInput
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.context.search.models import InferenceSection
# from onyx.utils.logger import setup_logger
logger = setup_logger()
# logger = setup_logger()
def kg_search(
state: BranchInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> BranchUpdate:
"""
LangGraph node to perform a KG search as part of the DR process.
"""
# def kg_search(
# state: BranchInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
# ) -> BranchUpdate:
# """
# LangGraph node to perform a KG search as part of the DR process.
# """
node_start_time = datetime.now()
iteration_nr = state.iteration_nr
state.current_step_nr
parallelization_nr = state.parallelization_nr
# node_start_time = datetime.now()
# iteration_nr = state.iteration_nr
# state.current_step_nr
# parallelization_nr = state.parallelization_nr
search_query = state.branch_question
if not search_query:
raise ValueError("search_query is not set")
# search_query = state.branch_question
# if not search_query:
# raise ValueError("search_query is not set")
logger.debug(
f"Search start for KG Search {iteration_nr}.{parallelization_nr} at {datetime.now()}"
)
# logger.debug(
# f"Search start for KG Search {iteration_nr}.{parallelization_nr} at {datetime.now()}"
# )
if not state.available_tools:
raise ValueError("available_tools is not set")
# if not state.available_tools:
# raise ValueError("available_tools is not set")
kg_tool_info = state.available_tools[state.tools_used[-1]]
# kg_tool_info = state.available_tools[state.tools_used[-1]]
kb_graph = kb_graph_builder().compile()
# kb_graph = kb_graph_builder().compile()
kb_results = kb_graph.invoke(
input=KbMainInput(question=search_query, individual_flow=False),
config=config,
)
# kb_results = kb_graph.invoke(
# input=KbMainInput(question=search_query, individual_flow=False),
# config=config,
# )
# get cited documents
answer_string = kb_results.get("final_answer") or "No answer provided"
claims: list[str] = []
retrieved_docs: list[InferenceSection] = kb_results.get("retrieved_documents", [])
# # get cited documents
# answer_string = kb_results.get("final_answer") or "No answer provided"
# claims: list[str] = []
# retrieved_docs: list[InferenceSection] = kb_results.get("retrieved_documents", [])
(
citation_numbers,
answer_string,
claims,
) = extract_document_citations(answer_string, claims)
# (
# citation_numbers,
# answer_string,
# claims,
# ) = extract_document_citations(answer_string, claims)
# if citation is empty, the answer must have come from the KG rather than a doc
# in that case, simply cite the docs returned by the KG
if not citation_numbers:
citation_numbers = [i + 1 for i in range(len(retrieved_docs))]
# # if citation is empty, the answer must have come from the KG rather than a doc
# # in that case, simply cite the docs returned by the KG
# if not citation_numbers:
# citation_numbers = [i + 1 for i in range(len(retrieved_docs))]
cited_documents = {
citation_number: retrieved_docs[citation_number - 1]
for citation_number in citation_numbers
if citation_number <= len(retrieved_docs)
}
# cited_documents = {
# citation_number: retrieved_docs[citation_number - 1]
# for citation_number in citation_numbers
# if citation_number <= len(retrieved_docs)
# }
return BranchUpdate(
branch_iteration_responses=[
IterationAnswer(
tool=kg_tool_info.llm_path,
tool_id=kg_tool_info.tool_id,
iteration_nr=iteration_nr,
parallelization_nr=parallelization_nr,
question=search_query,
answer=answer_string,
claims=claims,
cited_documents=cited_documents,
reasoning=None,
additional_data=None,
)
],
log_messages=[
get_langgraph_node_log_string(
graph_component="kg_search",
node_name="searching",
node_start_time=node_start_time,
)
],
)
# return BranchUpdate(
# branch_iteration_responses=[
# IterationAnswer(
# tool=kg_tool_info.llm_path,
# tool_id=kg_tool_info.tool_id,
# iteration_nr=iteration_nr,
# parallelization_nr=parallelization_nr,
# question=search_query,
# answer=answer_string,
# claims=claims,
# cited_documents=cited_documents,
# reasoning=None,
# additional_data=None,
# )
# ],
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="kg_search",
# node_name="searching",
# node_start_time=node_start_time,
# )
# ],
# )

View File

@@ -1,121 +1,121 @@
from datetime import datetime
# from datetime import datetime
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentUpdate
from onyx.agents.agent_search.dr.utils import convert_inference_sections_to_search_docs
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.server.query_and_chat.streaming_models import ReasoningDelta
from onyx.server.query_and_chat.streaming_models import ReasoningStart
from onyx.server.query_and_chat.streaming_models import SearchToolDelta
from onyx.server.query_and_chat.streaming_models import SearchToolStart
from onyx.server.query_and_chat.streaming_models import SectionEnd
from onyx.utils.logger import setup_logger
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentUpdate
# from onyx.agents.agent_search.dr.utils import convert_inference_sections_to_search_docs
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
# from onyx.server.query_and_chat.streaming_models import ReasoningDelta
# from onyx.server.query_and_chat.streaming_models import ReasoningStart
# from onyx.server.query_and_chat.streaming_models import SearchToolDelta
# from onyx.server.query_and_chat.streaming_models import SearchToolStart
# from onyx.server.query_and_chat.streaming_models import SectionEnd
# from onyx.utils.logger import setup_logger
logger = setup_logger()
# logger = setup_logger()
_MAX_KG_STEAMED_ANSWER_LENGTH = 1000 # num characters
# _MAX_KG_STEAMED_ANSWER_LENGTH = 1000 # num characters
def kg_search_reducer(
state: SubAgentMainState,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> SubAgentUpdate:
"""
LangGraph node to perform a KG search as part of the DR process.
"""
# def kg_search_reducer(
# state: SubAgentMainState,
# config: RunnableConfig,
# writer: StreamWriter = lambda _: None,
# ) -> SubAgentUpdate:
# """
# LangGraph node to perform a KG search as part of the DR process.
# """
node_start_time = datetime.now()
# node_start_time = datetime.now()
branch_updates = state.branch_iteration_responses
current_iteration = state.iteration_nr
current_step_nr = state.current_step_nr
# branch_updates = state.branch_iteration_responses
# current_iteration = state.iteration_nr
# current_step_nr = state.current_step_nr
new_updates = [
update for update in branch_updates if update.iteration_nr == current_iteration
]
# new_updates = [
# update for update in branch_updates if update.iteration_nr == current_iteration
# ]
queries = [update.question for update in new_updates]
doc_lists = [list(update.cited_documents.values()) for update in new_updates]
# queries = [update.question for update in new_updates]
# doc_lists = [list(update.cited_documents.values()) for update in new_updates]
doc_list = []
# doc_list = []
for xs in doc_lists:
for x in xs:
doc_list.append(x)
# for xs in doc_lists:
# for x in xs:
# doc_list.append(x)
retrieved_search_docs = convert_inference_sections_to_search_docs(doc_list)
kg_answer = (
"The Knowledge Graph Answer:\n\n" + new_updates[0].answer
if len(queries) == 1
else None
)
# retrieved_search_docs = convert_inference_sections_to_search_docs(doc_list)
# kg_answer = (
# "The Knowledge Graph Answer:\n\n" + new_updates[0].answer
# if len(queries) == 1
# else None
# )
if len(retrieved_search_docs) > 0:
write_custom_event(
current_step_nr,
SearchToolStart(
is_internet_search=False,
),
writer,
)
write_custom_event(
current_step_nr,
SearchToolDelta(
queries=queries,
documents=retrieved_search_docs,
),
writer,
)
write_custom_event(
current_step_nr,
SectionEnd(),
writer,
)
# if len(retrieved_search_docs) > 0:
# write_custom_event(
# current_step_nr,
# SearchToolStart(
# is_internet_search=False,
# ),
# writer,
# )
# write_custom_event(
# current_step_nr,
# SearchToolDelta(
# queries=queries,
# documents=retrieved_search_docs,
# ),
# writer,
# )
# write_custom_event(
# current_step_nr,
# SectionEnd(),
# writer,
# )
current_step_nr += 1
# current_step_nr += 1
if kg_answer is not None:
# if kg_answer is not None:
kg_display_answer = (
f"{kg_answer[:_MAX_KG_STEAMED_ANSWER_LENGTH]}..."
if len(kg_answer) > _MAX_KG_STEAMED_ANSWER_LENGTH
else kg_answer
)
# kg_display_answer = (
# f"{kg_answer[:_MAX_KG_STEAMED_ANSWER_LENGTH]}..."
# if len(kg_answer) > _MAX_KG_STEAMED_ANSWER_LENGTH
# else kg_answer
# )
write_custom_event(
current_step_nr,
ReasoningStart(),
writer,
)
write_custom_event(
current_step_nr,
ReasoningDelta(reasoning=kg_display_answer),
writer,
)
write_custom_event(
current_step_nr,
SectionEnd(),
writer,
)
# write_custom_event(
# current_step_nr,
# ReasoningStart(),
# writer,
# )
# write_custom_event(
# current_step_nr,
# ReasoningDelta(reasoning=kg_display_answer),
# writer,
# )
# write_custom_event(
# current_step_nr,
# SectionEnd(),
# writer,
# )
current_step_nr += 1
# current_step_nr += 1
return SubAgentUpdate(
iteration_responses=new_updates,
current_step_nr=current_step_nr,
log_messages=[
get_langgraph_node_log_string(
graph_component="kg_search",
node_name="consolidation",
node_start_time=node_start_time,
)
],
)
# return SubAgentUpdate(
# iteration_responses=new_updates,
# current_step_nr=current_step_nr,
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="kg_search",
# node_name="consolidation",
# node_start_time=node_start_time,
# )
# ],
# )

View File

@@ -1,27 +1,27 @@
from collections.abc import Hashable
# from collections.abc import Hashable
from langgraph.types import Send
# from langgraph.types import Send
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
# from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
return [
Send(
"act",
BranchInput(
iteration_nr=state.iteration_nr,
parallelization_nr=parallelization_nr,
branch_question=query,
context="",
tools_used=state.tools_used,
available_tools=state.available_tools,
assistant_system_prompt=state.assistant_system_prompt,
assistant_task_prompt=state.assistant_task_prompt,
),
)
for parallelization_nr, query in enumerate(
state.query_list[:1] # no parallel search for now
)
]
# def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
# return [
# Send(
# "act",
# BranchInput(
# iteration_nr=state.iteration_nr,
# parallelization_nr=parallelization_nr,
# branch_question=query,
# context="",
# tools_used=state.tools_used,
# available_tools=state.available_tools,
# assistant_system_prompt=state.assistant_system_prompt,
# assistant_task_prompt=state.assistant_task_prompt,
# ),
# )
# for parallelization_nr, query in enumerate(
# state.query_list[:1] # no parallel search for now
# )
# ]

View File

@@ -1,50 +1,50 @@
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
# from langgraph.graph import END
# from langgraph.graph import START
# from langgraph.graph import StateGraph
from onyx.agents.agent_search.dr.sub_agents.kg_search.dr_kg_search_1_branch import (
kg_search_branch,
)
from onyx.agents.agent_search.dr.sub_agents.kg_search.dr_kg_search_2_act import (
kg_search,
)
from onyx.agents.agent_search.dr.sub_agents.kg_search.dr_kg_search_3_reduce import (
kg_search_reducer,
)
from onyx.agents.agent_search.dr.sub_agents.kg_search.dr_kg_search_conditional_edges import (
branching_router,
)
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
from onyx.utils.logger import setup_logger
# from onyx.agents.agent_search.dr.sub_agents.kg_search.dr_kg_search_1_branch import (
# kg_search_branch,
# )
# from onyx.agents.agent_search.dr.sub_agents.kg_search.dr_kg_search_2_act import (
# kg_search,
# )
# from onyx.agents.agent_search.dr.sub_agents.kg_search.dr_kg_search_3_reduce import (
# kg_search_reducer,
# )
# from onyx.agents.agent_search.dr.sub_agents.kg_search.dr_kg_search_conditional_edges import (
# branching_router,
# )
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
# from onyx.utils.logger import setup_logger
logger = setup_logger()
# logger = setup_logger()
def dr_kg_search_graph_builder() -> StateGraph:
"""
LangGraph graph builder for KG Search Sub-Agent
"""
# def dr_kg_search_graph_builder() -> StateGraph:
# """
# LangGraph graph builder for KG Search Sub-Agent
# """
graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
# graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
### Add nodes ###
# ### Add nodes ###
graph.add_node("branch", kg_search_branch)
# graph.add_node("branch", kg_search_branch)
graph.add_node("act", kg_search)
# graph.add_node("act", kg_search)
graph.add_node("reducer", kg_search_reducer)
# graph.add_node("reducer", kg_search_reducer)
### Add edges ###
# ### Add edges ###
graph.add_edge(start_key=START, end_key="branch")
# graph.add_edge(start_key=START, end_key="branch")
graph.add_conditional_edges("branch", branching_router)
# graph.add_conditional_edges("branch", branching_router)
graph.add_edge(start_key="act", end_key="reducer")
# graph.add_edge(start_key="act", end_key="reducer")
graph.add_edge(start_key="reducer", end_key=END)
# graph.add_edge(start_key="reducer", end_key=END)
return graph
# return graph

View File

@@ -1,46 +1,46 @@
from operator import add
from typing import Annotated
# from operator import add
# from typing import Annotated
from onyx.agents.agent_search.dr.models import IterationAnswer
from onyx.agents.agent_search.dr.models import OrchestratorTool
from onyx.agents.agent_search.dr.states import LoggerUpdate
from onyx.db.connector import DocumentSource
# from onyx.agents.agent_search.dr.models import IterationAnswer
# from onyx.agents.agent_search.dr.models import OrchestratorTool
# from onyx.agents.agent_search.dr.states import LoggerUpdate
# from onyx.db.connector import DocumentSource
class SubAgentUpdate(LoggerUpdate):
iteration_responses: Annotated[list[IterationAnswer], add] = []
current_step_nr: int = 1
# class SubAgentUpdate(LoggerUpdate):
# iteration_responses: Annotated[list[IterationAnswer], add] = []
# current_step_nr: int = 1
class BranchUpdate(LoggerUpdate):
branch_iteration_responses: Annotated[list[IterationAnswer], add] = []
# class BranchUpdate(LoggerUpdate):
# branch_iteration_responses: Annotated[list[IterationAnswer], add] = []
class SubAgentInput(LoggerUpdate):
iteration_nr: int = 0
current_step_nr: int = 1
query_list: list[str] = []
context: str | None = None
active_source_types: list[DocumentSource] | None = None
tools_used: Annotated[list[str], add] = []
available_tools: dict[str, OrchestratorTool] | None = None
assistant_system_prompt: str | None = None
assistant_task_prompt: str | None = None
# class SubAgentInput(LoggerUpdate):
# iteration_nr: int = 0
# current_step_nr: int = 1
# query_list: list[str] = []
# context: str | None = None
# active_source_types: list[DocumentSource] | None = None
# tools_used: Annotated[list[str], add] = []
# available_tools: dict[str, OrchestratorTool] | None = None
# assistant_system_prompt: str | None = None
# assistant_task_prompt: str | None = None
class SubAgentMainState(
# This includes the core state
SubAgentInput,
SubAgentUpdate,
BranchUpdate,
):
pass
# class SubAgentMainState(
# # This includes the core state
# SubAgentInput,
# SubAgentUpdate,
# BranchUpdate,
# ):
# pass
class BranchInput(SubAgentInput):
parallelization_nr: int = 0
branch_question: str
# class BranchInput(SubAgentInput):
# parallelization_nr: int = 0
# branch_question: str
class CustomToolBranchInput(LoggerUpdate):
tool_info: OrchestratorTool
# class CustomToolBranchInput(LoggerUpdate):
# tool_info: OrchestratorTool

View File

@@ -1,70 +1,71 @@
from collections.abc import Sequence
# from collections.abc import Sequence
from exa_py import Exa
from exa_py.api import HighlightsContentsOptions
# from exa_py import Exa
# from exa_py.api import HighlightsContentsOptions
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
WebContent,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
WebSearchProvider,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
WebSearchResult,
)
from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
from onyx.utils.retry_wrapper import retry_builder
# from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
# WebContent,
# )
# from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
# WebSearchProvider,
# )
# from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
# WebSearchResult,
# )
# from onyx.configs.chat_configs import EXA_API_KEY
# from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
# from onyx.utils.retry_wrapper import retry_builder
class ExaClient(WebSearchProvider):
def __init__(self, api_key: str) -> None:
self.exa = Exa(api_key=api_key)
# class ExaClient(WebSearchProvider):
# def __init__(self, api_key: str | None = EXA_API_KEY) -> None:
# self.exa = Exa(api_key=api_key)
@retry_builder(tries=3, delay=1, backoff=2)
def search(self, query: str) -> list[WebSearchResult]:
response = self.exa.search_and_contents(
query,
type="auto",
highlights=HighlightsContentsOptions(
num_sentences=2,
highlights_per_url=1,
),
num_results=10,
)
# @retry_builder(tries=3, delay=1, backoff=2)
# def search(self, query: str) -> list[WebSearchResult]:
# response = self.exa.search_and_contents(
# query,
# type="auto",
# highlights=HighlightsContentsOptions(
# num_sentences=2,
# highlights_per_url=1,
# ),
# num_results=10,
# )
return [
WebSearchResult(
title=result.title or "",
link=result.url,
snippet=result.highlights[0] if result.highlights else "",
author=result.author,
published_date=(
time_str_to_utc(result.published_date)
if result.published_date
else None
),
)
for result in response.results
]
# return [
# WebSearchResult(
# title=result.title or "",
# link=result.url,
# snippet=result.highlights[0] if result.highlights else "",
# author=result.author,
# published_date=(
# time_str_to_utc(result.published_date)
# if result.published_date
# else None
# ),
# )
# for result in response.results
# ]
@retry_builder(tries=3, delay=1, backoff=2)
def contents(self, urls: Sequence[str]) -> list[WebContent]:
response = self.exa.get_contents(
urls=list(urls),
text=True,
livecrawl="preferred",
)
# @retry_builder(tries=3, delay=1, backoff=2)
# def contents(self, urls: Sequence[str]) -> list[WebContent]:
# response = self.exa.get_contents(
# urls=list(urls),
# text=True,
# livecrawl="preferred",
# )
return [
WebContent(
title=result.title or "",
link=result.url,
full_content=result.text or "",
published_date=(
time_str_to_utc(result.published_date)
if result.published_date
else None
),
)
for result in response.results
]
# return [
# WebContent(
# title=result.title or "",
# link=result.url,
# full_content=result.text or "",
# published_date=(
# time_str_to_utc(result.published_date)
# if result.published_date
# else None
# ),
# )
# for result in response.results
# ]

View File

@@ -1,159 +1,148 @@
import json
from collections.abc import Sequence
from concurrent.futures import ThreadPoolExecutor
# import json
# from collections.abc import Sequence
# from concurrent.futures import ThreadPoolExecutor
import requests
# import requests
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
WebContent,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
WebSearchProvider,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
WebSearchResult,
)
from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
from onyx.utils.retry_wrapper import retry_builder
# from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
# WebContent,
# )
# from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
# WebSearchProvider,
# )
# from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
# WebSearchResult,
# )
# from onyx.configs.chat_configs import SERPER_API_KEY
# from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
# from onyx.utils.retry_wrapper import retry_builder
SERPER_SEARCH_URL = "https://google.serper.dev/search"
SERPER_CONTENTS_URL = "https://scrape.serper.dev"
# SERPER_SEARCH_URL = "https://google.serper.dev/search"
# SERPER_CONTENTS_URL = "https://scrape.serper.dev"
class SerperClient(WebSearchProvider):
def __init__(self, api_key: str) -> None:
self.headers = {
"X-API-KEY": api_key,
"Content-Type": "application/json",
}
# class SerperClient(WebSearchProvider):
# def __init__(self, api_key: str | None = SERPER_API_KEY) -> None:
# self.headers = {
# "X-API-KEY": api_key,
# "Content-Type": "application/json",
# }
@retry_builder(tries=3, delay=1, backoff=2)
def search(self, query: str) -> list[WebSearchResult]:
payload = {
"q": query,
}
# @retry_builder(tries=3, delay=1, backoff=2)
# def search(self, query: str) -> list[WebSearchResult]:
# payload = {
# "q": query,
# }
response = requests.post(
SERPER_SEARCH_URL,
headers=self.headers,
data=json.dumps(payload),
)
# response = requests.post(
# SERPER_SEARCH_URL,
# headers=self.headers,
# data=json.dumps(payload),
# )
try:
response.raise_for_status()
except Exception:
# Avoid leaking API keys/URLs
raise ValueError(
"Serper search failed. Check credentials or quota."
) from None
# response.raise_for_status()
results = response.json()
organic_results = results["organic"]
# results = response.json()
# organic_results = results["organic"]
return [
WebSearchResult(
title=result["title"],
link=result["link"],
snippet=result["snippet"],
author=None,
published_date=None,
)
for result in organic_results
]
# return [
# WebSearchResult(
# title=result["title"],
# link=result["link"],
# snippet=result["snippet"],
# author=None,
# published_date=None,
# )
# for result in organic_results
# ]
def contents(self, urls: Sequence[str]) -> list[WebContent]:
if not urls:
return []
# def contents(self, urls: Sequence[str]) -> list[WebContent]:
# if not urls:
# return []
# Serper can responds with 500s regularly. We want to retry,
# but in the event of failure, return an unsuccesful scrape.
def safe_get_webpage_content(url: str) -> WebContent:
try:
return self._get_webpage_content(url)
except Exception:
return WebContent(
title="",
link=url,
full_content="",
published_date=None,
scrape_successful=False,
)
# # Serper can responds with 500s regularly. We want to retry,
# # but in the event of failure, return an unsuccesful scrape.
# def safe_get_webpage_content(url: str) -> WebContent:
# try:
# return self._get_webpage_content(url)
# except Exception:
# return WebContent(
# title="",
# link=url,
# full_content="",
# published_date=None,
# scrape_successful=False,
# )
with ThreadPoolExecutor(max_workers=min(8, len(urls))) as e:
return list(e.map(safe_get_webpage_content, urls))
# with ThreadPoolExecutor(max_workers=min(8, len(urls))) as e:
# return list(e.map(safe_get_webpage_content, urls))
@retry_builder(tries=3, delay=1, backoff=2)
def _get_webpage_content(self, url: str) -> WebContent:
payload = {
"url": url,
}
# @retry_builder(tries=3, delay=1, backoff=2)
# def _get_webpage_content(self, url: str) -> WebContent:
# payload = {
# "url": url,
# }
response = requests.post(
SERPER_CONTENTS_URL,
headers=self.headers,
data=json.dumps(payload),
)
# response = requests.post(
# SERPER_CONTENTS_URL,
# headers=self.headers,
# data=json.dumps(payload),
# )
# 400 returned when serper cannot scrape
if response.status_code == 400:
return WebContent(
title="",
link=url,
full_content="",
published_date=None,
scrape_successful=False,
)
# # 400 returned when serper cannot scrape
# if response.status_code == 400:
# return WebContent(
# title="",
# link=url,
# full_content="",
# published_date=None,
# scrape_successful=False,
# )
try:
response.raise_for_status()
except Exception:
# Avoid leaking API keys/URLs
raise ValueError(
"Serper content fetch failed. Check credentials."
) from None
# response.raise_for_status()
response_json = response.json()
# response_json = response.json()
# Response only guarantees text
text = response_json["text"]
# # Response only guarantees text
# text = response_json["text"]
# metadata & jsonld is not guaranteed to be present
metadata = response_json.get("metadata", {})
jsonld = response_json.get("jsonld", {})
# # metadata & jsonld is not guaranteed to be present
# metadata = response_json.get("metadata", {})
# jsonld = response_json.get("jsonld", {})
title = extract_title_from_metadata(metadata)
# title = extract_title_from_metadata(metadata)
# Serper does not provide a reliable mechanism to extract the url
response_url = url
published_date_str = extract_published_date_from_jsonld(jsonld)
published_date = None
# # Serper does not provide a reliable mechanism to extract the url
# response_url = url
# published_date_str = extract_published_date_from_jsonld(jsonld)
# published_date = None
if published_date_str:
try:
published_date = time_str_to_utc(published_date_str)
except Exception:
published_date = None
# if published_date_str:
# try:
# published_date = time_str_to_utc(published_date_str)
# except Exception:
# published_date = None
return WebContent(
title=title or "",
link=response_url,
full_content=text or "",
published_date=published_date,
)
# return WebContent(
# title=title or "",
# link=response_url,
# full_content=text or "",
# published_date=published_date,
# )
def extract_title_from_metadata(metadata: dict[str, str]) -> str | None:
keys = ["title", "og:title"]
return extract_value_from_dict(metadata, keys)
# def extract_title_from_metadata(metadata: dict[str, str]) -> str | None:
# keys = ["title", "og:title"]
# return extract_value_from_dict(metadata, keys)
def extract_published_date_from_jsonld(jsonld: dict[str, str]) -> str | None:
keys = ["dateModified"]
return extract_value_from_dict(jsonld, keys)
# def extract_published_date_from_jsonld(jsonld: dict[str, str]) -> str | None:
# keys = ["dateModified"]
# return extract_value_from_dict(jsonld, keys)
def extract_value_from_dict(data: dict[str, str], keys: list[str]) -> str | None:
for key in keys:
if key in data:
return data[key]
return None
# def extract_value_from_dict(data: dict[str, str], keys: list[str]) -> str | None:
# for key in keys:
# if key in data:
# return data[key]
# return None

View File

@@ -1,47 +1,47 @@
from datetime import datetime
# from datetime import datetime
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from onyx.agents.agent_search.dr.states import LoggerUpdate
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.server.query_and_chat.streaming_models import SearchToolStart
from onyx.utils.logger import setup_logger
# from onyx.agents.agent_search.dr.states import LoggerUpdate
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
# from onyx.server.query_and_chat.streaming_models import SearchToolStart
# from onyx.utils.logger import setup_logger
logger = setup_logger()
# logger = setup_logger()
def is_branch(
state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> LoggerUpdate:
"""
LangGraph node to perform a web search as part of the DR process.
"""
# def is_branch(
# state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
# ) -> LoggerUpdate:
# """
# LangGraph node to perform a web search as part of the DR process.
# """
node_start_time = datetime.now()
iteration_nr = state.iteration_nr
current_step_nr = state.current_step_nr
# node_start_time = datetime.now()
# iteration_nr = state.iteration_nr
# current_step_nr = state.current_step_nr
logger.debug(f"Search start for Web Search {iteration_nr} at {datetime.now()}")
# logger.debug(f"Search start for Web Search {iteration_nr} at {datetime.now()}")
write_custom_event(
current_step_nr,
SearchToolStart(
is_internet_search=True,
),
writer,
)
# write_custom_event(
# current_step_nr,
# SearchToolStart(
# is_internet_search=True,
# ),
# writer,
# )
return LoggerUpdate(
log_messages=[
get_langgraph_node_log_string(
graph_component="internet_search",
node_name="branching",
node_start_time=node_start_time,
)
],
)
# return LoggerUpdate(
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="internet_search",
# node_name="branching",
# node_start_time=node_start_time,
# )
# ],
# )

View File

@@ -1,137 +1,128 @@
from datetime import datetime
from typing import cast
# from datetime import datetime
# from typing import cast
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from langsmith import traceable
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
# from langsmith import traceable
from onyx.agents.agent_search.dr.models import WebSearchAnswer
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
WebSearchResult,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.providers import (
get_default_provider,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.states import (
InternetSearchInput,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.states import (
InternetSearchUpdate,
)
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.llm import invoke_llm_json
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.agents.agent_search.utils import create_question_prompt
from onyx.configs.agent_configs import TF_DR_TIMEOUT_SHORT
from onyx.prompts.dr_prompts import WEB_SEARCH_URL_SELECTION_PROMPT
from onyx.server.query_and_chat.streaming_models import SearchToolDelta
from onyx.utils.logger import setup_logger
# from onyx.agents.agent_search.dr.models import WebSearchAnswer
# from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
# WebSearchResult,
# )
# from onyx.agents.agent_search.dr.sub_agents.web_search.providers import (
# get_default_provider,
# )
# from onyx.agents.agent_search.dr.sub_agents.web_search.states import (
# InternetSearchInput,
# )
# from onyx.agents.agent_search.dr.sub_agents.web_search.states import (
# InternetSearchUpdate,
# )
# from onyx.agents.agent_search.models import GraphConfig
# from onyx.agents.agent_search.shared_graph_utils.llm import invoke_llm_json
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
# from onyx.agents.agent_search.utils import create_question_prompt
# from onyx.configs.agent_configs import TF_DR_TIMEOUT_SHORT
# from onyx.prompts.dr_prompts import WEB_SEARCH_URL_SELECTION_PROMPT
# from onyx.server.query_and_chat.streaming_models import SearchToolDelta
# from onyx.utils.logger import setup_logger
logger = setup_logger()
# logger = setup_logger()
def web_search(
state: InternetSearchInput,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> InternetSearchUpdate:
"""
LangGraph node to perform internet search and decide which URLs to fetch.
"""
# def web_search(
# state: InternetSearchInput,
# config: RunnableConfig,
# writer: StreamWriter = lambda _: None,
# ) -> InternetSearchUpdate:
# """
# LangGraph node to perform internet search and decide which URLs to fetch.
# """
node_start_time = datetime.now()
current_step_nr = state.current_step_nr
# node_start_time = datetime.now()
# current_step_nr = state.current_step_nr
if not current_step_nr:
raise ValueError("Current step number is not set. This should not happen.")
# if not current_step_nr:
# raise ValueError("Current step number is not set. This should not happen.")
assistant_system_prompt = state.assistant_system_prompt
assistant_task_prompt = state.assistant_task_prompt
# assistant_system_prompt = state.assistant_system_prompt
# assistant_task_prompt = state.assistant_task_prompt
if not state.available_tools:
raise ValueError("available_tools is not set")
search_query = state.branch_question
# if not state.available_tools:
# raise ValueError("available_tools is not set")
# search_query = state.branch_question
write_custom_event(
current_step_nr,
SearchToolDelta(
queries=[search_query],
documents=[],
),
writer,
)
# write_custom_event(
# current_step_nr,
# SearchToolDelta(
# queries=[search_query],
# documents=[],
# ),
# writer,
# )
graph_config = cast(GraphConfig, config["metadata"]["config"])
base_question = graph_config.inputs.prompt_builder.raw_user_query
# graph_config = cast(GraphConfig, config["metadata"]["config"])
# base_question = graph_config.inputs.prompt_builder.raw_user_query
if graph_config.inputs.persona is None:
raise ValueError("persona is not set")
# if graph_config.inputs.persona is None:
# raise ValueError("persona is not set")
provider = get_default_provider()
if not provider:
raise ValueError("No internet search provider found")
# provider = get_default_provider()
# if not provider:
# raise ValueError("No internet search provider found")
# Log which provider type is being used
provider_type = type(provider).__name__
logger.info(
f"Performing web search with {provider_type} for query: '{search_query}'"
)
# @traceable(name="Search Provider API Call")
# def _search(search_query: str) -> list[WebSearchResult]:
# search_results: list[WebSearchResult] = []
# try:
# search_results = list(provider.search(search_query))
# except Exception as e:
# logger.error(f"Error performing search: {e}")
# return search_results
@traceable(name="Search Provider API Call")
def _search(search_query: str) -> list[WebSearchResult]:
search_results: list[WebSearchResult] = []
try:
search_results = list(provider.search(search_query))
logger.info(
f"Search returned {len(search_results)} results using {provider_type}"
)
except Exception as e:
logger.error(f"Error performing search with {provider_type}: {e}")
return search_results
search_results: list[WebSearchResult] = _search(search_query)
search_results_text = "\n\n".join(
[
f"{i}. {result.title}\n URL: {result.link}\n"
+ (f" Author: {result.author}\n" if result.author else "")
+ (
f" Date: {result.published_date.strftime('%Y-%m-%d')}\n"
if result.published_date
else ""
)
+ (f" Snippet: {result.snippet}\n" if result.snippet else "")
for i, result in enumerate(search_results)
]
)
agent_decision_prompt = WEB_SEARCH_URL_SELECTION_PROMPT.build(
search_query=search_query,
base_question=base_question,
search_results_text=search_results_text,
)
agent_decision = invoke_llm_json(
llm=graph_config.tooling.fast_llm,
prompt=create_question_prompt(
assistant_system_prompt,
agent_decision_prompt + (assistant_task_prompt or ""),
),
schema=WebSearchAnswer,
timeout_override=TF_DR_TIMEOUT_SHORT,
)
results_to_open = [
(search_query, search_results[i])
for i in agent_decision.urls_to_open_indices
if i < len(search_results) and i >= 0
]
return InternetSearchUpdate(
results_to_open=results_to_open,
log_messages=[
get_langgraph_node_log_string(
graph_component="internet_search",
node_name="searching",
node_start_time=node_start_time,
)
],
)
# search_results: list[WebSearchResult] = _search(search_query)
# search_results_text = "\n\n".join(
# [
# f"{i}. {result.title}\n URL: {result.link}\n"
# + (f" Author: {result.author}\n" if result.author else "")
# + (
# f" Date: {result.published_date.strftime('%Y-%m-%d')}\n"
# if result.published_date
# else ""
# )
# + (f" Snippet: {result.snippet}\n" if result.snippet else "")
# for i, result in enumerate(search_results)
# ]
# )
# agent_decision_prompt = WEB_SEARCH_URL_SELECTION_PROMPT.build(
# search_query=search_query,
# base_question=base_question,
# search_results_text=search_results_text,
# )
# agent_decision = invoke_llm_json(
# llm=graph_config.tooling.fast_llm,
# prompt=create_question_prompt(
# assistant_system_prompt,
# agent_decision_prompt + (assistant_task_prompt or ""),
# ),
# schema=WebSearchAnswer,
# timeout_override=TF_DR_TIMEOUT_SHORT,
# )
# results_to_open = [
# (search_query, search_results[i])
# for i in agent_decision.urls_to_open_indices
# if i < len(search_results) and i >= 0
# ]
# return InternetSearchUpdate(
# results_to_open=results_to_open,
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="internet_search",
# node_name="searching",
# node_start_time=node_start_time,
# )
# ],
# )

View File

@@ -1,52 +1,52 @@
from collections import defaultdict
# from collections import defaultdict
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
WebSearchResult,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.states import (
InternetSearchInput,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.utils import (
dummy_inference_section_from_internet_search_result,
)
from onyx.agents.agent_search.dr.utils import convert_inference_sections_to_search_docs
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.server.query_and_chat.streaming_models import SearchToolDelta
# from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
# WebSearchResult,
# )
# from onyx.agents.agent_search.dr.sub_agents.web_search.states import (
# InternetSearchInput,
# )
# from onyx.agents.agent_search.dr.sub_agents.web_search.utils import (
# dummy_inference_section_from_internet_search_result,
# )
# from onyx.agents.agent_search.dr.utils import convert_inference_sections_to_search_docs
# from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
# from onyx.server.query_and_chat.streaming_models import SearchToolDelta
def dedup_urls(
state: InternetSearchInput,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> InternetSearchInput:
branch_questions_to_urls: dict[str, list[str]] = defaultdict(list)
unique_results_by_link: dict[str, WebSearchResult] = {}
for query, result in state.results_to_open:
branch_questions_to_urls[query].append(result.link)
if result.link not in unique_results_by_link:
unique_results_by_link[result.link] = result
# def dedup_urls(
# state: InternetSearchInput,
# config: RunnableConfig,
# writer: StreamWriter = lambda _: None,
# ) -> InternetSearchInput:
# branch_questions_to_urls: dict[str, list[str]] = defaultdict(list)
# unique_results_by_link: dict[str, WebSearchResult] = {}
# for query, result in state.results_to_open:
# branch_questions_to_urls[query].append(result.link)
# if result.link not in unique_results_by_link:
# unique_results_by_link[result.link] = result
unique_results = list(unique_results_by_link.values())
dummy_docs_inference_sections = [
dummy_inference_section_from_internet_search_result(doc)
for doc in unique_results
]
# unique_results = list(unique_results_by_link.values())
# dummy_docs_inference_sections = [
# dummy_inference_section_from_internet_search_result(doc)
# for doc in unique_results
# ]
write_custom_event(
state.current_step_nr,
SearchToolDelta(
queries=[],
documents=convert_inference_sections_to_search_docs(
dummy_docs_inference_sections, is_internet=True
),
),
writer,
)
# write_custom_event(
# state.current_step_nr,
# SearchToolDelta(
# queries=[],
# documents=convert_inference_sections_to_search_docs(
# dummy_docs_inference_sections, is_internet=True
# ),
# ),
# writer,
# )
return InternetSearchInput(
results_to_open=[],
branch_questions_to_urls=branch_questions_to_urls,
)
# return InternetSearchInput(
# results_to_open=[],
# branch_questions_to_urls=branch_questions_to_urls,
# )

View File

@@ -1,69 +1,69 @@
from datetime import datetime
from typing import cast
# from datetime import datetime
# from typing import cast
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from onyx.agents.agent_search.dr.sub_agents.web_search.providers import (
get_default_content_provider,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.states import FetchInput
from onyx.agents.agent_search.dr.sub_agents.web_search.states import FetchUpdate
from onyx.agents.agent_search.dr.sub_agents.web_search.utils import (
dummy_inference_section_from_internet_content,
)
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.context.search.models import InferenceSection
from onyx.utils.logger import setup_logger
# from onyx.agents.agent_search.dr.sub_agents.web_search.providers import (
# get_default_provider,
# )
# from onyx.agents.agent_search.dr.sub_agents.web_search.states import FetchInput
# from onyx.agents.agent_search.dr.sub_agents.web_search.states import FetchUpdate
# from onyx.agents.agent_search.dr.sub_agents.web_search.utils import (
# inference_section_from_internet_page_scrape,
# )
# from onyx.agents.agent_search.models import GraphConfig
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.context.search.models import InferenceSection
# from onyx.utils.logger import setup_logger
logger = setup_logger()
# logger = setup_logger()
def web_fetch(
state: FetchInput,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> FetchUpdate:
"""
LangGraph node to fetch content from URLs and process the results.
"""
# def web_fetch(
# state: FetchInput,
# config: RunnableConfig,
# writer: StreamWriter = lambda _: None,
# ) -> FetchUpdate:
# """
# LangGraph node to fetch content from URLs and process the results.
# """
node_start_time = datetime.now()
# node_start_time = datetime.now()
if not state.available_tools:
raise ValueError("available_tools is not set")
# if not state.available_tools:
# raise ValueError("available_tools is not set")
graph_config = cast(GraphConfig, config["metadata"]["config"])
# graph_config = cast(GraphConfig, config["metadata"]["config"])
if graph_config.inputs.persona is None:
raise ValueError("persona is not set")
# if graph_config.inputs.persona is None:
# raise ValueError("persona is not set")
provider = get_default_content_provider()
if provider is None:
raise ValueError("No web content provider found")
# provider = get_default_provider()
# if provider is None:
# raise ValueError("No web search provider found")
retrieved_docs: list[InferenceSection] = []
try:
retrieved_docs = [
dummy_inference_section_from_internet_content(result)
for result in provider.contents(state.urls_to_open)
]
except Exception as e:
logger.exception(e)
# retrieved_docs: list[InferenceSection] = []
# try:
# retrieved_docs = [
# inference_section_from_internet_page_scrape(result)
# for result in provider.contents(state.urls_to_open)
# ]
# except Exception as e:
# logger.error(f"Error fetching URLs: {e}")
if not retrieved_docs:
logger.warning("No content retrieved from URLs")
# if not retrieved_docs:
# logger.warning("No content retrieved from URLs")
return FetchUpdate(
raw_documents=retrieved_docs,
log_messages=[
get_langgraph_node_log_string(
graph_component="internet_search",
node_name="fetching",
node_start_time=node_start_time,
)
],
)
# return FetchUpdate(
# raw_documents=retrieved_docs,
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="internet_search",
# node_name="fetching",
# node_start_time=node_start_time,
# )
# ],
# )

View File

@@ -1,19 +1,19 @@
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from onyx.agents.agent_search.dr.sub_agents.web_search.states import FetchInput
from onyx.agents.agent_search.dr.sub_agents.web_search.states import (
InternetSearchInput,
)
# from onyx.agents.agent_search.dr.sub_agents.web_search.states import FetchInput
# from onyx.agents.agent_search.dr.sub_agents.web_search.states import (
# InternetSearchInput,
# )
def collect_raw_docs(
state: FetchInput,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> InternetSearchInput:
raw_documents = state.raw_documents
# def collect_raw_docs(
# state: FetchInput,
# config: RunnableConfig,
# writer: StreamWriter = lambda _: None,
# ) -> InternetSearchInput:
# raw_documents = state.raw_documents
return InternetSearchInput(
raw_documents=raw_documents,
)
# return InternetSearchInput(
# raw_documents=raw_documents,
# )

View File

@@ -1,133 +1,132 @@
from datetime import datetime
from typing import cast
# from datetime import datetime
# from typing import cast
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from onyx.agents.agent_search.dr.enums import ResearchType
from onyx.agents.agent_search.dr.models import IterationAnswer
from onyx.agents.agent_search.dr.models import SearchAnswer
from onyx.agents.agent_search.dr.sub_agents.states import BranchUpdate
from onyx.agents.agent_search.dr.sub_agents.web_search.states import SummarizeInput
from onyx.agents.agent_search.dr.utils import extract_document_citations
from onyx.agents.agent_search.kb_search.graph_utils import build_document_context
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.llm import invoke_llm_json
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.utils import create_question_prompt
from onyx.configs.agent_configs import TF_DR_TIMEOUT_SHORT
from onyx.context.search.models import InferenceSection
from onyx.prompts.dr_prompts import INTERNAL_SEARCH_PROMPTS
from onyx.utils.logger import setup_logger
from onyx.utils.url import normalize_url
# from onyx.agents.agent_search.dr.models import IterationAnswer
# from onyx.agents.agent_search.dr.models import SearchAnswer
# from onyx.agents.agent_search.dr.sub_agents.states import BranchUpdate
# from onyx.agents.agent_search.dr.sub_agents.web_search.states import SummarizeInput
# from onyx.agents.agent_search.dr.utils import extract_document_citations
# from onyx.agents.agent_search.kb_search.graph_utils import build_document_context
# from onyx.agents.agent_search.models import GraphConfig
# from onyx.agents.agent_search.shared_graph_utils.llm import invoke_llm_json
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.agents.agent_search.utils import create_question_prompt
# from onyx.configs.agent_configs import TF_DR_TIMEOUT_SHORT
# from onyx.context.search.models import InferenceSection
# from onyx.prompts.dr_prompts import INTERNAL_SEARCH_PROMPTS
# from onyx.utils.logger import setup_logger
# from onyx.utils.url import normalize_url
logger = setup_logger()
# logger = setup_logger()
def is_summarize(
state: SummarizeInput,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> BranchUpdate:
"""
LangGraph node to perform a internet search as part of the DR process.
"""
# def is_summarize(
# state: SummarizeInput,
# config: RunnableConfig,
# writer: StreamWriter = lambda _: None,
# ) -> BranchUpdate:
# """
# LangGraph node to perform a internet search as part of the DR process.
# """
node_start_time = datetime.now()
# node_start_time = datetime.now()
# build branch iterations from fetch inputs
# Normalize URLs to handle mismatches from query parameters (e.g., ?activeTab=explore)
url_to_raw_document: dict[str, InferenceSection] = {}
for raw_document in state.raw_documents:
normalized_url = normalize_url(raw_document.center_chunk.semantic_identifier)
url_to_raw_document[normalized_url] = raw_document
# # build branch iterations from fetch inputs
# # Normalize URLs to handle mismatches from query parameters (e.g., ?activeTab=explore)
# url_to_raw_document: dict[str, InferenceSection] = {}
# for raw_document in state.raw_documents:
# normalized_url = normalize_url(raw_document.center_chunk.semantic_identifier)
# url_to_raw_document[normalized_url] = raw_document
# Normalize the URLs from branch_questions_to_urls as well
urls = [
normalize_url(url)
for url in state.branch_questions_to_urls[state.branch_question]
]
current_iteration = state.iteration_nr
graph_config = cast(GraphConfig, config["metadata"]["config"])
research_type = graph_config.behavior.research_type
if not state.available_tools:
raise ValueError("available_tools is not set")
is_tool_info = state.available_tools[state.tools_used[-1]]
# # Normalize the URLs from branch_questions_to_urls as well
# urls = [
# normalize_url(url)
# for url in state.branch_questions_to_urls[state.branch_question]
# ]
# current_iteration = state.iteration_nr
# graph_config = cast(GraphConfig, config["metadata"]["config"])
# use_agentic_search = graph_config.behavior.use_agentic_search
# if not state.available_tools:
# raise ValueError("available_tools is not set")
# is_tool_info = state.available_tools[state.tools_used[-1]]
if research_type == ResearchType.DEEP:
cited_raw_documents = [url_to_raw_document[url] for url in urls]
document_texts = _create_document_texts(cited_raw_documents)
search_prompt = INTERNAL_SEARCH_PROMPTS[research_type].build(
search_query=state.branch_question,
base_question=graph_config.inputs.prompt_builder.raw_user_query,
document_text=document_texts,
)
assistant_system_prompt = state.assistant_system_prompt
assistant_task_prompt = state.assistant_task_prompt
search_answer_json = invoke_llm_json(
llm=graph_config.tooling.primary_llm,
prompt=create_question_prompt(
assistant_system_prompt, search_prompt + (assistant_task_prompt or "")
),
schema=SearchAnswer,
timeout_override=TF_DR_TIMEOUT_SHORT,
)
answer_string = search_answer_json.answer
claims = search_answer_json.claims or []
reasoning = search_answer_json.reasoning or ""
(
citation_numbers,
answer_string,
claims,
) = extract_document_citations(answer_string, claims)
cited_documents = {
citation_number: cited_raw_documents[citation_number - 1]
for citation_number in citation_numbers
}
# if use_agentic_search:
# cited_raw_documents = [url_to_raw_document[url] for url in urls]
# document_texts = _create_document_texts(cited_raw_documents)
# search_prompt = INTERNAL_SEARCH_PROMPTS[use_agentic_search].build(
# search_query=state.branch_question,
# base_question=graph_config.inputs.prompt_builder.raw_user_query,
# document_text=document_texts,
# )
# assistant_system_prompt = state.assistant_system_prompt
# assistant_task_prompt = state.assistant_task_prompt
# search_answer_json = invoke_llm_json(
# llm=graph_config.tooling.primary_llm,
# prompt=create_question_prompt(
# assistant_system_prompt, search_prompt + (assistant_task_prompt or "")
# ),
# schema=SearchAnswer,
# timeout_override=TF_DR_TIMEOUT_SHORT,
# )
# answer_string = search_answer_json.answer
# claims = search_answer_json.claims or []
# reasoning = search_answer_json.reasoning or ""
# (
# citation_numbers,
# answer_string,
# claims,
# ) = extract_document_citations(answer_string, claims)
# cited_documents = {
# citation_number: cited_raw_documents[citation_number - 1]
# for citation_number in citation_numbers
# }
else:
answer_string = ""
reasoning = ""
claims = []
cited_raw_documents = [url_to_raw_document[url] for url in urls]
cited_documents = {
doc_num + 1: retrieved_doc
for doc_num, retrieved_doc in enumerate(cited_raw_documents)
}
# else:
# answer_string = ""
# reasoning = ""
# claims = []
# cited_raw_documents = [url_to_raw_document[url] for url in urls]
# cited_documents = {
# doc_num + 1: retrieved_doc
# for doc_num, retrieved_doc in enumerate(cited_raw_documents)
# }
return BranchUpdate(
branch_iteration_responses=[
IterationAnswer(
tool=is_tool_info.llm_path,
tool_id=is_tool_info.tool_id,
iteration_nr=current_iteration,
parallelization_nr=0,
question=state.branch_question,
answer=answer_string,
claims=claims,
cited_documents=cited_documents,
reasoning=reasoning,
additional_data=None,
)
],
log_messages=[
get_langgraph_node_log_string(
graph_component="internet_search",
node_name="summarizing",
node_start_time=node_start_time,
)
],
)
# return BranchUpdate(
# branch_iteration_responses=[
# IterationAnswer(
# tool=is_tool_info.llm_path,
# tool_id=is_tool_info.tool_id,
# iteration_nr=current_iteration,
# parallelization_nr=0,
# question=state.branch_question,
# answer=answer_string,
# claims=claims,
# cited_documents=cited_documents,
# reasoning=reasoning,
# additional_data=None,
# )
# ],
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="internet_search",
# node_name="summarizing",
# node_start_time=node_start_time,
# )
# ],
# )
def _create_document_texts(raw_documents: list[InferenceSection]) -> str:
document_texts_list = []
for doc_num, retrieved_doc in enumerate(raw_documents):
if not isinstance(retrieved_doc, InferenceSection):
raise ValueError(f"Unexpected document type: {type(retrieved_doc)}")
chunk_text = build_document_context(retrieved_doc, doc_num + 1)
document_texts_list.append(chunk_text)
return "\n\n".join(document_texts_list)
# def _create_document_texts(raw_documents: list[InferenceSection]) -> str:
# document_texts_list = []
# for doc_num, retrieved_doc in enumerate(raw_documents):
# if not isinstance(retrieved_doc, InferenceSection):
# raise ValueError(f"Unexpected document type: {type(retrieved_doc)}")
# chunk_text = build_document_context(retrieved_doc, doc_num + 1)
# document_texts_list.append(chunk_text)
# return "\n\n".join(document_texts_list)

View File

@@ -1,56 +1,56 @@
from datetime import datetime
# from datetime import datetime
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentUpdate
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.server.query_and_chat.streaming_models import SectionEnd
from onyx.utils.logger import setup_logger
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentUpdate
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
# from onyx.server.query_and_chat.streaming_models import SectionEnd
# from onyx.utils.logger import setup_logger
logger = setup_logger()
# logger = setup_logger()
def is_reducer(
state: SubAgentMainState,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> SubAgentUpdate:
"""
LangGraph node to perform a internet search as part of the DR process.
"""
# def is_reducer(
# state: SubAgentMainState,
# config: RunnableConfig,
# writer: StreamWriter = lambda _: None,
# ) -> SubAgentUpdate:
# """
# LangGraph node to perform a internet search as part of the DR process.
# """
node_start_time = datetime.now()
# node_start_time = datetime.now()
branch_updates = state.branch_iteration_responses
current_iteration = state.iteration_nr
current_step_nr = state.current_step_nr
# branch_updates = state.branch_iteration_responses
# current_iteration = state.iteration_nr
# current_step_nr = state.current_step_nr
new_updates = [
update for update in branch_updates if update.iteration_nr == current_iteration
]
# new_updates = [
# update for update in branch_updates if update.iteration_nr == current_iteration
# ]
write_custom_event(
current_step_nr,
SectionEnd(),
writer,
)
# write_custom_event(
# current_step_nr,
# SectionEnd(),
# writer,
# )
current_step_nr += 1
# current_step_nr += 1
return SubAgentUpdate(
iteration_responses=new_updates,
current_step_nr=current_step_nr,
log_messages=[
get_langgraph_node_log_string(
graph_component="internet_search",
node_name="consolidation",
node_start_time=node_start_time,
)
],
)
# return SubAgentUpdate(
# iteration_responses=new_updates,
# current_step_nr=current_step_nr,
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="internet_search",
# node_name="consolidation",
# node_start_time=node_start_time,
# )
# ],
# )

View File

@@ -1,79 +1,79 @@
from collections.abc import Hashable
# from collections.abc import Hashable
from langgraph.types import Send
# from langgraph.types import Send
from onyx.agents.agent_search.dr.constants import MAX_DR_PARALLEL_SEARCH
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
from onyx.agents.agent_search.dr.sub_agents.web_search.states import FetchInput
from onyx.agents.agent_search.dr.sub_agents.web_search.states import (
InternetSearchInput,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.states import SummarizeInput
# from onyx.agents.agent_search.dr.constants import MAX_DR_PARALLEL_SEARCH
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
# from onyx.agents.agent_search.dr.sub_agents.web_search.states import FetchInput
# from onyx.agents.agent_search.dr.sub_agents.web_search.states import (
# InternetSearchInput,
# )
# from onyx.agents.agent_search.dr.sub_agents.web_search.states import SummarizeInput
def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
return [
Send(
"search",
InternetSearchInput(
iteration_nr=state.iteration_nr,
current_step_nr=state.current_step_nr,
parallelization_nr=parallelization_nr,
query_list=[query],
branch_question=query,
context="",
tools_used=state.tools_used,
available_tools=state.available_tools,
assistant_system_prompt=state.assistant_system_prompt,
assistant_task_prompt=state.assistant_task_prompt,
results_to_open=[],
),
)
for parallelization_nr, query in enumerate(
state.query_list[:MAX_DR_PARALLEL_SEARCH]
)
]
# def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
# return [
# Send(
# "search",
# InternetSearchInput(
# iteration_nr=state.iteration_nr,
# current_step_nr=state.current_step_nr,
# parallelization_nr=parallelization_nr,
# query_list=[query],
# branch_question=query,
# context="",
# tools_used=state.tools_used,
# available_tools=state.available_tools,
# assistant_system_prompt=state.assistant_system_prompt,
# assistant_task_prompt=state.assistant_task_prompt,
# results_to_open=[],
# ),
# )
# for parallelization_nr, query in enumerate(
# state.query_list[:MAX_DR_PARALLEL_SEARCH]
# )
# ]
def fetch_router(state: InternetSearchInput) -> list[Send | Hashable]:
branch_questions_to_urls = state.branch_questions_to_urls
return [
Send(
"fetch",
FetchInput(
iteration_nr=state.iteration_nr,
urls_to_open=[url],
tools_used=state.tools_used,
available_tools=state.available_tools,
assistant_system_prompt=state.assistant_system_prompt,
assistant_task_prompt=state.assistant_task_prompt,
current_step_nr=state.current_step_nr,
branch_questions_to_urls=branch_questions_to_urls,
raw_documents=state.raw_documents,
),
)
for url in set(
url for urls in branch_questions_to_urls.values() for url in urls
)
]
# def fetch_router(state: InternetSearchInput) -> list[Send | Hashable]:
# branch_questions_to_urls = state.branch_questions_to_urls
# return [
# Send(
# "fetch",
# FetchInput(
# iteration_nr=state.iteration_nr,
# urls_to_open=[url],
# tools_used=state.tools_used,
# available_tools=state.available_tools,
# assistant_system_prompt=state.assistant_system_prompt,
# assistant_task_prompt=state.assistant_task_prompt,
# current_step_nr=state.current_step_nr,
# branch_questions_to_urls=branch_questions_to_urls,
# raw_documents=state.raw_documents,
# ),
# )
# for url in set(
# url for urls in branch_questions_to_urls.values() for url in urls
# )
# ]
def summarize_router(state: InternetSearchInput) -> list[Send | Hashable]:
branch_questions_to_urls = state.branch_questions_to_urls
return [
Send(
"summarize",
SummarizeInput(
iteration_nr=state.iteration_nr,
raw_documents=state.raw_documents,
branch_questions_to_urls=branch_questions_to_urls,
branch_question=branch_question,
tools_used=state.tools_used,
available_tools=state.available_tools,
assistant_system_prompt=state.assistant_system_prompt,
assistant_task_prompt=state.assistant_task_prompt,
current_step_nr=state.current_step_nr,
),
)
for branch_question in branch_questions_to_urls.keys()
]
# def summarize_router(state: InternetSearchInput) -> list[Send | Hashable]:
# branch_questions_to_urls = state.branch_questions_to_urls
# return [
# Send(
# "summarize",
# SummarizeInput(
# iteration_nr=state.iteration_nr,
# raw_documents=state.raw_documents,
# branch_questions_to_urls=branch_questions_to_urls,
# branch_question=branch_question,
# tools_used=state.tools_used,
# available_tools=state.available_tools,
# assistant_system_prompt=state.assistant_system_prompt,
# assistant_task_prompt=state.assistant_task_prompt,
# current_step_nr=state.current_step_nr,
# ),
# )
# for branch_question in branch_questions_to_urls.keys()
# ]

View File

@@ -1,84 +1,84 @@
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
# from langgraph.graph import END
# from langgraph.graph import START
# from langgraph.graph import StateGraph
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
from onyx.agents.agent_search.dr.sub_agents.web_search.dr_ws_1_branch import (
is_branch,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.dr_ws_2_search import (
web_search,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.dr_ws_3_dedup_urls import (
dedup_urls,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.dr_ws_4_fetch import (
web_fetch,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.dr_ws_5_collect_raw_docs import (
collect_raw_docs,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.dr_ws_6_summarize import (
is_summarize,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.dr_ws_7_reduce import (
is_reducer,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.dr_ws_conditional_edges import (
branching_router,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.dr_ws_conditional_edges import (
fetch_router,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.dr_ws_conditional_edges import (
summarize_router,
)
from onyx.utils.logger import setup_logger
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
# from onyx.agents.agent_search.dr.sub_agents.web_search.dr_ws_1_branch import (
# is_branch,
# )
# from onyx.agents.agent_search.dr.sub_agents.web_search.dr_ws_2_search import (
# web_search,
# )
# from onyx.agents.agent_search.dr.sub_agents.web_search.dr_ws_3_dedup_urls import (
# dedup_urls,
# )
# from onyx.agents.agent_search.dr.sub_agents.web_search.dr_ws_4_fetch import (
# web_fetch,
# )
# from onyx.agents.agent_search.dr.sub_agents.web_search.dr_ws_5_collect_raw_docs import (
# collect_raw_docs,
# )
# from onyx.agents.agent_search.dr.sub_agents.web_search.dr_ws_6_summarize import (
# is_summarize,
# )
# from onyx.agents.agent_search.dr.sub_agents.web_search.dr_ws_7_reduce import (
# is_reducer,
# )
# from onyx.agents.agent_search.dr.sub_agents.web_search.dr_ws_conditional_edges import (
# branching_router,
# )
# from onyx.agents.agent_search.dr.sub_agents.web_search.dr_ws_conditional_edges import (
# fetch_router,
# )
# from onyx.agents.agent_search.dr.sub_agents.web_search.dr_ws_conditional_edges import (
# summarize_router,
# )
# from onyx.utils.logger import setup_logger
logger = setup_logger()
# logger = setup_logger()
def dr_ws_graph_builder() -> StateGraph:
"""
LangGraph graph builder for Internet Search Sub-Agent
"""
# def dr_ws_graph_builder() -> StateGraph:
# """
# LangGraph graph builder for Internet Search Sub-Agent
# """
graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
# graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
### Add nodes ###
# ### Add nodes ###
graph.add_node("branch", is_branch)
# graph.add_node("branch", is_branch)
graph.add_node("search", web_search)
# graph.add_node("search", web_search)
graph.add_node("dedup_urls", dedup_urls)
# graph.add_node("dedup_urls", dedup_urls)
graph.add_node("fetch", web_fetch)
# graph.add_node("fetch", web_fetch)
graph.add_node("collect_raw_docs", collect_raw_docs)
# graph.add_node("collect_raw_docs", collect_raw_docs)
graph.add_node("summarize", is_summarize)
# graph.add_node("summarize", is_summarize)
graph.add_node("reducer", is_reducer)
# graph.add_node("reducer", is_reducer)
### Add edges ###
# ### Add edges ###
graph.add_edge(start_key=START, end_key="branch")
# graph.add_edge(start_key=START, end_key="branch")
graph.add_conditional_edges("branch", branching_router)
# graph.add_conditional_edges("branch", branching_router)
graph.add_edge(start_key="search", end_key="dedup_urls")
# graph.add_edge(start_key="search", end_key="dedup_urls")
graph.add_conditional_edges("dedup_urls", fetch_router)
# graph.add_conditional_edges("dedup_urls", fetch_router)
graph.add_edge(start_key="fetch", end_key="collect_raw_docs")
# graph.add_edge(start_key="fetch", end_key="collect_raw_docs")
graph.add_conditional_edges("collect_raw_docs", summarize_router)
# graph.add_conditional_edges("collect_raw_docs", summarize_router)
graph.add_edge(start_key="summarize", end_key="reducer")
# graph.add_edge(start_key="summarize", end_key="reducer")
graph.add_edge(start_key="reducer", end_key=END)
# graph.add_edge(start_key="reducer", end_key=END)
return graph
# return graph

View File

@@ -1,47 +1,53 @@
from abc import ABC
from abc import abstractmethod
from collections.abc import Sequence
from datetime import datetime
# from abc import ABC
# from abc import abstractmethod
# from collections.abc import Sequence
# from datetime import datetime
# from enum import Enum
from pydantic import BaseModel
from pydantic import field_validator
# from pydantic import BaseModel
# from pydantic import field_validator
from onyx.utils.url import normalize_url
# from onyx.utils.url import normalize_url
class WebSearchResult(BaseModel):
title: str
link: str
snippet: str | None = None
author: str | None = None
published_date: datetime | None = None
# class ProviderType(Enum):
# """Enum for internet search provider types"""
@field_validator("link")
@classmethod
def normalize_link(cls, v: str) -> str:
return normalize_url(v)
# GOOGLE = "google"
# EXA = "exa"
class WebContent(BaseModel):
title: str
link: str
full_content: str
published_date: datetime | None = None
scrape_successful: bool = True
# class WebSearchResult(BaseModel):
# title: str
# link: str
# snippet: str | None = None
# author: str | None = None
# published_date: datetime | None = None
@field_validator("link")
@classmethod
def normalize_link(cls, v: str) -> str:
return normalize_url(v)
# @field_validator("link")
# @classmethod
# def normalize_link(cls, v: str) -> str:
# return normalize_url(v)
class WebContentProvider(ABC):
@abstractmethod
def contents(self, urls: Sequence[str]) -> list[WebContent]:
pass
# class WebContent(BaseModel):
# title: str
# link: str
# full_content: str
# published_date: datetime | None = None
# scrape_successful: bool = True
# @field_validator("link")
# @classmethod
# def normalize_link(cls, v: str) -> str:
# return normalize_url(v)
class WebSearchProvider(WebContentProvider):
@abstractmethod
def search(self, query: str) -> Sequence[WebSearchResult]:
pass
# class WebSearchProvider(ABC):
# @abstractmethod
# def search(self, query: str) -> Sequence[WebSearchResult]:
# pass
# @abstractmethod
# def contents(self, urls: Sequence[str]) -> list[WebContent]:
# pass

View File

@@ -1,199 +1,19 @@
from typing import Any
from onyx.agents.agent_search.dr.sub_agents.web_search.clients.exa_client import (
ExaClient,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.clients.firecrawl_client import (
FIRECRAWL_SCRAPE_URL,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.clients.firecrawl_client import (
FirecrawlClient,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.clients.google_pse_client import (
GooglePSEClient,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.clients.onyx_web_crawler_client import (
OnyxWebCrawlerClient,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.clients.serper_client import (
SerperClient,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
WebContentProvider,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
WebSearchProvider,
)
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.web_search import fetch_active_web_content_provider
from onyx.db.web_search import fetch_active_web_search_provider
from onyx.utils.logger import setup_logger
from shared_configs.enums import WebContentProviderType
from shared_configs.enums import WebSearchProviderType
logger = setup_logger()
# from onyx.agents.agent_search.dr.sub_agents.web_search.clients.exa_client import (
# ExaClient,
# )
# from onyx.agents.agent_search.dr.sub_agents.web_search.clients.serper_client import (
# SerperClient,
# )
# from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
# WebSearchProvider,
# )
# from onyx.configs.chat_configs import EXA_API_KEY
# from onyx.configs.chat_configs import SERPER_API_KEY
def build_search_provider_from_config(
*,
provider_type: WebSearchProviderType,
api_key: str | None,
config: dict[str, str] | None,
provider_name: str = "web_search_provider",
) -> WebSearchProvider | None:
provider_type_value = provider_type.value
try:
provider_type_enum = WebSearchProviderType(provider_type_value)
except ValueError:
logger.error(
f"Unknown web search provider type '{provider_type_value}'. "
"Skipping provider initialization."
)
return None
# All web search providers require an API key
if not api_key:
raise ValueError(
f"Web search provider '{provider_name}' is missing an API key."
)
assert api_key is not None
config = config or {}
if provider_type_enum == WebSearchProviderType.EXA:
return ExaClient(api_key=api_key)
if provider_type_enum == WebSearchProviderType.SERPER:
return SerperClient(api_key=api_key)
if provider_type_enum == WebSearchProviderType.GOOGLE_PSE:
search_engine_id = (
config.get("search_engine_id")
or config.get("cx")
or config.get("search_engine")
)
if not search_engine_id:
raise ValueError(
"Google PSE provider requires a search engine id (cx) in addition to the API key."
)
assert search_engine_id is not None
try:
num_results = int(config.get("num_results", 10))
except (TypeError, ValueError):
raise ValueError(
"Invalid value for Google PSE 'num_results'; expected integer."
)
try:
timeout_seconds = int(config.get("timeout_seconds", 10))
except (TypeError, ValueError):
raise ValueError(
"Invalid value for Google PSE 'timeout_seconds'; expected integer."
)
return GooglePSEClient(
api_key=api_key,
search_engine_id=search_engine_id,
num_results=num_results,
timeout_seconds=timeout_seconds,
)
logger.error(
f"Unhandled web search provider type '{provider_type_value}'. "
"Skipping provider initialization."
)
return None
def _build_search_provider(provider_model: Any) -> WebSearchProvider | None:
return build_search_provider_from_config(
provider_type=WebSearchProviderType(provider_model.provider_type),
api_key=provider_model.api_key,
config=provider_model.config or {},
provider_name=provider_model.name,
)
def build_content_provider_from_config(
*,
provider_type: WebContentProviderType,
api_key: str | None,
config: dict[str, str] | None,
provider_name: str = "web_content_provider",
) -> WebContentProvider | None:
provider_type_value = provider_type.value
try:
provider_type_enum = WebContentProviderType(provider_type_value)
except ValueError:
logger.error(
f"Unknown web content provider type '{provider_type_value}'. "
"Skipping provider initialization."
)
return None
if provider_type_enum == WebContentProviderType.ONYX_WEB_CRAWLER:
config = config or {}
timeout_value = config.get("timeout_seconds", 15)
try:
timeout_seconds = int(timeout_value)
except (TypeError, ValueError):
raise ValueError(
"Invalid value for Onyx Web Crawler 'timeout_seconds'; expected integer."
)
return OnyxWebCrawlerClient(timeout_seconds=timeout_seconds)
if provider_type_enum == WebContentProviderType.FIRECRAWL:
if not api_key:
raise ValueError("Firecrawl content provider requires an API key.")
assert api_key is not None
config = config or {}
timeout_seconds_str = config.get("timeout_seconds")
if timeout_seconds_str is None:
timeout_seconds = 10
else:
try:
timeout_seconds = int(timeout_seconds_str)
except (TypeError, ValueError):
raise ValueError(
"Invalid value for Firecrawl 'timeout_seconds'; expected integer."
)
return FirecrawlClient(
api_key=api_key,
base_url=config.get("base_url") or FIRECRAWL_SCRAPE_URL,
timeout_seconds=timeout_seconds,
)
logger.error(
f"Unhandled web content provider type '{provider_type_value}'. "
"Skipping provider initialization."
)
return None
def _build_content_provider(provider_model: Any) -> WebContentProvider | None:
return build_content_provider_from_config(
provider_type=WebContentProviderType(provider_model.provider_type),
api_key=provider_model.api_key,
config=provider_model.config or {},
provider_name=provider_model.name,
)
def get_default_provider() -> WebSearchProvider | None:
with get_session_with_current_tenant() as db_session:
provider_model = fetch_active_web_search_provider(db_session)
if provider_model is None:
return None
return _build_search_provider(provider_model)
def get_default_content_provider() -> WebContentProvider | None:
with get_session_with_current_tenant() as db_session:
provider_model = fetch_active_web_content_provider(db_session)
if provider_model:
provider = _build_content_provider(provider_model)
if provider:
return provider
# Fall back to built-in Onyx crawler when nothing is configured.
try:
return OnyxWebCrawlerClient()
except Exception as exc: # pragma: no cover - defensive
logger.error(f"Failed to initialize default Onyx crawler: {exc}")
return None
# def get_default_provider() -> WebSearchProvider | None:
# if EXA_API_KEY:
# return ExaClient()
# if SERPER_API_KEY:
# return SerperClient()
# return None

View File

@@ -1,37 +1,37 @@
from operator import add
from typing import Annotated
# from operator import add
# from typing import Annotated
from onyx.agents.agent_search.dr.states import LoggerUpdate
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
WebSearchResult,
)
from onyx.context.search.models import InferenceSection
# from onyx.agents.agent_search.dr.states import LoggerUpdate
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
# from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
# WebSearchResult,
# )
# from onyx.context.search.models import InferenceSection
class InternetSearchInput(SubAgentInput):
results_to_open: Annotated[list[tuple[str, WebSearchResult]], add] = []
parallelization_nr: int = 0
branch_question: Annotated[str, lambda x, y: y] = ""
branch_questions_to_urls: Annotated[dict[str, list[str]], lambda x, y: y] = {}
raw_documents: Annotated[list[InferenceSection], add] = []
# class InternetSearchInput(SubAgentInput):
# results_to_open: Annotated[list[tuple[str, WebSearchResult]], add] = []
# parallelization_nr: int = 0
# branch_question: Annotated[str, lambda x, y: y] = ""
# branch_questions_to_urls: Annotated[dict[str, list[str]], lambda x, y: y] = {}
# raw_documents: Annotated[list[InferenceSection], add] = []
class InternetSearchUpdate(LoggerUpdate):
results_to_open: Annotated[list[tuple[str, WebSearchResult]], add] = []
# class InternetSearchUpdate(LoggerUpdate):
# results_to_open: Annotated[list[tuple[str, WebSearchResult]], add] = []
class FetchInput(SubAgentInput):
urls_to_open: Annotated[list[str], add] = []
branch_questions_to_urls: dict[str, list[str]]
raw_documents: Annotated[list[InferenceSection], add] = []
# class FetchInput(SubAgentInput):
# urls_to_open: Annotated[list[str], add] = []
# branch_questions_to_urls: dict[str, list[str]]
# raw_documents: Annotated[list[InferenceSection], add] = []
class FetchUpdate(LoggerUpdate):
raw_documents: Annotated[list[InferenceSection], add] = []
# class FetchUpdate(LoggerUpdate):
# raw_documents: Annotated[list[InferenceSection], add] = []
class SummarizeInput(SubAgentInput):
raw_documents: Annotated[list[InferenceSection], add] = []
branch_questions_to_urls: dict[str, list[str]]
branch_question: str
# class SummarizeInput(SubAgentInput):
# raw_documents: Annotated[list[InferenceSection], add] = []
# branch_questions_to_urls: dict[str, list[str]]
# branch_question: str

View File

@@ -1,99 +1,99 @@
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
WebContent,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
WebSearchResult,
)
from onyx.chat.models import DOCUMENT_CITATION_NUMBER_EMPTY_VALUE
from onyx.chat.models import LlmDoc
from onyx.configs.constants import DocumentSource
from onyx.context.search.models import InferenceChunk
from onyx.context.search.models import InferenceSection
# from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
# WebContent,
# )
# from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
# WebSearchResult,
# )
# from onyx.chat.models import DOCUMENT_CITATION_NUMBER_EMPTY_VALUE
# from onyx.chat.models import LlmDoc
# from onyx.configs.constants import DocumentSource
# from onyx.context.search.models import InferenceChunk
# from onyx.context.search.models import InferenceSection
def truncate_search_result_content(content: str, max_chars: int = 10000) -> str:
"""Truncate search result content to a maximum number of characters"""
if len(content) <= max_chars:
return content
return content[:max_chars] + "..."
# def truncate_search_result_content(content: str, max_chars: int = 10000) -> str:
# """Truncate search result content to a maximum number of characters"""
# if len(content) <= max_chars:
# return content
# return content[:max_chars] + "..."
def dummy_inference_section_from_internet_content(
result: WebContent,
) -> InferenceSection:
truncated_content = truncate_search_result_content(result.full_content)
return InferenceSection(
center_chunk=InferenceChunk(
chunk_id=0,
blurb=result.title,
content=truncated_content,
source_links={0: result.link},
section_continuation=False,
document_id="INTERNET_SEARCH_DOC_" + result.link,
source_type=DocumentSource.WEB,
semantic_identifier=result.link,
title=result.title,
boost=1,
recency_bias=1.0,
score=1.0,
hidden=(not result.scrape_successful),
metadata={},
match_highlights=[],
doc_summary=truncated_content,
chunk_context=truncated_content,
updated_at=result.published_date,
image_file_id=None,
),
chunks=[],
combined_content=truncated_content,
)
# def inference_section_from_internet_page_scrape(
# result: WebContent,
# ) -> InferenceSection:
# truncated_content = truncate_search_result_content(result.full_content)
# return InferenceSection(
# center_chunk=InferenceChunk(
# chunk_id=0,
# blurb=result.title,
# content=truncated_content,
# source_links={0: result.link},
# section_continuation=False,
# document_id="INTERNET_SEARCH_DOC_" + result.link,
# source_type=DocumentSource.WEB,
# semantic_identifier=result.link,
# title=result.title,
# boost=1,
# recency_bias=1.0,
# score=1.0,
# hidden=(not result.scrape_successful),
# metadata={},
# match_highlights=[],
# doc_summary=truncated_content,
# chunk_context=truncated_content,
# updated_at=result.published_date,
# image_file_id=None,
# ),
# chunks=[],
# combined_content=truncated_content,
# )
def dummy_inference_section_from_internet_search_result(
result: WebSearchResult,
) -> InferenceSection:
return InferenceSection(
center_chunk=InferenceChunk(
chunk_id=0,
blurb=result.title,
content="",
source_links={0: result.link},
section_continuation=False,
document_id="INTERNET_SEARCH_DOC_" + result.link,
source_type=DocumentSource.WEB,
semantic_identifier=result.link,
title=result.title,
boost=1,
recency_bias=1.0,
score=1.0,
hidden=False,
metadata={},
match_highlights=[],
doc_summary="",
chunk_context="",
updated_at=result.published_date,
image_file_id=None,
),
chunks=[],
combined_content="",
)
# def dummy_inference_section_from_internet_search_result(
# result: WebSearchResult,
# ) -> InferenceSection:
# return InferenceSection(
# center_chunk=InferenceChunk(
# chunk_id=0,
# blurb=result.title,
# content="",
# source_links={0: result.link},
# section_continuation=False,
# document_id="INTERNET_SEARCH_DOC_" + result.link,
# source_type=DocumentSource.WEB,
# semantic_identifier=result.link,
# title=result.title,
# boost=1,
# recency_bias=1.0,
# score=1.0,
# hidden=False,
# metadata={},
# match_highlights=[],
# doc_summary="",
# chunk_context="",
# updated_at=result.published_date,
# image_file_id=None,
# ),
# chunks=[],
# combined_content="",
# )
def llm_doc_from_web_content(web_content: WebContent) -> LlmDoc:
"""Create an LlmDoc from WebContent with the INTERNET_SEARCH_DOC_ prefix"""
return LlmDoc(
# TODO: Is this what we want to do for document_id? We're kind of overloading it since it
# should ideally correspond to a document in the database. But I guess if you're calling this
# function you know it won't be in the database.
document_id="INTERNET_SEARCH_DOC_" + web_content.link,
content=truncate_search_result_content(web_content.full_content),
blurb=web_content.link,
semantic_identifier=web_content.link,
source_type=DocumentSource.WEB,
metadata={},
link=web_content.link,
document_citation_number=DOCUMENT_CITATION_NUMBER_EMPTY_VALUE,
updated_at=web_content.published_date,
source_links={},
match_highlights=[],
)
# def llm_doc_from_web_content(web_content: WebContent) -> LlmDoc:
# """Create an LlmDoc from WebContent with the INTERNET_SEARCH_DOC_ prefix"""
# return LlmDoc(
# # TODO: Is this what we want to do for document_id? We're kind of overloading it since it
# # should ideally correspond to a document in the database. But I guess if you're calling this
# # function you know it won't be in the database.
# document_id="INTERNET_SEARCH_DOC_" + web_content.link,
# content=truncate_search_result_content(web_content.full_content),
# blurb=web_content.link,
# semantic_identifier=web_content.link,
# source_type=DocumentSource.WEB,
# metadata={},
# link=web_content.link,
# document_citation_number=DOCUMENT_CITATION_NUMBER_EMPTY_VALUE,
# updated_at=web_content.published_date,
# source_links={},
# match_highlights=[],
# )

View File

@@ -1,277 +1,277 @@
import copy
import re
# import copy
# import re
from langchain.schema.messages import BaseMessage
from langchain.schema.messages import HumanMessage
# from langchain.schema.messages import BaseMessage
# from langchain.schema.messages import HumanMessage
from onyx.agents.agent_search.dr.models import AggregatedDRContext
from onyx.agents.agent_search.dr.models import IterationAnswer
from onyx.agents.agent_search.dr.models import OrchestrationClarificationInfo
from onyx.agents.agent_search.kb_search.graph_utils import build_document_context
from onyx.agents.agent_search.shared_graph_utils.operators import (
dedup_inference_section_list,
)
from onyx.context.search.models import InferenceSection
from onyx.context.search.models import SavedSearchDoc
from onyx.context.search.models import SearchDoc
from onyx.tools.tool_implementations.web_search.web_search_tool import (
WebSearchTool,
)
# from onyx.agents.agent_search.dr.models import AggregatedDRContext
# from onyx.agents.agent_search.dr.models import IterationAnswer
# from onyx.agents.agent_search.dr.models import OrchestrationClarificationInfo
# from onyx.agents.agent_search.kb_search.graph_utils import build_document_context
# from onyx.agents.agent_search.shared_graph_utils.operators import (
# dedup_inference_section_list,
# )
# from onyx.context.search.models import InferenceSection
# from onyx.context.search.models import SavedSearchDoc
# from onyx.context.search.models import SearchDoc
# from onyx.tools.tool_implementations.web_search.web_search_tool import (
# WebSearchTool,
# )
CITATION_PREFIX = "CITE:"
# CITATION_PREFIX = "CITE:"
def extract_document_citations(
answer: str, claims: list[str]
) -> tuple[list[int], str, list[str]]:
"""
Finds all citations of the form [1], [2, 3], etc. and returns the list of cited indices,
as well as the answer and claims with the citations replaced with [<CITATION_PREFIX>1],
etc., to help with citation deduplication later on.
"""
citations: set[int] = set()
# def extract_document_citations(
# answer: str, claims: list[str]
# ) -> tuple[list[int], str, list[str]]:
# """
# Finds all citations of the form [1], [2, 3], etc. and returns the list of cited indices,
# as well as the answer and claims with the citations replaced with [<CITATION_PREFIX>1],
# etc., to help with citation deduplication later on.
# """
# citations: set[int] = set()
# Pattern to match both single citations [1] and multiple citations [1, 2, 3]
# This regex matches:
# - \[(\d+)\] for single citations like [1]
# - \[(\d+(?:,\s*\d+)*)\] for multiple citations like [1, 2, 3]
pattern = re.compile(r"\[(\d+(?:,\s*\d+)*)\]")
# # Pattern to match both single citations [1] and multiple citations [1, 2, 3]
# # This regex matches:
# # - \[(\d+)\] for single citations like [1]
# # - \[(\d+(?:,\s*\d+)*)\] for multiple citations like [1, 2, 3]
# pattern = re.compile(r"\[(\d+(?:,\s*\d+)*)\]")
def _extract_and_replace(match: re.Match[str]) -> str:
numbers = [int(num) for num in match.group(1).split(",")]
citations.update(numbers)
return "".join(f"[{CITATION_PREFIX}{num}]" for num in numbers)
# def _extract_and_replace(match: re.Match[str]) -> str:
# numbers = [int(num) for num in match.group(1).split(",")]
# citations.update(numbers)
# return "".join(f"[{CITATION_PREFIX}{num}]" for num in numbers)
new_answer = pattern.sub(_extract_and_replace, answer)
new_claims = [pattern.sub(_extract_and_replace, claim) for claim in claims]
# new_answer = pattern.sub(_extract_and_replace, answer)
# new_claims = [pattern.sub(_extract_and_replace, claim) for claim in claims]
return list(citations), new_answer, new_claims
# return list(citations), new_answer, new_claims
def aggregate_context(
iteration_responses: list[IterationAnswer], include_documents: bool = True
) -> AggregatedDRContext:
"""
Converts the iteration response into a single string with unified citations.
For example,
it 1: the answer is x [3][4]. {3: doc_abc, 4: doc_xyz}
it 2: blah blah [1, 3]. {1: doc_xyz, 3: doc_pqr}
Output:
it 1: the answer is x [1][2].
it 2: blah blah [2][3]
[1]: doc_xyz
[2]: doc_abc
[3]: doc_pqr
"""
# dedupe and merge inference section contents
unrolled_inference_sections: list[InferenceSection] = []
is_internet_marker_dict: dict[str, bool] = {}
for iteration_response in sorted(
iteration_responses,
key=lambda x: (x.iteration_nr, x.parallelization_nr),
):
# def aggregate_context(
# iteration_responses: list[IterationAnswer], include_documents: bool = True
# ) -> AggregatedDRContext:
# """
# Converts the iteration response into a single string with unified citations.
# For example,
# it 1: the answer is x [3][4]. {3: doc_abc, 4: doc_xyz}
# it 2: blah blah [1, 3]. {1: doc_xyz, 3: doc_pqr}
# Output:
# it 1: the answer is x [1][2].
# it 2: blah blah [2][3]
# [1]: doc_xyz
# [2]: doc_abc
# [3]: doc_pqr
# """
# # dedupe and merge inference section contents
# unrolled_inference_sections: list[InferenceSection] = []
# is_internet_marker_dict: dict[str, bool] = {}
# for iteration_response in sorted(
# iteration_responses,
# key=lambda x: (x.iteration_nr, x.parallelization_nr),
# ):
iteration_tool = iteration_response.tool
is_internet = iteration_tool == WebSearchTool._NAME
# iteration_tool = iteration_response.tool
# is_internet = iteration_tool == WebSearchTool._NAME
for cited_doc in iteration_response.cited_documents.values():
unrolled_inference_sections.append(cited_doc)
if cited_doc.center_chunk.document_id not in is_internet_marker_dict:
is_internet_marker_dict[cited_doc.center_chunk.document_id] = (
is_internet
)
cited_doc.center_chunk.score = None # None means maintain order
# for cited_doc in iteration_response.cited_documents.values():
# unrolled_inference_sections.append(cited_doc)
# if cited_doc.center_chunk.document_id not in is_internet_marker_dict:
# is_internet_marker_dict[cited_doc.center_chunk.document_id] = (
# is_internet
# )
# cited_doc.center_chunk.score = None # None means maintain order
global_documents = dedup_inference_section_list(unrolled_inference_sections)
# global_documents = dedup_inference_section_list(unrolled_inference_sections)
global_citations = {
doc.center_chunk.document_id: i for i, doc in enumerate(global_documents, 1)
}
# global_citations = {
# doc.center_chunk.document_id: i for i, doc in enumerate(global_documents, 1)
# }
# build output string
output_strings: list[str] = []
global_iteration_responses: list[IterationAnswer] = []
# # build output string
# output_strings: list[str] = []
# global_iteration_responses: list[IterationAnswer] = []
for iteration_response in sorted(
iteration_responses,
key=lambda x: (x.iteration_nr, x.parallelization_nr),
):
# add basic iteration info
output_strings.append(
f"Iteration: {iteration_response.iteration_nr}, "
f"Question {iteration_response.parallelization_nr}"
)
output_strings.append(f"Tool: {iteration_response.tool}")
output_strings.append(f"Question: {iteration_response.question}")
# for iteration_response in sorted(
# iteration_responses,
# key=lambda x: (x.iteration_nr, x.parallelization_nr),
# ):
# # add basic iteration info
# output_strings.append(
# f"Iteration: {iteration_response.iteration_nr}, "
# f"Question {iteration_response.parallelization_nr}"
# )
# output_strings.append(f"Tool: {iteration_response.tool}")
# output_strings.append(f"Question: {iteration_response.question}")
# get answer and claims with global citations
answer_str = iteration_response.answer
claims = iteration_response.claims or []
# # get answer and claims with global citations
# answer_str = iteration_response.answer
# claims = iteration_response.claims or []
iteration_citations: list[int] = []
for local_number, cited_doc in iteration_response.cited_documents.items():
global_number = global_citations[cited_doc.center_chunk.document_id]
# translate local citations to global citations
answer_str = answer_str.replace(
f"[{CITATION_PREFIX}{local_number}]", f"[{global_number}]"
)
claims = [
claim.replace(
f"[{CITATION_PREFIX}{local_number}]", f"[{global_number}]"
)
for claim in claims
]
iteration_citations.append(global_number)
# iteration_citations: list[int] = []
# for local_number, cited_doc in iteration_response.cited_documents.items():
# global_number = global_citations[cited_doc.center_chunk.document_id]
# # translate local citations to global citations
# answer_str = answer_str.replace(
# f"[{CITATION_PREFIX}{local_number}]", f"[{global_number}]"
# )
# claims = [
# claim.replace(
# f"[{CITATION_PREFIX}{local_number}]", f"[{global_number}]"
# )
# for claim in claims
# ]
# iteration_citations.append(global_number)
# add answer, claims, and citation info
if answer_str:
output_strings.append(f"Answer: {answer_str}")
if claims:
output_strings.append(
"Claims: " + "".join(f"\n - {claim}" for claim in claims or [])
or "No claims provided"
)
if not answer_str and not claims:
output_strings.append(
"Retrieved documents: "
+ (
"".join(
f"[{global_number}]"
for global_number in sorted(iteration_citations)
)
or "No documents retrieved"
)
)
output_strings.append("\n---\n")
# # add answer, claims, and citation info
# if answer_str:
# output_strings.append(f"Answer: {answer_str}")
# if claims:
# output_strings.append(
# "Claims: " + "".join(f"\n - {claim}" for claim in claims or [])
# or "No claims provided"
# )
# if not answer_str and not claims:
# output_strings.append(
# "Retrieved documents: "
# + (
# "".join(
# f"[{global_number}]"
# for global_number in sorted(iteration_citations)
# )
# or "No documents retrieved"
# )
# )
# output_strings.append("\n---\n")
# save global iteration response
iteration_response_copy = iteration_response.model_copy()
iteration_response_copy.answer = answer_str
iteration_response_copy.claims = claims
iteration_response_copy.cited_documents = {
global_citations[doc.center_chunk.document_id]: doc
for doc in iteration_response.cited_documents.values()
}
global_iteration_responses.append(iteration_response_copy)
# # save global iteration response
# iteration_response_copy = iteration_response.model_copy()
# iteration_response_copy.answer = answer_str
# iteration_response_copy.claims = claims
# iteration_response_copy.cited_documents = {
# global_citations[doc.center_chunk.document_id]: doc
# for doc in iteration_response.cited_documents.values()
# }
# global_iteration_responses.append(iteration_response_copy)
# add document contents if requested
if include_documents:
if global_documents:
output_strings.append("Cited document contents:")
for doc in global_documents:
output_strings.append(
build_document_context(
doc, global_citations[doc.center_chunk.document_id]
)
)
output_strings.append("\n---\n")
# # add document contents if requested
# if include_documents:
# if global_documents:
# output_strings.append("Cited document contents:")
# for doc in global_documents:
# output_strings.append(
# build_document_context(
# doc, global_citations[doc.center_chunk.document_id]
# )
# )
# output_strings.append("\n---\n")
return AggregatedDRContext(
context="\n".join(output_strings),
cited_documents=global_documents,
is_internet_marker_dict=is_internet_marker_dict,
global_iteration_responses=global_iteration_responses,
)
# return AggregatedDRContext(
# context="\n".join(output_strings),
# cited_documents=global_documents,
# is_internet_marker_dict=is_internet_marker_dict,
# global_iteration_responses=global_iteration_responses,
# )
def get_chat_history_string(chat_history: list[BaseMessage], max_messages: int) -> str:
"""
Get the chat history (up to max_messages) as a string.
"""
# get past max_messages USER, ASSISTANT message pairs
# def get_chat_history_string(chat_history: list[BaseMessage], max_messages: int) -> str:
# """
# Get the chat history (up to max_messages) as a string.
# """
# # get past max_messages USER, ASSISTANT message pairs
past_messages = chat_history[-max_messages * 2 :]
filtered_past_messages = copy.deepcopy(past_messages)
# past_messages = chat_history[-max_messages * 2 :]
# filtered_past_messages = copy.deepcopy(past_messages)
for past_message_number, past_message in enumerate(past_messages):
# for past_message_number, past_message in enumerate(past_messages):
if isinstance(past_message.content, list):
removal_indices = []
for content_piece_number, content_piece in enumerate(past_message.content):
if (
isinstance(content_piece, dict)
and content_piece.get("type") != "text"
):
removal_indices.append(content_piece_number)
# if isinstance(past_message.content, list):
# removal_indices = []
# for content_piece_number, content_piece in enumerate(past_message.content):
# if (
# isinstance(content_piece, dict)
# and content_piece.get("type") != "text"
# ):
# removal_indices.append(content_piece_number)
# Only rebuild the content list if there are items to remove
if removal_indices:
filtered_past_messages[past_message_number].content = [
content_piece
for content_piece_number, content_piece in enumerate(
past_message.content
)
if content_piece_number not in removal_indices
]
# # Only rebuild the content list if there are items to remove
# if removal_indices:
# filtered_past_messages[past_message_number].content = [
# content_piece
# for content_piece_number, content_piece in enumerate(
# past_message.content
# )
# if content_piece_number not in removal_indices
# ]
else:
continue
# else:
# continue
return (
"...\n" if len(chat_history) > len(filtered_past_messages) else ""
) + "\n".join(
("user" if isinstance(msg, HumanMessage) else "you")
+ f": {str(msg.content).strip()}"
for msg in filtered_past_messages
)
# return (
# "...\n" if len(chat_history) > len(filtered_past_messages) else ""
# ) + "\n".join(
# ("user" if isinstance(msg, HumanMessage) else "you")
# + f": {str(msg.content).strip()}"
# for msg in filtered_past_messages
# )
def get_prompt_question(
question: str, clarification: OrchestrationClarificationInfo | None
) -> str:
if clarification:
clarification_question = clarification.clarification_question
clarification_response = clarification.clarification_response
return (
f"Initial User Question: {question}\n"
f"(Clarification Question: {clarification_question}\n"
f"User Response: {clarification_response})"
)
# def get_prompt_question(
# question: str, clarification: OrchestrationClarificationInfo | None
# ) -> str:
# if clarification:
# clarification_question = clarification.clarification_question
# clarification_response = clarification.clarification_response
# return (
# f"Initial User Question: {question}\n"
# f"(Clarification Question: {clarification_question}\n"
# f"User Response: {clarification_response})"
# )
return question
# return question
def create_tool_call_string(tool_name: str, query_list: list[str]) -> str:
"""
Create a string representation of the tool call.
"""
questions_str = "\n - ".join(query_list)
return f"Tool: {tool_name}\n\nQuestions:\n{questions_str}"
# def create_tool_call_string(tool_name: str, query_list: list[str]) -> str:
# """
# Create a string representation of the tool call.
# """
# questions_str = "\n - ".join(query_list)
# return f"Tool: {tool_name}\n\nQuestions:\n{questions_str}"
def parse_plan_to_dict(plan_text: str) -> dict[str, str]:
# Convert plan string to numbered dict format
if not plan_text:
return {}
# def parse_plan_to_dict(plan_text: str) -> dict[str, str]:
# # Convert plan string to numbered dict format
# if not plan_text:
# return {}
# Split by numbered items (1., 2., 3., etc. or 1), 2), 3), etc.)
parts = re.split(r"(\d+[.)])", plan_text)
plan_dict = {}
# # Split by numbered items (1., 2., 3., etc. or 1), 2), 3), etc.)
# parts = re.split(r"(\d+[.)])", plan_text)
# plan_dict = {}
for i in range(
1, len(parts), 2
): # Skip empty first part, then take number and text pairs
if i + 1 < len(parts):
number = parts[i].rstrip(".)") # Remove the dot or parenthesis
text = parts[i + 1].strip()
if text: # Only add if there's actual content
plan_dict[number] = text
# for i in range(
# 1, len(parts), 2
# ): # Skip empty first part, then take number and text pairs
# if i + 1 < len(parts):
# number = parts[i].rstrip(".)") # Remove the dot or parenthesis
# text = parts[i + 1].strip()
# if text: # Only add if there's actual content
# plan_dict[number] = text
return plan_dict
# return plan_dict
def convert_inference_sections_to_search_docs(
inference_sections: list[InferenceSection],
is_internet: bool = False,
) -> list[SavedSearchDoc]:
# Convert InferenceSections to SavedSearchDocs
search_docs = SearchDoc.from_chunks_or_sections(inference_sections)
for search_doc in search_docs:
search_doc.is_internet = is_internet
# def convert_inference_sections_to_search_docs(
# inference_sections: list[InferenceSection],
# is_internet: bool = False,
# ) -> list[SavedSearchDoc]:
# # Convert InferenceSections to SavedSearchDocs
# search_docs = SearchDoc.from_chunks_or_sections(inference_sections)
# for search_doc in search_docs:
# search_doc.is_internet = is_internet
retrieved_saved_search_docs = [
SavedSearchDoc.from_search_doc(search_doc, db_doc_id=0)
for search_doc in search_docs
]
return retrieved_saved_search_docs
# retrieved_saved_search_docs = [
# SavedSearchDoc.from_search_doc(search_doc, db_doc_id=0)
# for search_doc in search_docs
# ]
# return retrieved_saved_search_docs

View File

@@ -1,87 +1,87 @@
from collections.abc import Hashable
from datetime import datetime
from enum import Enum
# from collections.abc import Hashable
# from datetime import datetime
# from enum import Enum
from langgraph.types import Send
# from langgraph.types import Send
from onyx.agents.agent_search.kb_search.states import KGAnswerStrategy
from onyx.agents.agent_search.kb_search.states import KGSearchType
from onyx.agents.agent_search.kb_search.states import KGSourceDivisionType
from onyx.agents.agent_search.kb_search.states import MainState
from onyx.agents.agent_search.kb_search.states import ResearchObjectInput
from onyx.configs.kg_configs import KG_MAX_DECOMPOSITION_SEGMENTS
from onyx.utils.logger import setup_logger
# from onyx.agents.agent_search.kb_search.states import KGAnswerStrategy
# from onyx.agents.agent_search.kb_search.states import KGSearchType
# from onyx.agents.agent_search.kb_search.states import KGSourceDivisionType
# from onyx.agents.agent_search.kb_search.states import MainState
# from onyx.agents.agent_search.kb_search.states import ResearchObjectInput
# from onyx.configs.kg_configs import KG_MAX_DECOMPOSITION_SEGMENTS
# from onyx.utils.logger import setup_logger
logger = setup_logger()
# logger = setup_logger()
class KGAnalysisPath(str, Enum):
PROCESS_KG_ONLY_ANSWERS = "process_kg_only_answers"
CONSTRUCT_DEEP_SEARCH_FILTERS = "construct_deep_search_filters"
# class KGAnalysisPath(str, Enum):
# PROCESS_KG_ONLY_ANSWERS = "process_kg_only_answers"
# CONSTRUCT_DEEP_SEARCH_FILTERS = "construct_deep_search_filters"
def simple_vs_search(
state: MainState,
) -> str:
# def simple_vs_search(
# state: MainState,
# ) -> str:
identified_strategy = state.updated_strategy or state.strategy
# identified_strategy = state.updated_strategy or state.strategy
if (
identified_strategy == KGAnswerStrategy.DEEP
or state.search_type == KGSearchType.SEARCH
):
return KGAnalysisPath.CONSTRUCT_DEEP_SEARCH_FILTERS.value
else:
return KGAnalysisPath.PROCESS_KG_ONLY_ANSWERS.value
# if (
# identified_strategy == KGAnswerStrategy.DEEP
# or state.search_type == KGSearchType.SEARCH
# ):
# return KGAnalysisPath.CONSTRUCT_DEEP_SEARCH_FILTERS.value
# else:
# return KGAnalysisPath.PROCESS_KG_ONLY_ANSWERS.value
def research_individual_object(
state: MainState,
) -> list[Send | Hashable] | str:
edge_start_time = datetime.now()
# def research_individual_object(
# state: MainState,
# ) -> list[Send | Hashable] | str:
# edge_start_time = datetime.now()
assert state.div_con_entities is not None
assert state.broken_down_question is not None
assert state.vespa_filter_results is not None
# assert state.div_con_entities is not None
# assert state.broken_down_question is not None
# assert state.vespa_filter_results is not None
if (
state.search_type == KGSearchType.SQL
and state.strategy == KGAnswerStrategy.DEEP
):
# if (
# state.search_type == KGSearchType.SQL
# and state.strategy == KGAnswerStrategy.DEEP
# ):
if state.source_filters and state.source_division:
segments = state.source_filters
segment_type = KGSourceDivisionType.SOURCE.value
else:
segments = state.div_con_entities
segment_type = KGSourceDivisionType.ENTITY.value
# if state.source_filters and state.source_division:
# segments = state.source_filters
# segment_type = KGSourceDivisionType.SOURCE.value
# else:
# segments = state.div_con_entities
# segment_type = KGSourceDivisionType.ENTITY.value
if segments and (len(segments) > KG_MAX_DECOMPOSITION_SEGMENTS):
logger.debug(f"Too many sources ({len(segments)}), usingfiltered search")
return "filtered_search"
# if segments and (len(segments) > KG_MAX_DECOMPOSITION_SEGMENTS):
# logger.debug(f"Too many sources ({len(segments)}), usingfiltered search")
# return "filtered_search"
return [
Send(
"process_individual_deep_search",
ResearchObjectInput(
research_nr=research_nr + 1,
entity=entity,
broken_down_question=state.broken_down_question,
vespa_filter_results=state.vespa_filter_results,
source_division=state.source_division,
source_entity_filters=state.source_filters,
segment_type=segment_type,
log_messages=[
f"{edge_start_time} -- Main Edge - Parallelize Initial Sub-question Answering"
],
step_results=[],
),
)
for research_nr, entity in enumerate(segments)
]
elif state.search_type == KGSearchType.SEARCH:
return "filtered_search"
else:
raise ValueError(
f"Invalid combination of search type: {state.search_type} and strategy: {state.strategy}"
)
# return [
# Send(
# "process_individual_deep_search",
# ResearchObjectInput(
# research_nr=research_nr + 1,
# entity=entity,
# broken_down_question=state.broken_down_question,
# vespa_filter_results=state.vespa_filter_results,
# source_division=state.source_division,
# source_entity_filters=state.source_filters,
# segment_type=segment_type,
# log_messages=[
# f"{edge_start_time} -- Main Edge - Parallelize Initial Sub-question Answering"
# ],
# step_results=[],
# ),
# )
# for research_nr, entity in enumerate(segments)
# ]
# elif state.search_type == KGSearchType.SEARCH:
# return "filtered_search"
# else:
# raise ValueError(
# f"Invalid combination of search type: {state.search_type} and strategy: {state.strategy}"
# )

View File

@@ -1,143 +1,143 @@
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
# from langgraph.graph import END
# from langgraph.graph import START
# from langgraph.graph import StateGraph
from onyx.agents.agent_search.kb_search.conditional_edges import (
research_individual_object,
)
from onyx.agents.agent_search.kb_search.conditional_edges import simple_vs_search
from onyx.agents.agent_search.kb_search.nodes.a1_extract_ert import extract_ert
from onyx.agents.agent_search.kb_search.nodes.a2_analyze import analyze
from onyx.agents.agent_search.kb_search.nodes.a3_generate_simple_sql import (
generate_simple_sql,
)
from onyx.agents.agent_search.kb_search.nodes.b1_construct_deep_search_filters import (
construct_deep_search_filters,
)
from onyx.agents.agent_search.kb_search.nodes.b2p_process_individual_deep_search import (
process_individual_deep_search,
)
from onyx.agents.agent_search.kb_search.nodes.b2s_filtered_search import filtered_search
from onyx.agents.agent_search.kb_search.nodes.b3_consolidate_individual_deep_search import (
consolidate_individual_deep_search,
)
from onyx.agents.agent_search.kb_search.nodes.c1_process_kg_only_answers import (
process_kg_only_answers,
)
from onyx.agents.agent_search.kb_search.nodes.d1_generate_answer import generate_answer
from onyx.agents.agent_search.kb_search.nodes.d2_logging_node import log_data
from onyx.agents.agent_search.kb_search.states import MainInput
from onyx.agents.agent_search.kb_search.states import MainState
from onyx.utils.logger import setup_logger
# from onyx.agents.agent_search.kb_search.conditional_edges import (
# research_individual_object,
# )
# from onyx.agents.agent_search.kb_search.conditional_edges import simple_vs_search
# from onyx.agents.agent_search.kb_search.nodes.a1_extract_ert import extract_ert
# from onyx.agents.agent_search.kb_search.nodes.a2_analyze import analyze
# from onyx.agents.agent_search.kb_search.nodes.a3_generate_simple_sql import (
# generate_simple_sql,
# )
# from onyx.agents.agent_search.kb_search.nodes.b1_construct_deep_search_filters import (
# construct_deep_search_filters,
# )
# from onyx.agents.agent_search.kb_search.nodes.b2p_process_individual_deep_search import (
# process_individual_deep_search,
# )
# from onyx.agents.agent_search.kb_search.nodes.b2s_filtered_search import filtered_search
# from onyx.agents.agent_search.kb_search.nodes.b3_consolidate_individual_deep_search import (
# consolidate_individual_deep_search,
# )
# from onyx.agents.agent_search.kb_search.nodes.c1_process_kg_only_answers import (
# process_kg_only_answers,
# )
# from onyx.agents.agent_search.kb_search.nodes.d1_generate_answer import generate_answer
# from onyx.agents.agent_search.kb_search.nodes.d2_logging_node import log_data
# from onyx.agents.agent_search.kb_search.states import MainInput
# from onyx.agents.agent_search.kb_search.states import MainState
# from onyx.utils.logger import setup_logger
logger = setup_logger()
# logger = setup_logger()
def kb_graph_builder() -> StateGraph:
"""
LangGraph graph builder for the knowledge graph search process.
"""
# def kb_graph_builder() -> StateGraph:
# """
# LangGraph graph builder for the knowledge graph search process.
# """
graph = StateGraph(
state_schema=MainState,
input=MainInput,
)
# graph = StateGraph(
# state_schema=MainState,
# input=MainInput,
# )
### Add nodes ###
# ### Add nodes ###
graph.add_node(
"extract_ert",
extract_ert,
)
# graph.add_node(
# "extract_ert",
# extract_ert,
# )
graph.add_node(
"generate_simple_sql",
generate_simple_sql,
)
# graph.add_node(
# "generate_simple_sql",
# generate_simple_sql,
# )
graph.add_node(
"filtered_search",
filtered_search,
)
# graph.add_node(
# "filtered_search",
# filtered_search,
# )
graph.add_node(
"analyze",
analyze,
)
# graph.add_node(
# "analyze",
# analyze,
# )
graph.add_node(
"generate_answer",
generate_answer,
)
# graph.add_node(
# "generate_answer",
# generate_answer,
# )
graph.add_node(
"log_data",
log_data,
)
# graph.add_node(
# "log_data",
# log_data,
# )
graph.add_node(
"construct_deep_search_filters",
construct_deep_search_filters,
)
# graph.add_node(
# "construct_deep_search_filters",
# construct_deep_search_filters,
# )
graph.add_node(
"process_individual_deep_search",
process_individual_deep_search,
)
# graph.add_node(
# "process_individual_deep_search",
# process_individual_deep_search,
# )
graph.add_node(
"consolidate_individual_deep_search",
consolidate_individual_deep_search,
)
# graph.add_node(
# "consolidate_individual_deep_search",
# consolidate_individual_deep_search,
# )
graph.add_node("process_kg_only_answers", process_kg_only_answers)
# graph.add_node("process_kg_only_answers", process_kg_only_answers)
### Add edges ###
# ### Add edges ###
graph.add_edge(start_key=START, end_key="extract_ert")
# graph.add_edge(start_key=START, end_key="extract_ert")
graph.add_edge(
start_key="extract_ert",
end_key="analyze",
)
# graph.add_edge(
# start_key="extract_ert",
# end_key="analyze",
# )
graph.add_edge(
start_key="analyze",
end_key="generate_simple_sql",
)
# graph.add_edge(
# start_key="analyze",
# end_key="generate_simple_sql",
# )
graph.add_conditional_edges("generate_simple_sql", simple_vs_search)
# graph.add_conditional_edges("generate_simple_sql", simple_vs_search)
graph.add_edge(start_key="process_kg_only_answers", end_key="generate_answer")
# graph.add_edge(start_key="process_kg_only_answers", end_key="generate_answer")
graph.add_conditional_edges(
source="construct_deep_search_filters",
path=research_individual_object,
path_map=["process_individual_deep_search", "filtered_search"],
)
# graph.add_conditional_edges(
# source="construct_deep_search_filters",
# path=research_individual_object,
# path_map=["process_individual_deep_search", "filtered_search"],
# )
graph.add_edge(
start_key="process_individual_deep_search",
end_key="consolidate_individual_deep_search",
)
# graph.add_edge(
# start_key="process_individual_deep_search",
# end_key="consolidate_individual_deep_search",
# )
graph.add_edge(
start_key="consolidate_individual_deep_search", end_key="generate_answer"
)
# graph.add_edge(
# start_key="consolidate_individual_deep_search", end_key="generate_answer"
# )
graph.add_edge(
start_key="filtered_search",
end_key="generate_answer",
)
# graph.add_edge(
# start_key="filtered_search",
# end_key="generate_answer",
# )
graph.add_edge(
start_key="generate_answer",
end_key="log_data",
)
# graph.add_edge(
# start_key="generate_answer",
# end_key="log_data",
# )
graph.add_edge(
start_key="log_data",
end_key=END,
)
# graph.add_edge(
# start_key="log_data",
# end_key=END,
# )
return graph
# return graph

View File

@@ -1,244 +1,244 @@
import re
# import re
from onyx.agents.agent_search.kb_search.models import KGEntityDocInfo
from onyx.agents.agent_search.kb_search.models import KGExpandedGraphObjects
from onyx.agents.agent_search.kb_search.states import SubQuestionAnswerResults
from onyx.agents.agent_search.kb_search.step_definitions import (
KG_SEARCH_STEP_DESCRIPTIONS,
)
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkRetrievalStats
from onyx.chat.models import LlmDoc
from onyx.context.search.models import InferenceChunk
from onyx.context.search.models import InferenceSection
from onyx.db.document import get_kg_doc_info_for_entity_name
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.entities import get_document_id_for_entity
from onyx.db.entities import get_entity_name
from onyx.db.entity_type import get_entity_types
from onyx.kg.utils.formatting_utils import make_entity_id
from onyx.kg.utils.formatting_utils import split_relationship_id
from onyx.utils.logger import setup_logger
# from onyx.agents.agent_search.kb_search.models import KGEntityDocInfo
# from onyx.agents.agent_search.kb_search.models import KGExpandedGraphObjects
# from onyx.agents.agent_search.kb_search.states import SubQuestionAnswerResults
# from onyx.agents.agent_search.kb_search.step_definitions import (
# KG_SEARCH_STEP_DESCRIPTIONS,
# )
# from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkRetrievalStats
# from onyx.chat.models import LlmDoc
# from onyx.context.search.models import InferenceChunk
# from onyx.context.search.models import InferenceSection
# from onyx.db.document import get_kg_doc_info_for_entity_name
# from onyx.db.engine.sql_engine import get_session_with_current_tenant
# from onyx.db.entities import get_document_id_for_entity
# from onyx.db.entities import get_entity_name
# from onyx.db.entity_type import get_entity_types
# from onyx.kg.utils.formatting_utils import make_entity_id
# from onyx.kg.utils.formatting_utils import split_relationship_id
# from onyx.utils.logger import setup_logger
logger = setup_logger()
# logger = setup_logger()
def _check_entities_disconnected(
current_entities: list[str], current_relationships: list[str]
) -> bool:
"""
Check if all entities in current_entities are disconnected via the given relationships.
Relationships are in the format: source_entity__relationship_name__target_entity
# def _check_entities_disconnected(
# current_entities: list[str], current_relationships: list[str]
# ) -> bool:
# """
# Check if all entities in current_entities are disconnected via the given relationships.
# Relationships are in the format: source_entity__relationship_name__target_entity
Args:
current_entities: List of entity IDs to check connectivity for
current_relationships: List of relationships in format source__relationship__target
# Args:
# current_entities: List of entity IDs to check connectivity for
# current_relationships: List of relationships in format source__relationship__target
Returns:
bool: True if all entities are disconnected, False otherwise
"""
if not current_entities:
return True
# Returns:
# bool: True if all entities are disconnected, False otherwise
# """
# if not current_entities:
# return True
# Create a graph representation using adjacency list
graph: dict[str, set[str]] = {entity: set() for entity in current_entities}
# # Create a graph representation using adjacency list
# graph: dict[str, set[str]] = {entity: set() for entity in current_entities}
# Build the graph from relationships
for relationship in current_relationships:
try:
source, _, target = split_relationship_id(relationship)
if source in graph and target in graph:
graph[source].add(target)
# Add reverse edge to capture that we do also have a relationship in the other direction,
# albeit not quite the same one.
graph[target].add(source)
except ValueError:
raise ValueError(f"Invalid relationship format: {relationship}")
# # Build the graph from relationships
# for relationship in current_relationships:
# try:
# source, _, target = split_relationship_id(relationship)
# if source in graph and target in graph:
# graph[source].add(target)
# # Add reverse edge to capture that we do also have a relationship in the other direction,
# # albeit not quite the same one.
# graph[target].add(source)
# except ValueError:
# raise ValueError(f"Invalid relationship format: {relationship}")
# Use BFS to check if all entities are connected
visited: set[str] = set()
start_entity = current_entities[0]
# # Use BFS to check if all entities are connected
# visited: set[str] = set()
# start_entity = current_entities[0]
def _bfs(start: str) -> None:
queue = [start]
visited.add(start)
while queue:
current = queue.pop(0)
for neighbor in graph[current]:
if neighbor not in visited:
visited.add(neighbor)
queue.append(neighbor)
# def _bfs(start: str) -> None:
# queue = [start]
# visited.add(start)
# while queue:
# current = queue.pop(0)
# for neighbor in graph[current]:
# if neighbor not in visited:
# visited.add(neighbor)
# queue.append(neighbor)
# Start BFS from the first entity
_bfs(start_entity)
# # Start BFS from the first entity
# _bfs(start_entity)
logger.debug(f"Number of visited entities: {len(visited)}")
# logger.debug(f"Number of visited entities: {len(visited)}")
# Check if all current_entities are in visited
return not all(entity in visited for entity in current_entities)
# # Check if all current_entities are in visited
# return not all(entity in visited for entity in current_entities)
def create_minimal_connected_query_graph(
entities: list[str], relationships: list[str], max_depth: int = 1
) -> KGExpandedGraphObjects:
"""
TODO: Implement this. For now we'll trust the SQL generation to do the right thing.
Return the original entities and relationships.
"""
return KGExpandedGraphObjects(entities=entities, relationships=relationships)
# def create_minimal_connected_query_graph(
# entities: list[str], relationships: list[str], max_depth: int = 1
# ) -> KGExpandedGraphObjects:
# """
# TODO: Implement this. For now we'll trust the SQL generation to do the right thing.
# Return the original entities and relationships.
# """
# return KGExpandedGraphObjects(entities=entities, relationships=relationships)
def get_doc_information_for_entity(entity_id_name: str) -> KGEntityDocInfo:
"""
Get document information for an entity, including its semantic name and document details.
"""
if "::" not in entity_id_name:
return KGEntityDocInfo(
doc_id=None,
doc_semantic_id=None,
doc_link=None,
semantic_entity_name=entity_id_name,
semantic_linked_entity_name=entity_id_name,
)
# def get_doc_information_for_entity(entity_id_name: str) -> KGEntityDocInfo:
# """
# Get document information for an entity, including its semantic name and document details.
# """
# if "::" not in entity_id_name:
# return KGEntityDocInfo(
# doc_id=None,
# doc_semantic_id=None,
# doc_link=None,
# semantic_entity_name=entity_id_name,
# semantic_linked_entity_name=entity_id_name,
# )
entity_type, entity_name = map(str.strip, entity_id_name.split("::", 1))
# entity_type, entity_name = map(str.strip, entity_id_name.split("::", 1))
with get_session_with_current_tenant() as db_session:
entity_document_id = get_document_id_for_entity(db_session, entity_id_name)
if entity_document_id:
return get_kg_doc_info_for_entity_name(
db_session, entity_document_id, entity_type
)
else:
entity_actual_name = get_entity_name(db_session, entity_id_name)
# with get_session_with_current_tenant() as db_session:
# entity_document_id = get_document_id_for_entity(db_session, entity_id_name)
# if entity_document_id:
# return get_kg_doc_info_for_entity_name(
# db_session, entity_document_id, entity_type
# )
# else:
# entity_actual_name = get_entity_name(db_session, entity_id_name)
return KGEntityDocInfo(
doc_id=None,
doc_semantic_id=None,
doc_link=None,
semantic_entity_name=f"{entity_type} {entity_actual_name or entity_id_name}",
semantic_linked_entity_name=f"{entity_type} {entity_actual_name or entity_id_name}",
)
# return KGEntityDocInfo(
# doc_id=None,
# doc_semantic_id=None,
# doc_link=None,
# semantic_entity_name=f"{entity_type} {entity_actual_name or entity_id_name}",
# semantic_linked_entity_name=f"{entity_type} {entity_actual_name or entity_id_name}",
# )
def rename_entities_in_answer(answer: str) -> str:
"""
Process entity references in the answer string by:
1. Extracting all strings matching <str>:<str> or <str>: <str> patterns
2. Looking up these references in the entity table
3. Replacing valid references with their corresponding values
# def rename_entities_in_answer(answer: str) -> str:
# """
# Process entity references in the answer string by:
# 1. Extracting all strings matching <str>:<str> or <str>: <str> patterns
# 2. Looking up these references in the entity table
# 3. Replacing valid references with their corresponding values
Args:
answer: The input string containing potential entity references
# Args:
# answer: The input string containing potential entity references
Returns:
str: The processed string with entity references replaced
"""
logger.debug(f"Input answer: {answer}")
# Returns:
# str: The processed string with entity references replaced
# """
# logger.debug(f"Input answer: {answer}")
# Clean up any spaces around ::
answer = re.sub(r"::\s+", "::", answer)
# # Clean up any spaces around ::
# answer = re.sub(r"::\s+", "::", answer)
# Pattern to match entity_type::entity_name, with optional quotes
pattern = r"(?:')?([a-zA-Z0-9-]+)::([a-zA-Z0-9]+)(?:')?"
# # Pattern to match entity_type::entity_name, with optional quotes
# pattern = r"(?:')?([a-zA-Z0-9-]+)::([a-zA-Z0-9]+)(?:')?"
matches = list(re.finditer(pattern, answer))
logger.debug(f"Found {len(matches)} matches")
# matches = list(re.finditer(pattern, answer))
# logger.debug(f"Found {len(matches)} matches")
# get active entity types
with get_session_with_current_tenant() as db_session:
active_entity_types = [
x.id_name for x in get_entity_types(db_session, active=True)
]
logger.debug(f"Active entity types: {active_entity_types}")
# # get active entity types
# with get_session_with_current_tenant() as db_session:
# active_entity_types = [
# x.id_name for x in get_entity_types(db_session, active=True)
# ]
# logger.debug(f"Active entity types: {active_entity_types}")
# Create dictionary for processed references
processed_refs = {}
# # Create dictionary for processed references
# processed_refs = {}
for match in matches:
entity_type = match.group(1).upper().strip()
entity_name = match.group(2).strip()
potential_entity_id_name = make_entity_id(entity_type, entity_name)
# for match in matches:
# entity_type = match.group(1).upper().strip()
# entity_name = match.group(2).strip()
# potential_entity_id_name = make_entity_id(entity_type, entity_name)
if entity_type not in active_entity_types:
continue
# if entity_type not in active_entity_types:
# continue
replacement_candidate = get_doc_information_for_entity(potential_entity_id_name)
# replacement_candidate = get_doc_information_for_entity(potential_entity_id_name)
if replacement_candidate.doc_id:
# Store both the original match and the entity_id_name for replacement
processed_refs[match.group(0)] = (
replacement_candidate.semantic_linked_entity_name
)
else:
processed_refs[match.group(0)] = replacement_candidate.semantic_entity_name
# if replacement_candidate.doc_id:
# # Store both the original match and the entity_id_name for replacement
# processed_refs[match.group(0)] = (
# replacement_candidate.semantic_linked_entity_name
# )
# else:
# processed_refs[match.group(0)] = replacement_candidate.semantic_entity_name
# Replace all references in the answer
for ref, replacement in processed_refs.items():
answer = answer.replace(ref, replacement)
logger.debug(f"Replaced {ref} with {replacement}")
# # Replace all references in the answer
# for ref, replacement in processed_refs.items():
# answer = answer.replace(ref, replacement)
# logger.debug(f"Replaced {ref} with {replacement}")
return answer
# return answer
def build_document_context(
document: InferenceSection | LlmDoc, document_number: int
) -> str:
"""
Build a context string for a document.
"""
# def build_document_context(
# document: InferenceSection | LlmDoc, document_number: int
# ) -> str:
# """
# Build a context string for a document.
# """
metadata_list: list[str] = []
document_content: str | None = None
info_source: InferenceChunk | LlmDoc | None = None
info_content: str | None = None
# metadata_list: list[str] = []
# document_content: str | None = None
# info_source: InferenceChunk | LlmDoc | None = None
# info_content: str | None = None
if isinstance(document, InferenceSection):
info_source = document.center_chunk
info_content = document.combined_content
elif isinstance(document, LlmDoc):
info_source = document
info_content = document.content
# if isinstance(document, InferenceSection):
# info_source = document.center_chunk
# info_content = document.combined_content
# elif isinstance(document, LlmDoc):
# info_source = document
# info_content = document.content
for key, value in info_source.metadata.items():
metadata_list.append(f" - {key}: {value}")
# for key, value in info_source.metadata.items():
# metadata_list.append(f" - {key}: {value}")
if metadata_list:
metadata_str = "- Document Metadata:\n" + "\n".join(metadata_list)
else:
metadata_str = ""
# if metadata_list:
# metadata_str = "- Document Metadata:\n" + "\n".join(metadata_list)
# else:
# metadata_str = ""
# Construct document header with number and semantic identifier
doc_header = f"Document {str(document_number)}: {info_source.semantic_identifier}"
# # Construct document header with number and semantic identifier
# doc_header = f"Document {str(document_number)}: {info_source.semantic_identifier}"
# Combine all parts with proper spacing
document_content = f"{doc_header}\n\n{metadata_str}\n\n{info_content}"
# # Combine all parts with proper spacing
# document_content = f"{doc_header}\n\n{metadata_str}\n\n{info_content}"
return document_content
# return document_content
def get_near_empty_step_results(
step_number: int,
step_answer: str,
verified_reranked_documents: list[InferenceSection] = [],
) -> SubQuestionAnswerResults:
"""
Get near-empty step results from a list of step results.
"""
return SubQuestionAnswerResults(
question=KG_SEARCH_STEP_DESCRIPTIONS[step_number].description,
question_id="0_" + str(step_number),
answer=step_answer,
verified_high_quality=True,
sub_query_retrieval_results=[],
verified_reranked_documents=verified_reranked_documents,
context_documents=[],
cited_documents=[],
sub_question_retrieval_stats=AgentChunkRetrievalStats(
verified_count=None,
verified_avg_scores=None,
rejected_count=None,
rejected_avg_scores=None,
verified_doc_chunk_ids=[],
dismissed_doc_chunk_ids=[],
),
)
# def get_near_empty_step_results(
# step_number: int,
# step_answer: str,
# verified_reranked_documents: list[InferenceSection] = [],
# ) -> SubQuestionAnswerResults:
# """
# Get near-empty step results from a list of step results.
# """
# return SubQuestionAnswerResults(
# question=KG_SEARCH_STEP_DESCRIPTIONS[step_number].description,
# question_id="0_" + str(step_number),
# answer=step_answer,
# verified_high_quality=True,
# sub_query_retrieval_results=[],
# verified_reranked_documents=verified_reranked_documents,
# context_documents=[],
# cited_documents=[],
# sub_question_retrieval_stats=AgentChunkRetrievalStats(
# verified_count=None,
# verified_avg_scores=None,
# rejected_count=None,
# rejected_avg_scores=None,
# verified_doc_chunk_ids=[],
# dismissed_doc_chunk_ids=[],
# ),
# )

View File

@@ -1,55 +1,55 @@
from pydantic import BaseModel
# from pydantic import BaseModel
from onyx.agents.agent_search.kb_search.states import KGAnswerFormat
from onyx.agents.agent_search.kb_search.states import KGAnswerStrategy
from onyx.agents.agent_search.kb_search.states import KGRelationshipDetection
from onyx.agents.agent_search.kb_search.states import KGSearchType
from onyx.agents.agent_search.kb_search.states import YesNoEnum
# from onyx.agents.agent_search.kb_search.states import KGAnswerFormat
# from onyx.agents.agent_search.kb_search.states import KGAnswerStrategy
# from onyx.agents.agent_search.kb_search.states import KGRelationshipDetection
# from onyx.agents.agent_search.kb_search.states import KGSearchType
# from onyx.agents.agent_search.kb_search.states import YesNoEnum
class KGQuestionEntityExtractionResult(BaseModel):
entities: list[str]
time_filter: str | None
# class KGQuestionEntityExtractionResult(BaseModel):
# entities: list[str]
# time_filter: str | None
class KGViewNames(BaseModel):
allowed_docs_view_name: str
kg_relationships_view_name: str
kg_entity_view_name: str
# class KGViewNames(BaseModel):
# allowed_docs_view_name: str
# kg_relationships_view_name: str
# kg_entity_view_name: str
class KGAnswerApproach(BaseModel):
search_type: KGSearchType
search_strategy: KGAnswerStrategy
relationship_detection: KGRelationshipDetection
format: KGAnswerFormat
broken_down_question: str | None = None
divide_and_conquer: YesNoEnum | None = None
# class KGAnswerApproach(BaseModel):
# search_type: KGSearchType
# search_strategy: KGAnswerStrategy
# relationship_detection: KGRelationshipDetection
# format: KGAnswerFormat
# broken_down_question: str | None = None
# divide_and_conquer: YesNoEnum | None = None
class KGQuestionRelationshipExtractionResult(BaseModel):
relationships: list[str]
# class KGQuestionRelationshipExtractionResult(BaseModel):
# relationships: list[str]
class KGQuestionExtractionResult(BaseModel):
entities: list[str]
relationships: list[str]
time_filter: str | None
# class KGQuestionExtractionResult(BaseModel):
# entities: list[str]
# relationships: list[str]
# time_filter: str | None
class KGExpandedGraphObjects(BaseModel):
entities: list[str]
relationships: list[str]
# class KGExpandedGraphObjects(BaseModel):
# entities: list[str]
# relationships: list[str]
class KGSteps(BaseModel):
description: str
activities: list[str]
# class KGSteps(BaseModel):
# description: str
# activities: list[str]
class KGEntityDocInfo(BaseModel):
doc_id: str | None
doc_semantic_id: str | None
doc_link: str | None
semantic_entity_name: str
semantic_linked_entity_name: str
# class KGEntityDocInfo(BaseModel):
# doc_id: str | None
# doc_semantic_id: str | None
# doc_link: str | None
# semantic_entity_name: str
# semantic_linked_entity_name: str

View File

@@ -1,255 +1,255 @@
from datetime import datetime
from typing import cast
# from datetime import datetime
# from typing import cast
from langchain_core.messages import HumanMessage
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from pydantic import ValidationError
# from langchain_core.messages import HumanMessage
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
# from pydantic import ValidationError
from onyx.agents.agent_search.kb_search.graph_utils import get_near_empty_step_results
from onyx.agents.agent_search.kb_search.models import KGQuestionEntityExtractionResult
from onyx.agents.agent_search.kb_search.models import (
KGQuestionRelationshipExtractionResult,
)
from onyx.agents.agent_search.kb_search.states import EntityRelationshipExtractionUpdate
from onyx.agents.agent_search.kb_search.states import MainState
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.configs.kg_configs import KG_ENTITY_EXTRACTION_TIMEOUT
from onyx.configs.kg_configs import KG_RELATIONSHIP_EXTRACTION_TIMEOUT
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.kg_temp_view import create_views
from onyx.db.kg_temp_view import get_user_view_names
from onyx.db.relationships import get_allowed_relationship_type_pairs
from onyx.kg.utils.extraction_utils import get_entity_types_str
from onyx.kg.utils.extraction_utils import get_relationship_types_str
from onyx.prompts.kg_prompts import QUERY_ENTITY_EXTRACTION_PROMPT
from onyx.prompts.kg_prompts import QUERY_RELATIONSHIP_EXTRACTION_PROMPT
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
from shared_configs.contextvars import get_current_tenant_id
# from onyx.agents.agent_search.kb_search.graph_utils import get_near_empty_step_results
# from onyx.agents.agent_search.kb_search.models import KGQuestionEntityExtractionResult
# from onyx.agents.agent_search.kb_search.models import (
# KGQuestionRelationshipExtractionResult,
# )
# from onyx.agents.agent_search.kb_search.states import EntityRelationshipExtractionUpdate
# from onyx.agents.agent_search.kb_search.states import MainState
# from onyx.agents.agent_search.models import GraphConfig
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.configs.kg_configs import KG_ENTITY_EXTRACTION_TIMEOUT
# from onyx.configs.kg_configs import KG_RELATIONSHIP_EXTRACTION_TIMEOUT
# from onyx.db.engine.sql_engine import get_session_with_current_tenant
# from onyx.db.kg_temp_view import create_views
# from onyx.db.kg_temp_view import get_user_view_names
# from onyx.db.relationships import get_allowed_relationship_type_pairs
# from onyx.kg.utils.extraction_utils import get_entity_types_str
# from onyx.kg.utils.extraction_utils import get_relationship_types_str
# from onyx.prompts.kg_prompts import QUERY_ENTITY_EXTRACTION_PROMPT
# from onyx.prompts.kg_prompts import QUERY_RELATIONSHIP_EXTRACTION_PROMPT
# from onyx.utils.logger import setup_logger
# from onyx.utils.threadpool_concurrency import run_with_timeout
# from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
# logger = setup_logger()
def extract_ert(
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> EntityRelationshipExtractionUpdate:
"""
LangGraph node to start the agentic search process.
"""
# def extract_ert(
# state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
# ) -> EntityRelationshipExtractionUpdate:
# """
# LangGraph node to start the agentic search process.
# """
# recheck KG enablement at outset KG graph
# # recheck KG enablement at outset KG graph
if not config["metadata"]["config"].behavior.kg_config_settings.KG_ENABLED:
logger.error("KG approach is not enabled, the KG agent flow cannot run.")
raise ValueError("KG approach is not enabled, the KG agent flow cannot run.")
# if not config["metadata"]["config"].behavior.kg_config_settings.KG_ENABLED:
# logger.error("KG approach is not enabled, the KG agent flow cannot run.")
# raise ValueError("KG approach is not enabled, the KG agent flow cannot run.")
_KG_STEP_NR = 1
# _KG_STEP_NR = 1
node_start_time = datetime.now()
# node_start_time = datetime.now()
graph_config = cast(GraphConfig, config["metadata"]["config"])
# graph_config = cast(GraphConfig, config["metadata"]["config"])
if graph_config.tooling.search_tool is None:
raise ValueError("Search tool is not set")
elif graph_config.tooling.search_tool.user is None:
raise ValueError("User is not set")
else:
user_email = graph_config.tooling.search_tool.user.email
user_name = user_email.split("@")[0] or "unknown"
# if graph_config.tooling.search_tool is None:
# raise ValueError("Search tool is not set")
# elif graph_config.tooling.search_tool.user is None:
# raise ValueError("User is not set")
# else:
# user_email = graph_config.tooling.search_tool.user.email
# user_name = user_email.split("@")[0] or "unknown"
# first four lines duplicates from generate_initial_answer
question = state.question
today_date = datetime.now().strftime("%A, %Y-%m-%d")
# # first four lines duplicates from generate_initial_answer
# question = state.question
# today_date = datetime.now().strftime("%A, %Y-%m-%d")
all_entity_types = get_entity_types_str(active=True)
all_relationship_types = get_relationship_types_str(active=True)
# all_entity_types = get_entity_types_str(active=True)
# all_relationship_types = get_relationship_types_str(active=True)
# Create temporary views. TODO: move into parallel step, if ultimately materialized
tenant_id = get_current_tenant_id()
kg_views = get_user_view_names(user_email, tenant_id)
with get_session_with_current_tenant() as db_session:
create_views(
db_session,
tenant_id=tenant_id,
user_email=user_email,
allowed_docs_view_name=kg_views.allowed_docs_view_name,
kg_relationships_view_name=kg_views.kg_relationships_view_name,
kg_entity_view_name=kg_views.kg_entity_view_name,
)
# # Create temporary views. TODO: move into parallel step, if ultimately materialized
# tenant_id = get_current_tenant_id()
# kg_views = get_user_view_names(user_email, tenant_id)
# with get_session_with_current_tenant() as db_session:
# create_views(
# db_session,
# tenant_id=tenant_id,
# user_email=user_email,
# allowed_docs_view_name=kg_views.allowed_docs_view_name,
# kg_relationships_view_name=kg_views.kg_relationships_view_name,
# kg_entity_view_name=kg_views.kg_entity_view_name,
# )
### get the entities, terms, and filters
# ### get the entities, terms, and filters
query_extraction_pre_prompt = QUERY_ENTITY_EXTRACTION_PROMPT.format(
entity_types=all_entity_types,
relationship_types=all_relationship_types,
)
# query_extraction_pre_prompt = QUERY_ENTITY_EXTRACTION_PROMPT.format(
# entity_types=all_entity_types,
# relationship_types=all_relationship_types,
# )
query_extraction_prompt = (
query_extraction_pre_prompt.replace("---content---", question)
.replace("---today_date---", today_date)
.replace("---user_name---", f"EMPLOYEE:{user_name}")
.replace("{{", "{")
.replace("}}", "}")
)
# query_extraction_prompt = (
# query_extraction_pre_prompt.replace("---content---", question)
# .replace("---today_date---", today_date)
# .replace("---user_name---", f"EMPLOYEE:{user_name}")
# .replace("{{", "{")
# .replace("}}", "}")
# )
msg = [
HumanMessage(
content=query_extraction_prompt,
)
]
primary_llm = graph_config.tooling.primary_llm
# Grader
try:
llm_response = run_with_timeout(
KG_ENTITY_EXTRACTION_TIMEOUT,
primary_llm.invoke_langchain,
prompt=msg,
timeout_override=15,
max_tokens=300,
)
# msg = [
# HumanMessage(
# content=query_extraction_prompt,
# )
# ]
# primary_llm = graph_config.tooling.primary_llm
# # Grader
# try:
# llm_response = run_with_timeout(
# KG_ENTITY_EXTRACTION_TIMEOUT,
# primary_llm.invoke_langchain,
# prompt=msg,
# timeout_override=15,
# max_tokens=300,
# )
cleaned_response = (
str(llm_response.content)
.replace("{{", "{")
.replace("}}", "}")
.replace("```json\n", "")
.replace("\n```", "")
.replace("\n", "")
)
first_bracket = cleaned_response.find("{")
last_bracket = cleaned_response.rfind("}")
cleaned_response = cleaned_response[first_bracket : last_bracket + 1]
# cleaned_response = (
# str(llm_response.content)
# .replace("{{", "{")
# .replace("}}", "}")
# .replace("```json\n", "")
# .replace("\n```", "")
# .replace("\n", "")
# )
# first_bracket = cleaned_response.find("{")
# last_bracket = cleaned_response.rfind("}")
# cleaned_response = cleaned_response[first_bracket : last_bracket + 1]
entity_extraction_result = KGQuestionEntityExtractionResult.model_validate_json(
cleaned_response
)
except ValidationError:
logger.error("Failed to parse LLM response as JSON in Entity Extraction")
entity_extraction_result = KGQuestionEntityExtractionResult(
entities=[], time_filter=""
)
except Exception as e:
logger.error(f"Error in extract_ert: {e}")
entity_extraction_result = KGQuestionEntityExtractionResult(
entities=[], time_filter=""
)
# entity_extraction_result = KGQuestionEntityExtractionResult.model_validate_json(
# cleaned_response
# )
# except ValidationError:
# logger.error("Failed to parse LLM response as JSON in Entity Extraction")
# entity_extraction_result = KGQuestionEntityExtractionResult(
# entities=[], time_filter=""
# )
# except Exception as e:
# logger.error(f"Error in extract_ert: {e}")
# entity_extraction_result = KGQuestionEntityExtractionResult(
# entities=[], time_filter=""
# )
# remove the attribute filters from the entities to for the purpose of the relationship
entities_no_attributes = [
entity.split("--")[0] for entity in entity_extraction_result.entities
]
ert_entities_string = f"Entities: {entities_no_attributes}\n"
# # remove the attribute filters from the entities to for the purpose of the relationship
# entities_no_attributes = [
# entity.split("--")[0] for entity in entity_extraction_result.entities
# ]
# ert_entities_string = f"Entities: {entities_no_attributes}\n"
### get the relationships
# ### get the relationships
# find the relationship types that match the extracted entity types
# # find the relationship types that match the extracted entity types
with get_session_with_current_tenant() as db_session:
allowed_relationship_pairs = get_allowed_relationship_type_pairs(
db_session, entity_extraction_result.entities
)
# with get_session_with_current_tenant() as db_session:
# allowed_relationship_pairs = get_allowed_relationship_type_pairs(
# db_session, entity_extraction_result.entities
# )
query_relationship_extraction_prompt = (
QUERY_RELATIONSHIP_EXTRACTION_PROMPT.replace("---question---", question)
.replace("---today_date---", today_date)
.replace(
"---relationship_type_options---",
" - " + "\n - ".join(allowed_relationship_pairs),
)
.replace("---identified_entities---", ert_entities_string)
.replace("---entity_types---", all_entity_types)
.replace("{{", "{")
.replace("}}", "}")
)
# query_relationship_extraction_prompt = (
# QUERY_RELATIONSHIP_EXTRACTION_PROMPT.replace("---question---", question)
# .replace("---today_date---", today_date)
# .replace(
# "---relationship_type_options---",
# " - " + "\n - ".join(allowed_relationship_pairs),
# )
# .replace("---identified_entities---", ert_entities_string)
# .replace("---entity_types---", all_entity_types)
# .replace("{{", "{")
# .replace("}}", "}")
# )
msg = [
HumanMessage(
content=query_relationship_extraction_prompt,
)
]
primary_llm = graph_config.tooling.primary_llm
# Grader
try:
llm_response = run_with_timeout(
KG_RELATIONSHIP_EXTRACTION_TIMEOUT,
primary_llm.invoke_langchain,
prompt=msg,
timeout_override=15,
max_tokens=300,
)
# msg = [
# HumanMessage(
# content=query_relationship_extraction_prompt,
# )
# ]
# primary_llm = graph_config.tooling.primary_llm
# # Grader
# try:
# llm_response = run_with_timeout(
# KG_RELATIONSHIP_EXTRACTION_TIMEOUT,
# primary_llm.invoke_langchain,
# prompt=msg,
# timeout_override=15,
# max_tokens=300,
# )
cleaned_response = (
str(llm_response.content)
.replace("{{", "{")
.replace("}}", "}")
.replace("```json\n", "")
.replace("\n```", "")
.replace("\n", "")
)
first_bracket = cleaned_response.find("{")
last_bracket = cleaned_response.rfind("}")
cleaned_response = cleaned_response[first_bracket : last_bracket + 1]
cleaned_response = cleaned_response.replace("{{", '{"')
cleaned_response = cleaned_response.replace("}}", '"}')
# cleaned_response = (
# str(llm_response.content)
# .replace("{{", "{")
# .replace("}}", "}")
# .replace("```json\n", "")
# .replace("\n```", "")
# .replace("\n", "")
# )
# first_bracket = cleaned_response.find("{")
# last_bracket = cleaned_response.rfind("}")
# cleaned_response = cleaned_response[first_bracket : last_bracket + 1]
# cleaned_response = cleaned_response.replace("{{", '{"')
# cleaned_response = cleaned_response.replace("}}", '"}')
try:
relationship_extraction_result = (
KGQuestionRelationshipExtractionResult.model_validate_json(
cleaned_response
)
)
except ValidationError:
logger.error(
"Failed to parse LLM response as JSON in Relationship Extraction"
)
relationship_extraction_result = KGQuestionRelationshipExtractionResult(
relationships=[],
)
except Exception as e:
logger.error(f"Error in extract_ert: {e}")
relationship_extraction_result = KGQuestionRelationshipExtractionResult(
relationships=[],
)
# try:
# relationship_extraction_result = (
# KGQuestionRelationshipExtractionResult.model_validate_json(
# cleaned_response
# )
# )
# except ValidationError:
# logger.error(
# "Failed to parse LLM response as JSON in Relationship Extraction"
# )
# relationship_extraction_result = KGQuestionRelationshipExtractionResult(
# relationships=[],
# )
# except Exception as e:
# logger.error(f"Error in extract_ert: {e}")
# relationship_extraction_result = KGQuestionRelationshipExtractionResult(
# relationships=[],
# )
## STEP 1
# Stream answer pieces out to the UI for Step 1
# ## STEP 1
# # Stream answer pieces out to the UI for Step 1
extracted_entity_string = " \n ".join(
[x.split("--")[0] for x in entity_extraction_result.entities]
)
extracted_relationship_string = " \n ".join(
relationship_extraction_result.relationships
)
# extracted_entity_string = " \n ".join(
# [x.split("--")[0] for x in entity_extraction_result.entities]
# )
# extracted_relationship_string = " \n ".join(
# relationship_extraction_result.relationships
# )
step_answer = f"""Entities and relationships have been extracted from query - \n \
Entities: {extracted_entity_string} - \n Relationships: {extracted_relationship_string}"""
# step_answer = f"""Entities and relationships have been extracted from query - \n \
# Entities: {extracted_entity_string} - \n Relationships: {extracted_relationship_string}"""
return EntityRelationshipExtractionUpdate(
entities_types_str=all_entity_types,
relationship_types_str=all_relationship_types,
extracted_entities_w_attributes=entity_extraction_result.entities,
extracted_entities_no_attributes=entities_no_attributes,
extracted_relationships=relationship_extraction_result.relationships,
time_filter=entity_extraction_result.time_filter,
kg_doc_temp_view_name=kg_views.allowed_docs_view_name,
kg_rel_temp_view_name=kg_views.kg_relationships_view_name,
kg_entity_temp_view_name=kg_views.kg_entity_view_name,
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="extract entities terms",
node_start_time=node_start_time,
)
],
step_results=[
get_near_empty_step_results(
step_number=_KG_STEP_NR,
step_answer=step_answer,
verified_reranked_documents=[],
)
],
)
# return EntityRelationshipExtractionUpdate(
# entities_types_str=all_entity_types,
# relationship_types_str=all_relationship_types,
# extracted_entities_w_attributes=entity_extraction_result.entities,
# extracted_entities_no_attributes=entities_no_attributes,
# extracted_relationships=relationship_extraction_result.relationships,
# time_filter=entity_extraction_result.time_filter,
# kg_doc_temp_view_name=kg_views.allowed_docs_view_name,
# kg_rel_temp_view_name=kg_views.kg_relationships_view_name,
# kg_entity_temp_view_name=kg_views.kg_entity_view_name,
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="main",
# node_name="extract entities terms",
# node_start_time=node_start_time,
# )
# ],
# step_results=[
# get_near_empty_step_results(
# step_number=_KG_STEP_NR,
# step_answer=step_answer,
# verified_reranked_documents=[],
# )
# ],
# )

View File

@@ -1,312 +1,312 @@
from datetime import datetime
from typing import cast
# from datetime import datetime
# from typing import cast
from langchain_core.messages import HumanMessage
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from langchain_core.messages import HumanMessage
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from onyx.agents.agent_search.kb_search.graph_utils import (
create_minimal_connected_query_graph,
)
from onyx.agents.agent_search.kb_search.graph_utils import get_near_empty_step_results
from onyx.agents.agent_search.kb_search.models import KGAnswerApproach
from onyx.agents.agent_search.kb_search.states import AnalysisUpdate
from onyx.agents.agent_search.kb_search.states import KGAnswerFormat
from onyx.agents.agent_search.kb_search.states import KGAnswerStrategy
from onyx.agents.agent_search.kb_search.states import KGRelationshipDetection
from onyx.agents.agent_search.kb_search.states import KGSearchType
from onyx.agents.agent_search.kb_search.states import MainState
from onyx.agents.agent_search.kb_search.states import YesNoEnum
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.configs.kg_configs import KG_STRATEGY_GENERATION_TIMEOUT
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.entities import get_document_id_for_entity
from onyx.kg.clustering.normalizations import normalize_entities
from onyx.kg.clustering.normalizations import normalize_relationships
from onyx.kg.utils.formatting_utils import split_relationship_id
from onyx.prompts.kg_prompts import STRATEGY_GENERATION_PROMPT
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
# from onyx.agents.agent_search.kb_search.graph_utils import (
# create_minimal_connected_query_graph,
# )
# from onyx.agents.agent_search.kb_search.graph_utils import get_near_empty_step_results
# from onyx.agents.agent_search.kb_search.models import KGAnswerApproach
# from onyx.agents.agent_search.kb_search.states import AnalysisUpdate
# from onyx.agents.agent_search.kb_search.states import KGAnswerFormat
# from onyx.agents.agent_search.kb_search.states import KGAnswerStrategy
# from onyx.agents.agent_search.kb_search.states import KGRelationshipDetection
# from onyx.agents.agent_search.kb_search.states import KGSearchType
# from onyx.agents.agent_search.kb_search.states import MainState
# from onyx.agents.agent_search.kb_search.states import YesNoEnum
# from onyx.agents.agent_search.models import GraphConfig
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.configs.kg_configs import KG_STRATEGY_GENERATION_TIMEOUT
# from onyx.db.engine.sql_engine import get_session_with_current_tenant
# from onyx.db.entities import get_document_id_for_entity
# from onyx.kg.clustering.normalizations import normalize_entities
# from onyx.kg.clustering.normalizations import normalize_relationships
# from onyx.kg.utils.formatting_utils import split_relationship_id
# from onyx.prompts.kg_prompts import STRATEGY_GENERATION_PROMPT
# from onyx.utils.logger import setup_logger
# from onyx.utils.threadpool_concurrency import run_with_timeout
logger = setup_logger()
# logger = setup_logger()
def _articulate_normalizations(
entity_normalization_map: dict[str, str],
relationship_normalization_map: dict[str, str],
) -> str:
# def _articulate_normalizations(
# entity_normalization_map: dict[str, str],
# relationship_normalization_map: dict[str, str],
# ) -> str:
remark_list: list[str] = []
# remark_list: list[str] = []
if entity_normalization_map:
remark_list.append("\n Entities:")
for extracted_entity, normalized_entity in entity_normalization_map.items():
remark_list.append(f" - {extracted_entity} -> {normalized_entity}")
# if entity_normalization_map:
# remark_list.append("\n Entities:")
# for extracted_entity, normalized_entity in entity_normalization_map.items():
# remark_list.append(f" - {extracted_entity} -> {normalized_entity}")
if relationship_normalization_map:
remark_list.append(" \n Relationships:")
for (
extracted_relationship,
normalized_relationship,
) in relationship_normalization_map.items():
remark_list.append(
f" - {extracted_relationship} -> {normalized_relationship}"
)
# if relationship_normalization_map:
# remark_list.append(" \n Relationships:")
# for (
# extracted_relationship,
# normalized_relationship,
# ) in relationship_normalization_map.items():
# remark_list.append(
# f" - {extracted_relationship} -> {normalized_relationship}"
# )
return " \n ".join(remark_list)
# return " \n ".join(remark_list)
def _get_fully_connected_entities(
entities: list[str], relationships: list[str]
) -> list[str]:
"""
Analyze the connectedness of the entities and relationships.
"""
# Build a dictionary to track connections for each entity
entity_connections: dict[str, set[str]] = {entity: set() for entity in entities}
# def _get_fully_connected_entities(
# entities: list[str], relationships: list[str]
# ) -> list[str]:
# """
# Analyze the connectedness of the entities and relationships.
# """
# # Build a dictionary to track connections for each entity
# entity_connections: dict[str, set[str]] = {entity: set() for entity in entities}
# Parse relationships to build connection graph
for relationship in relationships:
# Split relationship into parts. Test for proper formatting just in case.
# Should never be an error though at this point.
parts = split_relationship_id(relationship)
if len(parts) != 3:
raise ValueError(f"Invalid relationship: {relationship}")
# # Parse relationships to build connection graph
# for relationship in relationships:
# # Split relationship into parts. Test for proper formatting just in case.
# # Should never be an error though at this point.
# parts = split_relationship_id(relationship)
# if len(parts) != 3:
# raise ValueError(f"Invalid relationship: {relationship}")
entity1 = parts[0]
entity2 = parts[2]
# entity1 = parts[0]
# entity2 = parts[2]
# Add bidirectional connections
if entity1 in entity_connections:
entity_connections[entity1].add(entity2)
if entity2 in entity_connections:
entity_connections[entity2].add(entity1)
# # Add bidirectional connections
# if entity1 in entity_connections:
# entity_connections[entity1].add(entity2)
# if entity2 in entity_connections:
# entity_connections[entity2].add(entity1)
# Find entities connected to all others
fully_connected_entities = []
all_entities = set(entities)
# # Find entities connected to all others
# fully_connected_entities = []
# all_entities = set(entities)
for entity, connections in entity_connections.items():
# Check if this entity is connected to all other entities
if connections == all_entities - {entity}:
fully_connected_entities.append(entity)
# for entity, connections in entity_connections.items():
# # Check if this entity is connected to all other entities
# if connections == all_entities - {entity}:
# fully_connected_entities.append(entity)
return fully_connected_entities
# return fully_connected_entities
def _check_for_single_doc(
normalized_entities: list[str],
raw_entities: list[str],
normalized_relationship_strings: list[str],
raw_relationships: list[str],
normalized_time_filter: str | None,
) -> str | None:
"""
Check if the query is for a single document, like 'Summarize ticket ENG-2243K'.
None is returned if the query is not for a single document.
"""
if (
len(normalized_entities) == 1
and len(raw_entities) == 1
and len(normalized_relationship_strings) == 0
and len(raw_relationships) == 0
and normalized_time_filter is None
):
with get_session_with_current_tenant() as db_session:
single_doc_id = get_document_id_for_entity(
db_session, normalized_entities[0]
)
else:
single_doc_id = None
return single_doc_id
# def _check_for_single_doc(
# normalized_entities: list[str],
# raw_entities: list[str],
# normalized_relationship_strings: list[str],
# raw_relationships: list[str],
# normalized_time_filter: str | None,
# ) -> str | None:
# """
# Check if the query is for a single document, like 'Summarize ticket ENG-2243K'.
# None is returned if the query is not for a single document.
# """
# if (
# len(normalized_entities) == 1
# and len(raw_entities) == 1
# and len(normalized_relationship_strings) == 0
# and len(raw_relationships) == 0
# and normalized_time_filter is None
# ):
# with get_session_with_current_tenant() as db_session:
# single_doc_id = get_document_id_for_entity(
# db_session, normalized_entities[0]
# )
# else:
# single_doc_id = None
# return single_doc_id
def analyze(
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> AnalysisUpdate:
"""
LangGraph node to start the agentic search process.
"""
# def analyze(
# state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
# ) -> AnalysisUpdate:
# """
# LangGraph node to start the agentic search process.
# """
_KG_STEP_NR = 2
# _KG_STEP_NR = 2
node_start_time = datetime.now()
# node_start_time = datetime.now()
graph_config = cast(GraphConfig, config["metadata"]["config"])
question = state.question
entities = (
state.extracted_entities_no_attributes
) # attribute knowledge is not required for this step
relationships = state.extracted_relationships
time_filter = state.time_filter
# graph_config = cast(GraphConfig, config["metadata"]["config"])
# question = state.question
# entities = (
# state.extracted_entities_no_attributes
# ) # attribute knowledge is not required for this step
# relationships = state.extracted_relationships
# time_filter = state.time_filter
## STEP 2 - stream out goals
# ## STEP 2 - stream out goals
# Continue with node
# # Continue with node
normalized_entities = normalize_entities(
entities,
state.extracted_entities_w_attributes,
allowed_docs_temp_view_name=state.kg_doc_temp_view_name,
)
# normalized_entities = normalize_entities(
# entities,
# state.extracted_entities_w_attributes,
# allowed_docs_temp_view_name=state.kg_doc_temp_view_name,
# )
normalized_relationships = normalize_relationships(
relationships, normalized_entities.entity_normalization_map
)
normalized_time_filter = time_filter
# normalized_relationships = normalize_relationships(
# relationships, normalized_entities.entity_normalization_map
# )
# normalized_time_filter = time_filter
# If single-doc inquiry, send to single-doc processing directly
# # If single-doc inquiry, send to single-doc processing directly
single_doc_id = _check_for_single_doc(
normalized_entities=normalized_entities.entities,
raw_entities=entities,
normalized_relationship_strings=normalized_relationships.relationships,
raw_relationships=relationships,
normalized_time_filter=normalized_time_filter,
)
# single_doc_id = _check_for_single_doc(
# normalized_entities=normalized_entities.entities,
# raw_entities=entities,
# normalized_relationship_strings=normalized_relationships.relationships,
# raw_relationships=relationships,
# normalized_time_filter=normalized_time_filter,
# )
# Expand the entities and relationships to make sure that entities are connected
# # Expand the entities and relationships to make sure that entities are connected
graph_expansion = create_minimal_connected_query_graph(
normalized_entities.entities,
normalized_relationships.relationships,
max_depth=2,
)
# graph_expansion = create_minimal_connected_query_graph(
# normalized_entities.entities,
# normalized_relationships.relationships,
# max_depth=2,
# )
query_graph_entities = graph_expansion.entities
query_graph_relationships = graph_expansion.relationships
# query_graph_entities = graph_expansion.entities
# query_graph_relationships = graph_expansion.relationships
# Evaluate whether a search needs to be done after identifying all entities and relationships
# # Evaluate whether a search needs to be done after identifying all entities and relationships
strategy_generation_prompt = (
STRATEGY_GENERATION_PROMPT.replace(
"---entities---", "\n".join(query_graph_entities)
)
.replace("---relationships---", "\n".join(query_graph_relationships))
.replace("---possible_entities---", state.entities_types_str)
.replace("---possible_relationships---", state.relationship_types_str)
.replace("---question---", question)
)
# strategy_generation_prompt = (
# STRATEGY_GENERATION_PROMPT.replace(
# "---entities---", "\n".join(query_graph_entities)
# )
# .replace("---relationships---", "\n".join(query_graph_relationships))
# .replace("---possible_entities---", state.entities_types_str)
# .replace("---possible_relationships---", state.relationship_types_str)
# .replace("---question---", question)
# )
msg = [
HumanMessage(
content=strategy_generation_prompt,
)
]
primary_llm = graph_config.tooling.primary_llm
# Grader
try:
llm_response = run_with_timeout(
KG_STRATEGY_GENERATION_TIMEOUT,
# fast_llm.invoke,
primary_llm.invoke_langchain,
prompt=msg,
timeout_override=5,
max_tokens=100,
)
# msg = [
# HumanMessage(
# content=strategy_generation_prompt,
# )
# ]
# primary_llm = graph_config.tooling.primary_llm
# # Grader
# try:
# llm_response = run_with_timeout(
# KG_STRATEGY_GENERATION_TIMEOUT,
# # fast_llm.invoke,
# primary_llm.invoke_langchain,
# prompt=msg,
# timeout_override=5,
# max_tokens=100,
# )
cleaned_response = (
str(llm_response.content)
.replace("```json\n", "")
.replace("\n```", "")
.replace("\n", "")
)
first_bracket = cleaned_response.find("{")
last_bracket = cleaned_response.rfind("}")
cleaned_response = cleaned_response[first_bracket : last_bracket + 1]
# cleaned_response = (
# str(llm_response.content)
# .replace("```json\n", "")
# .replace("\n```", "")
# .replace("\n", "")
# )
# first_bracket = cleaned_response.find("{")
# last_bracket = cleaned_response.rfind("}")
# cleaned_response = cleaned_response[first_bracket : last_bracket + 1]
try:
approach_extraction_result = KGAnswerApproach.model_validate_json(
cleaned_response
)
search_type = approach_extraction_result.search_type
search_strategy = approach_extraction_result.search_strategy
relationship_detection = (
approach_extraction_result.relationship_detection.value
)
output_format = approach_extraction_result.format
broken_down_question = approach_extraction_result.broken_down_question
divide_and_conquer = approach_extraction_result.divide_and_conquer
except ValueError:
logger.error(
"Failed to parse LLM response as JSON in Entity-Term Extraction"
)
search_type = KGSearchType.SEARCH
search_strategy = KGAnswerStrategy.DEEP
relationship_detection = KGRelationshipDetection.RELATIONSHIPS.value
output_format = KGAnswerFormat.TEXT
broken_down_question = None
divide_and_conquer = YesNoEnum.NO
if search_strategy is None or output_format is None:
raise ValueError(f"Invalid strategy: {cleaned_response}")
# try:
# approach_extraction_result = KGAnswerApproach.model_validate_json(
# cleaned_response
# )
# search_type = approach_extraction_result.search_type
# search_strategy = approach_extraction_result.search_strategy
# relationship_detection = (
# approach_extraction_result.relationship_detection.value
# )
# output_format = approach_extraction_result.format
# broken_down_question = approach_extraction_result.broken_down_question
# divide_and_conquer = approach_extraction_result.divide_and_conquer
# except ValueError:
# logger.error(
# "Failed to parse LLM response as JSON in Entity-Term Extraction"
# )
# search_type = KGSearchType.SEARCH
# search_strategy = KGAnswerStrategy.DEEP
# relationship_detection = KGRelationshipDetection.RELATIONSHIPS.value
# output_format = KGAnswerFormat.TEXT
# broken_down_question = None
# divide_and_conquer = YesNoEnum.NO
# if search_strategy is None or output_format is None:
# raise ValueError(f"Invalid strategy: {cleaned_response}")
except Exception as e:
logger.error(f"Error in strategy generation: {e}")
raise e
# except Exception as e:
# logger.error(f"Error in strategy generation: {e}")
# raise e
# Stream out relevant results
# # Stream out relevant results
if single_doc_id:
search_strategy = (
KGAnswerStrategy.DEEP
) # if a single doc is identified, we will want to look at the details.
# if single_doc_id:
# search_strategy = (
# KGAnswerStrategy.DEEP
# ) # if a single doc is identified, we will want to look at the details.
step_answer = f"Strategy and format have been extracted from query. Strategy: {search_strategy.value}, \
Format: {output_format.value}, Broken down question: {broken_down_question}"
# step_answer = f"Strategy and format have been extracted from query. Strategy: {search_strategy.value}, \
# Format: {output_format.value}, Broken down question: {broken_down_question}"
extraction_detected_relationships = len(query_graph_relationships) > 0
if extraction_detected_relationships:
query_type = KGRelationshipDetection.RELATIONSHIPS.value
# extraction_detected_relationships = len(query_graph_relationships) > 0
# if extraction_detected_relationships:
# query_type = KGRelationshipDetection.RELATIONSHIPS.value
if extraction_detected_relationships:
logger.warning(
"Fyi - Extraction detected relationships: "
f"{extraction_detected_relationships}, "
f"but relationship detection: {relationship_detection}"
)
else:
query_type = KGRelationshipDetection.NO_RELATIONSHIPS.value
# if extraction_detected_relationships:
# logger.warning(
# "Fyi - Extraction detected relationships: "
# f"{extraction_detected_relationships}, "
# f"but relationship detection: {relationship_detection}"
# )
# else:
# query_type = KGRelationshipDetection.NO_RELATIONSHIPS.value
# End node
# # End node
return AnalysisUpdate(
normalized_core_entities=normalized_entities.entities,
normalized_core_relationships=normalized_relationships.relationships,
entity_normalization_map=normalized_entities.entity_normalization_map,
relationship_normalization_map=normalized_relationships.relationship_normalization_map,
query_graph_entities_no_attributes=query_graph_entities,
query_graph_entities_w_attributes=normalized_entities.entities_w_attributes,
query_graph_relationships=query_graph_relationships,
normalized_terms=[], # TODO: remove fully later
normalized_time_filter=normalized_time_filter,
strategy=search_strategy,
broken_down_question=broken_down_question,
output_format=output_format,
divide_and_conquer=divide_and_conquer,
single_doc_id=single_doc_id,
search_type=search_type,
query_type=query_type,
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="analyze",
node_start_time=node_start_time,
)
],
step_results=[
get_near_empty_step_results(
step_number=_KG_STEP_NR,
step_answer=step_answer,
verified_reranked_documents=[],
)
],
remarks=[
_articulate_normalizations(
entity_normalization_map=normalized_entities.entity_normalization_map,
relationship_normalization_map=normalized_relationships.relationship_normalization_map,
)
],
)
# return AnalysisUpdate(
# normalized_core_entities=normalized_entities.entities,
# normalized_core_relationships=normalized_relationships.relationships,
# entity_normalization_map=normalized_entities.entity_normalization_map,
# relationship_normalization_map=normalized_relationships.relationship_normalization_map,
# query_graph_entities_no_attributes=query_graph_entities,
# query_graph_entities_w_attributes=normalized_entities.entities_w_attributes,
# query_graph_relationships=query_graph_relationships,
# normalized_terms=[], # TODO: remove fully later
# normalized_time_filter=normalized_time_filter,
# strategy=search_strategy,
# broken_down_question=broken_down_question,
# output_format=output_format,
# divide_and_conquer=divide_and_conquer,
# single_doc_id=single_doc_id,
# search_type=search_type,
# query_type=query_type,
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="main",
# node_name="analyze",
# node_start_time=node_start_time,
# )
# ],
# step_results=[
# get_near_empty_step_results(
# step_number=_KG_STEP_NR,
# step_answer=step_answer,
# verified_reranked_documents=[],
# )
# ],
# remarks=[
# _articulate_normalizations(
# entity_normalization_map=normalized_entities.entity_normalization_map,
# relationship_normalization_map=normalized_relationships.relationship_normalization_map,
# )
# ],
# )

View File

@@ -1,185 +1,185 @@
from datetime import datetime
from typing import cast
# from datetime import datetime
# from typing import cast
from langchain_core.messages import HumanMessage
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from langchain_core.messages import HumanMessage
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from onyx.agents.agent_search.kb_search.states import DeepSearchFilterUpdate
from onyx.agents.agent_search.kb_search.states import KGFilterConstructionResults
from onyx.agents.agent_search.kb_search.states import MainState
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.configs.kg_configs import KG_FILTER_CONSTRUCTION_TIMEOUT
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.entity_type import get_entity_types_with_grounded_source_name
from onyx.kg.utils.formatting_utils import make_entity_id
from onyx.prompts.kg_prompts import SEARCH_FILTER_CONSTRUCTION_PROMPT
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
# from onyx.agents.agent_search.kb_search.states import DeepSearchFilterUpdate
# from onyx.agents.agent_search.kb_search.states import KGFilterConstructionResults
# from onyx.agents.agent_search.kb_search.states import MainState
# from onyx.agents.agent_search.models import GraphConfig
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.configs.kg_configs import KG_FILTER_CONSTRUCTION_TIMEOUT
# from onyx.db.engine.sql_engine import get_session_with_current_tenant
# from onyx.db.entity_type import get_entity_types_with_grounded_source_name
# from onyx.kg.utils.formatting_utils import make_entity_id
# from onyx.prompts.kg_prompts import SEARCH_FILTER_CONSTRUCTION_PROMPT
# from onyx.utils.logger import setup_logger
# from onyx.utils.threadpool_concurrency import run_with_timeout
logger = setup_logger()
# logger = setup_logger()
def construct_deep_search_filters(
state: MainState, config: RunnableConfig, writer: StreamWriter
) -> DeepSearchFilterUpdate:
"""
LangGraph node to start the agentic search process.
"""
# def construct_deep_search_filters(
# state: MainState, config: RunnableConfig, writer: StreamWriter
# ) -> DeepSearchFilterUpdate:
# """
# LangGraph node to start the agentic search process.
# """
node_start_time = datetime.now()
# node_start_time = datetime.now()
graph_config = cast(GraphConfig, config["metadata"]["config"])
question = state.question
# graph_config = cast(GraphConfig, config["metadata"]["config"])
# question = state.question
entities_types_str = state.entities_types_str
entities = state.query_graph_entities_no_attributes
relationships = state.query_graph_relationships
simple_sql_query = state.sql_query
simple_sql_results = state.sql_query_results
source_document_results = state.source_document_results
if simple_sql_results:
simple_sql_results_str = "\n".join([str(x) for x in simple_sql_results])
else:
simple_sql_results_str = "(no SQL results generated)"
if source_document_results:
source_document_results_str = "\n".join(
[str(x) for x in source_document_results]
)
else:
source_document_results_str = "(no source document results generated)"
# entities_types_str = state.entities_types_str
# entities = state.query_graph_entities_no_attributes
# relationships = state.query_graph_relationships
# simple_sql_query = state.sql_query
# simple_sql_results = state.sql_query_results
# source_document_results = state.source_document_results
# if simple_sql_results:
# simple_sql_results_str = "\n".join([str(x) for x in simple_sql_results])
# else:
# simple_sql_results_str = "(no SQL results generated)"
# if source_document_results:
# source_document_results_str = "\n".join(
# [str(x) for x in source_document_results]
# )
# else:
# source_document_results_str = "(no source document results generated)"
search_filter_construction_prompt = (
SEARCH_FILTER_CONSTRUCTION_PROMPT.replace(
"---entity_type_descriptions---",
entities_types_str,
)
.replace(
"---entity_filters---",
"\n".join(entities),
)
.replace(
"---relationship_filters---",
"\n".join(relationships),
)
.replace(
"---sql_query---",
simple_sql_query or "(no SQL generated)",
)
.replace(
"---sql_results---",
simple_sql_results_str or "(no SQL results generated)",
)
.replace(
"---source_document_results---",
source_document_results_str or "(no source document results generated)",
)
.replace(
"---question---",
question,
)
)
# search_filter_construction_prompt = (
# SEARCH_FILTER_CONSTRUCTION_PROMPT.replace(
# "---entity_type_descriptions---",
# entities_types_str,
# )
# .replace(
# "---entity_filters---",
# "\n".join(entities),
# )
# .replace(
# "---relationship_filters---",
# "\n".join(relationships),
# )
# .replace(
# "---sql_query---",
# simple_sql_query or "(no SQL generated)",
# )
# .replace(
# "---sql_results---",
# simple_sql_results_str or "(no SQL results generated)",
# )
# .replace(
# "---source_document_results---",
# source_document_results_str or "(no source document results generated)",
# )
# .replace(
# "---question---",
# question,
# )
# )
msg = [
HumanMessage(
content=search_filter_construction_prompt,
)
]
llm = graph_config.tooling.primary_llm
# Grader
try:
llm_response = run_with_timeout(
KG_FILTER_CONSTRUCTION_TIMEOUT,
llm.invoke_langchain,
prompt=msg,
timeout_override=15,
max_tokens=1400,
)
# msg = [
# HumanMessage(
# content=search_filter_construction_prompt,
# )
# ]
# llm = graph_config.tooling.primary_llm
# # Grader
# try:
# llm_response = run_with_timeout(
# KG_FILTER_CONSTRUCTION_TIMEOUT,
# llm.invoke_langchain,
# prompt=msg,
# timeout_override=15,
# max_tokens=1400,
# )
cleaned_response = (
str(llm_response.content)
.replace("```json\n", "")
.replace("\n```", "")
.replace("\n", "")
)
first_bracket = cleaned_response.find("{")
last_bracket = cleaned_response.rfind("}")
# cleaned_response = (
# str(llm_response.content)
# .replace("```json\n", "")
# .replace("\n```", "")
# .replace("\n", "")
# )
# first_bracket = cleaned_response.find("{")
# last_bracket = cleaned_response.rfind("}")
if last_bracket == -1 or first_bracket == -1:
raise ValueError("No valid JSON found in LLM response - no brackets found")
cleaned_response = cleaned_response[first_bracket : last_bracket + 1]
cleaned_response = cleaned_response.replace("{{", '{"')
cleaned_response = cleaned_response.replace("}}", '"}')
# if last_bracket == -1 or first_bracket == -1:
# raise ValueError("No valid JSON found in LLM response - no brackets found")
# cleaned_response = cleaned_response[first_bracket : last_bracket + 1]
# cleaned_response = cleaned_response.replace("{{", '{"')
# cleaned_response = cleaned_response.replace("}}", '"}')
try:
# try:
filter_results = KGFilterConstructionResults.model_validate_json(
cleaned_response
)
except ValueError:
logger.error(
"Failed to parse LLM response as JSON in Entity-Term Extraction"
)
filter_results = KGFilterConstructionResults(
global_entity_filters=[],
global_relationship_filters=[],
local_entity_filters=[],
source_document_filters=[],
structure=[],
)
# filter_results = KGFilterConstructionResults.model_validate_json(
# cleaned_response
# )
# except ValueError:
# logger.error(
# "Failed to parse LLM response as JSON in Entity-Term Extraction"
# )
# filter_results = KGFilterConstructionResults(
# global_entity_filters=[],
# global_relationship_filters=[],
# local_entity_filters=[],
# source_document_filters=[],
# structure=[],
# )
except Exception as e:
logger.error(f"Error in construct_deep_search_filters: {e}")
filter_results = KGFilterConstructionResults(
global_entity_filters=[],
global_relationship_filters=[],
local_entity_filters=[],
source_document_filters=[],
structure=[],
)
# except Exception as e:
# logger.error(f"Error in construct_deep_search_filters: {e}")
# filter_results = KGFilterConstructionResults(
# global_entity_filters=[],
# global_relationship_filters=[],
# local_entity_filters=[],
# source_document_filters=[],
# structure=[],
# )
div_con_structure = filter_results.structure
# div_con_structure = filter_results.structure
logger.info(f"div_con_structure: {div_con_structure}")
# logger.info(f"div_con_structure: {div_con_structure}")
with get_session_with_current_tenant() as db_session:
double_grounded_entity_types = get_entity_types_with_grounded_source_name(
db_session
)
# with get_session_with_current_tenant() as db_session:
# double_grounded_entity_types = get_entity_types_with_grounded_source_name(
# db_session
# )
source_division = False
# source_division = False
if div_con_structure:
for entity_type in double_grounded_entity_types:
# entity_type is guaranteed to have grounded_source_name
if (
cast(str, entity_type.grounded_source_name).lower()
in div_con_structure[0].lower()
):
source_division = True
break
# if div_con_structure:
# for entity_type in double_grounded_entity_types:
# # entity_type is guaranteed to have grounded_source_name
# if (
# cast(str, entity_type.grounded_source_name).lower()
# in div_con_structure[0].lower()
# ):
# source_division = True
# break
return DeepSearchFilterUpdate(
vespa_filter_results=filter_results,
div_con_entities=div_con_structure,
source_division=source_division,
global_entity_filters=[
make_entity_id(global_filter, "*")
for global_filter in filter_results.global_entity_filters
],
global_relationship_filters=filter_results.global_relationship_filters,
local_entity_filters=filter_results.local_entity_filters,
source_filters=filter_results.source_document_filters,
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="construct deep search filters",
node_start_time=node_start_time,
)
],
step_results=[],
)
# return DeepSearchFilterUpdate(
# vespa_filter_results=filter_results,
# div_con_entities=div_con_structure,
# source_division=source_division,
# global_entity_filters=[
# make_entity_id(global_filter, "*")
# for global_filter in filter_results.global_entity_filters
# ],
# global_relationship_filters=filter_results.global_relationship_filters,
# local_entity_filters=filter_results.local_entity_filters,
# source_filters=filter_results.source_document_filters,
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="main",
# node_name="construct deep search filters",
# node_start_time=node_start_time,
# )
# ],
# step_results=[],
# )

View File

@@ -1,168 +1,168 @@
import copy
from datetime import datetime
from typing import cast
# import copy
# from datetime import datetime
# from typing import cast
from langchain_core.messages import HumanMessage
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from langchain_core.messages import HumanMessage
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from onyx.agents.agent_search.kb_search.graph_utils import build_document_context
from onyx.agents.agent_search.kb_search.ops import research
from onyx.agents.agent_search.kb_search.states import KGSourceDivisionType
from onyx.agents.agent_search.kb_search.states import ResearchObjectInput
from onyx.agents.agent_search.kb_search.states import ResearchObjectUpdate
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
trim_prompt_piece,
)
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.chat.models import LlmDoc
from onyx.configs.kg_configs import KG_MAX_SEARCH_DOCUMENTS
from onyx.configs.kg_configs import KG_OBJECT_SOURCE_RESEARCH_TIMEOUT
from onyx.context.search.models import InferenceSection
from onyx.kg.utils.formatting_utils import split_entity_id
from onyx.prompts.kg_prompts import KG_OBJECT_SOURCE_RESEARCH_PROMPT
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
# from onyx.agents.agent_search.kb_search.graph_utils import build_document_context
# from onyx.agents.agent_search.kb_search.ops import research
# from onyx.agents.agent_search.kb_search.states import KGSourceDivisionType
# from onyx.agents.agent_search.kb_search.states import ResearchObjectInput
# from onyx.agents.agent_search.kb_search.states import ResearchObjectUpdate
# from onyx.agents.agent_search.models import GraphConfig
# from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
# trim_prompt_piece,
# )
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.chat.models import LlmDoc
# from onyx.configs.kg_configs import KG_MAX_SEARCH_DOCUMENTS
# from onyx.configs.kg_configs import KG_OBJECT_SOURCE_RESEARCH_TIMEOUT
# from onyx.context.search.models import InferenceSection
# from onyx.kg.utils.formatting_utils import split_entity_id
# from onyx.prompts.kg_prompts import KG_OBJECT_SOURCE_RESEARCH_PROMPT
# from onyx.utils.logger import setup_logger
# from onyx.utils.threadpool_concurrency import run_with_timeout
logger = setup_logger()
# logger = setup_logger()
def process_individual_deep_search(
state: ResearchObjectInput,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> ResearchObjectUpdate:
"""
LangGraph node to start the agentic search process.
"""
# def process_individual_deep_search(
# state: ResearchObjectInput,
# config: RunnableConfig,
# writer: StreamWriter = lambda _: None,
# ) -> ResearchObjectUpdate:
# """
# LangGraph node to start the agentic search process.
# """
node_start_time = datetime.now()
# node_start_time = datetime.now()
graph_config = cast(GraphConfig, config["metadata"]["config"])
search_tool = graph_config.tooling.search_tool
question = state.broken_down_question
segment_type = state.segment_type
# graph_config = cast(GraphConfig, config["metadata"]["config"])
# search_tool = graph_config.tooling.search_tool
# question = state.broken_down_question
# segment_type = state.segment_type
object = state.entity.replace("::", ":: ").lower()
# object = state.entity.replace("::", ":: ").lower()
if not search_tool:
raise ValueError("search_tool is not provided")
# if not search_tool:
# raise ValueError("search_tool is not provided")
state.research_nr
# state.research_nr
if segment_type == KGSourceDivisionType.ENTITY.value:
# if segment_type == KGSourceDivisionType.ENTITY.value:
object_id = split_entity_id(object)[1].strip()
extended_question = f"{question} in regards to {object}"
source_filters = state.source_entity_filters
# object_id = split_entity_id(object)[1].strip()
# extended_question = f"{question} in regards to {object}"
# source_filters = state.source_entity_filters
# TODO: this does not really occur in V1. But needs to be changed for V2
raw_kg_entity_filters = copy.deepcopy(
list(
set((state.vespa_filter_results.global_entity_filters + [state.entity]))
)
)
# # TODO: this does not really occur in V1. But needs to be changed for V2
# raw_kg_entity_filters = copy.deepcopy(
# list(
# set((state.vespa_filter_results.global_entity_filters + [state.entity]))
# )
# )
kg_entity_filters = []
for raw_kg_entity_filter in raw_kg_entity_filters:
if "::" not in raw_kg_entity_filter:
raw_kg_entity_filter += "::*"
kg_entity_filters.append(raw_kg_entity_filter)
# kg_entity_filters = []
# for raw_kg_entity_filter in raw_kg_entity_filters:
# if "::" not in raw_kg_entity_filter:
# raw_kg_entity_filter += "::*"
# kg_entity_filters.append(raw_kg_entity_filter)
kg_relationship_filters = copy.deepcopy(
state.vespa_filter_results.global_relationship_filters
)
# kg_relationship_filters = copy.deepcopy(
# state.vespa_filter_results.global_relationship_filters
# )
logger.debug("Research for object: " + object)
logger.debug(f"kg_entity_filters: {kg_entity_filters}")
logger.debug(f"kg_relationship_filters: {kg_relationship_filters}")
# logger.debug("Research for object: " + object)
# logger.debug(f"kg_entity_filters: {kg_entity_filters}")
# logger.debug(f"kg_relationship_filters: {kg_relationship_filters}")
else:
# if we came through the entity view route, in KG V1 the state entity
# is the document to search for. No need to set other filters then.
object_id = state.entity # source doc in this case
extended_question = f"{question}"
source_filters = [object_id]
# else:
# # if we came through the entity view route, in KG V1 the state entity
# # is the document to search for. No need to set other filters then.
# object_id = state.entity # source doc in this case
# extended_question = f"{question}"
# source_filters = [object_id]
kg_entity_filters = None
kg_relationship_filters = None
# kg_entity_filters = None
# kg_relationship_filters = None
if source_filters and (len(source_filters) > KG_MAX_SEARCH_DOCUMENTS):
logger.debug(
f"Too many sources ({len(source_filters)}), setting to None and effectively filtered search"
)
source_filters = None
# if source_filters and (len(source_filters) > KG_MAX_SEARCH_DOCUMENTS):
# logger.debug(
# f"Too many sources ({len(source_filters)}), setting to None and effectively filtered search"
# )
# source_filters = None
retrieved_docs = research(
question=extended_question,
kg_entities=kg_entity_filters,
kg_relationships=kg_relationship_filters,
kg_sources=source_filters,
search_tool=search_tool,
)
# retrieved_docs = research(
# question=extended_question,
# kg_entities=kg_entity_filters,
# kg_relationships=kg_relationship_filters,
# kg_sources=source_filters,
# search_tool=search_tool,
# )
document_texts_list = []
# document_texts_list = []
for doc_num, retrieved_doc in enumerate(retrieved_docs):
if not isinstance(retrieved_doc, (InferenceSection, LlmDoc)):
raise ValueError(f"Unexpected document type: {type(retrieved_doc)}")
chunk_text = build_document_context(retrieved_doc, doc_num + 1)
document_texts_list.append(chunk_text)
# for doc_num, retrieved_doc in enumerate(retrieved_docs):
# if not isinstance(retrieved_doc, (InferenceSection, LlmDoc)):
# raise ValueError(f"Unexpected document type: {type(retrieved_doc)}")
# chunk_text = build_document_context(retrieved_doc, doc_num + 1)
# document_texts_list.append(chunk_text)
document_texts = "\n\n".join(document_texts_list)
# document_texts = "\n\n".join(document_texts_list)
# Built prompt
# # Built prompt
kg_object_source_research_prompt = KG_OBJECT_SOURCE_RESEARCH_PROMPT.format(
question=extended_question,
document_text=document_texts,
)
# kg_object_source_research_prompt = KG_OBJECT_SOURCE_RESEARCH_PROMPT.format(
# question=extended_question,
# document_text=document_texts,
# )
# Run LLM
# # Run LLM
msg = [
HumanMessage(
content=trim_prompt_piece(
config=graph_config.tooling.primary_llm.config,
prompt_piece=kg_object_source_research_prompt,
reserved_str="",
),
)
]
primary_llm = graph_config.tooling.primary_llm
# Grader
try:
llm_response = run_with_timeout(
KG_OBJECT_SOURCE_RESEARCH_TIMEOUT,
primary_llm.invoke_langchain,
prompt=msg,
timeout_override=KG_OBJECT_SOURCE_RESEARCH_TIMEOUT,
max_tokens=300,
)
# msg = [
# HumanMessage(
# content=trim_prompt_piece(
# config=graph_config.tooling.primary_llm.config,
# prompt_piece=kg_object_source_research_prompt,
# reserved_str="",
# ),
# )
# ]
# primary_llm = graph_config.tooling.primary_llm
# # Grader
# try:
# llm_response = run_with_timeout(
# KG_OBJECT_SOURCE_RESEARCH_TIMEOUT,
# primary_llm.invoke_langchain,
# prompt=msg,
# timeout_override=KG_OBJECT_SOURCE_RESEARCH_TIMEOUT,
# max_tokens=300,
# )
object_research_results = str(llm_response.content).replace("```json\n", "")
# object_research_results = str(llm_response.content).replace("```json\n", "")
except Exception as e:
raise ValueError(f"Error in research_object_source: {e}")
# except Exception as e:
# raise ValueError(f"Error in research_object_source: {e}")
logger.debug("DivCon Step A2 - Object Source Research - completed for an object")
# logger.debug("DivCon Step A2 - Object Source Research - completed for an object")
return ResearchObjectUpdate(
research_object_results=[
{
"object": object.replace("::", ":: ").capitalize(),
"results": object_research_results,
}
],
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="process individual deep search",
node_start_time=node_start_time,
)
],
step_results=[],
)
# return ResearchObjectUpdate(
# research_object_results=[
# {
# "object": object.replace("::", ":: ").capitalize(),
# "results": object_research_results,
# }
# ],
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="main",
# node_name="process individual deep search",
# node_start_time=node_start_time,
# )
# ],
# step_results=[],
# )

View File

@@ -1,166 +1,166 @@
from datetime import datetime
from typing import cast
# from datetime import datetime
# from typing import cast
from langchain_core.messages import HumanMessage
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from langchain_core.messages import HumanMessage
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from onyx.agents.agent_search.kb_search.graph_utils import build_document_context
from onyx.agents.agent_search.kb_search.graph_utils import get_near_empty_step_results
from onyx.agents.agent_search.kb_search.ops import research
from onyx.agents.agent_search.kb_search.states import ConsolidatedResearchUpdate
from onyx.agents.agent_search.kb_search.states import MainState
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
trim_prompt_piece,
)
from onyx.agents.agent_search.shared_graph_utils.calculations import (
get_answer_generation_documents,
)
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.configs.kg_configs import KG_FILTERED_SEARCH_TIMEOUT
from onyx.configs.kg_configs import KG_RESEARCH_NUM_RETRIEVED_DOCS
from onyx.context.search.models import InferenceSection
from onyx.prompts.kg_prompts import KG_SEARCH_PROMPT
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
# from onyx.agents.agent_search.kb_search.graph_utils import build_document_context
# from onyx.agents.agent_search.kb_search.graph_utils import get_near_empty_step_results
# from onyx.agents.agent_search.kb_search.ops import research
# from onyx.agents.agent_search.kb_search.states import ConsolidatedResearchUpdate
# from onyx.agents.agent_search.kb_search.states import MainState
# from onyx.agents.agent_search.models import GraphConfig
# from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
# trim_prompt_piece,
# )
# from onyx.agents.agent_search.shared_graph_utils.calculations import (
# get_answer_generation_documents,
# )
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.configs.kg_configs import KG_FILTERED_SEARCH_TIMEOUT
# from onyx.configs.kg_configs import KG_RESEARCH_NUM_RETRIEVED_DOCS
# from onyx.context.search.models import InferenceSection
# from onyx.prompts.kg_prompts import KG_SEARCH_PROMPT
# from onyx.utils.logger import setup_logger
# from onyx.utils.threadpool_concurrency import run_with_timeout
logger = setup_logger()
# logger = setup_logger()
def filtered_search(
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> ConsolidatedResearchUpdate:
"""
LangGraph node to do a filtered search.
"""
_KG_STEP_NR = 4
# def filtered_search(
# state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
# ) -> ConsolidatedResearchUpdate:
# """
# LangGraph node to do a filtered search.
# """
# _KG_STEP_NR = 4
node_start_time = datetime.now()
# node_start_time = datetime.now()
graph_config = cast(GraphConfig, config["metadata"]["config"])
search_tool = graph_config.tooling.search_tool
question = state.question
# graph_config = cast(GraphConfig, config["metadata"]["config"])
# search_tool = graph_config.tooling.search_tool
# question = state.question
if not search_tool:
raise ValueError("search_tool is not provided")
# if not search_tool:
# raise ValueError("search_tool is not provided")
if not state.vespa_filter_results:
raise ValueError("vespa_filter_results is not provided")
raw_kg_entity_filters = list(
set((state.vespa_filter_results.global_entity_filters))
)
# if not state.vespa_filter_results:
# raise ValueError("vespa_filter_results is not provided")
# raw_kg_entity_filters = list(
# set((state.vespa_filter_results.global_entity_filters))
# )
kg_entity_filters = []
for raw_kg_entity_filter in raw_kg_entity_filters:
if "::" not in raw_kg_entity_filter:
raw_kg_entity_filter += "::*"
kg_entity_filters.append(raw_kg_entity_filter)
# kg_entity_filters = []
# for raw_kg_entity_filter in raw_kg_entity_filters:
# if "::" not in raw_kg_entity_filter:
# raw_kg_entity_filter += "::*"
# kg_entity_filters.append(raw_kg_entity_filter)
kg_relationship_filters = state.vespa_filter_results.global_relationship_filters
# kg_relationship_filters = state.vespa_filter_results.global_relationship_filters
logger.debug("Starting filtered search")
logger.debug(f"kg_entity_filters: {kg_entity_filters}")
logger.debug(f"kg_relationship_filters: {kg_relationship_filters}")
# logger.debug("Starting filtered search")
# logger.debug(f"kg_entity_filters: {kg_entity_filters}")
# logger.debug(f"kg_relationship_filters: {kg_relationship_filters}")
retrieved_docs = cast(
list[InferenceSection],
research(
question=question,
kg_entities=kg_entity_filters,
kg_relationships=kg_relationship_filters,
kg_sources=None,
search_tool=search_tool,
inference_sections_only=True,
),
)
# retrieved_docs = cast(
# list[InferenceSection],
# research(
# question=question,
# kg_entities=kg_entity_filters,
# kg_relationships=kg_relationship_filters,
# kg_sources=None,
# search_tool=search_tool,
# inference_sections_only=True,
# ),
# )
source_link_dict = {
num + 1: doc.center_chunk.source_links[0]
for num, doc in enumerate(retrieved_docs)
if doc.center_chunk.source_links
}
# source_link_dict = {
# num + 1: doc.center_chunk.source_links[0]
# for num, doc in enumerate(retrieved_docs)
# if doc.center_chunk.source_links
# }
answer_generation_documents = get_answer_generation_documents(
relevant_docs=retrieved_docs,
context_documents=retrieved_docs,
original_question_docs=retrieved_docs,
max_docs=KG_RESEARCH_NUM_RETRIEVED_DOCS,
)
# answer_generation_documents = get_answer_generation_documents(
# relevant_docs=retrieved_docs,
# context_documents=retrieved_docs,
# original_question_docs=retrieved_docs,
# max_docs=KG_RESEARCH_NUM_RETRIEVED_DOCS,
# )
document_texts_list = []
# document_texts_list = []
for doc_num, retrieved_doc in enumerate(
answer_generation_documents.context_documents
):
chunk_text = build_document_context(retrieved_doc, doc_num + 1)
document_texts_list.append(chunk_text)
# for doc_num, retrieved_doc in enumerate(
# answer_generation_documents.context_documents
# ):
# chunk_text = build_document_context(retrieved_doc, doc_num + 1)
# document_texts_list.append(chunk_text)
document_texts = "\n\n".join(document_texts_list)
# document_texts = "\n\n".join(document_texts_list)
# Built prompt
# # Built prompt
datetime.now().strftime("%A, %Y-%m-%d")
# datetime.now().strftime("%A, %Y-%m-%d")
kg_object_source_research_prompt = KG_SEARCH_PROMPT.format(
question=question,
document_text=document_texts,
)
# kg_object_source_research_prompt = KG_SEARCH_PROMPT.format(
# question=question,
# document_text=document_texts,
# )
# Run LLM
# # Run LLM
msg = [
HumanMessage(
content=trim_prompt_piece(
config=graph_config.tooling.primary_llm.config,
prompt_piece=kg_object_source_research_prompt,
reserved_str="",
),
)
]
primary_llm = graph_config.tooling.primary_llm
llm = primary_llm
# Grader
try:
llm_response = run_with_timeout(
KG_FILTERED_SEARCH_TIMEOUT,
llm.invoke_langchain,
prompt=msg,
timeout_override=30,
max_tokens=300,
)
# msg = [
# HumanMessage(
# content=trim_prompt_piece(
# config=graph_config.tooling.primary_llm.config,
# prompt_piece=kg_object_source_research_prompt,
# reserved_str="",
# ),
# )
# ]
# primary_llm = graph_config.tooling.primary_llm
# llm = primary_llm
# # Grader
# try:
# llm_response = run_with_timeout(
# KG_FILTERED_SEARCH_TIMEOUT,
# llm.invoke_langchain,
# prompt=msg,
# timeout_override=30,
# max_tokens=300,
# )
filtered_search_answer = str(llm_response.content).replace("```json\n", "")
# filtered_search_answer = str(llm_response.content).replace("```json\n", "")
# TODO: make sure the citations look correct. Currently, they do not.
for source_link_num, source_link in source_link_dict.items():
if f"[{source_link_num}]" in filtered_search_answer:
filtered_search_answer = filtered_search_answer.replace(
f"[{source_link_num}]", f"[{source_link_num}]({source_link})"
)
# # TODO: make sure the citations look correct. Currently, they do not.
# for source_link_num, source_link in source_link_dict.items():
# if f"[{source_link_num}]" in filtered_search_answer:
# filtered_search_answer = filtered_search_answer.replace(
# f"[{source_link_num}]", f"[{source_link_num}]({source_link})"
# )
except Exception as e:
raise ValueError(f"Error in filtered_search: {e}")
# except Exception as e:
# raise ValueError(f"Error in filtered_search: {e}")
step_answer = "Filtered search is complete."
# step_answer = "Filtered search is complete."
return ConsolidatedResearchUpdate(
consolidated_research_object_results_str=filtered_search_answer,
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="filtered search",
node_start_time=node_start_time,
)
],
step_results=[
get_near_empty_step_results(
step_number=_KG_STEP_NR,
step_answer=step_answer,
verified_reranked_documents=retrieved_docs,
)
],
)
# return ConsolidatedResearchUpdate(
# consolidated_research_object_results_str=filtered_search_answer,
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="main",
# node_name="filtered search",
# node_start_time=node_start_time,
# )
# ],
# step_results=[
# get_near_empty_step_results(
# step_number=_KG_STEP_NR,
# step_answer=step_answer,
# verified_reranked_documents=retrieved_docs,
# )
# ],
# )

View File

@@ -1,54 +1,54 @@
from datetime import datetime
# from datetime import datetime
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from onyx.agents.agent_search.kb_search.graph_utils import get_near_empty_step_results
from onyx.agents.agent_search.kb_search.graph_utils import rename_entities_in_answer
from onyx.agents.agent_search.kb_search.states import ConsolidatedResearchUpdate
from onyx.agents.agent_search.kb_search.states import MainState
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.utils.logger import setup_logger
# from onyx.agents.agent_search.kb_search.graph_utils import get_near_empty_step_results
# from onyx.agents.agent_search.kb_search.graph_utils import rename_entities_in_answer
# from onyx.agents.agent_search.kb_search.states import ConsolidatedResearchUpdate
# from onyx.agents.agent_search.kb_search.states import MainState
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.utils.logger import setup_logger
logger = setup_logger()
# logger = setup_logger()
def consolidate_individual_deep_search(
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> ConsolidatedResearchUpdate:
"""
LangGraph node to start the agentic search process.
"""
# def consolidate_individual_deep_search(
# state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
# ) -> ConsolidatedResearchUpdate:
# """
# LangGraph node to start the agentic search process.
# """
_KG_STEP_NR = 4
node_start_time = datetime.now()
# _KG_STEP_NR = 4
# node_start_time = datetime.now()
research_object_results = state.research_object_results
# research_object_results = state.research_object_results
consolidated_research_object_results_str = "\n".join(
[f"{x['object']}: {x['results']}" for x in research_object_results]
)
# consolidated_research_object_results_str = "\n".join(
# [f"{x['object']}: {x['results']}" for x in research_object_results]
# )
consolidated_research_object_results_str = rename_entities_in_answer(
consolidated_research_object_results_str
)
# consolidated_research_object_results_str = rename_entities_in_answer(
# consolidated_research_object_results_str
# )
step_answer = "All research is complete. Consolidating results..."
# step_answer = "All research is complete. Consolidating results..."
return ConsolidatedResearchUpdate(
consolidated_research_object_results_str=consolidated_research_object_results_str,
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="consolidate individual deep search",
node_start_time=node_start_time,
)
],
step_results=[
get_near_empty_step_results(
step_number=_KG_STEP_NR, step_answer=step_answer
)
],
)
# return ConsolidatedResearchUpdate(
# consolidated_research_object_results_str=consolidated_research_object_results_str,
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="main",
# node_name="consolidate individual deep search",
# node_start_time=node_start_time,
# )
# ],
# step_results=[
# get_near_empty_step_results(
# step_number=_KG_STEP_NR, step_answer=step_answer
# )
# ],
# )

View File

@@ -1,96 +1,96 @@
from datetime import datetime
# from datetime import datetime
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from onyx.agents.agent_search.kb_search.graph_utils import get_near_empty_step_results
from onyx.agents.agent_search.kb_search.states import MainState
from onyx.agents.agent_search.kb_search.states import ResultsDataUpdate
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.db.document import get_base_llm_doc_information
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.utils.logger import setup_logger
# from onyx.agents.agent_search.kb_search.graph_utils import get_near_empty_step_results
# from onyx.agents.agent_search.kb_search.states import MainState
# from onyx.agents.agent_search.kb_search.states import ResultsDataUpdate
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.db.document import get_base_llm_doc_information
# from onyx.db.engine.sql_engine import get_session_with_current_tenant
# from onyx.utils.logger import setup_logger
logger = setup_logger()
# logger = setup_logger()
def _get_formated_source_reference_results(
source_document_results: list[str] | None,
) -> str | None:
"""
Generate reference results from the query results data string.
"""
# def _get_formated_source_reference_results(
# source_document_results: list[str] | None,
# ) -> str | None:
# """
# Generate reference results from the query results data string.
# """
if source_document_results is None:
return None
# if source_document_results is None:
# return None
# get all entities that correspond to an Onyx document
document_ids = source_document_results
# # get all entities that correspond to an Onyx document
# document_ids = source_document_results
with get_session_with_current_tenant() as session:
llm_doc_information_results = get_base_llm_doc_information(
session, document_ids
)
# with get_session_with_current_tenant() as session:
# llm_doc_information_results = get_base_llm_doc_information(
# session, document_ids
# )
if len(llm_doc_information_results) == 0:
return ""
# if len(llm_doc_information_results) == 0:
# return ""
return (
f"\n \n Here are {len(llm_doc_information_results)} supporting documents or examples: \n \n "
+ " \n \n ".join(llm_doc_information_results)
)
# return (
# f"\n \n Here are {len(llm_doc_information_results)} supporting documents or examples: \n \n "
# + " \n \n ".join(llm_doc_information_results)
# )
def process_kg_only_answers(
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> ResultsDataUpdate:
"""
LangGraph node to start the agentic search process.
"""
# def process_kg_only_answers(
# state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
# ) -> ResultsDataUpdate:
# """
# LangGraph node to start the agentic search process.
# """
_KG_STEP_NR = 4
# _KG_STEP_NR = 4
node_start_time = datetime.now()
# node_start_time = datetime.now()
query_results = state.sql_query_results
source_document_results = state.source_document_results
# query_results = state.sql_query_results
# source_document_results = state.source_document_results
if query_results:
query_results_data_str = "\n".join(
str(query_result).replace("::", ":: ").capitalize()
for query_result in query_results
)
else:
logger.warning("No query results were found")
query_results_data_str = "(No query results were found)"
# if query_results:
# query_results_data_str = "\n".join(
# str(query_result).replace("::", ":: ").capitalize()
# for query_result in query_results
# )
# else:
# logger.warning("No query results were found")
# query_results_data_str = "(No query results were found)"
source_reference_result_str = _get_formated_source_reference_results(
source_document_results
)
# source_reference_result_str = _get_formated_source_reference_results(
# source_document_results
# )
## STEP 4 - same components as Step 1
# ## STEP 4 - same components as Step 1
step_answer = (
"No further research is needed, the answer is derived from the knowledge graph."
)
# step_answer = (
# "No further research is needed, the answer is derived from the knowledge graph."
# )
return ResultsDataUpdate(
query_results_data_str=query_results_data_str,
individualized_query_results_data_str="",
reference_results_str=source_reference_result_str,
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="kg query results data processing",
node_start_time=node_start_time,
)
],
step_results=[
get_near_empty_step_results(
step_number=_KG_STEP_NR, step_answer=step_answer
)
],
)
# return ResultsDataUpdate(
# query_results_data_str=query_results_data_str,
# individualized_query_results_data_str="",
# reference_results_str=source_reference_result_str,
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="main",
# node_name="kg query results data processing",
# node_start_time=node_start_time,
# )
# ],
# step_results=[
# get_near_empty_step_results(
# step_number=_KG_STEP_NR, step_answer=step_answer
# )
# ],
# )

View File

@@ -1,213 +1,187 @@
from datetime import datetime
from typing import cast
# from datetime import datetime
# from typing import cast
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from onyx.access.access import get_acl_for_user
from onyx.agents.agent_search.kb_search.graph_utils import rename_entities_in_answer
from onyx.agents.agent_search.kb_search.ops import research
from onyx.agents.agent_search.kb_search.states import FinalAnswerUpdate
from onyx.agents.agent_search.kb_search.states import MainState
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.calculations import (
get_answer_generation_documents,
)
from onyx.agents.agent_search.shared_graph_utils.llm import get_answer_from_llm
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import relevance_from_docs
from onyx.configs.kg_configs import KG_RESEARCH_NUM_RETRIEVED_DOCS
from onyx.configs.kg_configs import KG_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION
from onyx.context.search.enums import SearchType
from onyx.context.search.models import InferenceSection
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.prompts.kg_prompts import OUTPUT_FORMAT_NO_EXAMPLES_PROMPT
from onyx.prompts.kg_prompts import OUTPUT_FORMAT_NO_OVERALL_ANSWER_PROMPT
from onyx.tools.tool_implementations.search.search_tool import IndexFilters
from onyx.tools.tool_implementations.search.search_tool import SearchQueryInfo
from onyx.tools.tool_implementations.search.search_tool import yield_search_responses
from onyx.utils.logger import setup_logger
# from onyx.access.access import get_acl_for_user
# from onyx.agents.agent_search.kb_search.graph_utils import rename_entities_in_answer
# from onyx.agents.agent_search.kb_search.ops import research
# from onyx.agents.agent_search.kb_search.states import FinalAnswerUpdate
# from onyx.agents.agent_search.kb_search.states import MainState
# from onyx.agents.agent_search.models import GraphConfig
# from onyx.agents.agent_search.shared_graph_utils.calculations import (
# get_answer_generation_documents,
# )
# from onyx.agents.agent_search.shared_graph_utils.llm import get_answer_from_llm
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.configs.kg_configs import KG_RESEARCH_NUM_RETRIEVED_DOCS
# from onyx.configs.kg_configs import KG_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION
# from onyx.context.search.models import InferenceSection
# from onyx.db.engine.sql_engine import get_session_with_current_tenant
# from onyx.prompts.kg_prompts import OUTPUT_FORMAT_NO_EXAMPLES_PROMPT
# from onyx.prompts.kg_prompts import OUTPUT_FORMAT_NO_OVERALL_ANSWER_PROMPT
# from onyx.utils.logger import setup_logger
logger = setup_logger()
# logger = setup_logger()
def generate_answer(
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> FinalAnswerUpdate:
"""
LangGraph node to start the agentic search process.
"""
# def generate_answer(
# state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
# ) -> FinalAnswerUpdate:
# """
# LangGraph node to start the agentic search process.
# """
node_start_time = datetime.now()
# node_start_time = datetime.now()
graph_config = cast(GraphConfig, config["metadata"]["config"])
question = state.question
# graph_config = cast(GraphConfig, config["metadata"]["config"])
# question = state.question
final_answer: str | None = None
# final_answer: str | None = None
user = (
graph_config.tooling.search_tool.user
if graph_config.tooling.search_tool
else None
)
# user = (
# graph_config.tooling.search_tool.user
# if graph_config.tooling.search_tool
# else None
# )
if not user:
raise ValueError("User is not set")
# if not user:
# raise ValueError("User is not set")
search_tool = graph_config.tooling.search_tool
if search_tool is None:
raise ValueError("Search tool is not set")
# search_tool = graph_config.tooling.search_tool
# if search_tool is None:
# raise ValueError("Search tool is not set")
# Close out previous streams of steps
# # Close out previous streams of steps
# DECLARE STEPS DONE
# # DECLARE STEPS DONE
## MAIN ANSWER
# ## MAIN ANSWER
# identify whether documents have already been retrieved
# # identify whether documents have already been retrieved
retrieved_docs: list[InferenceSection] = []
for step_result in state.step_results:
retrieved_docs += step_result.verified_reranked_documents
# retrieved_docs: list[InferenceSection] = []
# for step_result in state.step_results:
# retrieved_docs += step_result.verified_reranked_documents
# if still needed, get a search done and send the results to the UI
# # if still needed, get a search done and send the results to the UI
if not retrieved_docs and state.source_document_results:
assert graph_config.tooling.search_tool is not None
retrieved_docs = cast(
list[InferenceSection],
research(
question=question,
kg_entities=[],
kg_relationships=[],
kg_sources=state.source_document_results[
:KG_RESEARCH_NUM_RETRIEVED_DOCS
],
search_tool=graph_config.tooling.search_tool,
kg_chunk_id_zero_only=True,
inference_sections_only=True,
),
)
# if not retrieved_docs and state.source_document_results:
# assert graph_config.tooling.search_tool is not None
# retrieved_docs = cast(
# list[InferenceSection],
# research(
# question=question,
# kg_entities=[],
# kg_relationships=[],
# kg_sources=state.source_document_results[
# :KG_RESEARCH_NUM_RETRIEVED_DOCS
# ],
# search_tool=graph_config.tooling.search_tool,
# kg_chunk_id_zero_only=True,
# inference_sections_only=True,
# ),
# )
answer_generation_documents = get_answer_generation_documents(
relevant_docs=retrieved_docs,
context_documents=retrieved_docs,
original_question_docs=retrieved_docs,
max_docs=KG_RESEARCH_NUM_RETRIEVED_DOCS,
)
# answer_generation_documents = get_answer_generation_documents(
# relevant_docs=retrieved_docs,
# context_documents=retrieved_docs,
# original_question_docs=retrieved_docs,
# max_docs=KG_RESEARCH_NUM_RETRIEVED_DOCS,
# )
relevance_list = relevance_from_docs(
answer_generation_documents.streaming_documents
)
# assert graph_config.tooling.search_tool is not None
assert graph_config.tooling.search_tool is not None
# with get_session_with_current_tenant() as graph_db_session:
# list(get_acl_for_user(user, graph_db_session))
with get_session_with_current_tenant() as graph_db_session:
user_acl = list(get_acl_for_user(user, graph_db_session))
# # continue with the answer generation
for tool_response in yield_search_responses(
query=question,
get_retrieved_sections=lambda: answer_generation_documents.context_documents,
get_final_context_sections=lambda: answer_generation_documents.context_documents,
search_query_info=SearchQueryInfo(
predicted_search=SearchType.KEYWORD,
# acl here is empty, because the searach alrady happened and
# we are streaming out the results.
final_filters=IndexFilters(access_control_list=user_acl),
recency_bias_multiplier=1.0,
),
get_section_relevance=lambda: relevance_list,
search_tool=graph_config.tooling.search_tool,
):
# original document streaming
pass
# output_format = (
# state.output_format.value
# if state.output_format
# else "<you be the judge how to best present the data>"
# )
# continue with the answer generation
# # if deep path was taken:
output_format = (
state.output_format.value
if state.output_format
else "<you be the judge how to best present the data>"
)
# consolidated_research_object_results_str = (
# state.consolidated_research_object_results_str
# )
# # reference_results_str = (
# # state.reference_results_str
# # ) # will not be part of LLM. Manually added to the answer
# if deep path was taken:
# # if simple path was taken:
# introductory_answer = state.query_results_data_str # from simple answer path only
# if consolidated_research_object_results_str:
# research_results = consolidated_research_object_results_str
# else:
# research_results = ""
consolidated_research_object_results_str = (
state.consolidated_research_object_results_str
)
# reference_results_str = (
# state.reference_results_str
# ) # will not be part of LLM. Manually added to the answer
# if introductory_answer:
# output_format_prompt = (
# OUTPUT_FORMAT_NO_EXAMPLES_PROMPT.replace("---question---", question)
# .replace(
# "---introductory_answer---",
# rename_entities_in_answer(introductory_answer),
# )
# .replace("---output_format---", str(output_format) if output_format else "")
# )
# elif research_results and consolidated_research_object_results_str:
# output_format_prompt = (
# OUTPUT_FORMAT_NO_EXAMPLES_PROMPT.replace("---question---", question)
# .replace(
# "---introductory_answer---",
# rename_entities_in_answer(consolidated_research_object_results_str),
# )
# .replace("---output_format---", str(output_format) if output_format else "")
# )
# elif research_results and not consolidated_research_object_results_str:
# output_format_prompt = (
# OUTPUT_FORMAT_NO_OVERALL_ANSWER_PROMPT.replace("---question---", question)
# .replace("---output_format---", str(output_format) if output_format else "")
# .replace(
# "---research_results---", rename_entities_in_answer(research_results)
# )
# )
# elif consolidated_research_object_results_str:
# output_format_prompt = (
# OUTPUT_FORMAT_NO_EXAMPLES_PROMPT.replace("---question---", question)
# .replace("---output_format---", str(output_format) if output_format else "")
# .replace(
# "---research_results---", rename_entities_in_answer(research_results)
# )
# )
# else:
# raise ValueError("No research results or introductory answer provided")
# if simple path was taken:
introductory_answer = state.query_results_data_str # from simple answer path only
if consolidated_research_object_results_str:
research_results = consolidated_research_object_results_str
else:
research_results = ""
# try:
if introductory_answer:
output_format_prompt = (
OUTPUT_FORMAT_NO_EXAMPLES_PROMPT.replace("---question---", question)
.replace(
"---introductory_answer---",
rename_entities_in_answer(introductory_answer),
)
.replace("---output_format---", str(output_format) if output_format else "")
)
elif research_results and consolidated_research_object_results_str:
output_format_prompt = (
OUTPUT_FORMAT_NO_EXAMPLES_PROMPT.replace("---question---", question)
.replace(
"---introductory_answer---",
rename_entities_in_answer(consolidated_research_object_results_str),
)
.replace("---output_format---", str(output_format) if output_format else "")
)
elif research_results and not consolidated_research_object_results_str:
output_format_prompt = (
OUTPUT_FORMAT_NO_OVERALL_ANSWER_PROMPT.replace("---question---", question)
.replace("---output_format---", str(output_format) if output_format else "")
.replace(
"---research_results---", rename_entities_in_answer(research_results)
)
)
elif consolidated_research_object_results_str:
output_format_prompt = (
OUTPUT_FORMAT_NO_EXAMPLES_PROMPT.replace("---question---", question)
.replace("---output_format---", str(output_format) if output_format else "")
.replace(
"---research_results---", rename_entities_in_answer(research_results)
)
)
else:
raise ValueError("No research results or introductory answer provided")
# final_answer = get_answer_from_llm(
# llm=graph_config.tooling.primary_llm,
# prompt=output_format_prompt,
# stream=False,
# json_string_flag=False,
# timeout_override=KG_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION,
# )
try:
# except Exception as e:
# raise ValueError(f"Could not generate the answer. Error {e}")
final_answer = get_answer_from_llm(
llm=graph_config.tooling.primary_llm,
prompt=output_format_prompt,
stream=False,
json_string_flag=False,
timeout_override=KG_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION,
)
except Exception as e:
raise ValueError(f"Could not generate the answer. Error {e}")
return FinalAnswerUpdate(
final_answer=final_answer,
retrieved_documents=answer_generation_documents.context_documents,
step_results=[],
remarks=[],
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="query completed",
node_start_time=node_start_time,
)
],
)
# return FinalAnswerUpdate(
# final_answer=final_answer,
# retrieved_documents=answer_generation_documents.context_documents,
# step_results=[],
# remarks=[],
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="main",
# node_name="query completed",
# node_start_time=node_start_time,
# )
# ],
# )

View File

@@ -1,60 +1,60 @@
from datetime import datetime
from typing import cast
# from datetime import datetime
# from typing import cast
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from onyx.agents.agent_search.kb_search.states import MainOutput
from onyx.agents.agent_search.kb_search.states import MainState
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.db.chat import log_agent_sub_question_results
from onyx.utils.logger import setup_logger
# from onyx.agents.agent_search.kb_search.states import MainOutput
# from onyx.agents.agent_search.kb_search.states import MainState
# from onyx.agents.agent_search.models import GraphConfig
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.db.chat import log_agent_sub_question_results
# from onyx.utils.logger import setup_logger
logger = setup_logger()
# logger = setup_logger()
def log_data(
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> MainOutput:
"""
LangGraph node to start the agentic search process.
"""
# def log_data(
# state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
# ) -> MainOutput:
# """
# LangGraph node to start the agentic search process.
# """
node_start_time = datetime.now()
# node_start_time = datetime.now()
graph_config = cast(GraphConfig, config["metadata"]["config"])
# graph_config = cast(GraphConfig, config["metadata"]["config"])
search_tool = graph_config.tooling.search_tool
if search_tool is None:
raise ValueError("Search tool is not set")
# search_tool = graph_config.tooling.search_tool
# if search_tool is None:
# raise ValueError("Search tool is not set")
# commit original db_session
# # commit original db_session
query_db_session = graph_config.persistence.db_session
query_db_session.commit()
# query_db_session = graph_config.persistence.db_session
# query_db_session.commit()
chat_session_id = graph_config.persistence.chat_session_id
primary_message_id = graph_config.persistence.message_id
sub_question_answer_results = state.step_results
# chat_session_id = graph_config.persistence.chat_session_id
# primary_message_id = graph_config.persistence.message_id
# sub_question_answer_results = state.step_results
log_agent_sub_question_results(
db_session=query_db_session,
chat_session_id=chat_session_id,
primary_message_id=primary_message_id,
sub_question_answer_results=sub_question_answer_results,
)
# log_agent_sub_question_results(
# db_session=query_db_session,
# chat_session_id=chat_session_id,
# primary_message_id=primary_message_id,
# sub_question_answer_results=sub_question_answer_results,
# )
return MainOutput(
final_answer=state.final_answer,
retrieved_documents=state.retrieved_documents,
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="query completed",
node_start_time=node_start_time,
)
],
)
# return MainOutput(
# final_answer=state.final_answer,
# retrieved_documents=state.retrieved_documents,
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="main",
# node_name="query completed",
# node_start_time=node_start_time,
# )
# ],
# )

View File

@@ -1,65 +1,47 @@
from datetime import datetime
from typing import cast
# from datetime import datetime
# from typing import cast
from onyx.chat.models import LlmDoc
from onyx.configs.constants import DocumentSource
from onyx.configs.kg_configs import KG_RESEARCH_NUM_RETRIEVED_DOCS
from onyx.context.search.models import InferenceSection
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.tools.models import SearchToolOverrideKwargs
from onyx.tools.tool_implementations.search.search_tool import (
FINAL_CONTEXT_DOCUMENTS_ID,
)
from onyx.tools.tool_implementations.search.search_tool import SearchTool
# from onyx.chat.models import LlmDoc
# from onyx.configs.constants import DocumentSource
# from onyx.configs.kg_configs import KG_RESEARCH_NUM_RETRIEVED_DOCS
# from onyx.context.search.models import InferenceSection
# from onyx.tools.models import SearchToolOverrideKwargs
# from onyx.tools.tool_implementations.search.search_tool import (
# FINAL_CONTEXT_DOCUMENTS_ID,
# )
# from onyx.tools.tool_implementations.search.search_tool import SearchTool
def research(
question: str,
search_tool: SearchTool,
document_sources: list[DocumentSource] | None = None,
time_cutoff: datetime | None = None,
kg_entities: list[str] | None = None,
kg_relationships: list[str] | None = None,
kg_terms: list[str] | None = None,
kg_sources: list[str] | None = None,
kg_chunk_id_zero_only: bool = False,
inference_sections_only: bool = False,
) -> list[LlmDoc] | list[InferenceSection]:
# new db session to avoid concurrency issues
# def research(
# question: str,
# search_tool: SearchTool,
# document_sources: list[DocumentSource] | None = None,
# time_cutoff: datetime | None = None,
# kg_entities: list[str] | None = None,
# kg_relationships: list[str] | None = None,
# kg_terms: list[str] | None = None,
# kg_sources: list[str] | None = None,
# kg_chunk_id_zero_only: bool = False,
# inference_sections_only: bool = False,
# ) -> list[LlmDoc] | list[InferenceSection]:
# # new db session to avoid concurrency issues
callback_container: list[list[InferenceSection]] = []
retrieved_docs: list[LlmDoc] | list[InferenceSection] = []
# retrieved_docs: list[LlmDoc] | list[InferenceSection] = []
with get_session_with_current_tenant() as db_session:
for tool_response in search_tool.run(
query=question,
override_kwargs=SearchToolOverrideKwargs(
force_no_rerank=False,
alternate_db_session=db_session,
retrieved_sections_callback=callback_container.append,
skip_query_analysis=True,
document_sources=document_sources,
time_cutoff=time_cutoff,
kg_entities=kg_entities,
kg_relationships=kg_relationships,
kg_terms=kg_terms,
kg_sources=kg_sources,
kg_chunk_id_zero_only=kg_chunk_id_zero_only,
),
):
if (
inference_sections_only
and tool_response.id == "search_response_summary"
):
retrieved_docs = tool_response.response.top_sections[
:KG_RESEARCH_NUM_RETRIEVED_DOCS
]
retrieved_docs = cast(list[InferenceSection], retrieved_docs)
break
# get retrieved docs to send to the rest of the graph
elif tool_response.id == FINAL_CONTEXT_DOCUMENTS_ID:
retrieved_docs = cast(list[LlmDoc], tool_response.response)[
:KG_RESEARCH_NUM_RETRIEVED_DOCS
]
break
return retrieved_docs
# for tool_response in search_tool.run(
# query=question,
# override_kwargs=SearchToolOverrideKwargs(original_query=question),
# ):
# if inference_sections_only and tool_response.id == "search_response_summary":
# retrieved_docs = tool_response.response.top_sections[
# :KG_RESEARCH_NUM_RETRIEVED_DOCS
# ]
# retrieved_docs = cast(list[InferenceSection], retrieved_docs)
# break
# # get retrieved docs to send to the rest of the graph
# elif tool_response.id == FINAL_CONTEXT_DOCUMENTS_ID:
# retrieved_docs = cast(list[LlmDoc], tool_response.response)[
# :KG_RESEARCH_NUM_RETRIEVED_DOCS
# ]
# break
# return retrieved_docs

View File

@@ -1,191 +1,191 @@
from enum import Enum
from operator import add
from typing import Annotated
from typing import Any
from typing import Dict
from typing import TypedDict
# from enum import Enum
# from operator import add
# from typing import Annotated
# from typing import Any
# from typing import Dict
# from typing import TypedDict
from pydantic import BaseModel
# from pydantic import BaseModel
from onyx.agents.agent_search.core_state import CoreState
from onyx.agents.agent_search.orchestration.states import ToolCallUpdate
from onyx.agents.agent_search.orchestration.states import ToolChoiceInput
from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
from onyx.agents.agent_search.shared_graph_utils.models import QueryRetrievalResult
from onyx.agents.agent_search.shared_graph_utils.models import SubQuestionAnswerResults
from onyx.context.search.models import InferenceSection
# from onyx.agents.agent_search.core_state import CoreState
# from onyx.agents.agent_search.orchestration.states import ToolCallUpdate
# from onyx.agents.agent_search.orchestration.states import ToolChoiceInput
# from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
# from onyx.agents.agent_search.shared_graph_utils.models import QueryRetrievalResult
# from onyx.agents.agent_search.shared_graph_utils.models import SubQuestionAnswerResults
# from onyx.context.search.models import InferenceSection
### States ###
# ### States ###
class StepResults(BaseModel):
question: str
question_id: str
answer: str
sub_query_retrieval_results: list[QueryRetrievalResult]
verified_reranked_documents: list[InferenceSection]
context_documents: list[InferenceSection]
cited_documents: list[InferenceSection]
# class StepResults(BaseModel):
# question: str
# question_id: str
# answer: str
# sub_query_retrieval_results: list[QueryRetrievalResult]
# verified_reranked_documents: list[InferenceSection]
# context_documents: list[InferenceSection]
# cited_documents: list[InferenceSection]
class LoggerUpdate(BaseModel):
log_messages: Annotated[list[str], add] = []
step_results: Annotated[list[SubQuestionAnswerResults], add]
remarks: Annotated[list[str], add] = []
# class LoggerUpdate(BaseModel):
# log_messages: Annotated[list[str], add] = []
# step_results: Annotated[list[SubQuestionAnswerResults], add]
# remarks: Annotated[list[str], add] = []
class KGFilterConstructionResults(BaseModel):
global_entity_filters: list[str]
global_relationship_filters: list[str]
local_entity_filters: list[list[str]]
source_document_filters: list[str]
structure: list[str]
# class KGFilterConstructionResults(BaseModel):
# global_entity_filters: list[str]
# global_relationship_filters: list[str]
# local_entity_filters: list[list[str]]
# source_document_filters: list[str]
# structure: list[str]
class KGSearchType(Enum):
SEARCH = "SEARCH"
SQL = "SQL"
# class KGSearchType(Enum):
# SEARCH = "SEARCH"
# SQL = "SQL"
class KGAnswerStrategy(Enum):
DEEP = "DEEP"
SIMPLE = "SIMPLE"
# class KGAnswerStrategy(Enum):
# DEEP = "DEEP"
# SIMPLE = "SIMPLE"
class KGSourceDivisionType(Enum):
SOURCE = "SOURCE"
ENTITY = "ENTITY"
# class KGSourceDivisionType(Enum):
# SOURCE = "SOURCE"
# ENTITY = "ENTITY"
class KGRelationshipDetection(Enum):
RELATIONSHIPS = "RELATIONSHIPS"
NO_RELATIONSHIPS = "NO_RELATIONSHIPS"
# class KGRelationshipDetection(Enum):
# RELATIONSHIPS = "RELATIONSHIPS"
# NO_RELATIONSHIPS = "NO_RELATIONSHIPS"
class KGAnswerFormat(Enum):
LIST = "LIST"
TEXT = "TEXT"
# class KGAnswerFormat(Enum):
# LIST = "LIST"
# TEXT = "TEXT"
class YesNoEnum(str, Enum):
YES = "yes"
NO = "no"
# class YesNoEnum(str, Enum):
# YES = "yes"
# NO = "no"
class AnalysisUpdate(LoggerUpdate):
normalized_core_entities: list[str] = []
normalized_core_relationships: list[str] = []
entity_normalization_map: dict[str, str] = {}
relationship_normalization_map: dict[str, str] = {}
query_graph_entities_no_attributes: list[str] = []
query_graph_entities_w_attributes: list[str] = []
query_graph_relationships: list[str] = []
normalized_terms: list[str] = []
normalized_time_filter: str | None = None
strategy: KGAnswerStrategy | None = None
output_format: KGAnswerFormat | None = None
broken_down_question: str | None = None
divide_and_conquer: YesNoEnum | None = None
single_doc_id: str | None = None
search_type: KGSearchType | None = None
query_type: str | None = None
# class AnalysisUpdate(LoggerUpdate):
# normalized_core_entities: list[str] = []
# normalized_core_relationships: list[str] = []
# entity_normalization_map: dict[str, str] = {}
# relationship_normalization_map: dict[str, str] = {}
# query_graph_entities_no_attributes: list[str] = []
# query_graph_entities_w_attributes: list[str] = []
# query_graph_relationships: list[str] = []
# normalized_terms: list[str] = []
# normalized_time_filter: str | None = None
# strategy: KGAnswerStrategy | None = None
# output_format: KGAnswerFormat | None = None
# broken_down_question: str | None = None
# divide_and_conquer: YesNoEnum | None = None
# single_doc_id: str | None = None
# search_type: KGSearchType | None = None
# query_type: str | None = None
class SQLSimpleGenerationUpdate(LoggerUpdate):
sql_query: str | None = None
sql_query_results: list[Dict[Any, Any]] | None = None
individualized_sql_query: str | None = None
individualized_query_results: list[Dict[Any, Any]] | None = None
source_documents_sql: str | None = None
source_document_results: list[str] | None = None
updated_strategy: KGAnswerStrategy | None = None
# class SQLSimpleGenerationUpdate(LoggerUpdate):
# sql_query: str | None = None
# sql_query_results: list[Dict[Any, Any]] | None = None
# individualized_sql_query: str | None = None
# individualized_query_results: list[Dict[Any, Any]] | None = None
# source_documents_sql: str | None = None
# source_document_results: list[str] | None = None
# updated_strategy: KGAnswerStrategy | None = None
class ConsolidatedResearchUpdate(LoggerUpdate):
consolidated_research_object_results_str: str | None = None
# class ConsolidatedResearchUpdate(LoggerUpdate):
# consolidated_research_object_results_str: str | None = None
class DeepSearchFilterUpdate(LoggerUpdate):
vespa_filter_results: KGFilterConstructionResults | None = None
div_con_entities: list[str] | None = None
source_division: bool | None = None
global_entity_filters: list[str] | None = None
global_relationship_filters: list[str] | None = None
local_entity_filters: list[list[str]] | None = None
source_filters: list[str] | None = None
# class DeepSearchFilterUpdate(LoggerUpdate):
# vespa_filter_results: KGFilterConstructionResults | None = None
# div_con_entities: list[str] | None = None
# source_division: bool | None = None
# global_entity_filters: list[str] | None = None
# global_relationship_filters: list[str] | None = None
# local_entity_filters: list[list[str]] | None = None
# source_filters: list[str] | None = None
class ResearchObjectOutput(LoggerUpdate):
research_object_results: Annotated[list[dict[str, Any]], add] = []
# class ResearchObjectOutput(LoggerUpdate):
# research_object_results: Annotated[list[dict[str, Any]], add] = []
class EntityRelationshipExtractionUpdate(LoggerUpdate):
entities_types_str: str = ""
relationship_types_str: str = ""
extracted_entities_w_attributes: list[str] = []
extracted_entities_no_attributes: list[str] = []
extracted_relationships: list[str] = []
time_filter: str | None = None
kg_doc_temp_view_name: str | None = None
kg_rel_temp_view_name: str | None = None
kg_entity_temp_view_name: str | None = None
# class EntityRelationshipExtractionUpdate(LoggerUpdate):
# entities_types_str: str = ""
# relationship_types_str: str = ""
# extracted_entities_w_attributes: list[str] = []
# extracted_entities_no_attributes: list[str] = []
# extracted_relationships: list[str] = []
# time_filter: str | None = None
# kg_doc_temp_view_name: str | None = None
# kg_rel_temp_view_name: str | None = None
# kg_entity_temp_view_name: str | None = None
class ResultsDataUpdate(LoggerUpdate):
query_results_data_str: str | None = None
individualized_query_results_data_str: str | None = None
reference_results_str: str | None = None
# class ResultsDataUpdate(LoggerUpdate):
# query_results_data_str: str | None = None
# individualized_query_results_data_str: str | None = None
# reference_results_str: str | None = None
class ResearchObjectUpdate(LoggerUpdate):
research_object_results: Annotated[list[dict[str, Any]], add] = []
# class ResearchObjectUpdate(LoggerUpdate):
# research_object_results: Annotated[list[dict[str, Any]], add] = []
## Graph Input State
class MainInput(CoreState):
question: str
individual_flow: bool = True # used for UI display purposes
# ## Graph Input State
# class MainInput(CoreState):
# question: str
# individual_flow: bool = True # used for UI display purposes
class FinalAnswerUpdate(LoggerUpdate):
final_answer: str | None = None
retrieved_documents: list[InferenceSection] | None = None
# class FinalAnswerUpdate(LoggerUpdate):
# final_answer: str | None = None
# retrieved_documents: list[InferenceSection] | None = None
## Graph State
class MainState(
# This includes the core state
MainInput,
ToolChoiceInput,
ToolCallUpdate,
ToolChoiceUpdate,
EntityRelationshipExtractionUpdate,
AnalysisUpdate,
SQLSimpleGenerationUpdate,
ResultsDataUpdate,
ResearchObjectOutput,
DeepSearchFilterUpdate,
ResearchObjectUpdate,
ConsolidatedResearchUpdate,
FinalAnswerUpdate,
):
pass
# ## Graph State
# class MainState(
# # This includes the core state
# MainInput,
# ToolChoiceInput,
# ToolCallUpdate,
# ToolChoiceUpdate,
# EntityRelationshipExtractionUpdate,
# AnalysisUpdate,
# SQLSimpleGenerationUpdate,
# ResultsDataUpdate,
# ResearchObjectOutput,
# DeepSearchFilterUpdate,
# ResearchObjectUpdate,
# ConsolidatedResearchUpdate,
# FinalAnswerUpdate,
# ):
# pass
## Graph Output State - presently not used
class MainOutput(TypedDict):
log_messages: list[str]
final_answer: str | None
retrieved_documents: list[InferenceSection] | None
# ## Graph Output State - presently not used
# class MainOutput(TypedDict):
# log_messages: list[str]
# final_answer: str | None
# retrieved_documents: list[InferenceSection] | None
class ResearchObjectInput(LoggerUpdate):
research_nr: int
entity: str
broken_down_question: str
vespa_filter_results: KGFilterConstructionResults
source_division: bool | None
source_entity_filters: list[str] | None
segment_type: str
individual_flow: bool = True # used for UI display purposes
# class ResearchObjectInput(LoggerUpdate):
# research_nr: int
# entity: str
# broken_down_question: str
# vespa_filter_results: KGFilterConstructionResults
# source_division: bool | None
# source_entity_filters: list[str] | None
# segment_type: str
# individual_flow: bool = True # used for UI display purposes

View File

@@ -1,33 +1,33 @@
from onyx.agents.agent_search.kb_search.models import KGSteps
# from onyx.agents.agent_search.kb_search.models import KGSteps
KG_SEARCH_STEP_DESCRIPTIONS: dict[int, KGSteps] = {
1: KGSteps(
description="Analyzing the question...",
activities=[
"Entities in Query",
"Relationships in Query",
"Terms in Query",
"Time Filters",
],
),
2: KGSteps(
description="Planning the response approach...",
activities=["Query Execution Strategy", "Answer Format"],
),
3: KGSteps(
description="Querying the Knowledge Graph...",
activities=[
"Knowledge Graph Query",
"Knowledge Graph Query Results",
"Query for Source Documents",
"Source Documents",
],
),
4: KGSteps(
description="Conducting further research on source documents...", activities=[]
),
}
# KG_SEARCH_STEP_DESCRIPTIONS: dict[int, KGSteps] = {
# 1: KGSteps(
# description="Analyzing the question...",
# activities=[
# "Entities in Query",
# "Relationships in Query",
# "Terms in Query",
# "Time Filters",
# ],
# ),
# 2: KGSteps(
# description="Planning the response approach...",
# activities=["Query Execution Strategy", "Answer Format"],
# ),
# 3: KGSteps(
# description="Querying the Knowledge Graph...",
# activities=[
# "Knowledge Graph Query",
# "Knowledge Graph Query Results",
# "Query for Source Documents",
# "Source Documents",
# ],
# ),
# 4: KGSteps(
# description="Conducting further research on source documents...", activities=[]
# ),
# }
BASIC_SEARCH_STEP_DESCRIPTIONS: dict[int, KGSteps] = {
1: KGSteps(description="Conducting a standard search...", activities=[]),
}
# BASIC_SEARCH_STEP_DESCRIPTIONS: dict[int, KGSteps] = {
# 1: KGSteps(description="Conducting a standard search...", activities=[]),
# }

View File

@@ -1,88 +1,89 @@
from uuid import UUID
# from uuid import UUID
from pydantic import BaseModel
from pydantic import ConfigDict
from sqlalchemy.orm import Session
# from pydantic import BaseModel
# from sqlalchemy.orm import Session
from onyx.agents.agent_search.dr.enums import ResearchType
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.context.search.models import RerankingDetails
from onyx.db.models import Persona
from onyx.file_store.utils import InMemoryChatFile
from onyx.kg.models import KGConfigSettings
from onyx.llm.interfaces import LLM
from onyx.tools.force import ForceUseTool
from onyx.tools.tool import Tool
from onyx.tools.tool_implementations.search.search_tool import SearchTool
# from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
# from onyx.context.search.models import RerankingDetails
# from onyx.db.models import Persona
# from onyx.file_store.utils import InMemoryChatFile
# from onyx.kg.models import KGConfigSettings
# from onyx.llm.interfaces import LLM
# from onyx.tools.force import ForceUseTool
# from onyx.tools.tool import Tool
# from onyx.tools.tool_implementations.search.search_tool import SearchTool
class GraphInputs(BaseModel):
"""Input data required for the graph execution"""
# class GraphInputs(BaseModel):
# """Input data required for the graph execution"""
persona: Persona | None = None
rerank_settings: RerankingDetails | None = None
prompt_builder: AnswerPromptBuilder
files: list[InMemoryChatFile] | None = None
structured_response_format: dict | None = None
project_instructions: str | None = None
# persona: Persona | None = None
# rerank_settings: RerankingDetails | None = None
# prompt_builder: AnswerPromptBuilder
# files: list[InMemoryChatFile] | None = None
# structured_response_format: dict | None = None
# project_instructions: str | None = None
model_config = ConfigDict(arbitrary_types_allowed=True)
# class Config:
# arbitrary_types_allowed = True
class GraphTooling(BaseModel):
"""Tools and LLMs available to the graph"""
# class GraphTooling(BaseModel):
# """Tools and LLMs available to the graph"""
primary_llm: LLM
fast_llm: LLM
search_tool: SearchTool | None = None
tools: list[Tool]
# Whether to force use of a tool, or to
# force tool args IF the tool is used
force_use_tool: ForceUseTool
using_tool_calling_llm: bool = False
# primary_llm: LLM
# fast_llm: LLM
# search_tool: SearchTool | None = None
# tools: list[Tool]
# # Whether to force use of a tool, or to
# # force tool args IF the tool is used
# force_use_tool: ForceUseTool
# using_tool_calling_llm: bool = False
model_config = ConfigDict(arbitrary_types_allowed=True)
# class Config:
# arbitrary_types_allowed = True
class GraphPersistence(BaseModel):
"""Configuration for data persistence"""
# class GraphPersistence(BaseModel):
# """Configuration for data persistence"""
chat_session_id: UUID
# The message ID of the to-be-created first agent message
# in response to the user message that triggered the Pro Search
message_id: int
# chat_session_id: UUID
# # The message ID of the to-be-created first agent message
# # in response to the user message that triggered the Pro Search
# message_id: int
# The database session the user and initial agent
# message were flushed to; only needed for agentic search
db_session: Session
# # The database session the user and initial agent
# # message were flushed to; only needed for agentic search
# db_session: Session
model_config = ConfigDict(arbitrary_types_allowed=True)
# class Config:
# arbitrary_types_allowed = True
class GraphSearchConfig(BaseModel):
"""Configuration controlling search behavior"""
# class GraphSearchConfig(BaseModel):
# """Configuration controlling search behavior"""
use_agentic_search: bool = False
# Whether to perform initial search to inform decomposition
perform_initial_search_decomposition: bool = True
# use_agentic_search: bool = False
# # Whether to perform initial search to inform decomposition
# perform_initial_search_decomposition: bool = True
# Whether to allow creation of refinement questions (and entity extraction, etc.)
allow_refinement: bool = True
skip_gen_ai_answer_generation: bool = False
allow_agent_reranking: bool = False
kg_config_settings: KGConfigSettings = KGConfigSettings()
research_type: ResearchType = ResearchType.THOUGHTFUL
# # Whether to allow creation of refinement questions (and entity extraction, etc.)
# allow_refinement: bool = True
# skip_gen_ai_answer_generation: bool = False
# allow_agent_reranking: bool = False
# kg_config_settings: KGConfigSettings = KGConfigSettings()
class GraphConfig(BaseModel):
"""
Main container for data needed for Langgraph execution
"""
# class GraphConfig(BaseModel):
# """
# Main container for data needed for Langgraph execution
# """
inputs: GraphInputs
tooling: GraphTooling
behavior: GraphSearchConfig
# Only needed for agentic search
persistence: GraphPersistence
# inputs: GraphInputs
# tooling: GraphTooling
# behavior: GraphSearchConfig
# # Only needed for agentic search
# persistence: GraphPersistence
model_config = ConfigDict(arbitrary_types_allowed=True)
# class Config:
# arbitrary_types_allowed = True

View File

@@ -1,50 +1,50 @@
from pydantic import BaseModel
from pydantic import ConfigDict
# from pydantic import BaseModel
from onyx.chat.prompt_builder.schemas import PromptSnapshot
from onyx.tools.message import ToolCallSummary
from onyx.tools.models import SearchToolOverrideKwargs
from onyx.tools.models import ToolCallFinalResult
from onyx.tools.models import ToolCallKickoff
from onyx.tools.models import ToolResponse
from onyx.tools.tool import Tool
# from onyx.chat.prompt_builder.schemas import PromptSnapshot
# from onyx.tools.message import ToolCallSummary
# from onyx.tools.models import SearchToolOverrideKwargs
# from onyx.tools.models import ToolCallFinalResult
# from onyx.tools.models import ToolCallKickoff
# from onyx.tools.models import ToolResponse
# from onyx.tools.tool import Tool
# TODO: adapt the tool choice/tool call to allow for parallel tool calls by
# creating a subgraph that can be invoked in parallel via Send/Command APIs
class ToolChoiceInput(BaseModel):
should_stream_answer: bool = True
# default to the prompt builder from the config, but
# allow overrides for arbitrary tool calls
prompt_snapshot: PromptSnapshot | None = None
# # TODO: adapt the tool choice/tool call to allow for parallel tool calls by
# # creating a subgraph that can be invoked in parallel via Send/Command APIs
# class ToolChoiceInput(BaseModel):
# should_stream_answer: bool = True
# # default to the prompt builder from the config, but
# # allow overrides for arbitrary tool calls
# prompt_snapshot: PromptSnapshot | None = None
# names of tools to use for tool calling. Filters the tools available in the config
tools: list[str] = []
# # names of tools to use for tool calling. Filters the tools available in the config
# tools: list[str] = []
class ToolCallOutput(BaseModel):
tool_call_summary: ToolCallSummary
tool_call_kickoff: ToolCallKickoff
tool_call_responses: list[ToolResponse]
tool_call_final_result: ToolCallFinalResult
# class ToolCallOutput(BaseModel):
# tool_call_summary: ToolCallSummary
# tool_call_kickoff: ToolCallKickoff
# tool_call_responses: list[ToolResponse]
# tool_call_final_result: ToolCallFinalResult
class ToolCallUpdate(BaseModel):
tool_call_output: ToolCallOutput | None = None
# class ToolCallUpdate(BaseModel):
# tool_call_output: ToolCallOutput | None = None
class ToolChoice(BaseModel):
tool: Tool
tool_args: dict
id: str | None
search_tool_override_kwargs: SearchToolOverrideKwargs = SearchToolOverrideKwargs()
# class ToolChoice(BaseModel):
# tool: Tool
# tool_args: dict
# id: str | None
# search_tool_override_kwargs: SearchToolOverrideKwargs = SearchToolOverrideKwargs()
model_config = ConfigDict(arbitrary_types_allowed=True)
# class Config:
# arbitrary_types_allowed = True
class ToolChoiceUpdate(BaseModel):
tool_choice: ToolChoice | None = None
# class ToolChoiceUpdate(BaseModel):
# tool_choice: ToolChoice | None = None
class ToolChoiceState(ToolChoiceUpdate, ToolChoiceInput):
pass
# class ToolChoiceState(ToolChoiceUpdate, ToolChoiceInput):
# pass

View File

@@ -1,93 +1,93 @@
from collections.abc import Iterable
from typing import cast
# from collections.abc import Iterable
# from typing import cast
from langchain_core.runnables.schema import CustomStreamEvent
from langchain_core.runnables.schema import StreamEvent
from langfuse.langchain import CallbackHandler
from langgraph.graph.state import CompiledStateGraph
# from langchain_core.runnables.schema import CustomStreamEvent
# from langchain_core.runnables.schema import StreamEvent
# from langfuse.langchain import CallbackHandler
# from langgraph.graph.state import CompiledStateGraph
from onyx.agents.agent_search.dc_search_analysis.graph_builder import (
divide_and_conquer_graph_builder,
)
from onyx.agents.agent_search.dc_search_analysis.states import MainInput as DCMainInput
from onyx.agents.agent_search.dr.graph_builder import dr_graph_builder
from onyx.agents.agent_search.dr.states import MainInput as DRMainInput
from onyx.agents.agent_search.kb_search.graph_builder import kb_graph_builder
from onyx.agents.agent_search.kb_search.states import MainInput as KBMainInput
from onyx.agents.agent_search.models import GraphConfig
from onyx.chat.models import AnswerStream
from onyx.configs.app_configs import LANGFUSE_PUBLIC_KEY
from onyx.configs.app_configs import LANGFUSE_SECRET_KEY
from onyx.server.query_and_chat.streaming_models import Packet
from onyx.utils.logger import setup_logger
# from onyx.agents.agent_search.dc_search_analysis.graph_builder import (
# divide_and_conquer_graph_builder,
# )
# from onyx.agents.agent_search.dc_search_analysis.states import MainInput as DCMainInput
# from onyx.agents.agent_search.dr.graph_builder import dr_graph_builder
# from onyx.agents.agent_search.dr.states import MainInput as DRMainInput
# from onyx.agents.agent_search.kb_search.graph_builder import kb_graph_builder
# from onyx.agents.agent_search.kb_search.states import MainInput as KBMainInput
# from onyx.agents.agent_search.models import GraphConfig
# from onyx.chat.models import AnswerStream
# from onyx.configs.app_configs import LANGFUSE_PUBLIC_KEY
# from onyx.configs.app_configs import LANGFUSE_SECRET_KEY
# from onyx.server.query_and_chat.streaming_models import Packet
# from onyx.utils.logger import setup_logger
logger = setup_logger()
GraphInput = DCMainInput | KBMainInput | DRMainInput
# logger = setup_logger()
# GraphInput = DCMainInput | KBMainInput | DRMainInput
def manage_sync_streaming(
compiled_graph: CompiledStateGraph,
config: GraphConfig,
graph_input: GraphInput,
) -> Iterable[StreamEvent]:
message_id = config.persistence.message_id if config.persistence else None
callbacks: list[CallbackHandler] = []
if LANGFUSE_SECRET_KEY and LANGFUSE_PUBLIC_KEY:
callbacks.append(CallbackHandler())
for event in compiled_graph.stream(
stream_mode="custom",
input=graph_input,
config={
"metadata": {"config": config, "thread_id": str(message_id)},
"callbacks": callbacks, # type: ignore
},
):
yield cast(CustomStreamEvent, event)
# def manage_sync_streaming(
# compiled_graph: CompiledStateGraph,
# config: GraphConfig,
# graph_input: GraphInput,
# ) -> Iterable[StreamEvent]:
# message_id = config.persistence.message_id if config.persistence else None
# callbacks: list[CallbackHandler] = []
# if LANGFUSE_SECRET_KEY and LANGFUSE_PUBLIC_KEY:
# callbacks.append(CallbackHandler())
# for event in compiled_graph.stream(
# stream_mode="custom",
# input=graph_input,
# config={
# "metadata": {"config": config, "thread_id": str(message_id)},
# "callbacks": callbacks, # type: ignore
# },
# ):
# yield cast(CustomStreamEvent, event)
def run_graph(
compiled_graph: CompiledStateGraph,
config: GraphConfig,
input: GraphInput,
) -> AnswerStream:
# def run_graph(
# compiled_graph: CompiledStateGraph,
# config: GraphConfig,
# input: GraphInput,
# ) -> AnswerStream:
for event in manage_sync_streaming(
compiled_graph=compiled_graph, config=config, graph_input=input
):
# for event in manage_sync_streaming(
# compiled_graph=compiled_graph, config=config, graph_input=input
# ):
yield cast(Packet, event["data"])
# yield cast(Packet, event["data"])
def run_kb_graph(
config: GraphConfig,
) -> AnswerStream:
graph = kb_graph_builder()
compiled_graph = graph.compile()
input = KBMainInput(
log_messages=[], question=config.inputs.prompt_builder.raw_user_query
)
# def run_kb_graph(
# config: GraphConfig,
# ) -> AnswerStream:
# graph = kb_graph_builder()
# compiled_graph = graph.compile()
# input = KBMainInput(
# log_messages=[], question=config.inputs.prompt_builder.raw_user_query
# )
yield from run_graph(compiled_graph, config, input)
# yield from run_graph(compiled_graph, config, input)
def run_dr_graph(
config: GraphConfig,
) -> AnswerStream:
graph = dr_graph_builder()
compiled_graph = graph.compile()
input = DRMainInput(log_messages=[])
# def run_dr_graph(
# config: GraphConfig,
# ) -> AnswerStream:
# graph = dr_graph_builder()
# compiled_graph = graph.compile()
# input = DRMainInput(log_messages=[])
yield from run_graph(compiled_graph, config, input)
# yield from run_graph(compiled_graph, config, input)
def run_dc_graph(
config: GraphConfig,
) -> AnswerStream:
graph = divide_and_conquer_graph_builder()
compiled_graph = graph.compile()
input = DCMainInput(log_messages=[])
config.inputs.prompt_builder.raw_user_query = (
config.inputs.prompt_builder.raw_user_query.strip()
)
return run_graph(compiled_graph, config, input)
# def run_dc_graph(
# config: GraphConfig,
# ) -> AnswerStream:
# graph = divide_and_conquer_graph_builder()
# compiled_graph = graph.compile()
# input = DCMainInput(log_messages=[])
# config.inputs.prompt_builder.raw_user_query = (
# config.inputs.prompt_builder.raw_user_query.strip()
# )
# return run_graph(compiled_graph, config, input)

View File

@@ -1,176 +1,176 @@
from langchain.schema import AIMessage
from langchain.schema import HumanMessage
from langchain.schema import SystemMessage
from langchain_core.messages.tool import ToolMessage
# from langchain.schema import AIMessage
# from langchain.schema import HumanMessage
# from langchain.schema import SystemMessage
# from langchain_core.messages.tool import ToolMessage
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.models import (
AgentPromptEnrichmentComponents,
)
from onyx.agents.agent_search.shared_graph_utils.utils import format_docs
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_persona_agent_prompt_expressions,
)
from onyx.agents.agent_search.shared_graph_utils.utils import remove_document_citations
from onyx.agents.agent_search.shared_graph_utils.utils import summarize_history
from onyx.configs.agent_configs import AGENT_MAX_STATIC_HISTORY_WORD_LENGTH
from onyx.configs.constants import MessageType
from onyx.context.search.models import InferenceSection
from onyx.llm.interfaces import LLMConfig
from onyx.natural_language_processing.utils import get_tokenizer
from onyx.natural_language_processing.utils import tokenizer_trim_content
from onyx.prompts.agent_search import HISTORY_FRAMING_PROMPT
from onyx.prompts.agent_search import SUB_QUESTION_RAG_PROMPT
from onyx.prompts.prompt_utils import build_date_time_string
from onyx.utils.logger import setup_logger
# from onyx.agents.agent_search.models import GraphConfig
# from onyx.agents.agent_search.shared_graph_utils.models import (
# AgentPromptEnrichmentComponents,
# )
# from onyx.agents.agent_search.shared_graph_utils.utils import format_docs
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_persona_agent_prompt_expressions,
# )
# from onyx.agents.agent_search.shared_graph_utils.utils import remove_document_citations
# from onyx.agents.agent_search.shared_graph_utils.utils import summarize_history
# from onyx.configs.agent_configs import AGENT_MAX_STATIC_HISTORY_WORD_LENGTH
# from onyx.configs.constants import MessageType
# from onyx.context.search.models import InferenceSection
# from onyx.llm.interfaces import LLMConfig
# from onyx.natural_language_processing.utils import get_tokenizer
# from onyx.natural_language_processing.utils import tokenizer_trim_content
# from onyx.prompts.agent_search import HISTORY_FRAMING_PROMPT
# from onyx.prompts.agent_search import SUB_QUESTION_RAG_PROMPT
# from onyx.prompts.prompt_utils import build_date_time_string
# from onyx.utils.logger import setup_logger
logger = setup_logger()
# logger = setup_logger()
def build_sub_question_answer_prompt(
question: str,
original_question: str,
docs: list[InferenceSection],
persona_specification: str,
config: LLMConfig,
) -> list[SystemMessage | HumanMessage | AIMessage | ToolMessage]:
system_message = SystemMessage(
content=persona_specification,
)
# def build_sub_question_answer_prompt(
# question: str,
# original_question: str,
# docs: list[InferenceSection],
# persona_specification: str,
# config: LLMConfig,
# ) -> list[SystemMessage | HumanMessage | AIMessage | ToolMessage]:
# system_message = SystemMessage(
# content=persona_specification,
# )
date_str = build_date_time_string()
# date_str = build_date_time_string()
docs_str = format_docs(docs)
# docs_str = format_docs(docs)
docs_str = trim_prompt_piece(
config=config,
prompt_piece=docs_str,
reserved_str=SUB_QUESTION_RAG_PROMPT + question + original_question + date_str,
)
human_message = HumanMessage(
content=SUB_QUESTION_RAG_PROMPT.format(
question=question,
original_question=original_question,
context=docs_str,
date_prompt=date_str,
)
)
# docs_str = trim_prompt_piece(
# config=config,
# prompt_piece=docs_str,
# reserved_str=SUB_QUESTION_RAG_PROMPT + question + original_question + date_str,
# )
# human_message = HumanMessage(
# content=SUB_QUESTION_RAG_PROMPT.format(
# question=question,
# original_question=original_question,
# context=docs_str,
# date_prompt=date_str,
# )
# )
return [system_message, human_message]
# return [system_message, human_message]
def trim_prompt_piece(config: LLMConfig, prompt_piece: str, reserved_str: str) -> str:
# no need to trim if a conservative estimate of one token
# per character is already less than the max tokens
if len(prompt_piece) + len(reserved_str) < config.max_input_tokens:
return prompt_piece
# def trim_prompt_piece(config: LLMConfig, prompt_piece: str, reserved_str: str) -> str:
# # no need to trim if a conservative estimate of one token
# # per character is already less than the max tokens
# if len(prompt_piece) + len(reserved_str) < config.max_input_tokens:
# return prompt_piece
llm_tokenizer = get_tokenizer(
provider_type=config.model_provider,
model_name=config.model_name,
)
# llm_tokenizer = get_tokenizer(
# provider_type=config.model_provider,
# model_name=config.model_name,
# )
# slightly conservative trimming
return tokenizer_trim_content(
content=prompt_piece,
desired_length=config.max_input_tokens
- len(llm_tokenizer.encode(reserved_str)),
tokenizer=llm_tokenizer,
)
# # slightly conservative trimming
# return tokenizer_trim_content(
# content=prompt_piece,
# desired_length=config.max_input_tokens
# - len(llm_tokenizer.encode(reserved_str)),
# tokenizer=llm_tokenizer,
# )
def build_history_prompt(config: GraphConfig, question: str) -> str:
prompt_builder = config.inputs.prompt_builder
persona_base = get_persona_agent_prompt_expressions(
config.inputs.persona, db_session=config.persistence.db_session
).base_prompt
# def build_history_prompt(config: GraphConfig, question: str) -> str:
# prompt_builder = config.inputs.prompt_builder
# persona_base = get_persona_agent_prompt_expressions(
# config.inputs.persona, db_session=config.persistence.db_session
# ).base_prompt
if prompt_builder is None:
return ""
# if prompt_builder is None:
# return ""
if prompt_builder.single_message_history is not None:
history = prompt_builder.single_message_history
else:
history_components = []
previous_message_type = None
for message in prompt_builder.raw_message_history:
if message.message_type == MessageType.USER:
history_components.append(f"User: {message.message}\n")
previous_message_type = MessageType.USER
elif message.message_type == MessageType.ASSISTANT:
# Previously there could be multiple assistant messages in a row
# Now this is handled at the message history construction
assert previous_message_type is not MessageType.ASSISTANT
history_components.append(f"You/Agent: {message.message}\n")
previous_message_type = MessageType.ASSISTANT
else:
# Other message types are not included here, currently there should be no other message types
logger.error(
f"Unhandled message type: {message.message_type} with message: {message.message}"
)
continue
# if prompt_builder.single_message_history is not None:
# history = prompt_builder.single_message_history
# else:
# history_components = []
# previous_message_type = None
# for message in prompt_builder.raw_message_history:
# if message.message_type == MessageType.USER:
# history_components.append(f"User: {message.message}\n")
# previous_message_type = MessageType.USER
# elif message.message_type == MessageType.ASSISTANT:
# # Previously there could be multiple assistant messages in a row
# # Now this is handled at the message history construction
# assert previous_message_type is not MessageType.ASSISTANT
# history_components.append(f"You/Agent: {message.message}\n")
# previous_message_type = MessageType.ASSISTANT
# else:
# # Other message types are not included here, currently there should be no other message types
# logger.error(
# f"Unhandled message type: {message.message_type} with message: {message.message}"
# )
# continue
history = "\n".join(history_components)
history = remove_document_citations(history)
if len(history.split()) > AGENT_MAX_STATIC_HISTORY_WORD_LENGTH:
history = summarize_history(
history=history,
question=question,
persona_specification=persona_base,
llm=config.tooling.fast_llm,
)
# history = "\n".join(history_components)
# history = remove_document_citations(history)
# if len(history.split()) > AGENT_MAX_STATIC_HISTORY_WORD_LENGTH:
# history = summarize_history(
# history=history,
# question=question,
# persona_specification=persona_base,
# llm=config.tooling.fast_llm,
# )
return HISTORY_FRAMING_PROMPT.format(history=history) if history else ""
# return HISTORY_FRAMING_PROMPT.format(history=history) if history else ""
def get_prompt_enrichment_components(
config: GraphConfig,
) -> AgentPromptEnrichmentComponents:
persona_prompts = get_persona_agent_prompt_expressions(
config.inputs.persona, db_session=config.persistence.db_session
)
# def get_prompt_enrichment_components(
# config: GraphConfig,
# ) -> AgentPromptEnrichmentComponents:
# persona_prompts = get_persona_agent_prompt_expressions(
# config.inputs.persona, db_session=config.persistence.db_session
# )
history = build_history_prompt(config, config.inputs.prompt_builder.raw_user_query)
# history = build_history_prompt(config, config.inputs.prompt_builder.raw_user_query)
date_str = build_date_time_string()
# date_str = build_date_time_string()
return AgentPromptEnrichmentComponents(
persona_prompts=persona_prompts,
history=history,
date_str=date_str,
)
# return AgentPromptEnrichmentComponents(
# persona_prompts=persona_prompts,
# history=history,
# date_str=date_str,
# )
def binary_string_test(text: str, positive_value: str = "yes") -> bool:
"""
Tests if a string contains a positive value (case-insensitive).
# def binary_string_test(text: str, positive_value: str = "yes") -> bool:
# """
# Tests if a string contains a positive value (case-insensitive).
Args:
text: The string to test
positive_value: The value to look for (defaults to "yes")
# Args:
# text: The string to test
# positive_value: The value to look for (defaults to "yes")
Returns:
True if the positive value is found in the text
"""
return positive_value.lower() in text.lower()
# Returns:
# True if the positive value is found in the text
# """
# return positive_value.lower() in text.lower()
def binary_string_test_after_answer_separator(
text: str, positive_value: str = "yes", separator: str = "Answer:"
) -> bool:
"""
Tests if a string contains a positive value (case-insensitive).
# def binary_string_test_after_answer_separator(
# text: str, positive_value: str = "yes", separator: str = "Answer:"
# ) -> bool:
# """
# Tests if a string contains a positive value (case-insensitive).
Args:
text: The string to test
positive_value: The value to look for (defaults to "yes")
# Args:
# text: The string to test
# positive_value: The value to look for (defaults to "yes")
Returns:
True if the positive value is found in the text
"""
# Returns:
# True if the positive value is found in the text
# """
if separator not in text:
return False
relevant_text = text.split(f"{separator}")[-1]
# if separator not in text:
# return False
# relevant_text = text.split(f"{separator}")[-1]
return binary_string_test(relevant_text, positive_value)
# return binary_string_test(relevant_text, positive_value)

View File

@@ -1,205 +1,205 @@
import numpy as np
# import numpy as np
from onyx.agents.agent_search.shared_graph_utils.models import AnswerGenerationDocuments
from onyx.agents.agent_search.shared_graph_utils.models import RetrievalFitScoreMetrics
from onyx.agents.agent_search.shared_graph_utils.models import RetrievalFitStats
from onyx.agents.agent_search.shared_graph_utils.operators import (
dedup_inference_section_list,
)
from onyx.chat.models import SectionRelevancePiece
from onyx.context.search.models import InferenceSection
from onyx.utils.logger import setup_logger
# from onyx.agents.agent_search.shared_graph_utils.models import AnswerGenerationDocuments
# from onyx.agents.agent_search.shared_graph_utils.models import RetrievalFitScoreMetrics
# from onyx.agents.agent_search.shared_graph_utils.models import RetrievalFitStats
# from onyx.agents.agent_search.shared_graph_utils.operators import (
# dedup_inference_section_list,
# )
# from onyx.chat.models import SectionRelevancePiece
# from onyx.context.search.models import InferenceSection
# from onyx.utils.logger import setup_logger
logger = setup_logger()
# logger = setup_logger()
def unique_chunk_id(doc: InferenceSection) -> str:
return f"{doc.center_chunk.document_id}_{doc.center_chunk.chunk_id}"
# def unique_chunk_id(doc: InferenceSection) -> str:
# return f"{doc.center_chunk.document_id}_{doc.center_chunk.chunk_id}"
def calculate_rank_shift(list1: list, list2: list, top_n: int = 20) -> float:
shift = 0
for rank_first, doc_id in enumerate(list1[:top_n], 1):
try:
rank_second = list2.index(doc_id) + 1
except ValueError:
rank_second = len(list2) # Document not found in second list
# def calculate_rank_shift(list1: list, list2: list, top_n: int = 20) -> float:
# shift = 0
# for rank_first, doc_id in enumerate(list1[:top_n], 1):
# try:
# rank_second = list2.index(doc_id) + 1
# except ValueError:
# rank_second = len(list2) # Document not found in second list
shift += np.abs(rank_first - rank_second) / np.log(1 + rank_first * rank_second)
# shift += np.abs(rank_first - rank_second) / np.log(1 + rank_first * rank_second)
return shift / top_n
# return shift / top_n
def get_fit_scores(
pre_reranked_results: list[InferenceSection],
post_reranked_results: list[InferenceSection] | list[SectionRelevancePiece],
) -> RetrievalFitStats | None:
"""
Calculate retrieval metrics for search purposes
"""
# def get_fit_scores(
# pre_reranked_results: list[InferenceSection],
# post_reranked_results: list[InferenceSection] | list[SectionRelevancePiece],
# ) -> RetrievalFitStats | None:
# """
# Calculate retrieval metrics for search purposes
# """
if len(pre_reranked_results) == 0 or len(post_reranked_results) == 0:
return None
# if len(pre_reranked_results) == 0 or len(post_reranked_results) == 0:
# return None
ranked_sections = {
"initial": pre_reranked_results,
"reranked": post_reranked_results,
}
# ranked_sections = {
# "initial": pre_reranked_results,
# "reranked": post_reranked_results,
# }
fit_eval: RetrievalFitStats = RetrievalFitStats(
fit_score_lift=0,
rerank_effect=0,
fit_scores={
"initial": RetrievalFitScoreMetrics(scores={}, chunk_ids=[]),
"reranked": RetrievalFitScoreMetrics(scores={}, chunk_ids=[]),
},
)
# fit_eval: RetrievalFitStats = RetrievalFitStats(
# fit_score_lift=0,
# rerank_effect=0,
# fit_scores={
# "initial": RetrievalFitScoreMetrics(scores={}, chunk_ids=[]),
# "reranked": RetrievalFitScoreMetrics(scores={}, chunk_ids=[]),
# },
# )
for rank_type, docs in ranked_sections.items():
logger.debug(f"rank_type: {rank_type}")
# for rank_type, docs in ranked_sections.items():
# logger.debug(f"rank_type: {rank_type}")
for i in [1, 5, 10]:
fit_eval.fit_scores[rank_type].scores[str(i)] = (
sum(
[
float(doc.center_chunk.score)
for doc in docs[:i]
if isinstance(doc, InferenceSection)
and doc.center_chunk.score is not None
]
)
/ i
)
# for i in [1, 5, 10]:
# fit_eval.fit_scores[rank_type].scores[str(i)] = (
# sum(
# [
# float(doc.center_chunk.score)
# for doc in docs[:i]
# if isinstance(doc, InferenceSection)
# and doc.center_chunk.score is not None
# ]
# )
# / i
# )
fit_eval.fit_scores[rank_type].scores["fit_score"] = (
1
/ 3
* (
fit_eval.fit_scores[rank_type].scores["1"]
+ fit_eval.fit_scores[rank_type].scores["5"]
+ fit_eval.fit_scores[rank_type].scores["10"]
)
)
# fit_eval.fit_scores[rank_type].scores["fit_score"] = (
# 1
# / 3
# * (
# fit_eval.fit_scores[rank_type].scores["1"]
# + fit_eval.fit_scores[rank_type].scores["5"]
# + fit_eval.fit_scores[rank_type].scores["10"]
# )
# )
fit_eval.fit_scores[rank_type].scores["fit_score"] = fit_eval.fit_scores[
rank_type
].scores["1"]
# fit_eval.fit_scores[rank_type].scores["fit_score"] = fit_eval.fit_scores[
# rank_type
# ].scores["1"]
fit_eval.fit_scores[rank_type].chunk_ids = [
unique_chunk_id(doc) for doc in docs if isinstance(doc, InferenceSection)
]
# fit_eval.fit_scores[rank_type].chunk_ids = [
# unique_chunk_id(doc) for doc in docs if isinstance(doc, InferenceSection)
# ]
fit_eval.fit_score_lift = (
fit_eval.fit_scores["reranked"].scores["fit_score"]
/ fit_eval.fit_scores["initial"].scores["fit_score"]
)
# fit_eval.fit_score_lift = (
# fit_eval.fit_scores["reranked"].scores["fit_score"]
# / fit_eval.fit_scores["initial"].scores["fit_score"]
# )
fit_eval.rerank_effect = calculate_rank_shift(
fit_eval.fit_scores["initial"].chunk_ids,
fit_eval.fit_scores["reranked"].chunk_ids,
)
# fit_eval.rerank_effect = calculate_rank_shift(
# fit_eval.fit_scores["initial"].chunk_ids,
# fit_eval.fit_scores["reranked"].chunk_ids,
# )
return fit_eval
# return fit_eval
def get_answer_generation_documents(
relevant_docs: list[InferenceSection],
context_documents: list[InferenceSection],
original_question_docs: list[InferenceSection],
max_docs: int,
) -> AnswerGenerationDocuments:
"""
Create a deduplicated list of documents to stream, prioritizing relevant docs.
# def get_answer_generation_documents(
# relevant_docs: list[InferenceSection],
# context_documents: list[InferenceSection],
# original_question_docs: list[InferenceSection],
# max_docs: int,
# ) -> AnswerGenerationDocuments:
# """
# Create a deduplicated list of documents to stream, prioritizing relevant docs.
Args:
relevant_docs: Primary documents to include
context_documents: Additional context documents to append
original_question_docs: Original question documents to append
max_docs: Maximum number of documents to return
# Args:
# relevant_docs: Primary documents to include
# context_documents: Additional context documents to append
# original_question_docs: Original question documents to append
# max_docs: Maximum number of documents to return
Returns:
List of deduplicated documents, limited to max_docs
"""
# get relevant_doc ids
relevant_doc_ids = [doc.center_chunk.document_id for doc in relevant_docs]
# Returns:
# List of deduplicated documents, limited to max_docs
# """
# # get relevant_doc ids
# relevant_doc_ids = [doc.center_chunk.document_id for doc in relevant_docs]
# Start with relevant docs or fallback to original question docs
streaming_documents = relevant_docs.copy()
# # Start with relevant docs or fallback to original question docs
# streaming_documents = relevant_docs.copy()
# Use a set for O(1) lookups of document IDs
seen_doc_ids = {doc.center_chunk.document_id for doc in streaming_documents}
# # Use a set for O(1) lookups of document IDs
# seen_doc_ids = {doc.center_chunk.document_id for doc in streaming_documents}
# Combine additional documents to check in one iteration
additional_docs = context_documents + original_question_docs
for doc_idx, doc in enumerate(additional_docs):
doc_id = doc.center_chunk.document_id
if doc_id not in seen_doc_ids:
streaming_documents.append(doc)
seen_doc_ids.add(doc_id)
# # Combine additional documents to check in one iteration
# additional_docs = context_documents + original_question_docs
# for doc_idx, doc in enumerate(additional_docs):
# doc_id = doc.center_chunk.document_id
# if doc_id not in seen_doc_ids:
# streaming_documents.append(doc)
# seen_doc_ids.add(doc_id)
streaming_documents = dedup_inference_section_list(streaming_documents)
# streaming_documents = dedup_inference_section_list(streaming_documents)
relevant_streaming_docs = [
doc
for doc in streaming_documents
if doc.center_chunk.document_id in relevant_doc_ids
]
relevant_streaming_docs = dedup_sort_inference_section_list(relevant_streaming_docs)
# relevant_streaming_docs = [
# doc
# for doc in streaming_documents
# if doc.center_chunk.document_id in relevant_doc_ids
# ]
# relevant_streaming_docs = dedup_sort_inference_section_list(relevant_streaming_docs)
additional_streaming_docs = [
doc
for doc in streaming_documents
if doc.center_chunk.document_id not in relevant_doc_ids
]
additional_streaming_docs = dedup_sort_inference_section_list(
additional_streaming_docs
)
# additional_streaming_docs = [
# doc
# for doc in streaming_documents
# if doc.center_chunk.document_id not in relevant_doc_ids
# ]
# additional_streaming_docs = dedup_sort_inference_section_list(
# additional_streaming_docs
# )
for doc in additional_streaming_docs:
if doc.center_chunk.score:
doc.center_chunk.score += -2.0
else:
doc.center_chunk.score = -2.0
# for doc in additional_streaming_docs:
# if doc.center_chunk.score:
# doc.center_chunk.score += -2.0
# else:
# doc.center_chunk.score = -2.0
sorted_streaming_documents = relevant_streaming_docs + additional_streaming_docs
# sorted_streaming_documents = relevant_streaming_docs + additional_streaming_docs
return AnswerGenerationDocuments(
streaming_documents=sorted_streaming_documents[:max_docs],
context_documents=relevant_streaming_docs[:max_docs],
)
# return AnswerGenerationDocuments(
# streaming_documents=sorted_streaming_documents[:max_docs],
# context_documents=relevant_streaming_docs[:max_docs],
# )
def dedup_sort_inference_section_list(
sections: list[InferenceSection],
) -> list[InferenceSection]:
"""Deduplicates InferenceSections by document_id and sorts by score.
# def dedup_sort_inference_section_list(
# sections: list[InferenceSection],
# ) -> list[InferenceSection]:
# """Deduplicates InferenceSections by document_id and sorts by score.
Args:
sections: List of InferenceSections to deduplicate and sort
# Args:
# sections: List of InferenceSections to deduplicate and sort
Returns:
Deduplicated list of InferenceSections sorted by score in descending order
"""
# dedupe/merge with existing framework
sections = dedup_inference_section_list(sections)
# Returns:
# Deduplicated list of InferenceSections sorted by score in descending order
# """
# # dedupe/merge with existing framework
# sections = dedup_inference_section_list(sections)
# Use dict to deduplicate by document_id, keeping highest scored version
unique_sections: dict[str, InferenceSection] = {}
for section in sections:
doc_id = section.center_chunk.document_id
if doc_id not in unique_sections:
unique_sections[doc_id] = section
continue
# # Use dict to deduplicate by document_id, keeping highest scored version
# unique_sections: dict[str, InferenceSection] = {}
# for section in sections:
# doc_id = section.center_chunk.document_id
# if doc_id not in unique_sections:
# unique_sections[doc_id] = section
# continue
# Keep version with higher score
existing_score = unique_sections[doc_id].center_chunk.score or 0
new_score = section.center_chunk.score or 0
if new_score > existing_score:
unique_sections[doc_id] = section
# # Keep version with higher score
# existing_score = unique_sections[doc_id].center_chunk.score or 0
# new_score = section.center_chunk.score or 0
# if new_score > existing_score:
# unique_sections[doc_id] = section
# Sort by score in descending order, handling None scores
sorted_sections = sorted(
unique_sections.values(), key=lambda x: x.center_chunk.score or 0, reverse=True
)
# # Sort by score in descending order, handling None scores
# sorted_sections = sorted(
# unique_sections.values(), key=lambda x: x.center_chunk.score or 0, reverse=True
# )
return sorted_sections
# return sorted_sections

View File

@@ -1,24 +1,24 @@
from enum import Enum
# from enum import Enum
AGENT_LLM_TIMEOUT_MESSAGE = "The agent timed out. Please try again."
AGENT_LLM_ERROR_MESSAGE = "The agent encountered an error. Please try again."
AGENT_LLM_RATELIMIT_MESSAGE = (
"The agent encountered a rate limit error. Please try again."
)
LLM_ANSWER_ERROR_MESSAGE = "The question was not answered due to an LLM error."
# AGENT_LLM_TIMEOUT_MESSAGE = "The agent timed out. Please try again."
# AGENT_LLM_ERROR_MESSAGE = "The agent encountered an error. Please try again."
# AGENT_LLM_RATELIMIT_MESSAGE = (
# "The agent encountered a rate limit error. Please try again."
# )
# LLM_ANSWER_ERROR_MESSAGE = "The question was not answered due to an LLM error."
AGENT_POSITIVE_VALUE_STR = "yes"
AGENT_NEGATIVE_VALUE_STR = "no"
# AGENT_POSITIVE_VALUE_STR = "yes"
# AGENT_NEGATIVE_VALUE_STR = "no"
AGENT_ANSWER_SEPARATOR = "Answer:"
# AGENT_ANSWER_SEPARATOR = "Answer:"
EMBEDDING_KEY = "embedding"
IS_KEYWORD_KEY = "is_keyword"
KEYWORDS_KEY = "keywords"
# EMBEDDING_KEY = "embedding"
# IS_KEYWORD_KEY = "is_keyword"
# KEYWORDS_KEY = "keywords"
class AgentLLMErrorType(str, Enum):
TIMEOUT = "timeout"
RATE_LIMIT = "rate_limit"
GENERAL_ERROR = "general_error"
# class AgentLLMErrorType(str, Enum):
# TIMEOUT = "timeout"
# RATE_LIMIT = "rate_limit"
# GENERAL_ERROR = "general_error"

View File

@@ -1,243 +1,243 @@
import re
from datetime import datetime
from typing import cast
from typing import Literal
from typing import Type
from typing import TypeVar
# import re
# from datetime import datetime
# from typing import cast
# from typing import Literal
# from typing import Type
# from typing import TypeVar
from braintrust import traced
from langchain.schema.language_model import LanguageModelInput
from langchain_core.messages import HumanMessage
from langgraph.types import StreamWriter
from pydantic import BaseModel
# from braintrust import traced
# from langchain.schema.language_model import LanguageModelInput
# from langchain_core.messages import HumanMessage
# from langgraph.types import StreamWriter
# from pydantic import BaseModel
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.stream_processing.citation_processing import CitationProcessorGraph
from onyx.chat.stream_processing.citation_processing import LlmDoc
from onyx.llm.interfaces import LLM
from onyx.llm.interfaces import ToolChoiceOptions
from onyx.server.query_and_chat.streaming_models import CitationInfo
from onyx.server.query_and_chat.streaming_models import MessageDelta
from onyx.server.query_and_chat.streaming_models import ReasoningDelta
from onyx.server.query_and_chat.streaming_models import StreamingType
from onyx.utils.threadpool_concurrency import run_with_timeout
# from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
# from onyx.chat.stream_processing.citation_processing import CitationProcessorGraph
# from onyx.chat.stream_processing.citation_processing import LlmDoc
# from onyx.llm.interfaces import LLM
# from onyx.llm.interfaces import ToolChoiceOptions
# from onyx.server.query_and_chat.streaming_models import CitationInfo
# from onyx.server.query_and_chat.streaming_models import MessageDelta
# from onyx.server.query_and_chat.streaming_models import ReasoningDelta
# from onyx.server.query_and_chat.streaming_models import StreamingType
# from onyx.utils.threadpool_concurrency import run_with_timeout
SchemaType = TypeVar("SchemaType", bound=BaseModel)
# SchemaType = TypeVar("SchemaType", bound=BaseModel)
# match ```json{...}``` or ```{...}```
JSON_PATTERN = re.compile(r"```(?:json)?\s*(\{.*?\})\s*```", re.DOTALL)
# # match ```json{...}``` or ```{...}```
# JSON_PATTERN = re.compile(r"```(?:json)?\s*(\{.*?\})\s*```", re.DOTALL)
@traced(name="stream llm", type="llm")
def stream_llm_answer(
llm: LLM,
prompt: LanguageModelInput,
event_name: str,
writer: StreamWriter,
agent_answer_level: int,
agent_answer_question_num: int,
agent_answer_type: Literal["agent_level_answer", "agent_sub_answer"],
timeout_override: int | None = None,
max_tokens: int | None = None,
answer_piece: str | None = None,
ind: int | None = None,
context_docs: list[LlmDoc] | None = None,
replace_citations: bool = False,
) -> tuple[list[str], list[float], list[CitationInfo]]:
"""Stream the initial answer from the LLM.
# @traced(name="stream llm", type="llm")
# def stream_llm_answer(
# llm: LLM,
# prompt: LanguageModelInput,
# event_name: str,
# writer: StreamWriter,
# agent_answer_level: int,
# agent_answer_question_num: int,
# agent_answer_type: Literal["agent_level_answer", "agent_sub_answer"],
# timeout_override: int | None = None,
# max_tokens: int | None = None,
# answer_piece: str | None = None,
# ind: int | None = None,
# context_docs: list[LlmDoc] | None = None,
# replace_citations: bool = False,
# ) -> tuple[list[str], list[float], list[CitationInfo]]:
# """Stream the initial answer from the LLM.
Args:
llm: The LLM to use.
prompt: The prompt to use.
event_name: The name of the event to write.
writer: The writer to write to.
agent_answer_level: The level of the agent answer.
agent_answer_question_num: The question number within the level.
agent_answer_type: The type of answer ("agent_level_answer" or "agent_sub_answer").
timeout_override: The LLM timeout to use.
max_tokens: The LLM max tokens to use.
answer_piece: The type of answer piece to write.
ind: The index of the answer piece.
tools: The tools to use.
tool_choice: The tool choice to use.
structured_response_format: The structured response format to use.
# Args:
# llm: The LLM to use.
# prompt: The prompt to use.
# event_name: The name of the event to write.
# writer: The writer to write to.
# agent_answer_level: The level of the agent answer.
# agent_answer_question_num: The question number within the level.
# agent_answer_type: The type of answer ("agent_level_answer" or "agent_sub_answer").
# timeout_override: The LLM timeout to use.
# max_tokens: The LLM max tokens to use.
# answer_piece: The type of answer piece to write.
# ind: The index of the answer piece.
# tools: The tools to use.
# tool_choice: The tool choice to use.
# structured_response_format: The structured response format to use.
Returns:
A tuple of the response and the dispatch timings.
"""
response: list[str] = []
dispatch_timings: list[float] = []
citation_infos: list[CitationInfo] = []
# Returns:
# A tuple of the response and the dispatch timings.
# """
# response: list[str] = []
# dispatch_timings: list[float] = []
# citation_infos: list[CitationInfo] = []
if context_docs:
citation_processor = CitationProcessorGraph(
context_docs=context_docs,
)
else:
citation_processor = None
# if context_docs:
# citation_processor = CitationProcessorGraph(
# context_docs=context_docs,
# )
# else:
# citation_processor = None
for message in llm.stream_langchain(
prompt,
timeout_override=timeout_override,
max_tokens=max_tokens,
):
# for message in llm.stream_langchain(
# prompt,
# timeout_override=timeout_override,
# max_tokens=max_tokens,
# ):
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
content = message.content
if not isinstance(content, str):
raise ValueError(
f"Expected content to be a string, but got {type(content)}"
)
# # TODO: in principle, the answer here COULD contain images, but we don't support that yet
# content = message.content
# if not isinstance(content, str):
# raise ValueError(
# f"Expected content to be a string, but got {type(content)}"
# )
start_stream_token = datetime.now()
# start_stream_token = datetime.now()
if answer_piece == StreamingType.MESSAGE_DELTA.value:
if ind is None:
raise ValueError("index is required when answer_piece is message_delta")
# if answer_piece == StreamingType.MESSAGE_DELTA.value:
# if ind is None:
# raise ValueError("index is required when answer_piece is message_delta")
if citation_processor:
processed_token = citation_processor.process_token(content)
# if citation_processor:
# processed_token = citation_processor.process_token(content)
if isinstance(processed_token, tuple):
content = processed_token[0]
citation_infos.extend(processed_token[1])
elif isinstance(processed_token, str):
content = processed_token
else:
continue
# if isinstance(processed_token, tuple):
# content = processed_token[0]
# citation_infos.extend(processed_token[1])
# elif isinstance(processed_token, str):
# content = processed_token
# else:
# continue
write_custom_event(
ind,
MessageDelta(content=content),
writer,
)
# write_custom_event(
# ind,
# MessageDelta(content=content),
# writer,
# )
elif answer_piece == StreamingType.REASONING_DELTA.value:
if ind is None:
raise ValueError(
"index is required when answer_piece is reasoning_delta"
)
write_custom_event(
ind,
ReasoningDelta(reasoning=content),
writer,
)
# elif answer_piece == StreamingType.REASONING_DELTA.value:
# if ind is None:
# raise ValueError(
# "index is required when answer_piece is reasoning_delta"
# )
# write_custom_event(
# ind,
# ReasoningDelta(reasoning=content),
# writer,
# )
else:
raise ValueError(f"Invalid answer piece: {answer_piece}")
# else:
# raise ValueError(f"Invalid answer piece: {answer_piece}")
end_stream_token = datetime.now()
# end_stream_token = datetime.now()
dispatch_timings.append((end_stream_token - start_stream_token).microseconds)
response.append(content)
# dispatch_timings.append((end_stream_token - start_stream_token).microseconds)
# response.append(content)
return response, dispatch_timings, citation_infos
# return response, dispatch_timings, citation_infos
def invoke_llm_json(
llm: LLM,
prompt: LanguageModelInput,
schema: Type[SchemaType],
tools: list[dict] | None = None,
tool_choice: ToolChoiceOptions | None = None,
timeout_override: int | None = None,
max_tokens: int | None = None,
) -> SchemaType:
"""
Invoke an LLM, forcing it to respond in a specified JSON format if possible,
and return an object of that schema.
"""
from litellm.utils import get_supported_openai_params, supports_response_schema
# def invoke_llm_json(
# llm: LLM,
# prompt: LanguageModelInput,
# schema: Type[SchemaType],
# tools: list[dict] | None = None,
# tool_choice: ToolChoiceOptions | None = None,
# timeout_override: int | None = None,
# max_tokens: int | None = None,
# ) -> SchemaType:
# """
# Invoke an LLM, forcing it to respond in a specified JSON format if possible,
# and return an object of that schema.
# """
# from litellm.utils import get_supported_openai_params, supports_response_schema
# check if the model supports response_format: json_schema
supports_json = "response_format" in (
get_supported_openai_params(llm.config.model_name, llm.config.model_provider)
or []
) and supports_response_schema(llm.config.model_name, llm.config.model_provider)
# # check if the model supports response_format: json_schema
# supports_json = "response_format" in (
# get_supported_openai_params(llm.config.model_name, llm.config.model_provider)
# or []
# ) and supports_response_schema(llm.config.model_name, llm.config.model_provider)
response_content = str(
llm.invoke_langchain(
prompt,
tools=tools,
tool_choice=tool_choice,
timeout_override=timeout_override,
max_tokens=max_tokens,
**cast(
dict, {"structured_response_format": schema} if supports_json else {}
),
).content
)
# response_content = str(
# llm.invoke_langchain(
# prompt,
# tools=tools,
# tool_choice=tool_choice,
# timeout_override=timeout_override,
# max_tokens=max_tokens,
# **cast(
# dict, {"structured_response_format": schema} if supports_json else {}
# ),
# ).content
# )
if not supports_json:
# remove newlines as they often lead to json decoding errors
response_content = response_content.replace("\n", " ")
# hope the prompt is structured in a way a json is outputted...
json_block_match = JSON_PATTERN.search(response_content)
if json_block_match:
response_content = json_block_match.group(1)
else:
first_bracket = response_content.find("{")
last_bracket = response_content.rfind("}")
response_content = response_content[first_bracket : last_bracket + 1]
# if not supports_json:
# # remove newlines as they often lead to json decoding errors
# response_content = response_content.replace("\n", " ")
# # hope the prompt is structured in a way a json is outputted...
# json_block_match = JSON_PATTERN.search(response_content)
# if json_block_match:
# response_content = json_block_match.group(1)
# else:
# first_bracket = response_content.find("{")
# last_bracket = response_content.rfind("}")
# response_content = response_content[first_bracket : last_bracket + 1]
return schema.model_validate_json(response_content)
# return schema.model_validate_json(response_content)
def get_answer_from_llm(
llm: LLM,
prompt: str,
timeout: int = 25,
timeout_override: int = 5,
max_tokens: int = 500,
stream: bool = False,
writer: StreamWriter = lambda _: None,
agent_answer_level: int = 0,
agent_answer_question_num: int = 0,
agent_answer_type: Literal[
"agent_sub_answer", "agent_level_answer"
] = "agent_level_answer",
json_string_flag: bool = False,
) -> str:
msg = [
HumanMessage(
content=prompt,
)
]
# def get_answer_from_llm(
# llm: LLM,
# prompt: str,
# timeout: int = 25,
# timeout_override: int = 5,
# max_tokens: int = 500,
# stream: bool = False,
# writer: StreamWriter = lambda _: None,
# agent_answer_level: int = 0,
# agent_answer_question_num: int = 0,
# agent_answer_type: Literal[
# "agent_sub_answer", "agent_level_answer"
# ] = "agent_level_answer",
# json_string_flag: bool = False,
# ) -> str:
# msg = [
# HumanMessage(
# content=prompt,
# )
# ]
if stream:
# TODO - adjust for new UI. This is currently not working for current UI/Basic Search
stream_response, _, _ = run_with_timeout(
timeout,
lambda: stream_llm_answer(
llm=llm,
prompt=msg,
event_name="sub_answers",
writer=writer,
agent_answer_level=agent_answer_level,
agent_answer_question_num=agent_answer_question_num,
agent_answer_type=agent_answer_type,
timeout_override=timeout_override,
max_tokens=max_tokens,
),
)
content = "".join(stream_response)
else:
llm_response = run_with_timeout(
timeout,
llm.invoke_langchain,
prompt=msg,
timeout_override=timeout_override,
max_tokens=max_tokens,
)
content = str(llm_response.content)
# if stream:
# # TODO - adjust for new UI. This is currently not working for current UI/Basic Search
# stream_response, _, _ = run_with_timeout(
# timeout,
# lambda: stream_llm_answer(
# llm=llm,
# prompt=msg,
# event_name="sub_answers",
# writer=writer,
# agent_answer_level=agent_answer_level,
# agent_answer_question_num=agent_answer_question_num,
# agent_answer_type=agent_answer_type,
# timeout_override=timeout_override,
# max_tokens=max_tokens,
# ),
# )
# content = "".join(stream_response)
# else:
# llm_response = run_with_timeout(
# timeout,
# llm.invoke_langchain,
# prompt=msg,
# timeout_override=timeout_override,
# max_tokens=max_tokens,
# )
# content = str(llm_response.content)
cleaned_response = content
if json_string_flag:
cleaned_response = (
str(content).replace("```json\n", "").replace("\n```", "").replace("\n", "")
)
first_bracket = cleaned_response.find("{")
last_bracket = cleaned_response.rfind("}")
cleaned_response = cleaned_response[first_bracket : last_bracket + 1]
# cleaned_response = content
# if json_string_flag:
# cleaned_response = (
# str(content).replace("```json\n", "").replace("\n```", "").replace("\n", "")
# )
# first_bracket = cleaned_response.find("{")
# last_bracket = cleaned_response.rfind("}")
# cleaned_response = cleaned_response[first_bracket : last_bracket + 1]
return cleaned_response
# return cleaned_response

View File

@@ -1,166 +1,166 @@
from enum import Enum
from typing import Any
# from enum import Enum
# from typing import Any
from pydantic import BaseModel
# from pydantic import BaseModel
from onyx.agents.agent_search.deep_search.main.models import (
AgentAdditionalMetrics,
)
from onyx.agents.agent_search.deep_search.main.models import AgentBaseMetrics
from onyx.agents.agent_search.deep_search.main.models import (
AgentRefinedMetrics,
)
from onyx.agents.agent_search.deep_search.main.models import AgentTimings
from onyx.context.search.models import InferenceSection
from onyx.tools.models import SearchQueryInfo
# from onyx.agents.agent_search.deep_search.main.models import (
# AgentAdditionalMetrics,
# )
# from onyx.agents.agent_search.deep_search.main.models import AgentBaseMetrics
# from onyx.agents.agent_search.deep_search.main.models import (
# AgentRefinedMetrics,
# )
# from onyx.agents.agent_search.deep_search.main.models import AgentTimings
# from onyx.context.search.models import InferenceSection
# from onyx.tools.models import SearchQueryInfo
# Pydantic models for structured outputs
# class RewrittenQueries(BaseModel):
# rewritten_queries: list[str]
# # Pydantic models for structured outputs
# # class RewrittenQueries(BaseModel):
# # rewritten_queries: list[str]
# class BinaryDecision(BaseModel):
# decision: Literal["yes", "no"]
# # class BinaryDecision(BaseModel):
# # decision: Literal["yes", "no"]
# class BinaryDecisionWithReasoning(BaseModel):
# reasoning: str
# decision: Literal["yes", "no"]
# # class BinaryDecisionWithReasoning(BaseModel):
# # reasoning: str
# # decision: Literal["yes", "no"]
class RetrievalFitScoreMetrics(BaseModel):
scores: dict[str, float]
chunk_ids: list[str]
# class RetrievalFitScoreMetrics(BaseModel):
# scores: dict[str, float]
# chunk_ids: list[str]
class RetrievalFitStats(BaseModel):
fit_score_lift: float
rerank_effect: float
fit_scores: dict[str, RetrievalFitScoreMetrics]
# class RetrievalFitStats(BaseModel):
# fit_score_lift: float
# rerank_effect: float
# fit_scores: dict[str, RetrievalFitScoreMetrics]
# class AgentChunkScores(BaseModel):
# scores: dict[str, dict[str, list[int | float]]]
# # class AgentChunkScores(BaseModel):
# # scores: dict[str, dict[str, list[int | float]]]
class AgentChunkRetrievalStats(BaseModel):
verified_count: int | None = None
verified_avg_scores: float | None = None
rejected_count: int | None = None
rejected_avg_scores: float | None = None
verified_doc_chunk_ids: list[str] = []
dismissed_doc_chunk_ids: list[str] = []
# class AgentChunkRetrievalStats(BaseModel):
# verified_count: int | None = None
# verified_avg_scores: float | None = None
# rejected_count: int | None = None
# rejected_avg_scores: float | None = None
# verified_doc_chunk_ids: list[str] = []
# dismissed_doc_chunk_ids: list[str] = []
class InitialAgentResultStats(BaseModel):
sub_questions: dict[str, float | int | None]
original_question: dict[str, float | int | None]
agent_effectiveness: dict[str, float | int | None]
# class InitialAgentResultStats(BaseModel):
# sub_questions: dict[str, float | int | None]
# original_question: dict[str, float | int | None]
# agent_effectiveness: dict[str, float | int | None]
class AgentErrorLog(BaseModel):
error_message: str
error_type: str
error_result: str
# class AgentErrorLog(BaseModel):
# error_message: str
# error_type: str
# error_result: str
class RefinedAgentStats(BaseModel):
revision_doc_efficiency: float | None
revision_question_efficiency: float | None
# class RefinedAgentStats(BaseModel):
# revision_doc_efficiency: float | None
# revision_question_efficiency: float | None
class Term(BaseModel):
term_name: str = ""
term_type: str = ""
term_similar_to: list[str] = []
# class Term(BaseModel):
# term_name: str = ""
# term_type: str = ""
# term_similar_to: list[str] = []
### Models ###
# ### Models ###
class Entity(BaseModel):
entity_name: str = ""
entity_type: str = ""
# class Entity(BaseModel):
# entity_name: str = ""
# entity_type: str = ""
class Relationship(BaseModel):
relationship_name: str = ""
relationship_type: str = ""
relationship_entities: list[str] = []
# class Relationship(BaseModel):
# relationship_name: str = ""
# relationship_type: str = ""
# relationship_entities: list[str] = []
class EntityRelationshipTermExtraction(BaseModel):
entities: list[Entity] = []
relationships: list[Relationship] = []
terms: list[Term] = []
# class EntityRelationshipTermExtraction(BaseModel):
# entities: list[Entity] = []
# relationships: list[Relationship] = []
# terms: list[Term] = []
class EntityExtractionResult(BaseModel):
retrieved_entities_relationships: EntityRelationshipTermExtraction
# class EntityExtractionResult(BaseModel):
# retrieved_entities_relationships: EntityRelationshipTermExtraction
class QueryRetrievalResult(BaseModel):
query: str
retrieved_documents: list[InferenceSection]
stats: RetrievalFitStats | None
query_info: SearchQueryInfo | None
# class QueryRetrievalResult(BaseModel):
# query: str
# retrieved_documents: list[InferenceSection]
# stats: RetrievalFitStats | None
# query_info: SearchQueryInfo | None
class SubQuestionAnswerResults(BaseModel):
question: str
question_id: str
answer: str
verified_high_quality: bool
sub_query_retrieval_results: list[QueryRetrievalResult]
verified_reranked_documents: list[InferenceSection]
context_documents: list[InferenceSection]
cited_documents: list[InferenceSection]
sub_question_retrieval_stats: AgentChunkRetrievalStats
# class SubQuestionAnswerResults(BaseModel):
# question: str
# question_id: str
# answer: str
# verified_high_quality: bool
# sub_query_retrieval_results: list[QueryRetrievalResult]
# verified_reranked_documents: list[InferenceSection]
# context_documents: list[InferenceSection]
# cited_documents: list[InferenceSection]
# sub_question_retrieval_stats: AgentChunkRetrievalStats
class StructuredSubquestionDocuments(BaseModel):
cited_documents: list[InferenceSection]
context_documents: list[InferenceSection]
# class StructuredSubquestionDocuments(BaseModel):
# cited_documents: list[InferenceSection]
# context_documents: list[InferenceSection]
class CombinedAgentMetrics(BaseModel):
timings: AgentTimings
base_metrics: AgentBaseMetrics | None
refined_metrics: AgentRefinedMetrics
additional_metrics: AgentAdditionalMetrics
# class CombinedAgentMetrics(BaseModel):
# timings: AgentTimings
# base_metrics: AgentBaseMetrics | None
# refined_metrics: AgentRefinedMetrics
# additional_metrics: AgentAdditionalMetrics
class PersonaPromptExpressions(BaseModel):
contextualized_prompt: str
base_prompt: str | None
# class PersonaPromptExpressions(BaseModel):
# contextualized_prompt: str
# base_prompt: str | None
class AgentPromptEnrichmentComponents(BaseModel):
persona_prompts: PersonaPromptExpressions
history: str
date_str: str
# class AgentPromptEnrichmentComponents(BaseModel):
# persona_prompts: PersonaPromptExpressions
# history: str
# date_str: str
class LLMNodeErrorStrings(BaseModel):
timeout: str = "LLM Timeout Error"
rate_limit: str = "LLM Rate Limit Error"
general_error: str = "General LLM Error"
# class LLMNodeErrorStrings(BaseModel):
# timeout: str = "LLM Timeout Error"
# rate_limit: str = "LLM Rate Limit Error"
# general_error: str = "General LLM Error"
class AnswerGenerationDocuments(BaseModel):
streaming_documents: list[InferenceSection]
context_documents: list[InferenceSection]
# class AnswerGenerationDocuments(BaseModel):
# streaming_documents: list[InferenceSection]
# context_documents: list[InferenceSection]
BaseMessage_Content = str | list[str | dict[str, Any]]
# BaseMessage_Content = str | list[str | dict[str, Any]]
class QueryExpansionType(Enum):
KEYWORD = "keyword"
SEMANTIC = "semantic"
# class QueryExpansionType(Enum):
# KEYWORD = "keyword"
# SEMANTIC = "semantic"
class ReferenceResults(BaseModel):
citations: list[str]
general_entities: list[str]
# class ReferenceResults(BaseModel):
# citations: list[str]
# general_entities: list[str]

Some files were not shown because too many files have changed in this diff Show More