mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-16 23:35:46 +00:00
Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8a38fdf8a5 | ||
|
|
9155d4aa21 | ||
|
|
b20591611a |
2
.gitignore
vendored
2
.gitignore
vendored
@@ -49,5 +49,7 @@ CLAUDE.md
|
||||
# Local .terraform.lock.hcl file
|
||||
.terraform.lock.hcl
|
||||
|
||||
node_modules
|
||||
|
||||
# MCP configs
|
||||
.playwright-mcp
|
||||
|
||||
104
backend/alembic/versions/4f8a2b3c1d9e_add_open_url_tool.py
Normal file
104
backend/alembic/versions/4f8a2b3c1d9e_add_open_url_tool.py
Normal file
@@ -0,0 +1,104 @@
|
||||
"""add_open_url_tool
|
||||
|
||||
Revision ID: 4f8a2b3c1d9e
|
||||
Revises: a852cbe15577
|
||||
Create Date: 2025-11-24 12:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "4f8a2b3c1d9e"
|
||||
down_revision = "a852cbe15577"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
OPEN_URL_TOOL = {
|
||||
"name": "OpenURLTool",
|
||||
"display_name": "Open URL",
|
||||
"description": (
|
||||
"The Open URL Action allows the agent to fetch and read contents of web pages."
|
||||
),
|
||||
"in_code_tool_id": "OpenURLTool",
|
||||
"enabled": True,
|
||||
}
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
# Check if tool already exists
|
||||
existing = conn.execute(
|
||||
sa.text("SELECT id FROM tool WHERE in_code_tool_id = :in_code_tool_id"),
|
||||
{"in_code_tool_id": OPEN_URL_TOOL["in_code_tool_id"]},
|
||||
).fetchone()
|
||||
|
||||
if existing:
|
||||
tool_id = existing[0]
|
||||
# Update existing tool
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE tool
|
||||
SET name = :name,
|
||||
display_name = :display_name,
|
||||
description = :description
|
||||
WHERE in_code_tool_id = :in_code_tool_id
|
||||
"""
|
||||
),
|
||||
OPEN_URL_TOOL,
|
||||
)
|
||||
else:
|
||||
# Insert new tool
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO tool (name, display_name, description, in_code_tool_id, enabled)
|
||||
VALUES (:name, :display_name, :description, :in_code_tool_id, :enabled)
|
||||
"""
|
||||
),
|
||||
OPEN_URL_TOOL,
|
||||
)
|
||||
# Get the newly inserted tool's id
|
||||
result = conn.execute(
|
||||
sa.text("SELECT id FROM tool WHERE in_code_tool_id = :in_code_tool_id"),
|
||||
{"in_code_tool_id": OPEN_URL_TOOL["in_code_tool_id"]},
|
||||
).fetchone()
|
||||
tool_id = result[0] # type: ignore
|
||||
|
||||
# Associate the tool with all existing personas
|
||||
# Get all persona IDs
|
||||
persona_ids = conn.execute(sa.text("SELECT id FROM persona")).fetchall()
|
||||
|
||||
for (persona_id,) in persona_ids:
|
||||
# Check if association already exists
|
||||
exists = conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT 1 FROM persona__tool
|
||||
WHERE persona_id = :persona_id AND tool_id = :tool_id
|
||||
"""
|
||||
),
|
||||
{"persona_id": persona_id, "tool_id": tool_id},
|
||||
).fetchone()
|
||||
|
||||
if not exists:
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO persona__tool (persona_id, tool_id)
|
||||
VALUES (:persona_id, :tool_id)
|
||||
"""
|
||||
),
|
||||
{"persona_id": persona_id, "tool_id": tool_id},
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# We don't remove the tool on downgrade since it's fine to have it around.
|
||||
# If we upgrade again, it will be a no-op.
|
||||
pass
|
||||
572
backend/alembic/versions/a852cbe15577_new_chat_history.py
Normal file
572
backend/alembic/versions/a852cbe15577_new_chat_history.py
Normal file
@@ -0,0 +1,572 @@
|
||||
"""New Chat History
|
||||
|
||||
Revision ID: a852cbe15577
|
||||
Revises: 6436661d5b65
|
||||
Create Date: 2025-11-08 15:16:37.781308
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "a852cbe15577"
|
||||
down_revision = "6436661d5b65"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Drop research agent tables (if they exist)
|
||||
op.execute("DROP TABLE IF EXISTS research_agent_iteration_sub_step CASCADE")
|
||||
op.execute("DROP TABLE IF EXISTS research_agent_iteration CASCADE")
|
||||
|
||||
# Drop agent sub query and sub question tables (if they exist)
|
||||
op.execute("DROP TABLE IF EXISTS agent__sub_query__search_doc CASCADE")
|
||||
op.execute("DROP TABLE IF EXISTS agent__sub_query CASCADE")
|
||||
op.execute("DROP TABLE IF EXISTS agent__sub_question CASCADE")
|
||||
|
||||
# Update ChatMessage table
|
||||
# Rename parent_message to parent_message_id and make it a foreign key (if not already done)
|
||||
conn = op.get_bind()
|
||||
result = conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT column_name FROM information_schema.columns
|
||||
WHERE table_name = 'chat_message' AND column_name = 'parent_message'
|
||||
"""
|
||||
)
|
||||
)
|
||||
if result.fetchone():
|
||||
op.alter_column(
|
||||
"chat_message", "parent_message", new_column_name="parent_message_id"
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"fk_chat_message_parent_message_id",
|
||||
"chat_message",
|
||||
"chat_message",
|
||||
["parent_message_id"],
|
||||
["id"],
|
||||
)
|
||||
|
||||
# Rename latest_child_message to latest_child_message_id and make it a foreign key (if not already done)
|
||||
result = conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT column_name FROM information_schema.columns
|
||||
WHERE table_name = 'chat_message' AND column_name = 'latest_child_message'
|
||||
"""
|
||||
)
|
||||
)
|
||||
if result.fetchone():
|
||||
op.alter_column(
|
||||
"chat_message",
|
||||
"latest_child_message",
|
||||
new_column_name="latest_child_message_id",
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"fk_chat_message_latest_child_message_id",
|
||||
"chat_message",
|
||||
"chat_message",
|
||||
["latest_child_message_id"],
|
||||
["id"],
|
||||
)
|
||||
|
||||
# Add reasoning_tokens column (if not exists)
|
||||
result = conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT column_name FROM information_schema.columns
|
||||
WHERE table_name = 'chat_message' AND column_name = 'reasoning_tokens'
|
||||
"""
|
||||
)
|
||||
)
|
||||
if not result.fetchone():
|
||||
op.add_column(
|
||||
"chat_message", sa.Column("reasoning_tokens", sa.Text(), nullable=True)
|
||||
)
|
||||
|
||||
# Drop columns no longer needed (if they exist)
|
||||
for col in [
|
||||
"rephrased_query",
|
||||
"alternate_assistant_id",
|
||||
"overridden_model",
|
||||
"is_agentic",
|
||||
"refined_answer_improvement",
|
||||
"research_type",
|
||||
"research_plan",
|
||||
"research_answer_purpose",
|
||||
]:
|
||||
result = conn.execute(
|
||||
sa.text(
|
||||
f"""
|
||||
SELECT column_name FROM information_schema.columns
|
||||
WHERE table_name = 'chat_message' AND column_name = '{col}'
|
||||
"""
|
||||
)
|
||||
)
|
||||
if result.fetchone():
|
||||
op.drop_column("chat_message", col)
|
||||
|
||||
# Update ToolCall table
|
||||
# Add chat_session_id column (if not exists)
|
||||
result = conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT column_name FROM information_schema.columns
|
||||
WHERE table_name = 'tool_call' AND column_name = 'chat_session_id'
|
||||
"""
|
||||
)
|
||||
)
|
||||
if not result.fetchone():
|
||||
op.add_column(
|
||||
"tool_call",
|
||||
sa.Column("chat_session_id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"fk_tool_call_chat_session_id",
|
||||
"tool_call",
|
||||
"chat_session",
|
||||
["chat_session_id"],
|
||||
["id"],
|
||||
)
|
||||
|
||||
# Rename message_id to parent_chat_message_id and make nullable (if not already done)
|
||||
result = conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT column_name FROM information_schema.columns
|
||||
WHERE table_name = 'tool_call' AND column_name = 'message_id'
|
||||
"""
|
||||
)
|
||||
)
|
||||
if result.fetchone():
|
||||
op.alter_column(
|
||||
"tool_call",
|
||||
"message_id",
|
||||
new_column_name="parent_chat_message_id",
|
||||
nullable=True,
|
||||
)
|
||||
|
||||
# Add parent_tool_call_id (if not exists)
|
||||
result = conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT column_name FROM information_schema.columns
|
||||
WHERE table_name = 'tool_call' AND column_name = 'parent_tool_call_id'
|
||||
"""
|
||||
)
|
||||
)
|
||||
if not result.fetchone():
|
||||
op.add_column(
|
||||
"tool_call", sa.Column("parent_tool_call_id", sa.Integer(), nullable=True)
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"fk_tool_call_parent_tool_call_id",
|
||||
"tool_call",
|
||||
"tool_call",
|
||||
["parent_tool_call_id"],
|
||||
["id"],
|
||||
)
|
||||
op.drop_constraint("uq_tool_call_message_id", "tool_call", type_="unique")
|
||||
|
||||
# Add turn_number, tool_id (if not exists)
|
||||
for col_name in ["turn_number", "tool_id"]:
|
||||
result = conn.execute(
|
||||
sa.text(
|
||||
f"""
|
||||
SELECT column_name FROM information_schema.columns
|
||||
WHERE table_name = 'tool_call' AND column_name = '{col_name}'
|
||||
"""
|
||||
)
|
||||
)
|
||||
if not result.fetchone():
|
||||
op.add_column(
|
||||
"tool_call",
|
||||
sa.Column(col_name, sa.Integer(), nullable=False, server_default="0"),
|
||||
)
|
||||
|
||||
# Add tool_call_id as String (if not exists)
|
||||
result = conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT column_name FROM information_schema.columns
|
||||
WHERE table_name = 'tool_call' AND column_name = 'tool_call_id'
|
||||
"""
|
||||
)
|
||||
)
|
||||
if not result.fetchone():
|
||||
op.add_column(
|
||||
"tool_call",
|
||||
sa.Column("tool_call_id", sa.String(), nullable=False, server_default=""),
|
||||
)
|
||||
|
||||
# Add reasoning_tokens (if not exists)
|
||||
result = conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT column_name FROM information_schema.columns
|
||||
WHERE table_name = 'tool_call' AND column_name = 'reasoning_tokens'
|
||||
"""
|
||||
)
|
||||
)
|
||||
if not result.fetchone():
|
||||
op.add_column(
|
||||
"tool_call", sa.Column("reasoning_tokens", sa.Text(), nullable=True)
|
||||
)
|
||||
|
||||
# Rename tool_arguments to tool_call_arguments (if not already done)
|
||||
result = conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT column_name FROM information_schema.columns
|
||||
WHERE table_name = 'tool_call' AND column_name = 'tool_arguments'
|
||||
"""
|
||||
)
|
||||
)
|
||||
if result.fetchone():
|
||||
op.alter_column(
|
||||
"tool_call", "tool_arguments", new_column_name="tool_call_arguments"
|
||||
)
|
||||
|
||||
# Rename tool_result to tool_call_response and change type from JSONB to Text (if not already done)
|
||||
result = conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT column_name, data_type FROM information_schema.columns
|
||||
WHERE table_name = 'tool_call' AND column_name = 'tool_result'
|
||||
"""
|
||||
)
|
||||
)
|
||||
tool_result_row = result.fetchone()
|
||||
if tool_result_row:
|
||||
op.alter_column(
|
||||
"tool_call", "tool_result", new_column_name="tool_call_response"
|
||||
)
|
||||
# Change type from JSONB to Text
|
||||
op.execute(
|
||||
sa.text(
|
||||
"""
|
||||
ALTER TABLE tool_call
|
||||
ALTER COLUMN tool_call_response TYPE TEXT
|
||||
USING tool_call_response::text
|
||||
"""
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Check if tool_call_response already exists and is JSONB, then convert to Text
|
||||
result = conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT data_type FROM information_schema.columns
|
||||
WHERE table_name = 'tool_call' AND column_name = 'tool_call_response'
|
||||
"""
|
||||
)
|
||||
)
|
||||
tool_call_response_row = result.fetchone()
|
||||
if tool_call_response_row and tool_call_response_row[0] == "jsonb":
|
||||
op.execute(
|
||||
sa.text(
|
||||
"""
|
||||
ALTER TABLE tool_call
|
||||
ALTER COLUMN tool_call_response TYPE TEXT
|
||||
USING tool_call_response::text
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
# Add tool_call_tokens (if not exists)
|
||||
result = conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT column_name FROM information_schema.columns
|
||||
WHERE table_name = 'tool_call' AND column_name = 'tool_call_tokens'
|
||||
"""
|
||||
)
|
||||
)
|
||||
if not result.fetchone():
|
||||
op.add_column(
|
||||
"tool_call",
|
||||
sa.Column(
|
||||
"tool_call_tokens", sa.Integer(), nullable=False, server_default="0"
|
||||
),
|
||||
)
|
||||
|
||||
# Add generated_images column for image generation tool replay (if not exists)
|
||||
result = conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT column_name FROM information_schema.columns
|
||||
WHERE table_name = 'tool_call' AND column_name = 'generated_images'
|
||||
"""
|
||||
)
|
||||
)
|
||||
if not result.fetchone():
|
||||
op.add_column(
|
||||
"tool_call",
|
||||
sa.Column("generated_images", postgresql.JSONB(), nullable=True),
|
||||
)
|
||||
|
||||
# Drop tool_name column (if exists)
|
||||
result = conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT column_name FROM information_schema.columns
|
||||
WHERE table_name = 'tool_call' AND column_name = 'tool_name'
|
||||
"""
|
||||
)
|
||||
)
|
||||
if result.fetchone():
|
||||
op.drop_column("tool_call", "tool_name")
|
||||
|
||||
# Create tool_call__search_doc association table (if not exists)
|
||||
result = conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT table_name FROM information_schema.tables
|
||||
WHERE table_name = 'tool_call__search_doc'
|
||||
"""
|
||||
)
|
||||
)
|
||||
if not result.fetchone():
|
||||
op.create_table(
|
||||
"tool_call__search_doc",
|
||||
sa.Column("tool_call_id", sa.Integer(), nullable=False),
|
||||
sa.Column("search_doc_id", sa.Integer(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["tool_call_id"], ["tool_call.id"], ondelete="CASCADE"
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["search_doc_id"], ["search_doc.id"], ondelete="CASCADE"
|
||||
),
|
||||
sa.PrimaryKeyConstraint("tool_call_id", "search_doc_id"),
|
||||
)
|
||||
|
||||
# Add replace_base_system_prompt to persona table (if not exists)
|
||||
result = conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT column_name FROM information_schema.columns
|
||||
WHERE table_name = 'persona' AND column_name = 'replace_base_system_prompt'
|
||||
"""
|
||||
)
|
||||
)
|
||||
if not result.fetchone():
|
||||
op.add_column(
|
||||
"persona",
|
||||
sa.Column(
|
||||
"replace_base_system_prompt",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default="false",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Reverse persona changes
|
||||
op.drop_column("persona", "replace_base_system_prompt")
|
||||
|
||||
# Drop tool_call__search_doc association table
|
||||
op.execute("DROP TABLE IF EXISTS tool_call__search_doc CASCADE")
|
||||
|
||||
# Reverse ToolCall changes
|
||||
op.add_column("tool_call", sa.Column("tool_name", sa.String(), nullable=False))
|
||||
op.drop_column("tool_call", "tool_id")
|
||||
op.drop_column("tool_call", "tool_call_tokens")
|
||||
op.drop_column("tool_call", "generated_images")
|
||||
# Change tool_call_response back to JSONB before renaming
|
||||
op.execute(
|
||||
sa.text(
|
||||
"""
|
||||
ALTER TABLE tool_call
|
||||
ALTER COLUMN tool_call_response TYPE JSONB
|
||||
USING tool_call_response::jsonb
|
||||
"""
|
||||
)
|
||||
)
|
||||
op.alter_column("tool_call", "tool_call_response", new_column_name="tool_result")
|
||||
op.alter_column(
|
||||
"tool_call", "tool_call_arguments", new_column_name="tool_arguments"
|
||||
)
|
||||
op.drop_column("tool_call", "reasoning_tokens")
|
||||
op.drop_column("tool_call", "tool_call_id")
|
||||
op.drop_column("tool_call", "turn_number")
|
||||
op.drop_constraint(
|
||||
"fk_tool_call_parent_tool_call_id", "tool_call", type_="foreignkey"
|
||||
)
|
||||
op.drop_column("tool_call", "parent_tool_call_id")
|
||||
op.alter_column(
|
||||
"tool_call",
|
||||
"parent_chat_message_id",
|
||||
new_column_name="message_id",
|
||||
nullable=False,
|
||||
)
|
||||
op.drop_constraint("fk_tool_call_chat_session_id", "tool_call", type_="foreignkey")
|
||||
op.drop_column("tool_call", "chat_session_id")
|
||||
|
||||
op.add_column(
|
||||
"chat_message",
|
||||
sa.Column(
|
||||
"research_answer_purpose",
|
||||
sa.Enum("INTRO", "DEEP_DIVE", name="researchanswerpurpose"),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"chat_message", sa.Column("research_plan", postgresql.JSONB(), nullable=True)
|
||||
)
|
||||
op.add_column(
|
||||
"chat_message",
|
||||
sa.Column(
|
||||
"research_type",
|
||||
sa.Enum("SIMPLE", "DEEP", name="researchtype"),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"chat_message",
|
||||
sa.Column("refined_answer_improvement", sa.Boolean(), nullable=True),
|
||||
)
|
||||
op.add_column(
|
||||
"chat_message",
|
||||
sa.Column("is_agentic", sa.Boolean(), nullable=False, server_default="false"),
|
||||
)
|
||||
op.add_column(
|
||||
"chat_message", sa.Column("overridden_model", sa.String(), nullable=True)
|
||||
)
|
||||
op.add_column(
|
||||
"chat_message", sa.Column("alternate_assistant_id", sa.Integer(), nullable=True)
|
||||
)
|
||||
op.add_column(
|
||||
"chat_message", sa.Column("rephrased_query", sa.Text(), nullable=True)
|
||||
)
|
||||
op.drop_column("chat_message", "reasoning_tokens")
|
||||
op.drop_constraint(
|
||||
"fk_chat_message_latest_child_message_id", "chat_message", type_="foreignkey"
|
||||
)
|
||||
op.alter_column(
|
||||
"chat_message",
|
||||
"latest_child_message_id",
|
||||
new_column_name="latest_child_message",
|
||||
)
|
||||
op.drop_constraint(
|
||||
"fk_chat_message_parent_message_id", "chat_message", type_="foreignkey"
|
||||
)
|
||||
op.alter_column(
|
||||
"chat_message", "parent_message_id", new_column_name="parent_message"
|
||||
)
|
||||
|
||||
# Recreate agent sub question and sub query tables
|
||||
op.create_table(
|
||||
"agent__sub_question",
|
||||
sa.Column("id", sa.Integer(), primary_key=True),
|
||||
sa.Column("primary_question_id", sa.Integer(), nullable=False),
|
||||
sa.Column("chat_session_id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column("sub_question", sa.Text(), nullable=False),
|
||||
sa.Column("level", sa.Integer(), nullable=False),
|
||||
sa.Column("level_question_num", sa.Integer(), nullable=False),
|
||||
sa.Column(
|
||||
"time_created",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("sub_answer", sa.Text(), nullable=False),
|
||||
sa.Column("sub_question_doc_results", postgresql.JSONB(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["primary_question_id"], ["chat_message.id"], ondelete="CASCADE"
|
||||
),
|
||||
sa.ForeignKeyConstraint(["chat_session_id"], ["chat_session.id"]),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
|
||||
op.create_table(
|
||||
"agent__sub_query",
|
||||
sa.Column("id", sa.Integer(), primary_key=True),
|
||||
sa.Column("parent_question_id", sa.Integer(), nullable=False),
|
||||
sa.Column("chat_session_id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column("sub_query", sa.Text(), nullable=False),
|
||||
sa.Column(
|
||||
"time_created",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["parent_question_id"], ["agent__sub_question.id"], ondelete="CASCADE"
|
||||
),
|
||||
sa.ForeignKeyConstraint(["chat_session_id"], ["chat_session.id"]),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
|
||||
op.create_table(
|
||||
"agent__sub_query__search_doc",
|
||||
sa.Column("sub_query_id", sa.Integer(), nullable=False),
|
||||
sa.Column("search_doc_id", sa.Integer(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["sub_query_id"], ["agent__sub_query.id"], ondelete="CASCADE"
|
||||
),
|
||||
sa.ForeignKeyConstraint(["search_doc_id"], ["search_doc.id"]),
|
||||
sa.PrimaryKeyConstraint("sub_query_id", "search_doc_id"),
|
||||
)
|
||||
|
||||
# Recreate research agent tables
|
||||
op.create_table(
|
||||
"research_agent_iteration",
|
||||
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
|
||||
sa.Column("primary_question_id", sa.Integer(), nullable=False),
|
||||
sa.Column("iteration_nr", sa.Integer(), nullable=False),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("purpose", sa.String(), nullable=True),
|
||||
sa.Column("reasoning", sa.String(), nullable=True),
|
||||
sa.ForeignKeyConstraint(
|
||||
["primary_question_id"], ["chat_message.id"], ondelete="CASCADE"
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint(
|
||||
"primary_question_id",
|
||||
"iteration_nr",
|
||||
name="_research_agent_iteration_unique_constraint",
|
||||
),
|
||||
)
|
||||
|
||||
op.create_table(
|
||||
"research_agent_iteration_sub_step",
|
||||
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
|
||||
sa.Column("primary_question_id", sa.Integer(), nullable=False),
|
||||
sa.Column("iteration_nr", sa.Integer(), nullable=False),
|
||||
sa.Column("iteration_sub_step_nr", sa.Integer(), nullable=False),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("sub_step_instructions", sa.String(), nullable=True),
|
||||
sa.Column("sub_step_tool_id", sa.Integer(), nullable=True),
|
||||
sa.Column("reasoning", sa.String(), nullable=True),
|
||||
sa.Column("sub_answer", sa.String(), nullable=True),
|
||||
sa.Column("cited_doc_results", postgresql.JSONB(), nullable=False),
|
||||
sa.Column("claims", postgresql.JSONB(), nullable=True),
|
||||
sa.Column("is_web_fetch", sa.Boolean(), nullable=True),
|
||||
sa.Column("queries", postgresql.JSONB(), nullable=True),
|
||||
sa.Column("generated_images", postgresql.JSONB(), nullable=True),
|
||||
sa.Column("additional_data", postgresql.JSONB(), nullable=True),
|
||||
sa.ForeignKeyConstraint(
|
||||
["primary_question_id", "iteration_nr"],
|
||||
[
|
||||
"research_agent_iteration.primary_question_id",
|
||||
"research_agent_iteration.iteration_nr",
|
||||
],
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
sa.ForeignKeyConstraint(["sub_step_tool_id"], ["tool.id"], ondelete="SET NULL"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
@@ -199,10 +199,7 @@ def fetch_persona_message_analytics(
|
||||
ChatMessage.chat_session_id == ChatSession.id,
|
||||
)
|
||||
.where(
|
||||
or_(
|
||||
ChatMessage.alternate_assistant_id == persona_id,
|
||||
ChatSession.persona_id == persona_id,
|
||||
),
|
||||
ChatSession.persona_id == persona_id,
|
||||
ChatMessage.time_sent >= start,
|
||||
ChatMessage.time_sent <= end,
|
||||
ChatMessage.message_type == MessageType.ASSISTANT,
|
||||
@@ -231,10 +228,7 @@ def fetch_persona_unique_users(
|
||||
ChatMessage.chat_session_id == ChatSession.id,
|
||||
)
|
||||
.where(
|
||||
or_(
|
||||
ChatMessage.alternate_assistant_id == persona_id,
|
||||
ChatSession.persona_id == persona_id,
|
||||
),
|
||||
ChatSession.persona_id == persona_id,
|
||||
ChatMessage.time_sent >= start,
|
||||
ChatMessage.time_sent <= end,
|
||||
ChatMessage.message_type == MessageType.ASSISTANT,
|
||||
@@ -265,10 +259,7 @@ def fetch_assistant_message_analytics(
|
||||
ChatMessage.chat_session_id == ChatSession.id,
|
||||
)
|
||||
.where(
|
||||
or_(
|
||||
ChatMessage.alternate_assistant_id == assistant_id,
|
||||
ChatSession.persona_id == assistant_id,
|
||||
),
|
||||
ChatSession.persona_id == assistant_id,
|
||||
ChatMessage.time_sent >= start,
|
||||
ChatMessage.time_sent <= end,
|
||||
ChatMessage.message_type == MessageType.ASSISTANT,
|
||||
@@ -299,10 +290,7 @@ def fetch_assistant_unique_users(
|
||||
ChatMessage.chat_session_id == ChatSession.id,
|
||||
)
|
||||
.where(
|
||||
or_(
|
||||
ChatMessage.alternate_assistant_id == assistant_id,
|
||||
ChatSession.persona_id == assistant_id,
|
||||
),
|
||||
ChatSession.persona_id == assistant_id,
|
||||
ChatMessage.time_sent >= start,
|
||||
ChatMessage.time_sent <= end,
|
||||
ChatMessage.message_type == MessageType.ASSISTANT,
|
||||
@@ -332,10 +320,7 @@ def fetch_assistant_unique_users_total(
|
||||
ChatMessage.chat_session_id == ChatSession.id,
|
||||
)
|
||||
.where(
|
||||
or_(
|
||||
ChatMessage.alternate_assistant_id == assistant_id,
|
||||
ChatSession.persona_id == assistant_id,
|
||||
),
|
||||
ChatSession.persona_id == assistant_id,
|
||||
ChatMessage.time_sent >= start,
|
||||
ChatMessage.time_sent <= end,
|
||||
ChatMessage.message_type == MessageType.ASSISTANT,
|
||||
|
||||
@@ -55,18 +55,7 @@ def get_empty_chat_messages_entries__paginated(
|
||||
|
||||
# Get assistant name (from session persona, or alternate if specified)
|
||||
assistant_name = None
|
||||
if message.alternate_assistant_id:
|
||||
# If there's an alternate assistant, we need to fetch it
|
||||
from onyx.db.models import Persona
|
||||
|
||||
alternate_persona = (
|
||||
db_session.query(Persona)
|
||||
.filter(Persona.id == message.alternate_assistant_id)
|
||||
.first()
|
||||
)
|
||||
if alternate_persona:
|
||||
assistant_name = alternate_persona.name
|
||||
elif chat_session.persona:
|
||||
if chat_session.persona:
|
||||
assistant_name = chat_session.persona.name
|
||||
|
||||
message_skeletons.append(
|
||||
|
||||
@@ -9,7 +9,7 @@ from ee.onyx.server.query_and_chat.models import (
|
||||
)
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.chat.chat_utils import combine_message_thread
|
||||
from onyx.chat.chat_utils import create_chat_chain
|
||||
from onyx.chat.chat_utils import create_chat_history_chain
|
||||
from onyx.chat.models import ChatBasicResponse
|
||||
from onyx.chat.process_message import gather_stream
|
||||
from onyx.chat.process_message import stream_chat_message_objects
|
||||
@@ -69,7 +69,7 @@ def handle_simplified_chat_message(
|
||||
chat_session_id = chat_message_req.chat_session_id
|
||||
|
||||
try:
|
||||
parent_message, _ = create_chat_chain(
|
||||
parent_message, _ = create_chat_history_chain(
|
||||
chat_session_id=chat_session_id, db_session=db_session
|
||||
)
|
||||
except Exception:
|
||||
|
||||
@@ -8,10 +8,29 @@ from pydantic import model_validator
|
||||
|
||||
from onyx.chat.models import ThreadMessage
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.context.search.models import BaseFilters
|
||||
from onyx.context.search.models import BasicChunkRequest
|
||||
from onyx.context.search.models import ChunkContext
|
||||
from onyx.context.search.models import InferenceChunk
|
||||
from onyx.context.search.models import RetrievalDetails
|
||||
from onyx.server.manage.models import StandardAnswer
|
||||
from onyx.server.query_and_chat.streaming_models import SubQuestionIdentifier
|
||||
|
||||
|
||||
class StandardAnswerRequest(BaseModel):
|
||||
message: str
|
||||
slack_bot_categories: list[str]
|
||||
|
||||
|
||||
class StandardAnswerResponse(BaseModel):
|
||||
standard_answers: list[StandardAnswer] = Field(default_factory=list)
|
||||
|
||||
|
||||
class DocumentSearchRequest(BasicChunkRequest):
|
||||
user_selected_filters: BaseFilters | None = None
|
||||
|
||||
|
||||
class DocumentSearchResponse(BaseModel):
|
||||
top_documents: list[InferenceChunk]
|
||||
|
||||
|
||||
class BasicCreateChatMessageRequest(ChunkContext):
|
||||
@@ -71,17 +90,17 @@ class SimpleDoc(BaseModel):
|
||||
metadata: dict | None
|
||||
|
||||
|
||||
class AgentSubQuestion(SubQuestionIdentifier):
|
||||
class AgentSubQuestion(BaseModel):
|
||||
sub_question: str
|
||||
document_ids: list[str]
|
||||
|
||||
|
||||
class AgentAnswer(SubQuestionIdentifier):
|
||||
class AgentAnswer(BaseModel):
|
||||
answer: str
|
||||
answer_type: Literal["agent_sub_answer", "agent_level_answer"]
|
||||
|
||||
|
||||
class AgentSubQuery(SubQuestionIdentifier):
|
||||
class AgentSubQuery(BaseModel):
|
||||
sub_query: str
|
||||
query_id: int
|
||||
|
||||
@@ -127,12 +146,3 @@ class AgentSubQuery(SubQuestionIdentifier):
|
||||
sorted(level_question_dict.items(), key=lambda x: (x is None, x))
|
||||
)
|
||||
return sorted_dict
|
||||
|
||||
|
||||
class StandardAnswerRequest(BaseModel):
|
||||
message: str
|
||||
slack_bot_categories: list[str]
|
||||
|
||||
|
||||
class StandardAnswerResponse(BaseModel):
|
||||
standard_answers: list[StandardAnswer] = Field(default_factory=list)
|
||||
|
||||
@@ -24,7 +24,7 @@ from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import get_display_email
|
||||
from onyx.background.celery.versioned_apps.client import app as client_app
|
||||
from onyx.background.task_utils import construct_query_history_report_name
|
||||
from onyx.chat.chat_utils import create_chat_chain
|
||||
from onyx.chat.chat_utils import create_chat_history_chain
|
||||
from onyx.configs.app_configs import ONYX_QUERY_HISTORY_TYPE
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.configs.constants import FileType
|
||||
@@ -123,10 +123,9 @@ def snapshot_from_chat_session(
|
||||
) -> ChatSessionSnapshot | None:
|
||||
try:
|
||||
# Older chats may not have the right structure
|
||||
last_message, messages = create_chat_chain(
|
||||
messages = create_chat_history_chain(
|
||||
chat_session_id=chat_session.id, db_session=db_session
|
||||
)
|
||||
messages.append(last_message)
|
||||
except RuntimeError:
|
||||
return None
|
||||
|
||||
|
||||
@@ -1,365 +1,309 @@
|
||||
import json
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
# import json
|
||||
# from collections.abc import Callable
|
||||
# from collections.abc import Iterator
|
||||
# from collections.abc import Sequence
|
||||
# from dataclasses import dataclass
|
||||
# from typing import Any
|
||||
|
||||
import onyx.tracing.framework._error_tracing as _error_tracing
|
||||
from onyx.agents.agent_framework.models import RunItemStreamEvent
|
||||
from onyx.agents.agent_framework.models import StreamEvent
|
||||
from onyx.agents.agent_framework.models import ToolCallOutputStreamItem
|
||||
from onyx.agents.agent_framework.models import ToolCallStreamItem
|
||||
from onyx.llm.interfaces import LanguageModelInput
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.interfaces import ToolChoiceOptions
|
||||
from onyx.llm.message_types import ChatCompletionMessage
|
||||
from onyx.llm.message_types import ToolCall
|
||||
from onyx.llm.model_response import ModelResponseStream
|
||||
from onyx.tools.tool import RunContextWrapper
|
||||
from onyx.tools.tool import Tool
|
||||
from onyx.tracing.framework.create import agent_span
|
||||
from onyx.tracing.framework.create import function_span
|
||||
from onyx.tracing.framework.create import generation_span
|
||||
from onyx.tracing.framework.spans import SpanError
|
||||
# from onyx.agents.agent_framework.models import RunItemStreamEvent
|
||||
# from onyx.agents.agent_framework.models import StreamEvent
|
||||
# from onyx.agents.agent_framework.models import ToolCallStreamItem
|
||||
# from onyx.llm.interfaces import LanguageModelInput
|
||||
# from onyx.llm.interfaces import LLM
|
||||
# from onyx.llm.interfaces import ToolChoiceOptions
|
||||
# from onyx.llm.message_types import ChatCompletionMessage
|
||||
# from onyx.llm.message_types import ToolCall
|
||||
# from onyx.llm.model_response import ModelResponseStream
|
||||
# from onyx.tools.tool import Tool
|
||||
# from onyx.tracing.framework.create import agent_span
|
||||
# from onyx.tracing.framework.create import generation_span
|
||||
|
||||
|
||||
@dataclass
|
||||
class QueryResult:
|
||||
stream: Iterator[StreamEvent]
|
||||
new_messages_stateful: list[ChatCompletionMessage]
|
||||
# @dataclass
|
||||
# class QueryResult:
|
||||
# stream: Iterator[StreamEvent]
|
||||
# new_messages_stateful: list[ChatCompletionMessage]
|
||||
|
||||
|
||||
def _serialize_tool_output(output: Any) -> str:
|
||||
if isinstance(output, str):
|
||||
return output
|
||||
try:
|
||||
return json.dumps(output)
|
||||
except TypeError:
|
||||
return str(output)
|
||||
# def _serialize_tool_output(output: Any) -> str:
|
||||
# if isinstance(output, str):
|
||||
# return output
|
||||
# try:
|
||||
# return json.dumps(output)
|
||||
# except TypeError:
|
||||
# return str(output)
|
||||
|
||||
|
||||
def _parse_tool_calls_from_message_content(
|
||||
content: str,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Parse JSON content that represents tool call instructions."""
|
||||
try:
|
||||
parsed_content = json.loads(content)
|
||||
except json.JSONDecodeError:
|
||||
return []
|
||||
# def _parse_tool_calls_from_message_content(
|
||||
# content: str,
|
||||
# ) -> list[dict[str, Any]]:
|
||||
# """Parse JSON content that represents tool call instructions."""
|
||||
# try:
|
||||
# parsed_content = json.loads(content)
|
||||
# except json.JSONDecodeError:
|
||||
# return []
|
||||
|
||||
if isinstance(parsed_content, dict):
|
||||
candidates = [parsed_content]
|
||||
elif isinstance(parsed_content, list):
|
||||
candidates = [item for item in parsed_content if isinstance(item, dict)]
|
||||
else:
|
||||
return []
|
||||
# if isinstance(parsed_content, dict):
|
||||
# candidates = [parsed_content]
|
||||
# elif isinstance(parsed_content, list):
|
||||
# candidates = [item for item in parsed_content if isinstance(item, dict)]
|
||||
# else:
|
||||
# return []
|
||||
|
||||
tool_calls: list[dict[str, Any]] = []
|
||||
# tool_calls: list[dict[str, Any]] = []
|
||||
|
||||
for candidate in candidates:
|
||||
name = candidate.get("name")
|
||||
arguments = candidate.get("arguments")
|
||||
# for candidate in candidates:
|
||||
# name = candidate.get("name")
|
||||
# arguments = candidate.get("arguments")
|
||||
|
||||
if not isinstance(name, str) or arguments is None:
|
||||
continue
|
||||
# if not isinstance(name, str) or arguments is None:
|
||||
# continue
|
||||
|
||||
if not isinstance(arguments, dict):
|
||||
continue
|
||||
# if not isinstance(arguments, dict):
|
||||
# continue
|
||||
|
||||
call_id = candidate.get("id")
|
||||
arguments_str = json.dumps(arguments)
|
||||
tool_calls.append(
|
||||
{
|
||||
"id": call_id,
|
||||
"name": name,
|
||||
"arguments": arguments_str,
|
||||
}
|
||||
)
|
||||
# call_id = candidate.get("id")
|
||||
# arguments_str = json.dumps(arguments)
|
||||
# tool_calls.append(
|
||||
# {
|
||||
# "id": call_id,
|
||||
# "name": name,
|
||||
# "arguments": arguments_str,
|
||||
# }
|
||||
# )
|
||||
|
||||
return tool_calls
|
||||
# return tool_calls
|
||||
|
||||
|
||||
def _try_convert_content_to_tool_calls_for_non_tool_calling_llms(
|
||||
tool_calls_in_progress: dict[int, dict[str, Any]],
|
||||
content_parts: list[str],
|
||||
structured_response_format: dict | None,
|
||||
next_synthetic_tool_call_id: Callable[[], str],
|
||||
) -> None:
|
||||
"""Populate tool_calls_in_progress when a non-tool-calling LLM returns JSON content describing tool calls."""
|
||||
if tool_calls_in_progress or not content_parts or structured_response_format:
|
||||
return
|
||||
# def _try_convert_content_to_tool_calls_for_non_tool_calling_llms(
|
||||
# tool_calls_in_progress: dict[int, dict[str, Any]],
|
||||
# content_parts: list[str],
|
||||
# structured_response_format: dict | None,
|
||||
# next_synthetic_tool_call_id: Callable[[], str],
|
||||
# ) -> None:
|
||||
# """Populate tool_calls_in_progress when a non-tool-calling LLM returns JSON content describing tool calls."""
|
||||
# if tool_calls_in_progress or not content_parts or structured_response_format:
|
||||
# return
|
||||
|
||||
tool_calls_from_content = _parse_tool_calls_from_message_content(
|
||||
"".join(content_parts)
|
||||
)
|
||||
# tool_calls_from_content = _parse_tool_calls_from_message_content(
|
||||
# "".join(content_parts)
|
||||
# )
|
||||
|
||||
if not tool_calls_from_content:
|
||||
return
|
||||
# if not tool_calls_from_content:
|
||||
# return
|
||||
|
||||
content_parts.clear()
|
||||
# content_parts.clear()
|
||||
|
||||
for index, tool_call_data in enumerate(tool_calls_from_content):
|
||||
call_id = tool_call_data["id"] or next_synthetic_tool_call_id()
|
||||
tool_calls_in_progress[index] = {
|
||||
"id": call_id,
|
||||
"name": tool_call_data["name"],
|
||||
"arguments": tool_call_data["arguments"],
|
||||
}
|
||||
# for index, tool_call_data in enumerate(tool_calls_from_content):
|
||||
# call_id = tool_call_data["id"] or next_synthetic_tool_call_id()
|
||||
# tool_calls_in_progress[index] = {
|
||||
# "id": call_id,
|
||||
# "name": tool_call_data["name"],
|
||||
# "arguments": tool_call_data["arguments"],
|
||||
# }
|
||||
|
||||
|
||||
def _update_tool_call_with_delta(
|
||||
tool_calls_in_progress: dict[int, dict[str, Any]],
|
||||
tool_call_delta: Any,
|
||||
) -> None:
|
||||
index = tool_call_delta.index
|
||||
# def _update_tool_call_with_delta(
|
||||
# tool_calls_in_progress: dict[int, dict[str, Any]],
|
||||
# tool_call_delta: Any,
|
||||
# ) -> None:
|
||||
# index = tool_call_delta.index
|
||||
|
||||
if index not in tool_calls_in_progress:
|
||||
tool_calls_in_progress[index] = {
|
||||
"id": None,
|
||||
"name": None,
|
||||
"arguments": "",
|
||||
}
|
||||
# if index not in tool_calls_in_progress:
|
||||
# tool_calls_in_progress[index] = {
|
||||
# "id": None,
|
||||
# "name": None,
|
||||
# "arguments": "",
|
||||
# }
|
||||
|
||||
if tool_call_delta.id:
|
||||
tool_calls_in_progress[index]["id"] = tool_call_delta.id
|
||||
# if tool_call_delta.id:
|
||||
# tool_calls_in_progress[index]["id"] = tool_call_delta.id
|
||||
|
||||
if tool_call_delta.function:
|
||||
if tool_call_delta.function.name:
|
||||
tool_calls_in_progress[index]["name"] = tool_call_delta.function.name
|
||||
# if tool_call_delta.function:
|
||||
# if tool_call_delta.function.name:
|
||||
# tool_calls_in_progress[index]["name"] = tool_call_delta.function.name
|
||||
|
||||
if tool_call_delta.function.arguments:
|
||||
tool_calls_in_progress[index][
|
||||
"arguments"
|
||||
] += tool_call_delta.function.arguments
|
||||
# if tool_call_delta.function.arguments:
|
||||
# tool_calls_in_progress[index][
|
||||
# "arguments"
|
||||
# ] += tool_call_delta.function.arguments
|
||||
|
||||
|
||||
def query(
|
||||
llm_with_default_settings: LLM,
|
||||
messages: LanguageModelInput,
|
||||
tools: Sequence[Tool],
|
||||
context: Any,
|
||||
tool_choice: ToolChoiceOptions | None = None,
|
||||
structured_response_format: dict | None = None,
|
||||
) -> QueryResult:
|
||||
tool_definitions = [tool.tool_definition() for tool in tools]
|
||||
tools_by_name = {tool.name: tool for tool in tools}
|
||||
# def query(
|
||||
# llm_with_default_settings: LLM,
|
||||
# messages: LanguageModelInput,
|
||||
# tools: Sequence[Tool],
|
||||
# context: Any,
|
||||
# tool_choice: ToolChoiceOptions | None = None,
|
||||
# structured_response_format: dict | None = None,
|
||||
# ) -> QueryResult:
|
||||
# tool_definitions = [tool.tool_definition() for tool in tools]
|
||||
# tools_by_name = {tool.name: tool for tool in tools}
|
||||
|
||||
new_messages_stateful: list[ChatCompletionMessage] = []
|
||||
# new_messages_stateful: list[ChatCompletionMessage] = []
|
||||
|
||||
current_span = agent_span(
|
||||
name="agent_framework_query",
|
||||
output_type="dict" if structured_response_format else "str",
|
||||
)
|
||||
current_span.start(mark_as_current=True)
|
||||
current_span.span_data.tools = [t.name for t in tools]
|
||||
# current_span = agent_span(
|
||||
# name="agent_framework_query",
|
||||
# output_type="dict" if structured_response_format else "str",
|
||||
# )
|
||||
# current_span.start(mark_as_current=True)
|
||||
# current_span.span_data.tools = [t.name for t in tools]
|
||||
|
||||
def stream_generator() -> Iterator[StreamEvent]:
|
||||
message_started = False
|
||||
reasoning_started = False
|
||||
# def stream_generator() -> Iterator[StreamEvent]:
|
||||
# message_started = False
|
||||
# reasoning_started = False
|
||||
|
||||
tool_calls_in_progress: dict[int, dict[str, Any]] = {}
|
||||
# tool_calls_in_progress: dict[int, dict[str, Any]] = {}
|
||||
|
||||
content_parts: list[str] = []
|
||||
# content_parts: list[str] = []
|
||||
|
||||
synthetic_tool_call_counter = 0
|
||||
# synthetic_tool_call_counter = 0
|
||||
|
||||
def _next_synthetic_tool_call_id() -> str:
|
||||
nonlocal synthetic_tool_call_counter
|
||||
call_id = f"synthetic_tool_call_{synthetic_tool_call_counter}"
|
||||
synthetic_tool_call_counter += 1
|
||||
return call_id
|
||||
# def _next_synthetic_tool_call_id() -> str:
|
||||
# nonlocal synthetic_tool_call_counter
|
||||
# call_id = f"synthetic_tool_call_{synthetic_tool_call_counter}"
|
||||
# synthetic_tool_call_counter += 1
|
||||
# return call_id
|
||||
|
||||
with generation_span( # type: ignore[misc]
|
||||
model=llm_with_default_settings.config.model_name,
|
||||
model_config={
|
||||
"base_url": str(llm_with_default_settings.config.api_base or ""),
|
||||
"model_impl": "litellm",
|
||||
},
|
||||
) as span_generation:
|
||||
# Only set input if messages is a sequence (not a string)
|
||||
# ChatCompletionMessage TypedDicts are compatible with Mapping[str, Any] at runtime
|
||||
if isinstance(messages, Sequence) and not isinstance(messages, str):
|
||||
# Convert ChatCompletionMessage sequence to Sequence[Mapping[str, Any]]
|
||||
span_generation.span_data.input = [dict(msg) for msg in messages] # type: ignore[assignment]
|
||||
for chunk in llm_with_default_settings.stream(
|
||||
prompt=messages,
|
||||
tools=tool_definitions,
|
||||
tool_choice=tool_choice,
|
||||
structured_response_format=structured_response_format,
|
||||
):
|
||||
assert isinstance(chunk, ModelResponseStream)
|
||||
usage = getattr(chunk, "usage", None)
|
||||
if usage:
|
||||
span_generation.span_data.usage = {
|
||||
"input_tokens": usage.prompt_tokens,
|
||||
"output_tokens": usage.completion_tokens,
|
||||
"cache_read_input_tokens": usage.cache_read_input_tokens,
|
||||
"cache_creation_input_tokens": usage.cache_creation_input_tokens,
|
||||
}
|
||||
# with generation_span( # type: ignore[misc]
|
||||
# model=llm_with_default_settings.config.model_name,
|
||||
# model_config={
|
||||
# "base_url": str(llm_with_default_settings.config.api_base or ""),
|
||||
# "model_impl": "litellm",
|
||||
# },
|
||||
# ) as span_generation:
|
||||
# # Only set input if messages is a sequence (not a string)
|
||||
# # ChatCompletionMessage TypedDicts are compatible with Mapping[str, Any] at runtime
|
||||
# if isinstance(messages, Sequence) and not isinstance(messages, str):
|
||||
# # Convert ChatCompletionMessage sequence to Sequence[Mapping[str, Any]]
|
||||
# span_generation.span_data.input = [dict(msg) for msg in messages] # type: ignore[assignment]
|
||||
# for chunk in llm_with_default_settings.stream(
|
||||
# prompt=messages,
|
||||
# tools=tool_definitions,
|
||||
# tool_choice=tool_choice,
|
||||
# structured_response_format=structured_response_format,
|
||||
# ):
|
||||
# assert isinstance(chunk, ModelResponseStream)
|
||||
# usage = getattr(chunk, "usage", None)
|
||||
# if usage:
|
||||
# span_generation.span_data.usage = {
|
||||
# "input_tokens": usage.prompt_tokens,
|
||||
# "output_tokens": usage.completion_tokens,
|
||||
# "cache_read_input_tokens": usage.cache_read_input_tokens,
|
||||
# "cache_creation_input_tokens": usage.cache_creation_input_tokens,
|
||||
# }
|
||||
|
||||
delta = chunk.choice.delta
|
||||
finish_reason = chunk.choice.finish_reason
|
||||
# delta = chunk.choice.delta
|
||||
# finish_reason = chunk.choice.finish_reason
|
||||
|
||||
if delta.reasoning_content:
|
||||
if not reasoning_started:
|
||||
yield RunItemStreamEvent(type="reasoning_start")
|
||||
reasoning_started = True
|
||||
# if delta.reasoning_content:
|
||||
# if not reasoning_started:
|
||||
# yield RunItemStreamEvent(type="reasoning_start")
|
||||
# reasoning_started = True
|
||||
|
||||
if delta.content:
|
||||
if reasoning_started:
|
||||
yield RunItemStreamEvent(type="reasoning_done")
|
||||
reasoning_started = False
|
||||
content_parts.append(delta.content)
|
||||
if not message_started:
|
||||
yield RunItemStreamEvent(type="message_start")
|
||||
message_started = True
|
||||
# if delta.content:
|
||||
# if reasoning_started:
|
||||
# yield RunItemStreamEvent(type="reasoning_done")
|
||||
# reasoning_started = False
|
||||
# content_parts.append(delta.content)
|
||||
# if not message_started:
|
||||
# yield RunItemStreamEvent(type="message_start")
|
||||
# message_started = True
|
||||
|
||||
if delta.tool_calls:
|
||||
if reasoning_started:
|
||||
yield RunItemStreamEvent(type="reasoning_done")
|
||||
reasoning_started = False
|
||||
if message_started:
|
||||
yield RunItemStreamEvent(type="message_done")
|
||||
message_started = False
|
||||
# if delta.tool_calls:
|
||||
# if reasoning_started:
|
||||
# yield RunItemStreamEvent(type="reasoning_done")
|
||||
# reasoning_started = False
|
||||
# if message_started:
|
||||
# yield RunItemStreamEvent(type="message_done")
|
||||
# message_started = False
|
||||
|
||||
for tool_call_delta in delta.tool_calls:
|
||||
_update_tool_call_with_delta(
|
||||
tool_calls_in_progress, tool_call_delta
|
||||
)
|
||||
# for tool_call_delta in delta.tool_calls:
|
||||
# _update_tool_call_with_delta(
|
||||
# tool_calls_in_progress, tool_call_delta
|
||||
# )
|
||||
|
||||
yield chunk
|
||||
# yield chunk
|
||||
|
||||
if not finish_reason:
|
||||
continue
|
||||
# if not finish_reason:
|
||||
# continue
|
||||
|
||||
if reasoning_started:
|
||||
yield RunItemStreamEvent(type="reasoning_done")
|
||||
reasoning_started = False
|
||||
if message_started:
|
||||
yield RunItemStreamEvent(type="message_done")
|
||||
message_started = False
|
||||
# if reasoning_started:
|
||||
# yield RunItemStreamEvent(type="reasoning_done")
|
||||
# reasoning_started = False
|
||||
# if message_started:
|
||||
# yield RunItemStreamEvent(type="message_done")
|
||||
# message_started = False
|
||||
|
||||
if tool_choice != "none":
|
||||
_try_convert_content_to_tool_calls_for_non_tool_calling_llms(
|
||||
tool_calls_in_progress,
|
||||
content_parts,
|
||||
structured_response_format,
|
||||
_next_synthetic_tool_call_id,
|
||||
)
|
||||
# if tool_choice != "none":
|
||||
# _try_convert_content_to_tool_calls_for_non_tool_calling_llms(
|
||||
# tool_calls_in_progress,
|
||||
# content_parts,
|
||||
# structured_response_format,
|
||||
# _next_synthetic_tool_call_id,
|
||||
# )
|
||||
|
||||
if content_parts:
|
||||
new_messages_stateful.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "".join(content_parts),
|
||||
}
|
||||
)
|
||||
span_generation.span_data.output = new_messages_stateful
|
||||
# if content_parts:
|
||||
# new_messages_stateful.append(
|
||||
# {
|
||||
# "role": "assistant",
|
||||
# "content": "".join(content_parts),
|
||||
# }
|
||||
# )
|
||||
# span_generation.span_data.output = new_messages_stateful
|
||||
|
||||
# Execute tool calls outside of the stream loop and generation_span
|
||||
if tool_calls_in_progress:
|
||||
sorted_tool_calls = sorted(tool_calls_in_progress.items())
|
||||
# # Execute tool calls outside of the stream loop and generation_span
|
||||
# if tool_calls_in_progress:
|
||||
# sorted_tool_calls = sorted(tool_calls_in_progress.items())
|
||||
|
||||
# Build tool calls for the message and execute tools
|
||||
assistant_tool_calls: list[ToolCall] = []
|
||||
tool_outputs: dict[str, str] = {}
|
||||
# # Build tool calls for the message and execute tools
|
||||
# assistant_tool_calls: list[ToolCall] = []
|
||||
|
||||
for _, tool_call_data in sorted_tool_calls:
|
||||
call_id = tool_call_data["id"]
|
||||
name = tool_call_data["name"]
|
||||
arguments_str = tool_call_data["arguments"]
|
||||
# for _, tool_call_data in sorted_tool_calls:
|
||||
# call_id = tool_call_data["id"]
|
||||
# name = tool_call_data["name"]
|
||||
# arguments_str = tool_call_data["arguments"]
|
||||
|
||||
if call_id is None or name is None:
|
||||
continue
|
||||
# if call_id is None or name is None:
|
||||
# continue
|
||||
|
||||
assistant_tool_calls.append(
|
||||
{
|
||||
"id": call_id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": name,
|
||||
"arguments": arguments_str,
|
||||
},
|
||||
}
|
||||
)
|
||||
# assistant_tool_calls.append(
|
||||
# {
|
||||
# "id": call_id,
|
||||
# "type": "function",
|
||||
# "function": {
|
||||
# "name": name,
|
||||
# "arguments": arguments_str,
|
||||
# },
|
||||
# }
|
||||
# )
|
||||
|
||||
yield RunItemStreamEvent(
|
||||
type="tool_call",
|
||||
details=ToolCallStreamItem(
|
||||
call_id=call_id,
|
||||
name=name,
|
||||
arguments=arguments_str,
|
||||
),
|
||||
)
|
||||
# yield RunItemStreamEvent(
|
||||
# type="tool_call",
|
||||
# details=ToolCallStreamItem(
|
||||
# call_id=call_id,
|
||||
# name=name,
|
||||
# arguments=arguments_str,
|
||||
# ),
|
||||
# )
|
||||
|
||||
if name in tools_by_name:
|
||||
tool = tools_by_name[name]
|
||||
arguments = json.loads(arguments_str)
|
||||
# if name in tools_by_name:
|
||||
# tools_by_name[name]
|
||||
# json.loads(arguments_str)
|
||||
|
||||
run_context = RunContextWrapper(context=context)
|
||||
# run_context = RunContextWrapper(context=context)
|
||||
|
||||
# TODO: Instead of executing sequentially, execute in parallel
|
||||
# In practice, it's not a must right now since we don't use parallel
|
||||
# tool calls, so kicking the can down the road for now.
|
||||
with function_span(tool.name) as span_fn:
|
||||
span_fn.span_data.input = arguments
|
||||
try:
|
||||
output = tool.run_v2(run_context, **arguments)
|
||||
tool_outputs[call_id] = _serialize_tool_output(output)
|
||||
span_fn.span_data.output = output
|
||||
except Exception as e:
|
||||
_error_tracing.attach_error_to_current_span(
|
||||
SpanError(
|
||||
message="Error running tool",
|
||||
data={"tool_name": tool.name, "error": str(e)},
|
||||
)
|
||||
)
|
||||
# Treat the error as the tool output so the framework can continue
|
||||
error_output = f"Error: {str(e)}"
|
||||
tool_outputs[call_id] = error_output
|
||||
output = error_output
|
||||
# TODO: Instead of executing sequentially, execute in parallel
|
||||
# In practice, it's not a must right now since we don't use parallel
|
||||
# tool calls, so kicking the can down the road for now.
|
||||
|
||||
yield RunItemStreamEvent(
|
||||
type="tool_call_output",
|
||||
details=ToolCallOutputStreamItem(
|
||||
call_id=call_id,
|
||||
output=output,
|
||||
),
|
||||
)
|
||||
else:
|
||||
not_found_output = f"Tool {name} not found"
|
||||
tool_outputs[call_id] = _serialize_tool_output(not_found_output)
|
||||
yield RunItemStreamEvent(
|
||||
type="tool_call_output",
|
||||
details=ToolCallOutputStreamItem(
|
||||
call_id=call_id,
|
||||
output=not_found_output,
|
||||
),
|
||||
)
|
||||
# TODO broken for now, no need for a run_v2
|
||||
# output = tool.run_v2(run_context, **arguments)
|
||||
|
||||
new_messages_stateful.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": assistant_tool_calls,
|
||||
}
|
||||
)
|
||||
|
||||
for _, tool_call_data in sorted_tool_calls:
|
||||
call_id = tool_call_data["id"]
|
||||
|
||||
if call_id in tool_outputs:
|
||||
new_messages_stateful.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"content": tool_outputs[call_id],
|
||||
"tool_call_id": call_id,
|
||||
}
|
||||
)
|
||||
current_span.finish(reset_current=True)
|
||||
|
||||
return QueryResult(
|
||||
stream=stream_generator(),
|
||||
new_messages_stateful=new_messages_stateful,
|
||||
)
|
||||
# yield RunItemStreamEvent(
|
||||
# type="tool_call_output",
|
||||
# details=ToolCallOutputStreamItem(
|
||||
# call_id=call_id,
|
||||
# output=output,
|
||||
# ),
|
||||
# )
|
||||
|
||||
@@ -1,21 +1,21 @@
|
||||
from operator import add
|
||||
from typing import Annotated
|
||||
# from operator import add
|
||||
# from typing import Annotated
|
||||
|
||||
from pydantic import BaseModel
|
||||
# from pydantic import BaseModel
|
||||
|
||||
|
||||
class CoreState(BaseModel):
|
||||
"""
|
||||
This is the core state that is shared across all subgraphs.
|
||||
"""
|
||||
# class CoreState(BaseModel):
|
||||
# """
|
||||
# This is the core state that is shared across all subgraphs.
|
||||
# """
|
||||
|
||||
log_messages: Annotated[list[str], add] = []
|
||||
current_step_nr: int = 1
|
||||
# log_messages: Annotated[list[str], add] = []
|
||||
# current_step_nr: int = 1
|
||||
|
||||
|
||||
class SubgraphCoreState(BaseModel):
|
||||
"""
|
||||
This is the core state that is shared across all subgraphs.
|
||||
"""
|
||||
# class SubgraphCoreState(BaseModel):
|
||||
# """
|
||||
# This is the core state that is shared across all subgraphs.
|
||||
# """
|
||||
|
||||
log_messages: Annotated[list[str], add] = []
|
||||
# log_messages: Annotated[list[str], add] = []
|
||||
|
||||
@@ -1,62 +1,62 @@
|
||||
from collections.abc import Hashable
|
||||
from typing import cast
|
||||
# from collections.abc import Hashable
|
||||
# from typing import cast
|
||||
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
from langgraph.types import Send
|
||||
# from langchain_core.runnables.config import RunnableConfig
|
||||
# from langgraph.types import Send
|
||||
|
||||
from onyx.agents.agent_search.dc_search_analysis.states import ObjectInformationInput
|
||||
from onyx.agents.agent_search.dc_search_analysis.states import (
|
||||
ObjectResearchInformationUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.dc_search_analysis.states import ObjectSourceInput
|
||||
from onyx.agents.agent_search.dc_search_analysis.states import (
|
||||
SearchSourcesObjectsUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
# from onyx.agents.agent_search.dc_search_analysis.states import ObjectInformationInput
|
||||
# from onyx.agents.agent_search.dc_search_analysis.states import (
|
||||
# ObjectResearchInformationUpdate,
|
||||
# )
|
||||
# from onyx.agents.agent_search.dc_search_analysis.states import ObjectSourceInput
|
||||
# from onyx.agents.agent_search.dc_search_analysis.states import (
|
||||
# SearchSourcesObjectsUpdate,
|
||||
# )
|
||||
# from onyx.agents.agent_search.models import GraphConfig
|
||||
|
||||
|
||||
def parallel_object_source_research_edge(
|
||||
state: SearchSourcesObjectsUpdate, config: RunnableConfig
|
||||
) -> list[Send | Hashable]:
|
||||
"""
|
||||
LangGraph edge to parallelize the research for an individual object and source
|
||||
"""
|
||||
# def parallel_object_source_research_edge(
|
||||
# state: SearchSourcesObjectsUpdate, config: RunnableConfig
|
||||
# ) -> list[Send | Hashable]:
|
||||
# """
|
||||
# LangGraph edge to parallelize the research for an individual object and source
|
||||
# """
|
||||
|
||||
search_objects = state.analysis_objects
|
||||
search_sources = state.analysis_sources
|
||||
# search_objects = state.analysis_objects
|
||||
# search_sources = state.analysis_sources
|
||||
|
||||
object_source_combinations = [
|
||||
(object, source) for object in search_objects for source in search_sources
|
||||
]
|
||||
# object_source_combinations = [
|
||||
# (object, source) for object in search_objects for source in search_sources
|
||||
# ]
|
||||
|
||||
return [
|
||||
Send(
|
||||
"research_object_source",
|
||||
ObjectSourceInput(
|
||||
object_source_combination=object_source_combination,
|
||||
log_messages=[],
|
||||
),
|
||||
)
|
||||
for object_source_combination in object_source_combinations
|
||||
]
|
||||
# return [
|
||||
# Send(
|
||||
# "research_object_source",
|
||||
# ObjectSourceInput(
|
||||
# object_source_combination=object_source_combination,
|
||||
# log_messages=[],
|
||||
# ),
|
||||
# )
|
||||
# for object_source_combination in object_source_combinations
|
||||
# ]
|
||||
|
||||
|
||||
def parallel_object_research_consolidation_edge(
|
||||
state: ObjectResearchInformationUpdate, config: RunnableConfig
|
||||
) -> list[Send | Hashable]:
|
||||
"""
|
||||
LangGraph edge to parallelize the research for an individual object and source
|
||||
"""
|
||||
cast(GraphConfig, config["metadata"]["config"])
|
||||
object_research_information_results = state.object_research_information_results
|
||||
# def parallel_object_research_consolidation_edge(
|
||||
# state: ObjectResearchInformationUpdate, config: RunnableConfig
|
||||
# ) -> list[Send | Hashable]:
|
||||
# """
|
||||
# LangGraph edge to parallelize the research for an individual object and source
|
||||
# """
|
||||
# cast(GraphConfig, config["metadata"]["config"])
|
||||
# object_research_information_results = state.object_research_information_results
|
||||
|
||||
return [
|
||||
Send(
|
||||
"consolidate_object_research",
|
||||
ObjectInformationInput(
|
||||
object_information=object_information,
|
||||
log_messages=[],
|
||||
),
|
||||
)
|
||||
for object_information in object_research_information_results
|
||||
]
|
||||
# return [
|
||||
# Send(
|
||||
# "consolidate_object_research",
|
||||
# ObjectInformationInput(
|
||||
# object_information=object_information,
|
||||
# log_messages=[],
|
||||
# ),
|
||||
# )
|
||||
# for object_information in object_research_information_results
|
||||
# ]
|
||||
|
||||
@@ -1,103 +1,103 @@
|
||||
from langgraph.graph import END
|
||||
from langgraph.graph import START
|
||||
from langgraph.graph import StateGraph
|
||||
# from langgraph.graph import END
|
||||
# from langgraph.graph import START
|
||||
# from langgraph.graph import StateGraph
|
||||
|
||||
from onyx.agents.agent_search.dc_search_analysis.edges import (
|
||||
parallel_object_research_consolidation_edge,
|
||||
)
|
||||
from onyx.agents.agent_search.dc_search_analysis.edges import (
|
||||
parallel_object_source_research_edge,
|
||||
)
|
||||
from onyx.agents.agent_search.dc_search_analysis.nodes.a1_search_objects import (
|
||||
search_objects,
|
||||
)
|
||||
from onyx.agents.agent_search.dc_search_analysis.nodes.a2_research_object_source import (
|
||||
research_object_source,
|
||||
)
|
||||
from onyx.agents.agent_search.dc_search_analysis.nodes.a3_structure_research_by_object import (
|
||||
structure_research_by_object,
|
||||
)
|
||||
from onyx.agents.agent_search.dc_search_analysis.nodes.a4_consolidate_object_research import (
|
||||
consolidate_object_research,
|
||||
)
|
||||
from onyx.agents.agent_search.dc_search_analysis.nodes.a5_consolidate_research import (
|
||||
consolidate_research,
|
||||
)
|
||||
from onyx.agents.agent_search.dc_search_analysis.states import MainInput
|
||||
from onyx.agents.agent_search.dc_search_analysis.states import MainState
|
||||
from onyx.utils.logger import setup_logger
|
||||
# from onyx.agents.agent_search.dc_search_analysis.edges import (
|
||||
# parallel_object_research_consolidation_edge,
|
||||
# )
|
||||
# from onyx.agents.agent_search.dc_search_analysis.edges import (
|
||||
# parallel_object_source_research_edge,
|
||||
# )
|
||||
# from onyx.agents.agent_search.dc_search_analysis.nodes.a1_search_objects import (
|
||||
# search_objects,
|
||||
# )
|
||||
# from onyx.agents.agent_search.dc_search_analysis.nodes.a2_research_object_source import (
|
||||
# research_object_source,
|
||||
# )
|
||||
# from onyx.agents.agent_search.dc_search_analysis.nodes.a3_structure_research_by_object import (
|
||||
# structure_research_by_object,
|
||||
# )
|
||||
# from onyx.agents.agent_search.dc_search_analysis.nodes.a4_consolidate_object_research import (
|
||||
# consolidate_object_research,
|
||||
# )
|
||||
# from onyx.agents.agent_search.dc_search_analysis.nodes.a5_consolidate_research import (
|
||||
# consolidate_research,
|
||||
# )
|
||||
# from onyx.agents.agent_search.dc_search_analysis.states import MainInput
|
||||
# from onyx.agents.agent_search.dc_search_analysis.states import MainState
|
||||
# from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
# logger = setup_logger()
|
||||
|
||||
test_mode = False
|
||||
# test_mode = False
|
||||
|
||||
|
||||
def divide_and_conquer_graph_builder(test_mode: bool = False) -> StateGraph:
|
||||
"""
|
||||
LangGraph graph builder for the knowledge graph search process.
|
||||
"""
|
||||
# def divide_and_conquer_graph_builder(test_mode: bool = False) -> StateGraph:
|
||||
# """
|
||||
# LangGraph graph builder for the knowledge graph search process.
|
||||
# """
|
||||
|
||||
graph = StateGraph(
|
||||
state_schema=MainState,
|
||||
input=MainInput,
|
||||
)
|
||||
# graph = StateGraph(
|
||||
# state_schema=MainState,
|
||||
# input=MainInput,
|
||||
# )
|
||||
|
||||
### Add nodes ###
|
||||
# ### Add nodes ###
|
||||
|
||||
graph.add_node(
|
||||
"search_objects",
|
||||
search_objects,
|
||||
)
|
||||
# graph.add_node(
|
||||
# "search_objects",
|
||||
# search_objects,
|
||||
# )
|
||||
|
||||
graph.add_node(
|
||||
"structure_research_by_source",
|
||||
structure_research_by_object,
|
||||
)
|
||||
# graph.add_node(
|
||||
# "structure_research_by_source",
|
||||
# structure_research_by_object,
|
||||
# )
|
||||
|
||||
graph.add_node(
|
||||
"research_object_source",
|
||||
research_object_source,
|
||||
)
|
||||
# graph.add_node(
|
||||
# "research_object_source",
|
||||
# research_object_source,
|
||||
# )
|
||||
|
||||
graph.add_node(
|
||||
"consolidate_object_research",
|
||||
consolidate_object_research,
|
||||
)
|
||||
# graph.add_node(
|
||||
# "consolidate_object_research",
|
||||
# consolidate_object_research,
|
||||
# )
|
||||
|
||||
graph.add_node(
|
||||
"consolidate_research",
|
||||
consolidate_research,
|
||||
)
|
||||
# graph.add_node(
|
||||
# "consolidate_research",
|
||||
# consolidate_research,
|
||||
# )
|
||||
|
||||
### Add edges ###
|
||||
# ### Add edges ###
|
||||
|
||||
graph.add_edge(start_key=START, end_key="search_objects")
|
||||
# graph.add_edge(start_key=START, end_key="search_objects")
|
||||
|
||||
graph.add_conditional_edges(
|
||||
source="search_objects",
|
||||
path=parallel_object_source_research_edge,
|
||||
path_map=["research_object_source"],
|
||||
)
|
||||
# graph.add_conditional_edges(
|
||||
# source="search_objects",
|
||||
# path=parallel_object_source_research_edge,
|
||||
# path_map=["research_object_source"],
|
||||
# )
|
||||
|
||||
graph.add_edge(
|
||||
start_key="research_object_source",
|
||||
end_key="structure_research_by_source",
|
||||
)
|
||||
# graph.add_edge(
|
||||
# start_key="research_object_source",
|
||||
# end_key="structure_research_by_source",
|
||||
# )
|
||||
|
||||
graph.add_conditional_edges(
|
||||
source="structure_research_by_source",
|
||||
path=parallel_object_research_consolidation_edge,
|
||||
path_map=["consolidate_object_research"],
|
||||
)
|
||||
# graph.add_conditional_edges(
|
||||
# source="structure_research_by_source",
|
||||
# path=parallel_object_research_consolidation_edge,
|
||||
# path_map=["consolidate_object_research"],
|
||||
# )
|
||||
|
||||
graph.add_edge(
|
||||
start_key="consolidate_object_research",
|
||||
end_key="consolidate_research",
|
||||
)
|
||||
# graph.add_edge(
|
||||
# start_key="consolidate_object_research",
|
||||
# end_key="consolidate_research",
|
||||
# )
|
||||
|
||||
graph.add_edge(
|
||||
start_key="consolidate_research",
|
||||
end_key=END,
|
||||
)
|
||||
# graph.add_edge(
|
||||
# start_key="consolidate_research",
|
||||
# end_key=END,
|
||||
# )
|
||||
|
||||
return graph
|
||||
# return graph
|
||||
|
||||
@@ -1,146 +1,146 @@
|
||||
from typing import cast
|
||||
# from typing import cast
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
# from langchain_core.messages import HumanMessage
|
||||
# from langchain_core.runnables import RunnableConfig
|
||||
# from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.dc_search_analysis.ops import extract_section
|
||||
from onyx.agents.agent_search.dc_search_analysis.ops import research
|
||||
from onyx.agents.agent_search.dc_search_analysis.states import MainState
|
||||
from onyx.agents.agent_search.dc_search_analysis.states import (
|
||||
SearchSourcesObjectsUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
trim_prompt_piece,
|
||||
)
|
||||
from onyx.prompts.agents.dc_prompts import DC_OBJECT_NO_BASE_DATA_EXTRACTION_PROMPT
|
||||
from onyx.prompts.agents.dc_prompts import DC_OBJECT_SEPARATOR
|
||||
from onyx.prompts.agents.dc_prompts import DC_OBJECT_WITH_BASE_DATA_EXTRACTION_PROMPT
|
||||
from onyx.secondary_llm_flows.source_filter import strings_to_document_sources
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
# from onyx.agents.agent_search.dc_search_analysis.ops import extract_section
|
||||
# from onyx.agents.agent_search.dc_search_analysis.ops import research
|
||||
# from onyx.agents.agent_search.dc_search_analysis.states import MainState
|
||||
# from onyx.agents.agent_search.dc_search_analysis.states import (
|
||||
# SearchSourcesObjectsUpdate,
|
||||
# )
|
||||
# from onyx.agents.agent_search.models import GraphConfig
|
||||
# from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
# trim_prompt_piece,
|
||||
# )
|
||||
# from onyx.prompts.agents.dc_prompts import DC_OBJECT_NO_BASE_DATA_EXTRACTION_PROMPT
|
||||
# from onyx.prompts.agents.dc_prompts import DC_OBJECT_SEPARATOR
|
||||
# from onyx.prompts.agents.dc_prompts import DC_OBJECT_WITH_BASE_DATA_EXTRACTION_PROMPT
|
||||
# from onyx.secondary_llm_flows.source_filter import strings_to_document_sources
|
||||
# from onyx.utils.logger import setup_logger
|
||||
# from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
|
||||
logger = setup_logger()
|
||||
# logger = setup_logger()
|
||||
|
||||
|
||||
def search_objects(
|
||||
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> SearchSourcesObjectsUpdate:
|
||||
"""
|
||||
LangGraph node to start the agentic search process.
|
||||
"""
|
||||
# def search_objects(
|
||||
# state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
# ) -> SearchSourcesObjectsUpdate:
|
||||
# """
|
||||
# LangGraph node to start the agentic search process.
|
||||
# """
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
question = graph_config.inputs.prompt_builder.raw_user_query
|
||||
search_tool = graph_config.tooling.search_tool
|
||||
# graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
# question = graph_config.inputs.prompt_builder.raw_user_query
|
||||
# search_tool = graph_config.tooling.search_tool
|
||||
|
||||
if search_tool is None or graph_config.inputs.persona is None:
|
||||
raise ValueError("Search tool and persona must be provided for DivCon search")
|
||||
# if search_tool is None or graph_config.inputs.persona is None:
|
||||
# raise ValueError("Search tool and persona must be provided for DivCon search")
|
||||
|
||||
try:
|
||||
instructions = graph_config.inputs.persona.system_prompt or ""
|
||||
# try:
|
||||
# instructions = graph_config.inputs.persona.system_prompt or ""
|
||||
|
||||
agent_1_instructions = extract_section(
|
||||
instructions, "Agent Step 1:", "Agent Step 2:"
|
||||
)
|
||||
if agent_1_instructions is None:
|
||||
raise ValueError("Agent 1 instructions not found")
|
||||
# agent_1_instructions = extract_section(
|
||||
# instructions, "Agent Step 1:", "Agent Step 2:"
|
||||
# )
|
||||
# if agent_1_instructions is None:
|
||||
# raise ValueError("Agent 1 instructions not found")
|
||||
|
||||
agent_1_base_data = extract_section(instructions, "|Start Data|", "|End Data|")
|
||||
# agent_1_base_data = extract_section(instructions, "|Start Data|", "|End Data|")
|
||||
|
||||
agent_1_task = extract_section(
|
||||
agent_1_instructions, "Task:", "Independent Research Sources:"
|
||||
)
|
||||
if agent_1_task is None:
|
||||
raise ValueError("Agent 1 task not found")
|
||||
# agent_1_task = extract_section(
|
||||
# agent_1_instructions, "Task:", "Independent Research Sources:"
|
||||
# )
|
||||
# if agent_1_task is None:
|
||||
# raise ValueError("Agent 1 task not found")
|
||||
|
||||
agent_1_independent_sources_str = extract_section(
|
||||
agent_1_instructions, "Independent Research Sources:", "Output Objective:"
|
||||
)
|
||||
if agent_1_independent_sources_str is None:
|
||||
raise ValueError("Agent 1 Independent Research Sources not found")
|
||||
# agent_1_independent_sources_str = extract_section(
|
||||
# agent_1_instructions, "Independent Research Sources:", "Output Objective:"
|
||||
# )
|
||||
# if agent_1_independent_sources_str is None:
|
||||
# raise ValueError("Agent 1 Independent Research Sources not found")
|
||||
|
||||
document_sources = strings_to_document_sources(
|
||||
[
|
||||
x.strip().lower()
|
||||
for x in agent_1_independent_sources_str.split(DC_OBJECT_SEPARATOR)
|
||||
]
|
||||
)
|
||||
# document_sources = strings_to_document_sources(
|
||||
# [
|
||||
# x.strip().lower()
|
||||
# for x in agent_1_independent_sources_str.split(DC_OBJECT_SEPARATOR)
|
||||
# ]
|
||||
# )
|
||||
|
||||
agent_1_output_objective = extract_section(
|
||||
agent_1_instructions, "Output Objective:"
|
||||
)
|
||||
if agent_1_output_objective is None:
|
||||
raise ValueError("Agent 1 output objective not found")
|
||||
# agent_1_output_objective = extract_section(
|
||||
# agent_1_instructions, "Output Objective:"
|
||||
# )
|
||||
# if agent_1_output_objective is None:
|
||||
# raise ValueError("Agent 1 output objective not found")
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Agent 1 instructions not found or not formatted correctly: {e}"
|
||||
)
|
||||
# except Exception as e:
|
||||
# raise ValueError(
|
||||
# f"Agent 1 instructions not found or not formatted correctly: {e}"
|
||||
# )
|
||||
|
||||
# Extract objects
|
||||
# # Extract objects
|
||||
|
||||
if agent_1_base_data is None:
|
||||
# Retrieve chunks for objects
|
||||
# if agent_1_base_data is None:
|
||||
# # Retrieve chunks for objects
|
||||
|
||||
retrieved_docs = research(question, search_tool)[:10]
|
||||
# retrieved_docs = research(question, search_tool)[:10]
|
||||
|
||||
document_texts_list = []
|
||||
for doc_num, doc in enumerate(retrieved_docs):
|
||||
chunk_text = "Document " + str(doc_num) + ":\n" + doc.content
|
||||
document_texts_list.append(chunk_text)
|
||||
# document_texts_list = []
|
||||
# for doc_num, doc in enumerate(retrieved_docs):
|
||||
# chunk_text = "Document " + str(doc_num) + ":\n" + doc.content
|
||||
# document_texts_list.append(chunk_text)
|
||||
|
||||
document_texts = "\n\n".join(document_texts_list)
|
||||
# document_texts = "\n\n".join(document_texts_list)
|
||||
|
||||
dc_object_extraction_prompt = DC_OBJECT_NO_BASE_DATA_EXTRACTION_PROMPT.format(
|
||||
question=question,
|
||||
task=agent_1_task,
|
||||
document_text=document_texts,
|
||||
objects_of_interest=agent_1_output_objective,
|
||||
)
|
||||
else:
|
||||
dc_object_extraction_prompt = DC_OBJECT_WITH_BASE_DATA_EXTRACTION_PROMPT.format(
|
||||
question=question,
|
||||
task=agent_1_task,
|
||||
base_data=agent_1_base_data,
|
||||
objects_of_interest=agent_1_output_objective,
|
||||
)
|
||||
# dc_object_extraction_prompt = DC_OBJECT_NO_BASE_DATA_EXTRACTION_PROMPT.format(
|
||||
# question=question,
|
||||
# task=agent_1_task,
|
||||
# document_text=document_texts,
|
||||
# objects_of_interest=agent_1_output_objective,
|
||||
# )
|
||||
# else:
|
||||
# dc_object_extraction_prompt = DC_OBJECT_WITH_BASE_DATA_EXTRACTION_PROMPT.format(
|
||||
# question=question,
|
||||
# task=agent_1_task,
|
||||
# base_data=agent_1_base_data,
|
||||
# objects_of_interest=agent_1_output_objective,
|
||||
# )
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=trim_prompt_piece(
|
||||
config=graph_config.tooling.primary_llm.config,
|
||||
prompt_piece=dc_object_extraction_prompt,
|
||||
reserved_str="",
|
||||
),
|
||||
)
|
||||
]
|
||||
primary_llm = graph_config.tooling.primary_llm
|
||||
# Grader
|
||||
try:
|
||||
llm_response = run_with_timeout(
|
||||
30,
|
||||
primary_llm.invoke_langchain,
|
||||
prompt=msg,
|
||||
timeout_override=30,
|
||||
max_tokens=300,
|
||||
)
|
||||
# msg = [
|
||||
# HumanMessage(
|
||||
# content=trim_prompt_piece(
|
||||
# config=graph_config.tooling.primary_llm.config,
|
||||
# prompt_piece=dc_object_extraction_prompt,
|
||||
# reserved_str="",
|
||||
# ),
|
||||
# )
|
||||
# ]
|
||||
# primary_llm = graph_config.tooling.primary_llm
|
||||
# # Grader
|
||||
# try:
|
||||
# llm_response = run_with_timeout(
|
||||
# 30,
|
||||
# primary_llm.invoke_langchain,
|
||||
# prompt=msg,
|
||||
# timeout_override=30,
|
||||
# max_tokens=300,
|
||||
# )
|
||||
|
||||
cleaned_response = (
|
||||
str(llm_response.content)
|
||||
.replace("```json\n", "")
|
||||
.replace("\n```", "")
|
||||
.replace("\n", "")
|
||||
)
|
||||
cleaned_response = cleaned_response.split("OBJECTS:")[1]
|
||||
object_list = [x.strip() for x in cleaned_response.split(";")]
|
||||
# cleaned_response = (
|
||||
# str(llm_response.content)
|
||||
# .replace("```json\n", "")
|
||||
# .replace("\n```", "")
|
||||
# .replace("\n", "")
|
||||
# )
|
||||
# cleaned_response = cleaned_response.split("OBJECTS:")[1]
|
||||
# object_list = [x.strip() for x in cleaned_response.split(";")]
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error in search_objects: {e}")
|
||||
# except Exception as e:
|
||||
# raise ValueError(f"Error in search_objects: {e}")
|
||||
|
||||
return SearchSourcesObjectsUpdate(
|
||||
analysis_objects=object_list,
|
||||
analysis_sources=document_sources,
|
||||
log_messages=["Agent 1 Task done"],
|
||||
)
|
||||
# return SearchSourcesObjectsUpdate(
|
||||
# analysis_objects=object_list,
|
||||
# analysis_sources=document_sources,
|
||||
# log_messages=["Agent 1 Task done"],
|
||||
# )
|
||||
|
||||
@@ -1,180 +1,180 @@
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from typing import cast
|
||||
# from datetime import datetime
|
||||
# from datetime import timedelta
|
||||
# from datetime import timezone
|
||||
# from typing import cast
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
# from langchain_core.messages import HumanMessage
|
||||
# from langchain_core.runnables import RunnableConfig
|
||||
# from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.dc_search_analysis.ops import extract_section
|
||||
from onyx.agents.agent_search.dc_search_analysis.ops import research
|
||||
from onyx.agents.agent_search.dc_search_analysis.states import ObjectSourceInput
|
||||
from onyx.agents.agent_search.dc_search_analysis.states import (
|
||||
ObjectSourceResearchUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
trim_prompt_piece,
|
||||
)
|
||||
from onyx.prompts.agents.dc_prompts import DC_OBJECT_SOURCE_RESEARCH_PROMPT
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
# from onyx.agents.agent_search.dc_search_analysis.ops import extract_section
|
||||
# from onyx.agents.agent_search.dc_search_analysis.ops import research
|
||||
# from onyx.agents.agent_search.dc_search_analysis.states import ObjectSourceInput
|
||||
# from onyx.agents.agent_search.dc_search_analysis.states import (
|
||||
# ObjectSourceResearchUpdate,
|
||||
# )
|
||||
# from onyx.agents.agent_search.models import GraphConfig
|
||||
# from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
# trim_prompt_piece,
|
||||
# )
|
||||
# from onyx.prompts.agents.dc_prompts import DC_OBJECT_SOURCE_RESEARCH_PROMPT
|
||||
# from onyx.utils.logger import setup_logger
|
||||
# from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
|
||||
logger = setup_logger()
|
||||
# logger = setup_logger()
|
||||
|
||||
|
||||
def research_object_source(
|
||||
state: ObjectSourceInput,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
) -> ObjectSourceResearchUpdate:
|
||||
"""
|
||||
LangGraph node to start the agentic search process.
|
||||
"""
|
||||
datetime.now()
|
||||
# def research_object_source(
|
||||
# state: ObjectSourceInput,
|
||||
# config: RunnableConfig,
|
||||
# writer: StreamWriter = lambda _: None,
|
||||
# ) -> ObjectSourceResearchUpdate:
|
||||
# """
|
||||
# LangGraph node to start the agentic search process.
|
||||
# """
|
||||
# datetime.now()
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
search_tool = graph_config.tooling.search_tool
|
||||
question = graph_config.inputs.prompt_builder.raw_user_query
|
||||
object, document_source = state.object_source_combination
|
||||
# graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
# search_tool = graph_config.tooling.search_tool
|
||||
# question = graph_config.inputs.prompt_builder.raw_user_query
|
||||
# object, document_source = state.object_source_combination
|
||||
|
||||
if search_tool is None or graph_config.inputs.persona is None:
|
||||
raise ValueError("Search tool and persona must be provided for DivCon search")
|
||||
# if search_tool is None or graph_config.inputs.persona is None:
|
||||
# raise ValueError("Search tool and persona must be provided for DivCon search")
|
||||
|
||||
try:
|
||||
instructions = graph_config.inputs.persona.system_prompt or ""
|
||||
# try:
|
||||
# instructions = graph_config.inputs.persona.system_prompt or ""
|
||||
|
||||
agent_2_instructions = extract_section(
|
||||
instructions, "Agent Step 2:", "Agent Step 3:"
|
||||
)
|
||||
if agent_2_instructions is None:
|
||||
raise ValueError("Agent 2 instructions not found")
|
||||
# agent_2_instructions = extract_section(
|
||||
# instructions, "Agent Step 2:", "Agent Step 3:"
|
||||
# )
|
||||
# if agent_2_instructions is None:
|
||||
# raise ValueError("Agent 2 instructions not found")
|
||||
|
||||
agent_2_task = extract_section(
|
||||
agent_2_instructions, "Task:", "Independent Research Sources:"
|
||||
)
|
||||
if agent_2_task is None:
|
||||
raise ValueError("Agent 2 task not found")
|
||||
# agent_2_task = extract_section(
|
||||
# agent_2_instructions, "Task:", "Independent Research Sources:"
|
||||
# )
|
||||
# if agent_2_task is None:
|
||||
# raise ValueError("Agent 2 task not found")
|
||||
|
||||
agent_2_time_cutoff = extract_section(
|
||||
agent_2_instructions, "Time Cutoff:", "Research Topics:"
|
||||
)
|
||||
# agent_2_time_cutoff = extract_section(
|
||||
# agent_2_instructions, "Time Cutoff:", "Research Topics:"
|
||||
# )
|
||||
|
||||
agent_2_research_topics = extract_section(
|
||||
agent_2_instructions, "Research Topics:", "Output Objective"
|
||||
)
|
||||
# agent_2_research_topics = extract_section(
|
||||
# agent_2_instructions, "Research Topics:", "Output Objective"
|
||||
# )
|
||||
|
||||
agent_2_output_objective = extract_section(
|
||||
agent_2_instructions, "Output Objective:"
|
||||
)
|
||||
if agent_2_output_objective is None:
|
||||
raise ValueError("Agent 2 output objective not found")
|
||||
# agent_2_output_objective = extract_section(
|
||||
# agent_2_instructions, "Output Objective:"
|
||||
# )
|
||||
# if agent_2_output_objective is None:
|
||||
# raise ValueError("Agent 2 output objective not found")
|
||||
|
||||
except Exception:
|
||||
raise ValueError(
|
||||
"Agent 1 instructions not found or not formatted correctly: {e}"
|
||||
)
|
||||
# except Exception:
|
||||
# raise ValueError(
|
||||
# "Agent 1 instructions not found or not formatted correctly: {e}"
|
||||
# )
|
||||
|
||||
# Populate prompt
|
||||
# # Populate prompt
|
||||
|
||||
# Retrieve chunks for objects
|
||||
# # Retrieve chunks for objects
|
||||
|
||||
if agent_2_time_cutoff is not None and agent_2_time_cutoff.strip() != "":
|
||||
if agent_2_time_cutoff.strip().endswith("d"):
|
||||
try:
|
||||
days = int(agent_2_time_cutoff.strip()[:-1])
|
||||
agent_2_source_start_time = datetime.now(timezone.utc) - timedelta(
|
||||
days=days
|
||||
)
|
||||
except ValueError:
|
||||
raise ValueError(
|
||||
f"Invalid time cutoff format: {agent_2_time_cutoff}. Expected format: '<number>d'"
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid time cutoff format: {agent_2_time_cutoff}. Expected format: '<number>d'"
|
||||
)
|
||||
else:
|
||||
agent_2_source_start_time = None
|
||||
# if agent_2_time_cutoff is not None and agent_2_time_cutoff.strip() != "":
|
||||
# if agent_2_time_cutoff.strip().endswith("d"):
|
||||
# try:
|
||||
# days = int(agent_2_time_cutoff.strip()[:-1])
|
||||
# agent_2_source_start_time = datetime.now(timezone.utc) - timedelta(
|
||||
# days=days
|
||||
# )
|
||||
# except ValueError:
|
||||
# raise ValueError(
|
||||
# f"Invalid time cutoff format: {agent_2_time_cutoff}. Expected format: '<number>d'"
|
||||
# )
|
||||
# else:
|
||||
# raise ValueError(
|
||||
# f"Invalid time cutoff format: {agent_2_time_cutoff}. Expected format: '<number>d'"
|
||||
# )
|
||||
# else:
|
||||
# agent_2_source_start_time = None
|
||||
|
||||
document_sources = [document_source] if document_source else None
|
||||
# document_sources = [document_source] if document_source else None
|
||||
|
||||
if len(question.strip()) > 0:
|
||||
research_area = f"{question} for {object}"
|
||||
elif agent_2_research_topics and len(agent_2_research_topics.strip()) > 0:
|
||||
research_area = f"{agent_2_research_topics} for {object}"
|
||||
else:
|
||||
research_area = object
|
||||
# if len(question.strip()) > 0:
|
||||
# research_area = f"{question} for {object}"
|
||||
# elif agent_2_research_topics and len(agent_2_research_topics.strip()) > 0:
|
||||
# research_area = f"{agent_2_research_topics} for {object}"
|
||||
# else:
|
||||
# research_area = object
|
||||
|
||||
retrieved_docs = research(
|
||||
question=research_area,
|
||||
search_tool=search_tool,
|
||||
document_sources=document_sources,
|
||||
time_cutoff=agent_2_source_start_time,
|
||||
)
|
||||
# retrieved_docs = research(
|
||||
# question=research_area,
|
||||
# search_tool=search_tool,
|
||||
# document_sources=document_sources,
|
||||
# time_cutoff=agent_2_source_start_time,
|
||||
# )
|
||||
|
||||
# Generate document text
|
||||
# # Generate document text
|
||||
|
||||
document_texts_list = []
|
||||
for doc_num, doc in enumerate(retrieved_docs):
|
||||
chunk_text = "Document " + str(doc_num) + ":\n" + doc.content
|
||||
document_texts_list.append(chunk_text)
|
||||
# document_texts_list = []
|
||||
# for doc_num, doc in enumerate(retrieved_docs):
|
||||
# chunk_text = "Document " + str(doc_num) + ":\n" + doc.content
|
||||
# document_texts_list.append(chunk_text)
|
||||
|
||||
document_texts = "\n\n".join(document_texts_list)
|
||||
# document_texts = "\n\n".join(document_texts_list)
|
||||
|
||||
# Built prompt
|
||||
# # Built prompt
|
||||
|
||||
today = datetime.now().strftime("%A, %Y-%m-%d")
|
||||
# today = datetime.now().strftime("%A, %Y-%m-%d")
|
||||
|
||||
dc_object_source_research_prompt = (
|
||||
DC_OBJECT_SOURCE_RESEARCH_PROMPT.format(
|
||||
today=today,
|
||||
question=question,
|
||||
task=agent_2_task,
|
||||
document_text=document_texts,
|
||||
format=agent_2_output_objective,
|
||||
)
|
||||
.replace("---object---", object)
|
||||
.replace("---source---", document_source.value)
|
||||
)
|
||||
# dc_object_source_research_prompt = (
|
||||
# DC_OBJECT_SOURCE_RESEARCH_PROMPT.format(
|
||||
# today=today,
|
||||
# question=question,
|
||||
# task=agent_2_task,
|
||||
# document_text=document_texts,
|
||||
# format=agent_2_output_objective,
|
||||
# )
|
||||
# .replace("---object---", object)
|
||||
# .replace("---source---", document_source.value)
|
||||
# )
|
||||
|
||||
# Run LLM
|
||||
# # Run LLM
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=trim_prompt_piece(
|
||||
config=graph_config.tooling.primary_llm.config,
|
||||
prompt_piece=dc_object_source_research_prompt,
|
||||
reserved_str="",
|
||||
),
|
||||
)
|
||||
]
|
||||
primary_llm = graph_config.tooling.primary_llm
|
||||
# Grader
|
||||
try:
|
||||
llm_response = run_with_timeout(
|
||||
30,
|
||||
primary_llm.invoke_langchain,
|
||||
prompt=msg,
|
||||
timeout_override=30,
|
||||
max_tokens=300,
|
||||
)
|
||||
# msg = [
|
||||
# HumanMessage(
|
||||
# content=trim_prompt_piece(
|
||||
# config=graph_config.tooling.primary_llm.config,
|
||||
# prompt_piece=dc_object_source_research_prompt,
|
||||
# reserved_str="",
|
||||
# ),
|
||||
# )
|
||||
# ]
|
||||
# primary_llm = graph_config.tooling.primary_llm
|
||||
# # Grader
|
||||
# try:
|
||||
# llm_response = run_with_timeout(
|
||||
# 30,
|
||||
# primary_llm.invoke_langchain,
|
||||
# prompt=msg,
|
||||
# timeout_override=30,
|
||||
# max_tokens=300,
|
||||
# )
|
||||
|
||||
cleaned_response = str(llm_response.content).replace("```json\n", "")
|
||||
cleaned_response = cleaned_response.split("RESEARCH RESULTS:")[1]
|
||||
object_research_results = {
|
||||
"object": object,
|
||||
"source": document_source.value,
|
||||
"research_result": cleaned_response,
|
||||
}
|
||||
# cleaned_response = str(llm_response.content).replace("```json\n", "")
|
||||
# cleaned_response = cleaned_response.split("RESEARCH RESULTS:")[1]
|
||||
# object_research_results = {
|
||||
# "object": object,
|
||||
# "source": document_source.value,
|
||||
# "research_result": cleaned_response,
|
||||
# }
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error in research_object_source: {e}")
|
||||
# except Exception as e:
|
||||
# raise ValueError(f"Error in research_object_source: {e}")
|
||||
|
||||
logger.debug("DivCon Step A2 - Object Source Research - completed for an object")
|
||||
# logger.debug("DivCon Step A2 - Object Source Research - completed for an object")
|
||||
|
||||
return ObjectSourceResearchUpdate(
|
||||
object_source_research_results=[object_research_results],
|
||||
log_messages=["Agent Step 2 done for one object"],
|
||||
)
|
||||
# return ObjectSourceResearchUpdate(
|
||||
# object_source_research_results=[object_research_results],
|
||||
# log_messages=["Agent Step 2 done for one object"],
|
||||
# )
|
||||
|
||||
@@ -1,48 +1,48 @@
|
||||
from collections import defaultdict
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
# from collections import defaultdict
|
||||
# from typing import Dict
|
||||
# from typing import List
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
# from langchain_core.runnables import RunnableConfig
|
||||
# from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.dc_search_analysis.states import MainState
|
||||
from onyx.agents.agent_search.dc_search_analysis.states import (
|
||||
ObjectResearchInformationUpdate,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
# from onyx.agents.agent_search.dc_search_analysis.states import MainState
|
||||
# from onyx.agents.agent_search.dc_search_analysis.states import (
|
||||
# ObjectResearchInformationUpdate,
|
||||
# )
|
||||
# from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
# logger = setup_logger()
|
||||
|
||||
|
||||
def structure_research_by_object(
|
||||
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> ObjectResearchInformationUpdate:
|
||||
"""
|
||||
LangGraph node to start the agentic search process.
|
||||
"""
|
||||
# def structure_research_by_object(
|
||||
# state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
# ) -> ObjectResearchInformationUpdate:
|
||||
# """
|
||||
# LangGraph node to start the agentic search process.
|
||||
# """
|
||||
|
||||
object_source_research_results = state.object_source_research_results
|
||||
# object_source_research_results = state.object_source_research_results
|
||||
|
||||
object_research_information_results: List[Dict[str, str]] = []
|
||||
object_research_information_results_list: Dict[str, List[str]] = defaultdict(list)
|
||||
# object_research_information_results: List[Dict[str, str]] = []
|
||||
# object_research_information_results_list: Dict[str, List[str]] = defaultdict(list)
|
||||
|
||||
for object_source_research in object_source_research_results:
|
||||
object = object_source_research["object"]
|
||||
source = object_source_research["source"]
|
||||
research_result = object_source_research["research_result"]
|
||||
# for object_source_research in object_source_research_results:
|
||||
# object = object_source_research["object"]
|
||||
# source = object_source_research["source"]
|
||||
# research_result = object_source_research["research_result"]
|
||||
|
||||
object_research_information_results_list[object].append(
|
||||
f"Source: {source}\n{research_result}"
|
||||
)
|
||||
# object_research_information_results_list[object].append(
|
||||
# f"Source: {source}\n{research_result}"
|
||||
# )
|
||||
|
||||
for object, information in object_research_information_results_list.items():
|
||||
object_research_information_results.append(
|
||||
{"object": object, "information": "\n".join(information)}
|
||||
)
|
||||
# for object, information in object_research_information_results_list.items():
|
||||
# object_research_information_results.append(
|
||||
# {"object": object, "information": "\n".join(information)}
|
||||
# )
|
||||
|
||||
logger.debug("DivCon Step A3 - Object Research Information Structuring - completed")
|
||||
# logger.debug("DivCon Step A3 - Object Research Information Structuring - completed")
|
||||
|
||||
return ObjectResearchInformationUpdate(
|
||||
object_research_information_results=object_research_information_results,
|
||||
log_messages=["A3 - Object Research Information structured"],
|
||||
)
|
||||
# return ObjectResearchInformationUpdate(
|
||||
# object_research_information_results=object_research_information_results,
|
||||
# log_messages=["A3 - Object Research Information structured"],
|
||||
# )
|
||||
|
||||
@@ -1,103 +1,103 @@
|
||||
from typing import cast
|
||||
# from typing import cast
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
# from langchain_core.messages import HumanMessage
|
||||
# from langchain_core.runnables import RunnableConfig
|
||||
# from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.dc_search_analysis.ops import extract_section
|
||||
from onyx.agents.agent_search.dc_search_analysis.states import ObjectInformationInput
|
||||
from onyx.agents.agent_search.dc_search_analysis.states import ObjectResearchUpdate
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
trim_prompt_piece,
|
||||
)
|
||||
from onyx.prompts.agents.dc_prompts import DC_OBJECT_CONSOLIDATION_PROMPT
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
# from onyx.agents.agent_search.dc_search_analysis.ops import extract_section
|
||||
# from onyx.agents.agent_search.dc_search_analysis.states import ObjectInformationInput
|
||||
# from onyx.agents.agent_search.dc_search_analysis.states import ObjectResearchUpdate
|
||||
# from onyx.agents.agent_search.models import GraphConfig
|
||||
# from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
# trim_prompt_piece,
|
||||
# )
|
||||
# from onyx.prompts.agents.dc_prompts import DC_OBJECT_CONSOLIDATION_PROMPT
|
||||
# from onyx.utils.logger import setup_logger
|
||||
# from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
|
||||
logger = setup_logger()
|
||||
# logger = setup_logger()
|
||||
|
||||
|
||||
def consolidate_object_research(
|
||||
state: ObjectInformationInput,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
) -> ObjectResearchUpdate:
|
||||
"""
|
||||
LangGraph node to start the agentic search process.
|
||||
"""
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
search_tool = graph_config.tooling.search_tool
|
||||
question = graph_config.inputs.prompt_builder.raw_user_query
|
||||
# def consolidate_object_research(
|
||||
# state: ObjectInformationInput,
|
||||
# config: RunnableConfig,
|
||||
# writer: StreamWriter = lambda _: None,
|
||||
# ) -> ObjectResearchUpdate:
|
||||
# """
|
||||
# LangGraph node to start the agentic search process.
|
||||
# """
|
||||
# graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
# search_tool = graph_config.tooling.search_tool
|
||||
# question = graph_config.inputs.prompt_builder.raw_user_query
|
||||
|
||||
if search_tool is None or graph_config.inputs.persona is None:
|
||||
raise ValueError("Search tool and persona must be provided for DivCon search")
|
||||
# if search_tool is None or graph_config.inputs.persona is None:
|
||||
# raise ValueError("Search tool and persona must be provided for DivCon search")
|
||||
|
||||
instructions = graph_config.inputs.persona.system_prompt or ""
|
||||
# instructions = graph_config.inputs.persona.system_prompt or ""
|
||||
|
||||
agent_4_instructions = extract_section(
|
||||
instructions, "Agent Step 4:", "Agent Step 5:"
|
||||
)
|
||||
if agent_4_instructions is None:
|
||||
raise ValueError("Agent 4 instructions not found")
|
||||
agent_4_output_objective = extract_section(
|
||||
agent_4_instructions, "Output Objective:"
|
||||
)
|
||||
if agent_4_output_objective is None:
|
||||
raise ValueError("Agent 4 output objective not found")
|
||||
# agent_4_instructions = extract_section(
|
||||
# instructions, "Agent Step 4:", "Agent Step 5:"
|
||||
# )
|
||||
# if agent_4_instructions is None:
|
||||
# raise ValueError("Agent 4 instructions not found")
|
||||
# agent_4_output_objective = extract_section(
|
||||
# agent_4_instructions, "Output Objective:"
|
||||
# )
|
||||
# if agent_4_output_objective is None:
|
||||
# raise ValueError("Agent 4 output objective not found")
|
||||
|
||||
object_information = state.object_information
|
||||
# object_information = state.object_information
|
||||
|
||||
object = object_information["object"]
|
||||
information = object_information["information"]
|
||||
# object = object_information["object"]
|
||||
# information = object_information["information"]
|
||||
|
||||
# Create a prompt for the object consolidation
|
||||
# # Create a prompt for the object consolidation
|
||||
|
||||
dc_object_consolidation_prompt = DC_OBJECT_CONSOLIDATION_PROMPT.format(
|
||||
question=question,
|
||||
object=object,
|
||||
information=information,
|
||||
format=agent_4_output_objective,
|
||||
)
|
||||
# dc_object_consolidation_prompt = DC_OBJECT_CONSOLIDATION_PROMPT.format(
|
||||
# question=question,
|
||||
# object=object,
|
||||
# information=information,
|
||||
# format=agent_4_output_objective,
|
||||
# )
|
||||
|
||||
# Run LLM
|
||||
# # Run LLM
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=trim_prompt_piece(
|
||||
config=graph_config.tooling.primary_llm.config,
|
||||
prompt_piece=dc_object_consolidation_prompt,
|
||||
reserved_str="",
|
||||
),
|
||||
)
|
||||
]
|
||||
primary_llm = graph_config.tooling.primary_llm
|
||||
# Grader
|
||||
try:
|
||||
llm_response = run_with_timeout(
|
||||
30,
|
||||
primary_llm.invoke_langchain,
|
||||
prompt=msg,
|
||||
timeout_override=30,
|
||||
max_tokens=300,
|
||||
)
|
||||
# msg = [
|
||||
# HumanMessage(
|
||||
# content=trim_prompt_piece(
|
||||
# config=graph_config.tooling.primary_llm.config,
|
||||
# prompt_piece=dc_object_consolidation_prompt,
|
||||
# reserved_str="",
|
||||
# ),
|
||||
# )
|
||||
# ]
|
||||
# primary_llm = graph_config.tooling.primary_llm
|
||||
# # Grader
|
||||
# try:
|
||||
# llm_response = run_with_timeout(
|
||||
# 30,
|
||||
# primary_llm.invoke_langchain,
|
||||
# prompt=msg,
|
||||
# timeout_override=30,
|
||||
# max_tokens=300,
|
||||
# )
|
||||
|
||||
cleaned_response = str(llm_response.content).replace("```json\n", "")
|
||||
consolidated_information = cleaned_response.split("INFORMATION:")[1]
|
||||
# cleaned_response = str(llm_response.content).replace("```json\n", "")
|
||||
# consolidated_information = cleaned_response.split("INFORMATION:")[1]
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error in consolidate_object_research: {e}")
|
||||
# except Exception as e:
|
||||
# raise ValueError(f"Error in consolidate_object_research: {e}")
|
||||
|
||||
object_research_results = {
|
||||
"object": object,
|
||||
"research_result": consolidated_information,
|
||||
}
|
||||
# object_research_results = {
|
||||
# "object": object,
|
||||
# "research_result": consolidated_information,
|
||||
# }
|
||||
|
||||
logger.debug(
|
||||
"DivCon Step A4 - Object Research Consolidation - completed for an object"
|
||||
)
|
||||
# logger.debug(
|
||||
# "DivCon Step A4 - Object Research Consolidation - completed for an object"
|
||||
# )
|
||||
|
||||
return ObjectResearchUpdate(
|
||||
object_research_results=[object_research_results],
|
||||
log_messages=["Agent Source Consilidation done"],
|
||||
)
|
||||
# return ObjectResearchUpdate(
|
||||
# object_research_results=[object_research_results],
|
||||
# log_messages=["Agent Source Consilidation done"],
|
||||
# )
|
||||
|
||||
@@ -1,127 +1,127 @@
|
||||
from typing import cast
|
||||
# from typing import cast
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
# from langchain_core.messages import HumanMessage
|
||||
# from langchain_core.runnables import RunnableConfig
|
||||
# from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.dc_search_analysis.ops import extract_section
|
||||
from onyx.agents.agent_search.dc_search_analysis.states import MainState
|
||||
from onyx.agents.agent_search.dc_search_analysis.states import ResearchUpdate
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
trim_prompt_piece,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.llm import stream_llm_answer
|
||||
from onyx.prompts.agents.dc_prompts import DC_FORMATTING_NO_BASE_DATA_PROMPT
|
||||
from onyx.prompts.agents.dc_prompts import DC_FORMATTING_WITH_BASE_DATA_PROMPT
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
# from onyx.agents.agent_search.dc_search_analysis.ops import extract_section
|
||||
# from onyx.agents.agent_search.dc_search_analysis.states import MainState
|
||||
# from onyx.agents.agent_search.dc_search_analysis.states import ResearchUpdate
|
||||
# from onyx.agents.agent_search.models import GraphConfig
|
||||
# from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
# trim_prompt_piece,
|
||||
# )
|
||||
# from onyx.agents.agent_search.shared_graph_utils.llm import stream_llm_answer
|
||||
# from onyx.prompts.agents.dc_prompts import DC_FORMATTING_NO_BASE_DATA_PROMPT
|
||||
# from onyx.prompts.agents.dc_prompts import DC_FORMATTING_WITH_BASE_DATA_PROMPT
|
||||
# from onyx.utils.logger import setup_logger
|
||||
# from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
|
||||
logger = setup_logger()
|
||||
# logger = setup_logger()
|
||||
|
||||
|
||||
def consolidate_research(
|
||||
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> ResearchUpdate:
|
||||
"""
|
||||
LangGraph node to start the agentic search process.
|
||||
"""
|
||||
# def consolidate_research(
|
||||
# state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
# ) -> ResearchUpdate:
|
||||
# """
|
||||
# LangGraph node to start the agentic search process.
|
||||
# """
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
# graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
|
||||
search_tool = graph_config.tooling.search_tool
|
||||
# search_tool = graph_config.tooling.search_tool
|
||||
|
||||
if search_tool is None or graph_config.inputs.persona is None:
|
||||
raise ValueError("Search tool and persona must be provided for DivCon search")
|
||||
# if search_tool is None or graph_config.inputs.persona is None:
|
||||
# raise ValueError("Search tool and persona must be provided for DivCon search")
|
||||
|
||||
# Populate prompt
|
||||
instructions = graph_config.inputs.persona.system_prompt or ""
|
||||
# # Populate prompt
|
||||
# instructions = graph_config.inputs.persona.system_prompt or ""
|
||||
|
||||
try:
|
||||
agent_5_instructions = extract_section(
|
||||
instructions, "Agent Step 5:", "Agent End"
|
||||
)
|
||||
if agent_5_instructions is None:
|
||||
raise ValueError("Agent 5 instructions not found")
|
||||
agent_5_base_data = extract_section(instructions, "|Start Data|", "|End Data|")
|
||||
agent_5_task = extract_section(
|
||||
agent_5_instructions, "Task:", "Independent Research Sources:"
|
||||
)
|
||||
if agent_5_task is None:
|
||||
raise ValueError("Agent 5 task not found")
|
||||
agent_5_output_objective = extract_section(
|
||||
agent_5_instructions, "Output Objective:"
|
||||
)
|
||||
if agent_5_output_objective is None:
|
||||
raise ValueError("Agent 5 output objective not found")
|
||||
except ValueError as e:
|
||||
raise ValueError(
|
||||
f"Instructions for Agent Step 5 were not properly formatted: {e}"
|
||||
)
|
||||
# try:
|
||||
# agent_5_instructions = extract_section(
|
||||
# instructions, "Agent Step 5:", "Agent End"
|
||||
# )
|
||||
# if agent_5_instructions is None:
|
||||
# raise ValueError("Agent 5 instructions not found")
|
||||
# agent_5_base_data = extract_section(instructions, "|Start Data|", "|End Data|")
|
||||
# agent_5_task = extract_section(
|
||||
# agent_5_instructions, "Task:", "Independent Research Sources:"
|
||||
# )
|
||||
# if agent_5_task is None:
|
||||
# raise ValueError("Agent 5 task not found")
|
||||
# agent_5_output_objective = extract_section(
|
||||
# agent_5_instructions, "Output Objective:"
|
||||
# )
|
||||
# if agent_5_output_objective is None:
|
||||
# raise ValueError("Agent 5 output objective not found")
|
||||
# except ValueError as e:
|
||||
# raise ValueError(
|
||||
# f"Instructions for Agent Step 5 were not properly formatted: {e}"
|
||||
# )
|
||||
|
||||
research_result_list = []
|
||||
# research_result_list = []
|
||||
|
||||
if agent_5_task.strip() == "*concatenate*":
|
||||
object_research_results = state.object_research_results
|
||||
# if agent_5_task.strip() == "*concatenate*":
|
||||
# object_research_results = state.object_research_results
|
||||
|
||||
for object_research_result in object_research_results:
|
||||
object = object_research_result["object"]
|
||||
research_result = object_research_result["research_result"]
|
||||
research_result_list.append(f"Object: {object}\n\n{research_result}")
|
||||
# for object_research_result in object_research_results:
|
||||
# object = object_research_result["object"]
|
||||
# research_result = object_research_result["research_result"]
|
||||
# research_result_list.append(f"Object: {object}\n\n{research_result}")
|
||||
|
||||
research_results = "\n\n".join(research_result_list)
|
||||
# research_results = "\n\n".join(research_result_list)
|
||||
|
||||
else:
|
||||
raise NotImplementedError("Only '*concatenate*' is currently supported")
|
||||
# else:
|
||||
# raise NotImplementedError("Only '*concatenate*' is currently supported")
|
||||
|
||||
# Create a prompt for the object consolidation
|
||||
# # Create a prompt for the object consolidation
|
||||
|
||||
if agent_5_base_data is None:
|
||||
dc_formatting_prompt = DC_FORMATTING_NO_BASE_DATA_PROMPT.format(
|
||||
text=research_results,
|
||||
format=agent_5_output_objective,
|
||||
)
|
||||
else:
|
||||
dc_formatting_prompt = DC_FORMATTING_WITH_BASE_DATA_PROMPT.format(
|
||||
base_data=agent_5_base_data,
|
||||
text=research_results,
|
||||
format=agent_5_output_objective,
|
||||
)
|
||||
# if agent_5_base_data is None:
|
||||
# dc_formatting_prompt = DC_FORMATTING_NO_BASE_DATA_PROMPT.format(
|
||||
# text=research_results,
|
||||
# format=agent_5_output_objective,
|
||||
# )
|
||||
# else:
|
||||
# dc_formatting_prompt = DC_FORMATTING_WITH_BASE_DATA_PROMPT.format(
|
||||
# base_data=agent_5_base_data,
|
||||
# text=research_results,
|
||||
# format=agent_5_output_objective,
|
||||
# )
|
||||
|
||||
# Run LLM
|
||||
# # Run LLM
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=trim_prompt_piece(
|
||||
config=graph_config.tooling.primary_llm.config,
|
||||
prompt_piece=dc_formatting_prompt,
|
||||
reserved_str="",
|
||||
),
|
||||
)
|
||||
]
|
||||
# msg = [
|
||||
# HumanMessage(
|
||||
# content=trim_prompt_piece(
|
||||
# config=graph_config.tooling.primary_llm.config,
|
||||
# prompt_piece=dc_formatting_prompt,
|
||||
# reserved_str="",
|
||||
# ),
|
||||
# )
|
||||
# ]
|
||||
|
||||
try:
|
||||
_ = run_with_timeout(
|
||||
60,
|
||||
lambda: stream_llm_answer(
|
||||
llm=graph_config.tooling.primary_llm,
|
||||
prompt=msg,
|
||||
event_name="initial_agent_answer",
|
||||
writer=writer,
|
||||
agent_answer_level=0,
|
||||
agent_answer_question_num=0,
|
||||
agent_answer_type="agent_level_answer",
|
||||
timeout_override=30,
|
||||
max_tokens=None,
|
||||
),
|
||||
)
|
||||
# try:
|
||||
# _ = run_with_timeout(
|
||||
# 60,
|
||||
# lambda: stream_llm_answer(
|
||||
# llm=graph_config.tooling.primary_llm,
|
||||
# prompt=msg,
|
||||
# event_name="initial_agent_answer",
|
||||
# writer=writer,
|
||||
# agent_answer_level=0,
|
||||
# agent_answer_question_num=0,
|
||||
# agent_answer_type="agent_level_answer",
|
||||
# timeout_override=30,
|
||||
# max_tokens=None,
|
||||
# ),
|
||||
# )
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error in consolidate_research: {e}")
|
||||
# except Exception as e:
|
||||
# raise ValueError(f"Error in consolidate_research: {e}")
|
||||
|
||||
logger.debug("DivCon Step A5 - Final Generation - completed")
|
||||
# logger.debug("DivCon Step A5 - Final Generation - completed")
|
||||
|
||||
return ResearchUpdate(
|
||||
research_results=research_results,
|
||||
log_messages=["Agent Source Consilidation done"],
|
||||
)
|
||||
# return ResearchUpdate(
|
||||
# research_results=research_results,
|
||||
# log_messages=["Agent Source Consilidation done"],
|
||||
# )
|
||||
|
||||
@@ -1,61 +1,50 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
# from datetime import datetime
|
||||
# from typing import cast
|
||||
|
||||
from onyx.chat.models import LlmDoc
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.tools.models import SearchToolOverrideKwargs
|
||||
from onyx.tools.tool_implementations.search.search_tool import (
|
||||
FINAL_CONTEXT_DOCUMENTS_ID,
|
||||
)
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
# from onyx.chat.models import LlmDoc
|
||||
# from onyx.configs.constants import DocumentSource
|
||||
# from onyx.tools.models import SearchToolOverrideKwargs
|
||||
# from onyx.tools.tool_implementations.search.search_tool import (
|
||||
# FINAL_CONTEXT_DOCUMENTS_ID,
|
||||
# )
|
||||
# from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
|
||||
|
||||
def research(
|
||||
question: str,
|
||||
search_tool: SearchTool,
|
||||
document_sources: list[DocumentSource] | None = None,
|
||||
time_cutoff: datetime | None = None,
|
||||
) -> list[LlmDoc]:
|
||||
# new db session to avoid concurrency issues
|
||||
# def research(
|
||||
# question: str,
|
||||
# search_tool: SearchTool,
|
||||
# document_sources: list[DocumentSource] | None = None,
|
||||
# time_cutoff: datetime | None = None,
|
||||
# ) -> list[LlmDoc]:
|
||||
# # new db session to avoid concurrency issues
|
||||
|
||||
callback_container: list[list[InferenceSection]] = []
|
||||
retrieved_docs: list[LlmDoc] = []
|
||||
# retrieved_docs: list[LlmDoc] = []
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
for tool_response in search_tool.run(
|
||||
query=question,
|
||||
override_kwargs=SearchToolOverrideKwargs(
|
||||
force_no_rerank=False,
|
||||
alternate_db_session=db_session,
|
||||
retrieved_sections_callback=callback_container.append,
|
||||
skip_query_analysis=True,
|
||||
document_sources=document_sources,
|
||||
time_cutoff=time_cutoff,
|
||||
),
|
||||
):
|
||||
# get retrieved docs to send to the rest of the graph
|
||||
if tool_response.id == FINAL_CONTEXT_DOCUMENTS_ID:
|
||||
retrieved_docs = cast(list[LlmDoc], tool_response.response)[:10]
|
||||
break
|
||||
return retrieved_docs
|
||||
# for tool_response in search_tool.run(
|
||||
# query=question,
|
||||
# override_kwargs=SearchToolOverrideKwargs(original_query=question),
|
||||
# ):
|
||||
# # get retrieved docs to send to the rest of the graph
|
||||
# if tool_response.id == FINAL_CONTEXT_DOCUMENTS_ID:
|
||||
# retrieved_docs = cast(list[LlmDoc], tool_response.response)[:10]
|
||||
# break
|
||||
# return retrieved_docs
|
||||
|
||||
|
||||
def extract_section(
|
||||
text: str, start_marker: str, end_marker: str | None = None
|
||||
) -> str | None:
|
||||
"""Extract text between markers, returning None if markers not found"""
|
||||
parts = text.split(start_marker)
|
||||
# def extract_section(
|
||||
# text: str, start_marker: str, end_marker: str | None = None
|
||||
# ) -> str | None:
|
||||
# """Extract text between markers, returning None if markers not found"""
|
||||
# parts = text.split(start_marker)
|
||||
|
||||
if len(parts) == 1:
|
||||
return None
|
||||
# if len(parts) == 1:
|
||||
# return None
|
||||
|
||||
after_start = parts[1].strip()
|
||||
# after_start = parts[1].strip()
|
||||
|
||||
if not end_marker:
|
||||
return after_start
|
||||
# if not end_marker:
|
||||
# return after_start
|
||||
|
||||
extract = after_start.split(end_marker)[0]
|
||||
# extract = after_start.split(end_marker)[0]
|
||||
|
||||
return extract.strip()
|
||||
# return extract.strip()
|
||||
|
||||
@@ -1,72 +1,72 @@
|
||||
from operator import add
|
||||
from typing import Annotated
|
||||
from typing import Dict
|
||||
from typing import TypedDict
|
||||
# from operator import add
|
||||
# from typing import Annotated
|
||||
# from typing import Dict
|
||||
# from typing import TypedDict
|
||||
|
||||
from pydantic import BaseModel
|
||||
# from pydantic import BaseModel
|
||||
|
||||
from onyx.agents.agent_search.core_state import CoreState
|
||||
from onyx.agents.agent_search.orchestration.states import ToolCallUpdate
|
||||
from onyx.agents.agent_search.orchestration.states import ToolChoiceInput
|
||||
from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
|
||||
from onyx.configs.constants import DocumentSource
|
||||
# from onyx.agents.agent_search.core_state import CoreState
|
||||
# from onyx.agents.agent_search.orchestration.states import ToolCallUpdate
|
||||
# from onyx.agents.agent_search.orchestration.states import ToolChoiceInput
|
||||
# from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
|
||||
# from onyx.configs.constants import DocumentSource
|
||||
|
||||
|
||||
### States ###
|
||||
class LoggerUpdate(BaseModel):
|
||||
log_messages: Annotated[list[str], add] = []
|
||||
# ### States ###
|
||||
# class LoggerUpdate(BaseModel):
|
||||
# log_messages: Annotated[list[str], add] = []
|
||||
|
||||
|
||||
class SearchSourcesObjectsUpdate(LoggerUpdate):
|
||||
analysis_objects: list[str] = []
|
||||
analysis_sources: list[DocumentSource] = []
|
||||
# class SearchSourcesObjectsUpdate(LoggerUpdate):
|
||||
# analysis_objects: list[str] = []
|
||||
# analysis_sources: list[DocumentSource] = []
|
||||
|
||||
|
||||
class ObjectSourceInput(LoggerUpdate):
|
||||
object_source_combination: tuple[str, DocumentSource]
|
||||
# class ObjectSourceInput(LoggerUpdate):
|
||||
# object_source_combination: tuple[str, DocumentSource]
|
||||
|
||||
|
||||
class ObjectSourceResearchUpdate(LoggerUpdate):
|
||||
object_source_research_results: Annotated[list[Dict[str, str]], add] = []
|
||||
# class ObjectSourceResearchUpdate(LoggerUpdate):
|
||||
# object_source_research_results: Annotated[list[Dict[str, str]], add] = []
|
||||
|
||||
|
||||
class ObjectInformationInput(LoggerUpdate):
|
||||
object_information: Dict[str, str]
|
||||
# class ObjectInformationInput(LoggerUpdate):
|
||||
# object_information: Dict[str, str]
|
||||
|
||||
|
||||
class ObjectResearchInformationUpdate(LoggerUpdate):
|
||||
object_research_information_results: Annotated[list[Dict[str, str]], add] = []
|
||||
# class ObjectResearchInformationUpdate(LoggerUpdate):
|
||||
# object_research_information_results: Annotated[list[Dict[str, str]], add] = []
|
||||
|
||||
|
||||
class ObjectResearchUpdate(LoggerUpdate):
|
||||
object_research_results: Annotated[list[Dict[str, str]], add] = []
|
||||
# class ObjectResearchUpdate(LoggerUpdate):
|
||||
# object_research_results: Annotated[list[Dict[str, str]], add] = []
|
||||
|
||||
|
||||
class ResearchUpdate(LoggerUpdate):
|
||||
research_results: str | None = None
|
||||
# class ResearchUpdate(LoggerUpdate):
|
||||
# research_results: str | None = None
|
||||
|
||||
|
||||
## Graph Input State
|
||||
class MainInput(CoreState):
|
||||
pass
|
||||
# ## Graph Input State
|
||||
# class MainInput(CoreState):
|
||||
# pass
|
||||
|
||||
|
||||
## Graph State
|
||||
class MainState(
|
||||
# This includes the core state
|
||||
MainInput,
|
||||
ToolChoiceInput,
|
||||
ToolCallUpdate,
|
||||
ToolChoiceUpdate,
|
||||
SearchSourcesObjectsUpdate,
|
||||
ObjectSourceResearchUpdate,
|
||||
ObjectResearchInformationUpdate,
|
||||
ObjectResearchUpdate,
|
||||
ResearchUpdate,
|
||||
):
|
||||
pass
|
||||
# ## Graph State
|
||||
# class MainState(
|
||||
# # This includes the core state
|
||||
# MainInput,
|
||||
# ToolChoiceInput,
|
||||
# ToolCallUpdate,
|
||||
# ToolChoiceUpdate,
|
||||
# SearchSourcesObjectsUpdate,
|
||||
# ObjectSourceResearchUpdate,
|
||||
# ObjectResearchInformationUpdate,
|
||||
# ObjectResearchUpdate,
|
||||
# ResearchUpdate,
|
||||
# ):
|
||||
# pass
|
||||
|
||||
|
||||
## Graph Output State - presently not used
|
||||
class MainOutput(TypedDict):
|
||||
log_messages: list[str]
|
||||
# ## Graph Output State - presently not used
|
||||
# class MainOutput(TypedDict):
|
||||
# log_messages: list[str]
|
||||
|
||||
@@ -1,36 +1,36 @@
|
||||
from pydantic import BaseModel
|
||||
# from pydantic import BaseModel
|
||||
|
||||
|
||||
class RefinementSubQuestion(BaseModel):
|
||||
sub_question: str
|
||||
sub_question_id: str
|
||||
verified: bool
|
||||
answered: bool
|
||||
answer: str
|
||||
# class RefinementSubQuestion(BaseModel):
|
||||
# sub_question: str
|
||||
# sub_question_id: str
|
||||
# verified: bool
|
||||
# answered: bool
|
||||
# answer: str
|
||||
|
||||
|
||||
class AgentTimings(BaseModel):
|
||||
base_duration_s: float | None
|
||||
refined_duration_s: float | None
|
||||
full_duration_s: float | None
|
||||
# class AgentTimings(BaseModel):
|
||||
# base_duration_s: float | None
|
||||
# refined_duration_s: float | None
|
||||
# full_duration_s: float | None
|
||||
|
||||
|
||||
class AgentBaseMetrics(BaseModel):
|
||||
num_verified_documents_total: int | None
|
||||
num_verified_documents_core: int | None
|
||||
verified_avg_score_core: float | None
|
||||
num_verified_documents_base: int | float | None
|
||||
verified_avg_score_base: float | None = None
|
||||
base_doc_boost_factor: float | None = None
|
||||
support_boost_factor: float | None = None
|
||||
duration_s: float | None = None
|
||||
# class AgentBaseMetrics(BaseModel):
|
||||
# num_verified_documents_total: int | None
|
||||
# num_verified_documents_core: int | None
|
||||
# verified_avg_score_core: float | None
|
||||
# num_verified_documents_base: int | float | None
|
||||
# verified_avg_score_base: float | None = None
|
||||
# base_doc_boost_factor: float | None = None
|
||||
# support_boost_factor: float | None = None
|
||||
# duration_s: float | None = None
|
||||
|
||||
|
||||
class AgentRefinedMetrics(BaseModel):
|
||||
refined_doc_boost_factor: float | None = None
|
||||
refined_question_boost_factor: float | None = None
|
||||
duration_s: float | None = None
|
||||
# class AgentRefinedMetrics(BaseModel):
|
||||
# refined_doc_boost_factor: float | None = None
|
||||
# refined_question_boost_factor: float | None = None
|
||||
# duration_s: float | None = None
|
||||
|
||||
|
||||
class AgentAdditionalMetrics(BaseModel):
|
||||
pass
|
||||
# class AgentAdditionalMetrics(BaseModel):
|
||||
# pass
|
||||
|
||||
@@ -1,61 +1,61 @@
|
||||
from collections.abc import Hashable
|
||||
# from collections.abc import Hashable
|
||||
|
||||
from langgraph.graph import END
|
||||
from langgraph.types import Send
|
||||
# from langgraph.graph import END
|
||||
# from langgraph.types import Send
|
||||
|
||||
from onyx.agents.agent_search.dr.enums import DRPath
|
||||
from onyx.agents.agent_search.dr.states import MainState
|
||||
# from onyx.agents.agent_search.dr.enums import DRPath
|
||||
# from onyx.agents.agent_search.dr.states import MainState
|
||||
|
||||
|
||||
def decision_router(state: MainState) -> list[Send | Hashable] | DRPath | str:
|
||||
if not state.tools_used:
|
||||
raise IndexError("state.tools_used cannot be empty")
|
||||
# def decision_router(state: MainState) -> list[Send | Hashable] | DRPath | str:
|
||||
# if not state.tools_used:
|
||||
# raise IndexError("state.tools_used cannot be empty")
|
||||
|
||||
# next_tool is either a generic tool name or a DRPath string
|
||||
next_tool_name = state.tools_used[-1]
|
||||
# # next_tool is either a generic tool name or a DRPath string
|
||||
# next_tool_name = state.tools_used[-1]
|
||||
|
||||
available_tools = state.available_tools
|
||||
if not available_tools:
|
||||
raise ValueError("No tool is available. This should not happen.")
|
||||
# available_tools = state.available_tools
|
||||
# if not available_tools:
|
||||
# raise ValueError("No tool is available. This should not happen.")
|
||||
|
||||
if next_tool_name in available_tools:
|
||||
next_tool_path = available_tools[next_tool_name].path
|
||||
elif next_tool_name == DRPath.END.value:
|
||||
return END
|
||||
elif next_tool_name == DRPath.LOGGER.value:
|
||||
return DRPath.LOGGER
|
||||
elif next_tool_name == DRPath.CLOSER.value:
|
||||
return DRPath.CLOSER
|
||||
else:
|
||||
return DRPath.ORCHESTRATOR
|
||||
# if next_tool_name in available_tools:
|
||||
# next_tool_path = available_tools[next_tool_name].path
|
||||
# elif next_tool_name == DRPath.END.value:
|
||||
# return END
|
||||
# elif next_tool_name == DRPath.LOGGER.value:
|
||||
# return DRPath.LOGGER
|
||||
# elif next_tool_name == DRPath.CLOSER.value:
|
||||
# return DRPath.CLOSER
|
||||
# else:
|
||||
# return DRPath.ORCHESTRATOR
|
||||
|
||||
# handle invalid paths
|
||||
if next_tool_path == DRPath.CLARIFIER:
|
||||
raise ValueError("CLARIFIER is not a valid path during iteration")
|
||||
# # handle invalid paths
|
||||
# if next_tool_path == DRPath.CLARIFIER:
|
||||
# raise ValueError("CLARIFIER is not a valid path during iteration")
|
||||
|
||||
# handle tool calls without a query
|
||||
if (
|
||||
next_tool_path
|
||||
in (
|
||||
DRPath.INTERNAL_SEARCH,
|
||||
DRPath.WEB_SEARCH,
|
||||
DRPath.KNOWLEDGE_GRAPH,
|
||||
DRPath.IMAGE_GENERATION,
|
||||
)
|
||||
and len(state.query_list) == 0
|
||||
):
|
||||
return DRPath.CLOSER
|
||||
# # handle tool calls without a query
|
||||
# if (
|
||||
# next_tool_path
|
||||
# in (
|
||||
# DRPath.INTERNAL_SEARCH,
|
||||
# DRPath.WEB_SEARCH,
|
||||
# DRPath.KNOWLEDGE_GRAPH,
|
||||
# DRPath.IMAGE_GENERATION,
|
||||
# )
|
||||
# and len(state.query_list) == 0
|
||||
# ):
|
||||
# return DRPath.CLOSER
|
||||
|
||||
return next_tool_path
|
||||
# return next_tool_path
|
||||
|
||||
|
||||
def completeness_router(state: MainState) -> DRPath | str:
|
||||
if not state.tools_used:
|
||||
raise IndexError("tools_used cannot be empty")
|
||||
# def completeness_router(state: MainState) -> DRPath | str:
|
||||
# if not state.tools_used:
|
||||
# raise IndexError("tools_used cannot be empty")
|
||||
|
||||
# go to closer if path is CLOSER or no queries
|
||||
next_path = state.tools_used[-1]
|
||||
# # go to closer if path is CLOSER or no queries
|
||||
# next_path = state.tools_used[-1]
|
||||
|
||||
if next_path == DRPath.ORCHESTRATOR.value:
|
||||
return DRPath.ORCHESTRATOR
|
||||
return DRPath.LOGGER
|
||||
# if next_path == DRPath.ORCHESTRATOR.value:
|
||||
# return DRPath.ORCHESTRATOR
|
||||
# return DRPath.LOGGER
|
||||
|
||||
@@ -1,31 +1,27 @@
|
||||
from onyx.agents.agent_search.dr.enums import DRPath
|
||||
from onyx.agents.agent_search.dr.enums import ResearchType
|
||||
# from onyx.agents.agent_search.dr.enums import DRPath
|
||||
|
||||
MAX_CHAT_HISTORY_MESSAGES = (
|
||||
3 # note: actual count is x2 to account for user and assistant messages
|
||||
)
|
||||
# MAX_CHAT_HISTORY_MESSAGES = (
|
||||
# 3 # note: actual count is x2 to account for user and assistant messages
|
||||
# )
|
||||
|
||||
MAX_DR_PARALLEL_SEARCH = 4
|
||||
# MAX_DR_PARALLEL_SEARCH = 4
|
||||
|
||||
# TODO: test more, generally not needed/adds unnecessary iterations
|
||||
MAX_NUM_CLOSER_SUGGESTIONS = (
|
||||
0 # how many times the closer can send back to the orchestrator
|
||||
)
|
||||
# # TODO: test more, generally not needed/adds unnecessary iterations
|
||||
# MAX_NUM_CLOSER_SUGGESTIONS = (
|
||||
# 0 # how many times the closer can send back to the orchestrator
|
||||
# )
|
||||
|
||||
CLARIFICATION_REQUEST_PREFIX = "PLEASE CLARIFY:"
|
||||
HIGH_LEVEL_PLAN_PREFIX = "The Plan:"
|
||||
# CLARIFICATION_REQUEST_PREFIX = "PLEASE CLARIFY:"
|
||||
# HIGH_LEVEL_PLAN_PREFIX = "The Plan:"
|
||||
|
||||
AVERAGE_TOOL_COSTS: dict[DRPath, float] = {
|
||||
DRPath.INTERNAL_SEARCH: 1.0,
|
||||
DRPath.KNOWLEDGE_GRAPH: 2.0,
|
||||
DRPath.WEB_SEARCH: 1.5,
|
||||
DRPath.IMAGE_GENERATION: 3.0,
|
||||
DRPath.GENERIC_TOOL: 1.5, # TODO: see todo in OrchestratorTool
|
||||
DRPath.CLOSER: 0.0,
|
||||
}
|
||||
# AVERAGE_TOOL_COSTS: dict[DRPath, float] = {
|
||||
# DRPath.INTERNAL_SEARCH: 1.0,
|
||||
# DRPath.KNOWLEDGE_GRAPH: 2.0,
|
||||
# DRPath.WEB_SEARCH: 1.5,
|
||||
# DRPath.IMAGE_GENERATION: 3.0,
|
||||
# DRPath.GENERIC_TOOL: 1.5, # TODO: see todo in OrchestratorTool
|
||||
# DRPath.CLOSER: 0.0,
|
||||
# }
|
||||
|
||||
DR_TIME_BUDGET_BY_TYPE = {
|
||||
ResearchType.THOUGHTFUL: 3.0,
|
||||
ResearchType.DEEP: 12.0,
|
||||
ResearchType.FAST: 0.5,
|
||||
}
|
||||
# # Default time budget for agentic search (when use_agentic_search is True)
|
||||
# DR_TIME_BUDGET_DEFAULT = 12.0
|
||||
|
||||
@@ -1,112 +1,111 @@
|
||||
from datetime import datetime
|
||||
# from datetime import datetime
|
||||
|
||||
from onyx.agents.agent_search.dr.enums import DRPath
|
||||
from onyx.agents.agent_search.dr.enums import ResearchType
|
||||
from onyx.agents.agent_search.dr.models import DRPromptPurpose
|
||||
from onyx.agents.agent_search.dr.models import OrchestratorTool
|
||||
from onyx.prompts.dr_prompts import GET_CLARIFICATION_PROMPT
|
||||
from onyx.prompts.dr_prompts import KG_TYPES_DESCRIPTIONS
|
||||
from onyx.prompts.dr_prompts import ORCHESTRATOR_DEEP_INITIAL_PLAN_PROMPT
|
||||
from onyx.prompts.dr_prompts import ORCHESTRATOR_DEEP_ITERATIVE_DECISION_PROMPT
|
||||
from onyx.prompts.dr_prompts import ORCHESTRATOR_FAST_ITERATIVE_DECISION_PROMPT
|
||||
from onyx.prompts.dr_prompts import ORCHESTRATOR_FAST_ITERATIVE_REASONING_PROMPT
|
||||
from onyx.prompts.dr_prompts import ORCHESTRATOR_NEXT_STEP_PURPOSE_PROMPT
|
||||
from onyx.prompts.dr_prompts import TOOL_DIFFERENTIATION_HINTS
|
||||
from onyx.prompts.dr_prompts import TOOL_QUESTION_HINTS
|
||||
from onyx.prompts.prompt_template import PromptTemplate
|
||||
# from onyx.agents.agent_search.dr.enums import DRPath
|
||||
# from onyx.agents.agent_search.dr.models import DRPromptPurpose
|
||||
# from onyx.agents.agent_search.dr.models import OrchestratorTool
|
||||
# from onyx.prompts.dr_prompts import GET_CLARIFICATION_PROMPT
|
||||
# from onyx.prompts.dr_prompts import KG_TYPES_DESCRIPTIONS
|
||||
# from onyx.prompts.dr_prompts import ORCHESTRATOR_DEEP_INITIAL_PLAN_PROMPT
|
||||
# from onyx.prompts.dr_prompts import ORCHESTRATOR_DEEP_ITERATIVE_DECISION_PROMPT
|
||||
# from onyx.prompts.dr_prompts import ORCHESTRATOR_FAST_ITERATIVE_DECISION_PROMPT
|
||||
# from onyx.prompts.dr_prompts import ORCHESTRATOR_FAST_ITERATIVE_REASONING_PROMPT
|
||||
# from onyx.prompts.dr_prompts import ORCHESTRATOR_NEXT_STEP_PURPOSE_PROMPT
|
||||
# from onyx.prompts.dr_prompts import TOOL_DIFFERENTIATION_HINTS
|
||||
# from onyx.prompts.dr_prompts import TOOL_QUESTION_HINTS
|
||||
# from onyx.prompts.prompt_template import PromptTemplate
|
||||
|
||||
|
||||
def get_dr_prompt_orchestration_templates(
|
||||
purpose: DRPromptPurpose,
|
||||
research_type: ResearchType,
|
||||
available_tools: dict[str, OrchestratorTool],
|
||||
entity_types_string: str | None = None,
|
||||
relationship_types_string: str | None = None,
|
||||
reasoning_result: str | None = None,
|
||||
tool_calls_string: str | None = None,
|
||||
) -> PromptTemplate:
|
||||
available_tools = available_tools or {}
|
||||
tool_names = list(available_tools.keys())
|
||||
tool_description_str = "\n\n".join(
|
||||
f"- {tool_name}: {tool.description}"
|
||||
for tool_name, tool in available_tools.items()
|
||||
)
|
||||
tool_cost_str = "\n".join(
|
||||
f"{tool_name}: {tool.cost}" for tool_name, tool in available_tools.items()
|
||||
)
|
||||
# def get_dr_prompt_orchestration_templates(
|
||||
# purpose: DRPromptPurpose,
|
||||
# use_agentic_search: bool,
|
||||
# available_tools: dict[str, OrchestratorTool],
|
||||
# entity_types_string: str | None = None,
|
||||
# relationship_types_string: str | None = None,
|
||||
# reasoning_result: str | None = None,
|
||||
# tool_calls_string: str | None = None,
|
||||
# ) -> PromptTemplate:
|
||||
# available_tools = available_tools or {}
|
||||
# tool_names = list(available_tools.keys())
|
||||
# tool_description_str = "\n\n".join(
|
||||
# f"- {tool_name}: {tool.description}"
|
||||
# for tool_name, tool in available_tools.items()
|
||||
# )
|
||||
# tool_cost_str = "\n".join(
|
||||
# f"{tool_name}: {tool.cost}" for tool_name, tool in available_tools.items()
|
||||
# )
|
||||
|
||||
tool_differentiations: list[str] = [
|
||||
TOOL_DIFFERENTIATION_HINTS[(tool_1, tool_2)]
|
||||
for tool_1 in available_tools
|
||||
for tool_2 in available_tools
|
||||
if (tool_1, tool_2) in TOOL_DIFFERENTIATION_HINTS
|
||||
]
|
||||
tool_differentiation_hint_string = (
|
||||
"\n".join(tool_differentiations) or "(No differentiating hints available)"
|
||||
)
|
||||
# TODO: add tool deliniation pairs for custom tools as well
|
||||
# tool_differentiations: list[str] = [
|
||||
# TOOL_DIFFERENTIATION_HINTS[(tool_1, tool_2)]
|
||||
# for tool_1 in available_tools
|
||||
# for tool_2 in available_tools
|
||||
# if (tool_1, tool_2) in TOOL_DIFFERENTIATION_HINTS
|
||||
# ]
|
||||
# tool_differentiation_hint_string = (
|
||||
# "\n".join(tool_differentiations) or "(No differentiating hints available)"
|
||||
# )
|
||||
# # TODO: add tool deliniation pairs for custom tools as well
|
||||
|
||||
tool_question_hint_string = (
|
||||
"\n".join(
|
||||
"- " + TOOL_QUESTION_HINTS[tool]
|
||||
for tool in available_tools
|
||||
if tool in TOOL_QUESTION_HINTS
|
||||
)
|
||||
or "(No examples available)"
|
||||
)
|
||||
# tool_question_hint_string = (
|
||||
# "\n".join(
|
||||
# "- " + TOOL_QUESTION_HINTS[tool]
|
||||
# for tool in available_tools
|
||||
# if tool in TOOL_QUESTION_HINTS
|
||||
# )
|
||||
# or "(No examples available)"
|
||||
# )
|
||||
|
||||
if DRPath.KNOWLEDGE_GRAPH.value in available_tools and (
|
||||
entity_types_string or relationship_types_string
|
||||
):
|
||||
# if DRPath.KNOWLEDGE_GRAPH.value in available_tools and (
|
||||
# entity_types_string or relationship_types_string
|
||||
# ):
|
||||
|
||||
kg_types_descriptions = KG_TYPES_DESCRIPTIONS.build(
|
||||
possible_entities=entity_types_string or "",
|
||||
possible_relationships=relationship_types_string or "",
|
||||
)
|
||||
else:
|
||||
kg_types_descriptions = "(The Knowledge Graph is not used.)"
|
||||
# kg_types_descriptions = KG_TYPES_DESCRIPTIONS.build(
|
||||
# possible_entities=entity_types_string or "",
|
||||
# possible_relationships=relationship_types_string or "",
|
||||
# )
|
||||
# else:
|
||||
# kg_types_descriptions = "(The Knowledge Graph is not used.)"
|
||||
|
||||
if purpose == DRPromptPurpose.PLAN:
|
||||
if research_type == ResearchType.THOUGHTFUL:
|
||||
raise ValueError("plan generation is not supported for FAST time budget")
|
||||
base_template = ORCHESTRATOR_DEEP_INITIAL_PLAN_PROMPT
|
||||
# if purpose == DRPromptPurpose.PLAN:
|
||||
# if not use_agentic_search:
|
||||
# raise ValueError("plan generation is only supported for agentic search")
|
||||
# base_template = ORCHESTRATOR_DEEP_INITIAL_PLAN_PROMPT
|
||||
|
||||
elif purpose == DRPromptPurpose.NEXT_STEP_REASONING:
|
||||
if research_type == ResearchType.THOUGHTFUL:
|
||||
base_template = ORCHESTRATOR_FAST_ITERATIVE_REASONING_PROMPT
|
||||
else:
|
||||
raise ValueError(
|
||||
"reasoning is not separately required for DEEP time budget"
|
||||
)
|
||||
# elif purpose == DRPromptPurpose.NEXT_STEP_REASONING:
|
||||
# if not use_agentic_search:
|
||||
# base_template = ORCHESTRATOR_FAST_ITERATIVE_REASONING_PROMPT
|
||||
# else:
|
||||
# raise ValueError(
|
||||
# "reasoning is not separately required for agentic search"
|
||||
# )
|
||||
|
||||
elif purpose == DRPromptPurpose.NEXT_STEP_PURPOSE:
|
||||
base_template = ORCHESTRATOR_NEXT_STEP_PURPOSE_PROMPT
|
||||
# elif purpose == DRPromptPurpose.NEXT_STEP_PURPOSE:
|
||||
# base_template = ORCHESTRATOR_NEXT_STEP_PURPOSE_PROMPT
|
||||
|
||||
elif purpose == DRPromptPurpose.NEXT_STEP:
|
||||
if research_type == ResearchType.THOUGHTFUL:
|
||||
base_template = ORCHESTRATOR_FAST_ITERATIVE_DECISION_PROMPT
|
||||
else:
|
||||
base_template = ORCHESTRATOR_DEEP_ITERATIVE_DECISION_PROMPT
|
||||
# elif purpose == DRPromptPurpose.NEXT_STEP:
|
||||
# if not use_agentic_search:
|
||||
# base_template = ORCHESTRATOR_FAST_ITERATIVE_DECISION_PROMPT
|
||||
# else:
|
||||
# base_template = ORCHESTRATOR_DEEP_ITERATIVE_DECISION_PROMPT
|
||||
|
||||
elif purpose == DRPromptPurpose.CLARIFICATION:
|
||||
if research_type == ResearchType.THOUGHTFUL:
|
||||
raise ValueError("clarification is not supported for FAST time budget")
|
||||
base_template = GET_CLARIFICATION_PROMPT
|
||||
# elif purpose == DRPromptPurpose.CLARIFICATION:
|
||||
# if not use_agentic_search:
|
||||
# raise ValueError("clarification is only supported for agentic search")
|
||||
# base_template = GET_CLARIFICATION_PROMPT
|
||||
|
||||
else:
|
||||
# for mypy, clearly a mypy bug
|
||||
raise ValueError(f"Invalid purpose: {purpose}")
|
||||
# else:
|
||||
# # for mypy, clearly a mypy bug
|
||||
# raise ValueError(f"Invalid purpose: {purpose}")
|
||||
|
||||
return base_template.partial_build(
|
||||
num_available_tools=str(len(tool_names)),
|
||||
available_tools=", ".join(tool_names),
|
||||
tool_choice_options=" or ".join(tool_names),
|
||||
current_time=datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
||||
kg_types_descriptions=kg_types_descriptions,
|
||||
tool_descriptions=tool_description_str,
|
||||
tool_differentiation_hints=tool_differentiation_hint_string,
|
||||
tool_question_hints=tool_question_hint_string,
|
||||
average_tool_costs=tool_cost_str,
|
||||
reasoning_result=reasoning_result or "(No reasoning result provided.)",
|
||||
tool_calls_string=tool_calls_string or "(No tool calls provided.)",
|
||||
)
|
||||
# return base_template.partial_build(
|
||||
# num_available_tools=str(len(tool_names)),
|
||||
# available_tools=", ".join(tool_names),
|
||||
# tool_choice_options=" or ".join(tool_names),
|
||||
# current_time=datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
||||
# kg_types_descriptions=kg_types_descriptions,
|
||||
# tool_descriptions=tool_description_str,
|
||||
# tool_differentiation_hints=tool_differentiation_hint_string,
|
||||
# tool_question_hints=tool_question_hint_string,
|
||||
# average_tool_costs=tool_cost_str,
|
||||
# reasoning_result=reasoning_result or "(No reasoning result provided.)",
|
||||
# tool_calls_string=tool_calls_string or "(No tool calls provided.)",
|
||||
# )
|
||||
|
||||
@@ -1,32 +1,22 @@
|
||||
from enum import Enum
|
||||
# from enum import Enum
|
||||
|
||||
|
||||
class ResearchType(str, Enum):
|
||||
"""Research type options for agent search operations"""
|
||||
# class ResearchAnswerPurpose(str, Enum):
|
||||
# """Research answer purpose options for agent search operations"""
|
||||
|
||||
# BASIC = "BASIC"
|
||||
LEGACY_AGENTIC = "LEGACY_AGENTIC" # only used for legacy agentic search migrations
|
||||
THOUGHTFUL = "THOUGHTFUL"
|
||||
DEEP = "DEEP"
|
||||
FAST = "FAST"
|
||||
# ANSWER = "ANSWER"
|
||||
# CLARIFICATION_REQUEST = "CLARIFICATION_REQUEST"
|
||||
|
||||
|
||||
class ResearchAnswerPurpose(str, Enum):
|
||||
"""Research answer purpose options for agent search operations"""
|
||||
|
||||
ANSWER = "ANSWER"
|
||||
CLARIFICATION_REQUEST = "CLARIFICATION_REQUEST"
|
||||
|
||||
|
||||
class DRPath(str, Enum):
|
||||
CLARIFIER = "Clarifier"
|
||||
ORCHESTRATOR = "Orchestrator"
|
||||
INTERNAL_SEARCH = "Internal Search"
|
||||
GENERIC_TOOL = "Generic Tool"
|
||||
KNOWLEDGE_GRAPH = "Knowledge Graph Search"
|
||||
WEB_SEARCH = "Web Search"
|
||||
IMAGE_GENERATION = "Image Generation"
|
||||
GENERIC_INTERNAL_TOOL = "Generic Internal Tool"
|
||||
CLOSER = "Closer"
|
||||
LOGGER = "Logger"
|
||||
END = "End"
|
||||
# class DRPath(str, Enum):
|
||||
# CLARIFIER = "Clarifier"
|
||||
# ORCHESTRATOR = "Orchestrator"
|
||||
# INTERNAL_SEARCH = "Internal Search"
|
||||
# GENERIC_TOOL = "Generic Tool"
|
||||
# KNOWLEDGE_GRAPH = "Knowledge Graph Search"
|
||||
# WEB_SEARCH = "Web Search"
|
||||
# IMAGE_GENERATION = "Image Generation"
|
||||
# GENERIC_INTERNAL_TOOL = "Generic Internal Tool"
|
||||
# CLOSER = "Closer"
|
||||
# LOGGER = "Logger"
|
||||
# END = "End"
|
||||
|
||||
@@ -1,88 +1,88 @@
|
||||
from langgraph.graph import END
|
||||
from langgraph.graph import START
|
||||
from langgraph.graph import StateGraph
|
||||
# from langgraph.graph import END
|
||||
# from langgraph.graph import START
|
||||
# from langgraph.graph import StateGraph
|
||||
|
||||
from onyx.agents.agent_search.dr.conditional_edges import completeness_router
|
||||
from onyx.agents.agent_search.dr.conditional_edges import decision_router
|
||||
from onyx.agents.agent_search.dr.enums import DRPath
|
||||
from onyx.agents.agent_search.dr.nodes.dr_a0_clarification import clarifier
|
||||
from onyx.agents.agent_search.dr.nodes.dr_a1_orchestrator import orchestrator
|
||||
from onyx.agents.agent_search.dr.nodes.dr_a2_closer import closer
|
||||
from onyx.agents.agent_search.dr.nodes.dr_a3_logger import logging
|
||||
from onyx.agents.agent_search.dr.states import MainInput
|
||||
from onyx.agents.agent_search.dr.states import MainState
|
||||
from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_basic_search_graph_builder import (
|
||||
dr_basic_search_graph_builder,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.custom_tool.dr_custom_tool_graph_builder import (
|
||||
dr_custom_tool_graph_builder,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.generic_internal_tool.dr_generic_internal_tool_graph_builder import (
|
||||
dr_generic_internal_tool_graph_builder,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.image_generation.dr_image_generation_graph_builder import (
|
||||
dr_image_generation_graph_builder,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.kg_search.dr_kg_search_graph_builder import (
|
||||
dr_kg_search_graph_builder,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.dr_ws_graph_builder import (
|
||||
dr_ws_graph_builder,
|
||||
)
|
||||
# from onyx.agents.agent_search.dr.conditional_edges import completeness_router
|
||||
# from onyx.agents.agent_search.dr.conditional_edges import decision_router
|
||||
# from onyx.agents.agent_search.dr.enums import DRPath
|
||||
# from onyx.agents.agent_search.dr.nodes.dr_a0_clarification import clarifier
|
||||
# from onyx.agents.agent_search.dr.nodes.dr_a1_orchestrator import orchestrator
|
||||
# from onyx.agents.agent_search.dr.nodes.dr_a2_closer import closer
|
||||
# from onyx.agents.agent_search.dr.nodes.dr_a3_logger import logging
|
||||
# from onyx.agents.agent_search.dr.states import MainInput
|
||||
# from onyx.agents.agent_search.dr.states import MainState
|
||||
# from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_basic_search_graph_builder import (
|
||||
# dr_basic_search_graph_builder,
|
||||
# )
|
||||
# from onyx.agents.agent_search.dr.sub_agents.custom_tool.dr_custom_tool_graph_builder import (
|
||||
# dr_custom_tool_graph_builder,
|
||||
# )
|
||||
# from onyx.agents.agent_search.dr.sub_agents.generic_internal_tool.dr_generic_internal_tool_graph_builder import (
|
||||
# dr_generic_internal_tool_graph_builder,
|
||||
# )
|
||||
# from onyx.agents.agent_search.dr.sub_agents.image_generation.dr_image_generation_graph_builder import (
|
||||
# dr_image_generation_graph_builder,
|
||||
# )
|
||||
# from onyx.agents.agent_search.dr.sub_agents.kg_search.dr_kg_search_graph_builder import (
|
||||
# dr_kg_search_graph_builder,
|
||||
# )
|
||||
# from onyx.agents.agent_search.dr.sub_agents.web_search.dr_ws_graph_builder import (
|
||||
# dr_ws_graph_builder,
|
||||
# )
|
||||
|
||||
# from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_basic_search_2_act import search
|
||||
# # from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_basic_search_2_act import search
|
||||
|
||||
|
||||
def dr_graph_builder() -> StateGraph:
|
||||
"""
|
||||
LangGraph graph builder for the deep research agent.
|
||||
"""
|
||||
# def dr_graph_builder() -> StateGraph:
|
||||
# """
|
||||
# LangGraph graph builder for the deep research agent.
|
||||
# """
|
||||
|
||||
graph = StateGraph(state_schema=MainState, input=MainInput)
|
||||
# graph = StateGraph(state_schema=MainState, input=MainInput)
|
||||
|
||||
### Add nodes ###
|
||||
# ### Add nodes ###
|
||||
|
||||
graph.add_node(DRPath.CLARIFIER, clarifier)
|
||||
# graph.add_node(DRPath.CLARIFIER, clarifier)
|
||||
|
||||
graph.add_node(DRPath.ORCHESTRATOR, orchestrator)
|
||||
# graph.add_node(DRPath.ORCHESTRATOR, orchestrator)
|
||||
|
||||
basic_search_graph = dr_basic_search_graph_builder().compile()
|
||||
graph.add_node(DRPath.INTERNAL_SEARCH, basic_search_graph)
|
||||
# basic_search_graph = dr_basic_search_graph_builder().compile()
|
||||
# graph.add_node(DRPath.INTERNAL_SEARCH, basic_search_graph)
|
||||
|
||||
kg_search_graph = dr_kg_search_graph_builder().compile()
|
||||
graph.add_node(DRPath.KNOWLEDGE_GRAPH, kg_search_graph)
|
||||
# kg_search_graph = dr_kg_search_graph_builder().compile()
|
||||
# graph.add_node(DRPath.KNOWLEDGE_GRAPH, kg_search_graph)
|
||||
|
||||
internet_search_graph = dr_ws_graph_builder().compile()
|
||||
graph.add_node(DRPath.WEB_SEARCH, internet_search_graph)
|
||||
# internet_search_graph = dr_ws_graph_builder().compile()
|
||||
# graph.add_node(DRPath.WEB_SEARCH, internet_search_graph)
|
||||
|
||||
image_generation_graph = dr_image_generation_graph_builder().compile()
|
||||
graph.add_node(DRPath.IMAGE_GENERATION, image_generation_graph)
|
||||
# image_generation_graph = dr_image_generation_graph_builder().compile()
|
||||
# graph.add_node(DRPath.IMAGE_GENERATION, image_generation_graph)
|
||||
|
||||
custom_tool_graph = dr_custom_tool_graph_builder().compile()
|
||||
graph.add_node(DRPath.GENERIC_TOOL, custom_tool_graph)
|
||||
# custom_tool_graph = dr_custom_tool_graph_builder().compile()
|
||||
# graph.add_node(DRPath.GENERIC_TOOL, custom_tool_graph)
|
||||
|
||||
generic_internal_tool_graph = dr_generic_internal_tool_graph_builder().compile()
|
||||
graph.add_node(DRPath.GENERIC_INTERNAL_TOOL, generic_internal_tool_graph)
|
||||
# generic_internal_tool_graph = dr_generic_internal_tool_graph_builder().compile()
|
||||
# graph.add_node(DRPath.GENERIC_INTERNAL_TOOL, generic_internal_tool_graph)
|
||||
|
||||
graph.add_node(DRPath.CLOSER, closer)
|
||||
graph.add_node(DRPath.LOGGER, logging)
|
||||
# graph.add_node(DRPath.CLOSER, closer)
|
||||
# graph.add_node(DRPath.LOGGER, logging)
|
||||
|
||||
### Add edges ###
|
||||
# ### Add edges ###
|
||||
|
||||
graph.add_edge(start_key=START, end_key=DRPath.CLARIFIER)
|
||||
# graph.add_edge(start_key=START, end_key=DRPath.CLARIFIER)
|
||||
|
||||
graph.add_conditional_edges(DRPath.CLARIFIER, decision_router)
|
||||
# graph.add_conditional_edges(DRPath.CLARIFIER, decision_router)
|
||||
|
||||
graph.add_conditional_edges(DRPath.ORCHESTRATOR, decision_router)
|
||||
# graph.add_conditional_edges(DRPath.ORCHESTRATOR, decision_router)
|
||||
|
||||
graph.add_edge(start_key=DRPath.INTERNAL_SEARCH, end_key=DRPath.ORCHESTRATOR)
|
||||
graph.add_edge(start_key=DRPath.KNOWLEDGE_GRAPH, end_key=DRPath.ORCHESTRATOR)
|
||||
graph.add_edge(start_key=DRPath.WEB_SEARCH, end_key=DRPath.ORCHESTRATOR)
|
||||
graph.add_edge(start_key=DRPath.IMAGE_GENERATION, end_key=DRPath.ORCHESTRATOR)
|
||||
graph.add_edge(start_key=DRPath.GENERIC_TOOL, end_key=DRPath.ORCHESTRATOR)
|
||||
graph.add_edge(start_key=DRPath.GENERIC_INTERNAL_TOOL, end_key=DRPath.ORCHESTRATOR)
|
||||
# graph.add_edge(start_key=DRPath.INTERNAL_SEARCH, end_key=DRPath.ORCHESTRATOR)
|
||||
# graph.add_edge(start_key=DRPath.KNOWLEDGE_GRAPH, end_key=DRPath.ORCHESTRATOR)
|
||||
# graph.add_edge(start_key=DRPath.WEB_SEARCH, end_key=DRPath.ORCHESTRATOR)
|
||||
# graph.add_edge(start_key=DRPath.IMAGE_GENERATION, end_key=DRPath.ORCHESTRATOR)
|
||||
# graph.add_edge(start_key=DRPath.GENERIC_TOOL, end_key=DRPath.ORCHESTRATOR)
|
||||
# graph.add_edge(start_key=DRPath.GENERIC_INTERNAL_TOOL, end_key=DRPath.ORCHESTRATOR)
|
||||
|
||||
graph.add_conditional_edges(DRPath.CLOSER, completeness_router)
|
||||
graph.add_edge(start_key=DRPath.LOGGER, end_key=END)
|
||||
# graph.add_conditional_edges(DRPath.CLOSER, completeness_router)
|
||||
# graph.add_edge(start_key=DRPath.LOGGER, end_key=END)
|
||||
|
||||
return graph
|
||||
# return graph
|
||||
|
||||
@@ -1,131 +1,131 @@
|
||||
from enum import Enum
|
||||
# from enum import Enum
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
# from pydantic import BaseModel
|
||||
|
||||
from onyx.agents.agent_search.dr.enums import DRPath
|
||||
from onyx.agents.agent_search.dr.sub_agents.image_generation.models import (
|
||||
GeneratedImage,
|
||||
)
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.tools.tool import Tool
|
||||
# from onyx.agents.agent_search.dr.enums import DRPath
|
||||
# from onyx.agents.agent_search.dr.sub_agents.image_generation.models import (
|
||||
# GeneratedImage,
|
||||
# )
|
||||
# from onyx.context.search.models import InferenceSection
|
||||
# from onyx.tools.tool import Tool
|
||||
|
||||
|
||||
class OrchestratorStep(BaseModel):
|
||||
tool: str
|
||||
questions: list[str]
|
||||
# class OrchestratorStep(BaseModel):
|
||||
# tool: str
|
||||
# questions: list[str]
|
||||
|
||||
|
||||
class OrchestratorDecisonsNoPlan(BaseModel):
|
||||
reasoning: str
|
||||
next_step: OrchestratorStep
|
||||
# class OrchestratorDecisonsNoPlan(BaseModel):
|
||||
# reasoning: str
|
||||
# next_step: OrchestratorStep
|
||||
|
||||
|
||||
class OrchestrationPlan(BaseModel):
|
||||
reasoning: str
|
||||
plan: str
|
||||
# class OrchestrationPlan(BaseModel):
|
||||
# reasoning: str
|
||||
# plan: str
|
||||
|
||||
|
||||
class ClarificationGenerationResponse(BaseModel):
|
||||
clarification_needed: bool
|
||||
clarification_question: str
|
||||
# class ClarificationGenerationResponse(BaseModel):
|
||||
# clarification_needed: bool
|
||||
# clarification_question: str
|
||||
|
||||
|
||||
class DecisionResponse(BaseModel):
|
||||
reasoning: str
|
||||
decision: str
|
||||
# class DecisionResponse(BaseModel):
|
||||
# reasoning: str
|
||||
# decision: str
|
||||
|
||||
|
||||
class QueryEvaluationResponse(BaseModel):
|
||||
reasoning: str
|
||||
query_permitted: bool
|
||||
# class QueryEvaluationResponse(BaseModel):
|
||||
# reasoning: str
|
||||
# query_permitted: bool
|
||||
|
||||
|
||||
class OrchestrationClarificationInfo(BaseModel):
|
||||
clarification_question: str
|
||||
clarification_response: str | None = None
|
||||
# class OrchestrationClarificationInfo(BaseModel):
|
||||
# clarification_question: str
|
||||
# clarification_response: str | None = None
|
||||
|
||||
|
||||
class WebSearchAnswer(BaseModel):
|
||||
urls_to_open_indices: list[int]
|
||||
# class WebSearchAnswer(BaseModel):
|
||||
# urls_to_open_indices: list[int]
|
||||
|
||||
|
||||
class SearchAnswer(BaseModel):
|
||||
reasoning: str
|
||||
answer: str
|
||||
claims: list[str] | None = None
|
||||
# class SearchAnswer(BaseModel):
|
||||
# reasoning: str
|
||||
# answer: str
|
||||
# claims: list[str] | None = None
|
||||
|
||||
|
||||
class TestInfoCompleteResponse(BaseModel):
|
||||
reasoning: str
|
||||
complete: bool
|
||||
gaps: list[str]
|
||||
# class TestInfoCompleteResponse(BaseModel):
|
||||
# reasoning: str
|
||||
# complete: bool
|
||||
# gaps: list[str]
|
||||
|
||||
|
||||
# TODO: revisit with custom tools implementation in v2
|
||||
# each tool should be a class with the attributes below, plus the actual tool implementation
|
||||
# this will also allow custom tools to have their own cost
|
||||
class OrchestratorTool(BaseModel):
|
||||
tool_id: int
|
||||
name: str
|
||||
llm_path: str # the path for the LLM to refer by
|
||||
path: DRPath # the actual path in the graph
|
||||
description: str
|
||||
metadata: dict[str, str]
|
||||
cost: float
|
||||
tool_object: Tool | None = None # None for CLOSER
|
||||
# # TODO: revisit with custom tools implementation in v2
|
||||
# # each tool should be a class with the attributes below, plus the actual tool implementation
|
||||
# # this will also allow custom tools to have their own cost
|
||||
# class OrchestratorTool(BaseModel):
|
||||
# tool_id: int
|
||||
# name: str
|
||||
# llm_path: str # the path for the LLM to refer by
|
||||
# path: DRPath # the actual path in the graph
|
||||
# description: str
|
||||
# metadata: dict[str, str]
|
||||
# cost: float
|
||||
# tool_object: Tool | None = None # None for CLOSER
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
# class Config:
|
||||
# arbitrary_types_allowed = True
|
||||
|
||||
|
||||
class IterationInstructions(BaseModel):
|
||||
iteration_nr: int
|
||||
plan: str | None
|
||||
reasoning: str
|
||||
purpose: str
|
||||
# class IterationInstructions(BaseModel):
|
||||
# iteration_nr: int
|
||||
# plan: str | None
|
||||
# reasoning: str
|
||||
# purpose: str
|
||||
|
||||
|
||||
class IterationAnswer(BaseModel):
|
||||
tool: str
|
||||
tool_id: int
|
||||
iteration_nr: int
|
||||
parallelization_nr: int
|
||||
question: str
|
||||
reasoning: str | None
|
||||
answer: str
|
||||
cited_documents: dict[int, InferenceSection]
|
||||
background_info: str | None = None
|
||||
claims: list[str] | None = None
|
||||
additional_data: dict[str, str] | None = None
|
||||
response_type: str | None = None
|
||||
data: dict | list | str | int | float | bool | None = None
|
||||
file_ids: list[str] | None = None
|
||||
# TODO: This is not ideal, but we'll can rework the schema
|
||||
# for deep research later
|
||||
is_web_fetch: bool = False
|
||||
# for image generation step-types
|
||||
generated_images: list[GeneratedImage] | None = None
|
||||
# for multi-query search tools (v2 web search and internal search)
|
||||
# TODO: Clean this up to be more flexible to tools
|
||||
queries: list[str] | None = None
|
||||
# class IterationAnswer(BaseModel):
|
||||
# tool: str
|
||||
# tool_id: int
|
||||
# iteration_nr: int
|
||||
# parallelization_nr: int
|
||||
# question: str
|
||||
# reasoning: str | None
|
||||
# answer: str
|
||||
# cited_documents: dict[int, InferenceSection]
|
||||
# background_info: str | None = None
|
||||
# claims: list[str] | None = None
|
||||
# additional_data: dict[str, str] | None = None
|
||||
# response_type: str | None = None
|
||||
# data: dict | list | str | int | float | bool | None = None
|
||||
# file_ids: list[str] | None = None
|
||||
# # TODO: This is not ideal, but we'll can rework the schema
|
||||
# # for deep research later
|
||||
# is_web_fetch: bool = False
|
||||
# # for image generation step-types
|
||||
# generated_images: list[GeneratedImage] | None = None
|
||||
# # for multi-query search tools (v2 web search and internal search)
|
||||
# # TODO: Clean this up to be more flexible to tools
|
||||
# queries: list[str] | None = None
|
||||
|
||||
|
||||
class AggregatedDRContext(BaseModel):
|
||||
context: str
|
||||
cited_documents: list[InferenceSection]
|
||||
is_internet_marker_dict: dict[str, bool]
|
||||
global_iteration_responses: list[IterationAnswer]
|
||||
# class AggregatedDRContext(BaseModel):
|
||||
# context: str
|
||||
# cited_documents: list[InferenceSection]
|
||||
# is_internet_marker_dict: dict[str, bool]
|
||||
# global_iteration_responses: list[IterationAnswer]
|
||||
|
||||
|
||||
class DRPromptPurpose(str, Enum):
|
||||
PLAN = "PLAN"
|
||||
NEXT_STEP = "NEXT_STEP"
|
||||
NEXT_STEP_REASONING = "NEXT_STEP_REASONING"
|
||||
NEXT_STEP_PURPOSE = "NEXT_STEP_PURPOSE"
|
||||
CLARIFICATION = "CLARIFICATION"
|
||||
# class DRPromptPurpose(str, Enum):
|
||||
# PLAN = "PLAN"
|
||||
# NEXT_STEP = "NEXT_STEP"
|
||||
# NEXT_STEP_REASONING = "NEXT_STEP_REASONING"
|
||||
# NEXT_STEP_PURPOSE = "NEXT_STEP_PURPOSE"
|
||||
# CLARIFICATION = "CLARIFICATION"
|
||||
|
||||
|
||||
class BaseSearchProcessingResponse(BaseModel):
|
||||
specified_source_types: list[str]
|
||||
rewritten_query: str
|
||||
time_filter: str
|
||||
# class BaseSearchProcessingResponse(BaseModel):
|
||||
# specified_source_types: list[str]
|
||||
# rewritten_query: str
|
||||
# time_filter: str
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -1,423 +1,418 @@
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
# import re
|
||||
# from datetime import datetime
|
||||
# from typing import cast
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
from sqlalchemy.orm import Session
|
||||
# from langchain_core.runnables import RunnableConfig
|
||||
# from langgraph.types import StreamWriter
|
||||
# from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.agents.agent_search.dr.constants import MAX_CHAT_HISTORY_MESSAGES
|
||||
from onyx.agents.agent_search.dr.constants import MAX_NUM_CLOSER_SUGGESTIONS
|
||||
from onyx.agents.agent_search.dr.enums import DRPath
|
||||
from onyx.agents.agent_search.dr.enums import ResearchAnswerPurpose
|
||||
from onyx.agents.agent_search.dr.enums import ResearchType
|
||||
from onyx.agents.agent_search.dr.models import AggregatedDRContext
|
||||
from onyx.agents.agent_search.dr.models import TestInfoCompleteResponse
|
||||
from onyx.agents.agent_search.dr.states import FinalUpdate
|
||||
from onyx.agents.agent_search.dr.states import MainState
|
||||
from onyx.agents.agent_search.dr.states import OrchestrationUpdate
|
||||
from onyx.agents.agent_search.dr.sub_agents.image_generation.models import (
|
||||
GeneratedImageFullResult,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.utils import aggregate_context
|
||||
from onyx.agents.agent_search.dr.utils import convert_inference_sections_to_search_docs
|
||||
from onyx.agents.agent_search.dr.utils import get_chat_history_string
|
||||
from onyx.agents.agent_search.dr.utils import get_prompt_question
|
||||
from onyx.agents.agent_search.dr.utils import parse_plan_to_dict
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.llm import invoke_llm_json
|
||||
from onyx.agents.agent_search.shared_graph_utils.llm import stream_llm_answer
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.agents.agent_search.utils import create_question_prompt
|
||||
from onyx.chat.chat_utils import llm_doc_from_inference_section
|
||||
from onyx.configs.agent_configs import TF_DR_TIMEOUT_LONG
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.db.chat import create_search_doc_from_inference_section
|
||||
from onyx.db.chat import update_db_session_with_messages
|
||||
from onyx.db.models import ChatMessage__SearchDoc
|
||||
from onyx.db.models import ResearchAgentIteration
|
||||
from onyx.db.models import ResearchAgentIterationSubStep
|
||||
from onyx.db.models import SearchDoc as DbSearchDoc
|
||||
from onyx.llm.utils import check_number_of_tokens
|
||||
from onyx.prompts.chat_prompts import PROJECT_INSTRUCTIONS_SEPARATOR
|
||||
from onyx.prompts.dr_prompts import FINAL_ANSWER_PROMPT_W_SUB_ANSWERS
|
||||
from onyx.prompts.dr_prompts import FINAL_ANSWER_PROMPT_WITHOUT_SUB_ANSWERS
|
||||
from onyx.prompts.dr_prompts import TEST_INFO_COMPLETE_PROMPT
|
||||
from onyx.server.query_and_chat.streaming_models import CitationDelta
|
||||
from onyx.server.query_and_chat.streaming_models import CitationStart
|
||||
from onyx.server.query_and_chat.streaming_models import MessageStart
|
||||
from onyx.server.query_and_chat.streaming_models import SectionEnd
|
||||
from onyx.server.query_and_chat.streaming_models import StreamingType
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
# from onyx.agents.agent_search.dr.constants import MAX_CHAT_HISTORY_MESSAGES
|
||||
# from onyx.agents.agent_search.dr.constants import MAX_NUM_CLOSER_SUGGESTIONS
|
||||
# from onyx.agents.agent_search.dr.enums import DRPath
|
||||
# from onyx.agents.agent_search.dr.enums import ResearchAnswerPurpose
|
||||
# from onyx.agents.agent_search.dr.models import AggregatedDRContext
|
||||
# from onyx.agents.agent_search.dr.models import TestInfoCompleteResponse
|
||||
# from onyx.agents.agent_search.dr.states import FinalUpdate
|
||||
# from onyx.agents.agent_search.dr.states import MainState
|
||||
# from onyx.agents.agent_search.dr.states import OrchestrationUpdate
|
||||
# from onyx.agents.agent_search.dr.sub_agents.image_generation.models import (
|
||||
# GeneratedImageFullResult,
|
||||
# )
|
||||
# from onyx.agents.agent_search.dr.utils import aggregate_context
|
||||
# from onyx.agents.agent_search.dr.utils import convert_inference_sections_to_search_docs
|
||||
# from onyx.agents.agent_search.dr.utils import get_chat_history_string
|
||||
# from onyx.agents.agent_search.dr.utils import get_prompt_question
|
||||
# from onyx.agents.agent_search.dr.utils import parse_plan_to_dict
|
||||
# from onyx.agents.agent_search.models import GraphConfig
|
||||
# from onyx.agents.agent_search.shared_graph_utils.llm import invoke_llm_json
|
||||
# from onyx.agents.agent_search.shared_graph_utils.llm import stream_llm_answer
|
||||
# from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
# get_langgraph_node_log_string,
|
||||
# )
|
||||
# from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
# from onyx.agents.agent_search.utils import create_question_prompt
|
||||
# from onyx.chat.chat_utils import llm_doc_from_inference_section
|
||||
# from onyx.configs.agent_configs import TF_DR_TIMEOUT_LONG
|
||||
# from onyx.context.search.models import InferenceSection
|
||||
# from onyx.db.chat import create_search_doc_from_inference_section
|
||||
# from onyx.db.chat import update_db_session_with_messages
|
||||
# from onyx.db.models import ChatMessage__SearchDoc
|
||||
# from onyx.db.models import ResearchAgentIteration
|
||||
# from onyx.db.models import ResearchAgentIterationSubStep
|
||||
# from onyx.db.models import SearchDoc as DbSearchDoc
|
||||
# from onyx.llm.utils import check_number_of_tokens
|
||||
# from onyx.prompts.chat_prompts import PROJECT_INSTRUCTIONS_SEPARATOR
|
||||
# from onyx.prompts.dr_prompts import FINAL_ANSWER_PROMPT_W_SUB_ANSWERS
|
||||
# from onyx.prompts.dr_prompts import FINAL_ANSWER_PROMPT_WITHOUT_SUB_ANSWERS
|
||||
# from onyx.prompts.dr_prompts import TEST_INFO_COMPLETE_PROMPT
|
||||
# from onyx.server.query_and_chat.streaming_models import CitationDelta
|
||||
# from onyx.server.query_and_chat.streaming_models import CitationStart
|
||||
# from onyx.server.query_and_chat.streaming_models import MessageStart
|
||||
# from onyx.server.query_and_chat.streaming_models import SectionEnd
|
||||
# from onyx.server.query_and_chat.streaming_models import StreamingType
|
||||
# from onyx.utils.logger import setup_logger
|
||||
# from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
|
||||
logger = setup_logger()
|
||||
# logger = setup_logger()
|
||||
|
||||
|
||||
def extract_citation_numbers(text: str) -> list[int]:
|
||||
"""
|
||||
Extract all citation numbers from text in the format [[<number>]] or [[<number_1>, <number_2>, ...]].
|
||||
Returns a list of all unique citation numbers found.
|
||||
"""
|
||||
# Pattern to match [[number]] or [[number1, number2, ...]]
|
||||
pattern = r"\[\[(\d+(?:,\s*\d+)*)\]\]"
|
||||
matches = re.findall(pattern, text)
|
||||
# def extract_citation_numbers(text: str) -> list[int]:
|
||||
# """
|
||||
# Extract all citation numbers from text in the format [[<number>]] or [[<number_1>, <number_2>, ...]].
|
||||
# Returns a list of all unique citation numbers found.
|
||||
# """
|
||||
# # Pattern to match [[number]] or [[number1, number2, ...]]
|
||||
# pattern = r"\[\[(\d+(?:,\s*\d+)*)\]\]"
|
||||
# matches = re.findall(pattern, text)
|
||||
|
||||
cited_numbers = []
|
||||
for match in matches:
|
||||
# Split by comma and extract all numbers
|
||||
numbers = [int(num.strip()) for num in match.split(",")]
|
||||
cited_numbers.extend(numbers)
|
||||
# cited_numbers = []
|
||||
# for match in matches:
|
||||
# # Split by comma and extract all numbers
|
||||
# numbers = [int(num.strip()) for num in match.split(",")]
|
||||
# cited_numbers.extend(numbers)
|
||||
|
||||
return list(set(cited_numbers)) # Return unique numbers
|
||||
# return list(set(cited_numbers)) # Return unique numbers
|
||||
|
||||
|
||||
def replace_citation_with_link(match: re.Match[str], docs: list[DbSearchDoc]) -> str:
|
||||
citation_content = match.group(1) # e.g., "3" or "3, 5, 7"
|
||||
numbers = [int(num.strip()) for num in citation_content.split(",")]
|
||||
# def replace_citation_with_link(match: re.Match[str], docs: list[DbSearchDoc]) -> str:
|
||||
# citation_content = match.group(1) # e.g., "3" or "3, 5, 7"
|
||||
# numbers = [int(num.strip()) for num in citation_content.split(",")]
|
||||
|
||||
# For multiple citations like [[3, 5, 7]], create separate linked citations
|
||||
linked_citations = []
|
||||
for num in numbers:
|
||||
if num - 1 < len(docs): # Check bounds
|
||||
link = docs[num - 1].link or ""
|
||||
linked_citations.append(f"[[{num}]]({link})")
|
||||
else:
|
||||
linked_citations.append(f"[[{num}]]") # No link if out of bounds
|
||||
# # For multiple citations like [[3, 5, 7]], create separate linked citations
|
||||
# linked_citations = []
|
||||
# for num in numbers:
|
||||
# if num - 1 < len(docs): # Check bounds
|
||||
# link = docs[num - 1].link or ""
|
||||
# linked_citations.append(f"[[{num}]]({link})")
|
||||
# else:
|
||||
# linked_citations.append(f"[[{num}]]") # No link if out of bounds
|
||||
|
||||
return "".join(linked_citations)
|
||||
# return "".join(linked_citations)
|
||||
|
||||
|
||||
def insert_chat_message_search_doc_pair(
|
||||
message_id: int, search_doc_ids: list[int], db_session: Session
|
||||
) -> None:
|
||||
"""
|
||||
Insert a pair of message_id and search_doc_id into the chat_message__search_doc table.
|
||||
# def insert_chat_message_search_doc_pair(
|
||||
# message_id: int, search_doc_ids: list[int], db_session: Session
|
||||
# ) -> None:
|
||||
# """
|
||||
# Insert a pair of message_id and search_doc_id into the chat_message__search_doc table.
|
||||
|
||||
Args:
|
||||
message_id: The ID of the chat message
|
||||
search_doc_id: The ID of the search document
|
||||
db_session: The database session
|
||||
"""
|
||||
for search_doc_id in search_doc_ids:
|
||||
chat_message_search_doc = ChatMessage__SearchDoc(
|
||||
chat_message_id=message_id, search_doc_id=search_doc_id
|
||||
)
|
||||
db_session.add(chat_message_search_doc)
|
||||
# Args:
|
||||
# message_id: The ID of the chat message
|
||||
# search_doc_id: The ID of the search document
|
||||
# db_session: The database session
|
||||
# """
|
||||
# for search_doc_id in search_doc_ids:
|
||||
# chat_message_search_doc = ChatMessage__SearchDoc(
|
||||
# chat_message_id=message_id, search_doc_id=search_doc_id
|
||||
# )
|
||||
# db_session.add(chat_message_search_doc)
|
||||
|
||||
|
||||
def save_iteration(
|
||||
state: MainState,
|
||||
graph_config: GraphConfig,
|
||||
aggregated_context: AggregatedDRContext,
|
||||
final_answer: str,
|
||||
all_cited_documents: list[InferenceSection],
|
||||
is_internet_marker_dict: dict[str, bool],
|
||||
) -> None:
|
||||
db_session = graph_config.persistence.db_session
|
||||
message_id = graph_config.persistence.message_id
|
||||
research_type = graph_config.behavior.research_type
|
||||
db_session = graph_config.persistence.db_session
|
||||
# def save_iteration(
|
||||
# state: MainState,
|
||||
# graph_config: GraphConfig,
|
||||
# aggregated_context: AggregatedDRContext,
|
||||
# final_answer: str,
|
||||
# all_cited_documents: list[InferenceSection],
|
||||
# is_internet_marker_dict: dict[str, bool],
|
||||
# ) -> None:
|
||||
# db_session = graph_config.persistence.db_session
|
||||
# message_id = graph_config.persistence.message_id
|
||||
|
||||
# first, insert the search_docs
|
||||
search_docs = [
|
||||
create_search_doc_from_inference_section(
|
||||
inference_section=inference_section,
|
||||
is_internet=is_internet_marker_dict.get(
|
||||
inference_section.center_chunk.document_id, False
|
||||
), # TODO: revisit
|
||||
db_session=db_session,
|
||||
commit=False,
|
||||
)
|
||||
for inference_section in all_cited_documents
|
||||
]
|
||||
# # first, insert the search_docs
|
||||
# search_docs = [
|
||||
# create_search_doc_from_inference_section(
|
||||
# inference_section=inference_section,
|
||||
# is_internet=is_internet_marker_dict.get(
|
||||
# inference_section.center_chunk.document_id, False
|
||||
# ), # TODO: revisit
|
||||
# db_session=db_session,
|
||||
# commit=False,
|
||||
# )
|
||||
# for inference_section in all_cited_documents
|
||||
# ]
|
||||
|
||||
# then, map_search_docs to message
|
||||
insert_chat_message_search_doc_pair(
|
||||
message_id, [search_doc.id for search_doc in search_docs], db_session
|
||||
)
|
||||
# # then, map_search_docs to message
|
||||
# insert_chat_message_search_doc_pair(
|
||||
# message_id, [search_doc.id for search_doc in search_docs], db_session
|
||||
# )
|
||||
|
||||
# lastly, insert the citations
|
||||
citation_dict: dict[int, int] = {}
|
||||
cited_doc_nrs = extract_citation_numbers(final_answer)
|
||||
for cited_doc_nr in cited_doc_nrs:
|
||||
citation_dict[cited_doc_nr] = search_docs[cited_doc_nr - 1].id
|
||||
# # lastly, insert the citations
|
||||
# citation_dict: dict[int, int] = {}
|
||||
# cited_doc_nrs = extract_citation_numbers(final_answer)
|
||||
# for cited_doc_nr in cited_doc_nrs:
|
||||
# citation_dict[cited_doc_nr] = search_docs[cited_doc_nr - 1].id
|
||||
|
||||
# TODO: generate plan as dict in the first place
|
||||
plan_of_record = state.plan_of_record.plan if state.plan_of_record else ""
|
||||
plan_of_record_dict = parse_plan_to_dict(plan_of_record)
|
||||
# # TODO: generate plan as dict in the first place
|
||||
# plan_of_record = state.plan_of_record.plan if state.plan_of_record else ""
|
||||
# plan_of_record_dict = parse_plan_to_dict(plan_of_record)
|
||||
|
||||
# Update the chat message and its parent message in database
|
||||
update_db_session_with_messages(
|
||||
db_session=db_session,
|
||||
chat_message_id=message_id,
|
||||
chat_session_id=graph_config.persistence.chat_session_id,
|
||||
is_agentic=graph_config.behavior.use_agentic_search,
|
||||
message=final_answer,
|
||||
citations=citation_dict,
|
||||
research_type=research_type,
|
||||
research_plan=plan_of_record_dict,
|
||||
final_documents=search_docs,
|
||||
update_parent_message=True,
|
||||
research_answer_purpose=ResearchAnswerPurpose.ANSWER,
|
||||
)
|
||||
# # Update the chat message and its parent message in database
|
||||
# update_db_session_with_messages(
|
||||
# db_session=db_session,
|
||||
# chat_message_id=message_id,
|
||||
# chat_session_id=graph_config.persistence.chat_session_id,
|
||||
# is_agentic=graph_config.behavior.use_agentic_search,
|
||||
# message=final_answer,
|
||||
# citations=citation_dict,
|
||||
# research_type=None, # research_type is deprecated
|
||||
# research_plan=plan_of_record_dict,
|
||||
# final_documents=search_docs,
|
||||
# update_parent_message=True,
|
||||
# research_answer_purpose=ResearchAnswerPurpose.ANSWER,
|
||||
# )
|
||||
|
||||
for iteration_preparation in state.iteration_instructions:
|
||||
research_agent_iteration_step = ResearchAgentIteration(
|
||||
primary_question_id=message_id,
|
||||
reasoning=iteration_preparation.reasoning,
|
||||
purpose=iteration_preparation.purpose,
|
||||
iteration_nr=iteration_preparation.iteration_nr,
|
||||
)
|
||||
db_session.add(research_agent_iteration_step)
|
||||
# for iteration_preparation in state.iteration_instructions:
|
||||
# research_agent_iteration_step = ResearchAgentIteration(
|
||||
# primary_question_id=message_id,
|
||||
# reasoning=iteration_preparation.reasoning,
|
||||
# purpose=iteration_preparation.purpose,
|
||||
# iteration_nr=iteration_preparation.iteration_nr,
|
||||
# )
|
||||
# db_session.add(research_agent_iteration_step)
|
||||
|
||||
for iteration_answer in aggregated_context.global_iteration_responses:
|
||||
# for iteration_answer in aggregated_context.global_iteration_responses:
|
||||
|
||||
retrieved_search_docs = convert_inference_sections_to_search_docs(
|
||||
list(iteration_answer.cited_documents.values())
|
||||
)
|
||||
# retrieved_search_docs = convert_inference_sections_to_search_docs(
|
||||
# list(iteration_answer.cited_documents.values())
|
||||
# )
|
||||
|
||||
# Convert SavedSearchDoc objects to JSON-serializable format
|
||||
serialized_search_docs = [doc.model_dump() for doc in retrieved_search_docs]
|
||||
# # Convert SavedSearchDoc objects to JSON-serializable format
|
||||
# serialized_search_docs = [doc.model_dump() for doc in retrieved_search_docs]
|
||||
|
||||
research_agent_iteration_sub_step = ResearchAgentIterationSubStep(
|
||||
primary_question_id=message_id,
|
||||
iteration_nr=iteration_answer.iteration_nr,
|
||||
iteration_sub_step_nr=iteration_answer.parallelization_nr,
|
||||
sub_step_instructions=iteration_answer.question,
|
||||
sub_step_tool_id=iteration_answer.tool_id,
|
||||
sub_answer=iteration_answer.answer,
|
||||
reasoning=iteration_answer.reasoning,
|
||||
claims=iteration_answer.claims,
|
||||
cited_doc_results=serialized_search_docs,
|
||||
generated_images=(
|
||||
GeneratedImageFullResult(images=iteration_answer.generated_images)
|
||||
if iteration_answer.generated_images
|
||||
else None
|
||||
),
|
||||
additional_data=iteration_answer.additional_data,
|
||||
queries=iteration_answer.queries,
|
||||
)
|
||||
db_session.add(research_agent_iteration_sub_step)
|
||||
# research_agent_iteration_sub_step = ResearchAgentIterationSubStep(
|
||||
# primary_question_id=message_id,
|
||||
# iteration_nr=iteration_answer.iteration_nr,
|
||||
# iteration_sub_step_nr=iteration_answer.parallelization_nr,
|
||||
# sub_step_instructions=iteration_answer.question,
|
||||
# sub_step_tool_id=iteration_answer.tool_id,
|
||||
# sub_answer=iteration_answer.answer,
|
||||
# reasoning=iteration_answer.reasoning,
|
||||
# claims=iteration_answer.claims,
|
||||
# cited_doc_results=serialized_search_docs,
|
||||
# generated_images=(
|
||||
# GeneratedImageFullResult(images=iteration_answer.generated_images)
|
||||
# if iteration_answer.generated_images
|
||||
# else None
|
||||
# ),
|
||||
# additional_data=iteration_answer.additional_data,
|
||||
# queries=iteration_answer.queries,
|
||||
# )
|
||||
# db_session.add(research_agent_iteration_sub_step)
|
||||
|
||||
db_session.commit()
|
||||
# db_session.commit()
|
||||
|
||||
|
||||
def closer(
|
||||
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> FinalUpdate | OrchestrationUpdate:
|
||||
"""
|
||||
LangGraph node to close the DR process and finalize the answer.
|
||||
"""
|
||||
# def closer(
|
||||
# state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
# ) -> FinalUpdate | OrchestrationUpdate:
|
||||
# """
|
||||
# LangGraph node to close the DR process and finalize the answer.
|
||||
# """
|
||||
|
||||
node_start_time = datetime.now()
|
||||
# TODO: generate final answer using all the previous steps
|
||||
# (right now, answers from each step are concatenated onto each other)
|
||||
# Also, add missing fields once usage in UI is clear.
|
||||
# node_start_time = datetime.now()
|
||||
# # TODO: generate final answer using all the previous steps
|
||||
# # (right now, answers from each step are concatenated onto each other)
|
||||
# # Also, add missing fields once usage in UI is clear.
|
||||
|
||||
current_step_nr = state.current_step_nr
|
||||
# current_step_nr = state.current_step_nr
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
base_question = state.original_question
|
||||
if not base_question:
|
||||
raise ValueError("Question is required for closer")
|
||||
# graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
# base_question = state.original_question
|
||||
# if not base_question:
|
||||
# raise ValueError("Question is required for closer")
|
||||
|
||||
research_type = graph_config.behavior.research_type
|
||||
# use_agentic_search = graph_config.behavior.use_agentic_search
|
||||
|
||||
assistant_system_prompt: str = state.assistant_system_prompt or ""
|
||||
assistant_task_prompt = state.assistant_task_prompt
|
||||
# assistant_system_prompt: str = state.assistant_system_prompt or ""
|
||||
# assistant_task_prompt = state.assistant_task_prompt
|
||||
|
||||
uploaded_context = state.uploaded_test_context or ""
|
||||
# uploaded_context = state.uploaded_test_context or ""
|
||||
|
||||
clarification = state.clarification
|
||||
prompt_question = get_prompt_question(base_question, clarification)
|
||||
# clarification = state.clarification
|
||||
# prompt_question = get_prompt_question(base_question, clarification)
|
||||
|
||||
chat_history_string = (
|
||||
get_chat_history_string(
|
||||
graph_config.inputs.prompt_builder.message_history,
|
||||
MAX_CHAT_HISTORY_MESSAGES,
|
||||
)
|
||||
or "(No chat history yet available)"
|
||||
)
|
||||
# chat_history_string = (
|
||||
# get_chat_history_string(
|
||||
# graph_config.inputs.prompt_builder.message_history,
|
||||
# MAX_CHAT_HISTORY_MESSAGES,
|
||||
# )
|
||||
# or "(No chat history yet available)"
|
||||
# )
|
||||
|
||||
aggregated_context_w_docs = aggregate_context(
|
||||
state.iteration_responses, include_documents=True
|
||||
)
|
||||
# aggregated_context_w_docs = aggregate_context(
|
||||
# state.iteration_responses, include_documents=True
|
||||
# )
|
||||
|
||||
aggregated_context_wo_docs = aggregate_context(
|
||||
state.iteration_responses, include_documents=False
|
||||
)
|
||||
# aggregated_context_wo_docs = aggregate_context(
|
||||
# state.iteration_responses, include_documents=False
|
||||
# )
|
||||
|
||||
iteration_responses_w_docs_string = aggregated_context_w_docs.context
|
||||
iteration_responses_wo_docs_string = aggregated_context_wo_docs.context
|
||||
all_cited_documents = aggregated_context_w_docs.cited_documents
|
||||
# iteration_responses_w_docs_string = aggregated_context_w_docs.context
|
||||
# iteration_responses_wo_docs_string = aggregated_context_wo_docs.context
|
||||
# all_cited_documents = aggregated_context_w_docs.cited_documents
|
||||
|
||||
num_closer_suggestions = state.num_closer_suggestions
|
||||
# num_closer_suggestions = state.num_closer_suggestions
|
||||
|
||||
if (
|
||||
num_closer_suggestions < MAX_NUM_CLOSER_SUGGESTIONS
|
||||
and research_type == ResearchType.DEEP
|
||||
):
|
||||
test_info_complete_prompt = TEST_INFO_COMPLETE_PROMPT.build(
|
||||
base_question=prompt_question,
|
||||
questions_answers_claims=iteration_responses_wo_docs_string,
|
||||
chat_history_string=chat_history_string,
|
||||
high_level_plan=(
|
||||
state.plan_of_record.plan
|
||||
if state.plan_of_record
|
||||
else "No plan available"
|
||||
),
|
||||
)
|
||||
# if (
|
||||
# num_closer_suggestions < MAX_NUM_CLOSER_SUGGESTIONS
|
||||
# and use_agentic_search
|
||||
# ):
|
||||
# test_info_complete_prompt = TEST_INFO_COMPLETE_PROMPT.build(
|
||||
# base_question=prompt_question,
|
||||
# questions_answers_claims=iteration_responses_wo_docs_string,
|
||||
# chat_history_string=chat_history_string,
|
||||
# high_level_plan=(
|
||||
# state.plan_of_record.plan
|
||||
# if state.plan_of_record
|
||||
# else "No plan available"
|
||||
# ),
|
||||
# )
|
||||
|
||||
test_info_complete_json = invoke_llm_json(
|
||||
llm=graph_config.tooling.primary_llm,
|
||||
prompt=create_question_prompt(
|
||||
assistant_system_prompt,
|
||||
test_info_complete_prompt + (assistant_task_prompt or ""),
|
||||
),
|
||||
schema=TestInfoCompleteResponse,
|
||||
timeout_override=TF_DR_TIMEOUT_LONG,
|
||||
# max_tokens=1000,
|
||||
)
|
||||
# test_info_complete_json = invoke_llm_json(
|
||||
# llm=graph_config.tooling.primary_llm,
|
||||
# prompt=create_question_prompt(
|
||||
# assistant_system_prompt,
|
||||
# test_info_complete_prompt + (assistant_task_prompt or ""),
|
||||
# ),
|
||||
# schema=TestInfoCompleteResponse,
|
||||
# timeout_override=TF_DR_TIMEOUT_LONG,
|
||||
# # max_tokens=1000,
|
||||
# )
|
||||
|
||||
if test_info_complete_json.complete:
|
||||
pass
|
||||
# if test_info_complete_json.complete:
|
||||
# pass
|
||||
|
||||
else:
|
||||
return OrchestrationUpdate(
|
||||
tools_used=[DRPath.ORCHESTRATOR.value],
|
||||
query_list=[],
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="main",
|
||||
node_name="closer",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
gaps=test_info_complete_json.gaps,
|
||||
num_closer_suggestions=num_closer_suggestions + 1,
|
||||
)
|
||||
# else:
|
||||
# return OrchestrationUpdate(
|
||||
# tools_used=[DRPath.ORCHESTRATOR.value],
|
||||
# query_list=[],
|
||||
# log_messages=[
|
||||
# get_langgraph_node_log_string(
|
||||
# graph_component="main",
|
||||
# node_name="closer",
|
||||
# node_start_time=node_start_time,
|
||||
# )
|
||||
# ],
|
||||
# gaps=test_info_complete_json.gaps,
|
||||
# num_closer_suggestions=num_closer_suggestions + 1,
|
||||
# )
|
||||
|
||||
retrieved_search_docs = convert_inference_sections_to_search_docs(
|
||||
all_cited_documents
|
||||
)
|
||||
# retrieved_search_docs = convert_inference_sections_to_search_docs(
|
||||
# all_cited_documents
|
||||
# )
|
||||
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
MessageStart(
|
||||
content="",
|
||||
final_documents=retrieved_search_docs,
|
||||
),
|
||||
writer,
|
||||
)
|
||||
# write_custom_event(
|
||||
# current_step_nr,
|
||||
# MessageStart(
|
||||
# content="",
|
||||
# final_documents=retrieved_search_docs,
|
||||
# ),
|
||||
# writer,
|
||||
# )
|
||||
|
||||
if research_type in [ResearchType.THOUGHTFUL, ResearchType.FAST]:
|
||||
final_answer_base_prompt = FINAL_ANSWER_PROMPT_WITHOUT_SUB_ANSWERS
|
||||
elif research_type == ResearchType.DEEP:
|
||||
final_answer_base_prompt = FINAL_ANSWER_PROMPT_W_SUB_ANSWERS
|
||||
else:
|
||||
raise ValueError(f"Invalid research type: {research_type}")
|
||||
# if not use_agentic_search:
|
||||
# final_answer_base_prompt = FINAL_ANSWER_PROMPT_WITHOUT_SUB_ANSWERS
|
||||
# else:
|
||||
# final_answer_base_prompt = FINAL_ANSWER_PROMPT_W_SUB_ANSWERS
|
||||
|
||||
estimated_final_answer_prompt_tokens = check_number_of_tokens(
|
||||
final_answer_base_prompt.build(
|
||||
base_question=prompt_question,
|
||||
iteration_responses_string=iteration_responses_w_docs_string,
|
||||
chat_history_string=chat_history_string,
|
||||
uploaded_context=uploaded_context,
|
||||
)
|
||||
)
|
||||
# estimated_final_answer_prompt_tokens = check_number_of_tokens(
|
||||
# final_answer_base_prompt.build(
|
||||
# base_question=prompt_question,
|
||||
# iteration_responses_string=iteration_responses_w_docs_string,
|
||||
# chat_history_string=chat_history_string,
|
||||
# uploaded_context=uploaded_context,
|
||||
# )
|
||||
# )
|
||||
|
||||
# for DR, rely only on sub-answers and claims to save tokens if context is too long
|
||||
# TODO: consider compression step for Thoughtful mode if context is too long.
|
||||
# Should generally not be the case though.
|
||||
# # for DR, rely only on sub-answers and claims to save tokens if context is too long
|
||||
# # TODO: consider compression step for Thoughtful mode if context is too long.
|
||||
# # Should generally not be the case though.
|
||||
|
||||
max_allowed_input_tokens = graph_config.tooling.primary_llm.config.max_input_tokens
|
||||
# max_allowed_input_tokens = graph_config.tooling.primary_llm.config.max_input_tokens
|
||||
|
||||
if (
|
||||
estimated_final_answer_prompt_tokens > 0.8 * max_allowed_input_tokens
|
||||
and research_type == ResearchType.DEEP
|
||||
):
|
||||
iteration_responses_string = iteration_responses_wo_docs_string
|
||||
else:
|
||||
iteration_responses_string = iteration_responses_w_docs_string
|
||||
# if (
|
||||
# estimated_final_answer_prompt_tokens > 0.8 * max_allowed_input_tokens
|
||||
# and use_agentic_search
|
||||
# ):
|
||||
# iteration_responses_string = iteration_responses_wo_docs_string
|
||||
# else:
|
||||
# iteration_responses_string = iteration_responses_w_docs_string
|
||||
|
||||
final_answer_prompt = final_answer_base_prompt.build(
|
||||
base_question=prompt_question,
|
||||
iteration_responses_string=iteration_responses_string,
|
||||
chat_history_string=chat_history_string,
|
||||
uploaded_context=uploaded_context,
|
||||
)
|
||||
# final_answer_prompt = final_answer_base_prompt.build(
|
||||
# base_question=prompt_question,
|
||||
# iteration_responses_string=iteration_responses_string,
|
||||
# chat_history_string=chat_history_string,
|
||||
# uploaded_context=uploaded_context,
|
||||
# )
|
||||
|
||||
if graph_config.inputs.project_instructions:
|
||||
assistant_system_prompt = (
|
||||
assistant_system_prompt
|
||||
+ PROJECT_INSTRUCTIONS_SEPARATOR
|
||||
+ (graph_config.inputs.project_instructions or "")
|
||||
)
|
||||
# if graph_config.inputs.project_instructions:
|
||||
# assistant_system_prompt = (
|
||||
# assistant_system_prompt
|
||||
# + PROJECT_INSTRUCTIONS_SEPARATOR
|
||||
# + (graph_config.inputs.project_instructions or "")
|
||||
# )
|
||||
|
||||
all_context_llmdocs = [
|
||||
llm_doc_from_inference_section(inference_section)
|
||||
for inference_section in all_cited_documents
|
||||
]
|
||||
# all_context_llmdocs = [
|
||||
# llm_doc_from_inference_section(inference_section)
|
||||
# for inference_section in all_cited_documents
|
||||
# ]
|
||||
|
||||
try:
|
||||
streamed_output, _, citation_infos = run_with_timeout(
|
||||
int(3 * TF_DR_TIMEOUT_LONG),
|
||||
lambda: stream_llm_answer(
|
||||
llm=graph_config.tooling.primary_llm,
|
||||
prompt=create_question_prompt(
|
||||
assistant_system_prompt,
|
||||
final_answer_prompt + (assistant_task_prompt or ""),
|
||||
),
|
||||
event_name="basic_response",
|
||||
writer=writer,
|
||||
agent_answer_level=0,
|
||||
agent_answer_question_num=0,
|
||||
agent_answer_type="agent_level_answer",
|
||||
timeout_override=int(2 * TF_DR_TIMEOUT_LONG),
|
||||
answer_piece=StreamingType.MESSAGE_DELTA.value,
|
||||
ind=current_step_nr,
|
||||
context_docs=all_context_llmdocs,
|
||||
replace_citations=True,
|
||||
# max_tokens=None,
|
||||
),
|
||||
)
|
||||
# try:
|
||||
# streamed_output, _, citation_infos = run_with_timeout(
|
||||
# int(3 * TF_DR_TIMEOUT_LONG),
|
||||
# lambda: stream_llm_answer(
|
||||
# llm=graph_config.tooling.primary_llm,
|
||||
# prompt=create_question_prompt(
|
||||
# assistant_system_prompt,
|
||||
# final_answer_prompt + (assistant_task_prompt or ""),
|
||||
# ),
|
||||
# event_name="basic_response",
|
||||
# writer=writer,
|
||||
# agent_answer_level=0,
|
||||
# agent_answer_question_num=0,
|
||||
# agent_answer_type="agent_level_answer",
|
||||
# timeout_override=int(2 * TF_DR_TIMEOUT_LONG),
|
||||
# answer_piece=StreamingType.MESSAGE_DELTA.value,
|
||||
# ind=current_step_nr,
|
||||
# context_docs=all_context_llmdocs,
|
||||
# replace_citations=True,
|
||||
# # max_tokens=None,
|
||||
# ),
|
||||
# )
|
||||
|
||||
final_answer = "".join(streamed_output)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error in consolidate_research: {e}")
|
||||
# final_answer = "".join(streamed_output)
|
||||
# except Exception as e:
|
||||
# raise ValueError(f"Error in consolidate_research: {e}")
|
||||
|
||||
write_custom_event(current_step_nr, SectionEnd(), writer)
|
||||
# write_custom_event(current_step_nr, SectionEnd(), writer)
|
||||
|
||||
current_step_nr += 1
|
||||
# current_step_nr += 1
|
||||
|
||||
write_custom_event(current_step_nr, CitationStart(), writer)
|
||||
write_custom_event(current_step_nr, CitationDelta(citations=citation_infos), writer)
|
||||
write_custom_event(current_step_nr, SectionEnd(), writer)
|
||||
# write_custom_event(current_step_nr, CitationStart(), writer)
|
||||
# write_custom_event(current_step_nr, CitationDelta(citations=citation_infos), writer)
|
||||
# write_custom_event(current_step_nr, SectionEnd(), writer)
|
||||
|
||||
current_step_nr += 1
|
||||
# current_step_nr += 1
|
||||
|
||||
# Log the research agent steps
|
||||
# save_iteration(
|
||||
# state,
|
||||
# graph_config,
|
||||
# aggregated_context,
|
||||
# final_answer,
|
||||
# all_cited_documents,
|
||||
# is_internet_marker_dict,
|
||||
# )
|
||||
# # Log the research agent steps
|
||||
# # save_iteration(
|
||||
# # state,
|
||||
# # graph_config,
|
||||
# # aggregated_context,
|
||||
# # final_answer,
|
||||
# # all_cited_documents,
|
||||
# # is_internet_marker_dict,
|
||||
# # )
|
||||
|
||||
return FinalUpdate(
|
||||
final_answer=final_answer,
|
||||
all_cited_documents=all_cited_documents,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="main",
|
||||
node_name="closer",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
# return FinalUpdate(
|
||||
# final_answer=final_answer,
|
||||
# all_cited_documents=all_cited_documents,
|
||||
# log_messages=[
|
||||
# get_langgraph_node_log_string(
|
||||
# graph_component="main",
|
||||
# node_name="closer",
|
||||
# node_start_time=node_start_time,
|
||||
# )
|
||||
# ],
|
||||
# )
|
||||
|
||||
@@ -1,248 +1,246 @@
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
# import re
|
||||
# from datetime import datetime
|
||||
# from typing import cast
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
from sqlalchemy.orm import Session
|
||||
# from langchain_core.runnables import RunnableConfig
|
||||
# from langgraph.types import StreamWriter
|
||||
# from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.agents.agent_search.dr.enums import ResearchAnswerPurpose
|
||||
from onyx.agents.agent_search.dr.models import AggregatedDRContext
|
||||
from onyx.agents.agent_search.dr.states import LoggerUpdate
|
||||
from onyx.agents.agent_search.dr.states import MainState
|
||||
from onyx.agents.agent_search.dr.sub_agents.image_generation.models import (
|
||||
GeneratedImageFullResult,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.utils import aggregate_context
|
||||
from onyx.agents.agent_search.dr.utils import convert_inference_sections_to_search_docs
|
||||
from onyx.agents.agent_search.dr.utils import parse_plan_to_dict
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.db.chat import create_search_doc_from_inference_section
|
||||
from onyx.db.chat import update_db_session_with_messages
|
||||
from onyx.db.models import ChatMessage__SearchDoc
|
||||
from onyx.db.models import ResearchAgentIteration
|
||||
from onyx.db.models import ResearchAgentIterationSubStep
|
||||
from onyx.db.models import SearchDoc as DbSearchDoc
|
||||
from onyx.natural_language_processing.utils import get_tokenizer
|
||||
from onyx.server.query_and_chat.streaming_models import OverallStop
|
||||
from onyx.utils.logger import setup_logger
|
||||
# from onyx.agents.agent_search.dr.enums import ResearchAnswerPurpose
|
||||
# from onyx.agents.agent_search.dr.models import AggregatedDRContext
|
||||
# from onyx.agents.agent_search.dr.states import LoggerUpdate
|
||||
# from onyx.agents.agent_search.dr.states import MainState
|
||||
# from onyx.agents.agent_search.dr.sub_agents.image_generation.models import (
|
||||
# GeneratedImageFullResult,
|
||||
# )
|
||||
# from onyx.agents.agent_search.dr.utils import aggregate_context
|
||||
# from onyx.agents.agent_search.dr.utils import convert_inference_sections_to_search_docs
|
||||
# from onyx.agents.agent_search.dr.utils import parse_plan_to_dict
|
||||
# from onyx.agents.agent_search.models import GraphConfig
|
||||
# from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
# get_langgraph_node_log_string,
|
||||
# )
|
||||
# from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
# from onyx.context.search.models import InferenceSection
|
||||
# from onyx.db.chat import create_search_doc_from_inference_section
|
||||
# from onyx.db.chat import update_db_session_with_messages
|
||||
# from onyx.db.models import ChatMessage__SearchDoc
|
||||
# from onyx.db.models import ResearchAgentIteration
|
||||
# from onyx.db.models import ResearchAgentIterationSubStep
|
||||
# from onyx.db.models import SearchDoc as DbSearchDoc
|
||||
# from onyx.natural_language_processing.utils import get_tokenizer
|
||||
# from onyx.server.query_and_chat.streaming_models import OverallStop
|
||||
# from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
# logger = setup_logger()
|
||||
|
||||
|
||||
def _extract_citation_numbers(text: str) -> list[int]:
|
||||
"""
|
||||
Extract all citation numbers from text in the format [[<number>]] or [[<number_1>, <number_2>, ...]].
|
||||
Returns a list of all unique citation numbers found.
|
||||
"""
|
||||
# Pattern to match [[number]] or [[number1, number2, ...]]
|
||||
pattern = r"\[\[(\d+(?:,\s*\d+)*)\]\]"
|
||||
matches = re.findall(pattern, text)
|
||||
# def _extract_citation_numbers(text: str) -> list[int]:
|
||||
# """
|
||||
# Extract all citation numbers from text in the format [[<number>]] or [[<number_1>, <number_2>, ...]].
|
||||
# Returns a list of all unique citation numbers found.
|
||||
# """
|
||||
# # Pattern to match [[number]] or [[number1, number2, ...]]
|
||||
# pattern = r"\[\[(\d+(?:,\s*\d+)*)\]\]"
|
||||
# matches = re.findall(pattern, text)
|
||||
|
||||
cited_numbers = []
|
||||
for match in matches:
|
||||
# Split by comma and extract all numbers
|
||||
numbers = [int(num.strip()) for num in match.split(",")]
|
||||
cited_numbers.extend(numbers)
|
||||
# cited_numbers = []
|
||||
# for match in matches:
|
||||
# # Split by comma and extract all numbers
|
||||
# numbers = [int(num.strip()) for num in match.split(",")]
|
||||
# cited_numbers.extend(numbers)
|
||||
|
||||
return list(set(cited_numbers)) # Return unique numbers
|
||||
# return list(set(cited_numbers)) # Return unique numbers
|
||||
|
||||
|
||||
def replace_citation_with_link(match: re.Match[str], docs: list[DbSearchDoc]) -> str:
|
||||
citation_content = match.group(1) # e.g., "3" or "3, 5, 7"
|
||||
numbers = [int(num.strip()) for num in citation_content.split(",")]
|
||||
# def replace_citation_with_link(match: re.Match[str], docs: list[DbSearchDoc]) -> str:
|
||||
# citation_content = match.group(1) # e.g., "3" or "3, 5, 7"
|
||||
# numbers = [int(num.strip()) for num in citation_content.split(",")]
|
||||
|
||||
# For multiple citations like [[3, 5, 7]], create separate linked citations
|
||||
linked_citations = []
|
||||
for num in numbers:
|
||||
if num - 1 < len(docs): # Check bounds
|
||||
link = docs[num - 1].link or ""
|
||||
linked_citations.append(f"[[{num}]]({link})")
|
||||
else:
|
||||
linked_citations.append(f"[[{num}]]") # No link if out of bounds
|
||||
# # For multiple citations like [[3, 5, 7]], create separate linked citations
|
||||
# linked_citations = []
|
||||
# for num in numbers:
|
||||
# if num - 1 < len(docs): # Check bounds
|
||||
# link = docs[num - 1].link or ""
|
||||
# linked_citations.append(f"[[{num}]]({link})")
|
||||
# else:
|
||||
# linked_citations.append(f"[[{num}]]") # No link if out of bounds
|
||||
|
||||
return "".join(linked_citations)
|
||||
# return "".join(linked_citations)
|
||||
|
||||
|
||||
def _insert_chat_message_search_doc_pair(
|
||||
message_id: int, search_doc_ids: list[int], db_session: Session
|
||||
) -> None:
|
||||
"""
|
||||
Insert a pair of message_id and search_doc_id into the chat_message__search_doc table.
|
||||
# def _insert_chat_message_search_doc_pair(
|
||||
# message_id: int, search_doc_ids: list[int], db_session: Session
|
||||
# ) -> None:
|
||||
# """
|
||||
# Insert a pair of message_id and search_doc_id into the chat_message__search_doc table.
|
||||
|
||||
Args:
|
||||
message_id: The ID of the chat message
|
||||
search_doc_id: The ID of the search document
|
||||
db_session: The database session
|
||||
"""
|
||||
for search_doc_id in search_doc_ids:
|
||||
chat_message_search_doc = ChatMessage__SearchDoc(
|
||||
chat_message_id=message_id, search_doc_id=search_doc_id
|
||||
)
|
||||
db_session.add(chat_message_search_doc)
|
||||
# Args:
|
||||
# message_id: The ID of the chat message
|
||||
# search_doc_id: The ID of the search document
|
||||
# db_session: The database session
|
||||
# """
|
||||
# for search_doc_id in search_doc_ids:
|
||||
# chat_message_search_doc = ChatMessage__SearchDoc(
|
||||
# chat_message_id=message_id, search_doc_id=search_doc_id
|
||||
# )
|
||||
# db_session.add(chat_message_search_doc)
|
||||
|
||||
|
||||
def save_iteration(
|
||||
state: MainState,
|
||||
graph_config: GraphConfig,
|
||||
aggregated_context: AggregatedDRContext,
|
||||
final_answer: str,
|
||||
all_cited_documents: list[InferenceSection],
|
||||
is_internet_marker_dict: dict[str, bool],
|
||||
num_tokens: int,
|
||||
) -> None:
|
||||
db_session = graph_config.persistence.db_session
|
||||
message_id = graph_config.persistence.message_id
|
||||
research_type = graph_config.behavior.research_type
|
||||
db_session = graph_config.persistence.db_session
|
||||
# def save_iteration(
|
||||
# state: MainState,
|
||||
# graph_config: GraphConfig,
|
||||
# aggregated_context: AggregatedDRContext,
|
||||
# final_answer: str,
|
||||
# all_cited_documents: list[InferenceSection],
|
||||
# is_internet_marker_dict: dict[str, bool],
|
||||
# num_tokens: int,
|
||||
# ) -> None:
|
||||
# db_session = graph_config.persistence.db_session
|
||||
# message_id = graph_config.persistence.message_id
|
||||
|
||||
# first, insert the search_docs
|
||||
search_docs = [
|
||||
create_search_doc_from_inference_section(
|
||||
inference_section=inference_section,
|
||||
is_internet=is_internet_marker_dict.get(
|
||||
inference_section.center_chunk.document_id, False
|
||||
), # TODO: revisit
|
||||
db_session=db_session,
|
||||
commit=False,
|
||||
)
|
||||
for inference_section in all_cited_documents
|
||||
]
|
||||
# # first, insert the search_docs
|
||||
# search_docs = [
|
||||
# create_search_doc_from_inference_section(
|
||||
# inference_section=inference_section,
|
||||
# is_internet=is_internet_marker_dict.get(
|
||||
# inference_section.center_chunk.document_id, False
|
||||
# ), # TODO: revisit
|
||||
# db_session=db_session,
|
||||
# commit=False,
|
||||
# )
|
||||
# for inference_section in all_cited_documents
|
||||
# ]
|
||||
|
||||
# then, map_search_docs to message
|
||||
_insert_chat_message_search_doc_pair(
|
||||
message_id, [search_doc.id for search_doc in search_docs], db_session
|
||||
)
|
||||
# # then, map_search_docs to message
|
||||
# _insert_chat_message_search_doc_pair(
|
||||
# message_id, [search_doc.id for search_doc in search_docs], db_session
|
||||
# )
|
||||
|
||||
# lastly, insert the citations
|
||||
citation_dict: dict[int, int] = {}
|
||||
cited_doc_nrs = _extract_citation_numbers(final_answer)
|
||||
if search_docs:
|
||||
for cited_doc_nr in cited_doc_nrs:
|
||||
citation_dict[cited_doc_nr] = search_docs[cited_doc_nr - 1].id
|
||||
# # lastly, insert the citations
|
||||
# citation_dict: dict[int, int] = {}
|
||||
# cited_doc_nrs = _extract_citation_numbers(final_answer)
|
||||
# if search_docs:
|
||||
# for cited_doc_nr in cited_doc_nrs:
|
||||
# citation_dict[cited_doc_nr] = search_docs[cited_doc_nr - 1].id
|
||||
|
||||
# TODO: generate plan as dict in the first place
|
||||
plan_of_record = state.plan_of_record.plan if state.plan_of_record else ""
|
||||
plan_of_record_dict = parse_plan_to_dict(plan_of_record)
|
||||
# # TODO: generate plan as dict in the first place
|
||||
# plan_of_record = state.plan_of_record.plan if state.plan_of_record else ""
|
||||
# plan_of_record_dict = parse_plan_to_dict(plan_of_record)
|
||||
|
||||
# Update the chat message and its parent message in database
|
||||
update_db_session_with_messages(
|
||||
db_session=db_session,
|
||||
chat_message_id=message_id,
|
||||
chat_session_id=graph_config.persistence.chat_session_id,
|
||||
is_agentic=graph_config.behavior.use_agentic_search,
|
||||
message=final_answer,
|
||||
citations=citation_dict,
|
||||
research_type=research_type,
|
||||
research_plan=plan_of_record_dict,
|
||||
final_documents=search_docs,
|
||||
update_parent_message=True,
|
||||
research_answer_purpose=ResearchAnswerPurpose.ANSWER,
|
||||
token_count=num_tokens,
|
||||
)
|
||||
# # Update the chat message and its parent message in database
|
||||
# update_db_session_with_messages(
|
||||
# db_session=db_session,
|
||||
# chat_message_id=message_id,
|
||||
# chat_session_id=graph_config.persistence.chat_session_id,
|
||||
# is_agentic=graph_config.behavior.use_agentic_search,
|
||||
# message=final_answer,
|
||||
# citations=citation_dict,
|
||||
# research_type=None, # research_type is deprecated
|
||||
# research_plan=plan_of_record_dict,
|
||||
# final_documents=search_docs,
|
||||
# update_parent_message=True,
|
||||
# research_answer_purpose=ResearchAnswerPurpose.ANSWER,
|
||||
# token_count=num_tokens,
|
||||
# )
|
||||
|
||||
for iteration_preparation in state.iteration_instructions:
|
||||
research_agent_iteration_step = ResearchAgentIteration(
|
||||
primary_question_id=message_id,
|
||||
reasoning=iteration_preparation.reasoning,
|
||||
purpose=iteration_preparation.purpose,
|
||||
iteration_nr=iteration_preparation.iteration_nr,
|
||||
)
|
||||
db_session.add(research_agent_iteration_step)
|
||||
# for iteration_preparation in state.iteration_instructions:
|
||||
# research_agent_iteration_step = ResearchAgentIteration(
|
||||
# primary_question_id=message_id,
|
||||
# reasoning=iteration_preparation.reasoning,
|
||||
# purpose=iteration_preparation.purpose,
|
||||
# iteration_nr=iteration_preparation.iteration_nr,
|
||||
# )
|
||||
# db_session.add(research_agent_iteration_step)
|
||||
|
||||
for iteration_answer in aggregated_context.global_iteration_responses:
|
||||
# for iteration_answer in aggregated_context.global_iteration_responses:
|
||||
|
||||
retrieved_search_docs = convert_inference_sections_to_search_docs(
|
||||
list(iteration_answer.cited_documents.values())
|
||||
)
|
||||
# retrieved_search_docs = convert_inference_sections_to_search_docs(
|
||||
# list(iteration_answer.cited_documents.values())
|
||||
# )
|
||||
|
||||
# Convert SavedSearchDoc objects to JSON-serializable format
|
||||
serialized_search_docs = [doc.model_dump() for doc in retrieved_search_docs]
|
||||
# # Convert SavedSearchDoc objects to JSON-serializable format
|
||||
# serialized_search_docs = [doc.model_dump() for doc in retrieved_search_docs]
|
||||
|
||||
research_agent_iteration_sub_step = ResearchAgentIterationSubStep(
|
||||
primary_question_id=message_id,
|
||||
iteration_nr=iteration_answer.iteration_nr,
|
||||
iteration_sub_step_nr=iteration_answer.parallelization_nr,
|
||||
sub_step_instructions=iteration_answer.question,
|
||||
sub_step_tool_id=iteration_answer.tool_id,
|
||||
sub_answer=iteration_answer.answer,
|
||||
reasoning=iteration_answer.reasoning,
|
||||
claims=iteration_answer.claims,
|
||||
cited_doc_results=serialized_search_docs,
|
||||
generated_images=(
|
||||
GeneratedImageFullResult(images=iteration_answer.generated_images)
|
||||
if iteration_answer.generated_images
|
||||
else None
|
||||
),
|
||||
additional_data=iteration_answer.additional_data,
|
||||
queries=iteration_answer.queries,
|
||||
)
|
||||
db_session.add(research_agent_iteration_sub_step)
|
||||
# research_agent_iteration_sub_step = ResearchAgentIterationSubStep(
|
||||
# primary_question_id=message_id,
|
||||
# iteration_nr=iteration_answer.iteration_nr,
|
||||
# iteration_sub_step_nr=iteration_answer.parallelization_nr,
|
||||
# sub_step_instructions=iteration_answer.question,
|
||||
# sub_step_tool_id=iteration_answer.tool_id,
|
||||
# sub_answer=iteration_answer.answer,
|
||||
# reasoning=iteration_answer.reasoning,
|
||||
# claims=iteration_answer.claims,
|
||||
# cited_doc_results=serialized_search_docs,
|
||||
# generated_images=(
|
||||
# GeneratedImageFullResult(images=iteration_answer.generated_images)
|
||||
# if iteration_answer.generated_images
|
||||
# else None
|
||||
# ),
|
||||
# additional_data=iteration_answer.additional_data,
|
||||
# queries=iteration_answer.queries,
|
||||
# )
|
||||
# db_session.add(research_agent_iteration_sub_step)
|
||||
|
||||
db_session.commit()
|
||||
# db_session.commit()
|
||||
|
||||
|
||||
def logging(
|
||||
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> LoggerUpdate:
|
||||
"""
|
||||
LangGraph node to close the DR process and finalize the answer.
|
||||
"""
|
||||
# def logging(
|
||||
# state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
# ) -> LoggerUpdate:
|
||||
# """
|
||||
# LangGraph node to close the DR process and finalize the answer.
|
||||
# """
|
||||
|
||||
node_start_time = datetime.now()
|
||||
# TODO: generate final answer using all the previous steps
|
||||
# (right now, answers from each step are concatenated onto each other)
|
||||
# Also, add missing fields once usage in UI is clear.
|
||||
# node_start_time = datetime.now()
|
||||
# # TODO: generate final answer using all the previous steps
|
||||
# # (right now, answers from each step are concatenated onto each other)
|
||||
# # Also, add missing fields once usage in UI is clear.
|
||||
|
||||
current_step_nr = state.current_step_nr
|
||||
# current_step_nr = state.current_step_nr
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
base_question = state.original_question
|
||||
if not base_question:
|
||||
raise ValueError("Question is required for closer")
|
||||
# graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
# base_question = state.original_question
|
||||
# if not base_question:
|
||||
# raise ValueError("Question is required for closer")
|
||||
|
||||
aggregated_context = aggregate_context(
|
||||
state.iteration_responses, include_documents=True
|
||||
)
|
||||
# aggregated_context = aggregate_context(
|
||||
# state.iteration_responses, include_documents=True
|
||||
# )
|
||||
|
||||
all_cited_documents = aggregated_context.cited_documents
|
||||
# all_cited_documents = aggregated_context.cited_documents
|
||||
|
||||
is_internet_marker_dict = aggregated_context.is_internet_marker_dict
|
||||
# is_internet_marker_dict = aggregated_context.is_internet_marker_dict
|
||||
|
||||
final_answer = state.final_answer or ""
|
||||
llm_provider = graph_config.tooling.primary_llm.config.model_provider
|
||||
llm_model_name = graph_config.tooling.primary_llm.config.model_name
|
||||
# final_answer = state.final_answer or ""
|
||||
# llm_provider = graph_config.tooling.primary_llm.config.model_provider
|
||||
# llm_model_name = graph_config.tooling.primary_llm.config.model_name
|
||||
|
||||
llm_tokenizer = get_tokenizer(
|
||||
model_name=llm_model_name,
|
||||
provider_type=llm_provider,
|
||||
)
|
||||
num_tokens = len(llm_tokenizer.encode(final_answer or ""))
|
||||
# llm_tokenizer = get_tokenizer(
|
||||
# model_name=llm_model_name,
|
||||
# provider_type=llm_provider,
|
||||
# )
|
||||
# num_tokens = len(llm_tokenizer.encode(final_answer or ""))
|
||||
|
||||
write_custom_event(current_step_nr, OverallStop(), writer)
|
||||
# write_custom_event(current_step_nr, OverallStop(), writer)
|
||||
|
||||
# Log the research agent steps
|
||||
save_iteration(
|
||||
state,
|
||||
graph_config,
|
||||
aggregated_context,
|
||||
final_answer,
|
||||
all_cited_documents,
|
||||
is_internet_marker_dict,
|
||||
num_tokens,
|
||||
)
|
||||
# # Log the research agent steps
|
||||
# save_iteration(
|
||||
# state,
|
||||
# graph_config,
|
||||
# aggregated_context,
|
||||
# final_answer,
|
||||
# all_cited_documents,
|
||||
# is_internet_marker_dict,
|
||||
# num_tokens,
|
||||
# )
|
||||
|
||||
return LoggerUpdate(
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="main",
|
||||
node_name="logger",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
# return LoggerUpdate(
|
||||
# log_messages=[
|
||||
# get_langgraph_node_log_string(
|
||||
# graph_component="main",
|
||||
# node_name="logger",
|
||||
# node_start_time=node_start_time,
|
||||
# )
|
||||
# ],
|
||||
# )
|
||||
|
||||
@@ -1,132 +1,131 @@
|
||||
from collections.abc import Iterator
|
||||
from typing import cast
|
||||
# from collections.abc import Iterator
|
||||
# from typing import cast
|
||||
|
||||
from langchain_core.messages import AIMessageChunk
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langgraph.types import StreamWriter
|
||||
from pydantic import BaseModel
|
||||
# from langchain_core.messages import AIMessageChunk
|
||||
# from langchain_core.messages import BaseMessage
|
||||
# from langgraph.types import StreamWriter
|
||||
# from pydantic import BaseModel
|
||||
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.chat.chat_utils import saved_search_docs_from_llm_docs
|
||||
from onyx.chat.models import AgentAnswerPiece
|
||||
from onyx.chat.models import CitationInfo
|
||||
from onyx.chat.models import LlmDoc
|
||||
from onyx.chat.models import OnyxAnswerPiece
|
||||
from onyx.chat.stream_processing.answer_response_handler import AnswerResponseHandler
|
||||
from onyx.chat.stream_processing.answer_response_handler import CitationResponseHandler
|
||||
from onyx.chat.stream_processing.answer_response_handler import (
|
||||
PassThroughAnswerResponseHandler,
|
||||
)
|
||||
from onyx.chat.stream_processing.utils import map_document_id_order
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.server.query_and_chat.streaming_models import CitationDelta
|
||||
from onyx.server.query_and_chat.streaming_models import CitationStart
|
||||
from onyx.server.query_and_chat.streaming_models import MessageDelta
|
||||
from onyx.server.query_and_chat.streaming_models import MessageStart
|
||||
from onyx.server.query_and_chat.streaming_models import SectionEnd
|
||||
from onyx.utils.logger import setup_logger
|
||||
# from onyx.chat.chat_utils import saved_search_docs_from_llm_docs
|
||||
# from onyx.chat.models import AgentAnswerPiece
|
||||
# from onyx.chat.models import CitationInfo
|
||||
# from onyx.chat.models import LlmDoc
|
||||
# from onyx.chat.models import OnyxAnswerPiece
|
||||
# from onyx.chat.stream_processing.answer_response_handler import AnswerResponseHandler
|
||||
# from onyx.chat.stream_processing.answer_response_handler import CitationResponseHandler
|
||||
# from onyx.chat.stream_processing.answer_response_handler import (
|
||||
# PassThroughAnswerResponseHandler,
|
||||
# )
|
||||
# from onyx.chat.stream_processing.utils import map_document_id_order
|
||||
# from onyx.context.search.models import InferenceSection
|
||||
# from onyx.server.query_and_chat.streaming_models import CitationDelta
|
||||
# from onyx.server.query_and_chat.streaming_models import CitationStart
|
||||
# from onyx.server.query_and_chat.streaming_models import MessageDelta
|
||||
# from onyx.server.query_and_chat.streaming_models import MessageStart
|
||||
# from onyx.server.query_and_chat.streaming_models import SectionEnd
|
||||
# from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
# logger = setup_logger()
|
||||
|
||||
|
||||
class BasicSearchProcessedStreamResults(BaseModel):
|
||||
ai_message_chunk: AIMessageChunk = AIMessageChunk(content="")
|
||||
full_answer: str | None = None
|
||||
cited_references: list[InferenceSection] = []
|
||||
retrieved_documents: list[LlmDoc] = []
|
||||
# class BasicSearchProcessedStreamResults(BaseModel):
|
||||
# ai_message_chunk: AIMessageChunk = AIMessageChunk(content="")
|
||||
# full_answer: str | None = None
|
||||
# cited_references: list[InferenceSection] = []
|
||||
# retrieved_documents: list[LlmDoc] = []
|
||||
|
||||
|
||||
def process_llm_stream(
|
||||
messages: Iterator[BaseMessage],
|
||||
should_stream_answer: bool,
|
||||
writer: StreamWriter,
|
||||
ind: int,
|
||||
search_results: list[LlmDoc] | None = None,
|
||||
generate_final_answer: bool = False,
|
||||
chat_message_id: str | None = None,
|
||||
) -> BasicSearchProcessedStreamResults:
|
||||
tool_call_chunk = AIMessageChunk(content="")
|
||||
# def process_llm_stream(
|
||||
# messages: Iterator[BaseMessage],
|
||||
# should_stream_answer: bool,
|
||||
# writer: StreamWriter,
|
||||
# ind: int,
|
||||
# search_results: list[LlmDoc] | None = None,
|
||||
# generate_final_answer: bool = False,
|
||||
# chat_message_id: str | None = None,
|
||||
# ) -> BasicSearchProcessedStreamResults:
|
||||
# tool_call_chunk = AIMessageChunk(content="")
|
||||
|
||||
if search_results:
|
||||
answer_handler: AnswerResponseHandler = CitationResponseHandler(
|
||||
context_docs=search_results,
|
||||
doc_id_to_rank_map=map_document_id_order(search_results),
|
||||
)
|
||||
else:
|
||||
answer_handler = PassThroughAnswerResponseHandler()
|
||||
# if search_results:
|
||||
# answer_handler: AnswerResponseHandler = CitationResponseHandler(
|
||||
# context_docs=search_results,
|
||||
# doc_id_to_rank_map=map_document_id_order(search_results),
|
||||
# )
|
||||
# else:
|
||||
# answer_handler = PassThroughAnswerResponseHandler()
|
||||
|
||||
full_answer = ""
|
||||
start_final_answer_streaming_set = False
|
||||
# Accumulate citation infos if handler emits them
|
||||
collected_citation_infos: list[CitationInfo] = []
|
||||
# full_answer = ""
|
||||
# start_final_answer_streaming_set = False
|
||||
# # Accumulate citation infos if handler emits them
|
||||
# collected_citation_infos: list[CitationInfo] = []
|
||||
|
||||
# This stream will be the llm answer if no tool is chosen. When a tool is chosen,
|
||||
# the stream will contain AIMessageChunks with tool call information.
|
||||
for message in messages:
|
||||
# # This stream will be the llm answer if no tool is chosen. When a tool is chosen,
|
||||
# # the stream will contain AIMessageChunks with tool call information.
|
||||
# for message in messages:
|
||||
|
||||
answer_piece = message.content
|
||||
if not isinstance(answer_piece, str):
|
||||
# this is only used for logging, so fine to
|
||||
# just add the string representation
|
||||
answer_piece = str(answer_piece)
|
||||
full_answer += answer_piece
|
||||
# answer_piece = message.content
|
||||
# if not isinstance(answer_piece, str):
|
||||
# # this is only used for logging, so fine to
|
||||
# # just add the string representation
|
||||
# answer_piece = str(answer_piece)
|
||||
# full_answer += answer_piece
|
||||
|
||||
if isinstance(message, AIMessageChunk) and (
|
||||
message.tool_call_chunks or message.tool_calls
|
||||
):
|
||||
tool_call_chunk += message # type: ignore
|
||||
elif should_stream_answer:
|
||||
for response_part in answer_handler.handle_response_part(message):
|
||||
# if isinstance(message, AIMessageChunk) and (
|
||||
# message.tool_call_chunks or message.tool_calls
|
||||
# ):
|
||||
# tool_call_chunk += message # type: ignore
|
||||
# elif should_stream_answer:
|
||||
# for response_part in answer_handler.handle_response_part(message):
|
||||
|
||||
# only stream out answer parts
|
||||
if (
|
||||
isinstance(response_part, (OnyxAnswerPiece, AgentAnswerPiece))
|
||||
and generate_final_answer
|
||||
and response_part.answer_piece
|
||||
):
|
||||
if chat_message_id is None:
|
||||
raise ValueError(
|
||||
"chat_message_id is required when generating final answer"
|
||||
)
|
||||
# # only stream out answer parts
|
||||
# if (
|
||||
# isinstance(response_part, (OnyxAnswerPiece, AgentAnswerPiece))
|
||||
# and generate_final_answer
|
||||
# and response_part.answer_piece
|
||||
# ):
|
||||
# if chat_message_id is None:
|
||||
# raise ValueError(
|
||||
# "chat_message_id is required when generating final answer"
|
||||
# )
|
||||
|
||||
if not start_final_answer_streaming_set:
|
||||
# Convert LlmDocs to SavedSearchDocs
|
||||
saved_search_docs = saved_search_docs_from_llm_docs(
|
||||
search_results
|
||||
)
|
||||
write_custom_event(
|
||||
ind,
|
||||
MessageStart(content="", final_documents=saved_search_docs),
|
||||
writer,
|
||||
)
|
||||
start_final_answer_streaming_set = True
|
||||
# if not start_final_answer_streaming_set:
|
||||
# # Convert LlmDocs to SavedSearchDocs
|
||||
# saved_search_docs = saved_search_docs_from_llm_docs(
|
||||
# search_results
|
||||
# )
|
||||
# write_custom_event(
|
||||
# ind,
|
||||
# MessageStart(content="", final_documents=saved_search_docs),
|
||||
# writer,
|
||||
# )
|
||||
# start_final_answer_streaming_set = True
|
||||
|
||||
write_custom_event(
|
||||
ind,
|
||||
MessageDelta(content=response_part.answer_piece),
|
||||
writer,
|
||||
)
|
||||
# collect citation info objects
|
||||
elif isinstance(response_part, CitationInfo):
|
||||
collected_citation_infos.append(response_part)
|
||||
# write_custom_event(
|
||||
# ind,
|
||||
# MessageDelta(content=response_part.answer_piece),
|
||||
# writer,
|
||||
# )
|
||||
# # collect citation info objects
|
||||
# elif isinstance(response_part, CitationInfo):
|
||||
# collected_citation_infos.append(response_part)
|
||||
|
||||
if generate_final_answer and start_final_answer_streaming_set:
|
||||
# start_final_answer_streaming_set is only set if the answer is verbal and not a tool call
|
||||
write_custom_event(
|
||||
ind,
|
||||
SectionEnd(),
|
||||
writer,
|
||||
)
|
||||
# if generate_final_answer and start_final_answer_streaming_set:
|
||||
# # start_final_answer_streaming_set is only set if the answer is verbal and not a tool call
|
||||
# write_custom_event(
|
||||
# ind,
|
||||
# SectionEnd(),
|
||||
# writer,
|
||||
# )
|
||||
|
||||
# Emit citations section if any were collected
|
||||
if collected_citation_infos:
|
||||
write_custom_event(ind, CitationStart(), writer)
|
||||
write_custom_event(
|
||||
ind, CitationDelta(citations=collected_citation_infos), writer
|
||||
)
|
||||
write_custom_event(ind, SectionEnd(), writer)
|
||||
# # Emit citations section if any were collected
|
||||
# if collected_citation_infos:
|
||||
# write_custom_event(ind, CitationStart(), writer)
|
||||
# write_custom_event(
|
||||
# ind, CitationDelta(citations=collected_citation_infos), writer
|
||||
# )
|
||||
# write_custom_event(ind, SectionEnd(), writer)
|
||||
|
||||
logger.debug(f"Full answer: {full_answer}")
|
||||
return BasicSearchProcessedStreamResults(
|
||||
ai_message_chunk=cast(AIMessageChunk, tool_call_chunk), full_answer=full_answer
|
||||
)
|
||||
# logger.debug(f"Full answer: {full_answer}")
|
||||
# return BasicSearchProcessedStreamResults(
|
||||
# ai_message_chunk=cast(AIMessageChunk, tool_call_chunk), full_answer=full_answer
|
||||
# )
|
||||
|
||||
@@ -1,82 +1,82 @@
|
||||
from operator import add
|
||||
from typing import Annotated
|
||||
from typing import Any
|
||||
from typing import TypedDict
|
||||
# from operator import add
|
||||
# from typing import Annotated
|
||||
# from typing import Any
|
||||
# from typing import TypedDict
|
||||
|
||||
from pydantic import BaseModel
|
||||
# from pydantic import BaseModel
|
||||
|
||||
from onyx.agents.agent_search.core_state import CoreState
|
||||
from onyx.agents.agent_search.dr.models import IterationAnswer
|
||||
from onyx.agents.agent_search.dr.models import IterationInstructions
|
||||
from onyx.agents.agent_search.dr.models import OrchestrationClarificationInfo
|
||||
from onyx.agents.agent_search.dr.models import OrchestrationPlan
|
||||
from onyx.agents.agent_search.dr.models import OrchestratorTool
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.db.connector import DocumentSource
|
||||
# from onyx.agents.agent_search.core_state import CoreState
|
||||
# from onyx.agents.agent_search.dr.models import IterationAnswer
|
||||
# from onyx.agents.agent_search.dr.models import IterationInstructions
|
||||
# from onyx.agents.agent_search.dr.models import OrchestrationClarificationInfo
|
||||
# from onyx.agents.agent_search.dr.models import OrchestrationPlan
|
||||
# from onyx.agents.agent_search.dr.models import OrchestratorTool
|
||||
# from onyx.context.search.models import InferenceSection
|
||||
# from onyx.db.connector import DocumentSource
|
||||
|
||||
### States ###
|
||||
# ### States ###
|
||||
|
||||
|
||||
class LoggerUpdate(BaseModel):
|
||||
log_messages: Annotated[list[str], add] = []
|
||||
# class LoggerUpdate(BaseModel):
|
||||
# log_messages: Annotated[list[str], add] = []
|
||||
|
||||
|
||||
class OrchestrationUpdate(LoggerUpdate):
|
||||
tools_used: Annotated[list[str], add] = []
|
||||
query_list: list[str] = []
|
||||
iteration_nr: int = 0
|
||||
current_step_nr: int = 1
|
||||
plan_of_record: OrchestrationPlan | None = None # None for Thoughtful
|
||||
remaining_time_budget: float = 2.0 # set by default to about 2 searches
|
||||
num_closer_suggestions: int = 0 # how many times the closer was suggested
|
||||
gaps: list[str] = (
|
||||
[]
|
||||
) # gaps that may be identified by the closer before being able to answer the question.
|
||||
iteration_instructions: Annotated[list[IterationInstructions], add] = []
|
||||
# class OrchestrationUpdate(LoggerUpdate):
|
||||
# tools_used: Annotated[list[str], add] = []
|
||||
# query_list: list[str] = []
|
||||
# iteration_nr: int = 0
|
||||
# current_step_nr: int = 1
|
||||
# plan_of_record: OrchestrationPlan | None = None # None for Thoughtful
|
||||
# remaining_time_budget: float = 2.0 # set by default to about 2 searches
|
||||
# num_closer_suggestions: int = 0 # how many times the closer was suggested
|
||||
# gaps: list[str] = (
|
||||
# []
|
||||
# ) # gaps that may be identified by the closer before being able to answer the question.
|
||||
# iteration_instructions: Annotated[list[IterationInstructions], add] = []
|
||||
|
||||
|
||||
class OrchestrationSetup(OrchestrationUpdate):
|
||||
original_question: str | None = None
|
||||
chat_history_string: str | None = None
|
||||
clarification: OrchestrationClarificationInfo | None = None
|
||||
available_tools: dict[str, OrchestratorTool] | None = None
|
||||
num_closer_suggestions: int = 0 # how many times the closer was suggested
|
||||
# class OrchestrationSetup(OrchestrationUpdate):
|
||||
# original_question: str | None = None
|
||||
# chat_history_string: str | None = None
|
||||
# clarification: OrchestrationClarificationInfo | None = None
|
||||
# available_tools: dict[str, OrchestratorTool] | None = None
|
||||
# num_closer_suggestions: int = 0 # how many times the closer was suggested
|
||||
|
||||
active_source_types: list[DocumentSource] | None = None
|
||||
active_source_types_descriptions: str | None = None
|
||||
assistant_system_prompt: str | None = None
|
||||
assistant_task_prompt: str | None = None
|
||||
uploaded_test_context: str | None = None
|
||||
uploaded_image_context: list[dict[str, Any]] | None = None
|
||||
# active_source_types: list[DocumentSource] | None = None
|
||||
# active_source_types_descriptions: str | None = None
|
||||
# assistant_system_prompt: str | None = None
|
||||
# assistant_task_prompt: str | None = None
|
||||
# uploaded_test_context: str | None = None
|
||||
# uploaded_image_context: list[dict[str, Any]] | None = None
|
||||
|
||||
|
||||
class AnswerUpdate(LoggerUpdate):
|
||||
iteration_responses: Annotated[list[IterationAnswer], add] = []
|
||||
# class AnswerUpdate(LoggerUpdate):
|
||||
# iteration_responses: Annotated[list[IterationAnswer], add] = []
|
||||
|
||||
|
||||
class FinalUpdate(LoggerUpdate):
|
||||
final_answer: str | None = None
|
||||
all_cited_documents: list[InferenceSection] = []
|
||||
# class FinalUpdate(LoggerUpdate):
|
||||
# final_answer: str | None = None
|
||||
# all_cited_documents: list[InferenceSection] = []
|
||||
|
||||
|
||||
## Graph Input State
|
||||
class MainInput(CoreState):
|
||||
pass
|
||||
# ## Graph Input State
|
||||
# class MainInput(CoreState):
|
||||
# pass
|
||||
|
||||
|
||||
## Graph State
|
||||
class MainState(
|
||||
# This includes the core state
|
||||
MainInput,
|
||||
OrchestrationSetup,
|
||||
AnswerUpdate,
|
||||
FinalUpdate,
|
||||
):
|
||||
pass
|
||||
# ## Graph State
|
||||
# class MainState(
|
||||
# # This includes the core state
|
||||
# MainInput,
|
||||
# OrchestrationSetup,
|
||||
# AnswerUpdate,
|
||||
# FinalUpdate,
|
||||
# ):
|
||||
# pass
|
||||
|
||||
|
||||
## Graph Output State
|
||||
class MainOutput(TypedDict):
|
||||
log_messages: list[str]
|
||||
final_answer: str | None
|
||||
all_cited_documents: list[InferenceSection]
|
||||
# ## Graph Output State
|
||||
# class MainOutput(TypedDict):
|
||||
# log_messages: list[str]
|
||||
# final_answer: str | None
|
||||
# all_cited_documents: list[InferenceSection]
|
||||
|
||||
@@ -1,47 +1,47 @@
|
||||
from datetime import datetime
|
||||
# from datetime import datetime
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
# from langchain_core.runnables import RunnableConfig
|
||||
# from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.dr.states import LoggerUpdate
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.server.query_and_chat.streaming_models import SearchToolStart
|
||||
from onyx.utils.logger import setup_logger
|
||||
# from onyx.agents.agent_search.dr.states import LoggerUpdate
|
||||
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
|
||||
# from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
# get_langgraph_node_log_string,
|
||||
# )
|
||||
# from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
# from onyx.server.query_and_chat.streaming_models import SearchToolStart
|
||||
# from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
# logger = setup_logger()
|
||||
|
||||
|
||||
def basic_search_branch(
|
||||
state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> LoggerUpdate:
|
||||
"""
|
||||
LangGraph node to perform a standard search as part of the DR process.
|
||||
"""
|
||||
# def basic_search_branch(
|
||||
# state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
# ) -> LoggerUpdate:
|
||||
# """
|
||||
# LangGraph node to perform a standard search as part of the DR process.
|
||||
# """
|
||||
|
||||
node_start_time = datetime.now()
|
||||
iteration_nr = state.iteration_nr
|
||||
current_step_nr = state.current_step_nr
|
||||
# node_start_time = datetime.now()
|
||||
# iteration_nr = state.iteration_nr
|
||||
# current_step_nr = state.current_step_nr
|
||||
|
||||
logger.debug(f"Search start for Basic Search {iteration_nr} at {datetime.now()}")
|
||||
# logger.debug(f"Search start for Basic Search {iteration_nr} at {datetime.now()}")
|
||||
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
SearchToolStart(
|
||||
is_internet_search=False,
|
||||
),
|
||||
writer,
|
||||
)
|
||||
# write_custom_event(
|
||||
# current_step_nr,
|
||||
# SearchToolStart(
|
||||
# is_internet_search=False,
|
||||
# ),
|
||||
# writer,
|
||||
# )
|
||||
|
||||
return LoggerUpdate(
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="basic_search",
|
||||
node_name="branching",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
# return LoggerUpdate(
|
||||
# log_messages=[
|
||||
# get_langgraph_node_log_string(
|
||||
# graph_component="basic_search",
|
||||
# node_name="branching",
|
||||
# node_start_time=node_start_time,
|
||||
# )
|
||||
# ],
|
||||
# )
|
||||
|
||||
@@ -1,286 +1,261 @@
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
from uuid import UUID
|
||||
# import re
|
||||
# from datetime import datetime
|
||||
# from typing import cast
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
# from langchain_core.runnables import RunnableConfig
|
||||
# from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.dr.enums import ResearchType
|
||||
from onyx.agents.agent_search.dr.models import BaseSearchProcessingResponse
|
||||
from onyx.agents.agent_search.dr.models import IterationAnswer
|
||||
from onyx.agents.agent_search.dr.models import SearchAnswer
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import BranchUpdate
|
||||
from onyx.agents.agent_search.dr.utils import convert_inference_sections_to_search_docs
|
||||
from onyx.agents.agent_search.dr.utils import extract_document_citations
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import build_document_context
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.llm import invoke_llm_json
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.agents.agent_search.utils import create_question_prompt
|
||||
from onyx.chat.models import LlmDoc
|
||||
from onyx.configs.agent_configs import TF_DR_TIMEOUT_LONG
|
||||
from onyx.configs.agent_configs import TF_DR_TIMEOUT_SHORT
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.db.connector import DocumentSource
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.prompts.dr_prompts import BASE_SEARCH_PROCESSING_PROMPT
|
||||
from onyx.prompts.dr_prompts import INTERNAL_SEARCH_PROMPTS
|
||||
from onyx.secondary_llm_flows.source_filter import strings_to_document_sources
|
||||
from onyx.server.query_and_chat.streaming_models import SearchToolDelta
|
||||
from onyx.tools.models import SearchToolOverrideKwargs
|
||||
from onyx.tools.tool_implementations.search.search_tool import (
|
||||
SEARCH_RESPONSE_SUMMARY_ID,
|
||||
)
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from onyx.utils.logger import setup_logger
|
||||
# from onyx.agents.agent_search.dr.models import BaseSearchProcessingResponse
|
||||
# from onyx.agents.agent_search.dr.models import IterationAnswer
|
||||
# from onyx.agents.agent_search.dr.models import SearchAnswer
|
||||
# from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
|
||||
# from onyx.agents.agent_search.dr.sub_agents.states import BranchUpdate
|
||||
# from onyx.agents.agent_search.dr.utils import convert_inference_sections_to_search_docs
|
||||
# from onyx.agents.agent_search.dr.utils import extract_document_citations
|
||||
# from onyx.agents.agent_search.kb_search.graph_utils import build_document_context
|
||||
# from onyx.agents.agent_search.models import GraphConfig
|
||||
# from onyx.agents.agent_search.shared_graph_utils.llm import invoke_llm_json
|
||||
# from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
# get_langgraph_node_log_string,
|
||||
# )
|
||||
# from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
# from onyx.agents.agent_search.utils import create_question_prompt
|
||||
# from onyx.chat.models import LlmDoc
|
||||
# from onyx.configs.agent_configs import TF_DR_TIMEOUT_LONG
|
||||
# from onyx.configs.agent_configs import TF_DR_TIMEOUT_SHORT
|
||||
# from onyx.context.search.models import InferenceSection
|
||||
# from onyx.db.connector import DocumentSource
|
||||
# from onyx.prompts.dr_prompts import BASE_SEARCH_PROCESSING_PROMPT
|
||||
# from onyx.prompts.dr_prompts import INTERNAL_SEARCH_PROMPTS
|
||||
# from onyx.secondary_llm_flows.source_filter import strings_to_document_sources
|
||||
# from onyx.server.query_and_chat.streaming_models import SearchToolDelta
|
||||
# from onyx.tools.models import SearchToolOverrideKwargs
|
||||
# from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
# from onyx.tools.tool_implementations.search_like_tool_utils import (
|
||||
# SEARCH_INFERENCE_SECTIONS_ID,
|
||||
# )
|
||||
# from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
# logger = setup_logger()
|
||||
|
||||
|
||||
def basic_search(
|
||||
state: BranchInput,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
) -> BranchUpdate:
|
||||
"""
|
||||
LangGraph node to perform a standard search as part of the DR process.
|
||||
"""
|
||||
# def basic_search(
|
||||
# state: BranchInput,
|
||||
# config: RunnableConfig,
|
||||
# writer: StreamWriter = lambda _: None,
|
||||
# ) -> BranchUpdate:
|
||||
# """
|
||||
# LangGraph node to perform a standard search as part of the DR process.
|
||||
# """
|
||||
|
||||
node_start_time = datetime.now()
|
||||
iteration_nr = state.iteration_nr
|
||||
parallelization_nr = state.parallelization_nr
|
||||
current_step_nr = state.current_step_nr
|
||||
assistant_system_prompt = state.assistant_system_prompt
|
||||
assistant_task_prompt = state.assistant_task_prompt
|
||||
# node_start_time = datetime.now()
|
||||
# iteration_nr = state.iteration_nr
|
||||
# parallelization_nr = state.parallelization_nr
|
||||
# current_step_nr = state.current_step_nr
|
||||
# assistant_system_prompt = state.assistant_system_prompt
|
||||
# assistant_task_prompt = state.assistant_task_prompt
|
||||
|
||||
branch_query = state.branch_question
|
||||
if not branch_query:
|
||||
raise ValueError("branch_query is not set")
|
||||
# branch_query = state.branch_question
|
||||
# if not branch_query:
|
||||
# raise ValueError("branch_query is not set")
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
base_question = graph_config.inputs.prompt_builder.raw_user_query
|
||||
research_type = graph_config.behavior.research_type
|
||||
# graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
# base_question = graph_config.inputs.prompt_builder.raw_user_query
|
||||
# use_agentic_search = graph_config.behavior.use_agentic_search
|
||||
|
||||
if not state.available_tools:
|
||||
raise ValueError("available_tools is not set")
|
||||
# if not state.available_tools:
|
||||
# raise ValueError("available_tools is not set")
|
||||
|
||||
elif len(state.tools_used) == 0:
|
||||
raise ValueError("tools_used is empty")
|
||||
# elif len(state.tools_used) == 0:
|
||||
# raise ValueError("tools_used is empty")
|
||||
|
||||
search_tool_info = state.available_tools[state.tools_used[-1]]
|
||||
search_tool = cast(SearchTool, search_tool_info.tool_object)
|
||||
force_use_tool = graph_config.tooling.force_use_tool
|
||||
# search_tool_info = state.available_tools[state.tools_used[-1]]
|
||||
# search_tool = cast(SearchTool, search_tool_info.tool_object)
|
||||
# graph_config.tooling.force_use_tool
|
||||
|
||||
# sanity check
|
||||
if search_tool != graph_config.tooling.search_tool:
|
||||
raise ValueError("search_tool does not match the configured search tool")
|
||||
# # sanity check
|
||||
# if search_tool != graph_config.tooling.search_tool:
|
||||
# raise ValueError("search_tool does not match the configured search tool")
|
||||
|
||||
# rewrite query and identify source types
|
||||
active_source_types_str = ", ".join(
|
||||
[source.value for source in state.active_source_types or []]
|
||||
)
|
||||
# # rewrite query and identify source types
|
||||
# active_source_types_str = ", ".join(
|
||||
# [source.value for source in state.active_source_types or []]
|
||||
# )
|
||||
|
||||
base_search_processing_prompt = BASE_SEARCH_PROCESSING_PROMPT.build(
|
||||
active_source_types_str=active_source_types_str,
|
||||
branch_query=branch_query,
|
||||
current_time=datetime.now().strftime("%Y-%m-%d %H:%M"),
|
||||
)
|
||||
# base_search_processing_prompt = BASE_SEARCH_PROCESSING_PROMPT.build(
|
||||
# active_source_types_str=active_source_types_str,
|
||||
# branch_query=branch_query,
|
||||
# current_time=datetime.now().strftime("%Y-%m-%d %H:%M"),
|
||||
# )
|
||||
|
||||
try:
|
||||
search_processing = invoke_llm_json(
|
||||
llm=graph_config.tooling.primary_llm,
|
||||
prompt=create_question_prompt(
|
||||
assistant_system_prompt, base_search_processing_prompt
|
||||
),
|
||||
schema=BaseSearchProcessingResponse,
|
||||
timeout_override=TF_DR_TIMEOUT_SHORT,
|
||||
# max_tokens=100,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Could not process query: {e}")
|
||||
raise e
|
||||
# try:
|
||||
# search_processing = invoke_llm_json(
|
||||
# llm=graph_config.tooling.primary_llm,
|
||||
# prompt=create_question_prompt(
|
||||
# assistant_system_prompt, base_search_processing_prompt
|
||||
# ),
|
||||
# schema=BaseSearchProcessingResponse,
|
||||
# timeout_override=TF_DR_TIMEOUT_SHORT,
|
||||
# # max_tokens=100,
|
||||
# )
|
||||
# except Exception as e:
|
||||
# logger.error(f"Could not process query: {e}")
|
||||
# raise e
|
||||
|
||||
rewritten_query = search_processing.rewritten_query
|
||||
# rewritten_query = search_processing.rewritten_query
|
||||
|
||||
# give back the query so we can render it in the UI
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
SearchToolDelta(
|
||||
queries=[rewritten_query],
|
||||
documents=[],
|
||||
),
|
||||
writer,
|
||||
)
|
||||
# # give back the query so we can render it in the UI
|
||||
# write_custom_event(
|
||||
# current_step_nr,
|
||||
# SearchToolDelta(
|
||||
# queries=[rewritten_query],
|
||||
# documents=[],
|
||||
# ),
|
||||
# writer,
|
||||
# )
|
||||
|
||||
implied_start_date = search_processing.time_filter
|
||||
# implied_start_date = search_processing.time_filter
|
||||
|
||||
# Validate time_filter format if it exists
|
||||
implied_time_filter = None
|
||||
if implied_start_date:
|
||||
# # Validate time_filter format if it exists
|
||||
# implied_time_filter = None
|
||||
# if implied_start_date:
|
||||
|
||||
# Check if time_filter is in YYYY-MM-DD format
|
||||
date_pattern = r"^\d{4}-\d{2}-\d{2}$"
|
||||
if re.match(date_pattern, implied_start_date):
|
||||
implied_time_filter = datetime.strptime(implied_start_date, "%Y-%m-%d")
|
||||
# # Check if time_filter is in YYYY-MM-DD format
|
||||
# date_pattern = r"^\d{4}-\d{2}-\d{2}$"
|
||||
# if re.match(date_pattern, implied_start_date):
|
||||
# implied_time_filter = datetime.strptime(implied_start_date, "%Y-%m-%d")
|
||||
|
||||
specified_source_types: list[DocumentSource] | None = (
|
||||
strings_to_document_sources(search_processing.specified_source_types)
|
||||
if search_processing.specified_source_types
|
||||
else None
|
||||
)
|
||||
# specified_source_types: list[DocumentSource] | None = (
|
||||
# strings_to_document_sources(search_processing.specified_source_types)
|
||||
# if search_processing.specified_source_types
|
||||
# else None
|
||||
# )
|
||||
|
||||
if specified_source_types is not None and len(specified_source_types) == 0:
|
||||
specified_source_types = None
|
||||
# if specified_source_types is not None and len(specified_source_types) == 0:
|
||||
# specified_source_types = None
|
||||
|
||||
logger.debug(
|
||||
f"Search start for Standard Search {iteration_nr}.{parallelization_nr} at {datetime.now()}"
|
||||
)
|
||||
# logger.debug(
|
||||
# f"Search start for Standard Search {iteration_nr}.{parallelization_nr} at {datetime.now()}"
|
||||
# )
|
||||
|
||||
retrieved_docs: list[InferenceSection] = []
|
||||
callback_container: list[list[InferenceSection]] = []
|
||||
# retrieved_docs: list[InferenceSection] = []
|
||||
|
||||
user_file_ids: list[UUID] | None = None
|
||||
project_id: int | None = None
|
||||
if force_use_tool.override_kwargs and isinstance(
|
||||
force_use_tool.override_kwargs, SearchToolOverrideKwargs
|
||||
):
|
||||
override_kwargs = force_use_tool.override_kwargs
|
||||
user_file_ids = override_kwargs.user_file_ids
|
||||
project_id = override_kwargs.project_id
|
||||
# for tool_response in search_tool.run(
|
||||
# query=rewritten_query,
|
||||
# document_sources=specified_source_types,
|
||||
# time_filter=implied_time_filter,
|
||||
# override_kwargs=SearchToolOverrideKwargs(original_query=rewritten_query),
|
||||
# ):
|
||||
# # get retrieved docs to send to the rest of the graph
|
||||
# if tool_response.id == SEARCH_INFERENCE_SECTIONS_ID:
|
||||
# retrieved_docs = cast(list[InferenceSection], tool_response.response)
|
||||
|
||||
# new db session to avoid concurrency issues
|
||||
with get_session_with_current_tenant() as search_db_session:
|
||||
for tool_response in search_tool.run(
|
||||
query=rewritten_query,
|
||||
document_sources=specified_source_types,
|
||||
time_filter=implied_time_filter,
|
||||
override_kwargs=SearchToolOverrideKwargs(
|
||||
force_no_rerank=True,
|
||||
alternate_db_session=search_db_session,
|
||||
retrieved_sections_callback=callback_container.append,
|
||||
skip_query_analysis=True,
|
||||
original_query=rewritten_query,
|
||||
user_file_ids=user_file_ids,
|
||||
project_id=project_id,
|
||||
),
|
||||
):
|
||||
# get retrieved docs to send to the rest of the graph
|
||||
if tool_response.id == SEARCH_RESPONSE_SUMMARY_ID:
|
||||
response = cast(SearchResponseSummary, tool_response.response)
|
||||
retrieved_docs = response.top_sections
|
||||
# break
|
||||
|
||||
break
|
||||
# # render the retrieved docs in the UI
|
||||
# write_custom_event(
|
||||
# current_step_nr,
|
||||
# SearchToolDelta(
|
||||
# queries=[],
|
||||
# documents=convert_inference_sections_to_search_docs(
|
||||
# retrieved_docs, is_internet=False
|
||||
# ),
|
||||
# ),
|
||||
# writer,
|
||||
# )
|
||||
|
||||
# render the retrieved docs in the UI
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
SearchToolDelta(
|
||||
queries=[],
|
||||
documents=convert_inference_sections_to_search_docs(
|
||||
retrieved_docs, is_internet=False
|
||||
),
|
||||
),
|
||||
writer,
|
||||
)
|
||||
# document_texts_list = []
|
||||
|
||||
document_texts_list = []
|
||||
# for doc_num, retrieved_doc in enumerate(retrieved_docs[:15]):
|
||||
# if not isinstance(retrieved_doc, (InferenceSection, LlmDoc)):
|
||||
# raise ValueError(f"Unexpected document type: {type(retrieved_doc)}")
|
||||
# chunk_text = build_document_context(retrieved_doc, doc_num + 1)
|
||||
# document_texts_list.append(chunk_text)
|
||||
|
||||
for doc_num, retrieved_doc in enumerate(retrieved_docs[:15]):
|
||||
if not isinstance(retrieved_doc, (InferenceSection, LlmDoc)):
|
||||
raise ValueError(f"Unexpected document type: {type(retrieved_doc)}")
|
||||
chunk_text = build_document_context(retrieved_doc, doc_num + 1)
|
||||
document_texts_list.append(chunk_text)
|
||||
# document_texts = "\n\n".join(document_texts_list)
|
||||
|
||||
document_texts = "\n\n".join(document_texts_list)
|
||||
# logger.debug(
|
||||
# f"Search end/LLM start for Standard Search {iteration_nr}.{parallelization_nr} at {datetime.now()}"
|
||||
# )
|
||||
|
||||
logger.debug(
|
||||
f"Search end/LLM start for Standard Search {iteration_nr}.{parallelization_nr} at {datetime.now()}"
|
||||
)
|
||||
# # Built prompt
|
||||
|
||||
# Built prompt
|
||||
# if use_agentic_search:
|
||||
# search_prompt = INTERNAL_SEARCH_PROMPTS[use_agentic_search].build(
|
||||
# search_query=branch_query,
|
||||
# base_question=base_question,
|
||||
# document_text=document_texts,
|
||||
# )
|
||||
|
||||
if research_type == ResearchType.DEEP:
|
||||
search_prompt = INTERNAL_SEARCH_PROMPTS[research_type].build(
|
||||
search_query=branch_query,
|
||||
base_question=base_question,
|
||||
document_text=document_texts,
|
||||
)
|
||||
# # Run LLM
|
||||
|
||||
# Run LLM
|
||||
# # search_answer_json = None
|
||||
# search_answer_json = invoke_llm_json(
|
||||
# llm=graph_config.tooling.primary_llm,
|
||||
# prompt=create_question_prompt(
|
||||
# assistant_system_prompt, search_prompt + (assistant_task_prompt or "")
|
||||
# ),
|
||||
# schema=SearchAnswer,
|
||||
# timeout_override=TF_DR_TIMEOUT_LONG,
|
||||
# # max_tokens=1500,
|
||||
# )
|
||||
|
||||
# search_answer_json = None
|
||||
search_answer_json = invoke_llm_json(
|
||||
llm=graph_config.tooling.primary_llm,
|
||||
prompt=create_question_prompt(
|
||||
assistant_system_prompt, search_prompt + (assistant_task_prompt or "")
|
||||
),
|
||||
schema=SearchAnswer,
|
||||
timeout_override=TF_DR_TIMEOUT_LONG,
|
||||
# max_tokens=1500,
|
||||
)
|
||||
# logger.debug(
|
||||
# f"LLM/all done for Standard Search {iteration_nr}.{parallelization_nr} at {datetime.now()}"
|
||||
# )
|
||||
|
||||
logger.debug(
|
||||
f"LLM/all done for Standard Search {iteration_nr}.{parallelization_nr} at {datetime.now()}"
|
||||
)
|
||||
# # get cited documents
|
||||
# answer_string = search_answer_json.answer
|
||||
# claims = search_answer_json.claims or []
|
||||
# reasoning = search_answer_json.reasoning
|
||||
# # answer_string = ""
|
||||
# # claims = []
|
||||
|
||||
# get cited documents
|
||||
answer_string = search_answer_json.answer
|
||||
claims = search_answer_json.claims or []
|
||||
reasoning = search_answer_json.reasoning
|
||||
# answer_string = ""
|
||||
# claims = []
|
||||
# (
|
||||
# citation_numbers,
|
||||
# answer_string,
|
||||
# claims,
|
||||
# ) = extract_document_citations(answer_string, claims)
|
||||
|
||||
(
|
||||
citation_numbers,
|
||||
answer_string,
|
||||
claims,
|
||||
) = extract_document_citations(answer_string, claims)
|
||||
# if citation_numbers and (
|
||||
# (max(citation_numbers) > len(retrieved_docs)) or min(citation_numbers) < 1
|
||||
# ):
|
||||
# raise ValueError("Citation numbers are out of range for retrieved docs.")
|
||||
|
||||
if citation_numbers and (
|
||||
(max(citation_numbers) > len(retrieved_docs)) or min(citation_numbers) < 1
|
||||
):
|
||||
raise ValueError("Citation numbers are out of range for retrieved docs.")
|
||||
# cited_documents = {
|
||||
# citation_number: retrieved_docs[citation_number - 1]
|
||||
# for citation_number in citation_numbers
|
||||
# }
|
||||
|
||||
cited_documents = {
|
||||
citation_number: retrieved_docs[citation_number - 1]
|
||||
for citation_number in citation_numbers
|
||||
}
|
||||
# else:
|
||||
# answer_string = ""
|
||||
# claims = []
|
||||
# cited_documents = {
|
||||
# doc_num + 1: retrieved_doc
|
||||
# for doc_num, retrieved_doc in enumerate(retrieved_docs[:15])
|
||||
# }
|
||||
# reasoning = ""
|
||||
|
||||
else:
|
||||
answer_string = ""
|
||||
claims = []
|
||||
cited_documents = {
|
||||
doc_num + 1: retrieved_doc
|
||||
for doc_num, retrieved_doc in enumerate(retrieved_docs[:15])
|
||||
}
|
||||
reasoning = ""
|
||||
|
||||
return BranchUpdate(
|
||||
branch_iteration_responses=[
|
||||
IterationAnswer(
|
||||
tool=search_tool_info.llm_path,
|
||||
tool_id=search_tool_info.tool_id,
|
||||
iteration_nr=iteration_nr,
|
||||
parallelization_nr=parallelization_nr,
|
||||
question=branch_query,
|
||||
answer=answer_string,
|
||||
claims=claims,
|
||||
cited_documents=cited_documents,
|
||||
reasoning=reasoning,
|
||||
additional_data=None,
|
||||
)
|
||||
],
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="basic_search",
|
||||
node_name="searching",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
# return BranchUpdate(
|
||||
# branch_iteration_responses=[
|
||||
# IterationAnswer(
|
||||
# tool=search_tool_info.llm_path,
|
||||
# tool_id=search_tool_info.tool_id,
|
||||
# iteration_nr=iteration_nr,
|
||||
# parallelization_nr=parallelization_nr,
|
||||
# question=branch_query,
|
||||
# answer=answer_string,
|
||||
# claims=claims,
|
||||
# cited_documents=cited_documents,
|
||||
# reasoning=reasoning,
|
||||
# additional_data=None,
|
||||
# )
|
||||
# ],
|
||||
# log_messages=[
|
||||
# get_langgraph_node_log_string(
|
||||
# graph_component="basic_search",
|
||||
# node_name="searching",
|
||||
# node_start_time=node_start_time,
|
||||
# )
|
||||
# ],
|
||||
# )
|
||||
|
||||
@@ -1,77 +1,77 @@
|
||||
from datetime import datetime
|
||||
# from datetime import datetime
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
# from langchain_core.runnables import RunnableConfig
|
||||
# from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentUpdate
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.context.search.models import SavedSearchDoc
|
||||
from onyx.context.search.models import SearchDoc
|
||||
from onyx.server.query_and_chat.streaming_models import SectionEnd
|
||||
from onyx.utils.logger import setup_logger
|
||||
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
|
||||
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentUpdate
|
||||
# from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
# get_langgraph_node_log_string,
|
||||
# )
|
||||
# from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
# from onyx.context.search.models import SavedSearchDoc
|
||||
# from onyx.context.search.models import SearchDoc
|
||||
# from onyx.server.query_and_chat.streaming_models import SectionEnd
|
||||
# from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
# logger = setup_logger()
|
||||
|
||||
|
||||
def is_reducer(
|
||||
state: SubAgentMainState,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
) -> SubAgentUpdate:
|
||||
"""
|
||||
LangGraph node to perform a standard search as part of the DR process.
|
||||
"""
|
||||
# def is_reducer(
|
||||
# state: SubAgentMainState,
|
||||
# config: RunnableConfig,
|
||||
# writer: StreamWriter = lambda _: None,
|
||||
# ) -> SubAgentUpdate:
|
||||
# """
|
||||
# LangGraph node to perform a standard search as part of the DR process.
|
||||
# """
|
||||
|
||||
node_start_time = datetime.now()
|
||||
# node_start_time = datetime.now()
|
||||
|
||||
branch_updates = state.branch_iteration_responses
|
||||
current_iteration = state.iteration_nr
|
||||
current_step_nr = state.current_step_nr
|
||||
# branch_updates = state.branch_iteration_responses
|
||||
# current_iteration = state.iteration_nr
|
||||
# current_step_nr = state.current_step_nr
|
||||
|
||||
new_updates = [
|
||||
update for update in branch_updates if update.iteration_nr == current_iteration
|
||||
]
|
||||
# new_updates = [
|
||||
# update for update in branch_updates if update.iteration_nr == current_iteration
|
||||
# ]
|
||||
|
||||
[update.question for update in new_updates]
|
||||
doc_lists = [list(update.cited_documents.values()) for update in new_updates]
|
||||
# [update.question for update in new_updates]
|
||||
# doc_lists = [list(update.cited_documents.values()) for update in new_updates]
|
||||
|
||||
doc_list = []
|
||||
# doc_list = []
|
||||
|
||||
for xs in doc_lists:
|
||||
for x in xs:
|
||||
doc_list.append(x)
|
||||
# for xs in doc_lists:
|
||||
# for x in xs:
|
||||
# doc_list.append(x)
|
||||
|
||||
# Convert InferenceSections to SavedSearchDocs
|
||||
search_docs = SearchDoc.from_chunks_or_sections(doc_list)
|
||||
retrieved_saved_search_docs = [
|
||||
SavedSearchDoc.from_search_doc(search_doc, db_doc_id=0)
|
||||
for search_doc in search_docs
|
||||
]
|
||||
# # Convert InferenceSections to SavedSearchDocs
|
||||
# search_docs = SearchDoc.from_chunks_or_sections(doc_list)
|
||||
# retrieved_saved_search_docs = [
|
||||
# SavedSearchDoc.from_search_doc(search_doc, db_doc_id=0)
|
||||
# for search_doc in search_docs
|
||||
# ]
|
||||
|
||||
for retrieved_saved_search_doc in retrieved_saved_search_docs:
|
||||
retrieved_saved_search_doc.is_internet = False
|
||||
# for retrieved_saved_search_doc in retrieved_saved_search_docs:
|
||||
# retrieved_saved_search_doc.is_internet = False
|
||||
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
SectionEnd(),
|
||||
writer,
|
||||
)
|
||||
# write_custom_event(
|
||||
# current_step_nr,
|
||||
# SectionEnd(),
|
||||
# writer,
|
||||
# )
|
||||
|
||||
current_step_nr += 1
|
||||
# current_step_nr += 1
|
||||
|
||||
return SubAgentUpdate(
|
||||
iteration_responses=new_updates,
|
||||
current_step_nr=current_step_nr,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="basic_search",
|
||||
node_name="consolidation",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
# return SubAgentUpdate(
|
||||
# iteration_responses=new_updates,
|
||||
# current_step_nr=current_step_nr,
|
||||
# log_messages=[
|
||||
# get_langgraph_node_log_string(
|
||||
# graph_component="basic_search",
|
||||
# node_name="consolidation",
|
||||
# node_start_time=node_start_time,
|
||||
# )
|
||||
# ],
|
||||
# )
|
||||
|
||||
@@ -1,50 +1,50 @@
|
||||
from langgraph.graph import END
|
||||
from langgraph.graph import START
|
||||
from langgraph.graph import StateGraph
|
||||
# from langgraph.graph import END
|
||||
# from langgraph.graph import START
|
||||
# from langgraph.graph import StateGraph
|
||||
|
||||
from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_basic_search_1_branch import (
|
||||
basic_search_branch,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_basic_search_2_act import (
|
||||
basic_search,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_basic_search_3_reduce import (
|
||||
is_reducer,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_image_generation_conditional_edges import (
|
||||
branching_router,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
|
||||
from onyx.utils.logger import setup_logger
|
||||
# from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_basic_search_1_branch import (
|
||||
# basic_search_branch,
|
||||
# )
|
||||
# from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_basic_search_2_act import (
|
||||
# basic_search,
|
||||
# )
|
||||
# from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_basic_search_3_reduce import (
|
||||
# is_reducer,
|
||||
# )
|
||||
# from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_image_generation_conditional_edges import (
|
||||
# branching_router,
|
||||
# )
|
||||
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
|
||||
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
|
||||
# from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
# logger = setup_logger()
|
||||
|
||||
|
||||
def dr_basic_search_graph_builder() -> StateGraph:
|
||||
"""
|
||||
LangGraph graph builder for Web Search Sub-Agent
|
||||
"""
|
||||
# def dr_basic_search_graph_builder() -> StateGraph:
|
||||
# """
|
||||
# LangGraph graph builder for Web Search Sub-Agent
|
||||
# """
|
||||
|
||||
graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
|
||||
# graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
|
||||
|
||||
### Add nodes ###
|
||||
# ### Add nodes ###
|
||||
|
||||
graph.add_node("branch", basic_search_branch)
|
||||
# graph.add_node("branch", basic_search_branch)
|
||||
|
||||
graph.add_node("act", basic_search)
|
||||
# graph.add_node("act", basic_search)
|
||||
|
||||
graph.add_node("reducer", is_reducer)
|
||||
# graph.add_node("reducer", is_reducer)
|
||||
|
||||
### Add edges ###
|
||||
# ### Add edges ###
|
||||
|
||||
graph.add_edge(start_key=START, end_key="branch")
|
||||
# graph.add_edge(start_key=START, end_key="branch")
|
||||
|
||||
graph.add_conditional_edges("branch", branching_router)
|
||||
# graph.add_conditional_edges("branch", branching_router)
|
||||
|
||||
graph.add_edge(start_key="act", end_key="reducer")
|
||||
# graph.add_edge(start_key="act", end_key="reducer")
|
||||
|
||||
graph.add_edge(start_key="reducer", end_key=END)
|
||||
# graph.add_edge(start_key="reducer", end_key=END)
|
||||
|
||||
return graph
|
||||
# return graph
|
||||
|
||||
@@ -1,30 +1,30 @@
|
||||
from collections.abc import Hashable
|
||||
# from collections.abc import Hashable
|
||||
|
||||
from langgraph.types import Send
|
||||
# from langgraph.types import Send
|
||||
|
||||
from onyx.agents.agent_search.dr.constants import MAX_DR_PARALLEL_SEARCH
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
|
||||
# from onyx.agents.agent_search.dr.constants import MAX_DR_PARALLEL_SEARCH
|
||||
# from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
|
||||
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
|
||||
|
||||
|
||||
def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
|
||||
return [
|
||||
Send(
|
||||
"act",
|
||||
BranchInput(
|
||||
iteration_nr=state.iteration_nr,
|
||||
parallelization_nr=parallelization_nr,
|
||||
branch_question=query,
|
||||
current_step_nr=state.current_step_nr,
|
||||
context="",
|
||||
active_source_types=state.active_source_types,
|
||||
tools_used=state.tools_used,
|
||||
available_tools=state.available_tools,
|
||||
assistant_system_prompt=state.assistant_system_prompt,
|
||||
assistant_task_prompt=state.assistant_task_prompt,
|
||||
),
|
||||
)
|
||||
for parallelization_nr, query in enumerate(
|
||||
state.query_list[:MAX_DR_PARALLEL_SEARCH]
|
||||
)
|
||||
]
|
||||
# def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
|
||||
# return [
|
||||
# Send(
|
||||
# "act",
|
||||
# BranchInput(
|
||||
# iteration_nr=state.iteration_nr,
|
||||
# parallelization_nr=parallelization_nr,
|
||||
# branch_question=query,
|
||||
# current_step_nr=state.current_step_nr,
|
||||
# context="",
|
||||
# active_source_types=state.active_source_types,
|
||||
# tools_used=state.tools_used,
|
||||
# available_tools=state.available_tools,
|
||||
# assistant_system_prompt=state.assistant_system_prompt,
|
||||
# assistant_task_prompt=state.assistant_task_prompt,
|
||||
# ),
|
||||
# )
|
||||
# for parallelization_nr, query in enumerate(
|
||||
# state.query_list[:MAX_DR_PARALLEL_SEARCH]
|
||||
# )
|
||||
# ]
|
||||
|
||||
@@ -1,36 +1,36 @@
|
||||
from datetime import datetime
|
||||
# from datetime import datetime
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
# from langchain_core.runnables import RunnableConfig
|
||||
# from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.dr.states import LoggerUpdate
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
# from onyx.agents.agent_search.dr.states import LoggerUpdate
|
||||
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
|
||||
# from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
# get_langgraph_node_log_string,
|
||||
# )
|
||||
# from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
# logger = setup_logger()
|
||||
|
||||
|
||||
def custom_tool_branch(
|
||||
state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> LoggerUpdate:
|
||||
"""
|
||||
LangGraph node to perform a generic tool call as part of the DR process.
|
||||
"""
|
||||
# def custom_tool_branch(
|
||||
# state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
# ) -> LoggerUpdate:
|
||||
# """
|
||||
# LangGraph node to perform a generic tool call as part of the DR process.
|
||||
# """
|
||||
|
||||
node_start_time = datetime.now()
|
||||
iteration_nr = state.iteration_nr
|
||||
# node_start_time = datetime.now()
|
||||
# iteration_nr = state.iteration_nr
|
||||
|
||||
logger.debug(f"Search start for Generic Tool {iteration_nr} at {datetime.now()}")
|
||||
# logger.debug(f"Search start for Generic Tool {iteration_nr} at {datetime.now()}")
|
||||
|
||||
return LoggerUpdate(
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="custom_tool",
|
||||
node_name="branching",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
# return LoggerUpdate(
|
||||
# log_messages=[
|
||||
# get_langgraph_node_log_string(
|
||||
# graph_component="custom_tool",
|
||||
# node_name="branching",
|
||||
# node_start_time=node_start_time,
|
||||
# )
|
||||
# ],
|
||||
# )
|
||||
|
||||
@@ -1,169 +1,164 @@
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
# import json
|
||||
# from datetime import datetime
|
||||
# from typing import cast
|
||||
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
# from langchain_core.messages import AIMessage
|
||||
# from langchain_core.runnables import RunnableConfig
|
||||
# from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import BranchUpdate
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import IterationAnswer
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.configs.agent_configs import TF_DR_TIMEOUT_LONG
|
||||
from onyx.configs.agent_configs import TF_DR_TIMEOUT_SHORT
|
||||
from onyx.prompts.dr_prompts import CUSTOM_TOOL_PREP_PROMPT
|
||||
from onyx.prompts.dr_prompts import CUSTOM_TOOL_USE_PROMPT
|
||||
from onyx.tools.tool_implementations.custom.custom_tool import CUSTOM_TOOL_RESPONSE_ID
|
||||
from onyx.tools.tool_implementations.custom.custom_tool import CustomTool
|
||||
from onyx.tools.tool_implementations.custom.custom_tool import CustomToolCallSummary
|
||||
from onyx.tools.tool_implementations.mcp.mcp_tool import MCP_TOOL_RESPONSE_ID
|
||||
from onyx.utils.logger import setup_logger
|
||||
# from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
|
||||
# from onyx.agents.agent_search.dr.sub_agents.states import BranchUpdate
|
||||
# from onyx.agents.agent_search.dr.sub_agents.states import IterationAnswer
|
||||
# from onyx.agents.agent_search.models import GraphConfig
|
||||
# from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
# get_langgraph_node_log_string,
|
||||
# )
|
||||
# from onyx.configs.agent_configs import TF_DR_TIMEOUT_LONG
|
||||
# from onyx.configs.agent_configs import TF_DR_TIMEOUT_SHORT
|
||||
# from onyx.prompts.dr_prompts import CUSTOM_TOOL_PREP_PROMPT
|
||||
# from onyx.prompts.dr_prompts import CUSTOM_TOOL_USE_PROMPT
|
||||
# from onyx.tools.tool_implementations.custom.custom_tool import CUSTOM_TOOL_RESPONSE_ID
|
||||
# from onyx.tools.tool_implementations.custom.custom_tool import CustomTool
|
||||
# from onyx.tools.tool_implementations.custom.custom_tool import CustomToolCallSummary
|
||||
# from onyx.tools.tool_implementations.mcp.mcp_tool import MCP_TOOL_RESPONSE_ID
|
||||
# from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
# logger = setup_logger()
|
||||
|
||||
|
||||
def custom_tool_act(
|
||||
state: BranchInput,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
) -> BranchUpdate:
|
||||
"""
|
||||
LangGraph node to perform a generic tool call as part of the DR process.
|
||||
"""
|
||||
# def custom_tool_act(
|
||||
# state: BranchInput,
|
||||
# config: RunnableConfig,
|
||||
# writer: StreamWriter = lambda _: None,
|
||||
# ) -> BranchUpdate:
|
||||
# """
|
||||
# LangGraph node to perform a generic tool call as part of the DR process.
|
||||
# """
|
||||
|
||||
node_start_time = datetime.now()
|
||||
iteration_nr = state.iteration_nr
|
||||
parallelization_nr = state.parallelization_nr
|
||||
# node_start_time = datetime.now()
|
||||
# iteration_nr = state.iteration_nr
|
||||
# parallelization_nr = state.parallelization_nr
|
||||
|
||||
if not state.available_tools:
|
||||
raise ValueError("available_tools is not set")
|
||||
# if not state.available_tools:
|
||||
# raise ValueError("available_tools is not set")
|
||||
|
||||
custom_tool_info = state.available_tools[state.tools_used[-1]]
|
||||
custom_tool_name = custom_tool_info.name
|
||||
custom_tool = cast(CustomTool, custom_tool_info.tool_object)
|
||||
# custom_tool_info = state.available_tools[state.tools_used[-1]]
|
||||
# custom_tool_name = custom_tool_info.name
|
||||
# custom_tool = cast(CustomTool, custom_tool_info.tool_object)
|
||||
|
||||
branch_query = state.branch_question
|
||||
if not branch_query:
|
||||
raise ValueError("branch_query is not set")
|
||||
# branch_query = state.branch_question
|
||||
# if not branch_query:
|
||||
# raise ValueError("branch_query is not set")
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
base_question = graph_config.inputs.prompt_builder.raw_user_query
|
||||
# graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
# base_question = graph_config.inputs.prompt_builder.raw_user_query
|
||||
|
||||
logger.debug(
|
||||
f"Tool call start for {custom_tool_name} {iteration_nr}.{parallelization_nr} at {datetime.now()}"
|
||||
)
|
||||
# logger.debug(
|
||||
# f"Tool call start for {custom_tool_name} {iteration_nr}.{parallelization_nr} at {datetime.now()}"
|
||||
# )
|
||||
|
||||
# get tool call args
|
||||
tool_args: dict | None = None
|
||||
if graph_config.tooling.using_tool_calling_llm:
|
||||
# get tool call args from tool-calling LLM
|
||||
tool_use_prompt = CUSTOM_TOOL_PREP_PROMPT.build(
|
||||
query=branch_query,
|
||||
base_question=base_question,
|
||||
tool_description=custom_tool_info.description,
|
||||
)
|
||||
tool_calling_msg = graph_config.tooling.primary_llm.invoke_langchain(
|
||||
tool_use_prompt,
|
||||
tools=[custom_tool.tool_definition()],
|
||||
tool_choice="required",
|
||||
timeout_override=TF_DR_TIMEOUT_LONG,
|
||||
)
|
||||
# # get tool call args
|
||||
# tool_args: dict | None = None
|
||||
# if graph_config.tooling.using_tool_calling_llm:
|
||||
# # get tool call args from tool-calling LLM
|
||||
# tool_use_prompt = CUSTOM_TOOL_PREP_PROMPT.build(
|
||||
# query=branch_query,
|
||||
# base_question=base_question,
|
||||
# tool_description=custom_tool_info.description,
|
||||
# )
|
||||
# tool_calling_msg = graph_config.tooling.primary_llm.invoke_langchain(
|
||||
# tool_use_prompt,
|
||||
# tools=[custom_tool.tool_definition()],
|
||||
# tool_choice="required",
|
||||
# timeout_override=TF_DR_TIMEOUT_LONG,
|
||||
# )
|
||||
|
||||
# make sure we got a tool call
|
||||
if (
|
||||
isinstance(tool_calling_msg, AIMessage)
|
||||
and len(tool_calling_msg.tool_calls) == 1
|
||||
):
|
||||
tool_args = tool_calling_msg.tool_calls[0]["args"]
|
||||
else:
|
||||
logger.warning("Tool-calling LLM did not emit a tool call")
|
||||
# # make sure we got a tool call
|
||||
# if (
|
||||
# isinstance(tool_calling_msg, AIMessage)
|
||||
# and len(tool_calling_msg.tool_calls) == 1
|
||||
# ):
|
||||
# tool_args = tool_calling_msg.tool_calls[0]["args"]
|
||||
# else:
|
||||
# logger.warning("Tool-calling LLM did not emit a tool call")
|
||||
|
||||
if tool_args is None:
|
||||
# get tool call args from non-tool-calling LLM or for failed tool-calling LLM
|
||||
tool_args = custom_tool.get_args_for_non_tool_calling_llm(
|
||||
query=branch_query,
|
||||
history=[],
|
||||
llm=graph_config.tooling.primary_llm,
|
||||
force_run=True,
|
||||
)
|
||||
# if tool_args is None:
|
||||
# raise ValueError(
|
||||
# "Failed to obtain tool arguments from LLM - tool calling is required"
|
||||
# )
|
||||
|
||||
if tool_args is None:
|
||||
raise ValueError("Failed to obtain tool arguments from LLM")
|
||||
# # run the tool
|
||||
# response_summary: CustomToolCallSummary | None = None
|
||||
# for tool_response in custom_tool.run(**tool_args):
|
||||
# if tool_response.id in {CUSTOM_TOOL_RESPONSE_ID, MCP_TOOL_RESPONSE_ID}:
|
||||
# response_summary = cast(CustomToolCallSummary, tool_response.response)
|
||||
# break
|
||||
|
||||
# run the tool
|
||||
response_summary: CustomToolCallSummary | None = None
|
||||
for tool_response in custom_tool.run(**tool_args):
|
||||
if tool_response.id in {CUSTOM_TOOL_RESPONSE_ID, MCP_TOOL_RESPONSE_ID}:
|
||||
response_summary = cast(CustomToolCallSummary, tool_response.response)
|
||||
break
|
||||
# if not response_summary:
|
||||
# raise ValueError("Custom tool did not return a valid response summary")
|
||||
|
||||
if not response_summary:
|
||||
raise ValueError("Custom tool did not return a valid response summary")
|
||||
# # summarise tool result
|
||||
# if not response_summary.response_type:
|
||||
# raise ValueError("Response type is not returned.")
|
||||
|
||||
# summarise tool result
|
||||
if not response_summary.response_type:
|
||||
raise ValueError("Response type is not returned.")
|
||||
# if response_summary.response_type == "json":
|
||||
# tool_result_str = json.dumps(response_summary.tool_result, ensure_ascii=False)
|
||||
# elif response_summary.response_type in {"image", "csv"}:
|
||||
# tool_result_str = f"{response_summary.response_type} files: {response_summary.tool_result.file_ids}"
|
||||
# else:
|
||||
# tool_result_str = str(response_summary.tool_result)
|
||||
|
||||
if response_summary.response_type == "json":
|
||||
tool_result_str = json.dumps(response_summary.tool_result, ensure_ascii=False)
|
||||
elif response_summary.response_type in {"image", "csv"}:
|
||||
tool_result_str = f"{response_summary.response_type} files: {response_summary.tool_result.file_ids}"
|
||||
else:
|
||||
tool_result_str = str(response_summary.tool_result)
|
||||
# tool_str = (
|
||||
# f"Tool used: {custom_tool_name}\n"
|
||||
# f"Description: {custom_tool_info.description}\n"
|
||||
# f"Result: {tool_result_str}"
|
||||
# )
|
||||
|
||||
tool_str = (
|
||||
f"Tool used: {custom_tool_name}\n"
|
||||
f"Description: {custom_tool_info.description}\n"
|
||||
f"Result: {tool_result_str}"
|
||||
)
|
||||
# tool_summary_prompt = CUSTOM_TOOL_USE_PROMPT.build(
|
||||
# query=branch_query, base_question=base_question, tool_response=tool_str
|
||||
# )
|
||||
# answer_string = str(
|
||||
# graph_config.tooling.primary_llm.invoke(
|
||||
# tool_summary_prompt, timeout_override=TF_DR_TIMEOUT_SHORT
|
||||
# ).content
|
||||
# ).strip()
|
||||
|
||||
tool_summary_prompt = CUSTOM_TOOL_USE_PROMPT.build(
|
||||
query=branch_query, base_question=base_question, tool_response=tool_str
|
||||
)
|
||||
answer_string = str(
|
||||
graph_config.tooling.primary_llm.invoke_langchain(
|
||||
tool_summary_prompt, timeout_override=TF_DR_TIMEOUT_SHORT
|
||||
).content
|
||||
).strip()
|
||||
# tool_summary_prompt = CUSTOM_TOOL_USE_PROMPT.build(
|
||||
# query=branch_query, base_question=base_question, tool_response=tool_str
|
||||
# )
|
||||
# answer_string = str(
|
||||
# graph_config.tooling.primary_llm.invoke_langchain(
|
||||
# tool_summary_prompt, timeout_override=TF_DR_TIMEOUT_SHORT
|
||||
# ).content
|
||||
# ).strip()
|
||||
|
||||
# get file_ids:
|
||||
file_ids = None
|
||||
if response_summary.response_type in {"image", "csv"} and hasattr(
|
||||
response_summary.tool_result, "file_ids"
|
||||
):
|
||||
file_ids = response_summary.tool_result.file_ids
|
||||
# logger.debug(
|
||||
# f"Tool call end for {custom_tool_name} {iteration_nr}.{parallelization_nr} at {datetime.now()}"
|
||||
# )
|
||||
|
||||
logger.debug(
|
||||
f"Tool call end for {custom_tool_name} {iteration_nr}.{parallelization_nr} at {datetime.now()}"
|
||||
)
|
||||
|
||||
return BranchUpdate(
|
||||
branch_iteration_responses=[
|
||||
IterationAnswer(
|
||||
tool=custom_tool_name,
|
||||
tool_id=custom_tool_info.tool_id,
|
||||
iteration_nr=iteration_nr,
|
||||
parallelization_nr=parallelization_nr,
|
||||
question=branch_query,
|
||||
answer=answer_string,
|
||||
claims=[],
|
||||
cited_documents={},
|
||||
reasoning="",
|
||||
additional_data=None,
|
||||
response_type=response_summary.response_type,
|
||||
data=response_summary.tool_result,
|
||||
file_ids=file_ids,
|
||||
)
|
||||
],
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="custom_tool",
|
||||
node_name="tool_calling",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
# return BranchUpdate(
|
||||
# branch_iteration_responses=[
|
||||
# IterationAnswer(
|
||||
# tool=custom_tool_name,
|
||||
# tool_id=custom_tool_info.tool_id,
|
||||
# iteration_nr=iteration_nr,
|
||||
# parallelization_nr=parallelization_nr,
|
||||
# question=branch_query,
|
||||
# answer=answer_string,
|
||||
# claims=[],
|
||||
# cited_documents={},
|
||||
# reasoning="",
|
||||
# additional_data=None,
|
||||
# response_type=response_summary.response_type,
|
||||
# data=response_summary.tool_result,
|
||||
# file_ids=file_ids,
|
||||
# )
|
||||
# ],
|
||||
# log_messages=[
|
||||
# get_langgraph_node_log_string(
|
||||
# graph_component="custom_tool",
|
||||
# node_name="tool_calling",
|
||||
# node_start_time=node_start_time,
|
||||
# )
|
||||
# ],
|
||||
# )
|
||||
|
||||
@@ -1,82 +1,82 @@
|
||||
from datetime import datetime
|
||||
# from datetime import datetime
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
# from langchain_core.runnables import RunnableConfig
|
||||
# from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentUpdate
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.server.query_and_chat.streaming_models import CustomToolDelta
|
||||
from onyx.server.query_and_chat.streaming_models import CustomToolStart
|
||||
from onyx.server.query_and_chat.streaming_models import SectionEnd
|
||||
from onyx.utils.logger import setup_logger
|
||||
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
|
||||
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentUpdate
|
||||
# from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
# get_langgraph_node_log_string,
|
||||
# )
|
||||
# from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
# from onyx.server.query_and_chat.streaming_models import CustomToolDelta
|
||||
# from onyx.server.query_and_chat.streaming_models import CustomToolStart
|
||||
# from onyx.server.query_and_chat.streaming_models import SectionEnd
|
||||
# from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
# logger = setup_logger()
|
||||
|
||||
|
||||
def custom_tool_reducer(
|
||||
state: SubAgentMainState,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
) -> SubAgentUpdate:
|
||||
"""
|
||||
LangGraph node to perform a generic tool call as part of the DR process.
|
||||
"""
|
||||
# def custom_tool_reducer(
|
||||
# state: SubAgentMainState,
|
||||
# config: RunnableConfig,
|
||||
# writer: StreamWriter = lambda _: None,
|
||||
# ) -> SubAgentUpdate:
|
||||
# """
|
||||
# LangGraph node to perform a generic tool call as part of the DR process.
|
||||
# """
|
||||
|
||||
node_start_time = datetime.now()
|
||||
# node_start_time = datetime.now()
|
||||
|
||||
current_step_nr = state.current_step_nr
|
||||
# current_step_nr = state.current_step_nr
|
||||
|
||||
branch_updates = state.branch_iteration_responses
|
||||
current_iteration = state.iteration_nr
|
||||
# branch_updates = state.branch_iteration_responses
|
||||
# current_iteration = state.iteration_nr
|
||||
|
||||
new_updates = [
|
||||
update for update in branch_updates if update.iteration_nr == current_iteration
|
||||
]
|
||||
# new_updates = [
|
||||
# update for update in branch_updates if update.iteration_nr == current_iteration
|
||||
# ]
|
||||
|
||||
for new_update in new_updates:
|
||||
# for new_update in new_updates:
|
||||
|
||||
if not new_update.response_type:
|
||||
raise ValueError("Response type is not returned.")
|
||||
# if not new_update.response_type:
|
||||
# raise ValueError("Response type is not returned.")
|
||||
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
CustomToolStart(
|
||||
tool_name=new_update.tool,
|
||||
),
|
||||
writer,
|
||||
)
|
||||
# write_custom_event(
|
||||
# current_step_nr,
|
||||
# CustomToolStart(
|
||||
# tool_name=new_update.tool,
|
||||
# ),
|
||||
# writer,
|
||||
# )
|
||||
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
CustomToolDelta(
|
||||
tool_name=new_update.tool,
|
||||
response_type=new_update.response_type,
|
||||
data=new_update.data,
|
||||
file_ids=new_update.file_ids,
|
||||
),
|
||||
writer,
|
||||
)
|
||||
# write_custom_event(
|
||||
# current_step_nr,
|
||||
# CustomToolDelta(
|
||||
# tool_name=new_update.tool,
|
||||
# response_type=new_update.response_type,
|
||||
# data=new_update.data,
|
||||
# file_ids=new_update.file_ids,
|
||||
# ),
|
||||
# writer,
|
||||
# )
|
||||
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
SectionEnd(),
|
||||
writer,
|
||||
)
|
||||
# write_custom_event(
|
||||
# current_step_nr,
|
||||
# SectionEnd(),
|
||||
# writer,
|
||||
# )
|
||||
|
||||
current_step_nr += 1
|
||||
# current_step_nr += 1
|
||||
|
||||
return SubAgentUpdate(
|
||||
iteration_responses=new_updates,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="custom_tool",
|
||||
node_name="consolidation",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
# return SubAgentUpdate(
|
||||
# iteration_responses=new_updates,
|
||||
# log_messages=[
|
||||
# get_langgraph_node_log_string(
|
||||
# graph_component="custom_tool",
|
||||
# node_name="consolidation",
|
||||
# node_start_time=node_start_time,
|
||||
# )
|
||||
# ],
|
||||
# )
|
||||
|
||||
@@ -1,28 +1,28 @@
|
||||
from collections.abc import Hashable
|
||||
# from collections.abc import Hashable
|
||||
|
||||
from langgraph.types import Send
|
||||
# from langgraph.types import Send
|
||||
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import (
|
||||
SubAgentInput,
|
||||
)
|
||||
# from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
|
||||
# from onyx.agents.agent_search.dr.sub_agents.states import (
|
||||
# SubAgentInput,
|
||||
# )
|
||||
|
||||
|
||||
def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
|
||||
return [
|
||||
Send(
|
||||
"act",
|
||||
BranchInput(
|
||||
iteration_nr=state.iteration_nr,
|
||||
parallelization_nr=parallelization_nr,
|
||||
branch_question=query,
|
||||
context="",
|
||||
active_source_types=state.active_source_types,
|
||||
tools_used=state.tools_used,
|
||||
available_tools=state.available_tools,
|
||||
),
|
||||
)
|
||||
for parallelization_nr, query in enumerate(
|
||||
state.query_list[:1] # no parallel call for now
|
||||
)
|
||||
]
|
||||
# def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
|
||||
# return [
|
||||
# Send(
|
||||
# "act",
|
||||
# BranchInput(
|
||||
# iteration_nr=state.iteration_nr,
|
||||
# parallelization_nr=parallelization_nr,
|
||||
# branch_question=query,
|
||||
# context="",
|
||||
# active_source_types=state.active_source_types,
|
||||
# tools_used=state.tools_used,
|
||||
# available_tools=state.available_tools,
|
||||
# ),
|
||||
# )
|
||||
# for parallelization_nr, query in enumerate(
|
||||
# state.query_list[:1] # no parallel call for now
|
||||
# )
|
||||
# ]
|
||||
|
||||
@@ -1,50 +1,50 @@
|
||||
from langgraph.graph import END
|
||||
from langgraph.graph import START
|
||||
from langgraph.graph import StateGraph
|
||||
# from langgraph.graph import END
|
||||
# from langgraph.graph import START
|
||||
# from langgraph.graph import StateGraph
|
||||
|
||||
from onyx.agents.agent_search.dr.sub_agents.custom_tool.dr_custom_tool_1_branch import (
|
||||
custom_tool_branch,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.custom_tool.dr_custom_tool_2_act import (
|
||||
custom_tool_act,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.custom_tool.dr_custom_tool_3_reduce import (
|
||||
custom_tool_reducer,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.custom_tool.dr_custom_tool_conditional_edges import (
|
||||
branching_router,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
|
||||
from onyx.utils.logger import setup_logger
|
||||
# from onyx.agents.agent_search.dr.sub_agents.custom_tool.dr_custom_tool_1_branch import (
|
||||
# custom_tool_branch,
|
||||
# )
|
||||
# from onyx.agents.agent_search.dr.sub_agents.custom_tool.dr_custom_tool_2_act import (
|
||||
# custom_tool_act,
|
||||
# )
|
||||
# from onyx.agents.agent_search.dr.sub_agents.custom_tool.dr_custom_tool_3_reduce import (
|
||||
# custom_tool_reducer,
|
||||
# )
|
||||
# from onyx.agents.agent_search.dr.sub_agents.custom_tool.dr_custom_tool_conditional_edges import (
|
||||
# branching_router,
|
||||
# )
|
||||
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
|
||||
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
|
||||
# from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
# logger = setup_logger()
|
||||
|
||||
|
||||
def dr_custom_tool_graph_builder() -> StateGraph:
|
||||
"""
|
||||
LangGraph graph builder for Generic Tool Sub-Agent
|
||||
"""
|
||||
# def dr_custom_tool_graph_builder() -> StateGraph:
|
||||
# """
|
||||
# LangGraph graph builder for Generic Tool Sub-Agent
|
||||
# """
|
||||
|
||||
graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
|
||||
# graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
|
||||
|
||||
### Add nodes ###
|
||||
# ### Add nodes ###
|
||||
|
||||
graph.add_node("branch", custom_tool_branch)
|
||||
# graph.add_node("branch", custom_tool_branch)
|
||||
|
||||
graph.add_node("act", custom_tool_act)
|
||||
# graph.add_node("act", custom_tool_act)
|
||||
|
||||
graph.add_node("reducer", custom_tool_reducer)
|
||||
# graph.add_node("reducer", custom_tool_reducer)
|
||||
|
||||
### Add edges ###
|
||||
# ### Add edges ###
|
||||
|
||||
graph.add_edge(start_key=START, end_key="branch")
|
||||
# graph.add_edge(start_key=START, end_key="branch")
|
||||
|
||||
graph.add_conditional_edges("branch", branching_router)
|
||||
# graph.add_conditional_edges("branch", branching_router)
|
||||
|
||||
graph.add_edge(start_key="act", end_key="reducer")
|
||||
# graph.add_edge(start_key="act", end_key="reducer")
|
||||
|
||||
graph.add_edge(start_key="reducer", end_key=END)
|
||||
# graph.add_edge(start_key="reducer", end_key=END)
|
||||
|
||||
return graph
|
||||
# return graph
|
||||
|
||||
@@ -1,36 +1,36 @@
|
||||
from datetime import datetime
|
||||
# from datetime import datetime
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
# from langchain_core.runnables import RunnableConfig
|
||||
# from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.dr.states import LoggerUpdate
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
# from onyx.agents.agent_search.dr.states import LoggerUpdate
|
||||
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
|
||||
# from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
# get_langgraph_node_log_string,
|
||||
# )
|
||||
# from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
# logger = setup_logger()
|
||||
|
||||
|
||||
def generic_internal_tool_branch(
|
||||
state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> LoggerUpdate:
|
||||
"""
|
||||
LangGraph node to perform a generic tool call as part of the DR process.
|
||||
"""
|
||||
# def generic_internal_tool_branch(
|
||||
# state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
# ) -> LoggerUpdate:
|
||||
# """
|
||||
# LangGraph node to perform a generic tool call as part of the DR process.
|
||||
# """
|
||||
|
||||
node_start_time = datetime.now()
|
||||
iteration_nr = state.iteration_nr
|
||||
# node_start_time = datetime.now()
|
||||
# iteration_nr = state.iteration_nr
|
||||
|
||||
logger.debug(f"Search start for Generic Tool {iteration_nr} at {datetime.now()}")
|
||||
# logger.debug(f"Search start for Generic Tool {iteration_nr} at {datetime.now()}")
|
||||
|
||||
return LoggerUpdate(
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="generic_internal_tool",
|
||||
node_name="branching",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
# return LoggerUpdate(
|
||||
# log_messages=[
|
||||
# get_langgraph_node_log_string(
|
||||
# graph_component="generic_internal_tool",
|
||||
# node_name="branching",
|
||||
# node_start_time=node_start_time,
|
||||
# )
|
||||
# ],
|
||||
# )
|
||||
|
||||
@@ -1,149 +1,147 @@
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
# import json
|
||||
# from datetime import datetime
|
||||
# from typing import cast
|
||||
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
# from langchain_core.messages import AIMessage
|
||||
# from langchain_core.runnables import RunnableConfig
|
||||
# from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import BranchUpdate
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import IterationAnswer
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.configs.agent_configs import TF_DR_TIMEOUT_SHORT
|
||||
from onyx.prompts.dr_prompts import CUSTOM_TOOL_PREP_PROMPT
|
||||
from onyx.prompts.dr_prompts import CUSTOM_TOOL_USE_PROMPT
|
||||
from onyx.prompts.dr_prompts import OKTA_TOOL_USE_SPECIAL_PROMPT
|
||||
from onyx.utils.logger import setup_logger
|
||||
# from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
|
||||
# from onyx.agents.agent_search.dr.sub_agents.states import BranchUpdate
|
||||
# from onyx.agents.agent_search.dr.sub_agents.states import IterationAnswer
|
||||
# from onyx.agents.agent_search.models import GraphConfig
|
||||
# from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
# get_langgraph_node_log_string,
|
||||
# )
|
||||
# from onyx.configs.agent_configs import TF_DR_TIMEOUT_SHORT
|
||||
# from onyx.prompts.dr_prompts import CUSTOM_TOOL_PREP_PROMPT
|
||||
# from onyx.prompts.dr_prompts import CUSTOM_TOOL_USE_PROMPT
|
||||
# from onyx.prompts.dr_prompts import OKTA_TOOL_USE_SPECIAL_PROMPT
|
||||
# from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
# logger = setup_logger()
|
||||
|
||||
|
||||
def generic_internal_tool_act(
|
||||
state: BranchInput,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
) -> BranchUpdate:
|
||||
"""
|
||||
LangGraph node to perform a generic tool call as part of the DR process.
|
||||
"""
|
||||
# def generic_internal_tool_act(
|
||||
# state: BranchInput,
|
||||
# config: RunnableConfig,
|
||||
# writer: StreamWriter = lambda _: None,
|
||||
# ) -> BranchUpdate:
|
||||
# """
|
||||
# LangGraph node to perform a generic tool call as part of the DR process.
|
||||
# """
|
||||
|
||||
node_start_time = datetime.now()
|
||||
iteration_nr = state.iteration_nr
|
||||
parallelization_nr = state.parallelization_nr
|
||||
# node_start_time = datetime.now()
|
||||
# iteration_nr = state.iteration_nr
|
||||
# parallelization_nr = state.parallelization_nr
|
||||
|
||||
if not state.available_tools:
|
||||
raise ValueError("available_tools is not set")
|
||||
# if not state.available_tools:
|
||||
# raise ValueError("available_tools is not set")
|
||||
|
||||
generic_internal_tool_info = state.available_tools[state.tools_used[-1]]
|
||||
generic_internal_tool_name = generic_internal_tool_info.llm_path
|
||||
generic_internal_tool = generic_internal_tool_info.tool_object
|
||||
# generic_internal_tool_info = state.available_tools[state.tools_used[-1]]
|
||||
# generic_internal_tool_name = generic_internal_tool_info.llm_path
|
||||
# generic_internal_tool = generic_internal_tool_info.tool_object
|
||||
|
||||
if generic_internal_tool is None:
|
||||
raise ValueError("generic_internal_tool is not set")
|
||||
# if generic_internal_tool is None:
|
||||
# raise ValueError("generic_internal_tool is not set")
|
||||
|
||||
branch_query = state.branch_question
|
||||
if not branch_query:
|
||||
raise ValueError("branch_query is not set")
|
||||
# branch_query = state.branch_question
|
||||
# if not branch_query:
|
||||
# raise ValueError("branch_query is not set")
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
base_question = graph_config.inputs.prompt_builder.raw_user_query
|
||||
# graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
# base_question = graph_config.inputs.prompt_builder.raw_user_query
|
||||
|
||||
logger.debug(
|
||||
f"Tool call start for {generic_internal_tool_name} {iteration_nr}.{parallelization_nr} at {datetime.now()}"
|
||||
)
|
||||
# logger.debug(
|
||||
# f"Tool call start for {generic_internal_tool_name} {iteration_nr}.{parallelization_nr} at {datetime.now()}"
|
||||
# )
|
||||
|
||||
# get tool call args
|
||||
tool_args: dict | None = None
|
||||
if graph_config.tooling.using_tool_calling_llm:
|
||||
# get tool call args from tool-calling LLM
|
||||
tool_use_prompt = CUSTOM_TOOL_PREP_PROMPT.build(
|
||||
query=branch_query,
|
||||
base_question=base_question,
|
||||
tool_description=generic_internal_tool_info.description,
|
||||
)
|
||||
tool_calling_msg = graph_config.tooling.primary_llm.invoke_langchain(
|
||||
tool_use_prompt,
|
||||
tools=[generic_internal_tool.tool_definition()],
|
||||
tool_choice="required",
|
||||
timeout_override=TF_DR_TIMEOUT_SHORT,
|
||||
)
|
||||
# # get tool call args
|
||||
# tool_args: dict | None = None
|
||||
# if graph_config.tooling.using_tool_calling_llm:
|
||||
# # get tool call args from tool-calling LLM
|
||||
# tool_use_prompt = CUSTOM_TOOL_PREP_PROMPT.build(
|
||||
# query=branch_query,
|
||||
# base_question=base_question,
|
||||
# tool_description=generic_internal_tool_info.description,
|
||||
# )
|
||||
# tool_calling_msg = graph_config.tooling.primary_llm.invoke_langchain(
|
||||
# tool_use_prompt,
|
||||
# tools=[generic_internal_tool.tool_definition()],
|
||||
# tool_choice="required",
|
||||
# timeout_override=TF_DR_TIMEOUT_SHORT,
|
||||
# )
|
||||
|
||||
# make sure we got a tool call
|
||||
if (
|
||||
isinstance(tool_calling_msg, AIMessage)
|
||||
and len(tool_calling_msg.tool_calls) == 1
|
||||
):
|
||||
tool_args = tool_calling_msg.tool_calls[0]["args"]
|
||||
else:
|
||||
logger.warning("Tool-calling LLM did not emit a tool call")
|
||||
# # make sure we got a tool call
|
||||
# if (
|
||||
# isinstance(tool_calling_msg, AIMessage)
|
||||
# and len(tool_calling_msg.tool_calls) == 1
|
||||
# ):
|
||||
# tool_args = tool_calling_msg.tool_calls[0]["args"]
|
||||
# else:
|
||||
# logger.warning("Tool-calling LLM did not emit a tool call")
|
||||
|
||||
if tool_args is None:
|
||||
# get tool call args from non-tool-calling LLM or for failed tool-calling LLM
|
||||
tool_args = generic_internal_tool.get_args_for_non_tool_calling_llm(
|
||||
query=branch_query,
|
||||
history=[],
|
||||
llm=graph_config.tooling.primary_llm,
|
||||
force_run=True,
|
||||
)
|
||||
# if tool_args is None:
|
||||
# raise ValueError(
|
||||
# "Failed to obtain tool arguments from LLM - tool calling is required"
|
||||
# )
|
||||
|
||||
if tool_args is None:
|
||||
raise ValueError("Failed to obtain tool arguments from LLM")
|
||||
# # run the tool
|
||||
# tool_responses = list(generic_internal_tool.run(**tool_args))
|
||||
# final_data = generic_internal_tool.final_result(*tool_responses)
|
||||
# tool_result_str = json.dumps(final_data, ensure_ascii=False)
|
||||
|
||||
# run the tool
|
||||
tool_responses = list(generic_internal_tool.run(**tool_args))
|
||||
final_data = generic_internal_tool.final_result(*tool_responses)
|
||||
tool_result_str = json.dumps(final_data, ensure_ascii=False)
|
||||
# tool_str = (
|
||||
# f"Tool used: {generic_internal_tool.display_name}\n"
|
||||
# f"Description: {generic_internal_tool_info.description}\n"
|
||||
# f"Result: {tool_result_str}"
|
||||
# )
|
||||
|
||||
tool_str = (
|
||||
f"Tool used: {generic_internal_tool.display_name}\n"
|
||||
f"Description: {generic_internal_tool_info.description}\n"
|
||||
f"Result: {tool_result_str}"
|
||||
)
|
||||
# if generic_internal_tool.display_name == "Okta Profile":
|
||||
# tool_prompt = OKTA_TOOL_USE_SPECIAL_PROMPT
|
||||
# else:
|
||||
# tool_prompt = CUSTOM_TOOL_USE_PROMPT
|
||||
|
||||
if generic_internal_tool.display_name == "Okta Profile":
|
||||
tool_prompt = OKTA_TOOL_USE_SPECIAL_PROMPT
|
||||
else:
|
||||
tool_prompt = CUSTOM_TOOL_USE_PROMPT
|
||||
# tool_summary_prompt = tool_prompt.build(
|
||||
# query=branch_query, base_question=base_question, tool_response=tool_str
|
||||
# )
|
||||
# answer_string = str(
|
||||
# graph_config.tooling.primary_llm.invoke(
|
||||
# tool_summary_prompt, timeout_override=TF_DR_TIMEOUT_SHORT
|
||||
# ).content
|
||||
# ).strip()
|
||||
|
||||
tool_summary_prompt = tool_prompt.build(
|
||||
query=branch_query, base_question=base_question, tool_response=tool_str
|
||||
)
|
||||
answer_string = str(
|
||||
graph_config.tooling.primary_llm.invoke_langchain(
|
||||
tool_summary_prompt, timeout_override=TF_DR_TIMEOUT_SHORT
|
||||
).content
|
||||
).strip()
|
||||
# tool_summary_prompt = tool_prompt.build(
|
||||
# query=branch_query, base_question=base_question, tool_response=tool_str
|
||||
# )
|
||||
# answer_string = str(
|
||||
# graph_config.tooling.primary_llm.invoke_langchain(
|
||||
# tool_summary_prompt, timeout_override=TF_DR_TIMEOUT_SHORT
|
||||
# ).content
|
||||
# ).strip()
|
||||
|
||||
logger.debug(
|
||||
f"Tool call end for {generic_internal_tool_name} {iteration_nr}.{parallelization_nr} at {datetime.now()}"
|
||||
)
|
||||
|
||||
return BranchUpdate(
|
||||
branch_iteration_responses=[
|
||||
IterationAnswer(
|
||||
tool=generic_internal_tool.llm_name,
|
||||
tool_id=generic_internal_tool_info.tool_id,
|
||||
iteration_nr=iteration_nr,
|
||||
parallelization_nr=parallelization_nr,
|
||||
question=branch_query,
|
||||
answer=answer_string,
|
||||
claims=[],
|
||||
cited_documents={},
|
||||
reasoning="",
|
||||
additional_data=None,
|
||||
response_type="text", # TODO: convert all response types to enums
|
||||
data=answer_string,
|
||||
)
|
||||
],
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="custom_tool",
|
||||
node_name="tool_calling",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
# return BranchUpdate(
|
||||
# branch_iteration_responses=[
|
||||
# IterationAnswer(
|
||||
# tool=generic_internal_tool.name,
|
||||
# tool_id=generic_internal_tool_info.tool_id,
|
||||
# iteration_nr=iteration_nr,
|
||||
# parallelization_nr=parallelization_nr,
|
||||
# question=branch_query,
|
||||
# answer=answer_string,
|
||||
# claims=[],
|
||||
# cited_documents={},
|
||||
# reasoning="",
|
||||
# additional_data=None,
|
||||
# response_type="text", # TODO: convert all response types to enums
|
||||
# data=answer_string,
|
||||
# )
|
||||
# ],
|
||||
# log_messages=[
|
||||
# get_langgraph_node_log_string(
|
||||
# graph_component="custom_tool",
|
||||
# node_name="tool_calling",
|
||||
# node_start_time=node_start_time,
|
||||
# )
|
||||
# ],
|
||||
# )
|
||||
|
||||
@@ -1,82 +1,82 @@
|
||||
from datetime import datetime
|
||||
# from datetime import datetime
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
# from langchain_core.runnables import RunnableConfig
|
||||
# from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentUpdate
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.server.query_and_chat.streaming_models import CustomToolDelta
|
||||
from onyx.server.query_and_chat.streaming_models import CustomToolStart
|
||||
from onyx.server.query_and_chat.streaming_models import SectionEnd
|
||||
from onyx.utils.logger import setup_logger
|
||||
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
|
||||
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentUpdate
|
||||
# from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
# get_langgraph_node_log_string,
|
||||
# )
|
||||
# from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
# from onyx.server.query_and_chat.streaming_models import CustomToolDelta
|
||||
# from onyx.server.query_and_chat.streaming_models import CustomToolStart
|
||||
# from onyx.server.query_and_chat.streaming_models import SectionEnd
|
||||
# from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
# logger = setup_logger()
|
||||
|
||||
|
||||
def generic_internal_tool_reducer(
|
||||
state: SubAgentMainState,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
) -> SubAgentUpdate:
|
||||
"""
|
||||
LangGraph node to perform a generic tool call as part of the DR process.
|
||||
"""
|
||||
# def generic_internal_tool_reducer(
|
||||
# state: SubAgentMainState,
|
||||
# config: RunnableConfig,
|
||||
# writer: StreamWriter = lambda _: None,
|
||||
# ) -> SubAgentUpdate:
|
||||
# """
|
||||
# LangGraph node to perform a generic tool call as part of the DR process.
|
||||
# """
|
||||
|
||||
node_start_time = datetime.now()
|
||||
# node_start_time = datetime.now()
|
||||
|
||||
current_step_nr = state.current_step_nr
|
||||
# current_step_nr = state.current_step_nr
|
||||
|
||||
branch_updates = state.branch_iteration_responses
|
||||
current_iteration = state.iteration_nr
|
||||
# branch_updates = state.branch_iteration_responses
|
||||
# current_iteration = state.iteration_nr
|
||||
|
||||
new_updates = [
|
||||
update for update in branch_updates if update.iteration_nr == current_iteration
|
||||
]
|
||||
# new_updates = [
|
||||
# update for update in branch_updates if update.iteration_nr == current_iteration
|
||||
# ]
|
||||
|
||||
for new_update in new_updates:
|
||||
# for new_update in new_updates:
|
||||
|
||||
if not new_update.response_type:
|
||||
raise ValueError("Response type is not returned.")
|
||||
# if not new_update.response_type:
|
||||
# raise ValueError("Response type is not returned.")
|
||||
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
CustomToolStart(
|
||||
tool_name=new_update.tool,
|
||||
),
|
||||
writer,
|
||||
)
|
||||
# write_custom_event(
|
||||
# current_step_nr,
|
||||
# CustomToolStart(
|
||||
# tool_name=new_update.tool,
|
||||
# ),
|
||||
# writer,
|
||||
# )
|
||||
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
CustomToolDelta(
|
||||
tool_name=new_update.tool,
|
||||
response_type=new_update.response_type,
|
||||
data=new_update.data,
|
||||
file_ids=[],
|
||||
),
|
||||
writer,
|
||||
)
|
||||
# write_custom_event(
|
||||
# current_step_nr,
|
||||
# CustomToolDelta(
|
||||
# tool_name=new_update.tool,
|
||||
# response_type=new_update.response_type,
|
||||
# data=new_update.data,
|
||||
# file_ids=[],
|
||||
# ),
|
||||
# writer,
|
||||
# )
|
||||
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
SectionEnd(),
|
||||
writer,
|
||||
)
|
||||
# write_custom_event(
|
||||
# current_step_nr,
|
||||
# SectionEnd(),
|
||||
# writer,
|
||||
# )
|
||||
|
||||
current_step_nr += 1
|
||||
# current_step_nr += 1
|
||||
|
||||
return SubAgentUpdate(
|
||||
iteration_responses=new_updates,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="custom_tool",
|
||||
node_name="consolidation",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
# return SubAgentUpdate(
|
||||
# iteration_responses=new_updates,
|
||||
# log_messages=[
|
||||
# get_langgraph_node_log_string(
|
||||
# graph_component="custom_tool",
|
||||
# node_name="consolidation",
|
||||
# node_start_time=node_start_time,
|
||||
# )
|
||||
# ],
|
||||
# )
|
||||
|
||||
@@ -1,28 +1,28 @@
|
||||
from collections.abc import Hashable
|
||||
# from collections.abc import Hashable
|
||||
|
||||
from langgraph.types import Send
|
||||
# from langgraph.types import Send
|
||||
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import (
|
||||
SubAgentInput,
|
||||
)
|
||||
# from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
|
||||
# from onyx.agents.agent_search.dr.sub_agents.states import (
|
||||
# SubAgentInput,
|
||||
# )
|
||||
|
||||
|
||||
def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
|
||||
return [
|
||||
Send(
|
||||
"act",
|
||||
BranchInput(
|
||||
iteration_nr=state.iteration_nr,
|
||||
parallelization_nr=parallelization_nr,
|
||||
branch_question=query,
|
||||
context="",
|
||||
active_source_types=state.active_source_types,
|
||||
tools_used=state.tools_used,
|
||||
available_tools=state.available_tools,
|
||||
),
|
||||
)
|
||||
for parallelization_nr, query in enumerate(
|
||||
state.query_list[:1] # no parallel call for now
|
||||
)
|
||||
]
|
||||
# def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
|
||||
# return [
|
||||
# Send(
|
||||
# "act",
|
||||
# BranchInput(
|
||||
# iteration_nr=state.iteration_nr,
|
||||
# parallelization_nr=parallelization_nr,
|
||||
# branch_question=query,
|
||||
# context="",
|
||||
# active_source_types=state.active_source_types,
|
||||
# tools_used=state.tools_used,
|
||||
# available_tools=state.available_tools,
|
||||
# ),
|
||||
# )
|
||||
# for parallelization_nr, query in enumerate(
|
||||
# state.query_list[:1] # no parallel call for now
|
||||
# )
|
||||
# ]
|
||||
|
||||
@@ -1,50 +1,50 @@
|
||||
from langgraph.graph import END
|
||||
from langgraph.graph import START
|
||||
from langgraph.graph import StateGraph
|
||||
# from langgraph.graph import END
|
||||
# from langgraph.graph import START
|
||||
# from langgraph.graph import StateGraph
|
||||
|
||||
from onyx.agents.agent_search.dr.sub_agents.generic_internal_tool.dr_generic_internal_tool_1_branch import (
|
||||
generic_internal_tool_branch,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.generic_internal_tool.dr_generic_internal_tool_2_act import (
|
||||
generic_internal_tool_act,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.generic_internal_tool.dr_generic_internal_tool_3_reduce import (
|
||||
generic_internal_tool_reducer,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.generic_internal_tool.dr_generic_internal_tool_conditional_edges import (
|
||||
branching_router,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
|
||||
from onyx.utils.logger import setup_logger
|
||||
# from onyx.agents.agent_search.dr.sub_agents.generic_internal_tool.dr_generic_internal_tool_1_branch import (
|
||||
# generic_internal_tool_branch,
|
||||
# )
|
||||
# from onyx.agents.agent_search.dr.sub_agents.generic_internal_tool.dr_generic_internal_tool_2_act import (
|
||||
# generic_internal_tool_act,
|
||||
# )
|
||||
# from onyx.agents.agent_search.dr.sub_agents.generic_internal_tool.dr_generic_internal_tool_3_reduce import (
|
||||
# generic_internal_tool_reducer,
|
||||
# )
|
||||
# from onyx.agents.agent_search.dr.sub_agents.generic_internal_tool.dr_generic_internal_tool_conditional_edges import (
|
||||
# branching_router,
|
||||
# )
|
||||
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
|
||||
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
|
||||
# from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
# logger = setup_logger()
|
||||
|
||||
|
||||
def dr_generic_internal_tool_graph_builder() -> StateGraph:
|
||||
"""
|
||||
LangGraph graph builder for Generic Tool Sub-Agent
|
||||
"""
|
||||
# def dr_generic_internal_tool_graph_builder() -> StateGraph:
|
||||
# """
|
||||
# LangGraph graph builder for Generic Tool Sub-Agent
|
||||
# """
|
||||
|
||||
graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
|
||||
# graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
|
||||
|
||||
### Add nodes ###
|
||||
# ### Add nodes ###
|
||||
|
||||
graph.add_node("branch", generic_internal_tool_branch)
|
||||
# graph.add_node("branch", generic_internal_tool_branch)
|
||||
|
||||
graph.add_node("act", generic_internal_tool_act)
|
||||
# graph.add_node("act", generic_internal_tool_act)
|
||||
|
||||
graph.add_node("reducer", generic_internal_tool_reducer)
|
||||
# graph.add_node("reducer", generic_internal_tool_reducer)
|
||||
|
||||
### Add edges ###
|
||||
# ### Add edges ###
|
||||
|
||||
graph.add_edge(start_key=START, end_key="branch")
|
||||
# graph.add_edge(start_key=START, end_key="branch")
|
||||
|
||||
graph.add_conditional_edges("branch", branching_router)
|
||||
# graph.add_conditional_edges("branch", branching_router)
|
||||
|
||||
graph.add_edge(start_key="act", end_key="reducer")
|
||||
# graph.add_edge(start_key="act", end_key="reducer")
|
||||
|
||||
graph.add_edge(start_key="reducer", end_key=END)
|
||||
# graph.add_edge(start_key="reducer", end_key=END)
|
||||
|
||||
return graph
|
||||
# return graph
|
||||
|
||||
@@ -1,45 +1,45 @@
|
||||
from datetime import datetime
|
||||
# from datetime import datetime
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
# from langchain_core.runnables import RunnableConfig
|
||||
# from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.dr.states import LoggerUpdate
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.server.query_and_chat.streaming_models import ImageGenerationToolStart
|
||||
from onyx.utils.logger import setup_logger
|
||||
# from onyx.agents.agent_search.dr.states import LoggerUpdate
|
||||
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
|
||||
# from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
# get_langgraph_node_log_string,
|
||||
# )
|
||||
# from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
# from onyx.server.query_and_chat.streaming_models import ImageGenerationToolStart
|
||||
# from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
# logger = setup_logger()
|
||||
|
||||
|
||||
def image_generation_branch(
|
||||
state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> LoggerUpdate:
|
||||
"""
|
||||
LangGraph node to perform a image generation as part of the DR process.
|
||||
"""
|
||||
# def image_generation_branch(
|
||||
# state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
# ) -> LoggerUpdate:
|
||||
# """
|
||||
# LangGraph node to perform a image generation as part of the DR process.
|
||||
# """
|
||||
|
||||
node_start_time = datetime.now()
|
||||
iteration_nr = state.iteration_nr
|
||||
# node_start_time = datetime.now()
|
||||
# iteration_nr = state.iteration_nr
|
||||
|
||||
logger.debug(f"Image generation start {iteration_nr} at {datetime.now()}")
|
||||
# logger.debug(f"Image generation start {iteration_nr} at {datetime.now()}")
|
||||
|
||||
# tell frontend that we are starting the image generation tool
|
||||
write_custom_event(
|
||||
state.current_step_nr,
|
||||
ImageGenerationToolStart(),
|
||||
writer,
|
||||
)
|
||||
# # tell frontend that we are starting the image generation tool
|
||||
# write_custom_event(
|
||||
# state.current_step_nr,
|
||||
# ImageGenerationToolStart(),
|
||||
# writer,
|
||||
# )
|
||||
|
||||
return LoggerUpdate(
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="image_generation",
|
||||
node_name="branching",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
# return LoggerUpdate(
|
||||
# log_messages=[
|
||||
# get_langgraph_node_log_string(
|
||||
# graph_component="image_generation",
|
||||
# node_name="branching",
|
||||
# node_start_time=node_start_time,
|
||||
# )
|
||||
# ],
|
||||
# )
|
||||
|
||||
@@ -1,189 +1,187 @@
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
# import json
|
||||
# from datetime import datetime
|
||||
# from typing import cast
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
# from langchain_core.runnables import RunnableConfig
|
||||
# from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.dr.models import GeneratedImage
|
||||
from onyx.agents.agent_search.dr.models import IterationAnswer
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import BranchUpdate
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.file_store.utils import build_frontend_file_url
|
||||
from onyx.file_store.utils import save_files
|
||||
from onyx.server.query_and_chat.streaming_models import ImageGenerationToolHeartbeat
|
||||
from onyx.tools.tool_implementations.images.image_generation_tool import (
|
||||
IMAGE_GENERATION_HEARTBEAT_ID,
|
||||
)
|
||||
from onyx.tools.tool_implementations.images.image_generation_tool import (
|
||||
IMAGE_GENERATION_RESPONSE_ID,
|
||||
)
|
||||
from onyx.tools.tool_implementations.images.image_generation_tool import (
|
||||
ImageGenerationResponse,
|
||||
)
|
||||
from onyx.tools.tool_implementations.images.image_generation_tool import (
|
||||
ImageGenerationTool,
|
||||
)
|
||||
from onyx.tools.tool_implementations.images.image_generation_tool import ImageShape
|
||||
from onyx.utils.logger import setup_logger
|
||||
# from onyx.agents.agent_search.dr.models import GeneratedImage
|
||||
# from onyx.agents.agent_search.dr.models import IterationAnswer
|
||||
# from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
|
||||
# from onyx.agents.agent_search.dr.sub_agents.states import BranchUpdate
|
||||
# from onyx.agents.agent_search.models import GraphConfig
|
||||
# from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
# get_langgraph_node_log_string,
|
||||
# )
|
||||
# from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
# from onyx.file_store.utils import build_frontend_file_url
|
||||
# from onyx.file_store.utils import save_files
|
||||
# from onyx.server.query_and_chat.streaming_models import ImageGenerationToolHeartbeat
|
||||
# from onyx.tools.tool_implementations.images.image_generation_tool import (
|
||||
# IMAGE_GENERATION_HEARTBEAT_ID,
|
||||
# )
|
||||
# from onyx.tools.tool_implementations.images.image_generation_tool import (
|
||||
# IMAGE_GENERATION_RESPONSE_ID,
|
||||
# )
|
||||
# from onyx.tools.tool_implementations.images.image_generation_tool import (
|
||||
# ImageGenerationResponse,
|
||||
# )
|
||||
# from onyx.tools.tool_implementations.images.image_generation_tool import (
|
||||
# ImageGenerationTool,
|
||||
# )
|
||||
# from onyx.tools.tool_implementations.images.image_generation_tool import ImageShape
|
||||
# from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
# logger = setup_logger()
|
||||
|
||||
|
||||
def image_generation(
|
||||
state: BranchInput,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
) -> BranchUpdate:
|
||||
"""
|
||||
LangGraph node to perform a standard search as part of the DR process.
|
||||
"""
|
||||
# def image_generation(
|
||||
# state: BranchInput,
|
||||
# config: RunnableConfig,
|
||||
# writer: StreamWriter = lambda _: None,
|
||||
# ) -> BranchUpdate:
|
||||
# """
|
||||
# LangGraph node to perform a standard search as part of the DR process.
|
||||
# """
|
||||
|
||||
node_start_time = datetime.now()
|
||||
iteration_nr = state.iteration_nr
|
||||
parallelization_nr = state.parallelization_nr
|
||||
state.assistant_system_prompt
|
||||
state.assistant_task_prompt
|
||||
# node_start_time = datetime.now()
|
||||
# iteration_nr = state.iteration_nr
|
||||
# parallelization_nr = state.parallelization_nr
|
||||
# state.assistant_system_prompt
|
||||
# state.assistant_task_prompt
|
||||
|
||||
branch_query = state.branch_question
|
||||
if not branch_query:
|
||||
raise ValueError("branch_query is not set")
|
||||
# branch_query = state.branch_question
|
||||
# if not branch_query:
|
||||
# raise ValueError("branch_query is not set")
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
graph_config.inputs.prompt_builder.raw_user_query
|
||||
graph_config.behavior.research_type
|
||||
# graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
|
||||
if not state.available_tools:
|
||||
raise ValueError("available_tools is not set")
|
||||
# if not state.available_tools:
|
||||
# raise ValueError("available_tools is not set")
|
||||
|
||||
image_tool_info = state.available_tools[state.tools_used[-1]]
|
||||
image_tool = cast(ImageGenerationTool, image_tool_info.tool_object)
|
||||
# image_tool_info = state.available_tools[state.tools_used[-1]]
|
||||
# image_tool = cast(ImageGenerationTool, image_tool_info.tool_object)
|
||||
|
||||
image_prompt = branch_query
|
||||
requested_shape: ImageShape | None = None
|
||||
# image_prompt = branch_query
|
||||
# requested_shape: ImageShape | None = None
|
||||
|
||||
try:
|
||||
parsed_query = json.loads(branch_query)
|
||||
except json.JSONDecodeError:
|
||||
parsed_query = None
|
||||
# try:
|
||||
# parsed_query = json.loads(branch_query)
|
||||
# except json.JSONDecodeError:
|
||||
# parsed_query = None
|
||||
|
||||
if isinstance(parsed_query, dict):
|
||||
prompt_from_llm = parsed_query.get("prompt")
|
||||
if isinstance(prompt_from_llm, str) and prompt_from_llm.strip():
|
||||
image_prompt = prompt_from_llm.strip()
|
||||
# if isinstance(parsed_query, dict):
|
||||
# prompt_from_llm = parsed_query.get("prompt")
|
||||
# if isinstance(prompt_from_llm, str) and prompt_from_llm.strip():
|
||||
# image_prompt = prompt_from_llm.strip()
|
||||
|
||||
raw_shape = parsed_query.get("shape")
|
||||
if isinstance(raw_shape, str):
|
||||
try:
|
||||
requested_shape = ImageShape(raw_shape)
|
||||
except ValueError:
|
||||
logger.warning(
|
||||
"Received unsupported image shape '%s' from LLM. Falling back to square.",
|
||||
raw_shape,
|
||||
)
|
||||
# raw_shape = parsed_query.get("shape")
|
||||
# if isinstance(raw_shape, str):
|
||||
# try:
|
||||
# requested_shape = ImageShape(raw_shape)
|
||||
# except ValueError:
|
||||
# logger.warning(
|
||||
# "Received unsupported image shape '%s' from LLM. Falling back to square.",
|
||||
# raw_shape,
|
||||
# )
|
||||
|
||||
logger.debug(
|
||||
f"Image generation start for {iteration_nr}.{parallelization_nr} at {datetime.now()}"
|
||||
)
|
||||
# logger.debug(
|
||||
# f"Image generation start for {iteration_nr}.{parallelization_nr} at {datetime.now()}"
|
||||
# )
|
||||
|
||||
# Generate images using the image generation tool
|
||||
image_generation_responses: list[ImageGenerationResponse] = []
|
||||
# # Generate images using the image generation tool
|
||||
# image_generation_responses: list[ImageGenerationResponse] = []
|
||||
|
||||
if requested_shape is not None:
|
||||
tool_iterator = image_tool.run(
|
||||
prompt=image_prompt,
|
||||
shape=requested_shape.value,
|
||||
)
|
||||
else:
|
||||
tool_iterator = image_tool.run(prompt=image_prompt)
|
||||
# if requested_shape is not None:
|
||||
# tool_iterator = image_tool.run(
|
||||
# prompt=image_prompt,
|
||||
# shape=requested_shape.value,
|
||||
# )
|
||||
# else:
|
||||
# tool_iterator = image_tool.run(prompt=image_prompt)
|
||||
|
||||
for tool_response in tool_iterator:
|
||||
if tool_response.id == IMAGE_GENERATION_HEARTBEAT_ID:
|
||||
# Stream heartbeat to frontend
|
||||
write_custom_event(
|
||||
state.current_step_nr,
|
||||
ImageGenerationToolHeartbeat(),
|
||||
writer,
|
||||
)
|
||||
elif tool_response.id == IMAGE_GENERATION_RESPONSE_ID:
|
||||
response = cast(list[ImageGenerationResponse], tool_response.response)
|
||||
image_generation_responses = response
|
||||
break
|
||||
# for tool_response in tool_iterator:
|
||||
# if tool_response.id == IMAGE_GENERATION_HEARTBEAT_ID:
|
||||
# # Stream heartbeat to frontend
|
||||
# write_custom_event(
|
||||
# state.current_step_nr,
|
||||
# ImageGenerationToolHeartbeat(),
|
||||
# writer,
|
||||
# )
|
||||
# elif tool_response.id == IMAGE_GENERATION_RESPONSE_ID:
|
||||
# response = cast(list[ImageGenerationResponse], tool_response.response)
|
||||
# image_generation_responses = response
|
||||
# break
|
||||
|
||||
# save images to file store
|
||||
file_ids = save_files(
|
||||
urls=[],
|
||||
base64_files=[img.image_data for img in image_generation_responses],
|
||||
)
|
||||
# # save images to file store
|
||||
# file_ids = save_files(
|
||||
# urls=[],
|
||||
# base64_files=[img.image_data for img in image_generation_responses],
|
||||
# )
|
||||
|
||||
final_generated_images = [
|
||||
GeneratedImage(
|
||||
file_id=file_id,
|
||||
url=build_frontend_file_url(file_id),
|
||||
revised_prompt=img.revised_prompt,
|
||||
shape=(requested_shape or ImageShape.SQUARE).value,
|
||||
)
|
||||
for file_id, img in zip(file_ids, image_generation_responses)
|
||||
]
|
||||
# final_generated_images = [
|
||||
# GeneratedImage(
|
||||
# file_id=file_id,
|
||||
# url=build_frontend_file_url(file_id),
|
||||
# revised_prompt=img.revised_prompt,
|
||||
# shape=(requested_shape or ImageShape.SQUARE).value,
|
||||
# )
|
||||
# for file_id, img in zip(file_ids, image_generation_responses)
|
||||
# ]
|
||||
|
||||
logger.debug(
|
||||
f"Image generation complete for {iteration_nr}.{parallelization_nr} at {datetime.now()}"
|
||||
)
|
||||
# logger.debug(
|
||||
# f"Image generation complete for {iteration_nr}.{parallelization_nr} at {datetime.now()}"
|
||||
# )
|
||||
|
||||
# Create answer string describing the generated images
|
||||
if final_generated_images:
|
||||
image_descriptions = []
|
||||
for i, img in enumerate(final_generated_images, 1):
|
||||
if img.shape and img.shape != ImageShape.SQUARE.value:
|
||||
image_descriptions.append(
|
||||
f"Image {i}: {img.revised_prompt} (shape: {img.shape})"
|
||||
)
|
||||
else:
|
||||
image_descriptions.append(f"Image {i}: {img.revised_prompt}")
|
||||
# # Create answer string describing the generated images
|
||||
# if final_generated_images:
|
||||
# image_descriptions = []
|
||||
# for i, img in enumerate(final_generated_images, 1):
|
||||
# if img.shape and img.shape != ImageShape.SQUARE.value:
|
||||
# image_descriptions.append(
|
||||
# f"Image {i}: {img.revised_prompt} (shape: {img.shape})"
|
||||
# )
|
||||
# else:
|
||||
# image_descriptions.append(f"Image {i}: {img.revised_prompt}")
|
||||
|
||||
answer_string = (
|
||||
f"Generated {len(final_generated_images)} image(s) based on the request: {image_prompt}\n\n"
|
||||
+ "\n".join(image_descriptions)
|
||||
)
|
||||
if requested_shape:
|
||||
reasoning = (
|
||||
"Used image generation tool to create "
|
||||
f"{len(final_generated_images)} image(s) in {requested_shape.value} orientation."
|
||||
)
|
||||
else:
|
||||
reasoning = (
|
||||
"Used image generation tool to create "
|
||||
f"{len(final_generated_images)} image(s) based on the user's request."
|
||||
)
|
||||
else:
|
||||
answer_string = f"Failed to generate images for request: {image_prompt}"
|
||||
reasoning = "Image generation tool did not return any results."
|
||||
# answer_string = (
|
||||
# f"Generated {len(final_generated_images)} image(s) based on the request: {image_prompt}\n\n"
|
||||
# + "\n".join(image_descriptions)
|
||||
# )
|
||||
# if requested_shape:
|
||||
# reasoning = (
|
||||
# "Used image generation tool to create "
|
||||
# f"{len(final_generated_images)} image(s) in {requested_shape.value} orientation."
|
||||
# )
|
||||
# else:
|
||||
# reasoning = (
|
||||
# "Used image generation tool to create "
|
||||
# f"{len(final_generated_images)} image(s) based on the user's request."
|
||||
# )
|
||||
# else:
|
||||
# answer_string = f"Failed to generate images for request: {image_prompt}"
|
||||
# reasoning = "Image generation tool did not return any results."
|
||||
|
||||
return BranchUpdate(
|
||||
branch_iteration_responses=[
|
||||
IterationAnswer(
|
||||
tool=image_tool_info.llm_path,
|
||||
tool_id=image_tool_info.tool_id,
|
||||
iteration_nr=iteration_nr,
|
||||
parallelization_nr=parallelization_nr,
|
||||
question=branch_query,
|
||||
answer=answer_string,
|
||||
claims=[],
|
||||
cited_documents={},
|
||||
reasoning=reasoning,
|
||||
generated_images=final_generated_images,
|
||||
)
|
||||
],
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="image_generation",
|
||||
node_name="generating",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
# return BranchUpdate(
|
||||
# branch_iteration_responses=[
|
||||
# IterationAnswer(
|
||||
# tool=image_tool_info.llm_path,
|
||||
# tool_id=image_tool_info.tool_id,
|
||||
# iteration_nr=iteration_nr,
|
||||
# parallelization_nr=parallelization_nr,
|
||||
# question=branch_query,
|
||||
# answer=answer_string,
|
||||
# claims=[],
|
||||
# cited_documents={},
|
||||
# reasoning=reasoning,
|
||||
# generated_images=final_generated_images,
|
||||
# )
|
||||
# ],
|
||||
# log_messages=[
|
||||
# get_langgraph_node_log_string(
|
||||
# graph_component="image_generation",
|
||||
# node_name="generating",
|
||||
# node_start_time=node_start_time,
|
||||
# )
|
||||
# ],
|
||||
# )
|
||||
|
||||
@@ -1,71 +1,71 @@
|
||||
from datetime import datetime
|
||||
# from datetime import datetime
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
# from langchain_core.runnables import RunnableConfig
|
||||
# from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.dr.models import GeneratedImage
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentUpdate
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.server.query_and_chat.streaming_models import ImageGenerationToolDelta
|
||||
from onyx.server.query_and_chat.streaming_models import SectionEnd
|
||||
from onyx.utils.logger import setup_logger
|
||||
# from onyx.agents.agent_search.dr.models import GeneratedImage
|
||||
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
|
||||
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentUpdate
|
||||
# from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
# get_langgraph_node_log_string,
|
||||
# )
|
||||
# from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
# from onyx.server.query_and_chat.streaming_models import ImageGenerationFinal
|
||||
# from onyx.server.query_and_chat.streaming_models import SectionEnd
|
||||
# from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
# logger = setup_logger()
|
||||
|
||||
|
||||
def is_reducer(
|
||||
state: SubAgentMainState,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
) -> SubAgentUpdate:
|
||||
"""
|
||||
LangGraph node to perform a standard search as part of the DR process.
|
||||
"""
|
||||
# def is_reducer(
|
||||
# state: SubAgentMainState,
|
||||
# config: RunnableConfig,
|
||||
# writer: StreamWriter = lambda _: None,
|
||||
# ) -> SubAgentUpdate:
|
||||
# """
|
||||
# LangGraph node to perform a standard search as part of the DR process.
|
||||
# """
|
||||
|
||||
node_start_time = datetime.now()
|
||||
# node_start_time = datetime.now()
|
||||
|
||||
branch_updates = state.branch_iteration_responses
|
||||
current_iteration = state.iteration_nr
|
||||
current_step_nr = state.current_step_nr
|
||||
# branch_updates = state.branch_iteration_responses
|
||||
# current_iteration = state.iteration_nr
|
||||
# current_step_nr = state.current_step_nr
|
||||
|
||||
new_updates = [
|
||||
update for update in branch_updates if update.iteration_nr == current_iteration
|
||||
]
|
||||
generated_images: list[GeneratedImage] = []
|
||||
for update in new_updates:
|
||||
if update.generated_images:
|
||||
generated_images.extend(update.generated_images)
|
||||
# new_updates = [
|
||||
# update for update in branch_updates if update.iteration_nr == current_iteration
|
||||
# ]
|
||||
# generated_images: list[GeneratedImage] = []
|
||||
# for update in new_updates:
|
||||
# if update.generated_images:
|
||||
# generated_images.extend(update.generated_images)
|
||||
|
||||
# Write the results to the stream
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
ImageGenerationToolDelta(
|
||||
images=generated_images,
|
||||
),
|
||||
writer,
|
||||
)
|
||||
# # Write the results to the stream
|
||||
# write_custom_event(
|
||||
# current_step_nr,
|
||||
# ImageGenerationFinal(
|
||||
# images=generated_images,
|
||||
# ),
|
||||
# writer,
|
||||
# )
|
||||
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
SectionEnd(),
|
||||
writer,
|
||||
)
|
||||
# write_custom_event(
|
||||
# current_step_nr,
|
||||
# SectionEnd(),
|
||||
# writer,
|
||||
# )
|
||||
|
||||
current_step_nr += 1
|
||||
# current_step_nr += 1
|
||||
|
||||
return SubAgentUpdate(
|
||||
iteration_responses=new_updates,
|
||||
current_step_nr=current_step_nr,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="image_generation",
|
||||
node_name="consolidation",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
# return SubAgentUpdate(
|
||||
# iteration_responses=new_updates,
|
||||
# current_step_nr=current_step_nr,
|
||||
# log_messages=[
|
||||
# get_langgraph_node_log_string(
|
||||
# graph_component="image_generation",
|
||||
# node_name="consolidation",
|
||||
# node_start_time=node_start_time,
|
||||
# )
|
||||
# ],
|
||||
# )
|
||||
|
||||
@@ -1,29 +1,29 @@
|
||||
from collections.abc import Hashable
|
||||
# from collections.abc import Hashable
|
||||
|
||||
from langgraph.types import Send
|
||||
# from langgraph.types import Send
|
||||
|
||||
from onyx.agents.agent_search.dr.constants import MAX_DR_PARALLEL_SEARCH
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
|
||||
# from onyx.agents.agent_search.dr.constants import MAX_DR_PARALLEL_SEARCH
|
||||
# from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
|
||||
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
|
||||
|
||||
|
||||
def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
|
||||
return [
|
||||
Send(
|
||||
"act",
|
||||
BranchInput(
|
||||
iteration_nr=state.iteration_nr,
|
||||
parallelization_nr=parallelization_nr,
|
||||
branch_question=query,
|
||||
context="",
|
||||
active_source_types=state.active_source_types,
|
||||
tools_used=state.tools_used,
|
||||
available_tools=state.available_tools,
|
||||
assistant_system_prompt=state.assistant_system_prompt,
|
||||
assistant_task_prompt=state.assistant_task_prompt,
|
||||
),
|
||||
)
|
||||
for parallelization_nr, query in enumerate(
|
||||
state.query_list[:MAX_DR_PARALLEL_SEARCH]
|
||||
)
|
||||
]
|
||||
# def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
|
||||
# return [
|
||||
# Send(
|
||||
# "act",
|
||||
# BranchInput(
|
||||
# iteration_nr=state.iteration_nr,
|
||||
# parallelization_nr=parallelization_nr,
|
||||
# branch_question=query,
|
||||
# context="",
|
||||
# active_source_types=state.active_source_types,
|
||||
# tools_used=state.tools_used,
|
||||
# available_tools=state.available_tools,
|
||||
# assistant_system_prompt=state.assistant_system_prompt,
|
||||
# assistant_task_prompt=state.assistant_task_prompt,
|
||||
# ),
|
||||
# )
|
||||
# for parallelization_nr, query in enumerate(
|
||||
# state.query_list[:MAX_DR_PARALLEL_SEARCH]
|
||||
# )
|
||||
# ]
|
||||
|
||||
@@ -1,50 +1,50 @@
|
||||
from langgraph.graph import END
|
||||
from langgraph.graph import START
|
||||
from langgraph.graph import StateGraph
|
||||
# from langgraph.graph import END
|
||||
# from langgraph.graph import START
|
||||
# from langgraph.graph import StateGraph
|
||||
|
||||
from onyx.agents.agent_search.dr.sub_agents.image_generation.dr_image_generation_1_branch import (
|
||||
image_generation_branch,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.image_generation.dr_image_generation_2_act import (
|
||||
image_generation,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.image_generation.dr_image_generation_3_reduce import (
|
||||
is_reducer,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.image_generation.dr_image_generation_conditional_edges import (
|
||||
branching_router,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
|
||||
from onyx.utils.logger import setup_logger
|
||||
# from onyx.agents.agent_search.dr.sub_agents.image_generation.dr_image_generation_1_branch import (
|
||||
# image_generation_branch,
|
||||
# )
|
||||
# from onyx.agents.agent_search.dr.sub_agents.image_generation.dr_image_generation_2_act import (
|
||||
# image_generation,
|
||||
# )
|
||||
# from onyx.agents.agent_search.dr.sub_agents.image_generation.dr_image_generation_3_reduce import (
|
||||
# is_reducer,
|
||||
# )
|
||||
# from onyx.agents.agent_search.dr.sub_agents.image_generation.dr_image_generation_conditional_edges import (
|
||||
# branching_router,
|
||||
# )
|
||||
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
|
||||
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
|
||||
# from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
# logger = setup_logger()
|
||||
|
||||
|
||||
def dr_image_generation_graph_builder() -> StateGraph:
|
||||
"""
|
||||
LangGraph graph builder for Image Generation Sub-Agent
|
||||
"""
|
||||
# def dr_image_generation_graph_builder() -> StateGraph:
|
||||
# """
|
||||
# LangGraph graph builder for Image Generation Sub-Agent
|
||||
# """
|
||||
|
||||
graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
|
||||
# graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
|
||||
|
||||
### Add nodes ###
|
||||
# ### Add nodes ###
|
||||
|
||||
graph.add_node("branch", image_generation_branch)
|
||||
# graph.add_node("branch", image_generation_branch)
|
||||
|
||||
graph.add_node("act", image_generation)
|
||||
# graph.add_node("act", image_generation)
|
||||
|
||||
graph.add_node("reducer", is_reducer)
|
||||
# graph.add_node("reducer", is_reducer)
|
||||
|
||||
### Add edges ###
|
||||
# ### Add edges ###
|
||||
|
||||
graph.add_edge(start_key=START, end_key="branch")
|
||||
# graph.add_edge(start_key=START, end_key="branch")
|
||||
|
||||
graph.add_conditional_edges("branch", branching_router)
|
||||
# graph.add_conditional_edges("branch", branching_router)
|
||||
|
||||
graph.add_edge(start_key="act", end_key="reducer")
|
||||
# graph.add_edge(start_key="act", end_key="reducer")
|
||||
|
||||
graph.add_edge(start_key="reducer", end_key=END)
|
||||
# graph.add_edge(start_key="reducer", end_key=END)
|
||||
|
||||
return graph
|
||||
# return graph
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
from pydantic import BaseModel
|
||||
# from pydantic import BaseModel
|
||||
|
||||
|
||||
class GeneratedImage(BaseModel):
|
||||
file_id: str
|
||||
url: str
|
||||
revised_prompt: str
|
||||
shape: str | None = None
|
||||
# class GeneratedImage(BaseModel):
|
||||
# file_id: str
|
||||
# url: str
|
||||
# revised_prompt: str
|
||||
# shape: str | None = None
|
||||
|
||||
|
||||
# Needed for PydanticType
|
||||
class GeneratedImageFullResult(BaseModel):
|
||||
images: list[GeneratedImage]
|
||||
# # Needed for PydanticType
|
||||
# class GeneratedImageFullResult(BaseModel):
|
||||
# images: list[GeneratedImage]
|
||||
|
||||
@@ -1,36 +1,36 @@
|
||||
from datetime import datetime
|
||||
# from datetime import datetime
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
# from langchain_core.runnables import RunnableConfig
|
||||
# from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.dr.states import LoggerUpdate
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
# from onyx.agents.agent_search.dr.states import LoggerUpdate
|
||||
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
|
||||
# from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
# get_langgraph_node_log_string,
|
||||
# )
|
||||
# from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
# logger = setup_logger()
|
||||
|
||||
|
||||
def kg_search_branch(
|
||||
state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> LoggerUpdate:
|
||||
"""
|
||||
LangGraph node to perform a KG search as part of the DR process.
|
||||
"""
|
||||
# def kg_search_branch(
|
||||
# state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
# ) -> LoggerUpdate:
|
||||
# """
|
||||
# LangGraph node to perform a KG search as part of the DR process.
|
||||
# """
|
||||
|
||||
node_start_time = datetime.now()
|
||||
iteration_nr = state.iteration_nr
|
||||
# node_start_time = datetime.now()
|
||||
# iteration_nr = state.iteration_nr
|
||||
|
||||
logger.debug(f"Search start for KG Search {iteration_nr} at {datetime.now()}")
|
||||
# logger.debug(f"Search start for KG Search {iteration_nr} at {datetime.now()}")
|
||||
|
||||
return LoggerUpdate(
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="kg_search",
|
||||
node_name="branching",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
# return LoggerUpdate(
|
||||
# log_messages=[
|
||||
# get_langgraph_node_log_string(
|
||||
# graph_component="kg_search",
|
||||
# node_name="branching",
|
||||
# node_start_time=node_start_time,
|
||||
# )
|
||||
# ],
|
||||
# )
|
||||
|
||||
@@ -1,97 +1,97 @@
|
||||
from datetime import datetime
|
||||
# from datetime import datetime
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
# from langchain_core.runnables import RunnableConfig
|
||||
# from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.dr.models import IterationAnswer
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import BranchUpdate
|
||||
from onyx.agents.agent_search.dr.utils import extract_document_citations
|
||||
from onyx.agents.agent_search.kb_search.graph_builder import kb_graph_builder
|
||||
from onyx.agents.agent_search.kb_search.states import MainInput as KbMainInput
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.utils.logger import setup_logger
|
||||
# from onyx.agents.agent_search.dr.models import IterationAnswer
|
||||
# from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
|
||||
# from onyx.agents.agent_search.dr.sub_agents.states import BranchUpdate
|
||||
# from onyx.agents.agent_search.dr.utils import extract_document_citations
|
||||
# from onyx.agents.agent_search.kb_search.graph_builder import kb_graph_builder
|
||||
# from onyx.agents.agent_search.kb_search.states import MainInput as KbMainInput
|
||||
# from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
# get_langgraph_node_log_string,
|
||||
# )
|
||||
# from onyx.context.search.models import InferenceSection
|
||||
# from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
# logger = setup_logger()
|
||||
|
||||
|
||||
def kg_search(
|
||||
state: BranchInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> BranchUpdate:
|
||||
"""
|
||||
LangGraph node to perform a KG search as part of the DR process.
|
||||
"""
|
||||
# def kg_search(
|
||||
# state: BranchInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
# ) -> BranchUpdate:
|
||||
# """
|
||||
# LangGraph node to perform a KG search as part of the DR process.
|
||||
# """
|
||||
|
||||
node_start_time = datetime.now()
|
||||
iteration_nr = state.iteration_nr
|
||||
state.current_step_nr
|
||||
parallelization_nr = state.parallelization_nr
|
||||
# node_start_time = datetime.now()
|
||||
# iteration_nr = state.iteration_nr
|
||||
# state.current_step_nr
|
||||
# parallelization_nr = state.parallelization_nr
|
||||
|
||||
search_query = state.branch_question
|
||||
if not search_query:
|
||||
raise ValueError("search_query is not set")
|
||||
# search_query = state.branch_question
|
||||
# if not search_query:
|
||||
# raise ValueError("search_query is not set")
|
||||
|
||||
logger.debug(
|
||||
f"Search start for KG Search {iteration_nr}.{parallelization_nr} at {datetime.now()}"
|
||||
)
|
||||
# logger.debug(
|
||||
# f"Search start for KG Search {iteration_nr}.{parallelization_nr} at {datetime.now()}"
|
||||
# )
|
||||
|
||||
if not state.available_tools:
|
||||
raise ValueError("available_tools is not set")
|
||||
# if not state.available_tools:
|
||||
# raise ValueError("available_tools is not set")
|
||||
|
||||
kg_tool_info = state.available_tools[state.tools_used[-1]]
|
||||
# kg_tool_info = state.available_tools[state.tools_used[-1]]
|
||||
|
||||
kb_graph = kb_graph_builder().compile()
|
||||
# kb_graph = kb_graph_builder().compile()
|
||||
|
||||
kb_results = kb_graph.invoke(
|
||||
input=KbMainInput(question=search_query, individual_flow=False),
|
||||
config=config,
|
||||
)
|
||||
# kb_results = kb_graph.invoke(
|
||||
# input=KbMainInput(question=search_query, individual_flow=False),
|
||||
# config=config,
|
||||
# )
|
||||
|
||||
# get cited documents
|
||||
answer_string = kb_results.get("final_answer") or "No answer provided"
|
||||
claims: list[str] = []
|
||||
retrieved_docs: list[InferenceSection] = kb_results.get("retrieved_documents", [])
|
||||
# # get cited documents
|
||||
# answer_string = kb_results.get("final_answer") or "No answer provided"
|
||||
# claims: list[str] = []
|
||||
# retrieved_docs: list[InferenceSection] = kb_results.get("retrieved_documents", [])
|
||||
|
||||
(
|
||||
citation_numbers,
|
||||
answer_string,
|
||||
claims,
|
||||
) = extract_document_citations(answer_string, claims)
|
||||
# (
|
||||
# citation_numbers,
|
||||
# answer_string,
|
||||
# claims,
|
||||
# ) = extract_document_citations(answer_string, claims)
|
||||
|
||||
# if citation is empty, the answer must have come from the KG rather than a doc
|
||||
# in that case, simply cite the docs returned by the KG
|
||||
if not citation_numbers:
|
||||
citation_numbers = [i + 1 for i in range(len(retrieved_docs))]
|
||||
# # if citation is empty, the answer must have come from the KG rather than a doc
|
||||
# # in that case, simply cite the docs returned by the KG
|
||||
# if not citation_numbers:
|
||||
# citation_numbers = [i + 1 for i in range(len(retrieved_docs))]
|
||||
|
||||
cited_documents = {
|
||||
citation_number: retrieved_docs[citation_number - 1]
|
||||
for citation_number in citation_numbers
|
||||
if citation_number <= len(retrieved_docs)
|
||||
}
|
||||
# cited_documents = {
|
||||
# citation_number: retrieved_docs[citation_number - 1]
|
||||
# for citation_number in citation_numbers
|
||||
# if citation_number <= len(retrieved_docs)
|
||||
# }
|
||||
|
||||
return BranchUpdate(
|
||||
branch_iteration_responses=[
|
||||
IterationAnswer(
|
||||
tool=kg_tool_info.llm_path,
|
||||
tool_id=kg_tool_info.tool_id,
|
||||
iteration_nr=iteration_nr,
|
||||
parallelization_nr=parallelization_nr,
|
||||
question=search_query,
|
||||
answer=answer_string,
|
||||
claims=claims,
|
||||
cited_documents=cited_documents,
|
||||
reasoning=None,
|
||||
additional_data=None,
|
||||
)
|
||||
],
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="kg_search",
|
||||
node_name="searching",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
# return BranchUpdate(
|
||||
# branch_iteration_responses=[
|
||||
# IterationAnswer(
|
||||
# tool=kg_tool_info.llm_path,
|
||||
# tool_id=kg_tool_info.tool_id,
|
||||
# iteration_nr=iteration_nr,
|
||||
# parallelization_nr=parallelization_nr,
|
||||
# question=search_query,
|
||||
# answer=answer_string,
|
||||
# claims=claims,
|
||||
# cited_documents=cited_documents,
|
||||
# reasoning=None,
|
||||
# additional_data=None,
|
||||
# )
|
||||
# ],
|
||||
# log_messages=[
|
||||
# get_langgraph_node_log_string(
|
||||
# graph_component="kg_search",
|
||||
# node_name="searching",
|
||||
# node_start_time=node_start_time,
|
||||
# )
|
||||
# ],
|
||||
# )
|
||||
|
||||
@@ -1,121 +1,121 @@
|
||||
from datetime import datetime
|
||||
# from datetime import datetime
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
# from langchain_core.runnables import RunnableConfig
|
||||
# from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentUpdate
|
||||
from onyx.agents.agent_search.dr.utils import convert_inference_sections_to_search_docs
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.server.query_and_chat.streaming_models import ReasoningDelta
|
||||
from onyx.server.query_and_chat.streaming_models import ReasoningStart
|
||||
from onyx.server.query_and_chat.streaming_models import SearchToolDelta
|
||||
from onyx.server.query_and_chat.streaming_models import SearchToolStart
|
||||
from onyx.server.query_and_chat.streaming_models import SectionEnd
|
||||
from onyx.utils.logger import setup_logger
|
||||
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
|
||||
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentUpdate
|
||||
# from onyx.agents.agent_search.dr.utils import convert_inference_sections_to_search_docs
|
||||
# from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
# get_langgraph_node_log_string,
|
||||
# )
|
||||
# from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
# from onyx.server.query_and_chat.streaming_models import ReasoningDelta
|
||||
# from onyx.server.query_and_chat.streaming_models import ReasoningStart
|
||||
# from onyx.server.query_and_chat.streaming_models import SearchToolDelta
|
||||
# from onyx.server.query_and_chat.streaming_models import SearchToolStart
|
||||
# from onyx.server.query_and_chat.streaming_models import SectionEnd
|
||||
# from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
# logger = setup_logger()
|
||||
|
||||
_MAX_KG_STEAMED_ANSWER_LENGTH = 1000 # num characters
|
||||
# _MAX_KG_STEAMED_ANSWER_LENGTH = 1000 # num characters
|
||||
|
||||
|
||||
def kg_search_reducer(
|
||||
state: SubAgentMainState,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
) -> SubAgentUpdate:
|
||||
"""
|
||||
LangGraph node to perform a KG search as part of the DR process.
|
||||
"""
|
||||
# def kg_search_reducer(
|
||||
# state: SubAgentMainState,
|
||||
# config: RunnableConfig,
|
||||
# writer: StreamWriter = lambda _: None,
|
||||
# ) -> SubAgentUpdate:
|
||||
# """
|
||||
# LangGraph node to perform a KG search as part of the DR process.
|
||||
# """
|
||||
|
||||
node_start_time = datetime.now()
|
||||
# node_start_time = datetime.now()
|
||||
|
||||
branch_updates = state.branch_iteration_responses
|
||||
current_iteration = state.iteration_nr
|
||||
current_step_nr = state.current_step_nr
|
||||
# branch_updates = state.branch_iteration_responses
|
||||
# current_iteration = state.iteration_nr
|
||||
# current_step_nr = state.current_step_nr
|
||||
|
||||
new_updates = [
|
||||
update for update in branch_updates if update.iteration_nr == current_iteration
|
||||
]
|
||||
# new_updates = [
|
||||
# update for update in branch_updates if update.iteration_nr == current_iteration
|
||||
# ]
|
||||
|
||||
queries = [update.question for update in new_updates]
|
||||
doc_lists = [list(update.cited_documents.values()) for update in new_updates]
|
||||
# queries = [update.question for update in new_updates]
|
||||
# doc_lists = [list(update.cited_documents.values()) for update in new_updates]
|
||||
|
||||
doc_list = []
|
||||
# doc_list = []
|
||||
|
||||
for xs in doc_lists:
|
||||
for x in xs:
|
||||
doc_list.append(x)
|
||||
# for xs in doc_lists:
|
||||
# for x in xs:
|
||||
# doc_list.append(x)
|
||||
|
||||
retrieved_search_docs = convert_inference_sections_to_search_docs(doc_list)
|
||||
kg_answer = (
|
||||
"The Knowledge Graph Answer:\n\n" + new_updates[0].answer
|
||||
if len(queries) == 1
|
||||
else None
|
||||
)
|
||||
# retrieved_search_docs = convert_inference_sections_to_search_docs(doc_list)
|
||||
# kg_answer = (
|
||||
# "The Knowledge Graph Answer:\n\n" + new_updates[0].answer
|
||||
# if len(queries) == 1
|
||||
# else None
|
||||
# )
|
||||
|
||||
if len(retrieved_search_docs) > 0:
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
SearchToolStart(
|
||||
is_internet_search=False,
|
||||
),
|
||||
writer,
|
||||
)
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
SearchToolDelta(
|
||||
queries=queries,
|
||||
documents=retrieved_search_docs,
|
||||
),
|
||||
writer,
|
||||
)
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
SectionEnd(),
|
||||
writer,
|
||||
)
|
||||
# if len(retrieved_search_docs) > 0:
|
||||
# write_custom_event(
|
||||
# current_step_nr,
|
||||
# SearchToolStart(
|
||||
# is_internet_search=False,
|
||||
# ),
|
||||
# writer,
|
||||
# )
|
||||
# write_custom_event(
|
||||
# current_step_nr,
|
||||
# SearchToolDelta(
|
||||
# queries=queries,
|
||||
# documents=retrieved_search_docs,
|
||||
# ),
|
||||
# writer,
|
||||
# )
|
||||
# write_custom_event(
|
||||
# current_step_nr,
|
||||
# SectionEnd(),
|
||||
# writer,
|
||||
# )
|
||||
|
||||
current_step_nr += 1
|
||||
# current_step_nr += 1
|
||||
|
||||
if kg_answer is not None:
|
||||
# if kg_answer is not None:
|
||||
|
||||
kg_display_answer = (
|
||||
f"{kg_answer[:_MAX_KG_STEAMED_ANSWER_LENGTH]}..."
|
||||
if len(kg_answer) > _MAX_KG_STEAMED_ANSWER_LENGTH
|
||||
else kg_answer
|
||||
)
|
||||
# kg_display_answer = (
|
||||
# f"{kg_answer[:_MAX_KG_STEAMED_ANSWER_LENGTH]}..."
|
||||
# if len(kg_answer) > _MAX_KG_STEAMED_ANSWER_LENGTH
|
||||
# else kg_answer
|
||||
# )
|
||||
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
ReasoningStart(),
|
||||
writer,
|
||||
)
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
ReasoningDelta(reasoning=kg_display_answer),
|
||||
writer,
|
||||
)
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
SectionEnd(),
|
||||
writer,
|
||||
)
|
||||
# write_custom_event(
|
||||
# current_step_nr,
|
||||
# ReasoningStart(),
|
||||
# writer,
|
||||
# )
|
||||
# write_custom_event(
|
||||
# current_step_nr,
|
||||
# ReasoningDelta(reasoning=kg_display_answer),
|
||||
# writer,
|
||||
# )
|
||||
# write_custom_event(
|
||||
# current_step_nr,
|
||||
# SectionEnd(),
|
||||
# writer,
|
||||
# )
|
||||
|
||||
current_step_nr += 1
|
||||
# current_step_nr += 1
|
||||
|
||||
return SubAgentUpdate(
|
||||
iteration_responses=new_updates,
|
||||
current_step_nr=current_step_nr,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="kg_search",
|
||||
node_name="consolidation",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
# return SubAgentUpdate(
|
||||
# iteration_responses=new_updates,
|
||||
# current_step_nr=current_step_nr,
|
||||
# log_messages=[
|
||||
# get_langgraph_node_log_string(
|
||||
# graph_component="kg_search",
|
||||
# node_name="consolidation",
|
||||
# node_start_time=node_start_time,
|
||||
# )
|
||||
# ],
|
||||
# )
|
||||
|
||||
@@ -1,27 +1,27 @@
|
||||
from collections.abc import Hashable
|
||||
# from collections.abc import Hashable
|
||||
|
||||
from langgraph.types import Send
|
||||
# from langgraph.types import Send
|
||||
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
|
||||
# from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
|
||||
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
|
||||
|
||||
|
||||
def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
|
||||
return [
|
||||
Send(
|
||||
"act",
|
||||
BranchInput(
|
||||
iteration_nr=state.iteration_nr,
|
||||
parallelization_nr=parallelization_nr,
|
||||
branch_question=query,
|
||||
context="",
|
||||
tools_used=state.tools_used,
|
||||
available_tools=state.available_tools,
|
||||
assistant_system_prompt=state.assistant_system_prompt,
|
||||
assistant_task_prompt=state.assistant_task_prompt,
|
||||
),
|
||||
)
|
||||
for parallelization_nr, query in enumerate(
|
||||
state.query_list[:1] # no parallel search for now
|
||||
)
|
||||
]
|
||||
# def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
|
||||
# return [
|
||||
# Send(
|
||||
# "act",
|
||||
# BranchInput(
|
||||
# iteration_nr=state.iteration_nr,
|
||||
# parallelization_nr=parallelization_nr,
|
||||
# branch_question=query,
|
||||
# context="",
|
||||
# tools_used=state.tools_used,
|
||||
# available_tools=state.available_tools,
|
||||
# assistant_system_prompt=state.assistant_system_prompt,
|
||||
# assistant_task_prompt=state.assistant_task_prompt,
|
||||
# ),
|
||||
# )
|
||||
# for parallelization_nr, query in enumerate(
|
||||
# state.query_list[:1] # no parallel search for now
|
||||
# )
|
||||
# ]
|
||||
|
||||
@@ -1,50 +1,50 @@
|
||||
from langgraph.graph import END
|
||||
from langgraph.graph import START
|
||||
from langgraph.graph import StateGraph
|
||||
# from langgraph.graph import END
|
||||
# from langgraph.graph import START
|
||||
# from langgraph.graph import StateGraph
|
||||
|
||||
from onyx.agents.agent_search.dr.sub_agents.kg_search.dr_kg_search_1_branch import (
|
||||
kg_search_branch,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.kg_search.dr_kg_search_2_act import (
|
||||
kg_search,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.kg_search.dr_kg_search_3_reduce import (
|
||||
kg_search_reducer,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.kg_search.dr_kg_search_conditional_edges import (
|
||||
branching_router,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
|
||||
from onyx.utils.logger import setup_logger
|
||||
# from onyx.agents.agent_search.dr.sub_agents.kg_search.dr_kg_search_1_branch import (
|
||||
# kg_search_branch,
|
||||
# )
|
||||
# from onyx.agents.agent_search.dr.sub_agents.kg_search.dr_kg_search_2_act import (
|
||||
# kg_search,
|
||||
# )
|
||||
# from onyx.agents.agent_search.dr.sub_agents.kg_search.dr_kg_search_3_reduce import (
|
||||
# kg_search_reducer,
|
||||
# )
|
||||
# from onyx.agents.agent_search.dr.sub_agents.kg_search.dr_kg_search_conditional_edges import (
|
||||
# branching_router,
|
||||
# )
|
||||
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
|
||||
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
|
||||
# from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
# logger = setup_logger()
|
||||
|
||||
|
||||
def dr_kg_search_graph_builder() -> StateGraph:
|
||||
"""
|
||||
LangGraph graph builder for KG Search Sub-Agent
|
||||
"""
|
||||
# def dr_kg_search_graph_builder() -> StateGraph:
|
||||
# """
|
||||
# LangGraph graph builder for KG Search Sub-Agent
|
||||
# """
|
||||
|
||||
graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
|
||||
# graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
|
||||
|
||||
### Add nodes ###
|
||||
# ### Add nodes ###
|
||||
|
||||
graph.add_node("branch", kg_search_branch)
|
||||
# graph.add_node("branch", kg_search_branch)
|
||||
|
||||
graph.add_node("act", kg_search)
|
||||
# graph.add_node("act", kg_search)
|
||||
|
||||
graph.add_node("reducer", kg_search_reducer)
|
||||
# graph.add_node("reducer", kg_search_reducer)
|
||||
|
||||
### Add edges ###
|
||||
# ### Add edges ###
|
||||
|
||||
graph.add_edge(start_key=START, end_key="branch")
|
||||
# graph.add_edge(start_key=START, end_key="branch")
|
||||
|
||||
graph.add_conditional_edges("branch", branching_router)
|
||||
# graph.add_conditional_edges("branch", branching_router)
|
||||
|
||||
graph.add_edge(start_key="act", end_key="reducer")
|
||||
# graph.add_edge(start_key="act", end_key="reducer")
|
||||
|
||||
graph.add_edge(start_key="reducer", end_key=END)
|
||||
# graph.add_edge(start_key="reducer", end_key=END)
|
||||
|
||||
return graph
|
||||
# return graph
|
||||
|
||||
@@ -1,46 +1,46 @@
|
||||
from operator import add
|
||||
from typing import Annotated
|
||||
# from operator import add
|
||||
# from typing import Annotated
|
||||
|
||||
from onyx.agents.agent_search.dr.models import IterationAnswer
|
||||
from onyx.agents.agent_search.dr.models import OrchestratorTool
|
||||
from onyx.agents.agent_search.dr.states import LoggerUpdate
|
||||
from onyx.db.connector import DocumentSource
|
||||
# from onyx.agents.agent_search.dr.models import IterationAnswer
|
||||
# from onyx.agents.agent_search.dr.models import OrchestratorTool
|
||||
# from onyx.agents.agent_search.dr.states import LoggerUpdate
|
||||
# from onyx.db.connector import DocumentSource
|
||||
|
||||
|
||||
class SubAgentUpdate(LoggerUpdate):
|
||||
iteration_responses: Annotated[list[IterationAnswer], add] = []
|
||||
current_step_nr: int = 1
|
||||
# class SubAgentUpdate(LoggerUpdate):
|
||||
# iteration_responses: Annotated[list[IterationAnswer], add] = []
|
||||
# current_step_nr: int = 1
|
||||
|
||||
|
||||
class BranchUpdate(LoggerUpdate):
|
||||
branch_iteration_responses: Annotated[list[IterationAnswer], add] = []
|
||||
# class BranchUpdate(LoggerUpdate):
|
||||
# branch_iteration_responses: Annotated[list[IterationAnswer], add] = []
|
||||
|
||||
|
||||
class SubAgentInput(LoggerUpdate):
|
||||
iteration_nr: int = 0
|
||||
current_step_nr: int = 1
|
||||
query_list: list[str] = []
|
||||
context: str | None = None
|
||||
active_source_types: list[DocumentSource] | None = None
|
||||
tools_used: Annotated[list[str], add] = []
|
||||
available_tools: dict[str, OrchestratorTool] | None = None
|
||||
assistant_system_prompt: str | None = None
|
||||
assistant_task_prompt: str | None = None
|
||||
# class SubAgentInput(LoggerUpdate):
|
||||
# iteration_nr: int = 0
|
||||
# current_step_nr: int = 1
|
||||
# query_list: list[str] = []
|
||||
# context: str | None = None
|
||||
# active_source_types: list[DocumentSource] | None = None
|
||||
# tools_used: Annotated[list[str], add] = []
|
||||
# available_tools: dict[str, OrchestratorTool] | None = None
|
||||
# assistant_system_prompt: str | None = None
|
||||
# assistant_task_prompt: str | None = None
|
||||
|
||||
|
||||
class SubAgentMainState(
|
||||
# This includes the core state
|
||||
SubAgentInput,
|
||||
SubAgentUpdate,
|
||||
BranchUpdate,
|
||||
):
|
||||
pass
|
||||
# class SubAgentMainState(
|
||||
# # This includes the core state
|
||||
# SubAgentInput,
|
||||
# SubAgentUpdate,
|
||||
# BranchUpdate,
|
||||
# ):
|
||||
# pass
|
||||
|
||||
|
||||
class BranchInput(SubAgentInput):
|
||||
parallelization_nr: int = 0
|
||||
branch_question: str
|
||||
# class BranchInput(SubAgentInput):
|
||||
# parallelization_nr: int = 0
|
||||
# branch_question: str
|
||||
|
||||
|
||||
class CustomToolBranchInput(LoggerUpdate):
|
||||
tool_info: OrchestratorTool
|
||||
# class CustomToolBranchInput(LoggerUpdate):
|
||||
# tool_info: OrchestratorTool
|
||||
|
||||
@@ -1,70 +1,71 @@
|
||||
from collections.abc import Sequence
|
||||
# from collections.abc import Sequence
|
||||
|
||||
from exa_py import Exa
|
||||
from exa_py.api import HighlightsContentsOptions
|
||||
# from exa_py import Exa
|
||||
# from exa_py.api import HighlightsContentsOptions
|
||||
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
|
||||
WebContent,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
|
||||
WebSearchProvider,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
|
||||
WebSearchResult,
|
||||
)
|
||||
from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
|
||||
from onyx.utils.retry_wrapper import retry_builder
|
||||
# from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
|
||||
# WebContent,
|
||||
# )
|
||||
# from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
|
||||
# WebSearchProvider,
|
||||
# )
|
||||
# from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
|
||||
# WebSearchResult,
|
||||
# )
|
||||
# from onyx.configs.chat_configs import EXA_API_KEY
|
||||
# from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
|
||||
# from onyx.utils.retry_wrapper import retry_builder
|
||||
|
||||
|
||||
class ExaClient(WebSearchProvider):
|
||||
def __init__(self, api_key: str) -> None:
|
||||
self.exa = Exa(api_key=api_key)
|
||||
# class ExaClient(WebSearchProvider):
|
||||
# def __init__(self, api_key: str | None = EXA_API_KEY) -> None:
|
||||
# self.exa = Exa(api_key=api_key)
|
||||
|
||||
@retry_builder(tries=3, delay=1, backoff=2)
|
||||
def search(self, query: str) -> list[WebSearchResult]:
|
||||
response = self.exa.search_and_contents(
|
||||
query,
|
||||
type="auto",
|
||||
highlights=HighlightsContentsOptions(
|
||||
num_sentences=2,
|
||||
highlights_per_url=1,
|
||||
),
|
||||
num_results=10,
|
||||
)
|
||||
# @retry_builder(tries=3, delay=1, backoff=2)
|
||||
# def search(self, query: str) -> list[WebSearchResult]:
|
||||
# response = self.exa.search_and_contents(
|
||||
# query,
|
||||
# type="auto",
|
||||
# highlights=HighlightsContentsOptions(
|
||||
# num_sentences=2,
|
||||
# highlights_per_url=1,
|
||||
# ),
|
||||
# num_results=10,
|
||||
# )
|
||||
|
||||
return [
|
||||
WebSearchResult(
|
||||
title=result.title or "",
|
||||
link=result.url,
|
||||
snippet=result.highlights[0] if result.highlights else "",
|
||||
author=result.author,
|
||||
published_date=(
|
||||
time_str_to_utc(result.published_date)
|
||||
if result.published_date
|
||||
else None
|
||||
),
|
||||
)
|
||||
for result in response.results
|
||||
]
|
||||
# return [
|
||||
# WebSearchResult(
|
||||
# title=result.title or "",
|
||||
# link=result.url,
|
||||
# snippet=result.highlights[0] if result.highlights else "",
|
||||
# author=result.author,
|
||||
# published_date=(
|
||||
# time_str_to_utc(result.published_date)
|
||||
# if result.published_date
|
||||
# else None
|
||||
# ),
|
||||
# )
|
||||
# for result in response.results
|
||||
# ]
|
||||
|
||||
@retry_builder(tries=3, delay=1, backoff=2)
|
||||
def contents(self, urls: Sequence[str]) -> list[WebContent]:
|
||||
response = self.exa.get_contents(
|
||||
urls=list(urls),
|
||||
text=True,
|
||||
livecrawl="preferred",
|
||||
)
|
||||
# @retry_builder(tries=3, delay=1, backoff=2)
|
||||
# def contents(self, urls: Sequence[str]) -> list[WebContent]:
|
||||
# response = self.exa.get_contents(
|
||||
# urls=list(urls),
|
||||
# text=True,
|
||||
# livecrawl="preferred",
|
||||
# )
|
||||
|
||||
return [
|
||||
WebContent(
|
||||
title=result.title or "",
|
||||
link=result.url,
|
||||
full_content=result.text or "",
|
||||
published_date=(
|
||||
time_str_to_utc(result.published_date)
|
||||
if result.published_date
|
||||
else None
|
||||
),
|
||||
)
|
||||
for result in response.results
|
||||
]
|
||||
# return [
|
||||
# WebContent(
|
||||
# title=result.title or "",
|
||||
# link=result.url,
|
||||
# full_content=result.text or "",
|
||||
# published_date=(
|
||||
# time_str_to_utc(result.published_date)
|
||||
# if result.published_date
|
||||
# else None
|
||||
# ),
|
||||
# )
|
||||
# for result in response.results
|
||||
# ]
|
||||
|
||||
@@ -1,159 +1,148 @@
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
# import json
|
||||
# from collections.abc import Sequence
|
||||
# from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
import requests
|
||||
# import requests
|
||||
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
|
||||
WebContent,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
|
||||
WebSearchProvider,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
|
||||
WebSearchResult,
|
||||
)
|
||||
from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
|
||||
from onyx.utils.retry_wrapper import retry_builder
|
||||
# from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
|
||||
# WebContent,
|
||||
# )
|
||||
# from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
|
||||
# WebSearchProvider,
|
||||
# )
|
||||
# from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
|
||||
# WebSearchResult,
|
||||
# )
|
||||
# from onyx.configs.chat_configs import SERPER_API_KEY
|
||||
# from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
|
||||
# from onyx.utils.retry_wrapper import retry_builder
|
||||
|
||||
SERPER_SEARCH_URL = "https://google.serper.dev/search"
|
||||
SERPER_CONTENTS_URL = "https://scrape.serper.dev"
|
||||
# SERPER_SEARCH_URL = "https://google.serper.dev/search"
|
||||
# SERPER_CONTENTS_URL = "https://scrape.serper.dev"
|
||||
|
||||
|
||||
class SerperClient(WebSearchProvider):
|
||||
def __init__(self, api_key: str) -> None:
|
||||
self.headers = {
|
||||
"X-API-KEY": api_key,
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
# class SerperClient(WebSearchProvider):
|
||||
# def __init__(self, api_key: str | None = SERPER_API_KEY) -> None:
|
||||
# self.headers = {
|
||||
# "X-API-KEY": api_key,
|
||||
# "Content-Type": "application/json",
|
||||
# }
|
||||
|
||||
@retry_builder(tries=3, delay=1, backoff=2)
|
||||
def search(self, query: str) -> list[WebSearchResult]:
|
||||
payload = {
|
||||
"q": query,
|
||||
}
|
||||
# @retry_builder(tries=3, delay=1, backoff=2)
|
||||
# def search(self, query: str) -> list[WebSearchResult]:
|
||||
# payload = {
|
||||
# "q": query,
|
||||
# }
|
||||
|
||||
response = requests.post(
|
||||
SERPER_SEARCH_URL,
|
||||
headers=self.headers,
|
||||
data=json.dumps(payload),
|
||||
)
|
||||
# response = requests.post(
|
||||
# SERPER_SEARCH_URL,
|
||||
# headers=self.headers,
|
||||
# data=json.dumps(payload),
|
||||
# )
|
||||
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except Exception:
|
||||
# Avoid leaking API keys/URLs
|
||||
raise ValueError(
|
||||
"Serper search failed. Check credentials or quota."
|
||||
) from None
|
||||
# response.raise_for_status()
|
||||
|
||||
results = response.json()
|
||||
organic_results = results["organic"]
|
||||
# results = response.json()
|
||||
# organic_results = results["organic"]
|
||||
|
||||
return [
|
||||
WebSearchResult(
|
||||
title=result["title"],
|
||||
link=result["link"],
|
||||
snippet=result["snippet"],
|
||||
author=None,
|
||||
published_date=None,
|
||||
)
|
||||
for result in organic_results
|
||||
]
|
||||
# return [
|
||||
# WebSearchResult(
|
||||
# title=result["title"],
|
||||
# link=result["link"],
|
||||
# snippet=result["snippet"],
|
||||
# author=None,
|
||||
# published_date=None,
|
||||
# )
|
||||
# for result in organic_results
|
||||
# ]
|
||||
|
||||
def contents(self, urls: Sequence[str]) -> list[WebContent]:
|
||||
if not urls:
|
||||
return []
|
||||
# def contents(self, urls: Sequence[str]) -> list[WebContent]:
|
||||
# if not urls:
|
||||
# return []
|
||||
|
||||
# Serper can responds with 500s regularly. We want to retry,
|
||||
# but in the event of failure, return an unsuccesful scrape.
|
||||
def safe_get_webpage_content(url: str) -> WebContent:
|
||||
try:
|
||||
return self._get_webpage_content(url)
|
||||
except Exception:
|
||||
return WebContent(
|
||||
title="",
|
||||
link=url,
|
||||
full_content="",
|
||||
published_date=None,
|
||||
scrape_successful=False,
|
||||
)
|
||||
# # Serper can responds with 500s regularly. We want to retry,
|
||||
# # but in the event of failure, return an unsuccesful scrape.
|
||||
# def safe_get_webpage_content(url: str) -> WebContent:
|
||||
# try:
|
||||
# return self._get_webpage_content(url)
|
||||
# except Exception:
|
||||
# return WebContent(
|
||||
# title="",
|
||||
# link=url,
|
||||
# full_content="",
|
||||
# published_date=None,
|
||||
# scrape_successful=False,
|
||||
# )
|
||||
|
||||
with ThreadPoolExecutor(max_workers=min(8, len(urls))) as e:
|
||||
return list(e.map(safe_get_webpage_content, urls))
|
||||
# with ThreadPoolExecutor(max_workers=min(8, len(urls))) as e:
|
||||
# return list(e.map(safe_get_webpage_content, urls))
|
||||
|
||||
@retry_builder(tries=3, delay=1, backoff=2)
|
||||
def _get_webpage_content(self, url: str) -> WebContent:
|
||||
payload = {
|
||||
"url": url,
|
||||
}
|
||||
# @retry_builder(tries=3, delay=1, backoff=2)
|
||||
# def _get_webpage_content(self, url: str) -> WebContent:
|
||||
# payload = {
|
||||
# "url": url,
|
||||
# }
|
||||
|
||||
response = requests.post(
|
||||
SERPER_CONTENTS_URL,
|
||||
headers=self.headers,
|
||||
data=json.dumps(payload),
|
||||
)
|
||||
# response = requests.post(
|
||||
# SERPER_CONTENTS_URL,
|
||||
# headers=self.headers,
|
||||
# data=json.dumps(payload),
|
||||
# )
|
||||
|
||||
# 400 returned when serper cannot scrape
|
||||
if response.status_code == 400:
|
||||
return WebContent(
|
||||
title="",
|
||||
link=url,
|
||||
full_content="",
|
||||
published_date=None,
|
||||
scrape_successful=False,
|
||||
)
|
||||
# # 400 returned when serper cannot scrape
|
||||
# if response.status_code == 400:
|
||||
# return WebContent(
|
||||
# title="",
|
||||
# link=url,
|
||||
# full_content="",
|
||||
# published_date=None,
|
||||
# scrape_successful=False,
|
||||
# )
|
||||
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except Exception:
|
||||
# Avoid leaking API keys/URLs
|
||||
raise ValueError(
|
||||
"Serper content fetch failed. Check credentials."
|
||||
) from None
|
||||
# response.raise_for_status()
|
||||
|
||||
response_json = response.json()
|
||||
# response_json = response.json()
|
||||
|
||||
# Response only guarantees text
|
||||
text = response_json["text"]
|
||||
# # Response only guarantees text
|
||||
# text = response_json["text"]
|
||||
|
||||
# metadata & jsonld is not guaranteed to be present
|
||||
metadata = response_json.get("metadata", {})
|
||||
jsonld = response_json.get("jsonld", {})
|
||||
# # metadata & jsonld is not guaranteed to be present
|
||||
# metadata = response_json.get("metadata", {})
|
||||
# jsonld = response_json.get("jsonld", {})
|
||||
|
||||
title = extract_title_from_metadata(metadata)
|
||||
# title = extract_title_from_metadata(metadata)
|
||||
|
||||
# Serper does not provide a reliable mechanism to extract the url
|
||||
response_url = url
|
||||
published_date_str = extract_published_date_from_jsonld(jsonld)
|
||||
published_date = None
|
||||
# # Serper does not provide a reliable mechanism to extract the url
|
||||
# response_url = url
|
||||
# published_date_str = extract_published_date_from_jsonld(jsonld)
|
||||
# published_date = None
|
||||
|
||||
if published_date_str:
|
||||
try:
|
||||
published_date = time_str_to_utc(published_date_str)
|
||||
except Exception:
|
||||
published_date = None
|
||||
# if published_date_str:
|
||||
# try:
|
||||
# published_date = time_str_to_utc(published_date_str)
|
||||
# except Exception:
|
||||
# published_date = None
|
||||
|
||||
return WebContent(
|
||||
title=title or "",
|
||||
link=response_url,
|
||||
full_content=text or "",
|
||||
published_date=published_date,
|
||||
)
|
||||
# return WebContent(
|
||||
# title=title or "",
|
||||
# link=response_url,
|
||||
# full_content=text or "",
|
||||
# published_date=published_date,
|
||||
# )
|
||||
|
||||
|
||||
def extract_title_from_metadata(metadata: dict[str, str]) -> str | None:
|
||||
keys = ["title", "og:title"]
|
||||
return extract_value_from_dict(metadata, keys)
|
||||
# def extract_title_from_metadata(metadata: dict[str, str]) -> str | None:
|
||||
# keys = ["title", "og:title"]
|
||||
# return extract_value_from_dict(metadata, keys)
|
||||
|
||||
|
||||
def extract_published_date_from_jsonld(jsonld: dict[str, str]) -> str | None:
|
||||
keys = ["dateModified"]
|
||||
return extract_value_from_dict(jsonld, keys)
|
||||
# def extract_published_date_from_jsonld(jsonld: dict[str, str]) -> str | None:
|
||||
# keys = ["dateModified"]
|
||||
# return extract_value_from_dict(jsonld, keys)
|
||||
|
||||
|
||||
def extract_value_from_dict(data: dict[str, str], keys: list[str]) -> str | None:
|
||||
for key in keys:
|
||||
if key in data:
|
||||
return data[key]
|
||||
return None
|
||||
# def extract_value_from_dict(data: dict[str, str], keys: list[str]) -> str | None:
|
||||
# for key in keys:
|
||||
# if key in data:
|
||||
# return data[key]
|
||||
# return None
|
||||
|
||||
@@ -1,47 +1,47 @@
|
||||
from datetime import datetime
|
||||
# from datetime import datetime
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
# from langchain_core.runnables import RunnableConfig
|
||||
# from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.dr.states import LoggerUpdate
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.server.query_and_chat.streaming_models import SearchToolStart
|
||||
from onyx.utils.logger import setup_logger
|
||||
# from onyx.agents.agent_search.dr.states import LoggerUpdate
|
||||
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
|
||||
# from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
# get_langgraph_node_log_string,
|
||||
# )
|
||||
# from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
# from onyx.server.query_and_chat.streaming_models import SearchToolStart
|
||||
# from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
# logger = setup_logger()
|
||||
|
||||
|
||||
def is_branch(
|
||||
state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> LoggerUpdate:
|
||||
"""
|
||||
LangGraph node to perform a web search as part of the DR process.
|
||||
"""
|
||||
# def is_branch(
|
||||
# state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
# ) -> LoggerUpdate:
|
||||
# """
|
||||
# LangGraph node to perform a web search as part of the DR process.
|
||||
# """
|
||||
|
||||
node_start_time = datetime.now()
|
||||
iteration_nr = state.iteration_nr
|
||||
current_step_nr = state.current_step_nr
|
||||
# node_start_time = datetime.now()
|
||||
# iteration_nr = state.iteration_nr
|
||||
# current_step_nr = state.current_step_nr
|
||||
|
||||
logger.debug(f"Search start for Web Search {iteration_nr} at {datetime.now()}")
|
||||
# logger.debug(f"Search start for Web Search {iteration_nr} at {datetime.now()}")
|
||||
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
SearchToolStart(
|
||||
is_internet_search=True,
|
||||
),
|
||||
writer,
|
||||
)
|
||||
# write_custom_event(
|
||||
# current_step_nr,
|
||||
# SearchToolStart(
|
||||
# is_internet_search=True,
|
||||
# ),
|
||||
# writer,
|
||||
# )
|
||||
|
||||
return LoggerUpdate(
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="internet_search",
|
||||
node_name="branching",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
# return LoggerUpdate(
|
||||
# log_messages=[
|
||||
# get_langgraph_node_log_string(
|
||||
# graph_component="internet_search",
|
||||
# node_name="branching",
|
||||
# node_start_time=node_start_time,
|
||||
# )
|
||||
# ],
|
||||
# )
|
||||
|
||||
@@ -1,137 +1,128 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
# from datetime import datetime
|
||||
# from typing import cast
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
from langsmith import traceable
|
||||
# from langchain_core.runnables import RunnableConfig
|
||||
# from langgraph.types import StreamWriter
|
||||
# from langsmith import traceable
|
||||
|
||||
from onyx.agents.agent_search.dr.models import WebSearchAnswer
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
|
||||
WebSearchResult,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.providers import (
|
||||
get_default_provider,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.states import (
|
||||
InternetSearchInput,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.states import (
|
||||
InternetSearchUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.llm import invoke_llm_json
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.agents.agent_search.utils import create_question_prompt
|
||||
from onyx.configs.agent_configs import TF_DR_TIMEOUT_SHORT
|
||||
from onyx.prompts.dr_prompts import WEB_SEARCH_URL_SELECTION_PROMPT
|
||||
from onyx.server.query_and_chat.streaming_models import SearchToolDelta
|
||||
from onyx.utils.logger import setup_logger
|
||||
# from onyx.agents.agent_search.dr.models import WebSearchAnswer
|
||||
# from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
|
||||
# WebSearchResult,
|
||||
# )
|
||||
# from onyx.agents.agent_search.dr.sub_agents.web_search.providers import (
|
||||
# get_default_provider,
|
||||
# )
|
||||
# from onyx.agents.agent_search.dr.sub_agents.web_search.states import (
|
||||
# InternetSearchInput,
|
||||
# )
|
||||
# from onyx.agents.agent_search.dr.sub_agents.web_search.states import (
|
||||
# InternetSearchUpdate,
|
||||
# )
|
||||
# from onyx.agents.agent_search.models import GraphConfig
|
||||
# from onyx.agents.agent_search.shared_graph_utils.llm import invoke_llm_json
|
||||
# from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
# get_langgraph_node_log_string,
|
||||
# )
|
||||
# from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
# from onyx.agents.agent_search.utils import create_question_prompt
|
||||
# from onyx.configs.agent_configs import TF_DR_TIMEOUT_SHORT
|
||||
# from onyx.prompts.dr_prompts import WEB_SEARCH_URL_SELECTION_PROMPT
|
||||
# from onyx.server.query_and_chat.streaming_models import SearchToolDelta
|
||||
# from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
# logger = setup_logger()
|
||||
|
||||
|
||||
def web_search(
|
||||
state: InternetSearchInput,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
) -> InternetSearchUpdate:
|
||||
"""
|
||||
LangGraph node to perform internet search and decide which URLs to fetch.
|
||||
"""
|
||||
# def web_search(
|
||||
# state: InternetSearchInput,
|
||||
# config: RunnableConfig,
|
||||
# writer: StreamWriter = lambda _: None,
|
||||
# ) -> InternetSearchUpdate:
|
||||
# """
|
||||
# LangGraph node to perform internet search and decide which URLs to fetch.
|
||||
# """
|
||||
|
||||
node_start_time = datetime.now()
|
||||
current_step_nr = state.current_step_nr
|
||||
# node_start_time = datetime.now()
|
||||
# current_step_nr = state.current_step_nr
|
||||
|
||||
if not current_step_nr:
|
||||
raise ValueError("Current step number is not set. This should not happen.")
|
||||
# if not current_step_nr:
|
||||
# raise ValueError("Current step number is not set. This should not happen.")
|
||||
|
||||
assistant_system_prompt = state.assistant_system_prompt
|
||||
assistant_task_prompt = state.assistant_task_prompt
|
||||
# assistant_system_prompt = state.assistant_system_prompt
|
||||
# assistant_task_prompt = state.assistant_task_prompt
|
||||
|
||||
if not state.available_tools:
|
||||
raise ValueError("available_tools is not set")
|
||||
search_query = state.branch_question
|
||||
# if not state.available_tools:
|
||||
# raise ValueError("available_tools is not set")
|
||||
# search_query = state.branch_question
|
||||
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
SearchToolDelta(
|
||||
queries=[search_query],
|
||||
documents=[],
|
||||
),
|
||||
writer,
|
||||
)
|
||||
# write_custom_event(
|
||||
# current_step_nr,
|
||||
# SearchToolDelta(
|
||||
# queries=[search_query],
|
||||
# documents=[],
|
||||
# ),
|
||||
# writer,
|
||||
# )
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
base_question = graph_config.inputs.prompt_builder.raw_user_query
|
||||
# graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
# base_question = graph_config.inputs.prompt_builder.raw_user_query
|
||||
|
||||
if graph_config.inputs.persona is None:
|
||||
raise ValueError("persona is not set")
|
||||
# if graph_config.inputs.persona is None:
|
||||
# raise ValueError("persona is not set")
|
||||
|
||||
provider = get_default_provider()
|
||||
if not provider:
|
||||
raise ValueError("No internet search provider found")
|
||||
# provider = get_default_provider()
|
||||
# if not provider:
|
||||
# raise ValueError("No internet search provider found")
|
||||
|
||||
# Log which provider type is being used
|
||||
provider_type = type(provider).__name__
|
||||
logger.info(
|
||||
f"Performing web search with {provider_type} for query: '{search_query}'"
|
||||
)
|
||||
# @traceable(name="Search Provider API Call")
|
||||
# def _search(search_query: str) -> list[WebSearchResult]:
|
||||
# search_results: list[WebSearchResult] = []
|
||||
# try:
|
||||
# search_results = list(provider.search(search_query))
|
||||
# except Exception as e:
|
||||
# logger.error(f"Error performing search: {e}")
|
||||
# return search_results
|
||||
|
||||
@traceable(name="Search Provider API Call")
|
||||
def _search(search_query: str) -> list[WebSearchResult]:
|
||||
search_results: list[WebSearchResult] = []
|
||||
try:
|
||||
search_results = list(provider.search(search_query))
|
||||
logger.info(
|
||||
f"Search returned {len(search_results)} results using {provider_type}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error performing search with {provider_type}: {e}")
|
||||
return search_results
|
||||
|
||||
search_results: list[WebSearchResult] = _search(search_query)
|
||||
search_results_text = "\n\n".join(
|
||||
[
|
||||
f"{i}. {result.title}\n URL: {result.link}\n"
|
||||
+ (f" Author: {result.author}\n" if result.author else "")
|
||||
+ (
|
||||
f" Date: {result.published_date.strftime('%Y-%m-%d')}\n"
|
||||
if result.published_date
|
||||
else ""
|
||||
)
|
||||
+ (f" Snippet: {result.snippet}\n" if result.snippet else "")
|
||||
for i, result in enumerate(search_results)
|
||||
]
|
||||
)
|
||||
agent_decision_prompt = WEB_SEARCH_URL_SELECTION_PROMPT.build(
|
||||
search_query=search_query,
|
||||
base_question=base_question,
|
||||
search_results_text=search_results_text,
|
||||
)
|
||||
agent_decision = invoke_llm_json(
|
||||
llm=graph_config.tooling.fast_llm,
|
||||
prompt=create_question_prompt(
|
||||
assistant_system_prompt,
|
||||
agent_decision_prompt + (assistant_task_prompt or ""),
|
||||
),
|
||||
schema=WebSearchAnswer,
|
||||
timeout_override=TF_DR_TIMEOUT_SHORT,
|
||||
)
|
||||
results_to_open = [
|
||||
(search_query, search_results[i])
|
||||
for i in agent_decision.urls_to_open_indices
|
||||
if i < len(search_results) and i >= 0
|
||||
]
|
||||
return InternetSearchUpdate(
|
||||
results_to_open=results_to_open,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="internet_search",
|
||||
node_name="searching",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
# search_results: list[WebSearchResult] = _search(search_query)
|
||||
# search_results_text = "\n\n".join(
|
||||
# [
|
||||
# f"{i}. {result.title}\n URL: {result.link}\n"
|
||||
# + (f" Author: {result.author}\n" if result.author else "")
|
||||
# + (
|
||||
# f" Date: {result.published_date.strftime('%Y-%m-%d')}\n"
|
||||
# if result.published_date
|
||||
# else ""
|
||||
# )
|
||||
# + (f" Snippet: {result.snippet}\n" if result.snippet else "")
|
||||
# for i, result in enumerate(search_results)
|
||||
# ]
|
||||
# )
|
||||
# agent_decision_prompt = WEB_SEARCH_URL_SELECTION_PROMPT.build(
|
||||
# search_query=search_query,
|
||||
# base_question=base_question,
|
||||
# search_results_text=search_results_text,
|
||||
# )
|
||||
# agent_decision = invoke_llm_json(
|
||||
# llm=graph_config.tooling.fast_llm,
|
||||
# prompt=create_question_prompt(
|
||||
# assistant_system_prompt,
|
||||
# agent_decision_prompt + (assistant_task_prompt or ""),
|
||||
# ),
|
||||
# schema=WebSearchAnswer,
|
||||
# timeout_override=TF_DR_TIMEOUT_SHORT,
|
||||
# )
|
||||
# results_to_open = [
|
||||
# (search_query, search_results[i])
|
||||
# for i in agent_decision.urls_to_open_indices
|
||||
# if i < len(search_results) and i >= 0
|
||||
# ]
|
||||
# return InternetSearchUpdate(
|
||||
# results_to_open=results_to_open,
|
||||
# log_messages=[
|
||||
# get_langgraph_node_log_string(
|
||||
# graph_component="internet_search",
|
||||
# node_name="searching",
|
||||
# node_start_time=node_start_time,
|
||||
# )
|
||||
# ],
|
||||
# )
|
||||
|
||||
@@ -1,52 +1,52 @@
|
||||
from collections import defaultdict
|
||||
# from collections import defaultdict
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
# from langchain_core.runnables import RunnableConfig
|
||||
# from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
|
||||
WebSearchResult,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.states import (
|
||||
InternetSearchInput,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.utils import (
|
||||
dummy_inference_section_from_internet_search_result,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.utils import convert_inference_sections_to_search_docs
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.server.query_and_chat.streaming_models import SearchToolDelta
|
||||
# from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
|
||||
# WebSearchResult,
|
||||
# )
|
||||
# from onyx.agents.agent_search.dr.sub_agents.web_search.states import (
|
||||
# InternetSearchInput,
|
||||
# )
|
||||
# from onyx.agents.agent_search.dr.sub_agents.web_search.utils import (
|
||||
# dummy_inference_section_from_internet_search_result,
|
||||
# )
|
||||
# from onyx.agents.agent_search.dr.utils import convert_inference_sections_to_search_docs
|
||||
# from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
# from onyx.server.query_and_chat.streaming_models import SearchToolDelta
|
||||
|
||||
|
||||
def dedup_urls(
|
||||
state: InternetSearchInput,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
) -> InternetSearchInput:
|
||||
branch_questions_to_urls: dict[str, list[str]] = defaultdict(list)
|
||||
unique_results_by_link: dict[str, WebSearchResult] = {}
|
||||
for query, result in state.results_to_open:
|
||||
branch_questions_to_urls[query].append(result.link)
|
||||
if result.link not in unique_results_by_link:
|
||||
unique_results_by_link[result.link] = result
|
||||
# def dedup_urls(
|
||||
# state: InternetSearchInput,
|
||||
# config: RunnableConfig,
|
||||
# writer: StreamWriter = lambda _: None,
|
||||
# ) -> InternetSearchInput:
|
||||
# branch_questions_to_urls: dict[str, list[str]] = defaultdict(list)
|
||||
# unique_results_by_link: dict[str, WebSearchResult] = {}
|
||||
# for query, result in state.results_to_open:
|
||||
# branch_questions_to_urls[query].append(result.link)
|
||||
# if result.link not in unique_results_by_link:
|
||||
# unique_results_by_link[result.link] = result
|
||||
|
||||
unique_results = list(unique_results_by_link.values())
|
||||
dummy_docs_inference_sections = [
|
||||
dummy_inference_section_from_internet_search_result(doc)
|
||||
for doc in unique_results
|
||||
]
|
||||
# unique_results = list(unique_results_by_link.values())
|
||||
# dummy_docs_inference_sections = [
|
||||
# dummy_inference_section_from_internet_search_result(doc)
|
||||
# for doc in unique_results
|
||||
# ]
|
||||
|
||||
write_custom_event(
|
||||
state.current_step_nr,
|
||||
SearchToolDelta(
|
||||
queries=[],
|
||||
documents=convert_inference_sections_to_search_docs(
|
||||
dummy_docs_inference_sections, is_internet=True
|
||||
),
|
||||
),
|
||||
writer,
|
||||
)
|
||||
# write_custom_event(
|
||||
# state.current_step_nr,
|
||||
# SearchToolDelta(
|
||||
# queries=[],
|
||||
# documents=convert_inference_sections_to_search_docs(
|
||||
# dummy_docs_inference_sections, is_internet=True
|
||||
# ),
|
||||
# ),
|
||||
# writer,
|
||||
# )
|
||||
|
||||
return InternetSearchInput(
|
||||
results_to_open=[],
|
||||
branch_questions_to_urls=branch_questions_to_urls,
|
||||
)
|
||||
# return InternetSearchInput(
|
||||
# results_to_open=[],
|
||||
# branch_questions_to_urls=branch_questions_to_urls,
|
||||
# )
|
||||
|
||||
@@ -1,69 +1,69 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
# from datetime import datetime
|
||||
# from typing import cast
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
# from langchain_core.runnables import RunnableConfig
|
||||
# from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.providers import (
|
||||
get_default_content_provider,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.states import FetchInput
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.states import FetchUpdate
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.utils import (
|
||||
dummy_inference_section_from_internet_content,
|
||||
)
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.utils.logger import setup_logger
|
||||
# from onyx.agents.agent_search.dr.sub_agents.web_search.providers import (
|
||||
# get_default_provider,
|
||||
# )
|
||||
# from onyx.agents.agent_search.dr.sub_agents.web_search.states import FetchInput
|
||||
# from onyx.agents.agent_search.dr.sub_agents.web_search.states import FetchUpdate
|
||||
# from onyx.agents.agent_search.dr.sub_agents.web_search.utils import (
|
||||
# inference_section_from_internet_page_scrape,
|
||||
# )
|
||||
# from onyx.agents.agent_search.models import GraphConfig
|
||||
# from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
# get_langgraph_node_log_string,
|
||||
# )
|
||||
# from onyx.context.search.models import InferenceSection
|
||||
# from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
# logger = setup_logger()
|
||||
|
||||
|
||||
def web_fetch(
|
||||
state: FetchInput,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
) -> FetchUpdate:
|
||||
"""
|
||||
LangGraph node to fetch content from URLs and process the results.
|
||||
"""
|
||||
# def web_fetch(
|
||||
# state: FetchInput,
|
||||
# config: RunnableConfig,
|
||||
# writer: StreamWriter = lambda _: None,
|
||||
# ) -> FetchUpdate:
|
||||
# """
|
||||
# LangGraph node to fetch content from URLs and process the results.
|
||||
# """
|
||||
|
||||
node_start_time = datetime.now()
|
||||
# node_start_time = datetime.now()
|
||||
|
||||
if not state.available_tools:
|
||||
raise ValueError("available_tools is not set")
|
||||
# if not state.available_tools:
|
||||
# raise ValueError("available_tools is not set")
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
# graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
|
||||
if graph_config.inputs.persona is None:
|
||||
raise ValueError("persona is not set")
|
||||
# if graph_config.inputs.persona is None:
|
||||
# raise ValueError("persona is not set")
|
||||
|
||||
provider = get_default_content_provider()
|
||||
if provider is None:
|
||||
raise ValueError("No web content provider found")
|
||||
# provider = get_default_provider()
|
||||
# if provider is None:
|
||||
# raise ValueError("No web search provider found")
|
||||
|
||||
retrieved_docs: list[InferenceSection] = []
|
||||
try:
|
||||
retrieved_docs = [
|
||||
dummy_inference_section_from_internet_content(result)
|
||||
for result in provider.contents(state.urls_to_open)
|
||||
]
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
# retrieved_docs: list[InferenceSection] = []
|
||||
# try:
|
||||
# retrieved_docs = [
|
||||
# inference_section_from_internet_page_scrape(result)
|
||||
# for result in provider.contents(state.urls_to_open)
|
||||
# ]
|
||||
# except Exception as e:
|
||||
# logger.error(f"Error fetching URLs: {e}")
|
||||
|
||||
if not retrieved_docs:
|
||||
logger.warning("No content retrieved from URLs")
|
||||
# if not retrieved_docs:
|
||||
# logger.warning("No content retrieved from URLs")
|
||||
|
||||
return FetchUpdate(
|
||||
raw_documents=retrieved_docs,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="internet_search",
|
||||
node_name="fetching",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
# return FetchUpdate(
|
||||
# raw_documents=retrieved_docs,
|
||||
# log_messages=[
|
||||
# get_langgraph_node_log_string(
|
||||
# graph_component="internet_search",
|
||||
# node_name="fetching",
|
||||
# node_start_time=node_start_time,
|
||||
# )
|
||||
# ],
|
||||
# )
|
||||
|
||||
@@ -1,19 +1,19 @@
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
# from langchain_core.runnables import RunnableConfig
|
||||
# from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.states import FetchInput
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.states import (
|
||||
InternetSearchInput,
|
||||
)
|
||||
# from onyx.agents.agent_search.dr.sub_agents.web_search.states import FetchInput
|
||||
# from onyx.agents.agent_search.dr.sub_agents.web_search.states import (
|
||||
# InternetSearchInput,
|
||||
# )
|
||||
|
||||
|
||||
def collect_raw_docs(
|
||||
state: FetchInput,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
) -> InternetSearchInput:
|
||||
raw_documents = state.raw_documents
|
||||
# def collect_raw_docs(
|
||||
# state: FetchInput,
|
||||
# config: RunnableConfig,
|
||||
# writer: StreamWriter = lambda _: None,
|
||||
# ) -> InternetSearchInput:
|
||||
# raw_documents = state.raw_documents
|
||||
|
||||
return InternetSearchInput(
|
||||
raw_documents=raw_documents,
|
||||
)
|
||||
# return InternetSearchInput(
|
||||
# raw_documents=raw_documents,
|
||||
# )
|
||||
|
||||
@@ -1,133 +1,132 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
# from datetime import datetime
|
||||
# from typing import cast
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
# from langchain_core.runnables import RunnableConfig
|
||||
# from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.dr.enums import ResearchType
|
||||
from onyx.agents.agent_search.dr.models import IterationAnswer
|
||||
from onyx.agents.agent_search.dr.models import SearchAnswer
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import BranchUpdate
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.states import SummarizeInput
|
||||
from onyx.agents.agent_search.dr.utils import extract_document_citations
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import build_document_context
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.llm import invoke_llm_json
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.utils import create_question_prompt
|
||||
from onyx.configs.agent_configs import TF_DR_TIMEOUT_SHORT
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.prompts.dr_prompts import INTERNAL_SEARCH_PROMPTS
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.url import normalize_url
|
||||
# from onyx.agents.agent_search.dr.models import IterationAnswer
|
||||
# from onyx.agents.agent_search.dr.models import SearchAnswer
|
||||
# from onyx.agents.agent_search.dr.sub_agents.states import BranchUpdate
|
||||
# from onyx.agents.agent_search.dr.sub_agents.web_search.states import SummarizeInput
|
||||
# from onyx.agents.agent_search.dr.utils import extract_document_citations
|
||||
# from onyx.agents.agent_search.kb_search.graph_utils import build_document_context
|
||||
# from onyx.agents.agent_search.models import GraphConfig
|
||||
# from onyx.agents.agent_search.shared_graph_utils.llm import invoke_llm_json
|
||||
# from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
# get_langgraph_node_log_string,
|
||||
# )
|
||||
# from onyx.agents.agent_search.utils import create_question_prompt
|
||||
# from onyx.configs.agent_configs import TF_DR_TIMEOUT_SHORT
|
||||
# from onyx.context.search.models import InferenceSection
|
||||
# from onyx.prompts.dr_prompts import INTERNAL_SEARCH_PROMPTS
|
||||
# from onyx.utils.logger import setup_logger
|
||||
# from onyx.utils.url import normalize_url
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
# logger = setup_logger()
|
||||
|
||||
|
||||
def is_summarize(
|
||||
state: SummarizeInput,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
) -> BranchUpdate:
|
||||
"""
|
||||
LangGraph node to perform a internet search as part of the DR process.
|
||||
"""
|
||||
# def is_summarize(
|
||||
# state: SummarizeInput,
|
||||
# config: RunnableConfig,
|
||||
# writer: StreamWriter = lambda _: None,
|
||||
# ) -> BranchUpdate:
|
||||
# """
|
||||
# LangGraph node to perform a internet search as part of the DR process.
|
||||
# """
|
||||
|
||||
node_start_time = datetime.now()
|
||||
# node_start_time = datetime.now()
|
||||
|
||||
# build branch iterations from fetch inputs
|
||||
# Normalize URLs to handle mismatches from query parameters (e.g., ?activeTab=explore)
|
||||
url_to_raw_document: dict[str, InferenceSection] = {}
|
||||
for raw_document in state.raw_documents:
|
||||
normalized_url = normalize_url(raw_document.center_chunk.semantic_identifier)
|
||||
url_to_raw_document[normalized_url] = raw_document
|
||||
# # build branch iterations from fetch inputs
|
||||
# # Normalize URLs to handle mismatches from query parameters (e.g., ?activeTab=explore)
|
||||
# url_to_raw_document: dict[str, InferenceSection] = {}
|
||||
# for raw_document in state.raw_documents:
|
||||
# normalized_url = normalize_url(raw_document.center_chunk.semantic_identifier)
|
||||
# url_to_raw_document[normalized_url] = raw_document
|
||||
|
||||
# Normalize the URLs from branch_questions_to_urls as well
|
||||
urls = [
|
||||
normalize_url(url)
|
||||
for url in state.branch_questions_to_urls[state.branch_question]
|
||||
]
|
||||
current_iteration = state.iteration_nr
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
research_type = graph_config.behavior.research_type
|
||||
if not state.available_tools:
|
||||
raise ValueError("available_tools is not set")
|
||||
is_tool_info = state.available_tools[state.tools_used[-1]]
|
||||
# # Normalize the URLs from branch_questions_to_urls as well
|
||||
# urls = [
|
||||
# normalize_url(url)
|
||||
# for url in state.branch_questions_to_urls[state.branch_question]
|
||||
# ]
|
||||
# current_iteration = state.iteration_nr
|
||||
# graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
# use_agentic_search = graph_config.behavior.use_agentic_search
|
||||
# if not state.available_tools:
|
||||
# raise ValueError("available_tools is not set")
|
||||
# is_tool_info = state.available_tools[state.tools_used[-1]]
|
||||
|
||||
if research_type == ResearchType.DEEP:
|
||||
cited_raw_documents = [url_to_raw_document[url] for url in urls]
|
||||
document_texts = _create_document_texts(cited_raw_documents)
|
||||
search_prompt = INTERNAL_SEARCH_PROMPTS[research_type].build(
|
||||
search_query=state.branch_question,
|
||||
base_question=graph_config.inputs.prompt_builder.raw_user_query,
|
||||
document_text=document_texts,
|
||||
)
|
||||
assistant_system_prompt = state.assistant_system_prompt
|
||||
assistant_task_prompt = state.assistant_task_prompt
|
||||
search_answer_json = invoke_llm_json(
|
||||
llm=graph_config.tooling.primary_llm,
|
||||
prompt=create_question_prompt(
|
||||
assistant_system_prompt, search_prompt + (assistant_task_prompt or "")
|
||||
),
|
||||
schema=SearchAnswer,
|
||||
timeout_override=TF_DR_TIMEOUT_SHORT,
|
||||
)
|
||||
answer_string = search_answer_json.answer
|
||||
claims = search_answer_json.claims or []
|
||||
reasoning = search_answer_json.reasoning or ""
|
||||
(
|
||||
citation_numbers,
|
||||
answer_string,
|
||||
claims,
|
||||
) = extract_document_citations(answer_string, claims)
|
||||
cited_documents = {
|
||||
citation_number: cited_raw_documents[citation_number - 1]
|
||||
for citation_number in citation_numbers
|
||||
}
|
||||
# if use_agentic_search:
|
||||
# cited_raw_documents = [url_to_raw_document[url] for url in urls]
|
||||
# document_texts = _create_document_texts(cited_raw_documents)
|
||||
# search_prompt = INTERNAL_SEARCH_PROMPTS[use_agentic_search].build(
|
||||
# search_query=state.branch_question,
|
||||
# base_question=graph_config.inputs.prompt_builder.raw_user_query,
|
||||
# document_text=document_texts,
|
||||
# )
|
||||
# assistant_system_prompt = state.assistant_system_prompt
|
||||
# assistant_task_prompt = state.assistant_task_prompt
|
||||
# search_answer_json = invoke_llm_json(
|
||||
# llm=graph_config.tooling.primary_llm,
|
||||
# prompt=create_question_prompt(
|
||||
# assistant_system_prompt, search_prompt + (assistant_task_prompt or "")
|
||||
# ),
|
||||
# schema=SearchAnswer,
|
||||
# timeout_override=TF_DR_TIMEOUT_SHORT,
|
||||
# )
|
||||
# answer_string = search_answer_json.answer
|
||||
# claims = search_answer_json.claims or []
|
||||
# reasoning = search_answer_json.reasoning or ""
|
||||
# (
|
||||
# citation_numbers,
|
||||
# answer_string,
|
||||
# claims,
|
||||
# ) = extract_document_citations(answer_string, claims)
|
||||
# cited_documents = {
|
||||
# citation_number: cited_raw_documents[citation_number - 1]
|
||||
# for citation_number in citation_numbers
|
||||
# }
|
||||
|
||||
else:
|
||||
answer_string = ""
|
||||
reasoning = ""
|
||||
claims = []
|
||||
cited_raw_documents = [url_to_raw_document[url] for url in urls]
|
||||
cited_documents = {
|
||||
doc_num + 1: retrieved_doc
|
||||
for doc_num, retrieved_doc in enumerate(cited_raw_documents)
|
||||
}
|
||||
# else:
|
||||
# answer_string = ""
|
||||
# reasoning = ""
|
||||
# claims = []
|
||||
# cited_raw_documents = [url_to_raw_document[url] for url in urls]
|
||||
# cited_documents = {
|
||||
# doc_num + 1: retrieved_doc
|
||||
# for doc_num, retrieved_doc in enumerate(cited_raw_documents)
|
||||
# }
|
||||
|
||||
return BranchUpdate(
|
||||
branch_iteration_responses=[
|
||||
IterationAnswer(
|
||||
tool=is_tool_info.llm_path,
|
||||
tool_id=is_tool_info.tool_id,
|
||||
iteration_nr=current_iteration,
|
||||
parallelization_nr=0,
|
||||
question=state.branch_question,
|
||||
answer=answer_string,
|
||||
claims=claims,
|
||||
cited_documents=cited_documents,
|
||||
reasoning=reasoning,
|
||||
additional_data=None,
|
||||
)
|
||||
],
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="internet_search",
|
||||
node_name="summarizing",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
# return BranchUpdate(
|
||||
# branch_iteration_responses=[
|
||||
# IterationAnswer(
|
||||
# tool=is_tool_info.llm_path,
|
||||
# tool_id=is_tool_info.tool_id,
|
||||
# iteration_nr=current_iteration,
|
||||
# parallelization_nr=0,
|
||||
# question=state.branch_question,
|
||||
# answer=answer_string,
|
||||
# claims=claims,
|
||||
# cited_documents=cited_documents,
|
||||
# reasoning=reasoning,
|
||||
# additional_data=None,
|
||||
# )
|
||||
# ],
|
||||
# log_messages=[
|
||||
# get_langgraph_node_log_string(
|
||||
# graph_component="internet_search",
|
||||
# node_name="summarizing",
|
||||
# node_start_time=node_start_time,
|
||||
# )
|
||||
# ],
|
||||
# )
|
||||
|
||||
|
||||
def _create_document_texts(raw_documents: list[InferenceSection]) -> str:
|
||||
document_texts_list = []
|
||||
for doc_num, retrieved_doc in enumerate(raw_documents):
|
||||
if not isinstance(retrieved_doc, InferenceSection):
|
||||
raise ValueError(f"Unexpected document type: {type(retrieved_doc)}")
|
||||
chunk_text = build_document_context(retrieved_doc, doc_num + 1)
|
||||
document_texts_list.append(chunk_text)
|
||||
return "\n\n".join(document_texts_list)
|
||||
# def _create_document_texts(raw_documents: list[InferenceSection]) -> str:
|
||||
# document_texts_list = []
|
||||
# for doc_num, retrieved_doc in enumerate(raw_documents):
|
||||
# if not isinstance(retrieved_doc, InferenceSection):
|
||||
# raise ValueError(f"Unexpected document type: {type(retrieved_doc)}")
|
||||
# chunk_text = build_document_context(retrieved_doc, doc_num + 1)
|
||||
# document_texts_list.append(chunk_text)
|
||||
# return "\n\n".join(document_texts_list)
|
||||
|
||||
@@ -1,56 +1,56 @@
|
||||
from datetime import datetime
|
||||
# from datetime import datetime
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
# from langchain_core.runnables import RunnableConfig
|
||||
# from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentUpdate
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.server.query_and_chat.streaming_models import SectionEnd
|
||||
from onyx.utils.logger import setup_logger
|
||||
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
|
||||
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentUpdate
|
||||
# from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
# get_langgraph_node_log_string,
|
||||
# )
|
||||
# from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
# from onyx.server.query_and_chat.streaming_models import SectionEnd
|
||||
# from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
# logger = setup_logger()
|
||||
|
||||
|
||||
def is_reducer(
|
||||
state: SubAgentMainState,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
) -> SubAgentUpdate:
|
||||
"""
|
||||
LangGraph node to perform a internet search as part of the DR process.
|
||||
"""
|
||||
# def is_reducer(
|
||||
# state: SubAgentMainState,
|
||||
# config: RunnableConfig,
|
||||
# writer: StreamWriter = lambda _: None,
|
||||
# ) -> SubAgentUpdate:
|
||||
# """
|
||||
# LangGraph node to perform a internet search as part of the DR process.
|
||||
# """
|
||||
|
||||
node_start_time = datetime.now()
|
||||
# node_start_time = datetime.now()
|
||||
|
||||
branch_updates = state.branch_iteration_responses
|
||||
current_iteration = state.iteration_nr
|
||||
current_step_nr = state.current_step_nr
|
||||
# branch_updates = state.branch_iteration_responses
|
||||
# current_iteration = state.iteration_nr
|
||||
# current_step_nr = state.current_step_nr
|
||||
|
||||
new_updates = [
|
||||
update for update in branch_updates if update.iteration_nr == current_iteration
|
||||
]
|
||||
# new_updates = [
|
||||
# update for update in branch_updates if update.iteration_nr == current_iteration
|
||||
# ]
|
||||
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
SectionEnd(),
|
||||
writer,
|
||||
)
|
||||
# write_custom_event(
|
||||
# current_step_nr,
|
||||
# SectionEnd(),
|
||||
# writer,
|
||||
# )
|
||||
|
||||
current_step_nr += 1
|
||||
# current_step_nr += 1
|
||||
|
||||
return SubAgentUpdate(
|
||||
iteration_responses=new_updates,
|
||||
current_step_nr=current_step_nr,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="internet_search",
|
||||
node_name="consolidation",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
# return SubAgentUpdate(
|
||||
# iteration_responses=new_updates,
|
||||
# current_step_nr=current_step_nr,
|
||||
# log_messages=[
|
||||
# get_langgraph_node_log_string(
|
||||
# graph_component="internet_search",
|
||||
# node_name="consolidation",
|
||||
# node_start_time=node_start_time,
|
||||
# )
|
||||
# ],
|
||||
# )
|
||||
|
||||
@@ -1,79 +1,79 @@
|
||||
from collections.abc import Hashable
|
||||
# from collections.abc import Hashable
|
||||
|
||||
from langgraph.types import Send
|
||||
# from langgraph.types import Send
|
||||
|
||||
from onyx.agents.agent_search.dr.constants import MAX_DR_PARALLEL_SEARCH
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.states import FetchInput
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.states import (
|
||||
InternetSearchInput,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.states import SummarizeInput
|
||||
# from onyx.agents.agent_search.dr.constants import MAX_DR_PARALLEL_SEARCH
|
||||
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
|
||||
# from onyx.agents.agent_search.dr.sub_agents.web_search.states import FetchInput
|
||||
# from onyx.agents.agent_search.dr.sub_agents.web_search.states import (
|
||||
# InternetSearchInput,
|
||||
# )
|
||||
# from onyx.agents.agent_search.dr.sub_agents.web_search.states import SummarizeInput
|
||||
|
||||
|
||||
def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
|
||||
return [
|
||||
Send(
|
||||
"search",
|
||||
InternetSearchInput(
|
||||
iteration_nr=state.iteration_nr,
|
||||
current_step_nr=state.current_step_nr,
|
||||
parallelization_nr=parallelization_nr,
|
||||
query_list=[query],
|
||||
branch_question=query,
|
||||
context="",
|
||||
tools_used=state.tools_used,
|
||||
available_tools=state.available_tools,
|
||||
assistant_system_prompt=state.assistant_system_prompt,
|
||||
assistant_task_prompt=state.assistant_task_prompt,
|
||||
results_to_open=[],
|
||||
),
|
||||
)
|
||||
for parallelization_nr, query in enumerate(
|
||||
state.query_list[:MAX_DR_PARALLEL_SEARCH]
|
||||
)
|
||||
]
|
||||
# def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
|
||||
# return [
|
||||
# Send(
|
||||
# "search",
|
||||
# InternetSearchInput(
|
||||
# iteration_nr=state.iteration_nr,
|
||||
# current_step_nr=state.current_step_nr,
|
||||
# parallelization_nr=parallelization_nr,
|
||||
# query_list=[query],
|
||||
# branch_question=query,
|
||||
# context="",
|
||||
# tools_used=state.tools_used,
|
||||
# available_tools=state.available_tools,
|
||||
# assistant_system_prompt=state.assistant_system_prompt,
|
||||
# assistant_task_prompt=state.assistant_task_prompt,
|
||||
# results_to_open=[],
|
||||
# ),
|
||||
# )
|
||||
# for parallelization_nr, query in enumerate(
|
||||
# state.query_list[:MAX_DR_PARALLEL_SEARCH]
|
||||
# )
|
||||
# ]
|
||||
|
||||
|
||||
def fetch_router(state: InternetSearchInput) -> list[Send | Hashable]:
|
||||
branch_questions_to_urls = state.branch_questions_to_urls
|
||||
return [
|
||||
Send(
|
||||
"fetch",
|
||||
FetchInput(
|
||||
iteration_nr=state.iteration_nr,
|
||||
urls_to_open=[url],
|
||||
tools_used=state.tools_used,
|
||||
available_tools=state.available_tools,
|
||||
assistant_system_prompt=state.assistant_system_prompt,
|
||||
assistant_task_prompt=state.assistant_task_prompt,
|
||||
current_step_nr=state.current_step_nr,
|
||||
branch_questions_to_urls=branch_questions_to_urls,
|
||||
raw_documents=state.raw_documents,
|
||||
),
|
||||
)
|
||||
for url in set(
|
||||
url for urls in branch_questions_to_urls.values() for url in urls
|
||||
)
|
||||
]
|
||||
# def fetch_router(state: InternetSearchInput) -> list[Send | Hashable]:
|
||||
# branch_questions_to_urls = state.branch_questions_to_urls
|
||||
# return [
|
||||
# Send(
|
||||
# "fetch",
|
||||
# FetchInput(
|
||||
# iteration_nr=state.iteration_nr,
|
||||
# urls_to_open=[url],
|
||||
# tools_used=state.tools_used,
|
||||
# available_tools=state.available_tools,
|
||||
# assistant_system_prompt=state.assistant_system_prompt,
|
||||
# assistant_task_prompt=state.assistant_task_prompt,
|
||||
# current_step_nr=state.current_step_nr,
|
||||
# branch_questions_to_urls=branch_questions_to_urls,
|
||||
# raw_documents=state.raw_documents,
|
||||
# ),
|
||||
# )
|
||||
# for url in set(
|
||||
# url for urls in branch_questions_to_urls.values() for url in urls
|
||||
# )
|
||||
# ]
|
||||
|
||||
|
||||
def summarize_router(state: InternetSearchInput) -> list[Send | Hashable]:
|
||||
branch_questions_to_urls = state.branch_questions_to_urls
|
||||
return [
|
||||
Send(
|
||||
"summarize",
|
||||
SummarizeInput(
|
||||
iteration_nr=state.iteration_nr,
|
||||
raw_documents=state.raw_documents,
|
||||
branch_questions_to_urls=branch_questions_to_urls,
|
||||
branch_question=branch_question,
|
||||
tools_used=state.tools_used,
|
||||
available_tools=state.available_tools,
|
||||
assistant_system_prompt=state.assistant_system_prompt,
|
||||
assistant_task_prompt=state.assistant_task_prompt,
|
||||
current_step_nr=state.current_step_nr,
|
||||
),
|
||||
)
|
||||
for branch_question in branch_questions_to_urls.keys()
|
||||
]
|
||||
# def summarize_router(state: InternetSearchInput) -> list[Send | Hashable]:
|
||||
# branch_questions_to_urls = state.branch_questions_to_urls
|
||||
# return [
|
||||
# Send(
|
||||
# "summarize",
|
||||
# SummarizeInput(
|
||||
# iteration_nr=state.iteration_nr,
|
||||
# raw_documents=state.raw_documents,
|
||||
# branch_questions_to_urls=branch_questions_to_urls,
|
||||
# branch_question=branch_question,
|
||||
# tools_used=state.tools_used,
|
||||
# available_tools=state.available_tools,
|
||||
# assistant_system_prompt=state.assistant_system_prompt,
|
||||
# assistant_task_prompt=state.assistant_task_prompt,
|
||||
# current_step_nr=state.current_step_nr,
|
||||
# ),
|
||||
# )
|
||||
# for branch_question in branch_questions_to_urls.keys()
|
||||
# ]
|
||||
|
||||
@@ -1,84 +1,84 @@
|
||||
from langgraph.graph import END
|
||||
from langgraph.graph import START
|
||||
from langgraph.graph import StateGraph
|
||||
# from langgraph.graph import END
|
||||
# from langgraph.graph import START
|
||||
# from langgraph.graph import StateGraph
|
||||
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.dr_ws_1_branch import (
|
||||
is_branch,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.dr_ws_2_search import (
|
||||
web_search,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.dr_ws_3_dedup_urls import (
|
||||
dedup_urls,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.dr_ws_4_fetch import (
|
||||
web_fetch,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.dr_ws_5_collect_raw_docs import (
|
||||
collect_raw_docs,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.dr_ws_6_summarize import (
|
||||
is_summarize,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.dr_ws_7_reduce import (
|
||||
is_reducer,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.dr_ws_conditional_edges import (
|
||||
branching_router,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.dr_ws_conditional_edges import (
|
||||
fetch_router,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.dr_ws_conditional_edges import (
|
||||
summarize_router,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
|
||||
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
|
||||
# from onyx.agents.agent_search.dr.sub_agents.web_search.dr_ws_1_branch import (
|
||||
# is_branch,
|
||||
# )
|
||||
# from onyx.agents.agent_search.dr.sub_agents.web_search.dr_ws_2_search import (
|
||||
# web_search,
|
||||
# )
|
||||
# from onyx.agents.agent_search.dr.sub_agents.web_search.dr_ws_3_dedup_urls import (
|
||||
# dedup_urls,
|
||||
# )
|
||||
# from onyx.agents.agent_search.dr.sub_agents.web_search.dr_ws_4_fetch import (
|
||||
# web_fetch,
|
||||
# )
|
||||
# from onyx.agents.agent_search.dr.sub_agents.web_search.dr_ws_5_collect_raw_docs import (
|
||||
# collect_raw_docs,
|
||||
# )
|
||||
# from onyx.agents.agent_search.dr.sub_agents.web_search.dr_ws_6_summarize import (
|
||||
# is_summarize,
|
||||
# )
|
||||
# from onyx.agents.agent_search.dr.sub_agents.web_search.dr_ws_7_reduce import (
|
||||
# is_reducer,
|
||||
# )
|
||||
# from onyx.agents.agent_search.dr.sub_agents.web_search.dr_ws_conditional_edges import (
|
||||
# branching_router,
|
||||
# )
|
||||
# from onyx.agents.agent_search.dr.sub_agents.web_search.dr_ws_conditional_edges import (
|
||||
# fetch_router,
|
||||
# )
|
||||
# from onyx.agents.agent_search.dr.sub_agents.web_search.dr_ws_conditional_edges import (
|
||||
# summarize_router,
|
||||
# )
|
||||
# from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
# logger = setup_logger()
|
||||
|
||||
|
||||
def dr_ws_graph_builder() -> StateGraph:
|
||||
"""
|
||||
LangGraph graph builder for Internet Search Sub-Agent
|
||||
"""
|
||||
# def dr_ws_graph_builder() -> StateGraph:
|
||||
# """
|
||||
# LangGraph graph builder for Internet Search Sub-Agent
|
||||
# """
|
||||
|
||||
graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
|
||||
# graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
|
||||
|
||||
### Add nodes ###
|
||||
# ### Add nodes ###
|
||||
|
||||
graph.add_node("branch", is_branch)
|
||||
# graph.add_node("branch", is_branch)
|
||||
|
||||
graph.add_node("search", web_search)
|
||||
# graph.add_node("search", web_search)
|
||||
|
||||
graph.add_node("dedup_urls", dedup_urls)
|
||||
# graph.add_node("dedup_urls", dedup_urls)
|
||||
|
||||
graph.add_node("fetch", web_fetch)
|
||||
# graph.add_node("fetch", web_fetch)
|
||||
|
||||
graph.add_node("collect_raw_docs", collect_raw_docs)
|
||||
# graph.add_node("collect_raw_docs", collect_raw_docs)
|
||||
|
||||
graph.add_node("summarize", is_summarize)
|
||||
# graph.add_node("summarize", is_summarize)
|
||||
|
||||
graph.add_node("reducer", is_reducer)
|
||||
# graph.add_node("reducer", is_reducer)
|
||||
|
||||
### Add edges ###
|
||||
# ### Add edges ###
|
||||
|
||||
graph.add_edge(start_key=START, end_key="branch")
|
||||
# graph.add_edge(start_key=START, end_key="branch")
|
||||
|
||||
graph.add_conditional_edges("branch", branching_router)
|
||||
# graph.add_conditional_edges("branch", branching_router)
|
||||
|
||||
graph.add_edge(start_key="search", end_key="dedup_urls")
|
||||
# graph.add_edge(start_key="search", end_key="dedup_urls")
|
||||
|
||||
graph.add_conditional_edges("dedup_urls", fetch_router)
|
||||
# graph.add_conditional_edges("dedup_urls", fetch_router)
|
||||
|
||||
graph.add_edge(start_key="fetch", end_key="collect_raw_docs")
|
||||
# graph.add_edge(start_key="fetch", end_key="collect_raw_docs")
|
||||
|
||||
graph.add_conditional_edges("collect_raw_docs", summarize_router)
|
||||
# graph.add_conditional_edges("collect_raw_docs", summarize_router)
|
||||
|
||||
graph.add_edge(start_key="summarize", end_key="reducer")
|
||||
# graph.add_edge(start_key="summarize", end_key="reducer")
|
||||
|
||||
graph.add_edge(start_key="reducer", end_key=END)
|
||||
# graph.add_edge(start_key="reducer", end_key=END)
|
||||
|
||||
return graph
|
||||
# return graph
|
||||
|
||||
@@ -1,47 +1,53 @@
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
# from abc import ABC
|
||||
# from abc import abstractmethod
|
||||
# from collections.abc import Sequence
|
||||
# from datetime import datetime
|
||||
# from enum import Enum
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import field_validator
|
||||
# from pydantic import BaseModel
|
||||
# from pydantic import field_validator
|
||||
|
||||
from onyx.utils.url import normalize_url
|
||||
# from onyx.utils.url import normalize_url
|
||||
|
||||
|
||||
class WebSearchResult(BaseModel):
|
||||
title: str
|
||||
link: str
|
||||
snippet: str | None = None
|
||||
author: str | None = None
|
||||
published_date: datetime | None = None
|
||||
# class ProviderType(Enum):
|
||||
# """Enum for internet search provider types"""
|
||||
|
||||
@field_validator("link")
|
||||
@classmethod
|
||||
def normalize_link(cls, v: str) -> str:
|
||||
return normalize_url(v)
|
||||
# GOOGLE = "google"
|
||||
# EXA = "exa"
|
||||
|
||||
|
||||
class WebContent(BaseModel):
|
||||
title: str
|
||||
link: str
|
||||
full_content: str
|
||||
published_date: datetime | None = None
|
||||
scrape_successful: bool = True
|
||||
# class WebSearchResult(BaseModel):
|
||||
# title: str
|
||||
# link: str
|
||||
# snippet: str | None = None
|
||||
# author: str | None = None
|
||||
# published_date: datetime | None = None
|
||||
|
||||
@field_validator("link")
|
||||
@classmethod
|
||||
def normalize_link(cls, v: str) -> str:
|
||||
return normalize_url(v)
|
||||
# @field_validator("link")
|
||||
# @classmethod
|
||||
# def normalize_link(cls, v: str) -> str:
|
||||
# return normalize_url(v)
|
||||
|
||||
|
||||
class WebContentProvider(ABC):
|
||||
@abstractmethod
|
||||
def contents(self, urls: Sequence[str]) -> list[WebContent]:
|
||||
pass
|
||||
# class WebContent(BaseModel):
|
||||
# title: str
|
||||
# link: str
|
||||
# full_content: str
|
||||
# published_date: datetime | None = None
|
||||
# scrape_successful: bool = True
|
||||
|
||||
# @field_validator("link")
|
||||
# @classmethod
|
||||
# def normalize_link(cls, v: str) -> str:
|
||||
# return normalize_url(v)
|
||||
|
||||
|
||||
class WebSearchProvider(WebContentProvider):
|
||||
@abstractmethod
|
||||
def search(self, query: str) -> Sequence[WebSearchResult]:
|
||||
pass
|
||||
# class WebSearchProvider(ABC):
|
||||
# @abstractmethod
|
||||
# def search(self, query: str) -> Sequence[WebSearchResult]:
|
||||
# pass
|
||||
|
||||
# @abstractmethod
|
||||
# def contents(self, urls: Sequence[str]) -> list[WebContent]:
|
||||
# pass
|
||||
|
||||
@@ -1,199 +1,19 @@
|
||||
from typing import Any
|
||||
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.clients.exa_client import (
|
||||
ExaClient,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.clients.firecrawl_client import (
|
||||
FIRECRAWL_SCRAPE_URL,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.clients.firecrawl_client import (
|
||||
FirecrawlClient,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.clients.google_pse_client import (
|
||||
GooglePSEClient,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.clients.onyx_web_crawler_client import (
|
||||
OnyxWebCrawlerClient,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.clients.serper_client import (
|
||||
SerperClient,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
|
||||
WebContentProvider,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
|
||||
WebSearchProvider,
|
||||
)
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.web_search import fetch_active_web_content_provider
|
||||
from onyx.db.web_search import fetch_active_web_search_provider
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.enums import WebContentProviderType
|
||||
from shared_configs.enums import WebSearchProviderType
|
||||
|
||||
logger = setup_logger()
|
||||
# from onyx.agents.agent_search.dr.sub_agents.web_search.clients.exa_client import (
|
||||
# ExaClient,
|
||||
# )
|
||||
# from onyx.agents.agent_search.dr.sub_agents.web_search.clients.serper_client import (
|
||||
# SerperClient,
|
||||
# )
|
||||
# from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
|
||||
# WebSearchProvider,
|
||||
# )
|
||||
# from onyx.configs.chat_configs import EXA_API_KEY
|
||||
# from onyx.configs.chat_configs import SERPER_API_KEY
|
||||
|
||||
|
||||
def build_search_provider_from_config(
|
||||
*,
|
||||
provider_type: WebSearchProviderType,
|
||||
api_key: str | None,
|
||||
config: dict[str, str] | None,
|
||||
provider_name: str = "web_search_provider",
|
||||
) -> WebSearchProvider | None:
|
||||
provider_type_value = provider_type.value
|
||||
try:
|
||||
provider_type_enum = WebSearchProviderType(provider_type_value)
|
||||
except ValueError:
|
||||
logger.error(
|
||||
f"Unknown web search provider type '{provider_type_value}'. "
|
||||
"Skipping provider initialization."
|
||||
)
|
||||
return None
|
||||
|
||||
# All web search providers require an API key
|
||||
if not api_key:
|
||||
raise ValueError(
|
||||
f"Web search provider '{provider_name}' is missing an API key."
|
||||
)
|
||||
assert api_key is not None
|
||||
|
||||
config = config or {}
|
||||
|
||||
if provider_type_enum == WebSearchProviderType.EXA:
|
||||
return ExaClient(api_key=api_key)
|
||||
if provider_type_enum == WebSearchProviderType.SERPER:
|
||||
return SerperClient(api_key=api_key)
|
||||
if provider_type_enum == WebSearchProviderType.GOOGLE_PSE:
|
||||
search_engine_id = (
|
||||
config.get("search_engine_id")
|
||||
or config.get("cx")
|
||||
or config.get("search_engine")
|
||||
)
|
||||
if not search_engine_id:
|
||||
raise ValueError(
|
||||
"Google PSE provider requires a search engine id (cx) in addition to the API key."
|
||||
)
|
||||
assert search_engine_id is not None
|
||||
try:
|
||||
num_results = int(config.get("num_results", 10))
|
||||
except (TypeError, ValueError):
|
||||
raise ValueError(
|
||||
"Invalid value for Google PSE 'num_results'; expected integer."
|
||||
)
|
||||
try:
|
||||
timeout_seconds = int(config.get("timeout_seconds", 10))
|
||||
except (TypeError, ValueError):
|
||||
raise ValueError(
|
||||
"Invalid value for Google PSE 'timeout_seconds'; expected integer."
|
||||
)
|
||||
return GooglePSEClient(
|
||||
api_key=api_key,
|
||||
search_engine_id=search_engine_id,
|
||||
num_results=num_results,
|
||||
timeout_seconds=timeout_seconds,
|
||||
)
|
||||
|
||||
logger.error(
|
||||
f"Unhandled web search provider type '{provider_type_value}'. "
|
||||
"Skipping provider initialization."
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def _build_search_provider(provider_model: Any) -> WebSearchProvider | None:
|
||||
return build_search_provider_from_config(
|
||||
provider_type=WebSearchProviderType(provider_model.provider_type),
|
||||
api_key=provider_model.api_key,
|
||||
config=provider_model.config or {},
|
||||
provider_name=provider_model.name,
|
||||
)
|
||||
|
||||
|
||||
def build_content_provider_from_config(
|
||||
*,
|
||||
provider_type: WebContentProviderType,
|
||||
api_key: str | None,
|
||||
config: dict[str, str] | None,
|
||||
provider_name: str = "web_content_provider",
|
||||
) -> WebContentProvider | None:
|
||||
provider_type_value = provider_type.value
|
||||
try:
|
||||
provider_type_enum = WebContentProviderType(provider_type_value)
|
||||
except ValueError:
|
||||
logger.error(
|
||||
f"Unknown web content provider type '{provider_type_value}'. "
|
||||
"Skipping provider initialization."
|
||||
)
|
||||
return None
|
||||
|
||||
if provider_type_enum == WebContentProviderType.ONYX_WEB_CRAWLER:
|
||||
config = config or {}
|
||||
timeout_value = config.get("timeout_seconds", 15)
|
||||
try:
|
||||
timeout_seconds = int(timeout_value)
|
||||
except (TypeError, ValueError):
|
||||
raise ValueError(
|
||||
"Invalid value for Onyx Web Crawler 'timeout_seconds'; expected integer."
|
||||
)
|
||||
return OnyxWebCrawlerClient(timeout_seconds=timeout_seconds)
|
||||
|
||||
if provider_type_enum == WebContentProviderType.FIRECRAWL:
|
||||
if not api_key:
|
||||
raise ValueError("Firecrawl content provider requires an API key.")
|
||||
assert api_key is not None
|
||||
config = config or {}
|
||||
timeout_seconds_str = config.get("timeout_seconds")
|
||||
if timeout_seconds_str is None:
|
||||
timeout_seconds = 10
|
||||
else:
|
||||
try:
|
||||
timeout_seconds = int(timeout_seconds_str)
|
||||
except (TypeError, ValueError):
|
||||
raise ValueError(
|
||||
"Invalid value for Firecrawl 'timeout_seconds'; expected integer."
|
||||
)
|
||||
return FirecrawlClient(
|
||||
api_key=api_key,
|
||||
base_url=config.get("base_url") or FIRECRAWL_SCRAPE_URL,
|
||||
timeout_seconds=timeout_seconds,
|
||||
)
|
||||
|
||||
logger.error(
|
||||
f"Unhandled web content provider type '{provider_type_value}'. "
|
||||
"Skipping provider initialization."
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def _build_content_provider(provider_model: Any) -> WebContentProvider | None:
|
||||
return build_content_provider_from_config(
|
||||
provider_type=WebContentProviderType(provider_model.provider_type),
|
||||
api_key=provider_model.api_key,
|
||||
config=provider_model.config or {},
|
||||
provider_name=provider_model.name,
|
||||
)
|
||||
|
||||
|
||||
def get_default_provider() -> WebSearchProvider | None:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
provider_model = fetch_active_web_search_provider(db_session)
|
||||
if provider_model is None:
|
||||
return None
|
||||
return _build_search_provider(provider_model)
|
||||
|
||||
|
||||
def get_default_content_provider() -> WebContentProvider | None:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
provider_model = fetch_active_web_content_provider(db_session)
|
||||
if provider_model:
|
||||
provider = _build_content_provider(provider_model)
|
||||
if provider:
|
||||
return provider
|
||||
|
||||
# Fall back to built-in Onyx crawler when nothing is configured.
|
||||
try:
|
||||
return OnyxWebCrawlerClient()
|
||||
except Exception as exc: # pragma: no cover - defensive
|
||||
logger.error(f"Failed to initialize default Onyx crawler: {exc}")
|
||||
return None
|
||||
# def get_default_provider() -> WebSearchProvider | None:
|
||||
# if EXA_API_KEY:
|
||||
# return ExaClient()
|
||||
# if SERPER_API_KEY:
|
||||
# return SerperClient()
|
||||
# return None
|
||||
|
||||
@@ -1,37 +1,37 @@
|
||||
from operator import add
|
||||
from typing import Annotated
|
||||
# from operator import add
|
||||
# from typing import Annotated
|
||||
|
||||
from onyx.agents.agent_search.dr.states import LoggerUpdate
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
|
||||
WebSearchResult,
|
||||
)
|
||||
from onyx.context.search.models import InferenceSection
|
||||
# from onyx.agents.agent_search.dr.states import LoggerUpdate
|
||||
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
|
||||
# from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
|
||||
# WebSearchResult,
|
||||
# )
|
||||
# from onyx.context.search.models import InferenceSection
|
||||
|
||||
|
||||
class InternetSearchInput(SubAgentInput):
|
||||
results_to_open: Annotated[list[tuple[str, WebSearchResult]], add] = []
|
||||
parallelization_nr: int = 0
|
||||
branch_question: Annotated[str, lambda x, y: y] = ""
|
||||
branch_questions_to_urls: Annotated[dict[str, list[str]], lambda x, y: y] = {}
|
||||
raw_documents: Annotated[list[InferenceSection], add] = []
|
||||
# class InternetSearchInput(SubAgentInput):
|
||||
# results_to_open: Annotated[list[tuple[str, WebSearchResult]], add] = []
|
||||
# parallelization_nr: int = 0
|
||||
# branch_question: Annotated[str, lambda x, y: y] = ""
|
||||
# branch_questions_to_urls: Annotated[dict[str, list[str]], lambda x, y: y] = {}
|
||||
# raw_documents: Annotated[list[InferenceSection], add] = []
|
||||
|
||||
|
||||
class InternetSearchUpdate(LoggerUpdate):
|
||||
results_to_open: Annotated[list[tuple[str, WebSearchResult]], add] = []
|
||||
# class InternetSearchUpdate(LoggerUpdate):
|
||||
# results_to_open: Annotated[list[tuple[str, WebSearchResult]], add] = []
|
||||
|
||||
|
||||
class FetchInput(SubAgentInput):
|
||||
urls_to_open: Annotated[list[str], add] = []
|
||||
branch_questions_to_urls: dict[str, list[str]]
|
||||
raw_documents: Annotated[list[InferenceSection], add] = []
|
||||
# class FetchInput(SubAgentInput):
|
||||
# urls_to_open: Annotated[list[str], add] = []
|
||||
# branch_questions_to_urls: dict[str, list[str]]
|
||||
# raw_documents: Annotated[list[InferenceSection], add] = []
|
||||
|
||||
|
||||
class FetchUpdate(LoggerUpdate):
|
||||
raw_documents: Annotated[list[InferenceSection], add] = []
|
||||
# class FetchUpdate(LoggerUpdate):
|
||||
# raw_documents: Annotated[list[InferenceSection], add] = []
|
||||
|
||||
|
||||
class SummarizeInput(SubAgentInput):
|
||||
raw_documents: Annotated[list[InferenceSection], add] = []
|
||||
branch_questions_to_urls: dict[str, list[str]]
|
||||
branch_question: str
|
||||
# class SummarizeInput(SubAgentInput):
|
||||
# raw_documents: Annotated[list[InferenceSection], add] = []
|
||||
# branch_questions_to_urls: dict[str, list[str]]
|
||||
# branch_question: str
|
||||
|
||||
@@ -1,99 +1,99 @@
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
|
||||
WebContent,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
|
||||
WebSearchResult,
|
||||
)
|
||||
from onyx.chat.models import DOCUMENT_CITATION_NUMBER_EMPTY_VALUE
|
||||
from onyx.chat.models import LlmDoc
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.context.search.models import InferenceChunk
|
||||
from onyx.context.search.models import InferenceSection
|
||||
# from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
|
||||
# WebContent,
|
||||
# )
|
||||
# from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
|
||||
# WebSearchResult,
|
||||
# )
|
||||
# from onyx.chat.models import DOCUMENT_CITATION_NUMBER_EMPTY_VALUE
|
||||
# from onyx.chat.models import LlmDoc
|
||||
# from onyx.configs.constants import DocumentSource
|
||||
# from onyx.context.search.models import InferenceChunk
|
||||
# from onyx.context.search.models import InferenceSection
|
||||
|
||||
|
||||
def truncate_search_result_content(content: str, max_chars: int = 10000) -> str:
|
||||
"""Truncate search result content to a maximum number of characters"""
|
||||
if len(content) <= max_chars:
|
||||
return content
|
||||
return content[:max_chars] + "..."
|
||||
# def truncate_search_result_content(content: str, max_chars: int = 10000) -> str:
|
||||
# """Truncate search result content to a maximum number of characters"""
|
||||
# if len(content) <= max_chars:
|
||||
# return content
|
||||
# return content[:max_chars] + "..."
|
||||
|
||||
|
||||
def dummy_inference_section_from_internet_content(
|
||||
result: WebContent,
|
||||
) -> InferenceSection:
|
||||
truncated_content = truncate_search_result_content(result.full_content)
|
||||
return InferenceSection(
|
||||
center_chunk=InferenceChunk(
|
||||
chunk_id=0,
|
||||
blurb=result.title,
|
||||
content=truncated_content,
|
||||
source_links={0: result.link},
|
||||
section_continuation=False,
|
||||
document_id="INTERNET_SEARCH_DOC_" + result.link,
|
||||
source_type=DocumentSource.WEB,
|
||||
semantic_identifier=result.link,
|
||||
title=result.title,
|
||||
boost=1,
|
||||
recency_bias=1.0,
|
||||
score=1.0,
|
||||
hidden=(not result.scrape_successful),
|
||||
metadata={},
|
||||
match_highlights=[],
|
||||
doc_summary=truncated_content,
|
||||
chunk_context=truncated_content,
|
||||
updated_at=result.published_date,
|
||||
image_file_id=None,
|
||||
),
|
||||
chunks=[],
|
||||
combined_content=truncated_content,
|
||||
)
|
||||
# def inference_section_from_internet_page_scrape(
|
||||
# result: WebContent,
|
||||
# ) -> InferenceSection:
|
||||
# truncated_content = truncate_search_result_content(result.full_content)
|
||||
# return InferenceSection(
|
||||
# center_chunk=InferenceChunk(
|
||||
# chunk_id=0,
|
||||
# blurb=result.title,
|
||||
# content=truncated_content,
|
||||
# source_links={0: result.link},
|
||||
# section_continuation=False,
|
||||
# document_id="INTERNET_SEARCH_DOC_" + result.link,
|
||||
# source_type=DocumentSource.WEB,
|
||||
# semantic_identifier=result.link,
|
||||
# title=result.title,
|
||||
# boost=1,
|
||||
# recency_bias=1.0,
|
||||
# score=1.0,
|
||||
# hidden=(not result.scrape_successful),
|
||||
# metadata={},
|
||||
# match_highlights=[],
|
||||
# doc_summary=truncated_content,
|
||||
# chunk_context=truncated_content,
|
||||
# updated_at=result.published_date,
|
||||
# image_file_id=None,
|
||||
# ),
|
||||
# chunks=[],
|
||||
# combined_content=truncated_content,
|
||||
# )
|
||||
|
||||
|
||||
def dummy_inference_section_from_internet_search_result(
|
||||
result: WebSearchResult,
|
||||
) -> InferenceSection:
|
||||
return InferenceSection(
|
||||
center_chunk=InferenceChunk(
|
||||
chunk_id=0,
|
||||
blurb=result.title,
|
||||
content="",
|
||||
source_links={0: result.link},
|
||||
section_continuation=False,
|
||||
document_id="INTERNET_SEARCH_DOC_" + result.link,
|
||||
source_type=DocumentSource.WEB,
|
||||
semantic_identifier=result.link,
|
||||
title=result.title,
|
||||
boost=1,
|
||||
recency_bias=1.0,
|
||||
score=1.0,
|
||||
hidden=False,
|
||||
metadata={},
|
||||
match_highlights=[],
|
||||
doc_summary="",
|
||||
chunk_context="",
|
||||
updated_at=result.published_date,
|
||||
image_file_id=None,
|
||||
),
|
||||
chunks=[],
|
||||
combined_content="",
|
||||
)
|
||||
# def dummy_inference_section_from_internet_search_result(
|
||||
# result: WebSearchResult,
|
||||
# ) -> InferenceSection:
|
||||
# return InferenceSection(
|
||||
# center_chunk=InferenceChunk(
|
||||
# chunk_id=0,
|
||||
# blurb=result.title,
|
||||
# content="",
|
||||
# source_links={0: result.link},
|
||||
# section_continuation=False,
|
||||
# document_id="INTERNET_SEARCH_DOC_" + result.link,
|
||||
# source_type=DocumentSource.WEB,
|
||||
# semantic_identifier=result.link,
|
||||
# title=result.title,
|
||||
# boost=1,
|
||||
# recency_bias=1.0,
|
||||
# score=1.0,
|
||||
# hidden=False,
|
||||
# metadata={},
|
||||
# match_highlights=[],
|
||||
# doc_summary="",
|
||||
# chunk_context="",
|
||||
# updated_at=result.published_date,
|
||||
# image_file_id=None,
|
||||
# ),
|
||||
# chunks=[],
|
||||
# combined_content="",
|
||||
# )
|
||||
|
||||
|
||||
def llm_doc_from_web_content(web_content: WebContent) -> LlmDoc:
|
||||
"""Create an LlmDoc from WebContent with the INTERNET_SEARCH_DOC_ prefix"""
|
||||
return LlmDoc(
|
||||
# TODO: Is this what we want to do for document_id? We're kind of overloading it since it
|
||||
# should ideally correspond to a document in the database. But I guess if you're calling this
|
||||
# function you know it won't be in the database.
|
||||
document_id="INTERNET_SEARCH_DOC_" + web_content.link,
|
||||
content=truncate_search_result_content(web_content.full_content),
|
||||
blurb=web_content.link,
|
||||
semantic_identifier=web_content.link,
|
||||
source_type=DocumentSource.WEB,
|
||||
metadata={},
|
||||
link=web_content.link,
|
||||
document_citation_number=DOCUMENT_CITATION_NUMBER_EMPTY_VALUE,
|
||||
updated_at=web_content.published_date,
|
||||
source_links={},
|
||||
match_highlights=[],
|
||||
)
|
||||
# def llm_doc_from_web_content(web_content: WebContent) -> LlmDoc:
|
||||
# """Create an LlmDoc from WebContent with the INTERNET_SEARCH_DOC_ prefix"""
|
||||
# return LlmDoc(
|
||||
# # TODO: Is this what we want to do for document_id? We're kind of overloading it since it
|
||||
# # should ideally correspond to a document in the database. But I guess if you're calling this
|
||||
# # function you know it won't be in the database.
|
||||
# document_id="INTERNET_SEARCH_DOC_" + web_content.link,
|
||||
# content=truncate_search_result_content(web_content.full_content),
|
||||
# blurb=web_content.link,
|
||||
# semantic_identifier=web_content.link,
|
||||
# source_type=DocumentSource.WEB,
|
||||
# metadata={},
|
||||
# link=web_content.link,
|
||||
# document_citation_number=DOCUMENT_CITATION_NUMBER_EMPTY_VALUE,
|
||||
# updated_at=web_content.published_date,
|
||||
# source_links={},
|
||||
# match_highlights=[],
|
||||
# )
|
||||
|
||||
@@ -1,277 +1,277 @@
|
||||
import copy
|
||||
import re
|
||||
# import copy
|
||||
# import re
|
||||
|
||||
from langchain.schema.messages import BaseMessage
|
||||
from langchain.schema.messages import HumanMessage
|
||||
# from langchain.schema.messages import BaseMessage
|
||||
# from langchain.schema.messages import HumanMessage
|
||||
|
||||
from onyx.agents.agent_search.dr.models import AggregatedDRContext
|
||||
from onyx.agents.agent_search.dr.models import IterationAnswer
|
||||
from onyx.agents.agent_search.dr.models import OrchestrationClarificationInfo
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import build_document_context
|
||||
from onyx.agents.agent_search.shared_graph_utils.operators import (
|
||||
dedup_inference_section_list,
|
||||
)
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.context.search.models import SavedSearchDoc
|
||||
from onyx.context.search.models import SearchDoc
|
||||
from onyx.tools.tool_implementations.web_search.web_search_tool import (
|
||||
WebSearchTool,
|
||||
)
|
||||
# from onyx.agents.agent_search.dr.models import AggregatedDRContext
|
||||
# from onyx.agents.agent_search.dr.models import IterationAnswer
|
||||
# from onyx.agents.agent_search.dr.models import OrchestrationClarificationInfo
|
||||
# from onyx.agents.agent_search.kb_search.graph_utils import build_document_context
|
||||
# from onyx.agents.agent_search.shared_graph_utils.operators import (
|
||||
# dedup_inference_section_list,
|
||||
# )
|
||||
# from onyx.context.search.models import InferenceSection
|
||||
# from onyx.context.search.models import SavedSearchDoc
|
||||
# from onyx.context.search.models import SearchDoc
|
||||
# from onyx.tools.tool_implementations.web_search.web_search_tool import (
|
||||
# WebSearchTool,
|
||||
# )
|
||||
|
||||
|
||||
CITATION_PREFIX = "CITE:"
|
||||
# CITATION_PREFIX = "CITE:"
|
||||
|
||||
|
||||
def extract_document_citations(
|
||||
answer: str, claims: list[str]
|
||||
) -> tuple[list[int], str, list[str]]:
|
||||
"""
|
||||
Finds all citations of the form [1], [2, 3], etc. and returns the list of cited indices,
|
||||
as well as the answer and claims with the citations replaced with [<CITATION_PREFIX>1],
|
||||
etc., to help with citation deduplication later on.
|
||||
"""
|
||||
citations: set[int] = set()
|
||||
# def extract_document_citations(
|
||||
# answer: str, claims: list[str]
|
||||
# ) -> tuple[list[int], str, list[str]]:
|
||||
# """
|
||||
# Finds all citations of the form [1], [2, 3], etc. and returns the list of cited indices,
|
||||
# as well as the answer and claims with the citations replaced with [<CITATION_PREFIX>1],
|
||||
# etc., to help with citation deduplication later on.
|
||||
# """
|
||||
# citations: set[int] = set()
|
||||
|
||||
# Pattern to match both single citations [1] and multiple citations [1, 2, 3]
|
||||
# This regex matches:
|
||||
# - \[(\d+)\] for single citations like [1]
|
||||
# - \[(\d+(?:,\s*\d+)*)\] for multiple citations like [1, 2, 3]
|
||||
pattern = re.compile(r"\[(\d+(?:,\s*\d+)*)\]")
|
||||
# # Pattern to match both single citations [1] and multiple citations [1, 2, 3]
|
||||
# # This regex matches:
|
||||
# # - \[(\d+)\] for single citations like [1]
|
||||
# # - \[(\d+(?:,\s*\d+)*)\] for multiple citations like [1, 2, 3]
|
||||
# pattern = re.compile(r"\[(\d+(?:,\s*\d+)*)\]")
|
||||
|
||||
def _extract_and_replace(match: re.Match[str]) -> str:
|
||||
numbers = [int(num) for num in match.group(1).split(",")]
|
||||
citations.update(numbers)
|
||||
return "".join(f"[{CITATION_PREFIX}{num}]" for num in numbers)
|
||||
# def _extract_and_replace(match: re.Match[str]) -> str:
|
||||
# numbers = [int(num) for num in match.group(1).split(",")]
|
||||
# citations.update(numbers)
|
||||
# return "".join(f"[{CITATION_PREFIX}{num}]" for num in numbers)
|
||||
|
||||
new_answer = pattern.sub(_extract_and_replace, answer)
|
||||
new_claims = [pattern.sub(_extract_and_replace, claim) for claim in claims]
|
||||
# new_answer = pattern.sub(_extract_and_replace, answer)
|
||||
# new_claims = [pattern.sub(_extract_and_replace, claim) for claim in claims]
|
||||
|
||||
return list(citations), new_answer, new_claims
|
||||
# return list(citations), new_answer, new_claims
|
||||
|
||||
|
||||
def aggregate_context(
|
||||
iteration_responses: list[IterationAnswer], include_documents: bool = True
|
||||
) -> AggregatedDRContext:
|
||||
"""
|
||||
Converts the iteration response into a single string with unified citations.
|
||||
For example,
|
||||
it 1: the answer is x [3][4]. {3: doc_abc, 4: doc_xyz}
|
||||
it 2: blah blah [1, 3]. {1: doc_xyz, 3: doc_pqr}
|
||||
Output:
|
||||
it 1: the answer is x [1][2].
|
||||
it 2: blah blah [2][3]
|
||||
[1]: doc_xyz
|
||||
[2]: doc_abc
|
||||
[3]: doc_pqr
|
||||
"""
|
||||
# dedupe and merge inference section contents
|
||||
unrolled_inference_sections: list[InferenceSection] = []
|
||||
is_internet_marker_dict: dict[str, bool] = {}
|
||||
for iteration_response in sorted(
|
||||
iteration_responses,
|
||||
key=lambda x: (x.iteration_nr, x.parallelization_nr),
|
||||
):
|
||||
# def aggregate_context(
|
||||
# iteration_responses: list[IterationAnswer], include_documents: bool = True
|
||||
# ) -> AggregatedDRContext:
|
||||
# """
|
||||
# Converts the iteration response into a single string with unified citations.
|
||||
# For example,
|
||||
# it 1: the answer is x [3][4]. {3: doc_abc, 4: doc_xyz}
|
||||
# it 2: blah blah [1, 3]. {1: doc_xyz, 3: doc_pqr}
|
||||
# Output:
|
||||
# it 1: the answer is x [1][2].
|
||||
# it 2: blah blah [2][3]
|
||||
# [1]: doc_xyz
|
||||
# [2]: doc_abc
|
||||
# [3]: doc_pqr
|
||||
# """
|
||||
# # dedupe and merge inference section contents
|
||||
# unrolled_inference_sections: list[InferenceSection] = []
|
||||
# is_internet_marker_dict: dict[str, bool] = {}
|
||||
# for iteration_response in sorted(
|
||||
# iteration_responses,
|
||||
# key=lambda x: (x.iteration_nr, x.parallelization_nr),
|
||||
# ):
|
||||
|
||||
iteration_tool = iteration_response.tool
|
||||
is_internet = iteration_tool == WebSearchTool._NAME
|
||||
# iteration_tool = iteration_response.tool
|
||||
# is_internet = iteration_tool == WebSearchTool._NAME
|
||||
|
||||
for cited_doc in iteration_response.cited_documents.values():
|
||||
unrolled_inference_sections.append(cited_doc)
|
||||
if cited_doc.center_chunk.document_id not in is_internet_marker_dict:
|
||||
is_internet_marker_dict[cited_doc.center_chunk.document_id] = (
|
||||
is_internet
|
||||
)
|
||||
cited_doc.center_chunk.score = None # None means maintain order
|
||||
# for cited_doc in iteration_response.cited_documents.values():
|
||||
# unrolled_inference_sections.append(cited_doc)
|
||||
# if cited_doc.center_chunk.document_id not in is_internet_marker_dict:
|
||||
# is_internet_marker_dict[cited_doc.center_chunk.document_id] = (
|
||||
# is_internet
|
||||
# )
|
||||
# cited_doc.center_chunk.score = None # None means maintain order
|
||||
|
||||
global_documents = dedup_inference_section_list(unrolled_inference_sections)
|
||||
# global_documents = dedup_inference_section_list(unrolled_inference_sections)
|
||||
|
||||
global_citations = {
|
||||
doc.center_chunk.document_id: i for i, doc in enumerate(global_documents, 1)
|
||||
}
|
||||
# global_citations = {
|
||||
# doc.center_chunk.document_id: i for i, doc in enumerate(global_documents, 1)
|
||||
# }
|
||||
|
||||
# build output string
|
||||
output_strings: list[str] = []
|
||||
global_iteration_responses: list[IterationAnswer] = []
|
||||
# # build output string
|
||||
# output_strings: list[str] = []
|
||||
# global_iteration_responses: list[IterationAnswer] = []
|
||||
|
||||
for iteration_response in sorted(
|
||||
iteration_responses,
|
||||
key=lambda x: (x.iteration_nr, x.parallelization_nr),
|
||||
):
|
||||
# add basic iteration info
|
||||
output_strings.append(
|
||||
f"Iteration: {iteration_response.iteration_nr}, "
|
||||
f"Question {iteration_response.parallelization_nr}"
|
||||
)
|
||||
output_strings.append(f"Tool: {iteration_response.tool}")
|
||||
output_strings.append(f"Question: {iteration_response.question}")
|
||||
# for iteration_response in sorted(
|
||||
# iteration_responses,
|
||||
# key=lambda x: (x.iteration_nr, x.parallelization_nr),
|
||||
# ):
|
||||
# # add basic iteration info
|
||||
# output_strings.append(
|
||||
# f"Iteration: {iteration_response.iteration_nr}, "
|
||||
# f"Question {iteration_response.parallelization_nr}"
|
||||
# )
|
||||
# output_strings.append(f"Tool: {iteration_response.tool}")
|
||||
# output_strings.append(f"Question: {iteration_response.question}")
|
||||
|
||||
# get answer and claims with global citations
|
||||
answer_str = iteration_response.answer
|
||||
claims = iteration_response.claims or []
|
||||
# # get answer and claims with global citations
|
||||
# answer_str = iteration_response.answer
|
||||
# claims = iteration_response.claims or []
|
||||
|
||||
iteration_citations: list[int] = []
|
||||
for local_number, cited_doc in iteration_response.cited_documents.items():
|
||||
global_number = global_citations[cited_doc.center_chunk.document_id]
|
||||
# translate local citations to global citations
|
||||
answer_str = answer_str.replace(
|
||||
f"[{CITATION_PREFIX}{local_number}]", f"[{global_number}]"
|
||||
)
|
||||
claims = [
|
||||
claim.replace(
|
||||
f"[{CITATION_PREFIX}{local_number}]", f"[{global_number}]"
|
||||
)
|
||||
for claim in claims
|
||||
]
|
||||
iteration_citations.append(global_number)
|
||||
# iteration_citations: list[int] = []
|
||||
# for local_number, cited_doc in iteration_response.cited_documents.items():
|
||||
# global_number = global_citations[cited_doc.center_chunk.document_id]
|
||||
# # translate local citations to global citations
|
||||
# answer_str = answer_str.replace(
|
||||
# f"[{CITATION_PREFIX}{local_number}]", f"[{global_number}]"
|
||||
# )
|
||||
# claims = [
|
||||
# claim.replace(
|
||||
# f"[{CITATION_PREFIX}{local_number}]", f"[{global_number}]"
|
||||
# )
|
||||
# for claim in claims
|
||||
# ]
|
||||
# iteration_citations.append(global_number)
|
||||
|
||||
# add answer, claims, and citation info
|
||||
if answer_str:
|
||||
output_strings.append(f"Answer: {answer_str}")
|
||||
if claims:
|
||||
output_strings.append(
|
||||
"Claims: " + "".join(f"\n - {claim}" for claim in claims or [])
|
||||
or "No claims provided"
|
||||
)
|
||||
if not answer_str and not claims:
|
||||
output_strings.append(
|
||||
"Retrieved documents: "
|
||||
+ (
|
||||
"".join(
|
||||
f"[{global_number}]"
|
||||
for global_number in sorted(iteration_citations)
|
||||
)
|
||||
or "No documents retrieved"
|
||||
)
|
||||
)
|
||||
output_strings.append("\n---\n")
|
||||
# # add answer, claims, and citation info
|
||||
# if answer_str:
|
||||
# output_strings.append(f"Answer: {answer_str}")
|
||||
# if claims:
|
||||
# output_strings.append(
|
||||
# "Claims: " + "".join(f"\n - {claim}" for claim in claims or [])
|
||||
# or "No claims provided"
|
||||
# )
|
||||
# if not answer_str and not claims:
|
||||
# output_strings.append(
|
||||
# "Retrieved documents: "
|
||||
# + (
|
||||
# "".join(
|
||||
# f"[{global_number}]"
|
||||
# for global_number in sorted(iteration_citations)
|
||||
# )
|
||||
# or "No documents retrieved"
|
||||
# )
|
||||
# )
|
||||
# output_strings.append("\n---\n")
|
||||
|
||||
# save global iteration response
|
||||
iteration_response_copy = iteration_response.model_copy()
|
||||
iteration_response_copy.answer = answer_str
|
||||
iteration_response_copy.claims = claims
|
||||
iteration_response_copy.cited_documents = {
|
||||
global_citations[doc.center_chunk.document_id]: doc
|
||||
for doc in iteration_response.cited_documents.values()
|
||||
}
|
||||
global_iteration_responses.append(iteration_response_copy)
|
||||
# # save global iteration response
|
||||
# iteration_response_copy = iteration_response.model_copy()
|
||||
# iteration_response_copy.answer = answer_str
|
||||
# iteration_response_copy.claims = claims
|
||||
# iteration_response_copy.cited_documents = {
|
||||
# global_citations[doc.center_chunk.document_id]: doc
|
||||
# for doc in iteration_response.cited_documents.values()
|
||||
# }
|
||||
# global_iteration_responses.append(iteration_response_copy)
|
||||
|
||||
# add document contents if requested
|
||||
if include_documents:
|
||||
if global_documents:
|
||||
output_strings.append("Cited document contents:")
|
||||
for doc in global_documents:
|
||||
output_strings.append(
|
||||
build_document_context(
|
||||
doc, global_citations[doc.center_chunk.document_id]
|
||||
)
|
||||
)
|
||||
output_strings.append("\n---\n")
|
||||
# # add document contents if requested
|
||||
# if include_documents:
|
||||
# if global_documents:
|
||||
# output_strings.append("Cited document contents:")
|
||||
# for doc in global_documents:
|
||||
# output_strings.append(
|
||||
# build_document_context(
|
||||
# doc, global_citations[doc.center_chunk.document_id]
|
||||
# )
|
||||
# )
|
||||
# output_strings.append("\n---\n")
|
||||
|
||||
return AggregatedDRContext(
|
||||
context="\n".join(output_strings),
|
||||
cited_documents=global_documents,
|
||||
is_internet_marker_dict=is_internet_marker_dict,
|
||||
global_iteration_responses=global_iteration_responses,
|
||||
)
|
||||
# return AggregatedDRContext(
|
||||
# context="\n".join(output_strings),
|
||||
# cited_documents=global_documents,
|
||||
# is_internet_marker_dict=is_internet_marker_dict,
|
||||
# global_iteration_responses=global_iteration_responses,
|
||||
# )
|
||||
|
||||
|
||||
def get_chat_history_string(chat_history: list[BaseMessage], max_messages: int) -> str:
|
||||
"""
|
||||
Get the chat history (up to max_messages) as a string.
|
||||
"""
|
||||
# get past max_messages USER, ASSISTANT message pairs
|
||||
# def get_chat_history_string(chat_history: list[BaseMessage], max_messages: int) -> str:
|
||||
# """
|
||||
# Get the chat history (up to max_messages) as a string.
|
||||
# """
|
||||
# # get past max_messages USER, ASSISTANT message pairs
|
||||
|
||||
past_messages = chat_history[-max_messages * 2 :]
|
||||
filtered_past_messages = copy.deepcopy(past_messages)
|
||||
# past_messages = chat_history[-max_messages * 2 :]
|
||||
# filtered_past_messages = copy.deepcopy(past_messages)
|
||||
|
||||
for past_message_number, past_message in enumerate(past_messages):
|
||||
# for past_message_number, past_message in enumerate(past_messages):
|
||||
|
||||
if isinstance(past_message.content, list):
|
||||
removal_indices = []
|
||||
for content_piece_number, content_piece in enumerate(past_message.content):
|
||||
if (
|
||||
isinstance(content_piece, dict)
|
||||
and content_piece.get("type") != "text"
|
||||
):
|
||||
removal_indices.append(content_piece_number)
|
||||
# if isinstance(past_message.content, list):
|
||||
# removal_indices = []
|
||||
# for content_piece_number, content_piece in enumerate(past_message.content):
|
||||
# if (
|
||||
# isinstance(content_piece, dict)
|
||||
# and content_piece.get("type") != "text"
|
||||
# ):
|
||||
# removal_indices.append(content_piece_number)
|
||||
|
||||
# Only rebuild the content list if there are items to remove
|
||||
if removal_indices:
|
||||
filtered_past_messages[past_message_number].content = [
|
||||
content_piece
|
||||
for content_piece_number, content_piece in enumerate(
|
||||
past_message.content
|
||||
)
|
||||
if content_piece_number not in removal_indices
|
||||
]
|
||||
# # Only rebuild the content list if there are items to remove
|
||||
# if removal_indices:
|
||||
# filtered_past_messages[past_message_number].content = [
|
||||
# content_piece
|
||||
# for content_piece_number, content_piece in enumerate(
|
||||
# past_message.content
|
||||
# )
|
||||
# if content_piece_number not in removal_indices
|
||||
# ]
|
||||
|
||||
else:
|
||||
continue
|
||||
# else:
|
||||
# continue
|
||||
|
||||
return (
|
||||
"...\n" if len(chat_history) > len(filtered_past_messages) else ""
|
||||
) + "\n".join(
|
||||
("user" if isinstance(msg, HumanMessage) else "you")
|
||||
+ f": {str(msg.content).strip()}"
|
||||
for msg in filtered_past_messages
|
||||
)
|
||||
# return (
|
||||
# "...\n" if len(chat_history) > len(filtered_past_messages) else ""
|
||||
# ) + "\n".join(
|
||||
# ("user" if isinstance(msg, HumanMessage) else "you")
|
||||
# + f": {str(msg.content).strip()}"
|
||||
# for msg in filtered_past_messages
|
||||
# )
|
||||
|
||||
|
||||
def get_prompt_question(
|
||||
question: str, clarification: OrchestrationClarificationInfo | None
|
||||
) -> str:
|
||||
if clarification:
|
||||
clarification_question = clarification.clarification_question
|
||||
clarification_response = clarification.clarification_response
|
||||
return (
|
||||
f"Initial User Question: {question}\n"
|
||||
f"(Clarification Question: {clarification_question}\n"
|
||||
f"User Response: {clarification_response})"
|
||||
)
|
||||
# def get_prompt_question(
|
||||
# question: str, clarification: OrchestrationClarificationInfo | None
|
||||
# ) -> str:
|
||||
# if clarification:
|
||||
# clarification_question = clarification.clarification_question
|
||||
# clarification_response = clarification.clarification_response
|
||||
# return (
|
||||
# f"Initial User Question: {question}\n"
|
||||
# f"(Clarification Question: {clarification_question}\n"
|
||||
# f"User Response: {clarification_response})"
|
||||
# )
|
||||
|
||||
return question
|
||||
# return question
|
||||
|
||||
|
||||
def create_tool_call_string(tool_name: str, query_list: list[str]) -> str:
|
||||
"""
|
||||
Create a string representation of the tool call.
|
||||
"""
|
||||
questions_str = "\n - ".join(query_list)
|
||||
return f"Tool: {tool_name}\n\nQuestions:\n{questions_str}"
|
||||
# def create_tool_call_string(tool_name: str, query_list: list[str]) -> str:
|
||||
# """
|
||||
# Create a string representation of the tool call.
|
||||
# """
|
||||
# questions_str = "\n - ".join(query_list)
|
||||
# return f"Tool: {tool_name}\n\nQuestions:\n{questions_str}"
|
||||
|
||||
|
||||
def parse_plan_to_dict(plan_text: str) -> dict[str, str]:
|
||||
# Convert plan string to numbered dict format
|
||||
if not plan_text:
|
||||
return {}
|
||||
# def parse_plan_to_dict(plan_text: str) -> dict[str, str]:
|
||||
# # Convert plan string to numbered dict format
|
||||
# if not plan_text:
|
||||
# return {}
|
||||
|
||||
# Split by numbered items (1., 2., 3., etc. or 1), 2), 3), etc.)
|
||||
parts = re.split(r"(\d+[.)])", plan_text)
|
||||
plan_dict = {}
|
||||
# # Split by numbered items (1., 2., 3., etc. or 1), 2), 3), etc.)
|
||||
# parts = re.split(r"(\d+[.)])", plan_text)
|
||||
# plan_dict = {}
|
||||
|
||||
for i in range(
|
||||
1, len(parts), 2
|
||||
): # Skip empty first part, then take number and text pairs
|
||||
if i + 1 < len(parts):
|
||||
number = parts[i].rstrip(".)") # Remove the dot or parenthesis
|
||||
text = parts[i + 1].strip()
|
||||
if text: # Only add if there's actual content
|
||||
plan_dict[number] = text
|
||||
# for i in range(
|
||||
# 1, len(parts), 2
|
||||
# ): # Skip empty first part, then take number and text pairs
|
||||
# if i + 1 < len(parts):
|
||||
# number = parts[i].rstrip(".)") # Remove the dot or parenthesis
|
||||
# text = parts[i + 1].strip()
|
||||
# if text: # Only add if there's actual content
|
||||
# plan_dict[number] = text
|
||||
|
||||
return plan_dict
|
||||
# return plan_dict
|
||||
|
||||
|
||||
def convert_inference_sections_to_search_docs(
|
||||
inference_sections: list[InferenceSection],
|
||||
is_internet: bool = False,
|
||||
) -> list[SavedSearchDoc]:
|
||||
# Convert InferenceSections to SavedSearchDocs
|
||||
search_docs = SearchDoc.from_chunks_or_sections(inference_sections)
|
||||
for search_doc in search_docs:
|
||||
search_doc.is_internet = is_internet
|
||||
# def convert_inference_sections_to_search_docs(
|
||||
# inference_sections: list[InferenceSection],
|
||||
# is_internet: bool = False,
|
||||
# ) -> list[SavedSearchDoc]:
|
||||
# # Convert InferenceSections to SavedSearchDocs
|
||||
# search_docs = SearchDoc.from_chunks_or_sections(inference_sections)
|
||||
# for search_doc in search_docs:
|
||||
# search_doc.is_internet = is_internet
|
||||
|
||||
retrieved_saved_search_docs = [
|
||||
SavedSearchDoc.from_search_doc(search_doc, db_doc_id=0)
|
||||
for search_doc in search_docs
|
||||
]
|
||||
return retrieved_saved_search_docs
|
||||
# retrieved_saved_search_docs = [
|
||||
# SavedSearchDoc.from_search_doc(search_doc, db_doc_id=0)
|
||||
# for search_doc in search_docs
|
||||
# ]
|
||||
# return retrieved_saved_search_docs
|
||||
|
||||
@@ -1,87 +1,87 @@
|
||||
from collections.abc import Hashable
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
# from collections.abc import Hashable
|
||||
# from datetime import datetime
|
||||
# from enum import Enum
|
||||
|
||||
from langgraph.types import Send
|
||||
# from langgraph.types import Send
|
||||
|
||||
from onyx.agents.agent_search.kb_search.states import KGAnswerStrategy
|
||||
from onyx.agents.agent_search.kb_search.states import KGSearchType
|
||||
from onyx.agents.agent_search.kb_search.states import KGSourceDivisionType
|
||||
from onyx.agents.agent_search.kb_search.states import MainState
|
||||
from onyx.agents.agent_search.kb_search.states import ResearchObjectInput
|
||||
from onyx.configs.kg_configs import KG_MAX_DECOMPOSITION_SEGMENTS
|
||||
from onyx.utils.logger import setup_logger
|
||||
# from onyx.agents.agent_search.kb_search.states import KGAnswerStrategy
|
||||
# from onyx.agents.agent_search.kb_search.states import KGSearchType
|
||||
# from onyx.agents.agent_search.kb_search.states import KGSourceDivisionType
|
||||
# from onyx.agents.agent_search.kb_search.states import MainState
|
||||
# from onyx.agents.agent_search.kb_search.states import ResearchObjectInput
|
||||
# from onyx.configs.kg_configs import KG_MAX_DECOMPOSITION_SEGMENTS
|
||||
# from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
# logger = setup_logger()
|
||||
|
||||
|
||||
class KGAnalysisPath(str, Enum):
|
||||
PROCESS_KG_ONLY_ANSWERS = "process_kg_only_answers"
|
||||
CONSTRUCT_DEEP_SEARCH_FILTERS = "construct_deep_search_filters"
|
||||
# class KGAnalysisPath(str, Enum):
|
||||
# PROCESS_KG_ONLY_ANSWERS = "process_kg_only_answers"
|
||||
# CONSTRUCT_DEEP_SEARCH_FILTERS = "construct_deep_search_filters"
|
||||
|
||||
|
||||
def simple_vs_search(
|
||||
state: MainState,
|
||||
) -> str:
|
||||
# def simple_vs_search(
|
||||
# state: MainState,
|
||||
# ) -> str:
|
||||
|
||||
identified_strategy = state.updated_strategy or state.strategy
|
||||
# identified_strategy = state.updated_strategy or state.strategy
|
||||
|
||||
if (
|
||||
identified_strategy == KGAnswerStrategy.DEEP
|
||||
or state.search_type == KGSearchType.SEARCH
|
||||
):
|
||||
return KGAnalysisPath.CONSTRUCT_DEEP_SEARCH_FILTERS.value
|
||||
else:
|
||||
return KGAnalysisPath.PROCESS_KG_ONLY_ANSWERS.value
|
||||
# if (
|
||||
# identified_strategy == KGAnswerStrategy.DEEP
|
||||
# or state.search_type == KGSearchType.SEARCH
|
||||
# ):
|
||||
# return KGAnalysisPath.CONSTRUCT_DEEP_SEARCH_FILTERS.value
|
||||
# else:
|
||||
# return KGAnalysisPath.PROCESS_KG_ONLY_ANSWERS.value
|
||||
|
||||
|
||||
def research_individual_object(
|
||||
state: MainState,
|
||||
) -> list[Send | Hashable] | str:
|
||||
edge_start_time = datetime.now()
|
||||
# def research_individual_object(
|
||||
# state: MainState,
|
||||
# ) -> list[Send | Hashable] | str:
|
||||
# edge_start_time = datetime.now()
|
||||
|
||||
assert state.div_con_entities is not None
|
||||
assert state.broken_down_question is not None
|
||||
assert state.vespa_filter_results is not None
|
||||
# assert state.div_con_entities is not None
|
||||
# assert state.broken_down_question is not None
|
||||
# assert state.vespa_filter_results is not None
|
||||
|
||||
if (
|
||||
state.search_type == KGSearchType.SQL
|
||||
and state.strategy == KGAnswerStrategy.DEEP
|
||||
):
|
||||
# if (
|
||||
# state.search_type == KGSearchType.SQL
|
||||
# and state.strategy == KGAnswerStrategy.DEEP
|
||||
# ):
|
||||
|
||||
if state.source_filters and state.source_division:
|
||||
segments = state.source_filters
|
||||
segment_type = KGSourceDivisionType.SOURCE.value
|
||||
else:
|
||||
segments = state.div_con_entities
|
||||
segment_type = KGSourceDivisionType.ENTITY.value
|
||||
# if state.source_filters and state.source_division:
|
||||
# segments = state.source_filters
|
||||
# segment_type = KGSourceDivisionType.SOURCE.value
|
||||
# else:
|
||||
# segments = state.div_con_entities
|
||||
# segment_type = KGSourceDivisionType.ENTITY.value
|
||||
|
||||
if segments and (len(segments) > KG_MAX_DECOMPOSITION_SEGMENTS):
|
||||
logger.debug(f"Too many sources ({len(segments)}), usingfiltered search")
|
||||
return "filtered_search"
|
||||
# if segments and (len(segments) > KG_MAX_DECOMPOSITION_SEGMENTS):
|
||||
# logger.debug(f"Too many sources ({len(segments)}), usingfiltered search")
|
||||
# return "filtered_search"
|
||||
|
||||
return [
|
||||
Send(
|
||||
"process_individual_deep_search",
|
||||
ResearchObjectInput(
|
||||
research_nr=research_nr + 1,
|
||||
entity=entity,
|
||||
broken_down_question=state.broken_down_question,
|
||||
vespa_filter_results=state.vespa_filter_results,
|
||||
source_division=state.source_division,
|
||||
source_entity_filters=state.source_filters,
|
||||
segment_type=segment_type,
|
||||
log_messages=[
|
||||
f"{edge_start_time} -- Main Edge - Parallelize Initial Sub-question Answering"
|
||||
],
|
||||
step_results=[],
|
||||
),
|
||||
)
|
||||
for research_nr, entity in enumerate(segments)
|
||||
]
|
||||
elif state.search_type == KGSearchType.SEARCH:
|
||||
return "filtered_search"
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid combination of search type: {state.search_type} and strategy: {state.strategy}"
|
||||
)
|
||||
# return [
|
||||
# Send(
|
||||
# "process_individual_deep_search",
|
||||
# ResearchObjectInput(
|
||||
# research_nr=research_nr + 1,
|
||||
# entity=entity,
|
||||
# broken_down_question=state.broken_down_question,
|
||||
# vespa_filter_results=state.vespa_filter_results,
|
||||
# source_division=state.source_division,
|
||||
# source_entity_filters=state.source_filters,
|
||||
# segment_type=segment_type,
|
||||
# log_messages=[
|
||||
# f"{edge_start_time} -- Main Edge - Parallelize Initial Sub-question Answering"
|
||||
# ],
|
||||
# step_results=[],
|
||||
# ),
|
||||
# )
|
||||
# for research_nr, entity in enumerate(segments)
|
||||
# ]
|
||||
# elif state.search_type == KGSearchType.SEARCH:
|
||||
# return "filtered_search"
|
||||
# else:
|
||||
# raise ValueError(
|
||||
# f"Invalid combination of search type: {state.search_type} and strategy: {state.strategy}"
|
||||
# )
|
||||
|
||||
@@ -1,143 +1,143 @@
|
||||
from langgraph.graph import END
|
||||
from langgraph.graph import START
|
||||
from langgraph.graph import StateGraph
|
||||
# from langgraph.graph import END
|
||||
# from langgraph.graph import START
|
||||
# from langgraph.graph import StateGraph
|
||||
|
||||
from onyx.agents.agent_search.kb_search.conditional_edges import (
|
||||
research_individual_object,
|
||||
)
|
||||
from onyx.agents.agent_search.kb_search.conditional_edges import simple_vs_search
|
||||
from onyx.agents.agent_search.kb_search.nodes.a1_extract_ert import extract_ert
|
||||
from onyx.agents.agent_search.kb_search.nodes.a2_analyze import analyze
|
||||
from onyx.agents.agent_search.kb_search.nodes.a3_generate_simple_sql import (
|
||||
generate_simple_sql,
|
||||
)
|
||||
from onyx.agents.agent_search.kb_search.nodes.b1_construct_deep_search_filters import (
|
||||
construct_deep_search_filters,
|
||||
)
|
||||
from onyx.agents.agent_search.kb_search.nodes.b2p_process_individual_deep_search import (
|
||||
process_individual_deep_search,
|
||||
)
|
||||
from onyx.agents.agent_search.kb_search.nodes.b2s_filtered_search import filtered_search
|
||||
from onyx.agents.agent_search.kb_search.nodes.b3_consolidate_individual_deep_search import (
|
||||
consolidate_individual_deep_search,
|
||||
)
|
||||
from onyx.agents.agent_search.kb_search.nodes.c1_process_kg_only_answers import (
|
||||
process_kg_only_answers,
|
||||
)
|
||||
from onyx.agents.agent_search.kb_search.nodes.d1_generate_answer import generate_answer
|
||||
from onyx.agents.agent_search.kb_search.nodes.d2_logging_node import log_data
|
||||
from onyx.agents.agent_search.kb_search.states import MainInput
|
||||
from onyx.agents.agent_search.kb_search.states import MainState
|
||||
from onyx.utils.logger import setup_logger
|
||||
# from onyx.agents.agent_search.kb_search.conditional_edges import (
|
||||
# research_individual_object,
|
||||
# )
|
||||
# from onyx.agents.agent_search.kb_search.conditional_edges import simple_vs_search
|
||||
# from onyx.agents.agent_search.kb_search.nodes.a1_extract_ert import extract_ert
|
||||
# from onyx.agents.agent_search.kb_search.nodes.a2_analyze import analyze
|
||||
# from onyx.agents.agent_search.kb_search.nodes.a3_generate_simple_sql import (
|
||||
# generate_simple_sql,
|
||||
# )
|
||||
# from onyx.agents.agent_search.kb_search.nodes.b1_construct_deep_search_filters import (
|
||||
# construct_deep_search_filters,
|
||||
# )
|
||||
# from onyx.agents.agent_search.kb_search.nodes.b2p_process_individual_deep_search import (
|
||||
# process_individual_deep_search,
|
||||
# )
|
||||
# from onyx.agents.agent_search.kb_search.nodes.b2s_filtered_search import filtered_search
|
||||
# from onyx.agents.agent_search.kb_search.nodes.b3_consolidate_individual_deep_search import (
|
||||
# consolidate_individual_deep_search,
|
||||
# )
|
||||
# from onyx.agents.agent_search.kb_search.nodes.c1_process_kg_only_answers import (
|
||||
# process_kg_only_answers,
|
||||
# )
|
||||
# from onyx.agents.agent_search.kb_search.nodes.d1_generate_answer import generate_answer
|
||||
# from onyx.agents.agent_search.kb_search.nodes.d2_logging_node import log_data
|
||||
# from onyx.agents.agent_search.kb_search.states import MainInput
|
||||
# from onyx.agents.agent_search.kb_search.states import MainState
|
||||
# from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
# logger = setup_logger()
|
||||
|
||||
|
||||
def kb_graph_builder() -> StateGraph:
|
||||
"""
|
||||
LangGraph graph builder for the knowledge graph search process.
|
||||
"""
|
||||
# def kb_graph_builder() -> StateGraph:
|
||||
# """
|
||||
# LangGraph graph builder for the knowledge graph search process.
|
||||
# """
|
||||
|
||||
graph = StateGraph(
|
||||
state_schema=MainState,
|
||||
input=MainInput,
|
||||
)
|
||||
# graph = StateGraph(
|
||||
# state_schema=MainState,
|
||||
# input=MainInput,
|
||||
# )
|
||||
|
||||
### Add nodes ###
|
||||
# ### Add nodes ###
|
||||
|
||||
graph.add_node(
|
||||
"extract_ert",
|
||||
extract_ert,
|
||||
)
|
||||
# graph.add_node(
|
||||
# "extract_ert",
|
||||
# extract_ert,
|
||||
# )
|
||||
|
||||
graph.add_node(
|
||||
"generate_simple_sql",
|
||||
generate_simple_sql,
|
||||
)
|
||||
# graph.add_node(
|
||||
# "generate_simple_sql",
|
||||
# generate_simple_sql,
|
||||
# )
|
||||
|
||||
graph.add_node(
|
||||
"filtered_search",
|
||||
filtered_search,
|
||||
)
|
||||
# graph.add_node(
|
||||
# "filtered_search",
|
||||
# filtered_search,
|
||||
# )
|
||||
|
||||
graph.add_node(
|
||||
"analyze",
|
||||
analyze,
|
||||
)
|
||||
# graph.add_node(
|
||||
# "analyze",
|
||||
# analyze,
|
||||
# )
|
||||
|
||||
graph.add_node(
|
||||
"generate_answer",
|
||||
generate_answer,
|
||||
)
|
||||
# graph.add_node(
|
||||
# "generate_answer",
|
||||
# generate_answer,
|
||||
# )
|
||||
|
||||
graph.add_node(
|
||||
"log_data",
|
||||
log_data,
|
||||
)
|
||||
# graph.add_node(
|
||||
# "log_data",
|
||||
# log_data,
|
||||
# )
|
||||
|
||||
graph.add_node(
|
||||
"construct_deep_search_filters",
|
||||
construct_deep_search_filters,
|
||||
)
|
||||
# graph.add_node(
|
||||
# "construct_deep_search_filters",
|
||||
# construct_deep_search_filters,
|
||||
# )
|
||||
|
||||
graph.add_node(
|
||||
"process_individual_deep_search",
|
||||
process_individual_deep_search,
|
||||
)
|
||||
# graph.add_node(
|
||||
# "process_individual_deep_search",
|
||||
# process_individual_deep_search,
|
||||
# )
|
||||
|
||||
graph.add_node(
|
||||
"consolidate_individual_deep_search",
|
||||
consolidate_individual_deep_search,
|
||||
)
|
||||
# graph.add_node(
|
||||
# "consolidate_individual_deep_search",
|
||||
# consolidate_individual_deep_search,
|
||||
# )
|
||||
|
||||
graph.add_node("process_kg_only_answers", process_kg_only_answers)
|
||||
# graph.add_node("process_kg_only_answers", process_kg_only_answers)
|
||||
|
||||
### Add edges ###
|
||||
# ### Add edges ###
|
||||
|
||||
graph.add_edge(start_key=START, end_key="extract_ert")
|
||||
# graph.add_edge(start_key=START, end_key="extract_ert")
|
||||
|
||||
graph.add_edge(
|
||||
start_key="extract_ert",
|
||||
end_key="analyze",
|
||||
)
|
||||
# graph.add_edge(
|
||||
# start_key="extract_ert",
|
||||
# end_key="analyze",
|
||||
# )
|
||||
|
||||
graph.add_edge(
|
||||
start_key="analyze",
|
||||
end_key="generate_simple_sql",
|
||||
)
|
||||
# graph.add_edge(
|
||||
# start_key="analyze",
|
||||
# end_key="generate_simple_sql",
|
||||
# )
|
||||
|
||||
graph.add_conditional_edges("generate_simple_sql", simple_vs_search)
|
||||
# graph.add_conditional_edges("generate_simple_sql", simple_vs_search)
|
||||
|
||||
graph.add_edge(start_key="process_kg_only_answers", end_key="generate_answer")
|
||||
# graph.add_edge(start_key="process_kg_only_answers", end_key="generate_answer")
|
||||
|
||||
graph.add_conditional_edges(
|
||||
source="construct_deep_search_filters",
|
||||
path=research_individual_object,
|
||||
path_map=["process_individual_deep_search", "filtered_search"],
|
||||
)
|
||||
# graph.add_conditional_edges(
|
||||
# source="construct_deep_search_filters",
|
||||
# path=research_individual_object,
|
||||
# path_map=["process_individual_deep_search", "filtered_search"],
|
||||
# )
|
||||
|
||||
graph.add_edge(
|
||||
start_key="process_individual_deep_search",
|
||||
end_key="consolidate_individual_deep_search",
|
||||
)
|
||||
# graph.add_edge(
|
||||
# start_key="process_individual_deep_search",
|
||||
# end_key="consolidate_individual_deep_search",
|
||||
# )
|
||||
|
||||
graph.add_edge(
|
||||
start_key="consolidate_individual_deep_search", end_key="generate_answer"
|
||||
)
|
||||
# graph.add_edge(
|
||||
# start_key="consolidate_individual_deep_search", end_key="generate_answer"
|
||||
# )
|
||||
|
||||
graph.add_edge(
|
||||
start_key="filtered_search",
|
||||
end_key="generate_answer",
|
||||
)
|
||||
# graph.add_edge(
|
||||
# start_key="filtered_search",
|
||||
# end_key="generate_answer",
|
||||
# )
|
||||
|
||||
graph.add_edge(
|
||||
start_key="generate_answer",
|
||||
end_key="log_data",
|
||||
)
|
||||
# graph.add_edge(
|
||||
# start_key="generate_answer",
|
||||
# end_key="log_data",
|
||||
# )
|
||||
|
||||
graph.add_edge(
|
||||
start_key="log_data",
|
||||
end_key=END,
|
||||
)
|
||||
# graph.add_edge(
|
||||
# start_key="log_data",
|
||||
# end_key=END,
|
||||
# )
|
||||
|
||||
return graph
|
||||
# return graph
|
||||
|
||||
@@ -1,244 +1,244 @@
|
||||
import re
|
||||
# import re
|
||||
|
||||
from onyx.agents.agent_search.kb_search.models import KGEntityDocInfo
|
||||
from onyx.agents.agent_search.kb_search.models import KGExpandedGraphObjects
|
||||
from onyx.agents.agent_search.kb_search.states import SubQuestionAnswerResults
|
||||
from onyx.agents.agent_search.kb_search.step_definitions import (
|
||||
KG_SEARCH_STEP_DESCRIPTIONS,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkRetrievalStats
|
||||
from onyx.chat.models import LlmDoc
|
||||
from onyx.context.search.models import InferenceChunk
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.db.document import get_kg_doc_info_for_entity_name
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.entities import get_document_id_for_entity
|
||||
from onyx.db.entities import get_entity_name
|
||||
from onyx.db.entity_type import get_entity_types
|
||||
from onyx.kg.utils.formatting_utils import make_entity_id
|
||||
from onyx.kg.utils.formatting_utils import split_relationship_id
|
||||
from onyx.utils.logger import setup_logger
|
||||
# from onyx.agents.agent_search.kb_search.models import KGEntityDocInfo
|
||||
# from onyx.agents.agent_search.kb_search.models import KGExpandedGraphObjects
|
||||
# from onyx.agents.agent_search.kb_search.states import SubQuestionAnswerResults
|
||||
# from onyx.agents.agent_search.kb_search.step_definitions import (
|
||||
# KG_SEARCH_STEP_DESCRIPTIONS,
|
||||
# )
|
||||
# from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkRetrievalStats
|
||||
# from onyx.chat.models import LlmDoc
|
||||
# from onyx.context.search.models import InferenceChunk
|
||||
# from onyx.context.search.models import InferenceSection
|
||||
# from onyx.db.document import get_kg_doc_info_for_entity_name
|
||||
# from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
# from onyx.db.entities import get_document_id_for_entity
|
||||
# from onyx.db.entities import get_entity_name
|
||||
# from onyx.db.entity_type import get_entity_types
|
||||
# from onyx.kg.utils.formatting_utils import make_entity_id
|
||||
# from onyx.kg.utils.formatting_utils import split_relationship_id
|
||||
# from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
# logger = setup_logger()
|
||||
|
||||
|
||||
def _check_entities_disconnected(
|
||||
current_entities: list[str], current_relationships: list[str]
|
||||
) -> bool:
|
||||
"""
|
||||
Check if all entities in current_entities are disconnected via the given relationships.
|
||||
Relationships are in the format: source_entity__relationship_name__target_entity
|
||||
# def _check_entities_disconnected(
|
||||
# current_entities: list[str], current_relationships: list[str]
|
||||
# ) -> bool:
|
||||
# """
|
||||
# Check if all entities in current_entities are disconnected via the given relationships.
|
||||
# Relationships are in the format: source_entity__relationship_name__target_entity
|
||||
|
||||
Args:
|
||||
current_entities: List of entity IDs to check connectivity for
|
||||
current_relationships: List of relationships in format source__relationship__target
|
||||
# Args:
|
||||
# current_entities: List of entity IDs to check connectivity for
|
||||
# current_relationships: List of relationships in format source__relationship__target
|
||||
|
||||
Returns:
|
||||
bool: True if all entities are disconnected, False otherwise
|
||||
"""
|
||||
if not current_entities:
|
||||
return True
|
||||
# Returns:
|
||||
# bool: True if all entities are disconnected, False otherwise
|
||||
# """
|
||||
# if not current_entities:
|
||||
# return True
|
||||
|
||||
# Create a graph representation using adjacency list
|
||||
graph: dict[str, set[str]] = {entity: set() for entity in current_entities}
|
||||
# # Create a graph representation using adjacency list
|
||||
# graph: dict[str, set[str]] = {entity: set() for entity in current_entities}
|
||||
|
||||
# Build the graph from relationships
|
||||
for relationship in current_relationships:
|
||||
try:
|
||||
source, _, target = split_relationship_id(relationship)
|
||||
if source in graph and target in graph:
|
||||
graph[source].add(target)
|
||||
# Add reverse edge to capture that we do also have a relationship in the other direction,
|
||||
# albeit not quite the same one.
|
||||
graph[target].add(source)
|
||||
except ValueError:
|
||||
raise ValueError(f"Invalid relationship format: {relationship}")
|
||||
# # Build the graph from relationships
|
||||
# for relationship in current_relationships:
|
||||
# try:
|
||||
# source, _, target = split_relationship_id(relationship)
|
||||
# if source in graph and target in graph:
|
||||
# graph[source].add(target)
|
||||
# # Add reverse edge to capture that we do also have a relationship in the other direction,
|
||||
# # albeit not quite the same one.
|
||||
# graph[target].add(source)
|
||||
# except ValueError:
|
||||
# raise ValueError(f"Invalid relationship format: {relationship}")
|
||||
|
||||
# Use BFS to check if all entities are connected
|
||||
visited: set[str] = set()
|
||||
start_entity = current_entities[0]
|
||||
# # Use BFS to check if all entities are connected
|
||||
# visited: set[str] = set()
|
||||
# start_entity = current_entities[0]
|
||||
|
||||
def _bfs(start: str) -> None:
|
||||
queue = [start]
|
||||
visited.add(start)
|
||||
while queue:
|
||||
current = queue.pop(0)
|
||||
for neighbor in graph[current]:
|
||||
if neighbor not in visited:
|
||||
visited.add(neighbor)
|
||||
queue.append(neighbor)
|
||||
# def _bfs(start: str) -> None:
|
||||
# queue = [start]
|
||||
# visited.add(start)
|
||||
# while queue:
|
||||
# current = queue.pop(0)
|
||||
# for neighbor in graph[current]:
|
||||
# if neighbor not in visited:
|
||||
# visited.add(neighbor)
|
||||
# queue.append(neighbor)
|
||||
|
||||
# Start BFS from the first entity
|
||||
_bfs(start_entity)
|
||||
# # Start BFS from the first entity
|
||||
# _bfs(start_entity)
|
||||
|
||||
logger.debug(f"Number of visited entities: {len(visited)}")
|
||||
# logger.debug(f"Number of visited entities: {len(visited)}")
|
||||
|
||||
# Check if all current_entities are in visited
|
||||
return not all(entity in visited for entity in current_entities)
|
||||
# # Check if all current_entities are in visited
|
||||
# return not all(entity in visited for entity in current_entities)
|
||||
|
||||
|
||||
def create_minimal_connected_query_graph(
|
||||
entities: list[str], relationships: list[str], max_depth: int = 1
|
||||
) -> KGExpandedGraphObjects:
|
||||
"""
|
||||
TODO: Implement this. For now we'll trust the SQL generation to do the right thing.
|
||||
Return the original entities and relationships.
|
||||
"""
|
||||
return KGExpandedGraphObjects(entities=entities, relationships=relationships)
|
||||
# def create_minimal_connected_query_graph(
|
||||
# entities: list[str], relationships: list[str], max_depth: int = 1
|
||||
# ) -> KGExpandedGraphObjects:
|
||||
# """
|
||||
# TODO: Implement this. For now we'll trust the SQL generation to do the right thing.
|
||||
# Return the original entities and relationships.
|
||||
# """
|
||||
# return KGExpandedGraphObjects(entities=entities, relationships=relationships)
|
||||
|
||||
|
||||
def get_doc_information_for_entity(entity_id_name: str) -> KGEntityDocInfo:
|
||||
"""
|
||||
Get document information for an entity, including its semantic name and document details.
|
||||
"""
|
||||
if "::" not in entity_id_name:
|
||||
return KGEntityDocInfo(
|
||||
doc_id=None,
|
||||
doc_semantic_id=None,
|
||||
doc_link=None,
|
||||
semantic_entity_name=entity_id_name,
|
||||
semantic_linked_entity_name=entity_id_name,
|
||||
)
|
||||
# def get_doc_information_for_entity(entity_id_name: str) -> KGEntityDocInfo:
|
||||
# """
|
||||
# Get document information for an entity, including its semantic name and document details.
|
||||
# """
|
||||
# if "::" not in entity_id_name:
|
||||
# return KGEntityDocInfo(
|
||||
# doc_id=None,
|
||||
# doc_semantic_id=None,
|
||||
# doc_link=None,
|
||||
# semantic_entity_name=entity_id_name,
|
||||
# semantic_linked_entity_name=entity_id_name,
|
||||
# )
|
||||
|
||||
entity_type, entity_name = map(str.strip, entity_id_name.split("::", 1))
|
||||
# entity_type, entity_name = map(str.strip, entity_id_name.split("::", 1))
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
entity_document_id = get_document_id_for_entity(db_session, entity_id_name)
|
||||
if entity_document_id:
|
||||
return get_kg_doc_info_for_entity_name(
|
||||
db_session, entity_document_id, entity_type
|
||||
)
|
||||
else:
|
||||
entity_actual_name = get_entity_name(db_session, entity_id_name)
|
||||
# with get_session_with_current_tenant() as db_session:
|
||||
# entity_document_id = get_document_id_for_entity(db_session, entity_id_name)
|
||||
# if entity_document_id:
|
||||
# return get_kg_doc_info_for_entity_name(
|
||||
# db_session, entity_document_id, entity_type
|
||||
# )
|
||||
# else:
|
||||
# entity_actual_name = get_entity_name(db_session, entity_id_name)
|
||||
|
||||
return KGEntityDocInfo(
|
||||
doc_id=None,
|
||||
doc_semantic_id=None,
|
||||
doc_link=None,
|
||||
semantic_entity_name=f"{entity_type} {entity_actual_name or entity_id_name}",
|
||||
semantic_linked_entity_name=f"{entity_type} {entity_actual_name or entity_id_name}",
|
||||
)
|
||||
# return KGEntityDocInfo(
|
||||
# doc_id=None,
|
||||
# doc_semantic_id=None,
|
||||
# doc_link=None,
|
||||
# semantic_entity_name=f"{entity_type} {entity_actual_name or entity_id_name}",
|
||||
# semantic_linked_entity_name=f"{entity_type} {entity_actual_name or entity_id_name}",
|
||||
# )
|
||||
|
||||
|
||||
def rename_entities_in_answer(answer: str) -> str:
|
||||
"""
|
||||
Process entity references in the answer string by:
|
||||
1. Extracting all strings matching <str>:<str> or <str>: <str> patterns
|
||||
2. Looking up these references in the entity table
|
||||
3. Replacing valid references with their corresponding values
|
||||
# def rename_entities_in_answer(answer: str) -> str:
|
||||
# """
|
||||
# Process entity references in the answer string by:
|
||||
# 1. Extracting all strings matching <str>:<str> or <str>: <str> patterns
|
||||
# 2. Looking up these references in the entity table
|
||||
# 3. Replacing valid references with their corresponding values
|
||||
|
||||
Args:
|
||||
answer: The input string containing potential entity references
|
||||
# Args:
|
||||
# answer: The input string containing potential entity references
|
||||
|
||||
Returns:
|
||||
str: The processed string with entity references replaced
|
||||
"""
|
||||
logger.debug(f"Input answer: {answer}")
|
||||
# Returns:
|
||||
# str: The processed string with entity references replaced
|
||||
# """
|
||||
# logger.debug(f"Input answer: {answer}")
|
||||
|
||||
# Clean up any spaces around ::
|
||||
answer = re.sub(r"::\s+", "::", answer)
|
||||
# # Clean up any spaces around ::
|
||||
# answer = re.sub(r"::\s+", "::", answer)
|
||||
|
||||
# Pattern to match entity_type::entity_name, with optional quotes
|
||||
pattern = r"(?:')?([a-zA-Z0-9-]+)::([a-zA-Z0-9]+)(?:')?"
|
||||
# # Pattern to match entity_type::entity_name, with optional quotes
|
||||
# pattern = r"(?:')?([a-zA-Z0-9-]+)::([a-zA-Z0-9]+)(?:')?"
|
||||
|
||||
matches = list(re.finditer(pattern, answer))
|
||||
logger.debug(f"Found {len(matches)} matches")
|
||||
# matches = list(re.finditer(pattern, answer))
|
||||
# logger.debug(f"Found {len(matches)} matches")
|
||||
|
||||
# get active entity types
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
active_entity_types = [
|
||||
x.id_name for x in get_entity_types(db_session, active=True)
|
||||
]
|
||||
logger.debug(f"Active entity types: {active_entity_types}")
|
||||
# # get active entity types
|
||||
# with get_session_with_current_tenant() as db_session:
|
||||
# active_entity_types = [
|
||||
# x.id_name for x in get_entity_types(db_session, active=True)
|
||||
# ]
|
||||
# logger.debug(f"Active entity types: {active_entity_types}")
|
||||
|
||||
# Create dictionary for processed references
|
||||
processed_refs = {}
|
||||
# # Create dictionary for processed references
|
||||
# processed_refs = {}
|
||||
|
||||
for match in matches:
|
||||
entity_type = match.group(1).upper().strip()
|
||||
entity_name = match.group(2).strip()
|
||||
potential_entity_id_name = make_entity_id(entity_type, entity_name)
|
||||
# for match in matches:
|
||||
# entity_type = match.group(1).upper().strip()
|
||||
# entity_name = match.group(2).strip()
|
||||
# potential_entity_id_name = make_entity_id(entity_type, entity_name)
|
||||
|
||||
if entity_type not in active_entity_types:
|
||||
continue
|
||||
# if entity_type not in active_entity_types:
|
||||
# continue
|
||||
|
||||
replacement_candidate = get_doc_information_for_entity(potential_entity_id_name)
|
||||
# replacement_candidate = get_doc_information_for_entity(potential_entity_id_name)
|
||||
|
||||
if replacement_candidate.doc_id:
|
||||
# Store both the original match and the entity_id_name for replacement
|
||||
processed_refs[match.group(0)] = (
|
||||
replacement_candidate.semantic_linked_entity_name
|
||||
)
|
||||
else:
|
||||
processed_refs[match.group(0)] = replacement_candidate.semantic_entity_name
|
||||
# if replacement_candidate.doc_id:
|
||||
# # Store both the original match and the entity_id_name for replacement
|
||||
# processed_refs[match.group(0)] = (
|
||||
# replacement_candidate.semantic_linked_entity_name
|
||||
# )
|
||||
# else:
|
||||
# processed_refs[match.group(0)] = replacement_candidate.semantic_entity_name
|
||||
|
||||
# Replace all references in the answer
|
||||
for ref, replacement in processed_refs.items():
|
||||
answer = answer.replace(ref, replacement)
|
||||
logger.debug(f"Replaced {ref} with {replacement}")
|
||||
# # Replace all references in the answer
|
||||
# for ref, replacement in processed_refs.items():
|
||||
# answer = answer.replace(ref, replacement)
|
||||
# logger.debug(f"Replaced {ref} with {replacement}")
|
||||
|
||||
return answer
|
||||
# return answer
|
||||
|
||||
|
||||
def build_document_context(
|
||||
document: InferenceSection | LlmDoc, document_number: int
|
||||
) -> str:
|
||||
"""
|
||||
Build a context string for a document.
|
||||
"""
|
||||
# def build_document_context(
|
||||
# document: InferenceSection | LlmDoc, document_number: int
|
||||
# ) -> str:
|
||||
# """
|
||||
# Build a context string for a document.
|
||||
# """
|
||||
|
||||
metadata_list: list[str] = []
|
||||
document_content: str | None = None
|
||||
info_source: InferenceChunk | LlmDoc | None = None
|
||||
info_content: str | None = None
|
||||
# metadata_list: list[str] = []
|
||||
# document_content: str | None = None
|
||||
# info_source: InferenceChunk | LlmDoc | None = None
|
||||
# info_content: str | None = None
|
||||
|
||||
if isinstance(document, InferenceSection):
|
||||
info_source = document.center_chunk
|
||||
info_content = document.combined_content
|
||||
elif isinstance(document, LlmDoc):
|
||||
info_source = document
|
||||
info_content = document.content
|
||||
# if isinstance(document, InferenceSection):
|
||||
# info_source = document.center_chunk
|
||||
# info_content = document.combined_content
|
||||
# elif isinstance(document, LlmDoc):
|
||||
# info_source = document
|
||||
# info_content = document.content
|
||||
|
||||
for key, value in info_source.metadata.items():
|
||||
metadata_list.append(f" - {key}: {value}")
|
||||
# for key, value in info_source.metadata.items():
|
||||
# metadata_list.append(f" - {key}: {value}")
|
||||
|
||||
if metadata_list:
|
||||
metadata_str = "- Document Metadata:\n" + "\n".join(metadata_list)
|
||||
else:
|
||||
metadata_str = ""
|
||||
# if metadata_list:
|
||||
# metadata_str = "- Document Metadata:\n" + "\n".join(metadata_list)
|
||||
# else:
|
||||
# metadata_str = ""
|
||||
|
||||
# Construct document header with number and semantic identifier
|
||||
doc_header = f"Document {str(document_number)}: {info_source.semantic_identifier}"
|
||||
# # Construct document header with number and semantic identifier
|
||||
# doc_header = f"Document {str(document_number)}: {info_source.semantic_identifier}"
|
||||
|
||||
# Combine all parts with proper spacing
|
||||
document_content = f"{doc_header}\n\n{metadata_str}\n\n{info_content}"
|
||||
# # Combine all parts with proper spacing
|
||||
# document_content = f"{doc_header}\n\n{metadata_str}\n\n{info_content}"
|
||||
|
||||
return document_content
|
||||
# return document_content
|
||||
|
||||
|
||||
def get_near_empty_step_results(
|
||||
step_number: int,
|
||||
step_answer: str,
|
||||
verified_reranked_documents: list[InferenceSection] = [],
|
||||
) -> SubQuestionAnswerResults:
|
||||
"""
|
||||
Get near-empty step results from a list of step results.
|
||||
"""
|
||||
return SubQuestionAnswerResults(
|
||||
question=KG_SEARCH_STEP_DESCRIPTIONS[step_number].description,
|
||||
question_id="0_" + str(step_number),
|
||||
answer=step_answer,
|
||||
verified_high_quality=True,
|
||||
sub_query_retrieval_results=[],
|
||||
verified_reranked_documents=verified_reranked_documents,
|
||||
context_documents=[],
|
||||
cited_documents=[],
|
||||
sub_question_retrieval_stats=AgentChunkRetrievalStats(
|
||||
verified_count=None,
|
||||
verified_avg_scores=None,
|
||||
rejected_count=None,
|
||||
rejected_avg_scores=None,
|
||||
verified_doc_chunk_ids=[],
|
||||
dismissed_doc_chunk_ids=[],
|
||||
),
|
||||
)
|
||||
# def get_near_empty_step_results(
|
||||
# step_number: int,
|
||||
# step_answer: str,
|
||||
# verified_reranked_documents: list[InferenceSection] = [],
|
||||
# ) -> SubQuestionAnswerResults:
|
||||
# """
|
||||
# Get near-empty step results from a list of step results.
|
||||
# """
|
||||
# return SubQuestionAnswerResults(
|
||||
# question=KG_SEARCH_STEP_DESCRIPTIONS[step_number].description,
|
||||
# question_id="0_" + str(step_number),
|
||||
# answer=step_answer,
|
||||
# verified_high_quality=True,
|
||||
# sub_query_retrieval_results=[],
|
||||
# verified_reranked_documents=verified_reranked_documents,
|
||||
# context_documents=[],
|
||||
# cited_documents=[],
|
||||
# sub_question_retrieval_stats=AgentChunkRetrievalStats(
|
||||
# verified_count=None,
|
||||
# verified_avg_scores=None,
|
||||
# rejected_count=None,
|
||||
# rejected_avg_scores=None,
|
||||
# verified_doc_chunk_ids=[],
|
||||
# dismissed_doc_chunk_ids=[],
|
||||
# ),
|
||||
# )
|
||||
|
||||
@@ -1,55 +1,55 @@
|
||||
from pydantic import BaseModel
|
||||
# from pydantic import BaseModel
|
||||
|
||||
from onyx.agents.agent_search.kb_search.states import KGAnswerFormat
|
||||
from onyx.agents.agent_search.kb_search.states import KGAnswerStrategy
|
||||
from onyx.agents.agent_search.kb_search.states import KGRelationshipDetection
|
||||
from onyx.agents.agent_search.kb_search.states import KGSearchType
|
||||
from onyx.agents.agent_search.kb_search.states import YesNoEnum
|
||||
# from onyx.agents.agent_search.kb_search.states import KGAnswerFormat
|
||||
# from onyx.agents.agent_search.kb_search.states import KGAnswerStrategy
|
||||
# from onyx.agents.agent_search.kb_search.states import KGRelationshipDetection
|
||||
# from onyx.agents.agent_search.kb_search.states import KGSearchType
|
||||
# from onyx.agents.agent_search.kb_search.states import YesNoEnum
|
||||
|
||||
|
||||
class KGQuestionEntityExtractionResult(BaseModel):
|
||||
entities: list[str]
|
||||
time_filter: str | None
|
||||
# class KGQuestionEntityExtractionResult(BaseModel):
|
||||
# entities: list[str]
|
||||
# time_filter: str | None
|
||||
|
||||
|
||||
class KGViewNames(BaseModel):
|
||||
allowed_docs_view_name: str
|
||||
kg_relationships_view_name: str
|
||||
kg_entity_view_name: str
|
||||
# class KGViewNames(BaseModel):
|
||||
# allowed_docs_view_name: str
|
||||
# kg_relationships_view_name: str
|
||||
# kg_entity_view_name: str
|
||||
|
||||
|
||||
class KGAnswerApproach(BaseModel):
|
||||
search_type: KGSearchType
|
||||
search_strategy: KGAnswerStrategy
|
||||
relationship_detection: KGRelationshipDetection
|
||||
format: KGAnswerFormat
|
||||
broken_down_question: str | None = None
|
||||
divide_and_conquer: YesNoEnum | None = None
|
||||
# class KGAnswerApproach(BaseModel):
|
||||
# search_type: KGSearchType
|
||||
# search_strategy: KGAnswerStrategy
|
||||
# relationship_detection: KGRelationshipDetection
|
||||
# format: KGAnswerFormat
|
||||
# broken_down_question: str | None = None
|
||||
# divide_and_conquer: YesNoEnum | None = None
|
||||
|
||||
|
||||
class KGQuestionRelationshipExtractionResult(BaseModel):
|
||||
relationships: list[str]
|
||||
# class KGQuestionRelationshipExtractionResult(BaseModel):
|
||||
# relationships: list[str]
|
||||
|
||||
|
||||
class KGQuestionExtractionResult(BaseModel):
|
||||
entities: list[str]
|
||||
relationships: list[str]
|
||||
time_filter: str | None
|
||||
# class KGQuestionExtractionResult(BaseModel):
|
||||
# entities: list[str]
|
||||
# relationships: list[str]
|
||||
# time_filter: str | None
|
||||
|
||||
|
||||
class KGExpandedGraphObjects(BaseModel):
|
||||
entities: list[str]
|
||||
relationships: list[str]
|
||||
# class KGExpandedGraphObjects(BaseModel):
|
||||
# entities: list[str]
|
||||
# relationships: list[str]
|
||||
|
||||
|
||||
class KGSteps(BaseModel):
|
||||
description: str
|
||||
activities: list[str]
|
||||
# class KGSteps(BaseModel):
|
||||
# description: str
|
||||
# activities: list[str]
|
||||
|
||||
|
||||
class KGEntityDocInfo(BaseModel):
|
||||
doc_id: str | None
|
||||
doc_semantic_id: str | None
|
||||
doc_link: str | None
|
||||
semantic_entity_name: str
|
||||
semantic_linked_entity_name: str
|
||||
# class KGEntityDocInfo(BaseModel):
|
||||
# doc_id: str | None
|
||||
# doc_semantic_id: str | None
|
||||
# doc_link: str | None
|
||||
# semantic_entity_name: str
|
||||
# semantic_linked_entity_name: str
|
||||
|
||||
@@ -1,255 +1,255 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
# from datetime import datetime
|
||||
# from typing import cast
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
from pydantic import ValidationError
|
||||
# from langchain_core.messages import HumanMessage
|
||||
# from langchain_core.runnables import RunnableConfig
|
||||
# from langgraph.types import StreamWriter
|
||||
# from pydantic import ValidationError
|
||||
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import get_near_empty_step_results
|
||||
from onyx.agents.agent_search.kb_search.models import KGQuestionEntityExtractionResult
|
||||
from onyx.agents.agent_search.kb_search.models import (
|
||||
KGQuestionRelationshipExtractionResult,
|
||||
)
|
||||
from onyx.agents.agent_search.kb_search.states import EntityRelationshipExtractionUpdate
|
||||
from onyx.agents.agent_search.kb_search.states import MainState
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.configs.kg_configs import KG_ENTITY_EXTRACTION_TIMEOUT
|
||||
from onyx.configs.kg_configs import KG_RELATIONSHIP_EXTRACTION_TIMEOUT
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.kg_temp_view import create_views
|
||||
from onyx.db.kg_temp_view import get_user_view_names
|
||||
from onyx.db.relationships import get_allowed_relationship_type_pairs
|
||||
from onyx.kg.utils.extraction_utils import get_entity_types_str
|
||||
from onyx.kg.utils.extraction_utils import get_relationship_types_str
|
||||
from onyx.prompts.kg_prompts import QUERY_ENTITY_EXTRACTION_PROMPT
|
||||
from onyx.prompts.kg_prompts import QUERY_RELATIONSHIP_EXTRACTION_PROMPT
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
# from onyx.agents.agent_search.kb_search.graph_utils import get_near_empty_step_results
|
||||
# from onyx.agents.agent_search.kb_search.models import KGQuestionEntityExtractionResult
|
||||
# from onyx.agents.agent_search.kb_search.models import (
|
||||
# KGQuestionRelationshipExtractionResult,
|
||||
# )
|
||||
# from onyx.agents.agent_search.kb_search.states import EntityRelationshipExtractionUpdate
|
||||
# from onyx.agents.agent_search.kb_search.states import MainState
|
||||
# from onyx.agents.agent_search.models import GraphConfig
|
||||
# from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
# get_langgraph_node_log_string,
|
||||
# )
|
||||
# from onyx.configs.kg_configs import KG_ENTITY_EXTRACTION_TIMEOUT
|
||||
# from onyx.configs.kg_configs import KG_RELATIONSHIP_EXTRACTION_TIMEOUT
|
||||
# from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
# from onyx.db.kg_temp_view import create_views
|
||||
# from onyx.db.kg_temp_view import get_user_view_names
|
||||
# from onyx.db.relationships import get_allowed_relationship_type_pairs
|
||||
# from onyx.kg.utils.extraction_utils import get_entity_types_str
|
||||
# from onyx.kg.utils.extraction_utils import get_relationship_types_str
|
||||
# from onyx.prompts.kg_prompts import QUERY_ENTITY_EXTRACTION_PROMPT
|
||||
# from onyx.prompts.kg_prompts import QUERY_RELATIONSHIP_EXTRACTION_PROMPT
|
||||
# from onyx.utils.logger import setup_logger
|
||||
# from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
# from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
# logger = setup_logger()
|
||||
|
||||
|
||||
def extract_ert(
|
||||
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> EntityRelationshipExtractionUpdate:
|
||||
"""
|
||||
LangGraph node to start the agentic search process.
|
||||
"""
|
||||
# def extract_ert(
|
||||
# state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
# ) -> EntityRelationshipExtractionUpdate:
|
||||
# """
|
||||
# LangGraph node to start the agentic search process.
|
||||
# """
|
||||
|
||||
# recheck KG enablement at outset KG graph
|
||||
# # recheck KG enablement at outset KG graph
|
||||
|
||||
if not config["metadata"]["config"].behavior.kg_config_settings.KG_ENABLED:
|
||||
logger.error("KG approach is not enabled, the KG agent flow cannot run.")
|
||||
raise ValueError("KG approach is not enabled, the KG agent flow cannot run.")
|
||||
# if not config["metadata"]["config"].behavior.kg_config_settings.KG_ENABLED:
|
||||
# logger.error("KG approach is not enabled, the KG agent flow cannot run.")
|
||||
# raise ValueError("KG approach is not enabled, the KG agent flow cannot run.")
|
||||
|
||||
_KG_STEP_NR = 1
|
||||
# _KG_STEP_NR = 1
|
||||
|
||||
node_start_time = datetime.now()
|
||||
# node_start_time = datetime.now()
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
# graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
|
||||
if graph_config.tooling.search_tool is None:
|
||||
raise ValueError("Search tool is not set")
|
||||
elif graph_config.tooling.search_tool.user is None:
|
||||
raise ValueError("User is not set")
|
||||
else:
|
||||
user_email = graph_config.tooling.search_tool.user.email
|
||||
user_name = user_email.split("@")[0] or "unknown"
|
||||
# if graph_config.tooling.search_tool is None:
|
||||
# raise ValueError("Search tool is not set")
|
||||
# elif graph_config.tooling.search_tool.user is None:
|
||||
# raise ValueError("User is not set")
|
||||
# else:
|
||||
# user_email = graph_config.tooling.search_tool.user.email
|
||||
# user_name = user_email.split("@")[0] or "unknown"
|
||||
|
||||
# first four lines duplicates from generate_initial_answer
|
||||
question = state.question
|
||||
today_date = datetime.now().strftime("%A, %Y-%m-%d")
|
||||
# # first four lines duplicates from generate_initial_answer
|
||||
# question = state.question
|
||||
# today_date = datetime.now().strftime("%A, %Y-%m-%d")
|
||||
|
||||
all_entity_types = get_entity_types_str(active=True)
|
||||
all_relationship_types = get_relationship_types_str(active=True)
|
||||
# all_entity_types = get_entity_types_str(active=True)
|
||||
# all_relationship_types = get_relationship_types_str(active=True)
|
||||
|
||||
# Create temporary views. TODO: move into parallel step, if ultimately materialized
|
||||
tenant_id = get_current_tenant_id()
|
||||
kg_views = get_user_view_names(user_email, tenant_id)
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
create_views(
|
||||
db_session,
|
||||
tenant_id=tenant_id,
|
||||
user_email=user_email,
|
||||
allowed_docs_view_name=kg_views.allowed_docs_view_name,
|
||||
kg_relationships_view_name=kg_views.kg_relationships_view_name,
|
||||
kg_entity_view_name=kg_views.kg_entity_view_name,
|
||||
)
|
||||
# # Create temporary views. TODO: move into parallel step, if ultimately materialized
|
||||
# tenant_id = get_current_tenant_id()
|
||||
# kg_views = get_user_view_names(user_email, tenant_id)
|
||||
# with get_session_with_current_tenant() as db_session:
|
||||
# create_views(
|
||||
# db_session,
|
||||
# tenant_id=tenant_id,
|
||||
# user_email=user_email,
|
||||
# allowed_docs_view_name=kg_views.allowed_docs_view_name,
|
||||
# kg_relationships_view_name=kg_views.kg_relationships_view_name,
|
||||
# kg_entity_view_name=kg_views.kg_entity_view_name,
|
||||
# )
|
||||
|
||||
### get the entities, terms, and filters
|
||||
# ### get the entities, terms, and filters
|
||||
|
||||
query_extraction_pre_prompt = QUERY_ENTITY_EXTRACTION_PROMPT.format(
|
||||
entity_types=all_entity_types,
|
||||
relationship_types=all_relationship_types,
|
||||
)
|
||||
# query_extraction_pre_prompt = QUERY_ENTITY_EXTRACTION_PROMPT.format(
|
||||
# entity_types=all_entity_types,
|
||||
# relationship_types=all_relationship_types,
|
||||
# )
|
||||
|
||||
query_extraction_prompt = (
|
||||
query_extraction_pre_prompt.replace("---content---", question)
|
||||
.replace("---today_date---", today_date)
|
||||
.replace("---user_name---", f"EMPLOYEE:{user_name}")
|
||||
.replace("{{", "{")
|
||||
.replace("}}", "}")
|
||||
)
|
||||
# query_extraction_prompt = (
|
||||
# query_extraction_pre_prompt.replace("---content---", question)
|
||||
# .replace("---today_date---", today_date)
|
||||
# .replace("---user_name---", f"EMPLOYEE:{user_name}")
|
||||
# .replace("{{", "{")
|
||||
# .replace("}}", "}")
|
||||
# )
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=query_extraction_prompt,
|
||||
)
|
||||
]
|
||||
primary_llm = graph_config.tooling.primary_llm
|
||||
# Grader
|
||||
try:
|
||||
llm_response = run_with_timeout(
|
||||
KG_ENTITY_EXTRACTION_TIMEOUT,
|
||||
primary_llm.invoke_langchain,
|
||||
prompt=msg,
|
||||
timeout_override=15,
|
||||
max_tokens=300,
|
||||
)
|
||||
# msg = [
|
||||
# HumanMessage(
|
||||
# content=query_extraction_prompt,
|
||||
# )
|
||||
# ]
|
||||
# primary_llm = graph_config.tooling.primary_llm
|
||||
# # Grader
|
||||
# try:
|
||||
# llm_response = run_with_timeout(
|
||||
# KG_ENTITY_EXTRACTION_TIMEOUT,
|
||||
# primary_llm.invoke_langchain,
|
||||
# prompt=msg,
|
||||
# timeout_override=15,
|
||||
# max_tokens=300,
|
||||
# )
|
||||
|
||||
cleaned_response = (
|
||||
str(llm_response.content)
|
||||
.replace("{{", "{")
|
||||
.replace("}}", "}")
|
||||
.replace("```json\n", "")
|
||||
.replace("\n```", "")
|
||||
.replace("\n", "")
|
||||
)
|
||||
first_bracket = cleaned_response.find("{")
|
||||
last_bracket = cleaned_response.rfind("}")
|
||||
cleaned_response = cleaned_response[first_bracket : last_bracket + 1]
|
||||
# cleaned_response = (
|
||||
# str(llm_response.content)
|
||||
# .replace("{{", "{")
|
||||
# .replace("}}", "}")
|
||||
# .replace("```json\n", "")
|
||||
# .replace("\n```", "")
|
||||
# .replace("\n", "")
|
||||
# )
|
||||
# first_bracket = cleaned_response.find("{")
|
||||
# last_bracket = cleaned_response.rfind("}")
|
||||
# cleaned_response = cleaned_response[first_bracket : last_bracket + 1]
|
||||
|
||||
entity_extraction_result = KGQuestionEntityExtractionResult.model_validate_json(
|
||||
cleaned_response
|
||||
)
|
||||
except ValidationError:
|
||||
logger.error("Failed to parse LLM response as JSON in Entity Extraction")
|
||||
entity_extraction_result = KGQuestionEntityExtractionResult(
|
||||
entities=[], time_filter=""
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in extract_ert: {e}")
|
||||
entity_extraction_result = KGQuestionEntityExtractionResult(
|
||||
entities=[], time_filter=""
|
||||
)
|
||||
# entity_extraction_result = KGQuestionEntityExtractionResult.model_validate_json(
|
||||
# cleaned_response
|
||||
# )
|
||||
# except ValidationError:
|
||||
# logger.error("Failed to parse LLM response as JSON in Entity Extraction")
|
||||
# entity_extraction_result = KGQuestionEntityExtractionResult(
|
||||
# entities=[], time_filter=""
|
||||
# )
|
||||
# except Exception as e:
|
||||
# logger.error(f"Error in extract_ert: {e}")
|
||||
# entity_extraction_result = KGQuestionEntityExtractionResult(
|
||||
# entities=[], time_filter=""
|
||||
# )
|
||||
|
||||
# remove the attribute filters from the entities to for the purpose of the relationship
|
||||
entities_no_attributes = [
|
||||
entity.split("--")[0] for entity in entity_extraction_result.entities
|
||||
]
|
||||
ert_entities_string = f"Entities: {entities_no_attributes}\n"
|
||||
# # remove the attribute filters from the entities to for the purpose of the relationship
|
||||
# entities_no_attributes = [
|
||||
# entity.split("--")[0] for entity in entity_extraction_result.entities
|
||||
# ]
|
||||
# ert_entities_string = f"Entities: {entities_no_attributes}\n"
|
||||
|
||||
### get the relationships
|
||||
# ### get the relationships
|
||||
|
||||
# find the relationship types that match the extracted entity types
|
||||
# # find the relationship types that match the extracted entity types
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
allowed_relationship_pairs = get_allowed_relationship_type_pairs(
|
||||
db_session, entity_extraction_result.entities
|
||||
)
|
||||
# with get_session_with_current_tenant() as db_session:
|
||||
# allowed_relationship_pairs = get_allowed_relationship_type_pairs(
|
||||
# db_session, entity_extraction_result.entities
|
||||
# )
|
||||
|
||||
query_relationship_extraction_prompt = (
|
||||
QUERY_RELATIONSHIP_EXTRACTION_PROMPT.replace("---question---", question)
|
||||
.replace("---today_date---", today_date)
|
||||
.replace(
|
||||
"---relationship_type_options---",
|
||||
" - " + "\n - ".join(allowed_relationship_pairs),
|
||||
)
|
||||
.replace("---identified_entities---", ert_entities_string)
|
||||
.replace("---entity_types---", all_entity_types)
|
||||
.replace("{{", "{")
|
||||
.replace("}}", "}")
|
||||
)
|
||||
# query_relationship_extraction_prompt = (
|
||||
# QUERY_RELATIONSHIP_EXTRACTION_PROMPT.replace("---question---", question)
|
||||
# .replace("---today_date---", today_date)
|
||||
# .replace(
|
||||
# "---relationship_type_options---",
|
||||
# " - " + "\n - ".join(allowed_relationship_pairs),
|
||||
# )
|
||||
# .replace("---identified_entities---", ert_entities_string)
|
||||
# .replace("---entity_types---", all_entity_types)
|
||||
# .replace("{{", "{")
|
||||
# .replace("}}", "}")
|
||||
# )
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=query_relationship_extraction_prompt,
|
||||
)
|
||||
]
|
||||
primary_llm = graph_config.tooling.primary_llm
|
||||
# Grader
|
||||
try:
|
||||
llm_response = run_with_timeout(
|
||||
KG_RELATIONSHIP_EXTRACTION_TIMEOUT,
|
||||
primary_llm.invoke_langchain,
|
||||
prompt=msg,
|
||||
timeout_override=15,
|
||||
max_tokens=300,
|
||||
)
|
||||
# msg = [
|
||||
# HumanMessage(
|
||||
# content=query_relationship_extraction_prompt,
|
||||
# )
|
||||
# ]
|
||||
# primary_llm = graph_config.tooling.primary_llm
|
||||
# # Grader
|
||||
# try:
|
||||
# llm_response = run_with_timeout(
|
||||
# KG_RELATIONSHIP_EXTRACTION_TIMEOUT,
|
||||
# primary_llm.invoke_langchain,
|
||||
# prompt=msg,
|
||||
# timeout_override=15,
|
||||
# max_tokens=300,
|
||||
# )
|
||||
|
||||
cleaned_response = (
|
||||
str(llm_response.content)
|
||||
.replace("{{", "{")
|
||||
.replace("}}", "}")
|
||||
.replace("```json\n", "")
|
||||
.replace("\n```", "")
|
||||
.replace("\n", "")
|
||||
)
|
||||
first_bracket = cleaned_response.find("{")
|
||||
last_bracket = cleaned_response.rfind("}")
|
||||
cleaned_response = cleaned_response[first_bracket : last_bracket + 1]
|
||||
cleaned_response = cleaned_response.replace("{{", '{"')
|
||||
cleaned_response = cleaned_response.replace("}}", '"}')
|
||||
# cleaned_response = (
|
||||
# str(llm_response.content)
|
||||
# .replace("{{", "{")
|
||||
# .replace("}}", "}")
|
||||
# .replace("```json\n", "")
|
||||
# .replace("\n```", "")
|
||||
# .replace("\n", "")
|
||||
# )
|
||||
# first_bracket = cleaned_response.find("{")
|
||||
# last_bracket = cleaned_response.rfind("}")
|
||||
# cleaned_response = cleaned_response[first_bracket : last_bracket + 1]
|
||||
# cleaned_response = cleaned_response.replace("{{", '{"')
|
||||
# cleaned_response = cleaned_response.replace("}}", '"}')
|
||||
|
||||
try:
|
||||
relationship_extraction_result = (
|
||||
KGQuestionRelationshipExtractionResult.model_validate_json(
|
||||
cleaned_response
|
||||
)
|
||||
)
|
||||
except ValidationError:
|
||||
logger.error(
|
||||
"Failed to parse LLM response as JSON in Relationship Extraction"
|
||||
)
|
||||
relationship_extraction_result = KGQuestionRelationshipExtractionResult(
|
||||
relationships=[],
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in extract_ert: {e}")
|
||||
relationship_extraction_result = KGQuestionRelationshipExtractionResult(
|
||||
relationships=[],
|
||||
)
|
||||
# try:
|
||||
# relationship_extraction_result = (
|
||||
# KGQuestionRelationshipExtractionResult.model_validate_json(
|
||||
# cleaned_response
|
||||
# )
|
||||
# )
|
||||
# except ValidationError:
|
||||
# logger.error(
|
||||
# "Failed to parse LLM response as JSON in Relationship Extraction"
|
||||
# )
|
||||
# relationship_extraction_result = KGQuestionRelationshipExtractionResult(
|
||||
# relationships=[],
|
||||
# )
|
||||
# except Exception as e:
|
||||
# logger.error(f"Error in extract_ert: {e}")
|
||||
# relationship_extraction_result = KGQuestionRelationshipExtractionResult(
|
||||
# relationships=[],
|
||||
# )
|
||||
|
||||
## STEP 1
|
||||
# Stream answer pieces out to the UI for Step 1
|
||||
# ## STEP 1
|
||||
# # Stream answer pieces out to the UI for Step 1
|
||||
|
||||
extracted_entity_string = " \n ".join(
|
||||
[x.split("--")[0] for x in entity_extraction_result.entities]
|
||||
)
|
||||
extracted_relationship_string = " \n ".join(
|
||||
relationship_extraction_result.relationships
|
||||
)
|
||||
# extracted_entity_string = " \n ".join(
|
||||
# [x.split("--")[0] for x in entity_extraction_result.entities]
|
||||
# )
|
||||
# extracted_relationship_string = " \n ".join(
|
||||
# relationship_extraction_result.relationships
|
||||
# )
|
||||
|
||||
step_answer = f"""Entities and relationships have been extracted from query - \n \
|
||||
Entities: {extracted_entity_string} - \n Relationships: {extracted_relationship_string}"""
|
||||
# step_answer = f"""Entities and relationships have been extracted from query - \n \
|
||||
# Entities: {extracted_entity_string} - \n Relationships: {extracted_relationship_string}"""
|
||||
|
||||
return EntityRelationshipExtractionUpdate(
|
||||
entities_types_str=all_entity_types,
|
||||
relationship_types_str=all_relationship_types,
|
||||
extracted_entities_w_attributes=entity_extraction_result.entities,
|
||||
extracted_entities_no_attributes=entities_no_attributes,
|
||||
extracted_relationships=relationship_extraction_result.relationships,
|
||||
time_filter=entity_extraction_result.time_filter,
|
||||
kg_doc_temp_view_name=kg_views.allowed_docs_view_name,
|
||||
kg_rel_temp_view_name=kg_views.kg_relationships_view_name,
|
||||
kg_entity_temp_view_name=kg_views.kg_entity_view_name,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="main",
|
||||
node_name="extract entities terms",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
step_results=[
|
||||
get_near_empty_step_results(
|
||||
step_number=_KG_STEP_NR,
|
||||
step_answer=step_answer,
|
||||
verified_reranked_documents=[],
|
||||
)
|
||||
],
|
||||
)
|
||||
# return EntityRelationshipExtractionUpdate(
|
||||
# entities_types_str=all_entity_types,
|
||||
# relationship_types_str=all_relationship_types,
|
||||
# extracted_entities_w_attributes=entity_extraction_result.entities,
|
||||
# extracted_entities_no_attributes=entities_no_attributes,
|
||||
# extracted_relationships=relationship_extraction_result.relationships,
|
||||
# time_filter=entity_extraction_result.time_filter,
|
||||
# kg_doc_temp_view_name=kg_views.allowed_docs_view_name,
|
||||
# kg_rel_temp_view_name=kg_views.kg_relationships_view_name,
|
||||
# kg_entity_temp_view_name=kg_views.kg_entity_view_name,
|
||||
# log_messages=[
|
||||
# get_langgraph_node_log_string(
|
||||
# graph_component="main",
|
||||
# node_name="extract entities terms",
|
||||
# node_start_time=node_start_time,
|
||||
# )
|
||||
# ],
|
||||
# step_results=[
|
||||
# get_near_empty_step_results(
|
||||
# step_number=_KG_STEP_NR,
|
||||
# step_answer=step_answer,
|
||||
# verified_reranked_documents=[],
|
||||
# )
|
||||
# ],
|
||||
# )
|
||||
|
||||
@@ -1,312 +1,312 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
# from datetime import datetime
|
||||
# from typing import cast
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
# from langchain_core.messages import HumanMessage
|
||||
# from langchain_core.runnables import RunnableConfig
|
||||
# from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import (
|
||||
create_minimal_connected_query_graph,
|
||||
)
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import get_near_empty_step_results
|
||||
from onyx.agents.agent_search.kb_search.models import KGAnswerApproach
|
||||
from onyx.agents.agent_search.kb_search.states import AnalysisUpdate
|
||||
from onyx.agents.agent_search.kb_search.states import KGAnswerFormat
|
||||
from onyx.agents.agent_search.kb_search.states import KGAnswerStrategy
|
||||
from onyx.agents.agent_search.kb_search.states import KGRelationshipDetection
|
||||
from onyx.agents.agent_search.kb_search.states import KGSearchType
|
||||
from onyx.agents.agent_search.kb_search.states import MainState
|
||||
from onyx.agents.agent_search.kb_search.states import YesNoEnum
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.configs.kg_configs import KG_STRATEGY_GENERATION_TIMEOUT
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.entities import get_document_id_for_entity
|
||||
from onyx.kg.clustering.normalizations import normalize_entities
|
||||
from onyx.kg.clustering.normalizations import normalize_relationships
|
||||
from onyx.kg.utils.formatting_utils import split_relationship_id
|
||||
from onyx.prompts.kg_prompts import STRATEGY_GENERATION_PROMPT
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
# from onyx.agents.agent_search.kb_search.graph_utils import (
|
||||
# create_minimal_connected_query_graph,
|
||||
# )
|
||||
# from onyx.agents.agent_search.kb_search.graph_utils import get_near_empty_step_results
|
||||
# from onyx.agents.agent_search.kb_search.models import KGAnswerApproach
|
||||
# from onyx.agents.agent_search.kb_search.states import AnalysisUpdate
|
||||
# from onyx.agents.agent_search.kb_search.states import KGAnswerFormat
|
||||
# from onyx.agents.agent_search.kb_search.states import KGAnswerStrategy
|
||||
# from onyx.agents.agent_search.kb_search.states import KGRelationshipDetection
|
||||
# from onyx.agents.agent_search.kb_search.states import KGSearchType
|
||||
# from onyx.agents.agent_search.kb_search.states import MainState
|
||||
# from onyx.agents.agent_search.kb_search.states import YesNoEnum
|
||||
# from onyx.agents.agent_search.models import GraphConfig
|
||||
# from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
# get_langgraph_node_log_string,
|
||||
# )
|
||||
# from onyx.configs.kg_configs import KG_STRATEGY_GENERATION_TIMEOUT
|
||||
# from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
# from onyx.db.entities import get_document_id_for_entity
|
||||
# from onyx.kg.clustering.normalizations import normalize_entities
|
||||
# from onyx.kg.clustering.normalizations import normalize_relationships
|
||||
# from onyx.kg.utils.formatting_utils import split_relationship_id
|
||||
# from onyx.prompts.kg_prompts import STRATEGY_GENERATION_PROMPT
|
||||
# from onyx.utils.logger import setup_logger
|
||||
# from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
|
||||
logger = setup_logger()
|
||||
# logger = setup_logger()
|
||||
|
||||
|
||||
def _articulate_normalizations(
|
||||
entity_normalization_map: dict[str, str],
|
||||
relationship_normalization_map: dict[str, str],
|
||||
) -> str:
|
||||
# def _articulate_normalizations(
|
||||
# entity_normalization_map: dict[str, str],
|
||||
# relationship_normalization_map: dict[str, str],
|
||||
# ) -> str:
|
||||
|
||||
remark_list: list[str] = []
|
||||
# remark_list: list[str] = []
|
||||
|
||||
if entity_normalization_map:
|
||||
remark_list.append("\n Entities:")
|
||||
for extracted_entity, normalized_entity in entity_normalization_map.items():
|
||||
remark_list.append(f" - {extracted_entity} -> {normalized_entity}")
|
||||
# if entity_normalization_map:
|
||||
# remark_list.append("\n Entities:")
|
||||
# for extracted_entity, normalized_entity in entity_normalization_map.items():
|
||||
# remark_list.append(f" - {extracted_entity} -> {normalized_entity}")
|
||||
|
||||
if relationship_normalization_map:
|
||||
remark_list.append(" \n Relationships:")
|
||||
for (
|
||||
extracted_relationship,
|
||||
normalized_relationship,
|
||||
) in relationship_normalization_map.items():
|
||||
remark_list.append(
|
||||
f" - {extracted_relationship} -> {normalized_relationship}"
|
||||
)
|
||||
# if relationship_normalization_map:
|
||||
# remark_list.append(" \n Relationships:")
|
||||
# for (
|
||||
# extracted_relationship,
|
||||
# normalized_relationship,
|
||||
# ) in relationship_normalization_map.items():
|
||||
# remark_list.append(
|
||||
# f" - {extracted_relationship} -> {normalized_relationship}"
|
||||
# )
|
||||
|
||||
return " \n ".join(remark_list)
|
||||
# return " \n ".join(remark_list)
|
||||
|
||||
|
||||
def _get_fully_connected_entities(
|
||||
entities: list[str], relationships: list[str]
|
||||
) -> list[str]:
|
||||
"""
|
||||
Analyze the connectedness of the entities and relationships.
|
||||
"""
|
||||
# Build a dictionary to track connections for each entity
|
||||
entity_connections: dict[str, set[str]] = {entity: set() for entity in entities}
|
||||
# def _get_fully_connected_entities(
|
||||
# entities: list[str], relationships: list[str]
|
||||
# ) -> list[str]:
|
||||
# """
|
||||
# Analyze the connectedness of the entities and relationships.
|
||||
# """
|
||||
# # Build a dictionary to track connections for each entity
|
||||
# entity_connections: dict[str, set[str]] = {entity: set() for entity in entities}
|
||||
|
||||
# Parse relationships to build connection graph
|
||||
for relationship in relationships:
|
||||
# Split relationship into parts. Test for proper formatting just in case.
|
||||
# Should never be an error though at this point.
|
||||
parts = split_relationship_id(relationship)
|
||||
if len(parts) != 3:
|
||||
raise ValueError(f"Invalid relationship: {relationship}")
|
||||
# # Parse relationships to build connection graph
|
||||
# for relationship in relationships:
|
||||
# # Split relationship into parts. Test for proper formatting just in case.
|
||||
# # Should never be an error though at this point.
|
||||
# parts = split_relationship_id(relationship)
|
||||
# if len(parts) != 3:
|
||||
# raise ValueError(f"Invalid relationship: {relationship}")
|
||||
|
||||
entity1 = parts[0]
|
||||
entity2 = parts[2]
|
||||
# entity1 = parts[0]
|
||||
# entity2 = parts[2]
|
||||
|
||||
# Add bidirectional connections
|
||||
if entity1 in entity_connections:
|
||||
entity_connections[entity1].add(entity2)
|
||||
if entity2 in entity_connections:
|
||||
entity_connections[entity2].add(entity1)
|
||||
# # Add bidirectional connections
|
||||
# if entity1 in entity_connections:
|
||||
# entity_connections[entity1].add(entity2)
|
||||
# if entity2 in entity_connections:
|
||||
# entity_connections[entity2].add(entity1)
|
||||
|
||||
# Find entities connected to all others
|
||||
fully_connected_entities = []
|
||||
all_entities = set(entities)
|
||||
# # Find entities connected to all others
|
||||
# fully_connected_entities = []
|
||||
# all_entities = set(entities)
|
||||
|
||||
for entity, connections in entity_connections.items():
|
||||
# Check if this entity is connected to all other entities
|
||||
if connections == all_entities - {entity}:
|
||||
fully_connected_entities.append(entity)
|
||||
# for entity, connections in entity_connections.items():
|
||||
# # Check if this entity is connected to all other entities
|
||||
# if connections == all_entities - {entity}:
|
||||
# fully_connected_entities.append(entity)
|
||||
|
||||
return fully_connected_entities
|
||||
# return fully_connected_entities
|
||||
|
||||
|
||||
def _check_for_single_doc(
|
||||
normalized_entities: list[str],
|
||||
raw_entities: list[str],
|
||||
normalized_relationship_strings: list[str],
|
||||
raw_relationships: list[str],
|
||||
normalized_time_filter: str | None,
|
||||
) -> str | None:
|
||||
"""
|
||||
Check if the query is for a single document, like 'Summarize ticket ENG-2243K'.
|
||||
None is returned if the query is not for a single document.
|
||||
"""
|
||||
if (
|
||||
len(normalized_entities) == 1
|
||||
and len(raw_entities) == 1
|
||||
and len(normalized_relationship_strings) == 0
|
||||
and len(raw_relationships) == 0
|
||||
and normalized_time_filter is None
|
||||
):
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
single_doc_id = get_document_id_for_entity(
|
||||
db_session, normalized_entities[0]
|
||||
)
|
||||
else:
|
||||
single_doc_id = None
|
||||
return single_doc_id
|
||||
# def _check_for_single_doc(
|
||||
# normalized_entities: list[str],
|
||||
# raw_entities: list[str],
|
||||
# normalized_relationship_strings: list[str],
|
||||
# raw_relationships: list[str],
|
||||
# normalized_time_filter: str | None,
|
||||
# ) -> str | None:
|
||||
# """
|
||||
# Check if the query is for a single document, like 'Summarize ticket ENG-2243K'.
|
||||
# None is returned if the query is not for a single document.
|
||||
# """
|
||||
# if (
|
||||
# len(normalized_entities) == 1
|
||||
# and len(raw_entities) == 1
|
||||
# and len(normalized_relationship_strings) == 0
|
||||
# and len(raw_relationships) == 0
|
||||
# and normalized_time_filter is None
|
||||
# ):
|
||||
# with get_session_with_current_tenant() as db_session:
|
||||
# single_doc_id = get_document_id_for_entity(
|
||||
# db_session, normalized_entities[0]
|
||||
# )
|
||||
# else:
|
||||
# single_doc_id = None
|
||||
# return single_doc_id
|
||||
|
||||
|
||||
def analyze(
|
||||
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> AnalysisUpdate:
|
||||
"""
|
||||
LangGraph node to start the agentic search process.
|
||||
"""
|
||||
# def analyze(
|
||||
# state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
# ) -> AnalysisUpdate:
|
||||
# """
|
||||
# LangGraph node to start the agentic search process.
|
||||
# """
|
||||
|
||||
_KG_STEP_NR = 2
|
||||
# _KG_STEP_NR = 2
|
||||
|
||||
node_start_time = datetime.now()
|
||||
# node_start_time = datetime.now()
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
question = state.question
|
||||
entities = (
|
||||
state.extracted_entities_no_attributes
|
||||
) # attribute knowledge is not required for this step
|
||||
relationships = state.extracted_relationships
|
||||
time_filter = state.time_filter
|
||||
# graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
# question = state.question
|
||||
# entities = (
|
||||
# state.extracted_entities_no_attributes
|
||||
# ) # attribute knowledge is not required for this step
|
||||
# relationships = state.extracted_relationships
|
||||
# time_filter = state.time_filter
|
||||
|
||||
## STEP 2 - stream out goals
|
||||
# ## STEP 2 - stream out goals
|
||||
|
||||
# Continue with node
|
||||
# # Continue with node
|
||||
|
||||
normalized_entities = normalize_entities(
|
||||
entities,
|
||||
state.extracted_entities_w_attributes,
|
||||
allowed_docs_temp_view_name=state.kg_doc_temp_view_name,
|
||||
)
|
||||
# normalized_entities = normalize_entities(
|
||||
# entities,
|
||||
# state.extracted_entities_w_attributes,
|
||||
# allowed_docs_temp_view_name=state.kg_doc_temp_view_name,
|
||||
# )
|
||||
|
||||
normalized_relationships = normalize_relationships(
|
||||
relationships, normalized_entities.entity_normalization_map
|
||||
)
|
||||
normalized_time_filter = time_filter
|
||||
# normalized_relationships = normalize_relationships(
|
||||
# relationships, normalized_entities.entity_normalization_map
|
||||
# )
|
||||
# normalized_time_filter = time_filter
|
||||
|
||||
# If single-doc inquiry, send to single-doc processing directly
|
||||
# # If single-doc inquiry, send to single-doc processing directly
|
||||
|
||||
single_doc_id = _check_for_single_doc(
|
||||
normalized_entities=normalized_entities.entities,
|
||||
raw_entities=entities,
|
||||
normalized_relationship_strings=normalized_relationships.relationships,
|
||||
raw_relationships=relationships,
|
||||
normalized_time_filter=normalized_time_filter,
|
||||
)
|
||||
# single_doc_id = _check_for_single_doc(
|
||||
# normalized_entities=normalized_entities.entities,
|
||||
# raw_entities=entities,
|
||||
# normalized_relationship_strings=normalized_relationships.relationships,
|
||||
# raw_relationships=relationships,
|
||||
# normalized_time_filter=normalized_time_filter,
|
||||
# )
|
||||
|
||||
# Expand the entities and relationships to make sure that entities are connected
|
||||
# # Expand the entities and relationships to make sure that entities are connected
|
||||
|
||||
graph_expansion = create_minimal_connected_query_graph(
|
||||
normalized_entities.entities,
|
||||
normalized_relationships.relationships,
|
||||
max_depth=2,
|
||||
)
|
||||
# graph_expansion = create_minimal_connected_query_graph(
|
||||
# normalized_entities.entities,
|
||||
# normalized_relationships.relationships,
|
||||
# max_depth=2,
|
||||
# )
|
||||
|
||||
query_graph_entities = graph_expansion.entities
|
||||
query_graph_relationships = graph_expansion.relationships
|
||||
# query_graph_entities = graph_expansion.entities
|
||||
# query_graph_relationships = graph_expansion.relationships
|
||||
|
||||
# Evaluate whether a search needs to be done after identifying all entities and relationships
|
||||
# # Evaluate whether a search needs to be done after identifying all entities and relationships
|
||||
|
||||
strategy_generation_prompt = (
|
||||
STRATEGY_GENERATION_PROMPT.replace(
|
||||
"---entities---", "\n".join(query_graph_entities)
|
||||
)
|
||||
.replace("---relationships---", "\n".join(query_graph_relationships))
|
||||
.replace("---possible_entities---", state.entities_types_str)
|
||||
.replace("---possible_relationships---", state.relationship_types_str)
|
||||
.replace("---question---", question)
|
||||
)
|
||||
# strategy_generation_prompt = (
|
||||
# STRATEGY_GENERATION_PROMPT.replace(
|
||||
# "---entities---", "\n".join(query_graph_entities)
|
||||
# )
|
||||
# .replace("---relationships---", "\n".join(query_graph_relationships))
|
||||
# .replace("---possible_entities---", state.entities_types_str)
|
||||
# .replace("---possible_relationships---", state.relationship_types_str)
|
||||
# .replace("---question---", question)
|
||||
# )
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=strategy_generation_prompt,
|
||||
)
|
||||
]
|
||||
primary_llm = graph_config.tooling.primary_llm
|
||||
# Grader
|
||||
try:
|
||||
llm_response = run_with_timeout(
|
||||
KG_STRATEGY_GENERATION_TIMEOUT,
|
||||
# fast_llm.invoke,
|
||||
primary_llm.invoke_langchain,
|
||||
prompt=msg,
|
||||
timeout_override=5,
|
||||
max_tokens=100,
|
||||
)
|
||||
# msg = [
|
||||
# HumanMessage(
|
||||
# content=strategy_generation_prompt,
|
||||
# )
|
||||
# ]
|
||||
# primary_llm = graph_config.tooling.primary_llm
|
||||
# # Grader
|
||||
# try:
|
||||
# llm_response = run_with_timeout(
|
||||
# KG_STRATEGY_GENERATION_TIMEOUT,
|
||||
# # fast_llm.invoke,
|
||||
# primary_llm.invoke_langchain,
|
||||
# prompt=msg,
|
||||
# timeout_override=5,
|
||||
# max_tokens=100,
|
||||
# )
|
||||
|
||||
cleaned_response = (
|
||||
str(llm_response.content)
|
||||
.replace("```json\n", "")
|
||||
.replace("\n```", "")
|
||||
.replace("\n", "")
|
||||
)
|
||||
first_bracket = cleaned_response.find("{")
|
||||
last_bracket = cleaned_response.rfind("}")
|
||||
cleaned_response = cleaned_response[first_bracket : last_bracket + 1]
|
||||
# cleaned_response = (
|
||||
# str(llm_response.content)
|
||||
# .replace("```json\n", "")
|
||||
# .replace("\n```", "")
|
||||
# .replace("\n", "")
|
||||
# )
|
||||
# first_bracket = cleaned_response.find("{")
|
||||
# last_bracket = cleaned_response.rfind("}")
|
||||
# cleaned_response = cleaned_response[first_bracket : last_bracket + 1]
|
||||
|
||||
try:
|
||||
approach_extraction_result = KGAnswerApproach.model_validate_json(
|
||||
cleaned_response
|
||||
)
|
||||
search_type = approach_extraction_result.search_type
|
||||
search_strategy = approach_extraction_result.search_strategy
|
||||
relationship_detection = (
|
||||
approach_extraction_result.relationship_detection.value
|
||||
)
|
||||
output_format = approach_extraction_result.format
|
||||
broken_down_question = approach_extraction_result.broken_down_question
|
||||
divide_and_conquer = approach_extraction_result.divide_and_conquer
|
||||
except ValueError:
|
||||
logger.error(
|
||||
"Failed to parse LLM response as JSON in Entity-Term Extraction"
|
||||
)
|
||||
search_type = KGSearchType.SEARCH
|
||||
search_strategy = KGAnswerStrategy.DEEP
|
||||
relationship_detection = KGRelationshipDetection.RELATIONSHIPS.value
|
||||
output_format = KGAnswerFormat.TEXT
|
||||
broken_down_question = None
|
||||
divide_and_conquer = YesNoEnum.NO
|
||||
if search_strategy is None or output_format is None:
|
||||
raise ValueError(f"Invalid strategy: {cleaned_response}")
|
||||
# try:
|
||||
# approach_extraction_result = KGAnswerApproach.model_validate_json(
|
||||
# cleaned_response
|
||||
# )
|
||||
# search_type = approach_extraction_result.search_type
|
||||
# search_strategy = approach_extraction_result.search_strategy
|
||||
# relationship_detection = (
|
||||
# approach_extraction_result.relationship_detection.value
|
||||
# )
|
||||
# output_format = approach_extraction_result.format
|
||||
# broken_down_question = approach_extraction_result.broken_down_question
|
||||
# divide_and_conquer = approach_extraction_result.divide_and_conquer
|
||||
# except ValueError:
|
||||
# logger.error(
|
||||
# "Failed to parse LLM response as JSON in Entity-Term Extraction"
|
||||
# )
|
||||
# search_type = KGSearchType.SEARCH
|
||||
# search_strategy = KGAnswerStrategy.DEEP
|
||||
# relationship_detection = KGRelationshipDetection.RELATIONSHIPS.value
|
||||
# output_format = KGAnswerFormat.TEXT
|
||||
# broken_down_question = None
|
||||
# divide_and_conquer = YesNoEnum.NO
|
||||
# if search_strategy is None or output_format is None:
|
||||
# raise ValueError(f"Invalid strategy: {cleaned_response}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in strategy generation: {e}")
|
||||
raise e
|
||||
# except Exception as e:
|
||||
# logger.error(f"Error in strategy generation: {e}")
|
||||
# raise e
|
||||
|
||||
# Stream out relevant results
|
||||
# # Stream out relevant results
|
||||
|
||||
if single_doc_id:
|
||||
search_strategy = (
|
||||
KGAnswerStrategy.DEEP
|
||||
) # if a single doc is identified, we will want to look at the details.
|
||||
# if single_doc_id:
|
||||
# search_strategy = (
|
||||
# KGAnswerStrategy.DEEP
|
||||
# ) # if a single doc is identified, we will want to look at the details.
|
||||
|
||||
step_answer = f"Strategy and format have been extracted from query. Strategy: {search_strategy.value}, \
|
||||
Format: {output_format.value}, Broken down question: {broken_down_question}"
|
||||
# step_answer = f"Strategy and format have been extracted from query. Strategy: {search_strategy.value}, \
|
||||
# Format: {output_format.value}, Broken down question: {broken_down_question}"
|
||||
|
||||
extraction_detected_relationships = len(query_graph_relationships) > 0
|
||||
if extraction_detected_relationships:
|
||||
query_type = KGRelationshipDetection.RELATIONSHIPS.value
|
||||
# extraction_detected_relationships = len(query_graph_relationships) > 0
|
||||
# if extraction_detected_relationships:
|
||||
# query_type = KGRelationshipDetection.RELATIONSHIPS.value
|
||||
|
||||
if extraction_detected_relationships:
|
||||
logger.warning(
|
||||
"Fyi - Extraction detected relationships: "
|
||||
f"{extraction_detected_relationships}, "
|
||||
f"but relationship detection: {relationship_detection}"
|
||||
)
|
||||
else:
|
||||
query_type = KGRelationshipDetection.NO_RELATIONSHIPS.value
|
||||
# if extraction_detected_relationships:
|
||||
# logger.warning(
|
||||
# "Fyi - Extraction detected relationships: "
|
||||
# f"{extraction_detected_relationships}, "
|
||||
# f"but relationship detection: {relationship_detection}"
|
||||
# )
|
||||
# else:
|
||||
# query_type = KGRelationshipDetection.NO_RELATIONSHIPS.value
|
||||
|
||||
# End node
|
||||
# # End node
|
||||
|
||||
return AnalysisUpdate(
|
||||
normalized_core_entities=normalized_entities.entities,
|
||||
normalized_core_relationships=normalized_relationships.relationships,
|
||||
entity_normalization_map=normalized_entities.entity_normalization_map,
|
||||
relationship_normalization_map=normalized_relationships.relationship_normalization_map,
|
||||
query_graph_entities_no_attributes=query_graph_entities,
|
||||
query_graph_entities_w_attributes=normalized_entities.entities_w_attributes,
|
||||
query_graph_relationships=query_graph_relationships,
|
||||
normalized_terms=[], # TODO: remove fully later
|
||||
normalized_time_filter=normalized_time_filter,
|
||||
strategy=search_strategy,
|
||||
broken_down_question=broken_down_question,
|
||||
output_format=output_format,
|
||||
divide_and_conquer=divide_and_conquer,
|
||||
single_doc_id=single_doc_id,
|
||||
search_type=search_type,
|
||||
query_type=query_type,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="main",
|
||||
node_name="analyze",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
step_results=[
|
||||
get_near_empty_step_results(
|
||||
step_number=_KG_STEP_NR,
|
||||
step_answer=step_answer,
|
||||
verified_reranked_documents=[],
|
||||
)
|
||||
],
|
||||
remarks=[
|
||||
_articulate_normalizations(
|
||||
entity_normalization_map=normalized_entities.entity_normalization_map,
|
||||
relationship_normalization_map=normalized_relationships.relationship_normalization_map,
|
||||
)
|
||||
],
|
||||
)
|
||||
# return AnalysisUpdate(
|
||||
# normalized_core_entities=normalized_entities.entities,
|
||||
# normalized_core_relationships=normalized_relationships.relationships,
|
||||
# entity_normalization_map=normalized_entities.entity_normalization_map,
|
||||
# relationship_normalization_map=normalized_relationships.relationship_normalization_map,
|
||||
# query_graph_entities_no_attributes=query_graph_entities,
|
||||
# query_graph_entities_w_attributes=normalized_entities.entities_w_attributes,
|
||||
# query_graph_relationships=query_graph_relationships,
|
||||
# normalized_terms=[], # TODO: remove fully later
|
||||
# normalized_time_filter=normalized_time_filter,
|
||||
# strategy=search_strategy,
|
||||
# broken_down_question=broken_down_question,
|
||||
# output_format=output_format,
|
||||
# divide_and_conquer=divide_and_conquer,
|
||||
# single_doc_id=single_doc_id,
|
||||
# search_type=search_type,
|
||||
# query_type=query_type,
|
||||
# log_messages=[
|
||||
# get_langgraph_node_log_string(
|
||||
# graph_component="main",
|
||||
# node_name="analyze",
|
||||
# node_start_time=node_start_time,
|
||||
# )
|
||||
# ],
|
||||
# step_results=[
|
||||
# get_near_empty_step_results(
|
||||
# step_number=_KG_STEP_NR,
|
||||
# step_answer=step_answer,
|
||||
# verified_reranked_documents=[],
|
||||
# )
|
||||
# ],
|
||||
# remarks=[
|
||||
# _articulate_normalizations(
|
||||
# entity_normalization_map=normalized_entities.entity_normalization_map,
|
||||
# relationship_normalization_map=normalized_relationships.relationship_normalization_map,
|
||||
# )
|
||||
# ],
|
||||
# )
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,185 +1,185 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
# from datetime import datetime
|
||||
# from typing import cast
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
# from langchain_core.messages import HumanMessage
|
||||
# from langchain_core.runnables import RunnableConfig
|
||||
# from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.kb_search.states import DeepSearchFilterUpdate
|
||||
from onyx.agents.agent_search.kb_search.states import KGFilterConstructionResults
|
||||
from onyx.agents.agent_search.kb_search.states import MainState
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.configs.kg_configs import KG_FILTER_CONSTRUCTION_TIMEOUT
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.entity_type import get_entity_types_with_grounded_source_name
|
||||
from onyx.kg.utils.formatting_utils import make_entity_id
|
||||
from onyx.prompts.kg_prompts import SEARCH_FILTER_CONSTRUCTION_PROMPT
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
# from onyx.agents.agent_search.kb_search.states import DeepSearchFilterUpdate
|
||||
# from onyx.agents.agent_search.kb_search.states import KGFilterConstructionResults
|
||||
# from onyx.agents.agent_search.kb_search.states import MainState
|
||||
# from onyx.agents.agent_search.models import GraphConfig
|
||||
# from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
# get_langgraph_node_log_string,
|
||||
# )
|
||||
# from onyx.configs.kg_configs import KG_FILTER_CONSTRUCTION_TIMEOUT
|
||||
# from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
# from onyx.db.entity_type import get_entity_types_with_grounded_source_name
|
||||
# from onyx.kg.utils.formatting_utils import make_entity_id
|
||||
# from onyx.prompts.kg_prompts import SEARCH_FILTER_CONSTRUCTION_PROMPT
|
||||
# from onyx.utils.logger import setup_logger
|
||||
# from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
# logger = setup_logger()
|
||||
|
||||
|
||||
def construct_deep_search_filters(
|
||||
state: MainState, config: RunnableConfig, writer: StreamWriter
|
||||
) -> DeepSearchFilterUpdate:
|
||||
"""
|
||||
LangGraph node to start the agentic search process.
|
||||
"""
|
||||
# def construct_deep_search_filters(
|
||||
# state: MainState, config: RunnableConfig, writer: StreamWriter
|
||||
# ) -> DeepSearchFilterUpdate:
|
||||
# """
|
||||
# LangGraph node to start the agentic search process.
|
||||
# """
|
||||
|
||||
node_start_time = datetime.now()
|
||||
# node_start_time = datetime.now()
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
question = state.question
|
||||
# graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
# question = state.question
|
||||
|
||||
entities_types_str = state.entities_types_str
|
||||
entities = state.query_graph_entities_no_attributes
|
||||
relationships = state.query_graph_relationships
|
||||
simple_sql_query = state.sql_query
|
||||
simple_sql_results = state.sql_query_results
|
||||
source_document_results = state.source_document_results
|
||||
if simple_sql_results:
|
||||
simple_sql_results_str = "\n".join([str(x) for x in simple_sql_results])
|
||||
else:
|
||||
simple_sql_results_str = "(no SQL results generated)"
|
||||
if source_document_results:
|
||||
source_document_results_str = "\n".join(
|
||||
[str(x) for x in source_document_results]
|
||||
)
|
||||
else:
|
||||
source_document_results_str = "(no source document results generated)"
|
||||
# entities_types_str = state.entities_types_str
|
||||
# entities = state.query_graph_entities_no_attributes
|
||||
# relationships = state.query_graph_relationships
|
||||
# simple_sql_query = state.sql_query
|
||||
# simple_sql_results = state.sql_query_results
|
||||
# source_document_results = state.source_document_results
|
||||
# if simple_sql_results:
|
||||
# simple_sql_results_str = "\n".join([str(x) for x in simple_sql_results])
|
||||
# else:
|
||||
# simple_sql_results_str = "(no SQL results generated)"
|
||||
# if source_document_results:
|
||||
# source_document_results_str = "\n".join(
|
||||
# [str(x) for x in source_document_results]
|
||||
# )
|
||||
# else:
|
||||
# source_document_results_str = "(no source document results generated)"
|
||||
|
||||
search_filter_construction_prompt = (
|
||||
SEARCH_FILTER_CONSTRUCTION_PROMPT.replace(
|
||||
"---entity_type_descriptions---",
|
||||
entities_types_str,
|
||||
)
|
||||
.replace(
|
||||
"---entity_filters---",
|
||||
"\n".join(entities),
|
||||
)
|
||||
.replace(
|
||||
"---relationship_filters---",
|
||||
"\n".join(relationships),
|
||||
)
|
||||
.replace(
|
||||
"---sql_query---",
|
||||
simple_sql_query or "(no SQL generated)",
|
||||
)
|
||||
.replace(
|
||||
"---sql_results---",
|
||||
simple_sql_results_str or "(no SQL results generated)",
|
||||
)
|
||||
.replace(
|
||||
"---source_document_results---",
|
||||
source_document_results_str or "(no source document results generated)",
|
||||
)
|
||||
.replace(
|
||||
"---question---",
|
||||
question,
|
||||
)
|
||||
)
|
||||
# search_filter_construction_prompt = (
|
||||
# SEARCH_FILTER_CONSTRUCTION_PROMPT.replace(
|
||||
# "---entity_type_descriptions---",
|
||||
# entities_types_str,
|
||||
# )
|
||||
# .replace(
|
||||
# "---entity_filters---",
|
||||
# "\n".join(entities),
|
||||
# )
|
||||
# .replace(
|
||||
# "---relationship_filters---",
|
||||
# "\n".join(relationships),
|
||||
# )
|
||||
# .replace(
|
||||
# "---sql_query---",
|
||||
# simple_sql_query or "(no SQL generated)",
|
||||
# )
|
||||
# .replace(
|
||||
# "---sql_results---",
|
||||
# simple_sql_results_str or "(no SQL results generated)",
|
||||
# )
|
||||
# .replace(
|
||||
# "---source_document_results---",
|
||||
# source_document_results_str or "(no source document results generated)",
|
||||
# )
|
||||
# .replace(
|
||||
# "---question---",
|
||||
# question,
|
||||
# )
|
||||
# )
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=search_filter_construction_prompt,
|
||||
)
|
||||
]
|
||||
llm = graph_config.tooling.primary_llm
|
||||
# Grader
|
||||
try:
|
||||
llm_response = run_with_timeout(
|
||||
KG_FILTER_CONSTRUCTION_TIMEOUT,
|
||||
llm.invoke_langchain,
|
||||
prompt=msg,
|
||||
timeout_override=15,
|
||||
max_tokens=1400,
|
||||
)
|
||||
# msg = [
|
||||
# HumanMessage(
|
||||
# content=search_filter_construction_prompt,
|
||||
# )
|
||||
# ]
|
||||
# llm = graph_config.tooling.primary_llm
|
||||
# # Grader
|
||||
# try:
|
||||
# llm_response = run_with_timeout(
|
||||
# KG_FILTER_CONSTRUCTION_TIMEOUT,
|
||||
# llm.invoke_langchain,
|
||||
# prompt=msg,
|
||||
# timeout_override=15,
|
||||
# max_tokens=1400,
|
||||
# )
|
||||
|
||||
cleaned_response = (
|
||||
str(llm_response.content)
|
||||
.replace("```json\n", "")
|
||||
.replace("\n```", "")
|
||||
.replace("\n", "")
|
||||
)
|
||||
first_bracket = cleaned_response.find("{")
|
||||
last_bracket = cleaned_response.rfind("}")
|
||||
# cleaned_response = (
|
||||
# str(llm_response.content)
|
||||
# .replace("```json\n", "")
|
||||
# .replace("\n```", "")
|
||||
# .replace("\n", "")
|
||||
# )
|
||||
# first_bracket = cleaned_response.find("{")
|
||||
# last_bracket = cleaned_response.rfind("}")
|
||||
|
||||
if last_bracket == -1 or first_bracket == -1:
|
||||
raise ValueError("No valid JSON found in LLM response - no brackets found")
|
||||
cleaned_response = cleaned_response[first_bracket : last_bracket + 1]
|
||||
cleaned_response = cleaned_response.replace("{{", '{"')
|
||||
cleaned_response = cleaned_response.replace("}}", '"}')
|
||||
# if last_bracket == -1 or first_bracket == -1:
|
||||
# raise ValueError("No valid JSON found in LLM response - no brackets found")
|
||||
# cleaned_response = cleaned_response[first_bracket : last_bracket + 1]
|
||||
# cleaned_response = cleaned_response.replace("{{", '{"')
|
||||
# cleaned_response = cleaned_response.replace("}}", '"}')
|
||||
|
||||
try:
|
||||
# try:
|
||||
|
||||
filter_results = KGFilterConstructionResults.model_validate_json(
|
||||
cleaned_response
|
||||
)
|
||||
except ValueError:
|
||||
logger.error(
|
||||
"Failed to parse LLM response as JSON in Entity-Term Extraction"
|
||||
)
|
||||
filter_results = KGFilterConstructionResults(
|
||||
global_entity_filters=[],
|
||||
global_relationship_filters=[],
|
||||
local_entity_filters=[],
|
||||
source_document_filters=[],
|
||||
structure=[],
|
||||
)
|
||||
# filter_results = KGFilterConstructionResults.model_validate_json(
|
||||
# cleaned_response
|
||||
# )
|
||||
# except ValueError:
|
||||
# logger.error(
|
||||
# "Failed to parse LLM response as JSON in Entity-Term Extraction"
|
||||
# )
|
||||
# filter_results = KGFilterConstructionResults(
|
||||
# global_entity_filters=[],
|
||||
# global_relationship_filters=[],
|
||||
# local_entity_filters=[],
|
||||
# source_document_filters=[],
|
||||
# structure=[],
|
||||
# )
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in construct_deep_search_filters: {e}")
|
||||
filter_results = KGFilterConstructionResults(
|
||||
global_entity_filters=[],
|
||||
global_relationship_filters=[],
|
||||
local_entity_filters=[],
|
||||
source_document_filters=[],
|
||||
structure=[],
|
||||
)
|
||||
# except Exception as e:
|
||||
# logger.error(f"Error in construct_deep_search_filters: {e}")
|
||||
# filter_results = KGFilterConstructionResults(
|
||||
# global_entity_filters=[],
|
||||
# global_relationship_filters=[],
|
||||
# local_entity_filters=[],
|
||||
# source_document_filters=[],
|
||||
# structure=[],
|
||||
# )
|
||||
|
||||
div_con_structure = filter_results.structure
|
||||
# div_con_structure = filter_results.structure
|
||||
|
||||
logger.info(f"div_con_structure: {div_con_structure}")
|
||||
# logger.info(f"div_con_structure: {div_con_structure}")
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
double_grounded_entity_types = get_entity_types_with_grounded_source_name(
|
||||
db_session
|
||||
)
|
||||
# with get_session_with_current_tenant() as db_session:
|
||||
# double_grounded_entity_types = get_entity_types_with_grounded_source_name(
|
||||
# db_session
|
||||
# )
|
||||
|
||||
source_division = False
|
||||
# source_division = False
|
||||
|
||||
if div_con_structure:
|
||||
for entity_type in double_grounded_entity_types:
|
||||
# entity_type is guaranteed to have grounded_source_name
|
||||
if (
|
||||
cast(str, entity_type.grounded_source_name).lower()
|
||||
in div_con_structure[0].lower()
|
||||
):
|
||||
source_division = True
|
||||
break
|
||||
# if div_con_structure:
|
||||
# for entity_type in double_grounded_entity_types:
|
||||
# # entity_type is guaranteed to have grounded_source_name
|
||||
# if (
|
||||
# cast(str, entity_type.grounded_source_name).lower()
|
||||
# in div_con_structure[0].lower()
|
||||
# ):
|
||||
# source_division = True
|
||||
# break
|
||||
|
||||
return DeepSearchFilterUpdate(
|
||||
vespa_filter_results=filter_results,
|
||||
div_con_entities=div_con_structure,
|
||||
source_division=source_division,
|
||||
global_entity_filters=[
|
||||
make_entity_id(global_filter, "*")
|
||||
for global_filter in filter_results.global_entity_filters
|
||||
],
|
||||
global_relationship_filters=filter_results.global_relationship_filters,
|
||||
local_entity_filters=filter_results.local_entity_filters,
|
||||
source_filters=filter_results.source_document_filters,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="main",
|
||||
node_name="construct deep search filters",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
step_results=[],
|
||||
)
|
||||
# return DeepSearchFilterUpdate(
|
||||
# vespa_filter_results=filter_results,
|
||||
# div_con_entities=div_con_structure,
|
||||
# source_division=source_division,
|
||||
# global_entity_filters=[
|
||||
# make_entity_id(global_filter, "*")
|
||||
# for global_filter in filter_results.global_entity_filters
|
||||
# ],
|
||||
# global_relationship_filters=filter_results.global_relationship_filters,
|
||||
# local_entity_filters=filter_results.local_entity_filters,
|
||||
# source_filters=filter_results.source_document_filters,
|
||||
# log_messages=[
|
||||
# get_langgraph_node_log_string(
|
||||
# graph_component="main",
|
||||
# node_name="construct deep search filters",
|
||||
# node_start_time=node_start_time,
|
||||
# )
|
||||
# ],
|
||||
# step_results=[],
|
||||
# )
|
||||
|
||||
@@ -1,168 +1,168 @@
|
||||
import copy
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
# import copy
|
||||
# from datetime import datetime
|
||||
# from typing import cast
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
# from langchain_core.messages import HumanMessage
|
||||
# from langchain_core.runnables import RunnableConfig
|
||||
# from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import build_document_context
|
||||
from onyx.agents.agent_search.kb_search.ops import research
|
||||
from onyx.agents.agent_search.kb_search.states import KGSourceDivisionType
|
||||
from onyx.agents.agent_search.kb_search.states import ResearchObjectInput
|
||||
from onyx.agents.agent_search.kb_search.states import ResearchObjectUpdate
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
trim_prompt_piece,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.chat.models import LlmDoc
|
||||
from onyx.configs.kg_configs import KG_MAX_SEARCH_DOCUMENTS
|
||||
from onyx.configs.kg_configs import KG_OBJECT_SOURCE_RESEARCH_TIMEOUT
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.kg.utils.formatting_utils import split_entity_id
|
||||
from onyx.prompts.kg_prompts import KG_OBJECT_SOURCE_RESEARCH_PROMPT
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
# from onyx.agents.agent_search.kb_search.graph_utils import build_document_context
|
||||
# from onyx.agents.agent_search.kb_search.ops import research
|
||||
# from onyx.agents.agent_search.kb_search.states import KGSourceDivisionType
|
||||
# from onyx.agents.agent_search.kb_search.states import ResearchObjectInput
|
||||
# from onyx.agents.agent_search.kb_search.states import ResearchObjectUpdate
|
||||
# from onyx.agents.agent_search.models import GraphConfig
|
||||
# from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
# trim_prompt_piece,
|
||||
# )
|
||||
# from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
# get_langgraph_node_log_string,
|
||||
# )
|
||||
# from onyx.chat.models import LlmDoc
|
||||
# from onyx.configs.kg_configs import KG_MAX_SEARCH_DOCUMENTS
|
||||
# from onyx.configs.kg_configs import KG_OBJECT_SOURCE_RESEARCH_TIMEOUT
|
||||
# from onyx.context.search.models import InferenceSection
|
||||
# from onyx.kg.utils.formatting_utils import split_entity_id
|
||||
# from onyx.prompts.kg_prompts import KG_OBJECT_SOURCE_RESEARCH_PROMPT
|
||||
# from onyx.utils.logger import setup_logger
|
||||
# from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
|
||||
logger = setup_logger()
|
||||
# logger = setup_logger()
|
||||
|
||||
|
||||
def process_individual_deep_search(
|
||||
state: ResearchObjectInput,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
) -> ResearchObjectUpdate:
|
||||
"""
|
||||
LangGraph node to start the agentic search process.
|
||||
"""
|
||||
# def process_individual_deep_search(
|
||||
# state: ResearchObjectInput,
|
||||
# config: RunnableConfig,
|
||||
# writer: StreamWriter = lambda _: None,
|
||||
# ) -> ResearchObjectUpdate:
|
||||
# """
|
||||
# LangGraph node to start the agentic search process.
|
||||
# """
|
||||
|
||||
node_start_time = datetime.now()
|
||||
# node_start_time = datetime.now()
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
search_tool = graph_config.tooling.search_tool
|
||||
question = state.broken_down_question
|
||||
segment_type = state.segment_type
|
||||
# graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
# search_tool = graph_config.tooling.search_tool
|
||||
# question = state.broken_down_question
|
||||
# segment_type = state.segment_type
|
||||
|
||||
object = state.entity.replace("::", ":: ").lower()
|
||||
# object = state.entity.replace("::", ":: ").lower()
|
||||
|
||||
if not search_tool:
|
||||
raise ValueError("search_tool is not provided")
|
||||
# if not search_tool:
|
||||
# raise ValueError("search_tool is not provided")
|
||||
|
||||
state.research_nr
|
||||
# state.research_nr
|
||||
|
||||
if segment_type == KGSourceDivisionType.ENTITY.value:
|
||||
# if segment_type == KGSourceDivisionType.ENTITY.value:
|
||||
|
||||
object_id = split_entity_id(object)[1].strip()
|
||||
extended_question = f"{question} in regards to {object}"
|
||||
source_filters = state.source_entity_filters
|
||||
# object_id = split_entity_id(object)[1].strip()
|
||||
# extended_question = f"{question} in regards to {object}"
|
||||
# source_filters = state.source_entity_filters
|
||||
|
||||
# TODO: this does not really occur in V1. But needs to be changed for V2
|
||||
raw_kg_entity_filters = copy.deepcopy(
|
||||
list(
|
||||
set((state.vespa_filter_results.global_entity_filters + [state.entity]))
|
||||
)
|
||||
)
|
||||
# # TODO: this does not really occur in V1. But needs to be changed for V2
|
||||
# raw_kg_entity_filters = copy.deepcopy(
|
||||
# list(
|
||||
# set((state.vespa_filter_results.global_entity_filters + [state.entity]))
|
||||
# )
|
||||
# )
|
||||
|
||||
kg_entity_filters = []
|
||||
for raw_kg_entity_filter in raw_kg_entity_filters:
|
||||
if "::" not in raw_kg_entity_filter:
|
||||
raw_kg_entity_filter += "::*"
|
||||
kg_entity_filters.append(raw_kg_entity_filter)
|
||||
# kg_entity_filters = []
|
||||
# for raw_kg_entity_filter in raw_kg_entity_filters:
|
||||
# if "::" not in raw_kg_entity_filter:
|
||||
# raw_kg_entity_filter += "::*"
|
||||
# kg_entity_filters.append(raw_kg_entity_filter)
|
||||
|
||||
kg_relationship_filters = copy.deepcopy(
|
||||
state.vespa_filter_results.global_relationship_filters
|
||||
)
|
||||
# kg_relationship_filters = copy.deepcopy(
|
||||
# state.vespa_filter_results.global_relationship_filters
|
||||
# )
|
||||
|
||||
logger.debug("Research for object: " + object)
|
||||
logger.debug(f"kg_entity_filters: {kg_entity_filters}")
|
||||
logger.debug(f"kg_relationship_filters: {kg_relationship_filters}")
|
||||
# logger.debug("Research for object: " + object)
|
||||
# logger.debug(f"kg_entity_filters: {kg_entity_filters}")
|
||||
# logger.debug(f"kg_relationship_filters: {kg_relationship_filters}")
|
||||
|
||||
else:
|
||||
# if we came through the entity view route, in KG V1 the state entity
|
||||
# is the document to search for. No need to set other filters then.
|
||||
object_id = state.entity # source doc in this case
|
||||
extended_question = f"{question}"
|
||||
source_filters = [object_id]
|
||||
# else:
|
||||
# # if we came through the entity view route, in KG V1 the state entity
|
||||
# # is the document to search for. No need to set other filters then.
|
||||
# object_id = state.entity # source doc in this case
|
||||
# extended_question = f"{question}"
|
||||
# source_filters = [object_id]
|
||||
|
||||
kg_entity_filters = None
|
||||
kg_relationship_filters = None
|
||||
# kg_entity_filters = None
|
||||
# kg_relationship_filters = None
|
||||
|
||||
if source_filters and (len(source_filters) > KG_MAX_SEARCH_DOCUMENTS):
|
||||
logger.debug(
|
||||
f"Too many sources ({len(source_filters)}), setting to None and effectively filtered search"
|
||||
)
|
||||
source_filters = None
|
||||
# if source_filters and (len(source_filters) > KG_MAX_SEARCH_DOCUMENTS):
|
||||
# logger.debug(
|
||||
# f"Too many sources ({len(source_filters)}), setting to None and effectively filtered search"
|
||||
# )
|
||||
# source_filters = None
|
||||
|
||||
retrieved_docs = research(
|
||||
question=extended_question,
|
||||
kg_entities=kg_entity_filters,
|
||||
kg_relationships=kg_relationship_filters,
|
||||
kg_sources=source_filters,
|
||||
search_tool=search_tool,
|
||||
)
|
||||
# retrieved_docs = research(
|
||||
# question=extended_question,
|
||||
# kg_entities=kg_entity_filters,
|
||||
# kg_relationships=kg_relationship_filters,
|
||||
# kg_sources=source_filters,
|
||||
# search_tool=search_tool,
|
||||
# )
|
||||
|
||||
document_texts_list = []
|
||||
# document_texts_list = []
|
||||
|
||||
for doc_num, retrieved_doc in enumerate(retrieved_docs):
|
||||
if not isinstance(retrieved_doc, (InferenceSection, LlmDoc)):
|
||||
raise ValueError(f"Unexpected document type: {type(retrieved_doc)}")
|
||||
chunk_text = build_document_context(retrieved_doc, doc_num + 1)
|
||||
document_texts_list.append(chunk_text)
|
||||
# for doc_num, retrieved_doc in enumerate(retrieved_docs):
|
||||
# if not isinstance(retrieved_doc, (InferenceSection, LlmDoc)):
|
||||
# raise ValueError(f"Unexpected document type: {type(retrieved_doc)}")
|
||||
# chunk_text = build_document_context(retrieved_doc, doc_num + 1)
|
||||
# document_texts_list.append(chunk_text)
|
||||
|
||||
document_texts = "\n\n".join(document_texts_list)
|
||||
# document_texts = "\n\n".join(document_texts_list)
|
||||
|
||||
# Built prompt
|
||||
# # Built prompt
|
||||
|
||||
kg_object_source_research_prompt = KG_OBJECT_SOURCE_RESEARCH_PROMPT.format(
|
||||
question=extended_question,
|
||||
document_text=document_texts,
|
||||
)
|
||||
# kg_object_source_research_prompt = KG_OBJECT_SOURCE_RESEARCH_PROMPT.format(
|
||||
# question=extended_question,
|
||||
# document_text=document_texts,
|
||||
# )
|
||||
|
||||
# Run LLM
|
||||
# # Run LLM
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=trim_prompt_piece(
|
||||
config=graph_config.tooling.primary_llm.config,
|
||||
prompt_piece=kg_object_source_research_prompt,
|
||||
reserved_str="",
|
||||
),
|
||||
)
|
||||
]
|
||||
primary_llm = graph_config.tooling.primary_llm
|
||||
# Grader
|
||||
try:
|
||||
llm_response = run_with_timeout(
|
||||
KG_OBJECT_SOURCE_RESEARCH_TIMEOUT,
|
||||
primary_llm.invoke_langchain,
|
||||
prompt=msg,
|
||||
timeout_override=KG_OBJECT_SOURCE_RESEARCH_TIMEOUT,
|
||||
max_tokens=300,
|
||||
)
|
||||
# msg = [
|
||||
# HumanMessage(
|
||||
# content=trim_prompt_piece(
|
||||
# config=graph_config.tooling.primary_llm.config,
|
||||
# prompt_piece=kg_object_source_research_prompt,
|
||||
# reserved_str="",
|
||||
# ),
|
||||
# )
|
||||
# ]
|
||||
# primary_llm = graph_config.tooling.primary_llm
|
||||
# # Grader
|
||||
# try:
|
||||
# llm_response = run_with_timeout(
|
||||
# KG_OBJECT_SOURCE_RESEARCH_TIMEOUT,
|
||||
# primary_llm.invoke_langchain,
|
||||
# prompt=msg,
|
||||
# timeout_override=KG_OBJECT_SOURCE_RESEARCH_TIMEOUT,
|
||||
# max_tokens=300,
|
||||
# )
|
||||
|
||||
object_research_results = str(llm_response.content).replace("```json\n", "")
|
||||
# object_research_results = str(llm_response.content).replace("```json\n", "")
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error in research_object_source: {e}")
|
||||
# except Exception as e:
|
||||
# raise ValueError(f"Error in research_object_source: {e}")
|
||||
|
||||
logger.debug("DivCon Step A2 - Object Source Research - completed for an object")
|
||||
# logger.debug("DivCon Step A2 - Object Source Research - completed for an object")
|
||||
|
||||
return ResearchObjectUpdate(
|
||||
research_object_results=[
|
||||
{
|
||||
"object": object.replace("::", ":: ").capitalize(),
|
||||
"results": object_research_results,
|
||||
}
|
||||
],
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="main",
|
||||
node_name="process individual deep search",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
step_results=[],
|
||||
)
|
||||
# return ResearchObjectUpdate(
|
||||
# research_object_results=[
|
||||
# {
|
||||
# "object": object.replace("::", ":: ").capitalize(),
|
||||
# "results": object_research_results,
|
||||
# }
|
||||
# ],
|
||||
# log_messages=[
|
||||
# get_langgraph_node_log_string(
|
||||
# graph_component="main",
|
||||
# node_name="process individual deep search",
|
||||
# node_start_time=node_start_time,
|
||||
# )
|
||||
# ],
|
||||
# step_results=[],
|
||||
# )
|
||||
|
||||
@@ -1,166 +1,166 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
# from datetime import datetime
|
||||
# from typing import cast
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
# from langchain_core.messages import HumanMessage
|
||||
# from langchain_core.runnables import RunnableConfig
|
||||
# from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import build_document_context
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import get_near_empty_step_results
|
||||
from onyx.agents.agent_search.kb_search.ops import research
|
||||
from onyx.agents.agent_search.kb_search.states import ConsolidatedResearchUpdate
|
||||
from onyx.agents.agent_search.kb_search.states import MainState
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
trim_prompt_piece,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.calculations import (
|
||||
get_answer_generation_documents,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.configs.kg_configs import KG_FILTERED_SEARCH_TIMEOUT
|
||||
from onyx.configs.kg_configs import KG_RESEARCH_NUM_RETRIEVED_DOCS
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.prompts.kg_prompts import KG_SEARCH_PROMPT
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
# from onyx.agents.agent_search.kb_search.graph_utils import build_document_context
|
||||
# from onyx.agents.agent_search.kb_search.graph_utils import get_near_empty_step_results
|
||||
# from onyx.agents.agent_search.kb_search.ops import research
|
||||
# from onyx.agents.agent_search.kb_search.states import ConsolidatedResearchUpdate
|
||||
# from onyx.agents.agent_search.kb_search.states import MainState
|
||||
# from onyx.agents.agent_search.models import GraphConfig
|
||||
# from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
# trim_prompt_piece,
|
||||
# )
|
||||
# from onyx.agents.agent_search.shared_graph_utils.calculations import (
|
||||
# get_answer_generation_documents,
|
||||
# )
|
||||
# from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
# get_langgraph_node_log_string,
|
||||
# )
|
||||
# from onyx.configs.kg_configs import KG_FILTERED_SEARCH_TIMEOUT
|
||||
# from onyx.configs.kg_configs import KG_RESEARCH_NUM_RETRIEVED_DOCS
|
||||
# from onyx.context.search.models import InferenceSection
|
||||
# from onyx.prompts.kg_prompts import KG_SEARCH_PROMPT
|
||||
# from onyx.utils.logger import setup_logger
|
||||
# from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
# logger = setup_logger()
|
||||
|
||||
|
||||
def filtered_search(
|
||||
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> ConsolidatedResearchUpdate:
|
||||
"""
|
||||
LangGraph node to do a filtered search.
|
||||
"""
|
||||
_KG_STEP_NR = 4
|
||||
# def filtered_search(
|
||||
# state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
# ) -> ConsolidatedResearchUpdate:
|
||||
# """
|
||||
# LangGraph node to do a filtered search.
|
||||
# """
|
||||
# _KG_STEP_NR = 4
|
||||
|
||||
node_start_time = datetime.now()
|
||||
# node_start_time = datetime.now()
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
search_tool = graph_config.tooling.search_tool
|
||||
question = state.question
|
||||
# graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
# search_tool = graph_config.tooling.search_tool
|
||||
# question = state.question
|
||||
|
||||
if not search_tool:
|
||||
raise ValueError("search_tool is not provided")
|
||||
# if not search_tool:
|
||||
# raise ValueError("search_tool is not provided")
|
||||
|
||||
if not state.vespa_filter_results:
|
||||
raise ValueError("vespa_filter_results is not provided")
|
||||
raw_kg_entity_filters = list(
|
||||
set((state.vespa_filter_results.global_entity_filters))
|
||||
)
|
||||
# if not state.vespa_filter_results:
|
||||
# raise ValueError("vespa_filter_results is not provided")
|
||||
# raw_kg_entity_filters = list(
|
||||
# set((state.vespa_filter_results.global_entity_filters))
|
||||
# )
|
||||
|
||||
kg_entity_filters = []
|
||||
for raw_kg_entity_filter in raw_kg_entity_filters:
|
||||
if "::" not in raw_kg_entity_filter:
|
||||
raw_kg_entity_filter += "::*"
|
||||
kg_entity_filters.append(raw_kg_entity_filter)
|
||||
# kg_entity_filters = []
|
||||
# for raw_kg_entity_filter in raw_kg_entity_filters:
|
||||
# if "::" not in raw_kg_entity_filter:
|
||||
# raw_kg_entity_filter += "::*"
|
||||
# kg_entity_filters.append(raw_kg_entity_filter)
|
||||
|
||||
kg_relationship_filters = state.vespa_filter_results.global_relationship_filters
|
||||
# kg_relationship_filters = state.vespa_filter_results.global_relationship_filters
|
||||
|
||||
logger.debug("Starting filtered search")
|
||||
logger.debug(f"kg_entity_filters: {kg_entity_filters}")
|
||||
logger.debug(f"kg_relationship_filters: {kg_relationship_filters}")
|
||||
# logger.debug("Starting filtered search")
|
||||
# logger.debug(f"kg_entity_filters: {kg_entity_filters}")
|
||||
# logger.debug(f"kg_relationship_filters: {kg_relationship_filters}")
|
||||
|
||||
retrieved_docs = cast(
|
||||
list[InferenceSection],
|
||||
research(
|
||||
question=question,
|
||||
kg_entities=kg_entity_filters,
|
||||
kg_relationships=kg_relationship_filters,
|
||||
kg_sources=None,
|
||||
search_tool=search_tool,
|
||||
inference_sections_only=True,
|
||||
),
|
||||
)
|
||||
# retrieved_docs = cast(
|
||||
# list[InferenceSection],
|
||||
# research(
|
||||
# question=question,
|
||||
# kg_entities=kg_entity_filters,
|
||||
# kg_relationships=kg_relationship_filters,
|
||||
# kg_sources=None,
|
||||
# search_tool=search_tool,
|
||||
# inference_sections_only=True,
|
||||
# ),
|
||||
# )
|
||||
|
||||
source_link_dict = {
|
||||
num + 1: doc.center_chunk.source_links[0]
|
||||
for num, doc in enumerate(retrieved_docs)
|
||||
if doc.center_chunk.source_links
|
||||
}
|
||||
# source_link_dict = {
|
||||
# num + 1: doc.center_chunk.source_links[0]
|
||||
# for num, doc in enumerate(retrieved_docs)
|
||||
# if doc.center_chunk.source_links
|
||||
# }
|
||||
|
||||
answer_generation_documents = get_answer_generation_documents(
|
||||
relevant_docs=retrieved_docs,
|
||||
context_documents=retrieved_docs,
|
||||
original_question_docs=retrieved_docs,
|
||||
max_docs=KG_RESEARCH_NUM_RETRIEVED_DOCS,
|
||||
)
|
||||
# answer_generation_documents = get_answer_generation_documents(
|
||||
# relevant_docs=retrieved_docs,
|
||||
# context_documents=retrieved_docs,
|
||||
# original_question_docs=retrieved_docs,
|
||||
# max_docs=KG_RESEARCH_NUM_RETRIEVED_DOCS,
|
||||
# )
|
||||
|
||||
document_texts_list = []
|
||||
# document_texts_list = []
|
||||
|
||||
for doc_num, retrieved_doc in enumerate(
|
||||
answer_generation_documents.context_documents
|
||||
):
|
||||
chunk_text = build_document_context(retrieved_doc, doc_num + 1)
|
||||
document_texts_list.append(chunk_text)
|
||||
# for doc_num, retrieved_doc in enumerate(
|
||||
# answer_generation_documents.context_documents
|
||||
# ):
|
||||
# chunk_text = build_document_context(retrieved_doc, doc_num + 1)
|
||||
# document_texts_list.append(chunk_text)
|
||||
|
||||
document_texts = "\n\n".join(document_texts_list)
|
||||
# document_texts = "\n\n".join(document_texts_list)
|
||||
|
||||
# Built prompt
|
||||
# # Built prompt
|
||||
|
||||
datetime.now().strftime("%A, %Y-%m-%d")
|
||||
# datetime.now().strftime("%A, %Y-%m-%d")
|
||||
|
||||
kg_object_source_research_prompt = KG_SEARCH_PROMPT.format(
|
||||
question=question,
|
||||
document_text=document_texts,
|
||||
)
|
||||
# kg_object_source_research_prompt = KG_SEARCH_PROMPT.format(
|
||||
# question=question,
|
||||
# document_text=document_texts,
|
||||
# )
|
||||
|
||||
# Run LLM
|
||||
# # Run LLM
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=trim_prompt_piece(
|
||||
config=graph_config.tooling.primary_llm.config,
|
||||
prompt_piece=kg_object_source_research_prompt,
|
||||
reserved_str="",
|
||||
),
|
||||
)
|
||||
]
|
||||
primary_llm = graph_config.tooling.primary_llm
|
||||
llm = primary_llm
|
||||
# Grader
|
||||
try:
|
||||
llm_response = run_with_timeout(
|
||||
KG_FILTERED_SEARCH_TIMEOUT,
|
||||
llm.invoke_langchain,
|
||||
prompt=msg,
|
||||
timeout_override=30,
|
||||
max_tokens=300,
|
||||
)
|
||||
# msg = [
|
||||
# HumanMessage(
|
||||
# content=trim_prompt_piece(
|
||||
# config=graph_config.tooling.primary_llm.config,
|
||||
# prompt_piece=kg_object_source_research_prompt,
|
||||
# reserved_str="",
|
||||
# ),
|
||||
# )
|
||||
# ]
|
||||
# primary_llm = graph_config.tooling.primary_llm
|
||||
# llm = primary_llm
|
||||
# # Grader
|
||||
# try:
|
||||
# llm_response = run_with_timeout(
|
||||
# KG_FILTERED_SEARCH_TIMEOUT,
|
||||
# llm.invoke_langchain,
|
||||
# prompt=msg,
|
||||
# timeout_override=30,
|
||||
# max_tokens=300,
|
||||
# )
|
||||
|
||||
filtered_search_answer = str(llm_response.content).replace("```json\n", "")
|
||||
# filtered_search_answer = str(llm_response.content).replace("```json\n", "")
|
||||
|
||||
# TODO: make sure the citations look correct. Currently, they do not.
|
||||
for source_link_num, source_link in source_link_dict.items():
|
||||
if f"[{source_link_num}]" in filtered_search_answer:
|
||||
filtered_search_answer = filtered_search_answer.replace(
|
||||
f"[{source_link_num}]", f"[{source_link_num}]({source_link})"
|
||||
)
|
||||
# # TODO: make sure the citations look correct. Currently, they do not.
|
||||
# for source_link_num, source_link in source_link_dict.items():
|
||||
# if f"[{source_link_num}]" in filtered_search_answer:
|
||||
# filtered_search_answer = filtered_search_answer.replace(
|
||||
# f"[{source_link_num}]", f"[{source_link_num}]({source_link})"
|
||||
# )
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error in filtered_search: {e}")
|
||||
# except Exception as e:
|
||||
# raise ValueError(f"Error in filtered_search: {e}")
|
||||
|
||||
step_answer = "Filtered search is complete."
|
||||
# step_answer = "Filtered search is complete."
|
||||
|
||||
return ConsolidatedResearchUpdate(
|
||||
consolidated_research_object_results_str=filtered_search_answer,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="main",
|
||||
node_name="filtered search",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
step_results=[
|
||||
get_near_empty_step_results(
|
||||
step_number=_KG_STEP_NR,
|
||||
step_answer=step_answer,
|
||||
verified_reranked_documents=retrieved_docs,
|
||||
)
|
||||
],
|
||||
)
|
||||
# return ConsolidatedResearchUpdate(
|
||||
# consolidated_research_object_results_str=filtered_search_answer,
|
||||
# log_messages=[
|
||||
# get_langgraph_node_log_string(
|
||||
# graph_component="main",
|
||||
# node_name="filtered search",
|
||||
# node_start_time=node_start_time,
|
||||
# )
|
||||
# ],
|
||||
# step_results=[
|
||||
# get_near_empty_step_results(
|
||||
# step_number=_KG_STEP_NR,
|
||||
# step_answer=step_answer,
|
||||
# verified_reranked_documents=retrieved_docs,
|
||||
# )
|
||||
# ],
|
||||
# )
|
||||
|
||||
@@ -1,54 +1,54 @@
|
||||
from datetime import datetime
|
||||
# from datetime import datetime
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
# from langchain_core.runnables import RunnableConfig
|
||||
# from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import get_near_empty_step_results
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import rename_entities_in_answer
|
||||
from onyx.agents.agent_search.kb_search.states import ConsolidatedResearchUpdate
|
||||
from onyx.agents.agent_search.kb_search.states import MainState
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
# from onyx.agents.agent_search.kb_search.graph_utils import get_near_empty_step_results
|
||||
# from onyx.agents.agent_search.kb_search.graph_utils import rename_entities_in_answer
|
||||
# from onyx.agents.agent_search.kb_search.states import ConsolidatedResearchUpdate
|
||||
# from onyx.agents.agent_search.kb_search.states import MainState
|
||||
# from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
# get_langgraph_node_log_string,
|
||||
# )
|
||||
# from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
# logger = setup_logger()
|
||||
|
||||
|
||||
def consolidate_individual_deep_search(
|
||||
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> ConsolidatedResearchUpdate:
|
||||
"""
|
||||
LangGraph node to start the agentic search process.
|
||||
"""
|
||||
# def consolidate_individual_deep_search(
|
||||
# state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
# ) -> ConsolidatedResearchUpdate:
|
||||
# """
|
||||
# LangGraph node to start the agentic search process.
|
||||
# """
|
||||
|
||||
_KG_STEP_NR = 4
|
||||
node_start_time = datetime.now()
|
||||
# _KG_STEP_NR = 4
|
||||
# node_start_time = datetime.now()
|
||||
|
||||
research_object_results = state.research_object_results
|
||||
# research_object_results = state.research_object_results
|
||||
|
||||
consolidated_research_object_results_str = "\n".join(
|
||||
[f"{x['object']}: {x['results']}" for x in research_object_results]
|
||||
)
|
||||
# consolidated_research_object_results_str = "\n".join(
|
||||
# [f"{x['object']}: {x['results']}" for x in research_object_results]
|
||||
# )
|
||||
|
||||
consolidated_research_object_results_str = rename_entities_in_answer(
|
||||
consolidated_research_object_results_str
|
||||
)
|
||||
# consolidated_research_object_results_str = rename_entities_in_answer(
|
||||
# consolidated_research_object_results_str
|
||||
# )
|
||||
|
||||
step_answer = "All research is complete. Consolidating results..."
|
||||
# step_answer = "All research is complete. Consolidating results..."
|
||||
|
||||
return ConsolidatedResearchUpdate(
|
||||
consolidated_research_object_results_str=consolidated_research_object_results_str,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="main",
|
||||
node_name="consolidate individual deep search",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
step_results=[
|
||||
get_near_empty_step_results(
|
||||
step_number=_KG_STEP_NR, step_answer=step_answer
|
||||
)
|
||||
],
|
||||
)
|
||||
# return ConsolidatedResearchUpdate(
|
||||
# consolidated_research_object_results_str=consolidated_research_object_results_str,
|
||||
# log_messages=[
|
||||
# get_langgraph_node_log_string(
|
||||
# graph_component="main",
|
||||
# node_name="consolidate individual deep search",
|
||||
# node_start_time=node_start_time,
|
||||
# )
|
||||
# ],
|
||||
# step_results=[
|
||||
# get_near_empty_step_results(
|
||||
# step_number=_KG_STEP_NR, step_answer=step_answer
|
||||
# )
|
||||
# ],
|
||||
# )
|
||||
|
||||
@@ -1,96 +1,96 @@
|
||||
from datetime import datetime
|
||||
# from datetime import datetime
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
# from langchain_core.runnables import RunnableConfig
|
||||
# from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import get_near_empty_step_results
|
||||
from onyx.agents.agent_search.kb_search.states import MainState
|
||||
from onyx.agents.agent_search.kb_search.states import ResultsDataUpdate
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.db.document import get_base_llm_doc_information
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.utils.logger import setup_logger
|
||||
# from onyx.agents.agent_search.kb_search.graph_utils import get_near_empty_step_results
|
||||
# from onyx.agents.agent_search.kb_search.states import MainState
|
||||
# from onyx.agents.agent_search.kb_search.states import ResultsDataUpdate
|
||||
# from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
# get_langgraph_node_log_string,
|
||||
# )
|
||||
# from onyx.db.document import get_base_llm_doc_information
|
||||
# from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
# from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
# logger = setup_logger()
|
||||
|
||||
|
||||
def _get_formated_source_reference_results(
|
||||
source_document_results: list[str] | None,
|
||||
) -> str | None:
|
||||
"""
|
||||
Generate reference results from the query results data string.
|
||||
"""
|
||||
# def _get_formated_source_reference_results(
|
||||
# source_document_results: list[str] | None,
|
||||
# ) -> str | None:
|
||||
# """
|
||||
# Generate reference results from the query results data string.
|
||||
# """
|
||||
|
||||
if source_document_results is None:
|
||||
return None
|
||||
# if source_document_results is None:
|
||||
# return None
|
||||
|
||||
# get all entities that correspond to an Onyx document
|
||||
document_ids = source_document_results
|
||||
# # get all entities that correspond to an Onyx document
|
||||
# document_ids = source_document_results
|
||||
|
||||
with get_session_with_current_tenant() as session:
|
||||
llm_doc_information_results = get_base_llm_doc_information(
|
||||
session, document_ids
|
||||
)
|
||||
# with get_session_with_current_tenant() as session:
|
||||
# llm_doc_information_results = get_base_llm_doc_information(
|
||||
# session, document_ids
|
||||
# )
|
||||
|
||||
if len(llm_doc_information_results) == 0:
|
||||
return ""
|
||||
# if len(llm_doc_information_results) == 0:
|
||||
# return ""
|
||||
|
||||
return (
|
||||
f"\n \n Here are {len(llm_doc_information_results)} supporting documents or examples: \n \n "
|
||||
+ " \n \n ".join(llm_doc_information_results)
|
||||
)
|
||||
# return (
|
||||
# f"\n \n Here are {len(llm_doc_information_results)} supporting documents or examples: \n \n "
|
||||
# + " \n \n ".join(llm_doc_information_results)
|
||||
# )
|
||||
|
||||
|
||||
def process_kg_only_answers(
|
||||
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> ResultsDataUpdate:
|
||||
"""
|
||||
LangGraph node to start the agentic search process.
|
||||
"""
|
||||
# def process_kg_only_answers(
|
||||
# state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
# ) -> ResultsDataUpdate:
|
||||
# """
|
||||
# LangGraph node to start the agentic search process.
|
||||
# """
|
||||
|
||||
_KG_STEP_NR = 4
|
||||
# _KG_STEP_NR = 4
|
||||
|
||||
node_start_time = datetime.now()
|
||||
# node_start_time = datetime.now()
|
||||
|
||||
query_results = state.sql_query_results
|
||||
source_document_results = state.source_document_results
|
||||
# query_results = state.sql_query_results
|
||||
# source_document_results = state.source_document_results
|
||||
|
||||
if query_results:
|
||||
query_results_data_str = "\n".join(
|
||||
str(query_result).replace("::", ":: ").capitalize()
|
||||
for query_result in query_results
|
||||
)
|
||||
else:
|
||||
logger.warning("No query results were found")
|
||||
query_results_data_str = "(No query results were found)"
|
||||
# if query_results:
|
||||
# query_results_data_str = "\n".join(
|
||||
# str(query_result).replace("::", ":: ").capitalize()
|
||||
# for query_result in query_results
|
||||
# )
|
||||
# else:
|
||||
# logger.warning("No query results were found")
|
||||
# query_results_data_str = "(No query results were found)"
|
||||
|
||||
source_reference_result_str = _get_formated_source_reference_results(
|
||||
source_document_results
|
||||
)
|
||||
# source_reference_result_str = _get_formated_source_reference_results(
|
||||
# source_document_results
|
||||
# )
|
||||
|
||||
## STEP 4 - same components as Step 1
|
||||
# ## STEP 4 - same components as Step 1
|
||||
|
||||
step_answer = (
|
||||
"No further research is needed, the answer is derived from the knowledge graph."
|
||||
)
|
||||
# step_answer = (
|
||||
# "No further research is needed, the answer is derived from the knowledge graph."
|
||||
# )
|
||||
|
||||
return ResultsDataUpdate(
|
||||
query_results_data_str=query_results_data_str,
|
||||
individualized_query_results_data_str="",
|
||||
reference_results_str=source_reference_result_str,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="main",
|
||||
node_name="kg query results data processing",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
step_results=[
|
||||
get_near_empty_step_results(
|
||||
step_number=_KG_STEP_NR, step_answer=step_answer
|
||||
)
|
||||
],
|
||||
)
|
||||
# return ResultsDataUpdate(
|
||||
# query_results_data_str=query_results_data_str,
|
||||
# individualized_query_results_data_str="",
|
||||
# reference_results_str=source_reference_result_str,
|
||||
# log_messages=[
|
||||
# get_langgraph_node_log_string(
|
||||
# graph_component="main",
|
||||
# node_name="kg query results data processing",
|
||||
# node_start_time=node_start_time,
|
||||
# )
|
||||
# ],
|
||||
# step_results=[
|
||||
# get_near_empty_step_results(
|
||||
# step_number=_KG_STEP_NR, step_answer=step_answer
|
||||
# )
|
||||
# ],
|
||||
# )
|
||||
|
||||
@@ -1,213 +1,187 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
# from datetime import datetime
|
||||
# from typing import cast
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
# from langchain_core.runnables import RunnableConfig
|
||||
# from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.access.access import get_acl_for_user
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import rename_entities_in_answer
|
||||
from onyx.agents.agent_search.kb_search.ops import research
|
||||
from onyx.agents.agent_search.kb_search.states import FinalAnswerUpdate
|
||||
from onyx.agents.agent_search.kb_search.states import MainState
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.calculations import (
|
||||
get_answer_generation_documents,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.llm import get_answer_from_llm
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import relevance_from_docs
|
||||
from onyx.configs.kg_configs import KG_RESEARCH_NUM_RETRIEVED_DOCS
|
||||
from onyx.configs.kg_configs import KG_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION
|
||||
from onyx.context.search.enums import SearchType
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.prompts.kg_prompts import OUTPUT_FORMAT_NO_EXAMPLES_PROMPT
|
||||
from onyx.prompts.kg_prompts import OUTPUT_FORMAT_NO_OVERALL_ANSWER_PROMPT
|
||||
from onyx.tools.tool_implementations.search.search_tool import IndexFilters
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchQueryInfo
|
||||
from onyx.tools.tool_implementations.search.search_tool import yield_search_responses
|
||||
from onyx.utils.logger import setup_logger
|
||||
# from onyx.access.access import get_acl_for_user
|
||||
# from onyx.agents.agent_search.kb_search.graph_utils import rename_entities_in_answer
|
||||
# from onyx.agents.agent_search.kb_search.ops import research
|
||||
# from onyx.agents.agent_search.kb_search.states import FinalAnswerUpdate
|
||||
# from onyx.agents.agent_search.kb_search.states import MainState
|
||||
# from onyx.agents.agent_search.models import GraphConfig
|
||||
# from onyx.agents.agent_search.shared_graph_utils.calculations import (
|
||||
# get_answer_generation_documents,
|
||||
# )
|
||||
# from onyx.agents.agent_search.shared_graph_utils.llm import get_answer_from_llm
|
||||
# from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
# get_langgraph_node_log_string,
|
||||
# )
|
||||
# from onyx.configs.kg_configs import KG_RESEARCH_NUM_RETRIEVED_DOCS
|
||||
# from onyx.configs.kg_configs import KG_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION
|
||||
# from onyx.context.search.models import InferenceSection
|
||||
# from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
# from onyx.prompts.kg_prompts import OUTPUT_FORMAT_NO_EXAMPLES_PROMPT
|
||||
# from onyx.prompts.kg_prompts import OUTPUT_FORMAT_NO_OVERALL_ANSWER_PROMPT
|
||||
# from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
# logger = setup_logger()
|
||||
|
||||
|
||||
def generate_answer(
|
||||
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> FinalAnswerUpdate:
|
||||
"""
|
||||
LangGraph node to start the agentic search process.
|
||||
"""
|
||||
# def generate_answer(
|
||||
# state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
# ) -> FinalAnswerUpdate:
|
||||
# """
|
||||
# LangGraph node to start the agentic search process.
|
||||
# """
|
||||
|
||||
node_start_time = datetime.now()
|
||||
# node_start_time = datetime.now()
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
question = state.question
|
||||
# graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
# question = state.question
|
||||
|
||||
final_answer: str | None = None
|
||||
# final_answer: str | None = None
|
||||
|
||||
user = (
|
||||
graph_config.tooling.search_tool.user
|
||||
if graph_config.tooling.search_tool
|
||||
else None
|
||||
)
|
||||
# user = (
|
||||
# graph_config.tooling.search_tool.user
|
||||
# if graph_config.tooling.search_tool
|
||||
# else None
|
||||
# )
|
||||
|
||||
if not user:
|
||||
raise ValueError("User is not set")
|
||||
# if not user:
|
||||
# raise ValueError("User is not set")
|
||||
|
||||
search_tool = graph_config.tooling.search_tool
|
||||
if search_tool is None:
|
||||
raise ValueError("Search tool is not set")
|
||||
# search_tool = graph_config.tooling.search_tool
|
||||
# if search_tool is None:
|
||||
# raise ValueError("Search tool is not set")
|
||||
|
||||
# Close out previous streams of steps
|
||||
# # Close out previous streams of steps
|
||||
|
||||
# DECLARE STEPS DONE
|
||||
# # DECLARE STEPS DONE
|
||||
|
||||
## MAIN ANSWER
|
||||
# ## MAIN ANSWER
|
||||
|
||||
# identify whether documents have already been retrieved
|
||||
# # identify whether documents have already been retrieved
|
||||
|
||||
retrieved_docs: list[InferenceSection] = []
|
||||
for step_result in state.step_results:
|
||||
retrieved_docs += step_result.verified_reranked_documents
|
||||
# retrieved_docs: list[InferenceSection] = []
|
||||
# for step_result in state.step_results:
|
||||
# retrieved_docs += step_result.verified_reranked_documents
|
||||
|
||||
# if still needed, get a search done and send the results to the UI
|
||||
# # if still needed, get a search done and send the results to the UI
|
||||
|
||||
if not retrieved_docs and state.source_document_results:
|
||||
assert graph_config.tooling.search_tool is not None
|
||||
retrieved_docs = cast(
|
||||
list[InferenceSection],
|
||||
research(
|
||||
question=question,
|
||||
kg_entities=[],
|
||||
kg_relationships=[],
|
||||
kg_sources=state.source_document_results[
|
||||
:KG_RESEARCH_NUM_RETRIEVED_DOCS
|
||||
],
|
||||
search_tool=graph_config.tooling.search_tool,
|
||||
kg_chunk_id_zero_only=True,
|
||||
inference_sections_only=True,
|
||||
),
|
||||
)
|
||||
# if not retrieved_docs and state.source_document_results:
|
||||
# assert graph_config.tooling.search_tool is not None
|
||||
# retrieved_docs = cast(
|
||||
# list[InferenceSection],
|
||||
# research(
|
||||
# question=question,
|
||||
# kg_entities=[],
|
||||
# kg_relationships=[],
|
||||
# kg_sources=state.source_document_results[
|
||||
# :KG_RESEARCH_NUM_RETRIEVED_DOCS
|
||||
# ],
|
||||
# search_tool=graph_config.tooling.search_tool,
|
||||
# kg_chunk_id_zero_only=True,
|
||||
# inference_sections_only=True,
|
||||
# ),
|
||||
# )
|
||||
|
||||
answer_generation_documents = get_answer_generation_documents(
|
||||
relevant_docs=retrieved_docs,
|
||||
context_documents=retrieved_docs,
|
||||
original_question_docs=retrieved_docs,
|
||||
max_docs=KG_RESEARCH_NUM_RETRIEVED_DOCS,
|
||||
)
|
||||
# answer_generation_documents = get_answer_generation_documents(
|
||||
# relevant_docs=retrieved_docs,
|
||||
# context_documents=retrieved_docs,
|
||||
# original_question_docs=retrieved_docs,
|
||||
# max_docs=KG_RESEARCH_NUM_RETRIEVED_DOCS,
|
||||
# )
|
||||
|
||||
relevance_list = relevance_from_docs(
|
||||
answer_generation_documents.streaming_documents
|
||||
)
|
||||
# assert graph_config.tooling.search_tool is not None
|
||||
|
||||
assert graph_config.tooling.search_tool is not None
|
||||
# with get_session_with_current_tenant() as graph_db_session:
|
||||
# list(get_acl_for_user(user, graph_db_session))
|
||||
|
||||
with get_session_with_current_tenant() as graph_db_session:
|
||||
user_acl = list(get_acl_for_user(user, graph_db_session))
|
||||
# # continue with the answer generation
|
||||
|
||||
for tool_response in yield_search_responses(
|
||||
query=question,
|
||||
get_retrieved_sections=lambda: answer_generation_documents.context_documents,
|
||||
get_final_context_sections=lambda: answer_generation_documents.context_documents,
|
||||
search_query_info=SearchQueryInfo(
|
||||
predicted_search=SearchType.KEYWORD,
|
||||
# acl here is empty, because the searach alrady happened and
|
||||
# we are streaming out the results.
|
||||
final_filters=IndexFilters(access_control_list=user_acl),
|
||||
recency_bias_multiplier=1.0,
|
||||
),
|
||||
get_section_relevance=lambda: relevance_list,
|
||||
search_tool=graph_config.tooling.search_tool,
|
||||
):
|
||||
# original document streaming
|
||||
pass
|
||||
# output_format = (
|
||||
# state.output_format.value
|
||||
# if state.output_format
|
||||
# else "<you be the judge how to best present the data>"
|
||||
# )
|
||||
|
||||
# continue with the answer generation
|
||||
# # if deep path was taken:
|
||||
|
||||
output_format = (
|
||||
state.output_format.value
|
||||
if state.output_format
|
||||
else "<you be the judge how to best present the data>"
|
||||
)
|
||||
# consolidated_research_object_results_str = (
|
||||
# state.consolidated_research_object_results_str
|
||||
# )
|
||||
# # reference_results_str = (
|
||||
# # state.reference_results_str
|
||||
# # ) # will not be part of LLM. Manually added to the answer
|
||||
|
||||
# if deep path was taken:
|
||||
# # if simple path was taken:
|
||||
# introductory_answer = state.query_results_data_str # from simple answer path only
|
||||
# if consolidated_research_object_results_str:
|
||||
# research_results = consolidated_research_object_results_str
|
||||
# else:
|
||||
# research_results = ""
|
||||
|
||||
consolidated_research_object_results_str = (
|
||||
state.consolidated_research_object_results_str
|
||||
)
|
||||
# reference_results_str = (
|
||||
# state.reference_results_str
|
||||
# ) # will not be part of LLM. Manually added to the answer
|
||||
# if introductory_answer:
|
||||
# output_format_prompt = (
|
||||
# OUTPUT_FORMAT_NO_EXAMPLES_PROMPT.replace("---question---", question)
|
||||
# .replace(
|
||||
# "---introductory_answer---",
|
||||
# rename_entities_in_answer(introductory_answer),
|
||||
# )
|
||||
# .replace("---output_format---", str(output_format) if output_format else "")
|
||||
# )
|
||||
# elif research_results and consolidated_research_object_results_str:
|
||||
# output_format_prompt = (
|
||||
# OUTPUT_FORMAT_NO_EXAMPLES_PROMPT.replace("---question---", question)
|
||||
# .replace(
|
||||
# "---introductory_answer---",
|
||||
# rename_entities_in_answer(consolidated_research_object_results_str),
|
||||
# )
|
||||
# .replace("---output_format---", str(output_format) if output_format else "")
|
||||
# )
|
||||
# elif research_results and not consolidated_research_object_results_str:
|
||||
# output_format_prompt = (
|
||||
# OUTPUT_FORMAT_NO_OVERALL_ANSWER_PROMPT.replace("---question---", question)
|
||||
# .replace("---output_format---", str(output_format) if output_format else "")
|
||||
# .replace(
|
||||
# "---research_results---", rename_entities_in_answer(research_results)
|
||||
# )
|
||||
# )
|
||||
# elif consolidated_research_object_results_str:
|
||||
# output_format_prompt = (
|
||||
# OUTPUT_FORMAT_NO_EXAMPLES_PROMPT.replace("---question---", question)
|
||||
# .replace("---output_format---", str(output_format) if output_format else "")
|
||||
# .replace(
|
||||
# "---research_results---", rename_entities_in_answer(research_results)
|
||||
# )
|
||||
# )
|
||||
# else:
|
||||
# raise ValueError("No research results or introductory answer provided")
|
||||
|
||||
# if simple path was taken:
|
||||
introductory_answer = state.query_results_data_str # from simple answer path only
|
||||
if consolidated_research_object_results_str:
|
||||
research_results = consolidated_research_object_results_str
|
||||
else:
|
||||
research_results = ""
|
||||
# try:
|
||||
|
||||
if introductory_answer:
|
||||
output_format_prompt = (
|
||||
OUTPUT_FORMAT_NO_EXAMPLES_PROMPT.replace("---question---", question)
|
||||
.replace(
|
||||
"---introductory_answer---",
|
||||
rename_entities_in_answer(introductory_answer),
|
||||
)
|
||||
.replace("---output_format---", str(output_format) if output_format else "")
|
||||
)
|
||||
elif research_results and consolidated_research_object_results_str:
|
||||
output_format_prompt = (
|
||||
OUTPUT_FORMAT_NO_EXAMPLES_PROMPT.replace("---question---", question)
|
||||
.replace(
|
||||
"---introductory_answer---",
|
||||
rename_entities_in_answer(consolidated_research_object_results_str),
|
||||
)
|
||||
.replace("---output_format---", str(output_format) if output_format else "")
|
||||
)
|
||||
elif research_results and not consolidated_research_object_results_str:
|
||||
output_format_prompt = (
|
||||
OUTPUT_FORMAT_NO_OVERALL_ANSWER_PROMPT.replace("---question---", question)
|
||||
.replace("---output_format---", str(output_format) if output_format else "")
|
||||
.replace(
|
||||
"---research_results---", rename_entities_in_answer(research_results)
|
||||
)
|
||||
)
|
||||
elif consolidated_research_object_results_str:
|
||||
output_format_prompt = (
|
||||
OUTPUT_FORMAT_NO_EXAMPLES_PROMPT.replace("---question---", question)
|
||||
.replace("---output_format---", str(output_format) if output_format else "")
|
||||
.replace(
|
||||
"---research_results---", rename_entities_in_answer(research_results)
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise ValueError("No research results or introductory answer provided")
|
||||
# final_answer = get_answer_from_llm(
|
||||
# llm=graph_config.tooling.primary_llm,
|
||||
# prompt=output_format_prompt,
|
||||
# stream=False,
|
||||
# json_string_flag=False,
|
||||
# timeout_override=KG_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION,
|
||||
# )
|
||||
|
||||
try:
|
||||
# except Exception as e:
|
||||
# raise ValueError(f"Could not generate the answer. Error {e}")
|
||||
|
||||
final_answer = get_answer_from_llm(
|
||||
llm=graph_config.tooling.primary_llm,
|
||||
prompt=output_format_prompt,
|
||||
stream=False,
|
||||
json_string_flag=False,
|
||||
timeout_override=KG_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Could not generate the answer. Error {e}")
|
||||
|
||||
return FinalAnswerUpdate(
|
||||
final_answer=final_answer,
|
||||
retrieved_documents=answer_generation_documents.context_documents,
|
||||
step_results=[],
|
||||
remarks=[],
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="main",
|
||||
node_name="query completed",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
# return FinalAnswerUpdate(
|
||||
# final_answer=final_answer,
|
||||
# retrieved_documents=answer_generation_documents.context_documents,
|
||||
# step_results=[],
|
||||
# remarks=[],
|
||||
# log_messages=[
|
||||
# get_langgraph_node_log_string(
|
||||
# graph_component="main",
|
||||
# node_name="query completed",
|
||||
# node_start_time=node_start_time,
|
||||
# )
|
||||
# ],
|
||||
# )
|
||||
|
||||
@@ -1,60 +1,60 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
# from datetime import datetime
|
||||
# from typing import cast
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
# from langchain_core.runnables import RunnableConfig
|
||||
# from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.kb_search.states import MainOutput
|
||||
from onyx.agents.agent_search.kb_search.states import MainState
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.db.chat import log_agent_sub_question_results
|
||||
from onyx.utils.logger import setup_logger
|
||||
# from onyx.agents.agent_search.kb_search.states import MainOutput
|
||||
# from onyx.agents.agent_search.kb_search.states import MainState
|
||||
# from onyx.agents.agent_search.models import GraphConfig
|
||||
# from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
# get_langgraph_node_log_string,
|
||||
# )
|
||||
# from onyx.db.chat import log_agent_sub_question_results
|
||||
# from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
# logger = setup_logger()
|
||||
|
||||
|
||||
def log_data(
|
||||
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> MainOutput:
|
||||
"""
|
||||
LangGraph node to start the agentic search process.
|
||||
"""
|
||||
# def log_data(
|
||||
# state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
# ) -> MainOutput:
|
||||
# """
|
||||
# LangGraph node to start the agentic search process.
|
||||
# """
|
||||
|
||||
node_start_time = datetime.now()
|
||||
# node_start_time = datetime.now()
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
# graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
|
||||
search_tool = graph_config.tooling.search_tool
|
||||
if search_tool is None:
|
||||
raise ValueError("Search tool is not set")
|
||||
# search_tool = graph_config.tooling.search_tool
|
||||
# if search_tool is None:
|
||||
# raise ValueError("Search tool is not set")
|
||||
|
||||
# commit original db_session
|
||||
# # commit original db_session
|
||||
|
||||
query_db_session = graph_config.persistence.db_session
|
||||
query_db_session.commit()
|
||||
# query_db_session = graph_config.persistence.db_session
|
||||
# query_db_session.commit()
|
||||
|
||||
chat_session_id = graph_config.persistence.chat_session_id
|
||||
primary_message_id = graph_config.persistence.message_id
|
||||
sub_question_answer_results = state.step_results
|
||||
# chat_session_id = graph_config.persistence.chat_session_id
|
||||
# primary_message_id = graph_config.persistence.message_id
|
||||
# sub_question_answer_results = state.step_results
|
||||
|
||||
log_agent_sub_question_results(
|
||||
db_session=query_db_session,
|
||||
chat_session_id=chat_session_id,
|
||||
primary_message_id=primary_message_id,
|
||||
sub_question_answer_results=sub_question_answer_results,
|
||||
)
|
||||
# log_agent_sub_question_results(
|
||||
# db_session=query_db_session,
|
||||
# chat_session_id=chat_session_id,
|
||||
# primary_message_id=primary_message_id,
|
||||
# sub_question_answer_results=sub_question_answer_results,
|
||||
# )
|
||||
|
||||
return MainOutput(
|
||||
final_answer=state.final_answer,
|
||||
retrieved_documents=state.retrieved_documents,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="main",
|
||||
node_name="query completed",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
# return MainOutput(
|
||||
# final_answer=state.final_answer,
|
||||
# retrieved_documents=state.retrieved_documents,
|
||||
# log_messages=[
|
||||
# get_langgraph_node_log_string(
|
||||
# graph_component="main",
|
||||
# node_name="query completed",
|
||||
# node_start_time=node_start_time,
|
||||
# )
|
||||
# ],
|
||||
# )
|
||||
|
||||
@@ -1,65 +1,47 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
# from datetime import datetime
|
||||
# from typing import cast
|
||||
|
||||
from onyx.chat.models import LlmDoc
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.kg_configs import KG_RESEARCH_NUM_RETRIEVED_DOCS
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.tools.models import SearchToolOverrideKwargs
|
||||
from onyx.tools.tool_implementations.search.search_tool import (
|
||||
FINAL_CONTEXT_DOCUMENTS_ID,
|
||||
)
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
# from onyx.chat.models import LlmDoc
|
||||
# from onyx.configs.constants import DocumentSource
|
||||
# from onyx.configs.kg_configs import KG_RESEARCH_NUM_RETRIEVED_DOCS
|
||||
# from onyx.context.search.models import InferenceSection
|
||||
# from onyx.tools.models import SearchToolOverrideKwargs
|
||||
# from onyx.tools.tool_implementations.search.search_tool import (
|
||||
# FINAL_CONTEXT_DOCUMENTS_ID,
|
||||
# )
|
||||
# from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
|
||||
|
||||
def research(
|
||||
question: str,
|
||||
search_tool: SearchTool,
|
||||
document_sources: list[DocumentSource] | None = None,
|
||||
time_cutoff: datetime | None = None,
|
||||
kg_entities: list[str] | None = None,
|
||||
kg_relationships: list[str] | None = None,
|
||||
kg_terms: list[str] | None = None,
|
||||
kg_sources: list[str] | None = None,
|
||||
kg_chunk_id_zero_only: bool = False,
|
||||
inference_sections_only: bool = False,
|
||||
) -> list[LlmDoc] | list[InferenceSection]:
|
||||
# new db session to avoid concurrency issues
|
||||
# def research(
|
||||
# question: str,
|
||||
# search_tool: SearchTool,
|
||||
# document_sources: list[DocumentSource] | None = None,
|
||||
# time_cutoff: datetime | None = None,
|
||||
# kg_entities: list[str] | None = None,
|
||||
# kg_relationships: list[str] | None = None,
|
||||
# kg_terms: list[str] | None = None,
|
||||
# kg_sources: list[str] | None = None,
|
||||
# kg_chunk_id_zero_only: bool = False,
|
||||
# inference_sections_only: bool = False,
|
||||
# ) -> list[LlmDoc] | list[InferenceSection]:
|
||||
# # new db session to avoid concurrency issues
|
||||
|
||||
callback_container: list[list[InferenceSection]] = []
|
||||
retrieved_docs: list[LlmDoc] | list[InferenceSection] = []
|
||||
# retrieved_docs: list[LlmDoc] | list[InferenceSection] = []
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
for tool_response in search_tool.run(
|
||||
query=question,
|
||||
override_kwargs=SearchToolOverrideKwargs(
|
||||
force_no_rerank=False,
|
||||
alternate_db_session=db_session,
|
||||
retrieved_sections_callback=callback_container.append,
|
||||
skip_query_analysis=True,
|
||||
document_sources=document_sources,
|
||||
time_cutoff=time_cutoff,
|
||||
kg_entities=kg_entities,
|
||||
kg_relationships=kg_relationships,
|
||||
kg_terms=kg_terms,
|
||||
kg_sources=kg_sources,
|
||||
kg_chunk_id_zero_only=kg_chunk_id_zero_only,
|
||||
),
|
||||
):
|
||||
if (
|
||||
inference_sections_only
|
||||
and tool_response.id == "search_response_summary"
|
||||
):
|
||||
retrieved_docs = tool_response.response.top_sections[
|
||||
:KG_RESEARCH_NUM_RETRIEVED_DOCS
|
||||
]
|
||||
retrieved_docs = cast(list[InferenceSection], retrieved_docs)
|
||||
break
|
||||
# get retrieved docs to send to the rest of the graph
|
||||
elif tool_response.id == FINAL_CONTEXT_DOCUMENTS_ID:
|
||||
retrieved_docs = cast(list[LlmDoc], tool_response.response)[
|
||||
:KG_RESEARCH_NUM_RETRIEVED_DOCS
|
||||
]
|
||||
break
|
||||
return retrieved_docs
|
||||
# for tool_response in search_tool.run(
|
||||
# query=question,
|
||||
# override_kwargs=SearchToolOverrideKwargs(original_query=question),
|
||||
# ):
|
||||
# if inference_sections_only and tool_response.id == "search_response_summary":
|
||||
# retrieved_docs = tool_response.response.top_sections[
|
||||
# :KG_RESEARCH_NUM_RETRIEVED_DOCS
|
||||
# ]
|
||||
# retrieved_docs = cast(list[InferenceSection], retrieved_docs)
|
||||
# break
|
||||
# # get retrieved docs to send to the rest of the graph
|
||||
# elif tool_response.id == FINAL_CONTEXT_DOCUMENTS_ID:
|
||||
# retrieved_docs = cast(list[LlmDoc], tool_response.response)[
|
||||
# :KG_RESEARCH_NUM_RETRIEVED_DOCS
|
||||
# ]
|
||||
# break
|
||||
# return retrieved_docs
|
||||
|
||||
@@ -1,191 +1,191 @@
|
||||
from enum import Enum
|
||||
from operator import add
|
||||
from typing import Annotated
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import TypedDict
|
||||
# from enum import Enum
|
||||
# from operator import add
|
||||
# from typing import Annotated
|
||||
# from typing import Any
|
||||
# from typing import Dict
|
||||
# from typing import TypedDict
|
||||
|
||||
from pydantic import BaseModel
|
||||
# from pydantic import BaseModel
|
||||
|
||||
from onyx.agents.agent_search.core_state import CoreState
|
||||
from onyx.agents.agent_search.orchestration.states import ToolCallUpdate
|
||||
from onyx.agents.agent_search.orchestration.states import ToolChoiceInput
|
||||
from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import QueryRetrievalResult
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import SubQuestionAnswerResults
|
||||
from onyx.context.search.models import InferenceSection
|
||||
# from onyx.agents.agent_search.core_state import CoreState
|
||||
# from onyx.agents.agent_search.orchestration.states import ToolCallUpdate
|
||||
# from onyx.agents.agent_search.orchestration.states import ToolChoiceInput
|
||||
# from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
|
||||
# from onyx.agents.agent_search.shared_graph_utils.models import QueryRetrievalResult
|
||||
# from onyx.agents.agent_search.shared_graph_utils.models import SubQuestionAnswerResults
|
||||
# from onyx.context.search.models import InferenceSection
|
||||
|
||||
|
||||
### States ###
|
||||
# ### States ###
|
||||
|
||||
|
||||
class StepResults(BaseModel):
|
||||
question: str
|
||||
question_id: str
|
||||
answer: str
|
||||
sub_query_retrieval_results: list[QueryRetrievalResult]
|
||||
verified_reranked_documents: list[InferenceSection]
|
||||
context_documents: list[InferenceSection]
|
||||
cited_documents: list[InferenceSection]
|
||||
# class StepResults(BaseModel):
|
||||
# question: str
|
||||
# question_id: str
|
||||
# answer: str
|
||||
# sub_query_retrieval_results: list[QueryRetrievalResult]
|
||||
# verified_reranked_documents: list[InferenceSection]
|
||||
# context_documents: list[InferenceSection]
|
||||
# cited_documents: list[InferenceSection]
|
||||
|
||||
|
||||
class LoggerUpdate(BaseModel):
|
||||
log_messages: Annotated[list[str], add] = []
|
||||
step_results: Annotated[list[SubQuestionAnswerResults], add]
|
||||
remarks: Annotated[list[str], add] = []
|
||||
# class LoggerUpdate(BaseModel):
|
||||
# log_messages: Annotated[list[str], add] = []
|
||||
# step_results: Annotated[list[SubQuestionAnswerResults], add]
|
||||
# remarks: Annotated[list[str], add] = []
|
||||
|
||||
|
||||
class KGFilterConstructionResults(BaseModel):
|
||||
global_entity_filters: list[str]
|
||||
global_relationship_filters: list[str]
|
||||
local_entity_filters: list[list[str]]
|
||||
source_document_filters: list[str]
|
||||
structure: list[str]
|
||||
# class KGFilterConstructionResults(BaseModel):
|
||||
# global_entity_filters: list[str]
|
||||
# global_relationship_filters: list[str]
|
||||
# local_entity_filters: list[list[str]]
|
||||
# source_document_filters: list[str]
|
||||
# structure: list[str]
|
||||
|
||||
|
||||
class KGSearchType(Enum):
|
||||
SEARCH = "SEARCH"
|
||||
SQL = "SQL"
|
||||
# class KGSearchType(Enum):
|
||||
# SEARCH = "SEARCH"
|
||||
# SQL = "SQL"
|
||||
|
||||
|
||||
class KGAnswerStrategy(Enum):
|
||||
DEEP = "DEEP"
|
||||
SIMPLE = "SIMPLE"
|
||||
# class KGAnswerStrategy(Enum):
|
||||
# DEEP = "DEEP"
|
||||
# SIMPLE = "SIMPLE"
|
||||
|
||||
|
||||
class KGSourceDivisionType(Enum):
|
||||
SOURCE = "SOURCE"
|
||||
ENTITY = "ENTITY"
|
||||
# class KGSourceDivisionType(Enum):
|
||||
# SOURCE = "SOURCE"
|
||||
# ENTITY = "ENTITY"
|
||||
|
||||
|
||||
class KGRelationshipDetection(Enum):
|
||||
RELATIONSHIPS = "RELATIONSHIPS"
|
||||
NO_RELATIONSHIPS = "NO_RELATIONSHIPS"
|
||||
# class KGRelationshipDetection(Enum):
|
||||
# RELATIONSHIPS = "RELATIONSHIPS"
|
||||
# NO_RELATIONSHIPS = "NO_RELATIONSHIPS"
|
||||
|
||||
|
||||
class KGAnswerFormat(Enum):
|
||||
LIST = "LIST"
|
||||
TEXT = "TEXT"
|
||||
# class KGAnswerFormat(Enum):
|
||||
# LIST = "LIST"
|
||||
# TEXT = "TEXT"
|
||||
|
||||
|
||||
class YesNoEnum(str, Enum):
|
||||
YES = "yes"
|
||||
NO = "no"
|
||||
# class YesNoEnum(str, Enum):
|
||||
# YES = "yes"
|
||||
# NO = "no"
|
||||
|
||||
|
||||
class AnalysisUpdate(LoggerUpdate):
|
||||
normalized_core_entities: list[str] = []
|
||||
normalized_core_relationships: list[str] = []
|
||||
entity_normalization_map: dict[str, str] = {}
|
||||
relationship_normalization_map: dict[str, str] = {}
|
||||
query_graph_entities_no_attributes: list[str] = []
|
||||
query_graph_entities_w_attributes: list[str] = []
|
||||
query_graph_relationships: list[str] = []
|
||||
normalized_terms: list[str] = []
|
||||
normalized_time_filter: str | None = None
|
||||
strategy: KGAnswerStrategy | None = None
|
||||
output_format: KGAnswerFormat | None = None
|
||||
broken_down_question: str | None = None
|
||||
divide_and_conquer: YesNoEnum | None = None
|
||||
single_doc_id: str | None = None
|
||||
search_type: KGSearchType | None = None
|
||||
query_type: str | None = None
|
||||
# class AnalysisUpdate(LoggerUpdate):
|
||||
# normalized_core_entities: list[str] = []
|
||||
# normalized_core_relationships: list[str] = []
|
||||
# entity_normalization_map: dict[str, str] = {}
|
||||
# relationship_normalization_map: dict[str, str] = {}
|
||||
# query_graph_entities_no_attributes: list[str] = []
|
||||
# query_graph_entities_w_attributes: list[str] = []
|
||||
# query_graph_relationships: list[str] = []
|
||||
# normalized_terms: list[str] = []
|
||||
# normalized_time_filter: str | None = None
|
||||
# strategy: KGAnswerStrategy | None = None
|
||||
# output_format: KGAnswerFormat | None = None
|
||||
# broken_down_question: str | None = None
|
||||
# divide_and_conquer: YesNoEnum | None = None
|
||||
# single_doc_id: str | None = None
|
||||
# search_type: KGSearchType | None = None
|
||||
# query_type: str | None = None
|
||||
|
||||
|
||||
class SQLSimpleGenerationUpdate(LoggerUpdate):
|
||||
sql_query: str | None = None
|
||||
sql_query_results: list[Dict[Any, Any]] | None = None
|
||||
individualized_sql_query: str | None = None
|
||||
individualized_query_results: list[Dict[Any, Any]] | None = None
|
||||
source_documents_sql: str | None = None
|
||||
source_document_results: list[str] | None = None
|
||||
updated_strategy: KGAnswerStrategy | None = None
|
||||
# class SQLSimpleGenerationUpdate(LoggerUpdate):
|
||||
# sql_query: str | None = None
|
||||
# sql_query_results: list[Dict[Any, Any]] | None = None
|
||||
# individualized_sql_query: str | None = None
|
||||
# individualized_query_results: list[Dict[Any, Any]] | None = None
|
||||
# source_documents_sql: str | None = None
|
||||
# source_document_results: list[str] | None = None
|
||||
# updated_strategy: KGAnswerStrategy | None = None
|
||||
|
||||
|
||||
class ConsolidatedResearchUpdate(LoggerUpdate):
|
||||
consolidated_research_object_results_str: str | None = None
|
||||
# class ConsolidatedResearchUpdate(LoggerUpdate):
|
||||
# consolidated_research_object_results_str: str | None = None
|
||||
|
||||
|
||||
class DeepSearchFilterUpdate(LoggerUpdate):
|
||||
vespa_filter_results: KGFilterConstructionResults | None = None
|
||||
div_con_entities: list[str] | None = None
|
||||
source_division: bool | None = None
|
||||
global_entity_filters: list[str] | None = None
|
||||
global_relationship_filters: list[str] | None = None
|
||||
local_entity_filters: list[list[str]] | None = None
|
||||
source_filters: list[str] | None = None
|
||||
# class DeepSearchFilterUpdate(LoggerUpdate):
|
||||
# vespa_filter_results: KGFilterConstructionResults | None = None
|
||||
# div_con_entities: list[str] | None = None
|
||||
# source_division: bool | None = None
|
||||
# global_entity_filters: list[str] | None = None
|
||||
# global_relationship_filters: list[str] | None = None
|
||||
# local_entity_filters: list[list[str]] | None = None
|
||||
# source_filters: list[str] | None = None
|
||||
|
||||
|
||||
class ResearchObjectOutput(LoggerUpdate):
|
||||
research_object_results: Annotated[list[dict[str, Any]], add] = []
|
||||
# class ResearchObjectOutput(LoggerUpdate):
|
||||
# research_object_results: Annotated[list[dict[str, Any]], add] = []
|
||||
|
||||
|
||||
class EntityRelationshipExtractionUpdate(LoggerUpdate):
|
||||
entities_types_str: str = ""
|
||||
relationship_types_str: str = ""
|
||||
extracted_entities_w_attributes: list[str] = []
|
||||
extracted_entities_no_attributes: list[str] = []
|
||||
extracted_relationships: list[str] = []
|
||||
time_filter: str | None = None
|
||||
kg_doc_temp_view_name: str | None = None
|
||||
kg_rel_temp_view_name: str | None = None
|
||||
kg_entity_temp_view_name: str | None = None
|
||||
# class EntityRelationshipExtractionUpdate(LoggerUpdate):
|
||||
# entities_types_str: str = ""
|
||||
# relationship_types_str: str = ""
|
||||
# extracted_entities_w_attributes: list[str] = []
|
||||
# extracted_entities_no_attributes: list[str] = []
|
||||
# extracted_relationships: list[str] = []
|
||||
# time_filter: str | None = None
|
||||
# kg_doc_temp_view_name: str | None = None
|
||||
# kg_rel_temp_view_name: str | None = None
|
||||
# kg_entity_temp_view_name: str | None = None
|
||||
|
||||
|
||||
class ResultsDataUpdate(LoggerUpdate):
|
||||
query_results_data_str: str | None = None
|
||||
individualized_query_results_data_str: str | None = None
|
||||
reference_results_str: str | None = None
|
||||
# class ResultsDataUpdate(LoggerUpdate):
|
||||
# query_results_data_str: str | None = None
|
||||
# individualized_query_results_data_str: str | None = None
|
||||
# reference_results_str: str | None = None
|
||||
|
||||
|
||||
class ResearchObjectUpdate(LoggerUpdate):
|
||||
research_object_results: Annotated[list[dict[str, Any]], add] = []
|
||||
# class ResearchObjectUpdate(LoggerUpdate):
|
||||
# research_object_results: Annotated[list[dict[str, Any]], add] = []
|
||||
|
||||
|
||||
## Graph Input State
|
||||
class MainInput(CoreState):
|
||||
question: str
|
||||
individual_flow: bool = True # used for UI display purposes
|
||||
# ## Graph Input State
|
||||
# class MainInput(CoreState):
|
||||
# question: str
|
||||
# individual_flow: bool = True # used for UI display purposes
|
||||
|
||||
|
||||
class FinalAnswerUpdate(LoggerUpdate):
|
||||
final_answer: str | None = None
|
||||
retrieved_documents: list[InferenceSection] | None = None
|
||||
# class FinalAnswerUpdate(LoggerUpdate):
|
||||
# final_answer: str | None = None
|
||||
# retrieved_documents: list[InferenceSection] | None = None
|
||||
|
||||
|
||||
## Graph State
|
||||
class MainState(
|
||||
# This includes the core state
|
||||
MainInput,
|
||||
ToolChoiceInput,
|
||||
ToolCallUpdate,
|
||||
ToolChoiceUpdate,
|
||||
EntityRelationshipExtractionUpdate,
|
||||
AnalysisUpdate,
|
||||
SQLSimpleGenerationUpdate,
|
||||
ResultsDataUpdate,
|
||||
ResearchObjectOutput,
|
||||
DeepSearchFilterUpdate,
|
||||
ResearchObjectUpdate,
|
||||
ConsolidatedResearchUpdate,
|
||||
FinalAnswerUpdate,
|
||||
):
|
||||
pass
|
||||
# ## Graph State
|
||||
# class MainState(
|
||||
# # This includes the core state
|
||||
# MainInput,
|
||||
# ToolChoiceInput,
|
||||
# ToolCallUpdate,
|
||||
# ToolChoiceUpdate,
|
||||
# EntityRelationshipExtractionUpdate,
|
||||
# AnalysisUpdate,
|
||||
# SQLSimpleGenerationUpdate,
|
||||
# ResultsDataUpdate,
|
||||
# ResearchObjectOutput,
|
||||
# DeepSearchFilterUpdate,
|
||||
# ResearchObjectUpdate,
|
||||
# ConsolidatedResearchUpdate,
|
||||
# FinalAnswerUpdate,
|
||||
# ):
|
||||
# pass
|
||||
|
||||
|
||||
## Graph Output State - presently not used
|
||||
class MainOutput(TypedDict):
|
||||
log_messages: list[str]
|
||||
final_answer: str | None
|
||||
retrieved_documents: list[InferenceSection] | None
|
||||
# ## Graph Output State - presently not used
|
||||
# class MainOutput(TypedDict):
|
||||
# log_messages: list[str]
|
||||
# final_answer: str | None
|
||||
# retrieved_documents: list[InferenceSection] | None
|
||||
|
||||
|
||||
class ResearchObjectInput(LoggerUpdate):
|
||||
research_nr: int
|
||||
entity: str
|
||||
broken_down_question: str
|
||||
vespa_filter_results: KGFilterConstructionResults
|
||||
source_division: bool | None
|
||||
source_entity_filters: list[str] | None
|
||||
segment_type: str
|
||||
individual_flow: bool = True # used for UI display purposes
|
||||
# class ResearchObjectInput(LoggerUpdate):
|
||||
# research_nr: int
|
||||
# entity: str
|
||||
# broken_down_question: str
|
||||
# vespa_filter_results: KGFilterConstructionResults
|
||||
# source_division: bool | None
|
||||
# source_entity_filters: list[str] | None
|
||||
# segment_type: str
|
||||
# individual_flow: bool = True # used for UI display purposes
|
||||
|
||||
@@ -1,33 +1,33 @@
|
||||
from onyx.agents.agent_search.kb_search.models import KGSteps
|
||||
# from onyx.agents.agent_search.kb_search.models import KGSteps
|
||||
|
||||
KG_SEARCH_STEP_DESCRIPTIONS: dict[int, KGSteps] = {
|
||||
1: KGSteps(
|
||||
description="Analyzing the question...",
|
||||
activities=[
|
||||
"Entities in Query",
|
||||
"Relationships in Query",
|
||||
"Terms in Query",
|
||||
"Time Filters",
|
||||
],
|
||||
),
|
||||
2: KGSteps(
|
||||
description="Planning the response approach...",
|
||||
activities=["Query Execution Strategy", "Answer Format"],
|
||||
),
|
||||
3: KGSteps(
|
||||
description="Querying the Knowledge Graph...",
|
||||
activities=[
|
||||
"Knowledge Graph Query",
|
||||
"Knowledge Graph Query Results",
|
||||
"Query for Source Documents",
|
||||
"Source Documents",
|
||||
],
|
||||
),
|
||||
4: KGSteps(
|
||||
description="Conducting further research on source documents...", activities=[]
|
||||
),
|
||||
}
|
||||
# KG_SEARCH_STEP_DESCRIPTIONS: dict[int, KGSteps] = {
|
||||
# 1: KGSteps(
|
||||
# description="Analyzing the question...",
|
||||
# activities=[
|
||||
# "Entities in Query",
|
||||
# "Relationships in Query",
|
||||
# "Terms in Query",
|
||||
# "Time Filters",
|
||||
# ],
|
||||
# ),
|
||||
# 2: KGSteps(
|
||||
# description="Planning the response approach...",
|
||||
# activities=["Query Execution Strategy", "Answer Format"],
|
||||
# ),
|
||||
# 3: KGSteps(
|
||||
# description="Querying the Knowledge Graph...",
|
||||
# activities=[
|
||||
# "Knowledge Graph Query",
|
||||
# "Knowledge Graph Query Results",
|
||||
# "Query for Source Documents",
|
||||
# "Source Documents",
|
||||
# ],
|
||||
# ),
|
||||
# 4: KGSteps(
|
||||
# description="Conducting further research on source documents...", activities=[]
|
||||
# ),
|
||||
# }
|
||||
|
||||
BASIC_SEARCH_STEP_DESCRIPTIONS: dict[int, KGSteps] = {
|
||||
1: KGSteps(description="Conducting a standard search...", activities=[]),
|
||||
}
|
||||
# BASIC_SEARCH_STEP_DESCRIPTIONS: dict[int, KGSteps] = {
|
||||
# 1: KGSteps(description="Conducting a standard search...", activities=[]),
|
||||
# }
|
||||
|
||||
@@ -1,88 +1,89 @@
|
||||
from uuid import UUID
|
||||
# from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
from sqlalchemy.orm import Session
|
||||
# from pydantic import BaseModel
|
||||
# from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.agents.agent_search.dr.enums import ResearchType
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.context.search.models import RerankingDetails
|
||||
from onyx.db.models import Persona
|
||||
from onyx.file_store.utils import InMemoryChatFile
|
||||
from onyx.kg.models import KGConfigSettings
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.tools.force import ForceUseTool
|
||||
from onyx.tools.tool import Tool
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
# from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
# from onyx.context.search.models import RerankingDetails
|
||||
# from onyx.db.models import Persona
|
||||
# from onyx.file_store.utils import InMemoryChatFile
|
||||
# from onyx.kg.models import KGConfigSettings
|
||||
# from onyx.llm.interfaces import LLM
|
||||
# from onyx.tools.force import ForceUseTool
|
||||
# from onyx.tools.tool import Tool
|
||||
# from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
|
||||
|
||||
class GraphInputs(BaseModel):
|
||||
"""Input data required for the graph execution"""
|
||||
# class GraphInputs(BaseModel):
|
||||
# """Input data required for the graph execution"""
|
||||
|
||||
persona: Persona | None = None
|
||||
rerank_settings: RerankingDetails | None = None
|
||||
prompt_builder: AnswerPromptBuilder
|
||||
files: list[InMemoryChatFile] | None = None
|
||||
structured_response_format: dict | None = None
|
||||
project_instructions: str | None = None
|
||||
# persona: Persona | None = None
|
||||
# rerank_settings: RerankingDetails | None = None
|
||||
# prompt_builder: AnswerPromptBuilder
|
||||
# files: list[InMemoryChatFile] | None = None
|
||||
# structured_response_format: dict | None = None
|
||||
# project_instructions: str | None = None
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
# class Config:
|
||||
# arbitrary_types_allowed = True
|
||||
|
||||
|
||||
class GraphTooling(BaseModel):
|
||||
"""Tools and LLMs available to the graph"""
|
||||
# class GraphTooling(BaseModel):
|
||||
# """Tools and LLMs available to the graph"""
|
||||
|
||||
primary_llm: LLM
|
||||
fast_llm: LLM
|
||||
search_tool: SearchTool | None = None
|
||||
tools: list[Tool]
|
||||
# Whether to force use of a tool, or to
|
||||
# force tool args IF the tool is used
|
||||
force_use_tool: ForceUseTool
|
||||
using_tool_calling_llm: bool = False
|
||||
# primary_llm: LLM
|
||||
# fast_llm: LLM
|
||||
# search_tool: SearchTool | None = None
|
||||
# tools: list[Tool]
|
||||
# # Whether to force use of a tool, or to
|
||||
# # force tool args IF the tool is used
|
||||
# force_use_tool: ForceUseTool
|
||||
# using_tool_calling_llm: bool = False
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
# class Config:
|
||||
# arbitrary_types_allowed = True
|
||||
|
||||
|
||||
class GraphPersistence(BaseModel):
|
||||
"""Configuration for data persistence"""
|
||||
# class GraphPersistence(BaseModel):
|
||||
# """Configuration for data persistence"""
|
||||
|
||||
chat_session_id: UUID
|
||||
# The message ID of the to-be-created first agent message
|
||||
# in response to the user message that triggered the Pro Search
|
||||
message_id: int
|
||||
# chat_session_id: UUID
|
||||
# # The message ID of the to-be-created first agent message
|
||||
# # in response to the user message that triggered the Pro Search
|
||||
# message_id: int
|
||||
|
||||
# The database session the user and initial agent
|
||||
# message were flushed to; only needed for agentic search
|
||||
db_session: Session
|
||||
# # The database session the user and initial agent
|
||||
# # message were flushed to; only needed for agentic search
|
||||
# db_session: Session
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
# class Config:
|
||||
# arbitrary_types_allowed = True
|
||||
|
||||
|
||||
class GraphSearchConfig(BaseModel):
|
||||
"""Configuration controlling search behavior"""
|
||||
# class GraphSearchConfig(BaseModel):
|
||||
# """Configuration controlling search behavior"""
|
||||
|
||||
use_agentic_search: bool = False
|
||||
# Whether to perform initial search to inform decomposition
|
||||
perform_initial_search_decomposition: bool = True
|
||||
# use_agentic_search: bool = False
|
||||
# # Whether to perform initial search to inform decomposition
|
||||
# perform_initial_search_decomposition: bool = True
|
||||
|
||||
# Whether to allow creation of refinement questions (and entity extraction, etc.)
|
||||
allow_refinement: bool = True
|
||||
skip_gen_ai_answer_generation: bool = False
|
||||
allow_agent_reranking: bool = False
|
||||
kg_config_settings: KGConfigSettings = KGConfigSettings()
|
||||
research_type: ResearchType = ResearchType.THOUGHTFUL
|
||||
# # Whether to allow creation of refinement questions (and entity extraction, etc.)
|
||||
# allow_refinement: bool = True
|
||||
# skip_gen_ai_answer_generation: bool = False
|
||||
# allow_agent_reranking: bool = False
|
||||
# kg_config_settings: KGConfigSettings = KGConfigSettings()
|
||||
|
||||
|
||||
class GraphConfig(BaseModel):
|
||||
"""
|
||||
Main container for data needed for Langgraph execution
|
||||
"""
|
||||
# class GraphConfig(BaseModel):
|
||||
# """
|
||||
# Main container for data needed for Langgraph execution
|
||||
# """
|
||||
|
||||
inputs: GraphInputs
|
||||
tooling: GraphTooling
|
||||
behavior: GraphSearchConfig
|
||||
# Only needed for agentic search
|
||||
persistence: GraphPersistence
|
||||
# inputs: GraphInputs
|
||||
# tooling: GraphTooling
|
||||
# behavior: GraphSearchConfig
|
||||
# # Only needed for agentic search
|
||||
# persistence: GraphPersistence
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
# class Config:
|
||||
# arbitrary_types_allowed = True
|
||||
|
||||
@@ -1,50 +1,50 @@
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
# from pydantic import BaseModel
|
||||
|
||||
from onyx.chat.prompt_builder.schemas import PromptSnapshot
|
||||
from onyx.tools.message import ToolCallSummary
|
||||
from onyx.tools.models import SearchToolOverrideKwargs
|
||||
from onyx.tools.models import ToolCallFinalResult
|
||||
from onyx.tools.models import ToolCallKickoff
|
||||
from onyx.tools.models import ToolResponse
|
||||
from onyx.tools.tool import Tool
|
||||
# from onyx.chat.prompt_builder.schemas import PromptSnapshot
|
||||
# from onyx.tools.message import ToolCallSummary
|
||||
# from onyx.tools.models import SearchToolOverrideKwargs
|
||||
# from onyx.tools.models import ToolCallFinalResult
|
||||
# from onyx.tools.models import ToolCallKickoff
|
||||
# from onyx.tools.models import ToolResponse
|
||||
# from onyx.tools.tool import Tool
|
||||
|
||||
|
||||
# TODO: adapt the tool choice/tool call to allow for parallel tool calls by
|
||||
# creating a subgraph that can be invoked in parallel via Send/Command APIs
|
||||
class ToolChoiceInput(BaseModel):
|
||||
should_stream_answer: bool = True
|
||||
# default to the prompt builder from the config, but
|
||||
# allow overrides for arbitrary tool calls
|
||||
prompt_snapshot: PromptSnapshot | None = None
|
||||
# # TODO: adapt the tool choice/tool call to allow for parallel tool calls by
|
||||
# # creating a subgraph that can be invoked in parallel via Send/Command APIs
|
||||
# class ToolChoiceInput(BaseModel):
|
||||
# should_stream_answer: bool = True
|
||||
# # default to the prompt builder from the config, but
|
||||
# # allow overrides for arbitrary tool calls
|
||||
# prompt_snapshot: PromptSnapshot | None = None
|
||||
|
||||
# names of tools to use for tool calling. Filters the tools available in the config
|
||||
tools: list[str] = []
|
||||
# # names of tools to use for tool calling. Filters the tools available in the config
|
||||
# tools: list[str] = []
|
||||
|
||||
|
||||
class ToolCallOutput(BaseModel):
|
||||
tool_call_summary: ToolCallSummary
|
||||
tool_call_kickoff: ToolCallKickoff
|
||||
tool_call_responses: list[ToolResponse]
|
||||
tool_call_final_result: ToolCallFinalResult
|
||||
# class ToolCallOutput(BaseModel):
|
||||
# tool_call_summary: ToolCallSummary
|
||||
# tool_call_kickoff: ToolCallKickoff
|
||||
# tool_call_responses: list[ToolResponse]
|
||||
# tool_call_final_result: ToolCallFinalResult
|
||||
|
||||
|
||||
class ToolCallUpdate(BaseModel):
|
||||
tool_call_output: ToolCallOutput | None = None
|
||||
# class ToolCallUpdate(BaseModel):
|
||||
# tool_call_output: ToolCallOutput | None = None
|
||||
|
||||
|
||||
class ToolChoice(BaseModel):
|
||||
tool: Tool
|
||||
tool_args: dict
|
||||
id: str | None
|
||||
search_tool_override_kwargs: SearchToolOverrideKwargs = SearchToolOverrideKwargs()
|
||||
# class ToolChoice(BaseModel):
|
||||
# tool: Tool
|
||||
# tool_args: dict
|
||||
# id: str | None
|
||||
# search_tool_override_kwargs: SearchToolOverrideKwargs = SearchToolOverrideKwargs()
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
# class Config:
|
||||
# arbitrary_types_allowed = True
|
||||
|
||||
|
||||
class ToolChoiceUpdate(BaseModel):
|
||||
tool_choice: ToolChoice | None = None
|
||||
# class ToolChoiceUpdate(BaseModel):
|
||||
# tool_choice: ToolChoice | None = None
|
||||
|
||||
|
||||
class ToolChoiceState(ToolChoiceUpdate, ToolChoiceInput):
|
||||
pass
|
||||
# class ToolChoiceState(ToolChoiceUpdate, ToolChoiceInput):
|
||||
# pass
|
||||
|
||||
@@ -1,93 +1,93 @@
|
||||
from collections.abc import Iterable
|
||||
from typing import cast
|
||||
# from collections.abc import Iterable
|
||||
# from typing import cast
|
||||
|
||||
from langchain_core.runnables.schema import CustomStreamEvent
|
||||
from langchain_core.runnables.schema import StreamEvent
|
||||
from langfuse.langchain import CallbackHandler
|
||||
from langgraph.graph.state import CompiledStateGraph
|
||||
# from langchain_core.runnables.schema import CustomStreamEvent
|
||||
# from langchain_core.runnables.schema import StreamEvent
|
||||
# from langfuse.langchain import CallbackHandler
|
||||
# from langgraph.graph.state import CompiledStateGraph
|
||||
|
||||
from onyx.agents.agent_search.dc_search_analysis.graph_builder import (
|
||||
divide_and_conquer_graph_builder,
|
||||
)
|
||||
from onyx.agents.agent_search.dc_search_analysis.states import MainInput as DCMainInput
|
||||
from onyx.agents.agent_search.dr.graph_builder import dr_graph_builder
|
||||
from onyx.agents.agent_search.dr.states import MainInput as DRMainInput
|
||||
from onyx.agents.agent_search.kb_search.graph_builder import kb_graph_builder
|
||||
from onyx.agents.agent_search.kb_search.states import MainInput as KBMainInput
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.chat.models import AnswerStream
|
||||
from onyx.configs.app_configs import LANGFUSE_PUBLIC_KEY
|
||||
from onyx.configs.app_configs import LANGFUSE_SECRET_KEY
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.utils.logger import setup_logger
|
||||
# from onyx.agents.agent_search.dc_search_analysis.graph_builder import (
|
||||
# divide_and_conquer_graph_builder,
|
||||
# )
|
||||
# from onyx.agents.agent_search.dc_search_analysis.states import MainInput as DCMainInput
|
||||
# from onyx.agents.agent_search.dr.graph_builder import dr_graph_builder
|
||||
# from onyx.agents.agent_search.dr.states import MainInput as DRMainInput
|
||||
# from onyx.agents.agent_search.kb_search.graph_builder import kb_graph_builder
|
||||
# from onyx.agents.agent_search.kb_search.states import MainInput as KBMainInput
|
||||
# from onyx.agents.agent_search.models import GraphConfig
|
||||
# from onyx.chat.models import AnswerStream
|
||||
# from onyx.configs.app_configs import LANGFUSE_PUBLIC_KEY
|
||||
# from onyx.configs.app_configs import LANGFUSE_SECRET_KEY
|
||||
# from onyx.server.query_and_chat.streaming_models import Packet
|
||||
# from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
GraphInput = DCMainInput | KBMainInput | DRMainInput
|
||||
# logger = setup_logger()
|
||||
# GraphInput = DCMainInput | KBMainInput | DRMainInput
|
||||
|
||||
|
||||
def manage_sync_streaming(
|
||||
compiled_graph: CompiledStateGraph,
|
||||
config: GraphConfig,
|
||||
graph_input: GraphInput,
|
||||
) -> Iterable[StreamEvent]:
|
||||
message_id = config.persistence.message_id if config.persistence else None
|
||||
callbacks: list[CallbackHandler] = []
|
||||
if LANGFUSE_SECRET_KEY and LANGFUSE_PUBLIC_KEY:
|
||||
callbacks.append(CallbackHandler())
|
||||
for event in compiled_graph.stream(
|
||||
stream_mode="custom",
|
||||
input=graph_input,
|
||||
config={
|
||||
"metadata": {"config": config, "thread_id": str(message_id)},
|
||||
"callbacks": callbacks, # type: ignore
|
||||
},
|
||||
):
|
||||
yield cast(CustomStreamEvent, event)
|
||||
# def manage_sync_streaming(
|
||||
# compiled_graph: CompiledStateGraph,
|
||||
# config: GraphConfig,
|
||||
# graph_input: GraphInput,
|
||||
# ) -> Iterable[StreamEvent]:
|
||||
# message_id = config.persistence.message_id if config.persistence else None
|
||||
# callbacks: list[CallbackHandler] = []
|
||||
# if LANGFUSE_SECRET_KEY and LANGFUSE_PUBLIC_KEY:
|
||||
# callbacks.append(CallbackHandler())
|
||||
# for event in compiled_graph.stream(
|
||||
# stream_mode="custom",
|
||||
# input=graph_input,
|
||||
# config={
|
||||
# "metadata": {"config": config, "thread_id": str(message_id)},
|
||||
# "callbacks": callbacks, # type: ignore
|
||||
# },
|
||||
# ):
|
||||
# yield cast(CustomStreamEvent, event)
|
||||
|
||||
|
||||
def run_graph(
|
||||
compiled_graph: CompiledStateGraph,
|
||||
config: GraphConfig,
|
||||
input: GraphInput,
|
||||
) -> AnswerStream:
|
||||
# def run_graph(
|
||||
# compiled_graph: CompiledStateGraph,
|
||||
# config: GraphConfig,
|
||||
# input: GraphInput,
|
||||
# ) -> AnswerStream:
|
||||
|
||||
for event in manage_sync_streaming(
|
||||
compiled_graph=compiled_graph, config=config, graph_input=input
|
||||
):
|
||||
# for event in manage_sync_streaming(
|
||||
# compiled_graph=compiled_graph, config=config, graph_input=input
|
||||
# ):
|
||||
|
||||
yield cast(Packet, event["data"])
|
||||
# yield cast(Packet, event["data"])
|
||||
|
||||
|
||||
def run_kb_graph(
|
||||
config: GraphConfig,
|
||||
) -> AnswerStream:
|
||||
graph = kb_graph_builder()
|
||||
compiled_graph = graph.compile()
|
||||
input = KBMainInput(
|
||||
log_messages=[], question=config.inputs.prompt_builder.raw_user_query
|
||||
)
|
||||
# def run_kb_graph(
|
||||
# config: GraphConfig,
|
||||
# ) -> AnswerStream:
|
||||
# graph = kb_graph_builder()
|
||||
# compiled_graph = graph.compile()
|
||||
# input = KBMainInput(
|
||||
# log_messages=[], question=config.inputs.prompt_builder.raw_user_query
|
||||
# )
|
||||
|
||||
yield from run_graph(compiled_graph, config, input)
|
||||
# yield from run_graph(compiled_graph, config, input)
|
||||
|
||||
|
||||
def run_dr_graph(
|
||||
config: GraphConfig,
|
||||
) -> AnswerStream:
|
||||
graph = dr_graph_builder()
|
||||
compiled_graph = graph.compile()
|
||||
input = DRMainInput(log_messages=[])
|
||||
# def run_dr_graph(
|
||||
# config: GraphConfig,
|
||||
# ) -> AnswerStream:
|
||||
# graph = dr_graph_builder()
|
||||
# compiled_graph = graph.compile()
|
||||
# input = DRMainInput(log_messages=[])
|
||||
|
||||
yield from run_graph(compiled_graph, config, input)
|
||||
# yield from run_graph(compiled_graph, config, input)
|
||||
|
||||
|
||||
def run_dc_graph(
|
||||
config: GraphConfig,
|
||||
) -> AnswerStream:
|
||||
graph = divide_and_conquer_graph_builder()
|
||||
compiled_graph = graph.compile()
|
||||
input = DCMainInput(log_messages=[])
|
||||
config.inputs.prompt_builder.raw_user_query = (
|
||||
config.inputs.prompt_builder.raw_user_query.strip()
|
||||
)
|
||||
return run_graph(compiled_graph, config, input)
|
||||
# def run_dc_graph(
|
||||
# config: GraphConfig,
|
||||
# ) -> AnswerStream:
|
||||
# graph = divide_and_conquer_graph_builder()
|
||||
# compiled_graph = graph.compile()
|
||||
# input = DCMainInput(log_messages=[])
|
||||
# config.inputs.prompt_builder.raw_user_query = (
|
||||
# config.inputs.prompt_builder.raw_user_query.strip()
|
||||
# )
|
||||
# return run_graph(compiled_graph, config, input)
|
||||
|
||||
@@ -1,176 +1,176 @@
|
||||
from langchain.schema import AIMessage
|
||||
from langchain.schema import HumanMessage
|
||||
from langchain.schema import SystemMessage
|
||||
from langchain_core.messages.tool import ToolMessage
|
||||
# from langchain.schema import AIMessage
|
||||
# from langchain.schema import HumanMessage
|
||||
# from langchain.schema import SystemMessage
|
||||
# from langchain_core.messages.tool import ToolMessage
|
||||
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import (
|
||||
AgentPromptEnrichmentComponents,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import format_docs
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_persona_agent_prompt_expressions,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import remove_document_citations
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import summarize_history
|
||||
from onyx.configs.agent_configs import AGENT_MAX_STATIC_HISTORY_WORD_LENGTH
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.llm.interfaces import LLMConfig
|
||||
from onyx.natural_language_processing.utils import get_tokenizer
|
||||
from onyx.natural_language_processing.utils import tokenizer_trim_content
|
||||
from onyx.prompts.agent_search import HISTORY_FRAMING_PROMPT
|
||||
from onyx.prompts.agent_search import SUB_QUESTION_RAG_PROMPT
|
||||
from onyx.prompts.prompt_utils import build_date_time_string
|
||||
from onyx.utils.logger import setup_logger
|
||||
# from onyx.agents.agent_search.models import GraphConfig
|
||||
# from onyx.agents.agent_search.shared_graph_utils.models import (
|
||||
# AgentPromptEnrichmentComponents,
|
||||
# )
|
||||
# from onyx.agents.agent_search.shared_graph_utils.utils import format_docs
|
||||
# from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
# get_persona_agent_prompt_expressions,
|
||||
# )
|
||||
# from onyx.agents.agent_search.shared_graph_utils.utils import remove_document_citations
|
||||
# from onyx.agents.agent_search.shared_graph_utils.utils import summarize_history
|
||||
# from onyx.configs.agent_configs import AGENT_MAX_STATIC_HISTORY_WORD_LENGTH
|
||||
# from onyx.configs.constants import MessageType
|
||||
# from onyx.context.search.models import InferenceSection
|
||||
# from onyx.llm.interfaces import LLMConfig
|
||||
# from onyx.natural_language_processing.utils import get_tokenizer
|
||||
# from onyx.natural_language_processing.utils import tokenizer_trim_content
|
||||
# from onyx.prompts.agent_search import HISTORY_FRAMING_PROMPT
|
||||
# from onyx.prompts.agent_search import SUB_QUESTION_RAG_PROMPT
|
||||
# from onyx.prompts.prompt_utils import build_date_time_string
|
||||
# from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
# logger = setup_logger()
|
||||
|
||||
|
||||
def build_sub_question_answer_prompt(
|
||||
question: str,
|
||||
original_question: str,
|
||||
docs: list[InferenceSection],
|
||||
persona_specification: str,
|
||||
config: LLMConfig,
|
||||
) -> list[SystemMessage | HumanMessage | AIMessage | ToolMessage]:
|
||||
system_message = SystemMessage(
|
||||
content=persona_specification,
|
||||
)
|
||||
# def build_sub_question_answer_prompt(
|
||||
# question: str,
|
||||
# original_question: str,
|
||||
# docs: list[InferenceSection],
|
||||
# persona_specification: str,
|
||||
# config: LLMConfig,
|
||||
# ) -> list[SystemMessage | HumanMessage | AIMessage | ToolMessage]:
|
||||
# system_message = SystemMessage(
|
||||
# content=persona_specification,
|
||||
# )
|
||||
|
||||
date_str = build_date_time_string()
|
||||
# date_str = build_date_time_string()
|
||||
|
||||
docs_str = format_docs(docs)
|
||||
# docs_str = format_docs(docs)
|
||||
|
||||
docs_str = trim_prompt_piece(
|
||||
config=config,
|
||||
prompt_piece=docs_str,
|
||||
reserved_str=SUB_QUESTION_RAG_PROMPT + question + original_question + date_str,
|
||||
)
|
||||
human_message = HumanMessage(
|
||||
content=SUB_QUESTION_RAG_PROMPT.format(
|
||||
question=question,
|
||||
original_question=original_question,
|
||||
context=docs_str,
|
||||
date_prompt=date_str,
|
||||
)
|
||||
)
|
||||
# docs_str = trim_prompt_piece(
|
||||
# config=config,
|
||||
# prompt_piece=docs_str,
|
||||
# reserved_str=SUB_QUESTION_RAG_PROMPT + question + original_question + date_str,
|
||||
# )
|
||||
# human_message = HumanMessage(
|
||||
# content=SUB_QUESTION_RAG_PROMPT.format(
|
||||
# question=question,
|
||||
# original_question=original_question,
|
||||
# context=docs_str,
|
||||
# date_prompt=date_str,
|
||||
# )
|
||||
# )
|
||||
|
||||
return [system_message, human_message]
|
||||
# return [system_message, human_message]
|
||||
|
||||
|
||||
def trim_prompt_piece(config: LLMConfig, prompt_piece: str, reserved_str: str) -> str:
|
||||
# no need to trim if a conservative estimate of one token
|
||||
# per character is already less than the max tokens
|
||||
if len(prompt_piece) + len(reserved_str) < config.max_input_tokens:
|
||||
return prompt_piece
|
||||
# def trim_prompt_piece(config: LLMConfig, prompt_piece: str, reserved_str: str) -> str:
|
||||
# # no need to trim if a conservative estimate of one token
|
||||
# # per character is already less than the max tokens
|
||||
# if len(prompt_piece) + len(reserved_str) < config.max_input_tokens:
|
||||
# return prompt_piece
|
||||
|
||||
llm_tokenizer = get_tokenizer(
|
||||
provider_type=config.model_provider,
|
||||
model_name=config.model_name,
|
||||
)
|
||||
# llm_tokenizer = get_tokenizer(
|
||||
# provider_type=config.model_provider,
|
||||
# model_name=config.model_name,
|
||||
# )
|
||||
|
||||
# slightly conservative trimming
|
||||
return tokenizer_trim_content(
|
||||
content=prompt_piece,
|
||||
desired_length=config.max_input_tokens
|
||||
- len(llm_tokenizer.encode(reserved_str)),
|
||||
tokenizer=llm_tokenizer,
|
||||
)
|
||||
# # slightly conservative trimming
|
||||
# return tokenizer_trim_content(
|
||||
# content=prompt_piece,
|
||||
# desired_length=config.max_input_tokens
|
||||
# - len(llm_tokenizer.encode(reserved_str)),
|
||||
# tokenizer=llm_tokenizer,
|
||||
# )
|
||||
|
||||
|
||||
def build_history_prompt(config: GraphConfig, question: str) -> str:
|
||||
prompt_builder = config.inputs.prompt_builder
|
||||
persona_base = get_persona_agent_prompt_expressions(
|
||||
config.inputs.persona, db_session=config.persistence.db_session
|
||||
).base_prompt
|
||||
# def build_history_prompt(config: GraphConfig, question: str) -> str:
|
||||
# prompt_builder = config.inputs.prompt_builder
|
||||
# persona_base = get_persona_agent_prompt_expressions(
|
||||
# config.inputs.persona, db_session=config.persistence.db_session
|
||||
# ).base_prompt
|
||||
|
||||
if prompt_builder is None:
|
||||
return ""
|
||||
# if prompt_builder is None:
|
||||
# return ""
|
||||
|
||||
if prompt_builder.single_message_history is not None:
|
||||
history = prompt_builder.single_message_history
|
||||
else:
|
||||
history_components = []
|
||||
previous_message_type = None
|
||||
for message in prompt_builder.raw_message_history:
|
||||
if message.message_type == MessageType.USER:
|
||||
history_components.append(f"User: {message.message}\n")
|
||||
previous_message_type = MessageType.USER
|
||||
elif message.message_type == MessageType.ASSISTANT:
|
||||
# Previously there could be multiple assistant messages in a row
|
||||
# Now this is handled at the message history construction
|
||||
assert previous_message_type is not MessageType.ASSISTANT
|
||||
history_components.append(f"You/Agent: {message.message}\n")
|
||||
previous_message_type = MessageType.ASSISTANT
|
||||
else:
|
||||
# Other message types are not included here, currently there should be no other message types
|
||||
logger.error(
|
||||
f"Unhandled message type: {message.message_type} with message: {message.message}"
|
||||
)
|
||||
continue
|
||||
# if prompt_builder.single_message_history is not None:
|
||||
# history = prompt_builder.single_message_history
|
||||
# else:
|
||||
# history_components = []
|
||||
# previous_message_type = None
|
||||
# for message in prompt_builder.raw_message_history:
|
||||
# if message.message_type == MessageType.USER:
|
||||
# history_components.append(f"User: {message.message}\n")
|
||||
# previous_message_type = MessageType.USER
|
||||
# elif message.message_type == MessageType.ASSISTANT:
|
||||
# # Previously there could be multiple assistant messages in a row
|
||||
# # Now this is handled at the message history construction
|
||||
# assert previous_message_type is not MessageType.ASSISTANT
|
||||
# history_components.append(f"You/Agent: {message.message}\n")
|
||||
# previous_message_type = MessageType.ASSISTANT
|
||||
# else:
|
||||
# # Other message types are not included here, currently there should be no other message types
|
||||
# logger.error(
|
||||
# f"Unhandled message type: {message.message_type} with message: {message.message}"
|
||||
# )
|
||||
# continue
|
||||
|
||||
history = "\n".join(history_components)
|
||||
history = remove_document_citations(history)
|
||||
if len(history.split()) > AGENT_MAX_STATIC_HISTORY_WORD_LENGTH:
|
||||
history = summarize_history(
|
||||
history=history,
|
||||
question=question,
|
||||
persona_specification=persona_base,
|
||||
llm=config.tooling.fast_llm,
|
||||
)
|
||||
# history = "\n".join(history_components)
|
||||
# history = remove_document_citations(history)
|
||||
# if len(history.split()) > AGENT_MAX_STATIC_HISTORY_WORD_LENGTH:
|
||||
# history = summarize_history(
|
||||
# history=history,
|
||||
# question=question,
|
||||
# persona_specification=persona_base,
|
||||
# llm=config.tooling.fast_llm,
|
||||
# )
|
||||
|
||||
return HISTORY_FRAMING_PROMPT.format(history=history) if history else ""
|
||||
# return HISTORY_FRAMING_PROMPT.format(history=history) if history else ""
|
||||
|
||||
|
||||
def get_prompt_enrichment_components(
|
||||
config: GraphConfig,
|
||||
) -> AgentPromptEnrichmentComponents:
|
||||
persona_prompts = get_persona_agent_prompt_expressions(
|
||||
config.inputs.persona, db_session=config.persistence.db_session
|
||||
)
|
||||
# def get_prompt_enrichment_components(
|
||||
# config: GraphConfig,
|
||||
# ) -> AgentPromptEnrichmentComponents:
|
||||
# persona_prompts = get_persona_agent_prompt_expressions(
|
||||
# config.inputs.persona, db_session=config.persistence.db_session
|
||||
# )
|
||||
|
||||
history = build_history_prompt(config, config.inputs.prompt_builder.raw_user_query)
|
||||
# history = build_history_prompt(config, config.inputs.prompt_builder.raw_user_query)
|
||||
|
||||
date_str = build_date_time_string()
|
||||
# date_str = build_date_time_string()
|
||||
|
||||
return AgentPromptEnrichmentComponents(
|
||||
persona_prompts=persona_prompts,
|
||||
history=history,
|
||||
date_str=date_str,
|
||||
)
|
||||
# return AgentPromptEnrichmentComponents(
|
||||
# persona_prompts=persona_prompts,
|
||||
# history=history,
|
||||
# date_str=date_str,
|
||||
# )
|
||||
|
||||
|
||||
def binary_string_test(text: str, positive_value: str = "yes") -> bool:
|
||||
"""
|
||||
Tests if a string contains a positive value (case-insensitive).
|
||||
# def binary_string_test(text: str, positive_value: str = "yes") -> bool:
|
||||
# """
|
||||
# Tests if a string contains a positive value (case-insensitive).
|
||||
|
||||
Args:
|
||||
text: The string to test
|
||||
positive_value: The value to look for (defaults to "yes")
|
||||
# Args:
|
||||
# text: The string to test
|
||||
# positive_value: The value to look for (defaults to "yes")
|
||||
|
||||
Returns:
|
||||
True if the positive value is found in the text
|
||||
"""
|
||||
return positive_value.lower() in text.lower()
|
||||
# Returns:
|
||||
# True if the positive value is found in the text
|
||||
# """
|
||||
# return positive_value.lower() in text.lower()
|
||||
|
||||
|
||||
def binary_string_test_after_answer_separator(
|
||||
text: str, positive_value: str = "yes", separator: str = "Answer:"
|
||||
) -> bool:
|
||||
"""
|
||||
Tests if a string contains a positive value (case-insensitive).
|
||||
# def binary_string_test_after_answer_separator(
|
||||
# text: str, positive_value: str = "yes", separator: str = "Answer:"
|
||||
# ) -> bool:
|
||||
# """
|
||||
# Tests if a string contains a positive value (case-insensitive).
|
||||
|
||||
Args:
|
||||
text: The string to test
|
||||
positive_value: The value to look for (defaults to "yes")
|
||||
# Args:
|
||||
# text: The string to test
|
||||
# positive_value: The value to look for (defaults to "yes")
|
||||
|
||||
Returns:
|
||||
True if the positive value is found in the text
|
||||
"""
|
||||
# Returns:
|
||||
# True if the positive value is found in the text
|
||||
# """
|
||||
|
||||
if separator not in text:
|
||||
return False
|
||||
relevant_text = text.split(f"{separator}")[-1]
|
||||
# if separator not in text:
|
||||
# return False
|
||||
# relevant_text = text.split(f"{separator}")[-1]
|
||||
|
||||
return binary_string_test(relevant_text, positive_value)
|
||||
# return binary_string_test(relevant_text, positive_value)
|
||||
|
||||
@@ -1,205 +1,205 @@
|
||||
import numpy as np
|
||||
# import numpy as np
|
||||
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AnswerGenerationDocuments
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import RetrievalFitScoreMetrics
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import RetrievalFitStats
|
||||
from onyx.agents.agent_search.shared_graph_utils.operators import (
|
||||
dedup_inference_section_list,
|
||||
)
|
||||
from onyx.chat.models import SectionRelevancePiece
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.utils.logger import setup_logger
|
||||
# from onyx.agents.agent_search.shared_graph_utils.models import AnswerGenerationDocuments
|
||||
# from onyx.agents.agent_search.shared_graph_utils.models import RetrievalFitScoreMetrics
|
||||
# from onyx.agents.agent_search.shared_graph_utils.models import RetrievalFitStats
|
||||
# from onyx.agents.agent_search.shared_graph_utils.operators import (
|
||||
# dedup_inference_section_list,
|
||||
# )
|
||||
# from onyx.chat.models import SectionRelevancePiece
|
||||
# from onyx.context.search.models import InferenceSection
|
||||
# from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
# logger = setup_logger()
|
||||
|
||||
|
||||
def unique_chunk_id(doc: InferenceSection) -> str:
|
||||
return f"{doc.center_chunk.document_id}_{doc.center_chunk.chunk_id}"
|
||||
# def unique_chunk_id(doc: InferenceSection) -> str:
|
||||
# return f"{doc.center_chunk.document_id}_{doc.center_chunk.chunk_id}"
|
||||
|
||||
|
||||
def calculate_rank_shift(list1: list, list2: list, top_n: int = 20) -> float:
|
||||
shift = 0
|
||||
for rank_first, doc_id in enumerate(list1[:top_n], 1):
|
||||
try:
|
||||
rank_second = list2.index(doc_id) + 1
|
||||
except ValueError:
|
||||
rank_second = len(list2) # Document not found in second list
|
||||
# def calculate_rank_shift(list1: list, list2: list, top_n: int = 20) -> float:
|
||||
# shift = 0
|
||||
# for rank_first, doc_id in enumerate(list1[:top_n], 1):
|
||||
# try:
|
||||
# rank_second = list2.index(doc_id) + 1
|
||||
# except ValueError:
|
||||
# rank_second = len(list2) # Document not found in second list
|
||||
|
||||
shift += np.abs(rank_first - rank_second) / np.log(1 + rank_first * rank_second)
|
||||
# shift += np.abs(rank_first - rank_second) / np.log(1 + rank_first * rank_second)
|
||||
|
||||
return shift / top_n
|
||||
# return shift / top_n
|
||||
|
||||
|
||||
def get_fit_scores(
|
||||
pre_reranked_results: list[InferenceSection],
|
||||
post_reranked_results: list[InferenceSection] | list[SectionRelevancePiece],
|
||||
) -> RetrievalFitStats | None:
|
||||
"""
|
||||
Calculate retrieval metrics for search purposes
|
||||
"""
|
||||
# def get_fit_scores(
|
||||
# pre_reranked_results: list[InferenceSection],
|
||||
# post_reranked_results: list[InferenceSection] | list[SectionRelevancePiece],
|
||||
# ) -> RetrievalFitStats | None:
|
||||
# """
|
||||
# Calculate retrieval metrics for search purposes
|
||||
# """
|
||||
|
||||
if len(pre_reranked_results) == 0 or len(post_reranked_results) == 0:
|
||||
return None
|
||||
# if len(pre_reranked_results) == 0 or len(post_reranked_results) == 0:
|
||||
# return None
|
||||
|
||||
ranked_sections = {
|
||||
"initial": pre_reranked_results,
|
||||
"reranked": post_reranked_results,
|
||||
}
|
||||
# ranked_sections = {
|
||||
# "initial": pre_reranked_results,
|
||||
# "reranked": post_reranked_results,
|
||||
# }
|
||||
|
||||
fit_eval: RetrievalFitStats = RetrievalFitStats(
|
||||
fit_score_lift=0,
|
||||
rerank_effect=0,
|
||||
fit_scores={
|
||||
"initial": RetrievalFitScoreMetrics(scores={}, chunk_ids=[]),
|
||||
"reranked": RetrievalFitScoreMetrics(scores={}, chunk_ids=[]),
|
||||
},
|
||||
)
|
||||
# fit_eval: RetrievalFitStats = RetrievalFitStats(
|
||||
# fit_score_lift=0,
|
||||
# rerank_effect=0,
|
||||
# fit_scores={
|
||||
# "initial": RetrievalFitScoreMetrics(scores={}, chunk_ids=[]),
|
||||
# "reranked": RetrievalFitScoreMetrics(scores={}, chunk_ids=[]),
|
||||
# },
|
||||
# )
|
||||
|
||||
for rank_type, docs in ranked_sections.items():
|
||||
logger.debug(f"rank_type: {rank_type}")
|
||||
# for rank_type, docs in ranked_sections.items():
|
||||
# logger.debug(f"rank_type: {rank_type}")
|
||||
|
||||
for i in [1, 5, 10]:
|
||||
fit_eval.fit_scores[rank_type].scores[str(i)] = (
|
||||
sum(
|
||||
[
|
||||
float(doc.center_chunk.score)
|
||||
for doc in docs[:i]
|
||||
if isinstance(doc, InferenceSection)
|
||||
and doc.center_chunk.score is not None
|
||||
]
|
||||
)
|
||||
/ i
|
||||
)
|
||||
# for i in [1, 5, 10]:
|
||||
# fit_eval.fit_scores[rank_type].scores[str(i)] = (
|
||||
# sum(
|
||||
# [
|
||||
# float(doc.center_chunk.score)
|
||||
# for doc in docs[:i]
|
||||
# if isinstance(doc, InferenceSection)
|
||||
# and doc.center_chunk.score is not None
|
||||
# ]
|
||||
# )
|
||||
# / i
|
||||
# )
|
||||
|
||||
fit_eval.fit_scores[rank_type].scores["fit_score"] = (
|
||||
1
|
||||
/ 3
|
||||
* (
|
||||
fit_eval.fit_scores[rank_type].scores["1"]
|
||||
+ fit_eval.fit_scores[rank_type].scores["5"]
|
||||
+ fit_eval.fit_scores[rank_type].scores["10"]
|
||||
)
|
||||
)
|
||||
# fit_eval.fit_scores[rank_type].scores["fit_score"] = (
|
||||
# 1
|
||||
# / 3
|
||||
# * (
|
||||
# fit_eval.fit_scores[rank_type].scores["1"]
|
||||
# + fit_eval.fit_scores[rank_type].scores["5"]
|
||||
# + fit_eval.fit_scores[rank_type].scores["10"]
|
||||
# )
|
||||
# )
|
||||
|
||||
fit_eval.fit_scores[rank_type].scores["fit_score"] = fit_eval.fit_scores[
|
||||
rank_type
|
||||
].scores["1"]
|
||||
# fit_eval.fit_scores[rank_type].scores["fit_score"] = fit_eval.fit_scores[
|
||||
# rank_type
|
||||
# ].scores["1"]
|
||||
|
||||
fit_eval.fit_scores[rank_type].chunk_ids = [
|
||||
unique_chunk_id(doc) for doc in docs if isinstance(doc, InferenceSection)
|
||||
]
|
||||
# fit_eval.fit_scores[rank_type].chunk_ids = [
|
||||
# unique_chunk_id(doc) for doc in docs if isinstance(doc, InferenceSection)
|
||||
# ]
|
||||
|
||||
fit_eval.fit_score_lift = (
|
||||
fit_eval.fit_scores["reranked"].scores["fit_score"]
|
||||
/ fit_eval.fit_scores["initial"].scores["fit_score"]
|
||||
)
|
||||
# fit_eval.fit_score_lift = (
|
||||
# fit_eval.fit_scores["reranked"].scores["fit_score"]
|
||||
# / fit_eval.fit_scores["initial"].scores["fit_score"]
|
||||
# )
|
||||
|
||||
fit_eval.rerank_effect = calculate_rank_shift(
|
||||
fit_eval.fit_scores["initial"].chunk_ids,
|
||||
fit_eval.fit_scores["reranked"].chunk_ids,
|
||||
)
|
||||
# fit_eval.rerank_effect = calculate_rank_shift(
|
||||
# fit_eval.fit_scores["initial"].chunk_ids,
|
||||
# fit_eval.fit_scores["reranked"].chunk_ids,
|
||||
# )
|
||||
|
||||
return fit_eval
|
||||
# return fit_eval
|
||||
|
||||
|
||||
def get_answer_generation_documents(
|
||||
relevant_docs: list[InferenceSection],
|
||||
context_documents: list[InferenceSection],
|
||||
original_question_docs: list[InferenceSection],
|
||||
max_docs: int,
|
||||
) -> AnswerGenerationDocuments:
|
||||
"""
|
||||
Create a deduplicated list of documents to stream, prioritizing relevant docs.
|
||||
# def get_answer_generation_documents(
|
||||
# relevant_docs: list[InferenceSection],
|
||||
# context_documents: list[InferenceSection],
|
||||
# original_question_docs: list[InferenceSection],
|
||||
# max_docs: int,
|
||||
# ) -> AnswerGenerationDocuments:
|
||||
# """
|
||||
# Create a deduplicated list of documents to stream, prioritizing relevant docs.
|
||||
|
||||
Args:
|
||||
relevant_docs: Primary documents to include
|
||||
context_documents: Additional context documents to append
|
||||
original_question_docs: Original question documents to append
|
||||
max_docs: Maximum number of documents to return
|
||||
# Args:
|
||||
# relevant_docs: Primary documents to include
|
||||
# context_documents: Additional context documents to append
|
||||
# original_question_docs: Original question documents to append
|
||||
# max_docs: Maximum number of documents to return
|
||||
|
||||
Returns:
|
||||
List of deduplicated documents, limited to max_docs
|
||||
"""
|
||||
# get relevant_doc ids
|
||||
relevant_doc_ids = [doc.center_chunk.document_id for doc in relevant_docs]
|
||||
# Returns:
|
||||
# List of deduplicated documents, limited to max_docs
|
||||
# """
|
||||
# # get relevant_doc ids
|
||||
# relevant_doc_ids = [doc.center_chunk.document_id for doc in relevant_docs]
|
||||
|
||||
# Start with relevant docs or fallback to original question docs
|
||||
streaming_documents = relevant_docs.copy()
|
||||
# # Start with relevant docs or fallback to original question docs
|
||||
# streaming_documents = relevant_docs.copy()
|
||||
|
||||
# Use a set for O(1) lookups of document IDs
|
||||
seen_doc_ids = {doc.center_chunk.document_id for doc in streaming_documents}
|
||||
# # Use a set for O(1) lookups of document IDs
|
||||
# seen_doc_ids = {doc.center_chunk.document_id for doc in streaming_documents}
|
||||
|
||||
# Combine additional documents to check in one iteration
|
||||
additional_docs = context_documents + original_question_docs
|
||||
for doc_idx, doc in enumerate(additional_docs):
|
||||
doc_id = doc.center_chunk.document_id
|
||||
if doc_id not in seen_doc_ids:
|
||||
streaming_documents.append(doc)
|
||||
seen_doc_ids.add(doc_id)
|
||||
# # Combine additional documents to check in one iteration
|
||||
# additional_docs = context_documents + original_question_docs
|
||||
# for doc_idx, doc in enumerate(additional_docs):
|
||||
# doc_id = doc.center_chunk.document_id
|
||||
# if doc_id not in seen_doc_ids:
|
||||
# streaming_documents.append(doc)
|
||||
# seen_doc_ids.add(doc_id)
|
||||
|
||||
streaming_documents = dedup_inference_section_list(streaming_documents)
|
||||
# streaming_documents = dedup_inference_section_list(streaming_documents)
|
||||
|
||||
relevant_streaming_docs = [
|
||||
doc
|
||||
for doc in streaming_documents
|
||||
if doc.center_chunk.document_id in relevant_doc_ids
|
||||
]
|
||||
relevant_streaming_docs = dedup_sort_inference_section_list(relevant_streaming_docs)
|
||||
# relevant_streaming_docs = [
|
||||
# doc
|
||||
# for doc in streaming_documents
|
||||
# if doc.center_chunk.document_id in relevant_doc_ids
|
||||
# ]
|
||||
# relevant_streaming_docs = dedup_sort_inference_section_list(relevant_streaming_docs)
|
||||
|
||||
additional_streaming_docs = [
|
||||
doc
|
||||
for doc in streaming_documents
|
||||
if doc.center_chunk.document_id not in relevant_doc_ids
|
||||
]
|
||||
additional_streaming_docs = dedup_sort_inference_section_list(
|
||||
additional_streaming_docs
|
||||
)
|
||||
# additional_streaming_docs = [
|
||||
# doc
|
||||
# for doc in streaming_documents
|
||||
# if doc.center_chunk.document_id not in relevant_doc_ids
|
||||
# ]
|
||||
# additional_streaming_docs = dedup_sort_inference_section_list(
|
||||
# additional_streaming_docs
|
||||
# )
|
||||
|
||||
for doc in additional_streaming_docs:
|
||||
if doc.center_chunk.score:
|
||||
doc.center_chunk.score += -2.0
|
||||
else:
|
||||
doc.center_chunk.score = -2.0
|
||||
# for doc in additional_streaming_docs:
|
||||
# if doc.center_chunk.score:
|
||||
# doc.center_chunk.score += -2.0
|
||||
# else:
|
||||
# doc.center_chunk.score = -2.0
|
||||
|
||||
sorted_streaming_documents = relevant_streaming_docs + additional_streaming_docs
|
||||
# sorted_streaming_documents = relevant_streaming_docs + additional_streaming_docs
|
||||
|
||||
return AnswerGenerationDocuments(
|
||||
streaming_documents=sorted_streaming_documents[:max_docs],
|
||||
context_documents=relevant_streaming_docs[:max_docs],
|
||||
)
|
||||
# return AnswerGenerationDocuments(
|
||||
# streaming_documents=sorted_streaming_documents[:max_docs],
|
||||
# context_documents=relevant_streaming_docs[:max_docs],
|
||||
# )
|
||||
|
||||
|
||||
def dedup_sort_inference_section_list(
|
||||
sections: list[InferenceSection],
|
||||
) -> list[InferenceSection]:
|
||||
"""Deduplicates InferenceSections by document_id and sorts by score.
|
||||
# def dedup_sort_inference_section_list(
|
||||
# sections: list[InferenceSection],
|
||||
# ) -> list[InferenceSection]:
|
||||
# """Deduplicates InferenceSections by document_id and sorts by score.
|
||||
|
||||
Args:
|
||||
sections: List of InferenceSections to deduplicate and sort
|
||||
# Args:
|
||||
# sections: List of InferenceSections to deduplicate and sort
|
||||
|
||||
Returns:
|
||||
Deduplicated list of InferenceSections sorted by score in descending order
|
||||
"""
|
||||
# dedupe/merge with existing framework
|
||||
sections = dedup_inference_section_list(sections)
|
||||
# Returns:
|
||||
# Deduplicated list of InferenceSections sorted by score in descending order
|
||||
# """
|
||||
# # dedupe/merge with existing framework
|
||||
# sections = dedup_inference_section_list(sections)
|
||||
|
||||
# Use dict to deduplicate by document_id, keeping highest scored version
|
||||
unique_sections: dict[str, InferenceSection] = {}
|
||||
for section in sections:
|
||||
doc_id = section.center_chunk.document_id
|
||||
if doc_id not in unique_sections:
|
||||
unique_sections[doc_id] = section
|
||||
continue
|
||||
# # Use dict to deduplicate by document_id, keeping highest scored version
|
||||
# unique_sections: dict[str, InferenceSection] = {}
|
||||
# for section in sections:
|
||||
# doc_id = section.center_chunk.document_id
|
||||
# if doc_id not in unique_sections:
|
||||
# unique_sections[doc_id] = section
|
||||
# continue
|
||||
|
||||
# Keep version with higher score
|
||||
existing_score = unique_sections[doc_id].center_chunk.score or 0
|
||||
new_score = section.center_chunk.score or 0
|
||||
if new_score > existing_score:
|
||||
unique_sections[doc_id] = section
|
||||
# # Keep version with higher score
|
||||
# existing_score = unique_sections[doc_id].center_chunk.score or 0
|
||||
# new_score = section.center_chunk.score or 0
|
||||
# if new_score > existing_score:
|
||||
# unique_sections[doc_id] = section
|
||||
|
||||
# Sort by score in descending order, handling None scores
|
||||
sorted_sections = sorted(
|
||||
unique_sections.values(), key=lambda x: x.center_chunk.score or 0, reverse=True
|
||||
)
|
||||
# # Sort by score in descending order, handling None scores
|
||||
# sorted_sections = sorted(
|
||||
# unique_sections.values(), key=lambda x: x.center_chunk.score or 0, reverse=True
|
||||
# )
|
||||
|
||||
return sorted_sections
|
||||
# return sorted_sections
|
||||
|
||||
@@ -1,24 +1,24 @@
|
||||
from enum import Enum
|
||||
# from enum import Enum
|
||||
|
||||
AGENT_LLM_TIMEOUT_MESSAGE = "The agent timed out. Please try again."
|
||||
AGENT_LLM_ERROR_MESSAGE = "The agent encountered an error. Please try again."
|
||||
AGENT_LLM_RATELIMIT_MESSAGE = (
|
||||
"The agent encountered a rate limit error. Please try again."
|
||||
)
|
||||
LLM_ANSWER_ERROR_MESSAGE = "The question was not answered due to an LLM error."
|
||||
# AGENT_LLM_TIMEOUT_MESSAGE = "The agent timed out. Please try again."
|
||||
# AGENT_LLM_ERROR_MESSAGE = "The agent encountered an error. Please try again."
|
||||
# AGENT_LLM_RATELIMIT_MESSAGE = (
|
||||
# "The agent encountered a rate limit error. Please try again."
|
||||
# )
|
||||
# LLM_ANSWER_ERROR_MESSAGE = "The question was not answered due to an LLM error."
|
||||
|
||||
AGENT_POSITIVE_VALUE_STR = "yes"
|
||||
AGENT_NEGATIVE_VALUE_STR = "no"
|
||||
# AGENT_POSITIVE_VALUE_STR = "yes"
|
||||
# AGENT_NEGATIVE_VALUE_STR = "no"
|
||||
|
||||
AGENT_ANSWER_SEPARATOR = "Answer:"
|
||||
# AGENT_ANSWER_SEPARATOR = "Answer:"
|
||||
|
||||
|
||||
EMBEDDING_KEY = "embedding"
|
||||
IS_KEYWORD_KEY = "is_keyword"
|
||||
KEYWORDS_KEY = "keywords"
|
||||
# EMBEDDING_KEY = "embedding"
|
||||
# IS_KEYWORD_KEY = "is_keyword"
|
||||
# KEYWORDS_KEY = "keywords"
|
||||
|
||||
|
||||
class AgentLLMErrorType(str, Enum):
|
||||
TIMEOUT = "timeout"
|
||||
RATE_LIMIT = "rate_limit"
|
||||
GENERAL_ERROR = "general_error"
|
||||
# class AgentLLMErrorType(str, Enum):
|
||||
# TIMEOUT = "timeout"
|
||||
# RATE_LIMIT = "rate_limit"
|
||||
# GENERAL_ERROR = "general_error"
|
||||
|
||||
@@ -1,243 +1,243 @@
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
from typing import Literal
|
||||
from typing import Type
|
||||
from typing import TypeVar
|
||||
# import re
|
||||
# from datetime import datetime
|
||||
# from typing import cast
|
||||
# from typing import Literal
|
||||
# from typing import Type
|
||||
# from typing import TypeVar
|
||||
|
||||
from braintrust import traced
|
||||
from langchain.schema.language_model import LanguageModelInput
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langgraph.types import StreamWriter
|
||||
from pydantic import BaseModel
|
||||
# from braintrust import traced
|
||||
# from langchain.schema.language_model import LanguageModelInput
|
||||
# from langchain_core.messages import HumanMessage
|
||||
# from langgraph.types import StreamWriter
|
||||
# from pydantic import BaseModel
|
||||
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.chat.stream_processing.citation_processing import CitationProcessorGraph
|
||||
from onyx.chat.stream_processing.citation_processing import LlmDoc
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.interfaces import ToolChoiceOptions
|
||||
from onyx.server.query_and_chat.streaming_models import CitationInfo
|
||||
from onyx.server.query_and_chat.streaming_models import MessageDelta
|
||||
from onyx.server.query_and_chat.streaming_models import ReasoningDelta
|
||||
from onyx.server.query_and_chat.streaming_models import StreamingType
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
# from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
# from onyx.chat.stream_processing.citation_processing import CitationProcessorGraph
|
||||
# from onyx.chat.stream_processing.citation_processing import LlmDoc
|
||||
# from onyx.llm.interfaces import LLM
|
||||
# from onyx.llm.interfaces import ToolChoiceOptions
|
||||
# from onyx.server.query_and_chat.streaming_models import CitationInfo
|
||||
# from onyx.server.query_and_chat.streaming_models import MessageDelta
|
||||
# from onyx.server.query_and_chat.streaming_models import ReasoningDelta
|
||||
# from onyx.server.query_and_chat.streaming_models import StreamingType
|
||||
# from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
|
||||
SchemaType = TypeVar("SchemaType", bound=BaseModel)
|
||||
# SchemaType = TypeVar("SchemaType", bound=BaseModel)
|
||||
|
||||
# match ```json{...}``` or ```{...}```
|
||||
JSON_PATTERN = re.compile(r"```(?:json)?\s*(\{.*?\})\s*```", re.DOTALL)
|
||||
# # match ```json{...}``` or ```{...}```
|
||||
# JSON_PATTERN = re.compile(r"```(?:json)?\s*(\{.*?\})\s*```", re.DOTALL)
|
||||
|
||||
|
||||
@traced(name="stream llm", type="llm")
|
||||
def stream_llm_answer(
|
||||
llm: LLM,
|
||||
prompt: LanguageModelInput,
|
||||
event_name: str,
|
||||
writer: StreamWriter,
|
||||
agent_answer_level: int,
|
||||
agent_answer_question_num: int,
|
||||
agent_answer_type: Literal["agent_level_answer", "agent_sub_answer"],
|
||||
timeout_override: int | None = None,
|
||||
max_tokens: int | None = None,
|
||||
answer_piece: str | None = None,
|
||||
ind: int | None = None,
|
||||
context_docs: list[LlmDoc] | None = None,
|
||||
replace_citations: bool = False,
|
||||
) -> tuple[list[str], list[float], list[CitationInfo]]:
|
||||
"""Stream the initial answer from the LLM.
|
||||
# @traced(name="stream llm", type="llm")
|
||||
# def stream_llm_answer(
|
||||
# llm: LLM,
|
||||
# prompt: LanguageModelInput,
|
||||
# event_name: str,
|
||||
# writer: StreamWriter,
|
||||
# agent_answer_level: int,
|
||||
# agent_answer_question_num: int,
|
||||
# agent_answer_type: Literal["agent_level_answer", "agent_sub_answer"],
|
||||
# timeout_override: int | None = None,
|
||||
# max_tokens: int | None = None,
|
||||
# answer_piece: str | None = None,
|
||||
# ind: int | None = None,
|
||||
# context_docs: list[LlmDoc] | None = None,
|
||||
# replace_citations: bool = False,
|
||||
# ) -> tuple[list[str], list[float], list[CitationInfo]]:
|
||||
# """Stream the initial answer from the LLM.
|
||||
|
||||
Args:
|
||||
llm: The LLM to use.
|
||||
prompt: The prompt to use.
|
||||
event_name: The name of the event to write.
|
||||
writer: The writer to write to.
|
||||
agent_answer_level: The level of the agent answer.
|
||||
agent_answer_question_num: The question number within the level.
|
||||
agent_answer_type: The type of answer ("agent_level_answer" or "agent_sub_answer").
|
||||
timeout_override: The LLM timeout to use.
|
||||
max_tokens: The LLM max tokens to use.
|
||||
answer_piece: The type of answer piece to write.
|
||||
ind: The index of the answer piece.
|
||||
tools: The tools to use.
|
||||
tool_choice: The tool choice to use.
|
||||
structured_response_format: The structured response format to use.
|
||||
# Args:
|
||||
# llm: The LLM to use.
|
||||
# prompt: The prompt to use.
|
||||
# event_name: The name of the event to write.
|
||||
# writer: The writer to write to.
|
||||
# agent_answer_level: The level of the agent answer.
|
||||
# agent_answer_question_num: The question number within the level.
|
||||
# agent_answer_type: The type of answer ("agent_level_answer" or "agent_sub_answer").
|
||||
# timeout_override: The LLM timeout to use.
|
||||
# max_tokens: The LLM max tokens to use.
|
||||
# answer_piece: The type of answer piece to write.
|
||||
# ind: The index of the answer piece.
|
||||
# tools: The tools to use.
|
||||
# tool_choice: The tool choice to use.
|
||||
# structured_response_format: The structured response format to use.
|
||||
|
||||
Returns:
|
||||
A tuple of the response and the dispatch timings.
|
||||
"""
|
||||
response: list[str] = []
|
||||
dispatch_timings: list[float] = []
|
||||
citation_infos: list[CitationInfo] = []
|
||||
# Returns:
|
||||
# A tuple of the response and the dispatch timings.
|
||||
# """
|
||||
# response: list[str] = []
|
||||
# dispatch_timings: list[float] = []
|
||||
# citation_infos: list[CitationInfo] = []
|
||||
|
||||
if context_docs:
|
||||
citation_processor = CitationProcessorGraph(
|
||||
context_docs=context_docs,
|
||||
)
|
||||
else:
|
||||
citation_processor = None
|
||||
# if context_docs:
|
||||
# citation_processor = CitationProcessorGraph(
|
||||
# context_docs=context_docs,
|
||||
# )
|
||||
# else:
|
||||
# citation_processor = None
|
||||
|
||||
for message in llm.stream_langchain(
|
||||
prompt,
|
||||
timeout_override=timeout_override,
|
||||
max_tokens=max_tokens,
|
||||
):
|
||||
# for message in llm.stream_langchain(
|
||||
# prompt,
|
||||
# timeout_override=timeout_override,
|
||||
# max_tokens=max_tokens,
|
||||
# ):
|
||||
|
||||
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
|
||||
content = message.content
|
||||
if not isinstance(content, str):
|
||||
raise ValueError(
|
||||
f"Expected content to be a string, but got {type(content)}"
|
||||
)
|
||||
# # TODO: in principle, the answer here COULD contain images, but we don't support that yet
|
||||
# content = message.content
|
||||
# if not isinstance(content, str):
|
||||
# raise ValueError(
|
||||
# f"Expected content to be a string, but got {type(content)}"
|
||||
# )
|
||||
|
||||
start_stream_token = datetime.now()
|
||||
# start_stream_token = datetime.now()
|
||||
|
||||
if answer_piece == StreamingType.MESSAGE_DELTA.value:
|
||||
if ind is None:
|
||||
raise ValueError("index is required when answer_piece is message_delta")
|
||||
# if answer_piece == StreamingType.MESSAGE_DELTA.value:
|
||||
# if ind is None:
|
||||
# raise ValueError("index is required when answer_piece is message_delta")
|
||||
|
||||
if citation_processor:
|
||||
processed_token = citation_processor.process_token(content)
|
||||
# if citation_processor:
|
||||
# processed_token = citation_processor.process_token(content)
|
||||
|
||||
if isinstance(processed_token, tuple):
|
||||
content = processed_token[0]
|
||||
citation_infos.extend(processed_token[1])
|
||||
elif isinstance(processed_token, str):
|
||||
content = processed_token
|
||||
else:
|
||||
continue
|
||||
# if isinstance(processed_token, tuple):
|
||||
# content = processed_token[0]
|
||||
# citation_infos.extend(processed_token[1])
|
||||
# elif isinstance(processed_token, str):
|
||||
# content = processed_token
|
||||
# else:
|
||||
# continue
|
||||
|
||||
write_custom_event(
|
||||
ind,
|
||||
MessageDelta(content=content),
|
||||
writer,
|
||||
)
|
||||
# write_custom_event(
|
||||
# ind,
|
||||
# MessageDelta(content=content),
|
||||
# writer,
|
||||
# )
|
||||
|
||||
elif answer_piece == StreamingType.REASONING_DELTA.value:
|
||||
if ind is None:
|
||||
raise ValueError(
|
||||
"index is required when answer_piece is reasoning_delta"
|
||||
)
|
||||
write_custom_event(
|
||||
ind,
|
||||
ReasoningDelta(reasoning=content),
|
||||
writer,
|
||||
)
|
||||
# elif answer_piece == StreamingType.REASONING_DELTA.value:
|
||||
# if ind is None:
|
||||
# raise ValueError(
|
||||
# "index is required when answer_piece is reasoning_delta"
|
||||
# )
|
||||
# write_custom_event(
|
||||
# ind,
|
||||
# ReasoningDelta(reasoning=content),
|
||||
# writer,
|
||||
# )
|
||||
|
||||
else:
|
||||
raise ValueError(f"Invalid answer piece: {answer_piece}")
|
||||
# else:
|
||||
# raise ValueError(f"Invalid answer piece: {answer_piece}")
|
||||
|
||||
end_stream_token = datetime.now()
|
||||
# end_stream_token = datetime.now()
|
||||
|
||||
dispatch_timings.append((end_stream_token - start_stream_token).microseconds)
|
||||
response.append(content)
|
||||
# dispatch_timings.append((end_stream_token - start_stream_token).microseconds)
|
||||
# response.append(content)
|
||||
|
||||
return response, dispatch_timings, citation_infos
|
||||
# return response, dispatch_timings, citation_infos
|
||||
|
||||
|
||||
def invoke_llm_json(
|
||||
llm: LLM,
|
||||
prompt: LanguageModelInput,
|
||||
schema: Type[SchemaType],
|
||||
tools: list[dict] | None = None,
|
||||
tool_choice: ToolChoiceOptions | None = None,
|
||||
timeout_override: int | None = None,
|
||||
max_tokens: int | None = None,
|
||||
) -> SchemaType:
|
||||
"""
|
||||
Invoke an LLM, forcing it to respond in a specified JSON format if possible,
|
||||
and return an object of that schema.
|
||||
"""
|
||||
from litellm.utils import get_supported_openai_params, supports_response_schema
|
||||
# def invoke_llm_json(
|
||||
# llm: LLM,
|
||||
# prompt: LanguageModelInput,
|
||||
# schema: Type[SchemaType],
|
||||
# tools: list[dict] | None = None,
|
||||
# tool_choice: ToolChoiceOptions | None = None,
|
||||
# timeout_override: int | None = None,
|
||||
# max_tokens: int | None = None,
|
||||
# ) -> SchemaType:
|
||||
# """
|
||||
# Invoke an LLM, forcing it to respond in a specified JSON format if possible,
|
||||
# and return an object of that schema.
|
||||
# """
|
||||
# from litellm.utils import get_supported_openai_params, supports_response_schema
|
||||
|
||||
# check if the model supports response_format: json_schema
|
||||
supports_json = "response_format" in (
|
||||
get_supported_openai_params(llm.config.model_name, llm.config.model_provider)
|
||||
or []
|
||||
) and supports_response_schema(llm.config.model_name, llm.config.model_provider)
|
||||
# # check if the model supports response_format: json_schema
|
||||
# supports_json = "response_format" in (
|
||||
# get_supported_openai_params(llm.config.model_name, llm.config.model_provider)
|
||||
# or []
|
||||
# ) and supports_response_schema(llm.config.model_name, llm.config.model_provider)
|
||||
|
||||
response_content = str(
|
||||
llm.invoke_langchain(
|
||||
prompt,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
timeout_override=timeout_override,
|
||||
max_tokens=max_tokens,
|
||||
**cast(
|
||||
dict, {"structured_response_format": schema} if supports_json else {}
|
||||
),
|
||||
).content
|
||||
)
|
||||
# response_content = str(
|
||||
# llm.invoke_langchain(
|
||||
# prompt,
|
||||
# tools=tools,
|
||||
# tool_choice=tool_choice,
|
||||
# timeout_override=timeout_override,
|
||||
# max_tokens=max_tokens,
|
||||
# **cast(
|
||||
# dict, {"structured_response_format": schema} if supports_json else {}
|
||||
# ),
|
||||
# ).content
|
||||
# )
|
||||
|
||||
if not supports_json:
|
||||
# remove newlines as they often lead to json decoding errors
|
||||
response_content = response_content.replace("\n", " ")
|
||||
# hope the prompt is structured in a way a json is outputted...
|
||||
json_block_match = JSON_PATTERN.search(response_content)
|
||||
if json_block_match:
|
||||
response_content = json_block_match.group(1)
|
||||
else:
|
||||
first_bracket = response_content.find("{")
|
||||
last_bracket = response_content.rfind("}")
|
||||
response_content = response_content[first_bracket : last_bracket + 1]
|
||||
# if not supports_json:
|
||||
# # remove newlines as they often lead to json decoding errors
|
||||
# response_content = response_content.replace("\n", " ")
|
||||
# # hope the prompt is structured in a way a json is outputted...
|
||||
# json_block_match = JSON_PATTERN.search(response_content)
|
||||
# if json_block_match:
|
||||
# response_content = json_block_match.group(1)
|
||||
# else:
|
||||
# first_bracket = response_content.find("{")
|
||||
# last_bracket = response_content.rfind("}")
|
||||
# response_content = response_content[first_bracket : last_bracket + 1]
|
||||
|
||||
return schema.model_validate_json(response_content)
|
||||
# return schema.model_validate_json(response_content)
|
||||
|
||||
|
||||
def get_answer_from_llm(
|
||||
llm: LLM,
|
||||
prompt: str,
|
||||
timeout: int = 25,
|
||||
timeout_override: int = 5,
|
||||
max_tokens: int = 500,
|
||||
stream: bool = False,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
agent_answer_level: int = 0,
|
||||
agent_answer_question_num: int = 0,
|
||||
agent_answer_type: Literal[
|
||||
"agent_sub_answer", "agent_level_answer"
|
||||
] = "agent_level_answer",
|
||||
json_string_flag: bool = False,
|
||||
) -> str:
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=prompt,
|
||||
)
|
||||
]
|
||||
# def get_answer_from_llm(
|
||||
# llm: LLM,
|
||||
# prompt: str,
|
||||
# timeout: int = 25,
|
||||
# timeout_override: int = 5,
|
||||
# max_tokens: int = 500,
|
||||
# stream: bool = False,
|
||||
# writer: StreamWriter = lambda _: None,
|
||||
# agent_answer_level: int = 0,
|
||||
# agent_answer_question_num: int = 0,
|
||||
# agent_answer_type: Literal[
|
||||
# "agent_sub_answer", "agent_level_answer"
|
||||
# ] = "agent_level_answer",
|
||||
# json_string_flag: bool = False,
|
||||
# ) -> str:
|
||||
# msg = [
|
||||
# HumanMessage(
|
||||
# content=prompt,
|
||||
# )
|
||||
# ]
|
||||
|
||||
if stream:
|
||||
# TODO - adjust for new UI. This is currently not working for current UI/Basic Search
|
||||
stream_response, _, _ = run_with_timeout(
|
||||
timeout,
|
||||
lambda: stream_llm_answer(
|
||||
llm=llm,
|
||||
prompt=msg,
|
||||
event_name="sub_answers",
|
||||
writer=writer,
|
||||
agent_answer_level=agent_answer_level,
|
||||
agent_answer_question_num=agent_answer_question_num,
|
||||
agent_answer_type=agent_answer_type,
|
||||
timeout_override=timeout_override,
|
||||
max_tokens=max_tokens,
|
||||
),
|
||||
)
|
||||
content = "".join(stream_response)
|
||||
else:
|
||||
llm_response = run_with_timeout(
|
||||
timeout,
|
||||
llm.invoke_langchain,
|
||||
prompt=msg,
|
||||
timeout_override=timeout_override,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
content = str(llm_response.content)
|
||||
# if stream:
|
||||
# # TODO - adjust for new UI. This is currently not working for current UI/Basic Search
|
||||
# stream_response, _, _ = run_with_timeout(
|
||||
# timeout,
|
||||
# lambda: stream_llm_answer(
|
||||
# llm=llm,
|
||||
# prompt=msg,
|
||||
# event_name="sub_answers",
|
||||
# writer=writer,
|
||||
# agent_answer_level=agent_answer_level,
|
||||
# agent_answer_question_num=agent_answer_question_num,
|
||||
# agent_answer_type=agent_answer_type,
|
||||
# timeout_override=timeout_override,
|
||||
# max_tokens=max_tokens,
|
||||
# ),
|
||||
# )
|
||||
# content = "".join(stream_response)
|
||||
# else:
|
||||
# llm_response = run_with_timeout(
|
||||
# timeout,
|
||||
# llm.invoke_langchain,
|
||||
# prompt=msg,
|
||||
# timeout_override=timeout_override,
|
||||
# max_tokens=max_tokens,
|
||||
# )
|
||||
# content = str(llm_response.content)
|
||||
|
||||
cleaned_response = content
|
||||
if json_string_flag:
|
||||
cleaned_response = (
|
||||
str(content).replace("```json\n", "").replace("\n```", "").replace("\n", "")
|
||||
)
|
||||
first_bracket = cleaned_response.find("{")
|
||||
last_bracket = cleaned_response.rfind("}")
|
||||
cleaned_response = cleaned_response[first_bracket : last_bracket + 1]
|
||||
# cleaned_response = content
|
||||
# if json_string_flag:
|
||||
# cleaned_response = (
|
||||
# str(content).replace("```json\n", "").replace("\n```", "").replace("\n", "")
|
||||
# )
|
||||
# first_bracket = cleaned_response.find("{")
|
||||
# last_bracket = cleaned_response.rfind("}")
|
||||
# cleaned_response = cleaned_response[first_bracket : last_bracket + 1]
|
||||
|
||||
return cleaned_response
|
||||
# return cleaned_response
|
||||
|
||||
@@ -1,166 +1,166 @@
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
# from enum import Enum
|
||||
# from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
# from pydantic import BaseModel
|
||||
|
||||
from onyx.agents.agent_search.deep_search.main.models import (
|
||||
AgentAdditionalMetrics,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.main.models import AgentBaseMetrics
|
||||
from onyx.agents.agent_search.deep_search.main.models import (
|
||||
AgentRefinedMetrics,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.main.models import AgentTimings
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.tools.models import SearchQueryInfo
|
||||
# from onyx.agents.agent_search.deep_search.main.models import (
|
||||
# AgentAdditionalMetrics,
|
||||
# )
|
||||
# from onyx.agents.agent_search.deep_search.main.models import AgentBaseMetrics
|
||||
# from onyx.agents.agent_search.deep_search.main.models import (
|
||||
# AgentRefinedMetrics,
|
||||
# )
|
||||
# from onyx.agents.agent_search.deep_search.main.models import AgentTimings
|
||||
# from onyx.context.search.models import InferenceSection
|
||||
# from onyx.tools.models import SearchQueryInfo
|
||||
|
||||
|
||||
# Pydantic models for structured outputs
|
||||
# class RewrittenQueries(BaseModel):
|
||||
# rewritten_queries: list[str]
|
||||
# # Pydantic models for structured outputs
|
||||
# # class RewrittenQueries(BaseModel):
|
||||
# # rewritten_queries: list[str]
|
||||
|
||||
|
||||
# class BinaryDecision(BaseModel):
|
||||
# decision: Literal["yes", "no"]
|
||||
# # class BinaryDecision(BaseModel):
|
||||
# # decision: Literal["yes", "no"]
|
||||
|
||||
|
||||
# class BinaryDecisionWithReasoning(BaseModel):
|
||||
# reasoning: str
|
||||
# decision: Literal["yes", "no"]
|
||||
# # class BinaryDecisionWithReasoning(BaseModel):
|
||||
# # reasoning: str
|
||||
# # decision: Literal["yes", "no"]
|
||||
|
||||
|
||||
class RetrievalFitScoreMetrics(BaseModel):
|
||||
scores: dict[str, float]
|
||||
chunk_ids: list[str]
|
||||
# class RetrievalFitScoreMetrics(BaseModel):
|
||||
# scores: dict[str, float]
|
||||
# chunk_ids: list[str]
|
||||
|
||||
|
||||
class RetrievalFitStats(BaseModel):
|
||||
fit_score_lift: float
|
||||
rerank_effect: float
|
||||
fit_scores: dict[str, RetrievalFitScoreMetrics]
|
||||
# class RetrievalFitStats(BaseModel):
|
||||
# fit_score_lift: float
|
||||
# rerank_effect: float
|
||||
# fit_scores: dict[str, RetrievalFitScoreMetrics]
|
||||
|
||||
|
||||
# class AgentChunkScores(BaseModel):
|
||||
# scores: dict[str, dict[str, list[int | float]]]
|
||||
# # class AgentChunkScores(BaseModel):
|
||||
# # scores: dict[str, dict[str, list[int | float]]]
|
||||
|
||||
|
||||
class AgentChunkRetrievalStats(BaseModel):
|
||||
verified_count: int | None = None
|
||||
verified_avg_scores: float | None = None
|
||||
rejected_count: int | None = None
|
||||
rejected_avg_scores: float | None = None
|
||||
verified_doc_chunk_ids: list[str] = []
|
||||
dismissed_doc_chunk_ids: list[str] = []
|
||||
# class AgentChunkRetrievalStats(BaseModel):
|
||||
# verified_count: int | None = None
|
||||
# verified_avg_scores: float | None = None
|
||||
# rejected_count: int | None = None
|
||||
# rejected_avg_scores: float | None = None
|
||||
# verified_doc_chunk_ids: list[str] = []
|
||||
# dismissed_doc_chunk_ids: list[str] = []
|
||||
|
||||
|
||||
class InitialAgentResultStats(BaseModel):
|
||||
sub_questions: dict[str, float | int | None]
|
||||
original_question: dict[str, float | int | None]
|
||||
agent_effectiveness: dict[str, float | int | None]
|
||||
# class InitialAgentResultStats(BaseModel):
|
||||
# sub_questions: dict[str, float | int | None]
|
||||
# original_question: dict[str, float | int | None]
|
||||
# agent_effectiveness: dict[str, float | int | None]
|
||||
|
||||
|
||||
class AgentErrorLog(BaseModel):
|
||||
error_message: str
|
||||
error_type: str
|
||||
error_result: str
|
||||
# class AgentErrorLog(BaseModel):
|
||||
# error_message: str
|
||||
# error_type: str
|
||||
# error_result: str
|
||||
|
||||
|
||||
class RefinedAgentStats(BaseModel):
|
||||
revision_doc_efficiency: float | None
|
||||
revision_question_efficiency: float | None
|
||||
# class RefinedAgentStats(BaseModel):
|
||||
# revision_doc_efficiency: float | None
|
||||
# revision_question_efficiency: float | None
|
||||
|
||||
|
||||
class Term(BaseModel):
|
||||
term_name: str = ""
|
||||
term_type: str = ""
|
||||
term_similar_to: list[str] = []
|
||||
# class Term(BaseModel):
|
||||
# term_name: str = ""
|
||||
# term_type: str = ""
|
||||
# term_similar_to: list[str] = []
|
||||
|
||||
|
||||
### Models ###
|
||||
# ### Models ###
|
||||
|
||||
|
||||
class Entity(BaseModel):
|
||||
entity_name: str = ""
|
||||
entity_type: str = ""
|
||||
# class Entity(BaseModel):
|
||||
# entity_name: str = ""
|
||||
# entity_type: str = ""
|
||||
|
||||
|
||||
class Relationship(BaseModel):
|
||||
relationship_name: str = ""
|
||||
relationship_type: str = ""
|
||||
relationship_entities: list[str] = []
|
||||
# class Relationship(BaseModel):
|
||||
# relationship_name: str = ""
|
||||
# relationship_type: str = ""
|
||||
# relationship_entities: list[str] = []
|
||||
|
||||
|
||||
class EntityRelationshipTermExtraction(BaseModel):
|
||||
entities: list[Entity] = []
|
||||
relationships: list[Relationship] = []
|
||||
terms: list[Term] = []
|
||||
# class EntityRelationshipTermExtraction(BaseModel):
|
||||
# entities: list[Entity] = []
|
||||
# relationships: list[Relationship] = []
|
||||
# terms: list[Term] = []
|
||||
|
||||
|
||||
class EntityExtractionResult(BaseModel):
|
||||
retrieved_entities_relationships: EntityRelationshipTermExtraction
|
||||
# class EntityExtractionResult(BaseModel):
|
||||
# retrieved_entities_relationships: EntityRelationshipTermExtraction
|
||||
|
||||
|
||||
class QueryRetrievalResult(BaseModel):
|
||||
query: str
|
||||
retrieved_documents: list[InferenceSection]
|
||||
stats: RetrievalFitStats | None
|
||||
query_info: SearchQueryInfo | None
|
||||
# class QueryRetrievalResult(BaseModel):
|
||||
# query: str
|
||||
# retrieved_documents: list[InferenceSection]
|
||||
# stats: RetrievalFitStats | None
|
||||
# query_info: SearchQueryInfo | None
|
||||
|
||||
|
||||
class SubQuestionAnswerResults(BaseModel):
|
||||
question: str
|
||||
question_id: str
|
||||
answer: str
|
||||
verified_high_quality: bool
|
||||
sub_query_retrieval_results: list[QueryRetrievalResult]
|
||||
verified_reranked_documents: list[InferenceSection]
|
||||
context_documents: list[InferenceSection]
|
||||
cited_documents: list[InferenceSection]
|
||||
sub_question_retrieval_stats: AgentChunkRetrievalStats
|
||||
# class SubQuestionAnswerResults(BaseModel):
|
||||
# question: str
|
||||
# question_id: str
|
||||
# answer: str
|
||||
# verified_high_quality: bool
|
||||
# sub_query_retrieval_results: list[QueryRetrievalResult]
|
||||
# verified_reranked_documents: list[InferenceSection]
|
||||
# context_documents: list[InferenceSection]
|
||||
# cited_documents: list[InferenceSection]
|
||||
# sub_question_retrieval_stats: AgentChunkRetrievalStats
|
||||
|
||||
|
||||
class StructuredSubquestionDocuments(BaseModel):
|
||||
cited_documents: list[InferenceSection]
|
||||
context_documents: list[InferenceSection]
|
||||
# class StructuredSubquestionDocuments(BaseModel):
|
||||
# cited_documents: list[InferenceSection]
|
||||
# context_documents: list[InferenceSection]
|
||||
|
||||
|
||||
class CombinedAgentMetrics(BaseModel):
|
||||
timings: AgentTimings
|
||||
base_metrics: AgentBaseMetrics | None
|
||||
refined_metrics: AgentRefinedMetrics
|
||||
additional_metrics: AgentAdditionalMetrics
|
||||
# class CombinedAgentMetrics(BaseModel):
|
||||
# timings: AgentTimings
|
||||
# base_metrics: AgentBaseMetrics | None
|
||||
# refined_metrics: AgentRefinedMetrics
|
||||
# additional_metrics: AgentAdditionalMetrics
|
||||
|
||||
|
||||
class PersonaPromptExpressions(BaseModel):
|
||||
contextualized_prompt: str
|
||||
base_prompt: str | None
|
||||
# class PersonaPromptExpressions(BaseModel):
|
||||
# contextualized_prompt: str
|
||||
# base_prompt: str | None
|
||||
|
||||
|
||||
class AgentPromptEnrichmentComponents(BaseModel):
|
||||
persona_prompts: PersonaPromptExpressions
|
||||
history: str
|
||||
date_str: str
|
||||
# class AgentPromptEnrichmentComponents(BaseModel):
|
||||
# persona_prompts: PersonaPromptExpressions
|
||||
# history: str
|
||||
# date_str: str
|
||||
|
||||
|
||||
class LLMNodeErrorStrings(BaseModel):
|
||||
timeout: str = "LLM Timeout Error"
|
||||
rate_limit: str = "LLM Rate Limit Error"
|
||||
general_error: str = "General LLM Error"
|
||||
# class LLMNodeErrorStrings(BaseModel):
|
||||
# timeout: str = "LLM Timeout Error"
|
||||
# rate_limit: str = "LLM Rate Limit Error"
|
||||
# general_error: str = "General LLM Error"
|
||||
|
||||
|
||||
class AnswerGenerationDocuments(BaseModel):
|
||||
streaming_documents: list[InferenceSection]
|
||||
context_documents: list[InferenceSection]
|
||||
# class AnswerGenerationDocuments(BaseModel):
|
||||
# streaming_documents: list[InferenceSection]
|
||||
# context_documents: list[InferenceSection]
|
||||
|
||||
|
||||
BaseMessage_Content = str | list[str | dict[str, Any]]
|
||||
# BaseMessage_Content = str | list[str | dict[str, Any]]
|
||||
|
||||
|
||||
class QueryExpansionType(Enum):
|
||||
KEYWORD = "keyword"
|
||||
SEMANTIC = "semantic"
|
||||
# class QueryExpansionType(Enum):
|
||||
# KEYWORD = "keyword"
|
||||
# SEMANTIC = "semantic"
|
||||
|
||||
|
||||
class ReferenceResults(BaseModel):
|
||||
citations: list[str]
|
||||
general_entities: list[str]
|
||||
# class ReferenceResults(BaseModel):
|
||||
# citations: list[str]
|
||||
# general_entities: list[str]
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user