mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-21 09:45:46 +00:00
Compare commits
33 Commits
testing
...
agent-sear
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6803548066 | ||
|
|
1111ce6ce4 | ||
|
|
3f68e8ea8e | ||
|
|
06a8373ff4 | ||
|
|
86e770d968 | ||
|
|
f11216132e | ||
|
|
1f7d05cd75 | ||
|
|
c8bf051fb6 | ||
|
|
14b54db033 | ||
|
|
0e9f9301ba | ||
|
|
69c60feda4 | ||
|
|
a215ea9143 | ||
|
|
f81a42b4e8 | ||
|
|
b095e17827 | ||
|
|
2a758ae33f | ||
|
|
3e58cf2667 | ||
|
|
b9c29f2a36 | ||
|
|
647adb9ba0 | ||
|
|
7d6d73529b | ||
|
|
420476ad92 | ||
|
|
4ca7325d1a | ||
|
|
8ddd95d0d4 | ||
|
|
1378364686 | ||
|
|
cc4953b560 | ||
|
|
fe3eae3680 | ||
|
|
2a7a22d953 | ||
|
|
f163b798ea | ||
|
|
d4563b8693 | ||
|
|
a54ed77140 | ||
|
|
f27979ef7f | ||
|
|
122a9af9b3 | ||
|
|
32a97e5479 | ||
|
|
bf30dab9c4 |
1
.github/pull_request_template.md
vendored
1
.github/pull_request_template.md
vendored
@@ -11,5 +11,4 @@
|
||||
Note: You have to check that the action passes, otherwise resolve the conflicts manually and tag the patches.
|
||||
|
||||
- [ ] This PR should be backported (make sure to check that the backport attempt succeeds)
|
||||
- [ ] I have included a link to a Linear ticket in my description.
|
||||
- [ ] [Optional] Override Linear Check
|
||||
|
||||
@@ -67,6 +67,7 @@ jobs:
|
||||
NEXT_PUBLIC_SENTRY_DSN=${{ secrets.SENTRY_DSN }}
|
||||
NEXT_PUBLIC_GTM_ENABLED=true
|
||||
NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED=true
|
||||
NODE_OPTIONS=--max-old-space-size=8192
|
||||
# needed due to weird interactions with the builds for different platforms
|
||||
no-cache: true
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
|
||||
@@ -60,6 +60,8 @@ jobs:
|
||||
push: true
|
||||
build-args: |
|
||||
ONYX_VERSION=${{ github.ref_name }}
|
||||
NODE_OPTIONS=--max-old-space-size=8192
|
||||
|
||||
# needed due to weird interactions with the builds for different platforms
|
||||
no-cache: true
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
|
||||
4
.gitignore
vendored
4
.gitignore
vendored
@@ -7,4 +7,6 @@
|
||||
.vscode/
|
||||
*.sw?
|
||||
/backend/tests/regression/answer_quality/search_test_config.yaml
|
||||
/web/test-results/
|
||||
/web/test-results/
|
||||
backend/onyx/agent_search/main/test_data.json
|
||||
backend/tests/regression/answer_quality/test_data.json
|
||||
|
||||
6
.vscode/env_template.txt
vendored
6
.vscode/env_template.txt
vendored
@@ -52,3 +52,9 @@ BING_API_KEY=<REPLACE THIS>
|
||||
# Enable the full set of Danswer Enterprise Edition features
|
||||
# NOTE: DO NOT ENABLE THIS UNLESS YOU HAVE A PAID ENTERPRISE LICENSE (or if you are using this for local testing/development)
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=False
|
||||
|
||||
# Agent Search configs # TODO: Remove give proper namings
|
||||
AGENT_RETRIEVAL_STATS=False # Note: This setting will incur substantial re-ranking effort
|
||||
AGENT_RERANKING_STATS=True
|
||||
AGENT_MAX_QUERY_RETRIEVAL_RESULTS=20
|
||||
AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS=20
|
||||
|
||||
@@ -119,7 +119,7 @@ There are two editions of Onyx:
|
||||
- Whitelabeling
|
||||
- API key authentication
|
||||
- Encryption of secrets
|
||||
- Any many more! Checkout [our website](https://www.onyx.app/) for the latest.
|
||||
- And many more! Checkout [our website](https://www.onyx.app/) for the latest.
|
||||
|
||||
To try the Onyx Enterprise Edition:
|
||||
|
||||
|
||||
@@ -0,0 +1,29 @@
|
||||
"""agent_doc_result_col
|
||||
|
||||
Revision ID: 1adf5ea20d2b
|
||||
Revises: e9cf2bd7baed
|
||||
Create Date: 2025-01-05 13:14:58.344316
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "1adf5ea20d2b"
|
||||
down_revision = "e9cf2bd7baed"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add the new column with JSONB type
|
||||
op.add_column(
|
||||
"sub_question",
|
||||
sa.Column("sub_question_doc_results", postgresql.JSONB(), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop the column
|
||||
op.drop_column("sub_question", "sub_question_doc_results")
|
||||
@@ -0,0 +1,35 @@
|
||||
"""agent_metric_col_rename__s
|
||||
|
||||
Revision ID: 925b58bd75b6
|
||||
Revises: 9787be927e58
|
||||
Create Date: 2025-01-06 11:20:26.752441
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "925b58bd75b6"
|
||||
down_revision = "9787be927e58"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Rename columns using PostgreSQL syntax
|
||||
op.alter_column(
|
||||
"agent__search_metrics", "base_duration_s", new_column_name="base_duration__s"
|
||||
)
|
||||
op.alter_column(
|
||||
"agent__search_metrics", "full_duration_s", new_column_name="full_duration__s"
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Revert the column renames
|
||||
op.alter_column(
|
||||
"agent__search_metrics", "base_duration__s", new_column_name="base_duration_s"
|
||||
)
|
||||
op.alter_column(
|
||||
"agent__search_metrics", "full_duration__s", new_column_name="full_duration_s"
|
||||
)
|
||||
@@ -0,0 +1,25 @@
|
||||
"""agent_metric_table_renames__agent__
|
||||
|
||||
Revision ID: 9787be927e58
|
||||
Revises: bceb76d618ec
|
||||
Create Date: 2025-01-06 11:01:44.210160
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "9787be927e58"
|
||||
down_revision = "bceb76d618ec"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Rename table from agent_search_metrics to agent__search_metrics
|
||||
op.rename_table("agent_search_metrics", "agent__search_metrics")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Rename table back from agent__search_metrics to agent_search_metrics
|
||||
op.rename_table("agent__search_metrics", "agent_search_metrics")
|
||||
42
backend/alembic/versions/98a5008d8711_agent_tracking.py
Normal file
42
backend/alembic/versions/98a5008d8711_agent_tracking.py
Normal file
@@ -0,0 +1,42 @@
|
||||
"""agent_tracking
|
||||
|
||||
Revision ID: 98a5008d8711
|
||||
Revises: f1ca58b2f2ec
|
||||
Create Date: 2025-01-04 14:41:52.732238
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "98a5008d8711"
|
||||
down_revision = "f1ca58b2f2ec"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"agent_search_metrics",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=True),
|
||||
sa.Column("persona_id", sa.Integer(), nullable=True),
|
||||
sa.Column("agent_type", sa.String(), nullable=False),
|
||||
sa.Column("start_time", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("base_duration_s", sa.Float(), nullable=False),
|
||||
sa.Column("full_duration_s", sa.Float(), nullable=False),
|
||||
sa.Column("base_metrics", postgresql.JSONB(), nullable=True),
|
||||
sa.Column("refined_metrics", postgresql.JSONB(), nullable=True),
|
||||
sa.Column("all_metrics", postgresql.JSONB(), nullable=True),
|
||||
sa.ForeignKeyConstraint(
|
||||
["persona_id"],
|
||||
["persona.id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("agent_search_metrics")
|
||||
@@ -0,0 +1,84 @@
|
||||
"""agent_table_renames__agent__
|
||||
|
||||
Revision ID: bceb76d618ec
|
||||
Revises: c0132518a25b
|
||||
Create Date: 2025-01-06 10:50:48.109285
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "bceb76d618ec"
|
||||
down_revision = "c0132518a25b"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.drop_constraint(
|
||||
"sub_query__search_doc_sub_query_id_fkey",
|
||||
"sub_query__search_doc",
|
||||
type_="foreignkey",
|
||||
)
|
||||
op.drop_constraint(
|
||||
"sub_query__search_doc_search_doc_id_fkey",
|
||||
"sub_query__search_doc",
|
||||
type_="foreignkey",
|
||||
)
|
||||
# Rename tables
|
||||
op.rename_table("sub_query", "agent__sub_query")
|
||||
op.rename_table("sub_question", "agent__sub_question")
|
||||
op.rename_table("sub_query__search_doc", "agent__sub_query__search_doc")
|
||||
|
||||
# Update both foreign key constraints for agent__sub_query__search_doc
|
||||
|
||||
# Create new foreign keys with updated names
|
||||
op.create_foreign_key(
|
||||
"agent__sub_query__search_doc_sub_query_id_fkey",
|
||||
"agent__sub_query__search_doc",
|
||||
"agent__sub_query",
|
||||
["sub_query_id"],
|
||||
["id"],
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"agent__sub_query__search_doc_search_doc_id_fkey",
|
||||
"agent__sub_query__search_doc",
|
||||
"search_doc", # This table name doesn't change
|
||||
["search_doc_id"],
|
||||
["id"],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Update foreign key constraints for sub_query__search_doc
|
||||
op.drop_constraint(
|
||||
"agent__sub_query__search_doc_sub_query_id_fkey",
|
||||
"agent__sub_query__search_doc",
|
||||
type_="foreignkey",
|
||||
)
|
||||
op.drop_constraint(
|
||||
"agent__sub_query__search_doc_search_doc_id_fkey",
|
||||
"agent__sub_query__search_doc",
|
||||
type_="foreignkey",
|
||||
)
|
||||
|
||||
# Rename tables back
|
||||
op.rename_table("agent__sub_query__search_doc", "sub_query__search_doc")
|
||||
op.rename_table("agent__sub_question", "sub_question")
|
||||
op.rename_table("agent__sub_query", "sub_query")
|
||||
|
||||
op.create_foreign_key(
|
||||
"sub_query__search_doc_sub_query_id_fkey",
|
||||
"sub_query__search_doc",
|
||||
"sub_query",
|
||||
["sub_query_id"],
|
||||
["id"],
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"sub_query__search_doc_search_doc_id_fkey",
|
||||
"sub_query__search_doc",
|
||||
"search_doc", # This table name doesn't change
|
||||
["search_doc_id"],
|
||||
["id"],
|
||||
)
|
||||
@@ -0,0 +1,40 @@
|
||||
"""agent_table_changes_rename_level
|
||||
|
||||
Revision ID: c0132518a25b
|
||||
Revises: 1adf5ea20d2b
|
||||
Create Date: 2025-01-05 16:38:37.660152
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "c0132518a25b"
|
||||
down_revision = "1adf5ea20d2b"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add level and level_question_nr columns with NOT NULL constraint
|
||||
op.add_column(
|
||||
"sub_question",
|
||||
sa.Column("level", sa.Integer(), nullable=False, server_default="0"),
|
||||
)
|
||||
op.add_column(
|
||||
"sub_question",
|
||||
sa.Column(
|
||||
"level_question_nr", sa.Integer(), nullable=False, server_default="0"
|
||||
),
|
||||
)
|
||||
|
||||
# Remove the server_default after the columns are created
|
||||
op.alter_column("sub_question", "level", server_default=None)
|
||||
op.alter_column("sub_question", "level_question_nr", server_default=None)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Remove the columns
|
||||
op.drop_column("sub_question", "level_question_nr")
|
||||
op.drop_column("sub_question", "level")
|
||||
@@ -0,0 +1,68 @@
|
||||
"""create pro search persistence tables
|
||||
|
||||
Revision ID: e9cf2bd7baed
|
||||
Revises: 98a5008d8711
|
||||
Create Date: 2025-01-02 17:55:56.544246
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "e9cf2bd7baed"
|
||||
down_revision = "98a5008d8711"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Create sub_question table
|
||||
op.create_table(
|
||||
"sub_question",
|
||||
sa.Column("id", sa.Integer, primary_key=True),
|
||||
sa.Column("primary_question_id", sa.Integer, sa.ForeignKey("chat_message.id")),
|
||||
sa.Column(
|
||||
"chat_session_id", UUID(as_uuid=True), sa.ForeignKey("chat_session.id")
|
||||
),
|
||||
sa.Column("sub_question", sa.Text),
|
||||
sa.Column(
|
||||
"time_created", sa.DateTime(timezone=True), server_default=sa.func.now()
|
||||
),
|
||||
sa.Column("sub_answer", sa.Text),
|
||||
)
|
||||
|
||||
# Create sub_query table
|
||||
op.create_table(
|
||||
"sub_query",
|
||||
sa.Column("id", sa.Integer, primary_key=True),
|
||||
sa.Column("parent_question_id", sa.Integer, sa.ForeignKey("sub_question.id")),
|
||||
sa.Column(
|
||||
"chat_session_id", UUID(as_uuid=True), sa.ForeignKey("chat_session.id")
|
||||
),
|
||||
sa.Column("sub_query", sa.Text),
|
||||
sa.Column(
|
||||
"time_created", sa.DateTime(timezone=True), server_default=sa.func.now()
|
||||
),
|
||||
)
|
||||
|
||||
# Create sub_query__search_doc association table
|
||||
op.create_table(
|
||||
"sub_query__search_doc",
|
||||
sa.Column(
|
||||
"sub_query_id", sa.Integer, sa.ForeignKey("sub_query.id"), primary_key=True
|
||||
),
|
||||
sa.Column(
|
||||
"search_doc_id",
|
||||
sa.Integer,
|
||||
sa.ForeignKey("search_doc.id"),
|
||||
primary_key=True,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("sub_query__search_doc")
|
||||
op.drop_table("sub_query")
|
||||
op.drop_table("sub_question")
|
||||
@@ -0,0 +1,33 @@
|
||||
"""add passthrough auth to tool
|
||||
|
||||
Revision ID: f1ca58b2f2ec
|
||||
Revises: c7bf5721733e
|
||||
Create Date: 2024-03-19
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "f1ca58b2f2ec"
|
||||
down_revision: Union[str, None] = "c7bf5721733e"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add passthrough_auth column to tool table with default value of False
|
||||
op.add_column(
|
||||
"tool",
|
||||
sa.Column(
|
||||
"passthrough_auth", sa.Boolean(), nullable=False, server_default=sa.false()
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Remove passthrough_auth column from tool table
|
||||
op.drop_column("tool", "passthrough_auth")
|
||||
@@ -98,10 +98,9 @@ def get_page_of_chat_sessions(
|
||||
conditions = _build_filter_conditions(start_time, end_time, feedback_filter)
|
||||
|
||||
subquery = (
|
||||
select(ChatSession.id, ChatSession.time_created)
|
||||
select(ChatSession.id)
|
||||
.filter(*conditions)
|
||||
.order_by(ChatSession.id, desc(ChatSession.time_created))
|
||||
.distinct(ChatSession.id)
|
||||
.order_by(desc(ChatSession.time_created), ChatSession.id)
|
||||
.limit(page_size)
|
||||
.offset(page_num * page_size)
|
||||
.subquery()
|
||||
@@ -118,7 +117,11 @@ def get_page_of_chat_sessions(
|
||||
ChatMessage.chat_message_feedbacks
|
||||
),
|
||||
)
|
||||
.order_by(desc(ChatSession.time_created), asc(ChatMessage.id))
|
||||
.order_by(
|
||||
desc(ChatSession.time_created),
|
||||
ChatSession.id,
|
||||
asc(ChatMessage.id), # Ensure chronological message order
|
||||
)
|
||||
)
|
||||
|
||||
return db_session.scalars(stmt).unique().all()
|
||||
|
||||
@@ -179,6 +179,7 @@ def handle_simplified_chat_message(
|
||||
chunks_below=0,
|
||||
full_doc=chat_message_req.full_doc,
|
||||
structured_response_format=chat_message_req.structured_response_format,
|
||||
use_agentic_search=chat_message_req.use_agentic_search,
|
||||
)
|
||||
|
||||
packets = stream_chat_message_objects(
|
||||
@@ -301,6 +302,7 @@ def handle_send_message_simple_with_history(
|
||||
chunks_below=0,
|
||||
full_doc=req.full_doc,
|
||||
structured_response_format=req.structured_response_format,
|
||||
use_agentic_search=req.use_agentic_search,
|
||||
)
|
||||
|
||||
packets = stream_chat_message_objects(
|
||||
|
||||
@@ -57,6 +57,9 @@ class BasicCreateChatMessageRequest(ChunkContext):
|
||||
# https://platform.openai.com/docs/guides/structured-outputs/introduction
|
||||
structured_response_format: dict | None = None
|
||||
|
||||
# If True, uses agentic search instead of basic search
|
||||
use_agentic_search: bool = False
|
||||
|
||||
|
||||
class BasicCreateChatMessageWithHistoryRequest(ChunkContext):
|
||||
# Last element is the new query. All previous elements are historical context
|
||||
@@ -71,6 +74,8 @@ class BasicCreateChatMessageWithHistoryRequest(ChunkContext):
|
||||
# only works if using an OpenAI model. See the following for more details:
|
||||
# https://platform.openai.com/docs/guides/structured-outputs/introduction
|
||||
structured_response_format: dict | None = None
|
||||
# If True, uses agentic search instead of basic search
|
||||
use_agentic_search: bool = False
|
||||
|
||||
|
||||
class SimpleDoc(BaseModel):
|
||||
@@ -123,6 +128,9 @@ class OneShotQARequest(ChunkContext):
|
||||
# If True, skips generative an AI response to the search query
|
||||
skip_gen_ai_answer_generation: bool = False
|
||||
|
||||
# If True, uses pro search instead of basic search
|
||||
use_agentic_search: bool = False
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_persona_fields(self) -> "OneShotQARequest":
|
||||
if self.persona_override_config is None and self.persona_id is None:
|
||||
|
||||
@@ -196,6 +196,8 @@ def get_answer_stream(
|
||||
retrieval_details=query_request.retrieval_options,
|
||||
rerank_settings=query_request.rerank_settings,
|
||||
db_session=db_session,
|
||||
use_agentic_search=query_request.use_agentic_search,
|
||||
skip_gen_ai_answer_generation=query_request.skip_gen_ai_answer_generation,
|
||||
)
|
||||
|
||||
packets = stream_chat_message_objects(
|
||||
|
||||
71
backend/onyx/agents/agent_search/basic/graph_builder.py
Normal file
71
backend/onyx/agents/agent_search/basic/graph_builder.py
Normal file
@@ -0,0 +1,71 @@
|
||||
from langgraph.graph import END
|
||||
from langgraph.graph import START
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
from onyx.agents.agent_search.basic.nodes.basic_use_tool_response import (
|
||||
basic_use_tool_response,
|
||||
)
|
||||
from onyx.agents.agent_search.basic.nodes.llm_tool_choice import llm_tool_choice
|
||||
from onyx.agents.agent_search.basic.nodes.tool_call import tool_call
|
||||
from onyx.agents.agent_search.basic.states import BasicInput
|
||||
from onyx.agents.agent_search.basic.states import BasicOutput
|
||||
from onyx.agents.agent_search.basic.states import BasicState
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def basic_graph_builder() -> StateGraph:
|
||||
graph = StateGraph(
|
||||
state_schema=BasicState,
|
||||
input=BasicInput,
|
||||
output=BasicOutput,
|
||||
)
|
||||
|
||||
### Add nodes ###
|
||||
|
||||
graph.add_node(
|
||||
node="llm_tool_choice",
|
||||
action=llm_tool_choice,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
node="tool_call",
|
||||
action=tool_call,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
node="basic_use_tool_response",
|
||||
action=basic_use_tool_response,
|
||||
)
|
||||
|
||||
### Add edges ###
|
||||
|
||||
graph.add_edge(start_key=START, end_key="llm_tool_choice")
|
||||
|
||||
graph.add_conditional_edges("llm_tool_choice", should_continue, ["tool_call", END])
|
||||
|
||||
graph.add_edge(
|
||||
start_key="tool_call",
|
||||
end_key="basic_use_tool_response",
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key="basic_use_tool_response",
|
||||
end_key=END,
|
||||
)
|
||||
|
||||
return graph
|
||||
|
||||
|
||||
def should_continue(state: BasicState) -> str:
|
||||
return (
|
||||
# If there are no tool calls, basic graph already streamed the answer
|
||||
END
|
||||
if state["tool_choice"] is None
|
||||
else "tool_call"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pass
|
||||
@@ -0,0 +1,57 @@
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.basic.states import BasicOutput
|
||||
from onyx.agents.agent_search.basic.states import BasicState
|
||||
from onyx.agents.agent_search.basic.utils import process_llm_stream
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.chat.models import LlmDoc
|
||||
from onyx.tools.tool_implementations.search_like_tool_utils import (
|
||||
FINAL_CONTEXT_DOCUMENTS_ID,
|
||||
)
|
||||
from onyx.tools.tool_implementations.search_like_tool_utils import (
|
||||
ORIGINAL_CONTEXT_DOCUMENTS_ID,
|
||||
)
|
||||
|
||||
|
||||
def basic_use_tool_response(state: BasicState, config: RunnableConfig) -> BasicOutput:
|
||||
agent_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
structured_response_format = agent_config.structured_response_format
|
||||
llm = agent_config.primary_llm
|
||||
tool_choice = state["tool_choice"]
|
||||
if tool_choice is None:
|
||||
raise ValueError("Tool choice is None")
|
||||
tool = tool_choice["tool"]
|
||||
prompt_builder = agent_config.prompt_builder
|
||||
tool_call_summary = state["tool_call_summary"]
|
||||
tool_call_responses = state["tool_call_responses"]
|
||||
state["tool_call_final_result"]
|
||||
new_prompt_builder = tool.build_next_prompt(
|
||||
prompt_builder=prompt_builder,
|
||||
tool_call_summary=tool_call_summary,
|
||||
tool_responses=tool_call_responses,
|
||||
using_tool_calling_llm=agent_config.using_tool_calling_llm,
|
||||
)
|
||||
|
||||
initial_search_results = []
|
||||
for yield_item in tool_call_responses:
|
||||
if yield_item.id == FINAL_CONTEXT_DOCUMENTS_ID:
|
||||
cast(list[LlmDoc], yield_item.response)
|
||||
elif yield_item.id == ORIGINAL_CONTEXT_DOCUMENTS_ID:
|
||||
search_contexts = yield_item.response.contexts
|
||||
for doc in search_contexts:
|
||||
if doc.document_id not in initial_search_results:
|
||||
initial_search_results.append(doc)
|
||||
|
||||
initial_search_results = cast(list[LlmDoc], initial_search_results)
|
||||
|
||||
stream = llm.stream(
|
||||
prompt=new_prompt_builder.build(),
|
||||
structured_response_format=structured_response_format,
|
||||
)
|
||||
|
||||
# For now, we don't do multiple tool calls, so we ignore the tool_message
|
||||
process_llm_stream(stream, True)
|
||||
|
||||
return BasicOutput()
|
||||
134
backend/onyx/agents/agent_search/basic/nodes/llm_tool_choice.py
Normal file
134
backend/onyx/agents/agent_search/basic/nodes/llm_tool_choice.py
Normal file
@@ -0,0 +1,134 @@
|
||||
from typing import cast
|
||||
from uuid import uuid4
|
||||
|
||||
from langchain_core.messages import ToolCall
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.basic.states import BasicState
|
||||
from onyx.agents.agent_search.basic.states import ToolChoice
|
||||
from onyx.agents.agent_search.basic.states import ToolChoiceUpdate
|
||||
from onyx.agents.agent_search.basic.utils import process_llm_stream
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.chat.tool_handling.tool_response_handler import get_tool_by_name
|
||||
from onyx.chat.tool_handling.tool_response_handler import (
|
||||
get_tool_call_for_non_tool_calling_llm_impl,
|
||||
)
|
||||
from onyx.tools.tool import Tool
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
# TODO: break this out into an implementation function
|
||||
# and a function that handles extracting the necessary fields
|
||||
# from the state and config
|
||||
# TODO: fan-out to multiple tool call nodes? Make this configurable?
|
||||
def llm_tool_choice(state: BasicState, config: RunnableConfig) -> ToolChoiceUpdate:
|
||||
"""
|
||||
This node is responsible for calling the LLM to choose a tool. If no tool is chosen,
|
||||
The node MAY emit an answer, depending on whether state["should_stream_answer"] is set.
|
||||
"""
|
||||
should_stream_answer = state["should_stream_answer"]
|
||||
|
||||
agent_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
using_tool_calling_llm = agent_config.using_tool_calling_llm
|
||||
prompt_builder = agent_config.prompt_builder
|
||||
llm = agent_config.primary_llm
|
||||
skip_gen_ai_answer_generation = agent_config.skip_gen_ai_answer_generation
|
||||
|
||||
structured_response_format = agent_config.structured_response_format
|
||||
tools = agent_config.tools or []
|
||||
force_use_tool = agent_config.force_use_tool
|
||||
|
||||
tool, tool_args = None, None
|
||||
if force_use_tool.force_use and force_use_tool.args is not None:
|
||||
tool_name, tool_args = (
|
||||
force_use_tool.tool_name,
|
||||
force_use_tool.args,
|
||||
)
|
||||
tool = get_tool_by_name(tools, tool_name)
|
||||
|
||||
# special pre-logic for non-tool calling LLM case
|
||||
elif not using_tool_calling_llm and tools:
|
||||
chosen_tool_and_args = get_tool_call_for_non_tool_calling_llm_impl(
|
||||
force_use_tool=force_use_tool,
|
||||
tools=tools,
|
||||
prompt_builder=prompt_builder,
|
||||
llm=llm,
|
||||
)
|
||||
if chosen_tool_and_args:
|
||||
tool, tool_args = chosen_tool_and_args
|
||||
|
||||
# If we have a tool and tool args, we are redy to request a tool call.
|
||||
# This only happens if the tool call was forced or we are using a non-tool calling LLM.
|
||||
if tool and tool_args:
|
||||
return ToolChoiceUpdate(
|
||||
tool_choice=ToolChoice(
|
||||
tool=tool,
|
||||
tool_args=tool_args,
|
||||
id=str(uuid4()),
|
||||
),
|
||||
)
|
||||
|
||||
# if we're skipping gen ai answer generation, we should only
|
||||
# continue if we're forcing a tool call (which will be emitted by
|
||||
# the tool calling llm in the stream() below)
|
||||
if skip_gen_ai_answer_generation and not force_use_tool.force_use:
|
||||
return ToolChoiceUpdate(
|
||||
tool_choice=None,
|
||||
)
|
||||
|
||||
# At this point, we are either using a tool calling LLM or we are skipping the tool call.
|
||||
# DEBUG: good breakpoint
|
||||
stream = llm.stream(
|
||||
# For tool calling LLMs, we want to insert the task prompt as part of this flow, this is because the LLM
|
||||
# may choose to not call any tools and just generate the answer, in which case the task prompt is needed.
|
||||
prompt=prompt_builder.build(),
|
||||
tools=[tool.tool_definition() for tool in tools] or None,
|
||||
tool_choice=("required" if tools and force_use_tool.force_use else None),
|
||||
structured_response_format=structured_response_format,
|
||||
)
|
||||
|
||||
tool_message = process_llm_stream(stream, should_stream_answer)
|
||||
|
||||
# If no tool calls are emitted by the LLM, we should not choose a tool
|
||||
if len(tool_message.tool_calls) == 0:
|
||||
return ToolChoiceUpdate(
|
||||
tool_choice=None,
|
||||
)
|
||||
|
||||
# TODO: here we could handle parallel tool calls. Right now
|
||||
# we just pick the first one that matches.
|
||||
selected_tool: Tool | None = None
|
||||
selected_tool_call_request: ToolCall | None = None
|
||||
for tool_call_request in tool_message.tool_calls:
|
||||
known_tools_by_name = [
|
||||
tool for tool in tools if tool.name == tool_call_request["name"]
|
||||
]
|
||||
|
||||
if known_tools_by_name:
|
||||
selected_tool = known_tools_by_name[0]
|
||||
selected_tool_call_request = tool_call_request
|
||||
break
|
||||
|
||||
logger.error(
|
||||
"Tool call requested with unknown name field. \n"
|
||||
f"tools: {tools}"
|
||||
f"tool_call_request: {tool_call_request}"
|
||||
)
|
||||
|
||||
if not selected_tool or not selected_tool_call_request:
|
||||
raise ValueError(
|
||||
f"Tool call attempted with tool {selected_tool}, request {selected_tool_call_request}"
|
||||
)
|
||||
|
||||
logger.info(f"Selected tool: {selected_tool.name}")
|
||||
logger.debug(f"Selected tool call request: {selected_tool_call_request}")
|
||||
|
||||
return ToolChoiceUpdate(
|
||||
tool_choice=ToolChoice(
|
||||
tool=selected_tool,
|
||||
tool_args=selected_tool_call_request["args"],
|
||||
id=selected_tool_call_request["id"],
|
||||
),
|
||||
)
|
||||
69
backend/onyx/agents/agent_search/basic/nodes/tool_call.py
Normal file
69
backend/onyx/agents/agent_search/basic/nodes/tool_call.py
Normal file
@@ -0,0 +1,69 @@
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.callbacks.manager import dispatch_custom_event
|
||||
from langchain_core.messages import AIMessageChunk
|
||||
from langchain_core.messages.tool import ToolCall
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.basic.states import BasicState
|
||||
from onyx.agents.agent_search.basic.states import ToolCallUpdate
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.chat.models import AnswerPacket
|
||||
from onyx.tools.message import build_tool_message
|
||||
from onyx.tools.message import ToolCallSummary
|
||||
from onyx.tools.tool_runner import ToolRunner
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def emit_packet(packet: AnswerPacket) -> None:
|
||||
dispatch_custom_event("basic_response", packet)
|
||||
|
||||
|
||||
# TODO: handle is_cancelled
|
||||
def tool_call(state: BasicState, config: RunnableConfig) -> ToolCallUpdate:
|
||||
"""Calls the tool specified in the state and updates the state with the result"""
|
||||
# TODO: implement
|
||||
|
||||
cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
# Unnecessary now, node should only be called if there is a tool call
|
||||
# if not self.tool_call_chunk or not self.tool_call_chunk.tool_calls:
|
||||
# return
|
||||
|
||||
tool_choice = state["tool_choice"]
|
||||
if tool_choice is None:
|
||||
raise ValueError("Cannot invoke tool call node without a tool choice")
|
||||
|
||||
tool = tool_choice["tool"]
|
||||
tool_args = tool_choice["tool_args"]
|
||||
tool_id = tool_choice["id"]
|
||||
tool_runner = ToolRunner(tool, tool_args)
|
||||
tool_kickoff = tool_runner.kickoff()
|
||||
|
||||
# TODO: custom events for yields
|
||||
emit_packet(tool_kickoff)
|
||||
|
||||
tool_responses = []
|
||||
for response in tool_runner.tool_responses():
|
||||
tool_responses.append(response)
|
||||
emit_packet(response)
|
||||
|
||||
tool_final_result = tool_runner.tool_final_result()
|
||||
emit_packet(tool_final_result)
|
||||
|
||||
tool_call = ToolCall(name=tool.name, args=tool_args, id=tool_id)
|
||||
tool_call_summary = ToolCallSummary(
|
||||
tool_call_request=AIMessageChunk(content="", tool_calls=[tool_call]),
|
||||
tool_call_result=build_tool_message(
|
||||
tool_call, tool_runner.tool_message_content()
|
||||
),
|
||||
)
|
||||
|
||||
return ToolCallUpdate(
|
||||
tool_call_summary=tool_call_summary,
|
||||
tool_call_kickoff=tool_kickoff,
|
||||
tool_call_responses=tool_responses,
|
||||
tool_call_final_result=tool_final_result,
|
||||
)
|
||||
55
backend/onyx/agents/agent_search/basic/states.py
Normal file
55
backend/onyx/agents/agent_search/basic/states.py
Normal file
@@ -0,0 +1,55 @@
|
||||
from typing import TypedDict
|
||||
|
||||
from onyx.tools.message import ToolCallSummary
|
||||
from onyx.tools.models import ToolCallFinalResult
|
||||
from onyx.tools.models import ToolCallKickoff
|
||||
from onyx.tools.models import ToolResponse
|
||||
from onyx.tools.tool import Tool
|
||||
|
||||
# States contain values that change over the course of graph execution,
|
||||
# Config is for values that are set at the start and never change.
|
||||
# If you are using a value from the config and realize it needs to change,
|
||||
# you should add it to the state and use/update the version in the state.
|
||||
|
||||
## Graph Input State
|
||||
|
||||
|
||||
class BasicInput(TypedDict):
|
||||
should_stream_answer: bool
|
||||
|
||||
|
||||
## Graph Output State
|
||||
|
||||
|
||||
class BasicOutput(TypedDict):
|
||||
pass
|
||||
|
||||
|
||||
## Update States
|
||||
class ToolCallUpdate(TypedDict):
|
||||
tool_call_summary: ToolCallSummary
|
||||
tool_call_kickoff: ToolCallKickoff
|
||||
tool_call_responses: list[ToolResponse]
|
||||
tool_call_final_result: ToolCallFinalResult
|
||||
|
||||
|
||||
class ToolChoice(TypedDict):
|
||||
tool: Tool
|
||||
tool_args: dict
|
||||
id: str | None
|
||||
|
||||
|
||||
class ToolChoiceUpdate(TypedDict):
|
||||
tool_choice: ToolChoice | None
|
||||
|
||||
|
||||
## Graph State
|
||||
|
||||
|
||||
class BasicState(
|
||||
BasicInput,
|
||||
ToolCallUpdate,
|
||||
ToolChoiceUpdate,
|
||||
BasicOutput,
|
||||
):
|
||||
pass
|
||||
52
backend/onyx/agents/agent_search/basic/utils.py
Normal file
52
backend/onyx/agents/agent_search/basic/utils.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from collections.abc import Iterator
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.callbacks.manager import dispatch_custom_event
|
||||
from langchain_core.messages import AIMessageChunk
|
||||
from langchain_core.messages import BaseMessage
|
||||
|
||||
from onyx.chat.models import LlmDoc
|
||||
from onyx.chat.models import OnyxAnswerPiece
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# TODO: handle citations here; below is what was previously passed in
|
||||
# see basic_use_tool_response.py for where these variables come from
|
||||
# answer_handler = CitationResponseHandler(
|
||||
# context_docs=final_search_results,
|
||||
# final_doc_id_to_rank_map=map_document_id_order(final_search_results),
|
||||
# display_doc_id_to_rank_map=map_document_id_order(displayed_search_results),
|
||||
# )
|
||||
|
||||
|
||||
def process_llm_stream(
|
||||
stream: Iterator[BaseMessage],
|
||||
should_stream_answer: bool,
|
||||
final_search_results: list[LlmDoc] | None = None,
|
||||
displayed_search_results: list[LlmDoc] | None = None,
|
||||
) -> AIMessageChunk:
|
||||
tool_call_chunk = AIMessageChunk(content="")
|
||||
# for response in response_handler_manager.handle_llm_response(stream):
|
||||
|
||||
# This stream will be the llm answer if no tool is chosen. When a tool is chosen,
|
||||
# the stream will contain AIMessageChunks with tool call information.
|
||||
for response in stream:
|
||||
answer_piece = response.content
|
||||
if not isinstance(answer_piece, str):
|
||||
# TODO: handle non-string content
|
||||
logger.warning(f"Received non-string content: {type(answer_piece)}")
|
||||
answer_piece = str(answer_piece)
|
||||
|
||||
if isinstance(response, AIMessageChunk) and (
|
||||
response.tool_call_chunks or response.tool_calls
|
||||
):
|
||||
tool_call_chunk += response # type: ignore
|
||||
elif should_stream_answer:
|
||||
# TODO: handle emitting of CitationInfo
|
||||
dispatch_custom_event(
|
||||
"basic_response",
|
||||
OnyxAnswerPiece(answer_piece=answer_piece),
|
||||
)
|
||||
|
||||
return cast(AIMessageChunk, tool_call_chunk)
|
||||
21
backend/onyx/agents/agent_search/core_state.py
Normal file
21
backend/onyx/agents/agent_search/core_state.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from operator import add
|
||||
from typing import Annotated
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class CoreState(BaseModel):
|
||||
"""
|
||||
This is the core state that is shared across all subgraphs.
|
||||
"""
|
||||
|
||||
base_question: str = ""
|
||||
log_messages: Annotated[list[str], add] = []
|
||||
|
||||
|
||||
class SubgraphCoreState(BaseModel):
|
||||
"""
|
||||
This is the core state that is shared across all subgraphs.
|
||||
"""
|
||||
|
||||
log_messages: Annotated[list[str], add]
|
||||
66
backend/onyx/agents/agent_search/db_operations.py
Normal file
66
backend/onyx/agents/agent_search/db_operations.py
Normal file
@@ -0,0 +1,66 @@
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.models import AgentSubQuery
|
||||
from onyx.db.models import AgentSubQuestion
|
||||
|
||||
|
||||
def create_sub_question(
|
||||
db_session: Session,
|
||||
chat_session_id: UUID,
|
||||
primary_message_id: int,
|
||||
sub_question: str,
|
||||
sub_answer: str,
|
||||
) -> AgentSubQuestion:
|
||||
"""Create a new sub-question record in the database."""
|
||||
sub_q = AgentSubQuestion(
|
||||
chat_session_id=chat_session_id,
|
||||
primary_question_id=primary_message_id,
|
||||
sub_question=sub_question,
|
||||
sub_answer=sub_answer,
|
||||
)
|
||||
db_session.add(sub_q)
|
||||
db_session.flush()
|
||||
return sub_q
|
||||
|
||||
|
||||
def create_sub_query(
|
||||
db_session: Session,
|
||||
chat_session_id: UUID,
|
||||
parent_question_id: int,
|
||||
sub_query: str,
|
||||
) -> AgentSubQuery:
|
||||
"""Create a new sub-query record in the database."""
|
||||
sub_q = AgentSubQuery(
|
||||
chat_session_id=chat_session_id,
|
||||
parent_question_id=parent_question_id,
|
||||
sub_query=sub_query,
|
||||
)
|
||||
db_session.add(sub_q)
|
||||
db_session.flush()
|
||||
return sub_q
|
||||
|
||||
|
||||
def get_sub_questions_for_message(
|
||||
db_session: Session,
|
||||
primary_message_id: int,
|
||||
) -> list[AgentSubQuestion]:
|
||||
"""Get all sub-questions for a given primary message."""
|
||||
return (
|
||||
db_session.query(AgentSubQuestion)
|
||||
.filter(AgentSubQuestion.primary_question_id == primary_message_id)
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
def get_sub_queries_for_question(
|
||||
db_session: Session,
|
||||
sub_question_id: int,
|
||||
) -> list[AgentSubQuery]:
|
||||
"""Get all sub-queries for a given sub-question."""
|
||||
return (
|
||||
db_session.query(AgentSubQuery)
|
||||
.filter(AgentSubQuery.parent_question_id == sub_question_id)
|
||||
.all()
|
||||
)
|
||||
@@ -0,0 +1,29 @@
|
||||
from collections.abc import Hashable
|
||||
from datetime import datetime
|
||||
|
||||
from langgraph.types import Send
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.states import (
|
||||
AnswerQuestionInput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
|
||||
ExpandedRetrievalInput,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def send_to_expanded_retrieval(state: AnswerQuestionInput) -> Send | Hashable:
|
||||
logger.debug("sending to expanded retrieval via edge")
|
||||
now_start = datetime.now()
|
||||
|
||||
return Send(
|
||||
"initial_sub_question_expanded_retrieval",
|
||||
ExpandedRetrievalInput(
|
||||
question=state.question,
|
||||
base_search=False,
|
||||
sub_question_id=state.question_id,
|
||||
log_messages=[f"{now_start} -- Sending to expanded retrieval"],
|
||||
),
|
||||
)
|
||||
@@ -0,0 +1,126 @@
|
||||
from langgraph.graph import END
|
||||
from langgraph.graph import START
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.edges import (
|
||||
send_to_expanded_retrieval,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.nodes.answer_check import (
|
||||
answer_check,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.nodes.answer_generation import (
|
||||
answer_generation,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.nodes.format_answer import (
|
||||
format_answer,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.nodes.ingest_retrieval import (
|
||||
ingest_retrieval,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.states import (
|
||||
AnswerQuestionInput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.states import (
|
||||
AnswerQuestionOutput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.states import (
|
||||
AnswerQuestionState,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.graph_builder import (
|
||||
expanded_retrieval_graph_builder,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def answer_query_graph_builder() -> StateGraph:
|
||||
graph = StateGraph(
|
||||
state_schema=AnswerQuestionState,
|
||||
input=AnswerQuestionInput,
|
||||
output=AnswerQuestionOutput,
|
||||
)
|
||||
|
||||
### Add nodes ###
|
||||
|
||||
expanded_retrieval = expanded_retrieval_graph_builder().compile()
|
||||
graph.add_node(
|
||||
node="initial_sub_question_expanded_retrieval",
|
||||
action=expanded_retrieval,
|
||||
)
|
||||
graph.add_node(
|
||||
node="answer_check",
|
||||
action=answer_check,
|
||||
)
|
||||
graph.add_node(
|
||||
node="answer_generation",
|
||||
action=answer_generation,
|
||||
)
|
||||
graph.add_node(
|
||||
node="format_answer",
|
||||
action=format_answer,
|
||||
)
|
||||
graph.add_node(
|
||||
node="ingest_retrieval",
|
||||
action=ingest_retrieval,
|
||||
)
|
||||
|
||||
### Add edges ###
|
||||
|
||||
graph.add_conditional_edges(
|
||||
source=START,
|
||||
path=send_to_expanded_retrieval,
|
||||
path_map=["initial_sub_question_expanded_retrieval"],
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="initial_sub_question_expanded_retrieval",
|
||||
end_key="ingest_retrieval",
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="ingest_retrieval",
|
||||
end_key="answer_generation",
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="answer_generation",
|
||||
end_key="answer_check",
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="answer_check",
|
||||
end_key="format_answer",
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="format_answer",
|
||||
end_key=END,
|
||||
)
|
||||
|
||||
return graph
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.llm.factory import get_default_llms
|
||||
from onyx.context.search.models import SearchRequest
|
||||
|
||||
graph = answer_query_graph_builder()
|
||||
compiled_graph = graph.compile()
|
||||
primary_llm, fast_llm = get_default_llms()
|
||||
search_request = SearchRequest(
|
||||
query="what can you do with onyx or danswer?",
|
||||
)
|
||||
with get_session_context_manager() as db_session:
|
||||
agent_search_config, search_tool = get_test_config(
|
||||
db_session, primary_llm, fast_llm, search_request
|
||||
)
|
||||
inputs = AnswerQuestionInput(
|
||||
question="what can you do with onyx?",
|
||||
question_id="0_0",
|
||||
log_messages=[],
|
||||
)
|
||||
for thing in compiled_graph.stream(
|
||||
input=inputs,
|
||||
config={"configurable": {"config": agent_search_config}},
|
||||
# debug=True,
|
||||
# subgraphs=True,
|
||||
):
|
||||
logger.debug(thing)
|
||||
@@ -0,0 +1,8 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
### Models ###
|
||||
|
||||
|
||||
class AnswerRetrievalStats(BaseModel):
|
||||
answer_retrieval_stats: dict[str, float | int]
|
||||
@@ -0,0 +1,59 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import merge_message_runs
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.states import (
|
||||
AnswerQuestionState,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.states import (
|
||||
QACheckUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import SUB_CHECK_NO
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import SUB_CHECK_PROMPT
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import UNKNOWN_ANSWER
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
|
||||
|
||||
|
||||
def answer_check(state: AnswerQuestionState, config: RunnableConfig) -> QACheckUpdate:
|
||||
now_start = datetime.now()
|
||||
|
||||
level, question_num = parse_question_id(state.question_id)
|
||||
if state.answer == UNKNOWN_ANSWER:
|
||||
now_end = datetime.now()
|
||||
return QACheckUpdate(
|
||||
answer_quality=SUB_CHECK_NO,
|
||||
log_messages=[
|
||||
f"{now_end} -- Answer check SQ-{level}-{question_num} - unknown answer, Time taken: {now_end - now_start}"
|
||||
],
|
||||
)
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=SUB_CHECK_PROMPT.format(
|
||||
question=state.question,
|
||||
base_answer=state.answer,
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
agent_searchch_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
fast_llm = agent_searchch_config.fast_llm
|
||||
response = list(
|
||||
fast_llm.stream(
|
||||
prompt=msg,
|
||||
)
|
||||
)
|
||||
|
||||
quality_str = merge_message_runs(response, chunk_separator="")[0].content
|
||||
|
||||
now_end = datetime.now()
|
||||
return QACheckUpdate(
|
||||
answer_quality=quality_str,
|
||||
log_messages=[
|
||||
f"""{now_end} -- Answer check SQ-{level}-{question_num} - Answer quality: {quality_str},
|
||||
Time taken: {now_end - now_start}"""
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,116 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.callbacks.manager import dispatch_custom_event
|
||||
from langchain_core.messages import merge_message_runs
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.states import (
|
||||
AnswerQuestionState,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.states import (
|
||||
QAGenerationUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
build_sub_question_answer_prompt,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import (
|
||||
ASSISTANT_SYSTEM_PROMPT_DEFAULT,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import (
|
||||
ASSISTANT_SYSTEM_PROMPT_PERSONA,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import UNKNOWN_ANSWER
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import get_persona_prompt
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
|
||||
from onyx.chat.models import AgentAnswerPiece
|
||||
from onyx.chat.models import StreamStopInfo
|
||||
from onyx.chat.models import StreamStopReason
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def answer_generation(
|
||||
state: AnswerQuestionState, config: RunnableConfig
|
||||
) -> QAGenerationUpdate:
|
||||
now_start = datetime.now()
|
||||
logger.debug(f"--------{now_start}--------START ANSWER GENERATION---")
|
||||
|
||||
agent_search_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
question = state.question
|
||||
docs = state.documents
|
||||
level, question_nr = parse_question_id(state.question_id)
|
||||
context_docs = state.context_documents
|
||||
persona_prompt = get_persona_prompt(agent_search_config.search_request.persona)
|
||||
|
||||
if len(context_docs) == 0:
|
||||
answer_str = UNKNOWN_ANSWER
|
||||
dispatch_custom_event(
|
||||
"sub_answers",
|
||||
AgentAnswerPiece(
|
||||
answer_piece=answer_str,
|
||||
level=level,
|
||||
level_question_nr=question_nr,
|
||||
answer_type="agent_sub_answer",
|
||||
),
|
||||
)
|
||||
else:
|
||||
if len(persona_prompt) > 0:
|
||||
persona_specification = ASSISTANT_SYSTEM_PROMPT_DEFAULT
|
||||
else:
|
||||
persona_specification = ASSISTANT_SYSTEM_PROMPT_PERSONA.format(
|
||||
persona_prompt=persona_prompt
|
||||
)
|
||||
|
||||
logger.debug(f"Number of verified retrieval docs: {len(docs)}")
|
||||
|
||||
fast_llm = agent_search_config.fast_llm
|
||||
msg = build_sub_question_answer_prompt(
|
||||
question=question,
|
||||
original_question=agent_search_config.search_request.query,
|
||||
docs=docs,
|
||||
persona_specification=persona_specification,
|
||||
config=fast_llm.config,
|
||||
)
|
||||
|
||||
response: list[str | list[str | dict[str, Any]]] = []
|
||||
for message in fast_llm.stream(
|
||||
prompt=msg,
|
||||
):
|
||||
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
|
||||
content = message.content
|
||||
if not isinstance(content, str):
|
||||
raise ValueError(
|
||||
f"Expected content to be a string, but got {type(content)}"
|
||||
)
|
||||
dispatch_custom_event(
|
||||
"sub_answers",
|
||||
AgentAnswerPiece(
|
||||
answer_piece=content,
|
||||
level=level,
|
||||
level_question_nr=question_nr,
|
||||
answer_type="agent_sub_answer",
|
||||
),
|
||||
)
|
||||
response.append(content)
|
||||
|
||||
answer_str = merge_message_runs(response, chunk_separator="")[0].content
|
||||
|
||||
stop_event = StreamStopInfo(
|
||||
stop_reason=StreamStopReason.FINISHED,
|
||||
stream_type="sub_answer",
|
||||
level=level,
|
||||
level_question_nr=question_nr,
|
||||
)
|
||||
dispatch_custom_event("stream_finished", stop_event)
|
||||
|
||||
now_end = datetime.now()
|
||||
return QAGenerationUpdate(
|
||||
answer=answer_str,
|
||||
log_messages=[
|
||||
f"{now_end} -- Answer generation SQ-{level} - Q{question_nr} - Time taken: {now_end - now_start}"
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,27 @@
|
||||
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.states import (
|
||||
AnswerQuestionOutput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.states import (
|
||||
AnswerQuestionState,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import (
|
||||
QuestionAnswerResults,
|
||||
)
|
||||
|
||||
|
||||
def format_answer(state: AnswerQuestionState) -> AnswerQuestionOutput:
|
||||
return AnswerQuestionOutput(
|
||||
answer_results=[
|
||||
QuestionAnswerResults(
|
||||
question=state.question,
|
||||
question_id=state.question_id,
|
||||
quality=state.answer_quality
|
||||
if hasattr(state, "answer_quality")
|
||||
else "No",
|
||||
answer=state.answer,
|
||||
expanded_retrieval_results=state.expanded_retrieval_results,
|
||||
documents=state.documents,
|
||||
sub_question_retrieval_stats=state.sub_question_retrieval_stats,
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,22 @@
|
||||
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.states import (
|
||||
RetrievalIngestionUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
|
||||
ExpandedRetrievalOutput,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkStats
|
||||
|
||||
|
||||
def ingest_retrieval(state: ExpandedRetrievalOutput) -> RetrievalIngestionUpdate:
|
||||
sub_question_retrieval_stats = (
|
||||
state.expanded_retrieval_result.sub_question_retrieval_stats
|
||||
)
|
||||
if sub_question_retrieval_stats is None:
|
||||
sub_question_retrieval_stats = [AgentChunkStats()]
|
||||
|
||||
return RetrievalIngestionUpdate(
|
||||
expanded_retrieval_results=state.expanded_retrieval_result.expanded_queries_results,
|
||||
documents=state.expanded_retrieval_result.all_documents,
|
||||
context_documents=state.expanded_retrieval_result.context_documents,
|
||||
sub_question_retrieval_stats=sub_question_retrieval_stats,
|
||||
)
|
||||
@@ -0,0 +1,71 @@
|
||||
from operator import add
|
||||
from typing import Annotated
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.agents.agent_search.core_state import SubgraphCoreState
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkStats
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import QueryResult
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import (
|
||||
QuestionAnswerResults,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.operators import (
|
||||
dedup_inference_sections,
|
||||
)
|
||||
from onyx.context.search.models import InferenceSection
|
||||
|
||||
|
||||
## Update States
|
||||
class QACheckUpdate(BaseModel):
|
||||
answer_quality: str = ""
|
||||
log_messages: list[str] = []
|
||||
|
||||
|
||||
class QAGenerationUpdate(BaseModel):
|
||||
answer: str = ""
|
||||
log_messages: list[str] = []
|
||||
# answer_stat: AnswerStats
|
||||
|
||||
|
||||
class RetrievalIngestionUpdate(BaseModel):
|
||||
expanded_retrieval_results: list[QueryResult] = []
|
||||
documents: Annotated[list[InferenceSection], dedup_inference_sections] = []
|
||||
context_documents: Annotated[list[InferenceSection], dedup_inference_sections] = []
|
||||
sub_question_retrieval_stats: AgentChunkStats = AgentChunkStats()
|
||||
|
||||
|
||||
## Graph Input State
|
||||
|
||||
|
||||
class AnswerQuestionInput(SubgraphCoreState):
|
||||
question: str = ""
|
||||
question_id: str = (
|
||||
"" # 0_0 is original question, everything else is <level>_<question_num>.
|
||||
)
|
||||
# level 0 is original question and first decomposition, level 1 is follow up, etc
|
||||
# question_num is a unique number per original question per level.
|
||||
|
||||
|
||||
## Graph State
|
||||
|
||||
|
||||
class AnswerQuestionState(
|
||||
AnswerQuestionInput,
|
||||
QAGenerationUpdate,
|
||||
QACheckUpdate,
|
||||
RetrievalIngestionUpdate,
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
## Graph Output State
|
||||
|
||||
|
||||
class AnswerQuestionOutput(BaseModel):
|
||||
"""
|
||||
This is a list of results even though each call of this subgraph only returns one result.
|
||||
This is because if we parallelize the answer query subgraph, there will be multiple
|
||||
results in a list so the add operator is used to add them together.
|
||||
"""
|
||||
|
||||
answer_results: Annotated[list[QuestionAnswerResults], add] = []
|
||||
@@ -0,0 +1,28 @@
|
||||
from collections.abc import Hashable
|
||||
from datetime import datetime
|
||||
|
||||
from langgraph.types import Send
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.states import (
|
||||
AnswerQuestionInput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
|
||||
ExpandedRetrievalInput,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def send_to_expanded_refined_retrieval(state: AnswerQuestionInput) -> Send | Hashable:
|
||||
logger.debug("sending to expanded retrieval for follow up question via edge")
|
||||
datetime.now()
|
||||
return Send(
|
||||
"refined_sub_question_expanded_retrieval",
|
||||
ExpandedRetrievalInput(
|
||||
question=state.question,
|
||||
sub_question_id=state.question_id,
|
||||
base_search=False,
|
||||
log_messages=[f"{datetime.now()} -- Sending to expanded retrieval"],
|
||||
),
|
||||
)
|
||||
@@ -0,0 +1,123 @@
|
||||
from langgraph.graph import END
|
||||
from langgraph.graph import START
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.nodes.answer_check import (
|
||||
answer_check,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.nodes.answer_generation import (
|
||||
answer_generation,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.nodes.format_answer import (
|
||||
format_answer,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.nodes.ingest_retrieval import (
|
||||
ingest_retrieval,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.states import (
|
||||
AnswerQuestionInput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.states import (
|
||||
AnswerQuestionOutput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.states import (
|
||||
AnswerQuestionState,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.answer_refinement_sub_question.edges import (
|
||||
send_to_expanded_refined_retrieval,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.graph_builder import (
|
||||
expanded_retrieval_graph_builder,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def answer_refined_query_graph_builder() -> StateGraph:
|
||||
graph = StateGraph(
|
||||
state_schema=AnswerQuestionState,
|
||||
input=AnswerQuestionInput,
|
||||
output=AnswerQuestionOutput,
|
||||
)
|
||||
|
||||
### Add nodes ###
|
||||
|
||||
expanded_retrieval = expanded_retrieval_graph_builder().compile()
|
||||
graph.add_node(
|
||||
node="refined_sub_question_expanded_retrieval",
|
||||
action=expanded_retrieval,
|
||||
)
|
||||
graph.add_node(
|
||||
node="refined_sub_answer_check",
|
||||
action=answer_check,
|
||||
)
|
||||
graph.add_node(
|
||||
node="refined_sub_answer_generation",
|
||||
action=answer_generation,
|
||||
)
|
||||
graph.add_node(
|
||||
node="format_refined_sub_answer",
|
||||
action=format_answer,
|
||||
)
|
||||
graph.add_node(
|
||||
node="ingest_refined_retrieval",
|
||||
action=ingest_retrieval,
|
||||
)
|
||||
|
||||
### Add edges ###
|
||||
|
||||
graph.add_conditional_edges(
|
||||
source=START,
|
||||
path=send_to_expanded_refined_retrieval,
|
||||
path_map=["refined_sub_question_expanded_retrieval"],
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="refined_sub_question_expanded_retrieval",
|
||||
end_key="ingest_refined_retrieval",
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="ingest_refined_retrieval",
|
||||
end_key="refined_sub_answer_generation",
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="refined_sub_answer_generation",
|
||||
end_key="refined_sub_answer_check",
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="refined_sub_answer_check",
|
||||
end_key="format_refined_sub_answer",
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="format_refined_sub_answer",
|
||||
end_key=END,
|
||||
)
|
||||
|
||||
return graph
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.llm.factory import get_default_llms
|
||||
from onyx.context.search.models import SearchRequest
|
||||
|
||||
graph = answer_refined_query_graph_builder()
|
||||
compiled_graph = graph.compile()
|
||||
primary_llm, fast_llm = get_default_llms()
|
||||
search_request = SearchRequest(
|
||||
query="what can you do with onyx or danswer?",
|
||||
)
|
||||
with get_session_context_manager() as db_session:
|
||||
inputs = AnswerQuestionInput(
|
||||
question="what can you do with onyx?",
|
||||
question_id="0_0",
|
||||
log_messages=[],
|
||||
)
|
||||
for thing in compiled_graph.stream(
|
||||
input=inputs,
|
||||
# debug=True,
|
||||
# subgraphs=True,
|
||||
):
|
||||
logger.debug(thing)
|
||||
# output = compiled_graph.invoke(inputs)
|
||||
# logger.debug(output)
|
||||
@@ -0,0 +1,19 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkStats
|
||||
from onyx.context.search.models import InferenceSection
|
||||
|
||||
### Models ###
|
||||
|
||||
|
||||
class AnswerRetrievalStats(BaseModel):
|
||||
answer_retrieval_stats: dict[str, float | int]
|
||||
|
||||
|
||||
class QuestionAnswerResults(BaseModel):
|
||||
question: str
|
||||
answer: str
|
||||
quality: str
|
||||
# expanded_retrieval_results: list[QueryResult]
|
||||
documents: list[InferenceSection]
|
||||
sub_question_retrieval_stats: AgentChunkStats
|
||||
@@ -0,0 +1,76 @@
|
||||
from langgraph.graph import END
|
||||
from langgraph.graph import START
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.base_raw_search.nodes.format_raw_search_results import (
|
||||
format_raw_search_results,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.base_raw_search.nodes.generate_raw_search_data import (
|
||||
generate_raw_search_data,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.base_raw_search.states import (
|
||||
BaseRawSearchInput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.base_raw_search.states import (
|
||||
BaseRawSearchOutput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.base_raw_search.states import (
|
||||
BaseRawSearchState,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.graph_builder import (
|
||||
expanded_retrieval_graph_builder,
|
||||
)
|
||||
|
||||
|
||||
def base_raw_search_graph_builder() -> StateGraph:
|
||||
graph = StateGraph(
|
||||
state_schema=BaseRawSearchState,
|
||||
input=BaseRawSearchInput,
|
||||
output=BaseRawSearchOutput,
|
||||
)
|
||||
|
||||
### Add nodes ###
|
||||
|
||||
graph.add_node(
|
||||
node="generate_raw_search_data",
|
||||
action=generate_raw_search_data,
|
||||
)
|
||||
|
||||
expanded_retrieval = expanded_retrieval_graph_builder().compile()
|
||||
graph.add_node(
|
||||
node="expanded_retrieval_base_search",
|
||||
action=expanded_retrieval,
|
||||
)
|
||||
graph.add_node(
|
||||
node="format_raw_search_results",
|
||||
action=format_raw_search_results,
|
||||
)
|
||||
|
||||
### Add edges ###
|
||||
|
||||
graph.add_edge(start_key=START, end_key="generate_raw_search_data")
|
||||
|
||||
graph.add_edge(
|
||||
start_key="generate_raw_search_data",
|
||||
end_key="expanded_retrieval_base_search",
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="expanded_retrieval_base_search",
|
||||
end_key="format_raw_search_results",
|
||||
)
|
||||
|
||||
# graph.add_edge(
|
||||
# start_key="expanded_retrieval_base_search",
|
||||
# end_key=END,
|
||||
# )
|
||||
|
||||
graph.add_edge(
|
||||
start_key="format_raw_search_results",
|
||||
end_key=END,
|
||||
)
|
||||
|
||||
return graph
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pass
|
||||
@@ -0,0 +1,20 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkStats
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import QueryResult
|
||||
from onyx.context.search.models import InferenceSection
|
||||
|
||||
### Models ###
|
||||
|
||||
|
||||
class AnswerRetrievalStats(BaseModel):
|
||||
answer_retrieval_stats: dict[str, float | int]
|
||||
|
||||
|
||||
class QuestionAnswerResults(BaseModel):
|
||||
question: str
|
||||
answer: str
|
||||
quality: str
|
||||
expanded_retrieval_results: list[QueryResult]
|
||||
documents: list[InferenceSection]
|
||||
sub_question_retrieval_stats: list[AgentChunkStats]
|
||||
@@ -0,0 +1,18 @@
|
||||
from onyx.agents.agent_search.deep_search_a.base_raw_search.states import (
|
||||
BaseRawSearchOutput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
|
||||
ExpandedRetrievalOutput,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def format_raw_search_results(state: ExpandedRetrievalOutput) -> BaseRawSearchOutput:
|
||||
logger.debug("format_raw_search_results")
|
||||
return BaseRawSearchOutput(
|
||||
base_expanded_retrieval_result=state.expanded_retrieval_result,
|
||||
# base_retrieval_results=[state.expanded_retrieval_result],
|
||||
# base_search_documents=[],
|
||||
)
|
||||
@@ -0,0 +1,25 @@
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.core_state import CoreState
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
|
||||
ExpandedRetrievalInput,
|
||||
)
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def generate_raw_search_data(
|
||||
state: CoreState, config: RunnableConfig
|
||||
) -> ExpandedRetrievalInput:
|
||||
logger.debug("generate_raw_search_data")
|
||||
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
return ExpandedRetrievalInput(
|
||||
question=agent_a_config.search_request.query,
|
||||
base_search=True,
|
||||
sub_question_id=None, # This graph is always and only used for the original question
|
||||
log_messages=[],
|
||||
)
|
||||
@@ -0,0 +1,43 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.models import (
|
||||
ExpandedRetrievalResult,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
|
||||
ExpandedRetrievalInput,
|
||||
)
|
||||
|
||||
|
||||
## Update States
|
||||
|
||||
|
||||
## Graph Input State
|
||||
|
||||
|
||||
class BaseRawSearchInput(ExpandedRetrievalInput):
|
||||
pass
|
||||
|
||||
|
||||
## Graph Output State
|
||||
|
||||
|
||||
class BaseRawSearchOutput(BaseModel):
|
||||
"""
|
||||
This is a list of results even though each call of this subgraph only returns one result.
|
||||
This is because if we parallelize the answer query subgraph, there will be multiple
|
||||
results in a list so the add operator is used to add them together.
|
||||
"""
|
||||
|
||||
# base_search_documents: Annotated[list[InferenceSection], dedup_inference_sections]
|
||||
# base_retrieval_results: Annotated[list[ExpandedRetrievalResult], add]
|
||||
base_expanded_retrieval_result: ExpandedRetrievalResult = ExpandedRetrievalResult()
|
||||
|
||||
|
||||
## Graph State
|
||||
|
||||
|
||||
class BaseRawSearchState(
|
||||
BaseRawSearchInput,
|
||||
BaseRawSearchOutput,
|
||||
):
|
||||
pass
|
||||
@@ -0,0 +1,37 @@
|
||||
from collections.abc import Hashable
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
from langgraph.types import Send
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
|
||||
ExpandedRetrievalState,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
|
||||
RetrievalInput,
|
||||
)
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
|
||||
|
||||
def parallel_retrieval_edge(
|
||||
state: ExpandedRetrievalState, config: RunnableConfig
|
||||
) -> list[Send | Hashable]:
|
||||
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
question = state.question if state.question else agent_a_config.search_request.query
|
||||
|
||||
query_expansions = (
|
||||
state.expanded_queries if state.expanded_queries else [] + [question]
|
||||
)
|
||||
return [
|
||||
Send(
|
||||
"doc_retrieval",
|
||||
RetrievalInput(
|
||||
query_to_retrieve=query,
|
||||
question=question,
|
||||
base_search=False,
|
||||
sub_question_id=state.sub_question_id,
|
||||
log_messages=[],
|
||||
),
|
||||
)
|
||||
for query in query_expansions
|
||||
]
|
||||
@@ -0,0 +1,147 @@
|
||||
from langgraph.graph import END
|
||||
from langgraph.graph import START
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.edges import (
|
||||
parallel_retrieval_edge,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.nodes.doc_reranking import (
|
||||
doc_reranking,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.nodes.doc_retrieval import (
|
||||
doc_retrieval,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.nodes.doc_verification import (
|
||||
doc_verification,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.nodes.dummy import (
|
||||
dummy,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.nodes.expand_queries import (
|
||||
expand_queries,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.nodes.format_results import (
|
||||
format_results,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.nodes.verification_kickoff import (
|
||||
verification_kickoff,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
|
||||
ExpandedRetrievalInput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
|
||||
ExpandedRetrievalOutput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
|
||||
ExpandedRetrievalState,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def expanded_retrieval_graph_builder() -> StateGraph:
|
||||
graph = StateGraph(
|
||||
state_schema=ExpandedRetrievalState,
|
||||
input=ExpandedRetrievalInput,
|
||||
output=ExpandedRetrievalOutput,
|
||||
)
|
||||
|
||||
### Add nodes ###
|
||||
|
||||
graph.add_node(
|
||||
node="expand_queries",
|
||||
action=expand_queries,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
node="dummy",
|
||||
action=dummy,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
node="doc_retrieval",
|
||||
action=doc_retrieval,
|
||||
)
|
||||
graph.add_node(
|
||||
node="verification_kickoff",
|
||||
action=verification_kickoff,
|
||||
)
|
||||
graph.add_node(
|
||||
node="doc_verification",
|
||||
action=doc_verification,
|
||||
)
|
||||
graph.add_node(
|
||||
node="doc_reranking",
|
||||
action=doc_reranking,
|
||||
)
|
||||
graph.add_node(
|
||||
node="format_results",
|
||||
action=format_results,
|
||||
)
|
||||
|
||||
### Add edges ###
|
||||
graph.add_edge(
|
||||
start_key=START,
|
||||
end_key="expand_queries",
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="expand_queries",
|
||||
end_key="dummy",
|
||||
)
|
||||
|
||||
graph.add_conditional_edges(
|
||||
source="dummy",
|
||||
path=parallel_retrieval_edge,
|
||||
path_map=["doc_retrieval"],
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="doc_retrieval",
|
||||
end_key="verification_kickoff",
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="doc_verification",
|
||||
end_key="doc_reranking",
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="doc_reranking",
|
||||
end_key="format_results",
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="format_results",
|
||||
end_key=END,
|
||||
)
|
||||
|
||||
return graph
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.llm.factory import get_default_llms
|
||||
from onyx.context.search.models import SearchRequest
|
||||
|
||||
graph = expanded_retrieval_graph_builder()
|
||||
compiled_graph = graph.compile()
|
||||
primary_llm, fast_llm = get_default_llms()
|
||||
search_request = SearchRequest(
|
||||
query="what can you do with onyx or danswer?",
|
||||
)
|
||||
|
||||
with get_session_context_manager() as db_session:
|
||||
agent_a_config, search_tool = get_test_config(
|
||||
db_session, primary_llm, fast_llm, search_request
|
||||
)
|
||||
inputs = ExpandedRetrievalInput(
|
||||
question="what can you do with onyx?",
|
||||
base_search=False,
|
||||
sub_question_id=None,
|
||||
log_messages=[],
|
||||
)
|
||||
for thing in compiled_graph.stream(
|
||||
input=inputs,
|
||||
config={"configurable": {"config": agent_a_config}},
|
||||
# debug=True,
|
||||
subgraphs=True,
|
||||
):
|
||||
logger.debug(thing)
|
||||
@@ -0,0 +1,12 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkStats
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import QueryResult
|
||||
from onyx.context.search.models import InferenceSection
|
||||
|
||||
|
||||
class ExpandedRetrievalResult(BaseModel):
|
||||
expanded_queries_results: list[QueryResult] = []
|
||||
all_documents: list[InferenceSection] = []
|
||||
context_documents: list[InferenceSection] = []
|
||||
sub_question_retrieval_stats: AgentChunkStats = AgentChunkStats()
|
||||
@@ -0,0 +1,74 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.operations import logger
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
|
||||
DocRerankingUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
|
||||
ExpandedRetrievalState,
|
||||
)
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.calculations import get_fit_scores
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import RetrievalFitStats
|
||||
from onyx.configs.dev_configs import AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS
|
||||
from onyx.configs.dev_configs import AGENT_RERANKING_STATS
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.context.search.models import SearchRequest
|
||||
from onyx.context.search.pipeline import retrieval_preprocessing
|
||||
from onyx.context.search.postprocessing.postprocessing import rerank_sections
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
|
||||
|
||||
def doc_reranking(
|
||||
state: ExpandedRetrievalState, config: RunnableConfig
|
||||
) -> DocRerankingUpdate:
|
||||
now_start = datetime.now()
|
||||
verified_documents = state.verified_documents
|
||||
|
||||
# Rerank post retrieval and verification. First, create a search query
|
||||
# then create the list of reranked sections
|
||||
|
||||
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
question = state.question if state.question else agent_a_config.search_request.query
|
||||
with get_session_context_manager() as db_session:
|
||||
_search_query = retrieval_preprocessing(
|
||||
search_request=SearchRequest(query=question),
|
||||
user=agent_a_config.search_tool.user, # bit of a hack
|
||||
llm=agent_a_config.fast_llm,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# skip section filtering
|
||||
|
||||
if (
|
||||
_search_query.rerank_settings
|
||||
and _search_query.rerank_settings.rerank_model_name
|
||||
and _search_query.rerank_settings.num_rerank > 0
|
||||
):
|
||||
reranked_documents = rerank_sections(
|
||||
_search_query,
|
||||
verified_documents,
|
||||
)
|
||||
else:
|
||||
logger.warning("No reranking settings found, using unranked documents")
|
||||
reranked_documents = verified_documents
|
||||
|
||||
if AGENT_RERANKING_STATS:
|
||||
fit_scores = get_fit_scores(verified_documents, reranked_documents)
|
||||
else:
|
||||
fit_scores = RetrievalFitStats(fit_score_lift=0, rerank_effect=0, fit_scores={})
|
||||
|
||||
# TODO: stream deduped docs here, or decide to use search tool ranking/verification
|
||||
now_end = datetime.now()
|
||||
return DocRerankingUpdate(
|
||||
reranked_documents=[
|
||||
doc for doc in reranked_documents if type(doc) == InferenceSection
|
||||
][:AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS],
|
||||
sub_question_retrieval_stats=fit_scores,
|
||||
log_messages=[
|
||||
f"{now_end} -- Expanded Retrieval - Reranking - Time taken: {now_end - now_start}"
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,103 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.operations import logger
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
|
||||
DocRetrievalUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
|
||||
RetrievalInput,
|
||||
)
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.calculations import get_fit_scores
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import QueryResult
|
||||
from onyx.configs.dev_configs import AGENT_MAX_QUERY_RETRIEVAL_RESULTS
|
||||
from onyx.configs.dev_configs import AGENT_RETRIEVAL_STATS
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.tools.models import SearchQueryInfo
|
||||
from onyx.tools.tool_implementations.search.search_tool import (
|
||||
SEARCH_RESPONSE_SUMMARY_ID,
|
||||
)
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary
|
||||
|
||||
|
||||
def doc_retrieval(state: RetrievalInput, config: RunnableConfig) -> DocRetrievalUpdate:
|
||||
"""
|
||||
Retrieve documents
|
||||
|
||||
Args:
|
||||
state (RetrievalInput): Primary state + the query to retrieve
|
||||
config (RunnableConfig): Configuration containing ProSearchConfig
|
||||
|
||||
Updates:
|
||||
expanded_retrieval_results: list[ExpandedRetrievalResult]
|
||||
retrieved_documents: list[InferenceSection]
|
||||
"""
|
||||
now_start = datetime.now()
|
||||
query_to_retrieve = state.query_to_retrieve
|
||||
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
search_tool = agent_a_config.search_tool
|
||||
|
||||
retrieved_docs: list[InferenceSection] = []
|
||||
if not query_to_retrieve.strip():
|
||||
logger.warning("Empty query, skipping retrieval")
|
||||
now_end = datetime.now()
|
||||
return DocRetrievalUpdate(
|
||||
expanded_retrieval_results=[],
|
||||
retrieved_documents=[],
|
||||
log_messages=[
|
||||
f"{now_end} -- Expanded Retrieval - Retrieval - Empty Query - Time taken: {now_end - now_start}"
|
||||
],
|
||||
)
|
||||
|
||||
query_info = None
|
||||
# new db session to avoid concurrency issues
|
||||
with get_session_context_manager() as db_session:
|
||||
for tool_response in search_tool.run(
|
||||
query=query_to_retrieve,
|
||||
force_no_rerank=True,
|
||||
alternate_db_session=db_session,
|
||||
):
|
||||
# get retrieved docs to send to the rest of the graph
|
||||
if tool_response.id == SEARCH_RESPONSE_SUMMARY_ID:
|
||||
response = cast(SearchResponseSummary, tool_response.response)
|
||||
retrieved_docs = response.top_sections
|
||||
query_info = SearchQueryInfo(
|
||||
predicted_search=response.predicted_search,
|
||||
final_filters=response.final_filters,
|
||||
recency_bias_multiplier=response.recency_bias_multiplier,
|
||||
)
|
||||
break
|
||||
|
||||
retrieved_docs = retrieved_docs[:AGENT_MAX_QUERY_RETRIEVAL_RESULTS]
|
||||
pre_rerank_docs = retrieved_docs
|
||||
if search_tool.search_pipeline is not None:
|
||||
pre_rerank_docs = (
|
||||
search_tool.search_pipeline._retrieved_sections or retrieved_docs
|
||||
)
|
||||
|
||||
if AGENT_RETRIEVAL_STATS:
|
||||
fit_scores = get_fit_scores(
|
||||
pre_rerank_docs,
|
||||
retrieved_docs,
|
||||
)
|
||||
else:
|
||||
fit_scores = None
|
||||
|
||||
expanded_retrieval_result = QueryResult(
|
||||
query=query_to_retrieve,
|
||||
search_results=retrieved_docs,
|
||||
stats=fit_scores,
|
||||
query_info=query_info,
|
||||
)
|
||||
now_end = datetime.now()
|
||||
return DocRetrievalUpdate(
|
||||
expanded_retrieval_results=[expanded_retrieval_result],
|
||||
retrieved_documents=retrieved_docs,
|
||||
log_messages=[
|
||||
f"{now_end} -- Expanded Retrieval - Retrieval - Time taken: {now_end - now_start}"
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,60 @@
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
|
||||
DocVerificationInput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
|
||||
DocVerificationUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
trim_prompt_piece,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import VERIFIER_PROMPT
|
||||
|
||||
|
||||
def doc_verification(
|
||||
state: DocVerificationInput, config: RunnableConfig
|
||||
) -> DocVerificationUpdate:
|
||||
"""
|
||||
Check whether the document is relevant for the original user question
|
||||
|
||||
Args:
|
||||
state (DocVerificationInput): The current state
|
||||
config (RunnableConfig): Configuration containing ProSearchConfig
|
||||
|
||||
Updates:
|
||||
verified_documents: list[InferenceSection]
|
||||
"""
|
||||
|
||||
question = state.question
|
||||
doc_to_verify = state.doc_to_verify
|
||||
document_content = doc_to_verify.combined_content
|
||||
|
||||
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
fast_llm = agent_a_config.fast_llm
|
||||
|
||||
document_content = trim_prompt_piece(
|
||||
fast_llm.config, document_content, VERIFIER_PROMPT + question
|
||||
)
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=VERIFIER_PROMPT.format(
|
||||
question=question, document_content=document_content
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
response = fast_llm.invoke(msg)
|
||||
|
||||
verified_documents = []
|
||||
if isinstance(response.content, str) and "yes" in response.content.lower():
|
||||
verified_documents.append(doc_to_verify)
|
||||
|
||||
return DocVerificationUpdate(
|
||||
verified_documents=verified_documents,
|
||||
)
|
||||
@@ -0,0 +1,16 @@
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
|
||||
ExpandedRetrievalState,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
|
||||
QueryExpansionUpdate,
|
||||
)
|
||||
|
||||
|
||||
def dummy(
|
||||
state: ExpandedRetrievalState, config: RunnableConfig
|
||||
) -> QueryExpansionUpdate:
|
||||
return QueryExpansionUpdate(
|
||||
expanded_queries=state.expanded_queries,
|
||||
)
|
||||
@@ -0,0 +1,68 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import merge_message_runs
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.operations import (
|
||||
dispatch_subquery,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
|
||||
ExpandedRetrievalInput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
|
||||
QueryExpansionUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import (
|
||||
REWRITE_PROMPT_MULTI_ORIGINAL,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import dispatch_separated
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
|
||||
|
||||
|
||||
def expand_queries(
|
||||
state: ExpandedRetrievalInput, config: RunnableConfig
|
||||
) -> QueryExpansionUpdate:
|
||||
# Sometimes we want to expand the original question, sometimes we want to expand a sub-question.
|
||||
# When we are running this node on the original question, no question is explictly passed in.
|
||||
# Instead, we use the original question from the search request.
|
||||
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
now_start = datetime.now()
|
||||
question = (
|
||||
state.question
|
||||
if hasattr(state, "question")
|
||||
else agent_a_config.search_request.query
|
||||
)
|
||||
llm = agent_a_config.fast_llm
|
||||
chat_session_id = agent_a_config.chat_session_id
|
||||
sub_question_id = state.sub_question_id
|
||||
if sub_question_id is None:
|
||||
level, question_nr = 0, 0
|
||||
else:
|
||||
level, question_nr = parse_question_id(sub_question_id)
|
||||
|
||||
if chat_session_id is None:
|
||||
raise ValueError("chat_session_id must be provided for agent search")
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=REWRITE_PROMPT_MULTI_ORIGINAL.format(question=question),
|
||||
)
|
||||
]
|
||||
|
||||
llm_response_list = dispatch_separated(
|
||||
llm.stream(prompt=msg), dispatch_subquery(level, question_nr)
|
||||
)
|
||||
|
||||
llm_response = merge_message_runs(llm_response_list, chunk_separator="")[0].content
|
||||
|
||||
rewritten_queries = llm_response.split("\n")
|
||||
now_end = datetime.now()
|
||||
return QueryExpansionUpdate(
|
||||
expanded_queries=rewritten_queries,
|
||||
log_messages=[
|
||||
f"{now_end} -- Expanded Retrieval - Query Expansion - Time taken: {now_end - now_start}"
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,82 @@
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.callbacks.manager import dispatch_custom_event
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.models import (
|
||||
ExpandedRetrievalResult,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.operations import (
|
||||
calculate_sub_question_retrieval_stats,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
|
||||
ExpandedRetrievalState,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
|
||||
ExpandedRetrievalUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkStats
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
|
||||
from onyx.chat.models import ExtendedToolResponse
|
||||
from onyx.tools.tool_implementations.search.search_tool import yield_search_responses
|
||||
|
||||
|
||||
def format_results(
|
||||
state: ExpandedRetrievalState, config: RunnableConfig
|
||||
) -> ExpandedRetrievalUpdate:
|
||||
level, question_nr = parse_question_id(state.sub_question_id or "0_0")
|
||||
query_infos = [
|
||||
result.query_info
|
||||
for result in state.expanded_retrieval_results
|
||||
if result.query_info is not None
|
||||
]
|
||||
if len(query_infos) == 0:
|
||||
raise ValueError("No query info found")
|
||||
|
||||
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
# main question docs will be sent later after aggregation and deduping with sub-question docs
|
||||
|
||||
stream_documents = state.reranked_documents
|
||||
|
||||
if not (level == 0 and question_nr == 0):
|
||||
if len(stream_documents) == 0:
|
||||
# The sub-question is used as the last query. If no verified documents are found, stream
|
||||
# the top 3 for that one. We may want to revisit this.
|
||||
stream_documents = state.expanded_retrieval_results[-1].search_results[:3]
|
||||
|
||||
for tool_response in yield_search_responses(
|
||||
query=state.question,
|
||||
reranked_sections=state.retrieved_documents, # TODO: rename params. (sections pre-merging here.)
|
||||
final_context_sections=stream_documents,
|
||||
search_query_info=query_infos[0], # TODO: handle differing query infos?
|
||||
get_section_relevance=lambda: None, # TODO: add relevance
|
||||
search_tool=agent_a_config.search_tool,
|
||||
):
|
||||
dispatch_custom_event(
|
||||
"tool_response",
|
||||
ExtendedToolResponse(
|
||||
id=tool_response.id,
|
||||
response=tool_response.response,
|
||||
level=level,
|
||||
level_question_nr=question_nr,
|
||||
),
|
||||
)
|
||||
sub_question_retrieval_stats = calculate_sub_question_retrieval_stats(
|
||||
verified_documents=state.verified_documents,
|
||||
expanded_retrieval_results=state.expanded_retrieval_results,
|
||||
)
|
||||
|
||||
if sub_question_retrieval_stats is None:
|
||||
sub_question_retrieval_stats = AgentChunkStats()
|
||||
# else:
|
||||
# sub_question_retrieval_stats = [sub_question_retrieval_stats]
|
||||
|
||||
return ExpandedRetrievalUpdate(
|
||||
expanded_retrieval_result=ExpandedRetrievalResult(
|
||||
expanded_queries_results=state.expanded_retrieval_results,
|
||||
all_documents=state.reranked_documents,
|
||||
context_documents=state.reranked_documents,
|
||||
sub_question_retrieval_stats=sub_question_retrieval_stats,
|
||||
),
|
||||
)
|
||||
@@ -0,0 +1,44 @@
|
||||
from typing import cast
|
||||
from typing import Literal
|
||||
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
from langgraph.types import Command
|
||||
from langgraph.types import Send
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
|
||||
DocVerificationInput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
|
||||
ExpandedRetrievalState,
|
||||
)
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
|
||||
|
||||
def verification_kickoff(
|
||||
state: ExpandedRetrievalState,
|
||||
config: RunnableConfig,
|
||||
) -> Command[Literal["doc_verification"]]:
|
||||
documents = state.retrieved_documents
|
||||
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
verification_question = (
|
||||
state.question
|
||||
if hasattr(state, "question")
|
||||
else agent_a_config.search_request.query
|
||||
)
|
||||
sub_question_id = state.sub_question_id
|
||||
return Command(
|
||||
update={},
|
||||
goto=[
|
||||
Send(
|
||||
node="doc_verification",
|
||||
arg=DocVerificationInput(
|
||||
doc_to_verify=doc,
|
||||
question=verification_question,
|
||||
base_search=False,
|
||||
sub_question_id=sub_question_id,
|
||||
log_messages=[],
|
||||
),
|
||||
)
|
||||
for doc in documents
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,97 @@
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
|
||||
import numpy as np
|
||||
from langchain_core.callbacks.manager import dispatch_custom_event
|
||||
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkStats
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import QueryResult
|
||||
from onyx.chat.models import SubQueryPiece
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def dispatch_subquery(level: int, question_nr: int) -> Callable[[str, int], None]:
|
||||
def helper(token: str, num: int) -> None:
|
||||
dispatch_custom_event(
|
||||
"subqueries",
|
||||
SubQueryPiece(
|
||||
sub_query=token,
|
||||
level=level,
|
||||
level_question_nr=question_nr,
|
||||
query_id=num,
|
||||
),
|
||||
)
|
||||
|
||||
return helper
|
||||
|
||||
|
||||
def calculate_sub_question_retrieval_stats(
|
||||
verified_documents: list[InferenceSection],
|
||||
expanded_retrieval_results: list[QueryResult],
|
||||
) -> AgentChunkStats:
|
||||
chunk_scores: dict[str, dict[str, list[int | float]]] = defaultdict(
|
||||
lambda: defaultdict(list)
|
||||
)
|
||||
|
||||
for expanded_retrieval_result in expanded_retrieval_results:
|
||||
for doc in expanded_retrieval_result.search_results:
|
||||
doc_chunk_id = f"{doc.center_chunk.document_id}_{doc.center_chunk.chunk_id}"
|
||||
if doc.center_chunk.score is not None:
|
||||
chunk_scores[doc_chunk_id]["score"].append(doc.center_chunk.score)
|
||||
|
||||
verified_doc_chunk_ids = [
|
||||
f"{verified_document.center_chunk.document_id}_{verified_document.center_chunk.chunk_id}"
|
||||
for verified_document in verified_documents
|
||||
]
|
||||
dismissed_doc_chunk_ids = []
|
||||
|
||||
raw_chunk_stats_counts: dict[str, int] = defaultdict(int)
|
||||
raw_chunk_stats_scores: dict[str, float] = defaultdict(float)
|
||||
for doc_chunk_id, chunk_data in chunk_scores.items():
|
||||
if doc_chunk_id in verified_doc_chunk_ids:
|
||||
raw_chunk_stats_counts["verified_count"] += 1
|
||||
|
||||
valid_chunk_scores = [
|
||||
score for score in chunk_data["score"] if score is not None
|
||||
]
|
||||
raw_chunk_stats_scores["verified_scores"] += float(
|
||||
np.mean(valid_chunk_scores)
|
||||
)
|
||||
else:
|
||||
raw_chunk_stats_counts["rejected_count"] += 1
|
||||
valid_chunk_scores = [
|
||||
score for score in chunk_data["score"] if score is not None
|
||||
]
|
||||
raw_chunk_stats_scores["rejected_scores"] += float(
|
||||
np.mean(valid_chunk_scores)
|
||||
)
|
||||
dismissed_doc_chunk_ids.append(doc_chunk_id)
|
||||
|
||||
if raw_chunk_stats_counts["verified_count"] == 0:
|
||||
verified_avg_scores = 0.0
|
||||
else:
|
||||
verified_avg_scores = raw_chunk_stats_scores["verified_scores"] / float(
|
||||
raw_chunk_stats_counts["verified_count"]
|
||||
)
|
||||
|
||||
rejected_scores = raw_chunk_stats_scores.get("rejected_scores", None)
|
||||
if rejected_scores is not None:
|
||||
rejected_avg_scores = rejected_scores / float(
|
||||
raw_chunk_stats_counts["rejected_count"]
|
||||
)
|
||||
else:
|
||||
rejected_avg_scores = None
|
||||
|
||||
chunk_stats = AgentChunkStats(
|
||||
verified_count=raw_chunk_stats_counts["verified_count"],
|
||||
verified_avg_scores=verified_avg_scores,
|
||||
rejected_count=raw_chunk_stats_counts["rejected_count"],
|
||||
rejected_avg_scores=rejected_avg_scores,
|
||||
verified_doc_chunk_ids=verified_doc_chunk_ids,
|
||||
dismissed_doc_chunk_ids=dismissed_doc_chunk_ids,
|
||||
)
|
||||
|
||||
return chunk_stats
|
||||
@@ -0,0 +1,91 @@
|
||||
from operator import add
|
||||
from typing import Annotated
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.agents.agent_search.core_state import SubgraphCoreState
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.models import (
|
||||
ExpandedRetrievalResult,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import QueryResult
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import RetrievalFitStats
|
||||
from onyx.agents.agent_search.shared_graph_utils.operators import (
|
||||
dedup_inference_sections,
|
||||
)
|
||||
from onyx.context.search.models import InferenceSection
|
||||
|
||||
|
||||
### States ###
|
||||
|
||||
## Graph Input State
|
||||
|
||||
|
||||
class ExpandedRetrievalInput(SubgraphCoreState):
|
||||
question: str = ""
|
||||
base_search: bool = False
|
||||
sub_question_id: str | None = None
|
||||
|
||||
|
||||
## Update/Return States
|
||||
|
||||
|
||||
class QueryExpansionUpdate(BaseModel):
|
||||
expanded_queries: list[str] = ["aaa", "bbb"]
|
||||
log_messages: list[str] = []
|
||||
|
||||
|
||||
class DocVerificationUpdate(BaseModel):
|
||||
verified_documents: Annotated[list[InferenceSection], dedup_inference_sections] = []
|
||||
|
||||
|
||||
class DocRetrievalUpdate(BaseModel):
|
||||
expanded_retrieval_results: Annotated[list[QueryResult], add] = []
|
||||
retrieved_documents: Annotated[
|
||||
list[InferenceSection], dedup_inference_sections
|
||||
] = []
|
||||
log_messages: list[str] = []
|
||||
|
||||
|
||||
class DocRerankingUpdate(BaseModel):
|
||||
reranked_documents: Annotated[list[InferenceSection], dedup_inference_sections] = []
|
||||
sub_question_retrieval_stats: RetrievalFitStats | None = None
|
||||
log_messages: list[str] = []
|
||||
|
||||
|
||||
class ExpandedRetrievalUpdate(BaseModel):
|
||||
expanded_retrieval_result: ExpandedRetrievalResult
|
||||
|
||||
|
||||
## Graph Output State
|
||||
|
||||
|
||||
class ExpandedRetrievalOutput(BaseModel):
|
||||
expanded_retrieval_result: ExpandedRetrievalResult = ExpandedRetrievalResult()
|
||||
base_expanded_retrieval_result: ExpandedRetrievalResult = ExpandedRetrievalResult()
|
||||
log_messages: list[str] = []
|
||||
|
||||
|
||||
## Graph State
|
||||
|
||||
|
||||
class ExpandedRetrievalState(
|
||||
# This includes the core state
|
||||
ExpandedRetrievalInput,
|
||||
QueryExpansionUpdate,
|
||||
DocRetrievalUpdate,
|
||||
DocVerificationUpdate,
|
||||
DocRerankingUpdate,
|
||||
ExpandedRetrievalOutput,
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
## Conditional Input States
|
||||
|
||||
|
||||
class DocVerificationInput(ExpandedRetrievalInput):
|
||||
doc_to_verify: InferenceSection
|
||||
|
||||
|
||||
class RetrievalInput(ExpandedRetrievalInput):
|
||||
query_to_retrieve: str = ""
|
||||
100
backend/onyx/agents/agent_search/deep_search_a/main/edges.py
Normal file
100
backend/onyx/agents/agent_search/deep_search_a/main/edges.py
Normal file
@@ -0,0 +1,100 @@
|
||||
from collections.abc import Hashable
|
||||
from datetime import datetime
|
||||
from typing import Literal
|
||||
|
||||
from langgraph.types import Send
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.states import (
|
||||
AnswerQuestionInput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.states import (
|
||||
AnswerQuestionOutput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import MainState
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import (
|
||||
RequireRefinedAnswerUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import make_question_id
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def parallelize_initial_sub_question_answering(
|
||||
state: MainState,
|
||||
) -> list[Send | Hashable]:
|
||||
now_start = datetime.now()
|
||||
if len(state.initial_decomp_questions) > 0:
|
||||
# sub_question_record_ids = [subq_record.id for subq_record in state["sub_question_records"]]
|
||||
# if len(state["sub_question_records"]) == 0:
|
||||
# if state["config"].use_persistence:
|
||||
# raise ValueError("No sub-questions found for initial decompozed questions")
|
||||
# else:
|
||||
# # in this case, we are doing retrieval on the original question.
|
||||
# # to make all the logic consistent, we create a new sub-question
|
||||
# # with the same content as the original question
|
||||
# sub_question_record_ids = [1] * len(state["initial_decomp_questions"])
|
||||
|
||||
return [
|
||||
Send(
|
||||
"answer_query_subgraph",
|
||||
AnswerQuestionInput(
|
||||
question=question,
|
||||
question_id=make_question_id(0, question_nr + 1),
|
||||
log_messages=[
|
||||
f"{now_start} -- Main Edge - Parallelize Initial Sub-question Answering"
|
||||
],
|
||||
),
|
||||
)
|
||||
for question_nr, question in enumerate(state.initial_decomp_questions)
|
||||
]
|
||||
|
||||
else:
|
||||
return [
|
||||
Send(
|
||||
"ingest_answers",
|
||||
AnswerQuestionOutput(
|
||||
answer_results=[],
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
# Define the function that determines whether to continue or not
|
||||
def continue_to_refined_answer_or_end(
|
||||
state: RequireRefinedAnswerUpdate,
|
||||
) -> Literal["refined_sub_question_creation", "logging_node"]:
|
||||
if state.require_refined_answer:
|
||||
return "refined_sub_question_creation"
|
||||
else:
|
||||
return "logging_node"
|
||||
|
||||
|
||||
def parallelize_refined_sub_question_answering(
|
||||
state: MainState,
|
||||
) -> list[Send | Hashable]:
|
||||
now_start = datetime.now()
|
||||
if len(state.refined_sub_questions) > 0:
|
||||
return [
|
||||
Send(
|
||||
"answer_refined_question",
|
||||
AnswerQuestionInput(
|
||||
question=question_data.sub_question,
|
||||
question_id=make_question_id(1, question_nr),
|
||||
log_messages=[
|
||||
f"{now_start} -- Main Edge - Parallelize Refined Sub-question Answering"
|
||||
],
|
||||
),
|
||||
)
|
||||
for question_nr, question_data in state.refined_sub_questions.items()
|
||||
]
|
||||
|
||||
else:
|
||||
return [
|
||||
Send(
|
||||
"ingest_refined_sub_answers",
|
||||
AnswerQuestionOutput(
|
||||
answer_results=[],
|
||||
),
|
||||
)
|
||||
]
|
||||
@@ -0,0 +1,365 @@
|
||||
from langgraph.graph import END
|
||||
from langgraph.graph import START
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.graph_builder import (
|
||||
answer_query_graph_builder,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.answer_refinement_sub_question.graph_builder import (
|
||||
answer_refined_query_graph_builder,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.base_raw_search.graph_builder import (
|
||||
base_raw_search_graph_builder,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.edges import (
|
||||
continue_to_refined_answer_or_end,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.edges import (
|
||||
parallelize_initial_sub_question_answering,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.edges import (
|
||||
parallelize_refined_sub_question_answering,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.nodes.agent_logging import (
|
||||
agent_logging,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.nodes.agent_path_decision import (
|
||||
agent_path_decision,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.nodes.agent_path_routing import (
|
||||
agent_path_routing,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.nodes.agent_search_start import (
|
||||
agent_search_start,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.nodes.direct_llm_handling import (
|
||||
direct_llm_handling,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.nodes.entity_term_extraction_llm import (
|
||||
entity_term_extraction_llm,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.nodes.generate_initial_answer import (
|
||||
generate_initial_answer,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.nodes.generate_refined_answer import (
|
||||
generate_refined_answer,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.nodes.ingest_initial_base_retrieval import (
|
||||
ingest_initial_base_retrieval,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.nodes.ingest_initial_sub_question_answers import (
|
||||
ingest_initial_sub_question_answers,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.nodes.ingest_refined_answers import (
|
||||
ingest_refined_answers,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.nodes.initial_answer_quality_check import (
|
||||
initial_answer_quality_check,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.nodes.initial_sub_question_creation import (
|
||||
initial_sub_question_creation,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.nodes.refined_answer_decision import (
|
||||
refined_answer_decision,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.nodes.refined_sub_question_creation import (
|
||||
refined_sub_question_creation,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.nodes.retrieval_consolidation import (
|
||||
retrieval_consolidation,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import MainInput
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import MainState
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
test_mode = False
|
||||
|
||||
|
||||
def main_graph_builder(test_mode: bool = False) -> StateGraph:
|
||||
graph = StateGraph(
|
||||
state_schema=MainState,
|
||||
input=MainInput,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
node="agent_path_decision",
|
||||
action=agent_path_decision,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
node="agent_path_routing",
|
||||
action=agent_path_routing,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
node="LLM",
|
||||
action=direct_llm_handling,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
node="agent_search_start",
|
||||
action=agent_search_start,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
node="initial_sub_question_creation",
|
||||
action=initial_sub_question_creation,
|
||||
)
|
||||
answer_query_subgraph = answer_query_graph_builder().compile()
|
||||
graph.add_node(
|
||||
node="answer_query_subgraph",
|
||||
action=answer_query_subgraph,
|
||||
)
|
||||
|
||||
base_raw_search_subgraph = base_raw_search_graph_builder().compile()
|
||||
graph.add_node(
|
||||
node="base_raw_search_subgraph",
|
||||
action=base_raw_search_subgraph,
|
||||
)
|
||||
|
||||
# refined_answer_subgraph = refined_answers_graph_builder().compile()
|
||||
# graph.add_node(
|
||||
# node="refined_answer_subgraph",
|
||||
# action=refined_answer_subgraph,
|
||||
# )
|
||||
|
||||
graph.add_node(
|
||||
node="refined_sub_question_creation",
|
||||
action=refined_sub_question_creation,
|
||||
)
|
||||
|
||||
answer_refined_question = answer_refined_query_graph_builder().compile()
|
||||
graph.add_node(
|
||||
node="answer_refined_question",
|
||||
action=answer_refined_question,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
node="ingest_refined_answers",
|
||||
action=ingest_refined_answers,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
node="generate_refined_answer",
|
||||
action=generate_refined_answer,
|
||||
)
|
||||
|
||||
# graph.add_node(
|
||||
# node="check_refined_answer",
|
||||
# action=check_refined_answer,
|
||||
# )
|
||||
|
||||
graph.add_node(
|
||||
node="ingest_initial_retrieval",
|
||||
action=ingest_initial_base_retrieval,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
node="retrieval_consolidation",
|
||||
action=retrieval_consolidation,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
node="ingest_initial_sub_question_answers",
|
||||
action=ingest_initial_sub_question_answers,
|
||||
)
|
||||
graph.add_node(
|
||||
node="generate_initial_answer",
|
||||
action=generate_initial_answer,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
node="initial_answer_quality_check",
|
||||
action=initial_answer_quality_check,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
node="entity_term_extraction_llm",
|
||||
action=entity_term_extraction_llm,
|
||||
)
|
||||
graph.add_node(
|
||||
node="refined_answer_decision",
|
||||
action=refined_answer_decision,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
node="logging_node",
|
||||
action=agent_logging,
|
||||
)
|
||||
# if test_mode:
|
||||
# graph.add_node(
|
||||
# node="generate_initial_base_answer",
|
||||
# action=generate_initial_base_answer,
|
||||
# )
|
||||
|
||||
### Add edges ###
|
||||
|
||||
# raph.add_edge(start_key=START, end_key="base_raw_search_subgraph")
|
||||
|
||||
graph.add_edge(
|
||||
start_key=START,
|
||||
end_key="agent_path_decision",
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key="agent_path_decision",
|
||||
end_key="agent_path_routing",
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key="agent_search_start",
|
||||
end_key="base_raw_search_subgraph",
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key="agent_search_start",
|
||||
end_key="initial_sub_question_creation",
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key="base_raw_search_subgraph",
|
||||
end_key="ingest_initial_retrieval",
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key=["ingest_initial_retrieval", "ingest_initial_sub_question_answers"],
|
||||
end_key="retrieval_consolidation",
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key="retrieval_consolidation",
|
||||
end_key="entity_term_extraction_llm",
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key="retrieval_consolidation",
|
||||
end_key="generate_initial_answer",
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key="LLM",
|
||||
end_key=END,
|
||||
)
|
||||
|
||||
# graph.add_edge(
|
||||
# start_key=START,
|
||||
# end_key="initial_sub_question_creation",
|
||||
# )
|
||||
|
||||
graph.add_conditional_edges(
|
||||
source="initial_sub_question_creation",
|
||||
path=parallelize_initial_sub_question_answering,
|
||||
path_map=["answer_query_subgraph"],
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="answer_query_subgraph",
|
||||
end_key="ingest_initial_sub_question_answers",
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key="retrieval_consolidation",
|
||||
end_key="generate_initial_answer",
|
||||
)
|
||||
|
||||
# graph.add_edge(
|
||||
# start_key="generate_initial_answer",
|
||||
# end_key="entity_term_extraction_llm",
|
||||
# )
|
||||
|
||||
graph.add_edge(
|
||||
start_key="generate_initial_answer",
|
||||
end_key="initial_answer_quality_check",
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key=["initial_answer_quality_check", "entity_term_extraction_llm"],
|
||||
end_key="refined_answer_decision",
|
||||
)
|
||||
|
||||
graph.add_conditional_edges(
|
||||
source="refined_answer_decision",
|
||||
path=continue_to_refined_answer_or_end,
|
||||
path_map=["refined_sub_question_creation", "logging_node"],
|
||||
)
|
||||
|
||||
graph.add_conditional_edges(
|
||||
source="refined_sub_question_creation", # DONE
|
||||
path=parallelize_refined_sub_question_answering,
|
||||
path_map=["answer_refined_question"],
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="answer_refined_question", # HERE
|
||||
end_key="ingest_refined_answers",
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key="ingest_refined_answers",
|
||||
end_key="generate_refined_answer",
|
||||
)
|
||||
|
||||
# graph.add_conditional_edges(
|
||||
# source="refined_answer_decision",
|
||||
# path=continue_to_refined_answer_or_end,
|
||||
# path_map=["refined_answer_subgraph", END],
|
||||
# )
|
||||
|
||||
# graph.add_edge(
|
||||
# start_key="refined_answer_subgraph",
|
||||
# end_key="generate_refined_answer",
|
||||
# )
|
||||
|
||||
graph.add_edge(
|
||||
start_key="generate_refined_answer",
|
||||
end_key="logging_node",
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key="logging_node",
|
||||
end_key=END,
|
||||
)
|
||||
|
||||
# graph.add_edge(
|
||||
# start_key="generate_refined_answer",
|
||||
# end_key="check_refined_answer",
|
||||
# )
|
||||
|
||||
# graph.add_edge(
|
||||
# start_key="check_refined_answer",
|
||||
# end_key=END,
|
||||
# )
|
||||
|
||||
return graph
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pass
|
||||
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.llm.factory import get_default_llms
|
||||
from onyx.context.search.models import SearchRequest
|
||||
|
||||
graph = main_graph_builder()
|
||||
compiled_graph = graph.compile()
|
||||
primary_llm, fast_llm = get_default_llms()
|
||||
|
||||
with get_session_context_manager() as db_session:
|
||||
search_request = SearchRequest(query="Who created Excel?")
|
||||
agent_a_config, search_tool = get_test_config(
|
||||
db_session, primary_llm, fast_llm, search_request
|
||||
)
|
||||
|
||||
inputs = MainInput(
|
||||
base_question=agent_a_config.search_request.query, log_messages=[]
|
||||
)
|
||||
|
||||
for thing in compiled_graph.stream(
|
||||
input=inputs,
|
||||
config={"configurable": {"config": agent_a_config}},
|
||||
# stream_mode="debug",
|
||||
# debug=True,
|
||||
subgraphs=True,
|
||||
):
|
||||
logger.debug(thing)
|
||||
@@ -0,0 +1,36 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class FollowUpSubQuestion(BaseModel):
|
||||
sub_question: str
|
||||
sub_question_id: str
|
||||
verified: bool
|
||||
answered: bool
|
||||
answer: str
|
||||
|
||||
|
||||
class AgentTimings(BaseModel):
|
||||
base_duration__s: float | None
|
||||
refined_duration__s: float | None
|
||||
full_duration__s: float | None
|
||||
|
||||
|
||||
class AgentBaseMetrics(BaseModel):
|
||||
num_verified_documents_total: int | None
|
||||
num_verified_documents_core: int | None
|
||||
verified_avg_score_core: float | None
|
||||
num_verified_documents_base: int | float | None
|
||||
verified_avg_score_base: float | None = None
|
||||
base_doc_boost_factor: float | None = None
|
||||
support_boost_factor: float | None = None
|
||||
duration__s: float | None = None
|
||||
|
||||
|
||||
class AgentRefinedMetrics(BaseModel):
|
||||
refined_doc_boost_factor: float | None = None
|
||||
refined_question_boost_factor: float | None = None
|
||||
duration__s: float | None = None
|
||||
|
||||
|
||||
class AgentAdditionalMetrics(BaseModel):
|
||||
pass
|
||||
@@ -0,0 +1,113 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.main.models import AgentAdditionalMetrics
|
||||
from onyx.agents.agent_search.deep_search_a.main.models import AgentTimings
|
||||
from onyx.agents.agent_search.deep_search_a.main.operations import logger
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import MainOutput
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import MainState
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import CombinedAgentMetrics
|
||||
from onyx.db.chat import log_agent_metrics
|
||||
from onyx.db.chat import log_agent_sub_question_results
|
||||
|
||||
|
||||
def agent_logging(state: MainState, config: RunnableConfig) -> MainOutput:
|
||||
now_start = datetime.now()
|
||||
|
||||
logger.debug(f"--------{now_start}--------LOGGING NODE---")
|
||||
|
||||
agent_start_time = state.agent_start_time
|
||||
agent_base_end_time = state.agent_base_end_time
|
||||
agent_refined_start_time = state.agent_refined_start_time or None
|
||||
agent_refined_end_time = state.agent_refined_end_time or None
|
||||
agent_end_time = agent_refined_end_time or agent_base_end_time
|
||||
|
||||
agent_base_duration = None
|
||||
if agent_base_end_time:
|
||||
agent_base_duration = (agent_base_end_time - agent_start_time).total_seconds()
|
||||
|
||||
agent_refined_duration = None
|
||||
if agent_refined_start_time and agent_refined_end_time:
|
||||
agent_refined_duration = (
|
||||
agent_refined_end_time - agent_refined_start_time
|
||||
).total_seconds()
|
||||
|
||||
agent_full_duration = None
|
||||
if agent_end_time:
|
||||
agent_full_duration = (agent_end_time - agent_start_time).total_seconds()
|
||||
|
||||
agent_type = "refined" if agent_refined_duration else "base"
|
||||
|
||||
agent_base_metrics = state.agent_base_metrics
|
||||
agent_refined_metrics = state.agent_refined_metrics
|
||||
|
||||
combined_agent_metrics = CombinedAgentMetrics(
|
||||
timings=AgentTimings(
|
||||
base_duration__s=agent_base_duration,
|
||||
refined_duration__s=agent_refined_duration,
|
||||
full_duration__s=agent_full_duration,
|
||||
),
|
||||
base_metrics=agent_base_metrics,
|
||||
refined_metrics=agent_refined_metrics,
|
||||
additional_metrics=AgentAdditionalMetrics(),
|
||||
)
|
||||
|
||||
persona_id = None
|
||||
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
if agent_a_config.search_request.persona:
|
||||
persona_id = agent_a_config.search_request.persona.id
|
||||
|
||||
user_id = None
|
||||
user = agent_a_config.search_tool.user
|
||||
if user:
|
||||
user_id = user.id
|
||||
|
||||
# log the agent metrics
|
||||
if agent_a_config.db_session is not None:
|
||||
log_agent_metrics(
|
||||
db_session=agent_a_config.db_session,
|
||||
user_id=user_id,
|
||||
persona_id=persona_id,
|
||||
agent_type=agent_type,
|
||||
start_time=agent_start_time,
|
||||
agent_metrics=combined_agent_metrics,
|
||||
)
|
||||
|
||||
if agent_a_config.use_persistence:
|
||||
# Persist the sub-answer in the database
|
||||
db_session = agent_a_config.db_session
|
||||
chat_session_id = agent_a_config.chat_session_id
|
||||
primary_message_id = agent_a_config.message_id
|
||||
sub_question_answer_results = state.decomp_answer_results
|
||||
|
||||
log_agent_sub_question_results(
|
||||
db_session=db_session,
|
||||
chat_session_id=chat_session_id,
|
||||
primary_message_id=primary_message_id,
|
||||
sub_question_answer_results=sub_question_answer_results,
|
||||
)
|
||||
|
||||
# if chat_session_id is not None and primary_message_id is not None and sub_question_id is not None:
|
||||
# create_sub_answer(
|
||||
# db_session=db_session,
|
||||
# chat_session_id=chat_session_id,
|
||||
# primary_message_id=primary_message_id,
|
||||
# sub_question_id=sub_question_id,
|
||||
# answer=answer_str,
|
||||
# # )
|
||||
# pass
|
||||
|
||||
now_end = datetime.now()
|
||||
main_output = MainOutput(
|
||||
log_messages=[
|
||||
f"{now_end} -- Main - Logging, Time taken: {now_end - now_start}"
|
||||
],
|
||||
)
|
||||
|
||||
logger.debug(f"--------{now_end}--{now_end - now_start}--------LOGGING NODE END---")
|
||||
logger.debug(f"--------{now_end}--{now_end - now_start}--------LOGGING NODE END---")
|
||||
|
||||
return main_output
|
||||
@@ -0,0 +1,99 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.main.operations import logger
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import MainState
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import RoutingDecision
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
build_history_prompt,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import AGENT_DECISION_PROMPT
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import (
|
||||
AGENT_DECISION_PROMPT_AFTER_SEARCH,
|
||||
)
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.llm.utils import check_number_of_tokens
|
||||
from onyx.tools.tool_implementations.search.search_tool import (
|
||||
SEARCH_RESPONSE_SUMMARY_ID,
|
||||
)
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary
|
||||
|
||||
|
||||
def agent_path_decision(state: MainState, config: RunnableConfig) -> RoutingDecision:
|
||||
now_start = datetime.now()
|
||||
|
||||
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
question = agent_a_config.search_request.query
|
||||
perform_initial_search_path_decision = (
|
||||
agent_a_config.perform_initial_search_path_decision
|
||||
)
|
||||
|
||||
history = build_history_prompt(agent_a_config.prompt_builder)
|
||||
|
||||
logger.debug(f"--------{now_start}--------DECIDING TO SEARCH OR GO TO LLM---")
|
||||
|
||||
if perform_initial_search_path_decision:
|
||||
search_tool = agent_a_config.search_tool
|
||||
retrieved_docs: list[InferenceSection] = []
|
||||
|
||||
# new db session to avoid concurrency issues
|
||||
with get_session_context_manager() as db_session:
|
||||
for tool_response in search_tool.run(
|
||||
query=question,
|
||||
force_no_rerank=True,
|
||||
alternate_db_session=db_session,
|
||||
):
|
||||
# get retrieved docs to send to the rest of the graph
|
||||
if tool_response.id == SEARCH_RESPONSE_SUMMARY_ID:
|
||||
response = cast(SearchResponseSummary, tool_response.response)
|
||||
retrieved_docs = response.top_sections
|
||||
break
|
||||
|
||||
sample_doc_str = "\n\n".join(
|
||||
[doc.combined_content for _, doc in enumerate(retrieved_docs[:3])]
|
||||
)
|
||||
|
||||
agent_decision_prompt = AGENT_DECISION_PROMPT_AFTER_SEARCH.format(
|
||||
question=question, sample_doc_str=sample_doc_str, history=history
|
||||
)
|
||||
|
||||
else:
|
||||
sample_doc_str = ""
|
||||
agent_decision_prompt = AGENT_DECISION_PROMPT.format(
|
||||
question=question, history=history
|
||||
)
|
||||
|
||||
msg = [HumanMessage(content=agent_decision_prompt)]
|
||||
|
||||
# Get the rewritten queries in a defined format
|
||||
model = agent_a_config.fast_llm
|
||||
|
||||
# no need to stream this
|
||||
resp = model.invoke(msg)
|
||||
|
||||
if isinstance(resp.content, str) and "research" in resp.content.lower():
|
||||
routing = "agent_search"
|
||||
else:
|
||||
routing = "LLM"
|
||||
|
||||
now_end = datetime.now()
|
||||
|
||||
logger.debug(
|
||||
f"--------{now_end}--{now_end - now_start}--------DECIDING TO SEARCH OR GO TO LLM END---"
|
||||
)
|
||||
|
||||
check_number_of_tokens(agent_decision_prompt)
|
||||
|
||||
return RoutingDecision(
|
||||
# Decide which route to take
|
||||
routing=routing,
|
||||
sample_doc_str=sample_doc_str,
|
||||
log_messages=[
|
||||
f"{now_end} -- Path decision: {routing}, Time taken: {now_end - now_start}"
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,31 @@
|
||||
from datetime import datetime
|
||||
from typing import Literal
|
||||
|
||||
from langgraph.types import Command
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import MainState
|
||||
|
||||
|
||||
def agent_path_routing(
|
||||
state: MainState,
|
||||
) -> Command[Literal["agent_search_start", "LLM"]]:
|
||||
now_start = datetime.now()
|
||||
routing = state.routing if hasattr(state, "routing") else "agent_search"
|
||||
|
||||
if routing == "agent_search":
|
||||
agent_path = "agent_search_start"
|
||||
else:
|
||||
agent_path = "LLM"
|
||||
|
||||
now_end = datetime.now()
|
||||
|
||||
return Command(
|
||||
# state update
|
||||
update={
|
||||
"log_messages": [
|
||||
f"{now_end} -- Main - Path routing: {agent_path}, Time taken: {now_end - now_start}"
|
||||
]
|
||||
},
|
||||
# control flow
|
||||
goto=agent_path,
|
||||
)
|
||||
@@ -0,0 +1,8 @@
|
||||
from datetime import datetime
|
||||
|
||||
from onyx.agents.agent_search.core_state import CoreState
|
||||
|
||||
|
||||
def agent_search_start(state: CoreState) -> CoreState:
|
||||
now_end = datetime.now()
|
||||
return CoreState(log_messages=[f"{now_end} -- Main - Agent search start"])
|
||||
@@ -0,0 +1,89 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.callbacks.manager import dispatch_custom_event
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import merge_content
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.main.operations import logger
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import InitialAnswerUpdate
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import MainState
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import (
|
||||
ASSISTANT_SYSTEM_PROMPT_DEFAULT,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import (
|
||||
ASSISTANT_SYSTEM_PROMPT_PERSONA,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import DIRECT_LLM_PROMPT
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import get_persona_prompt
|
||||
from onyx.chat.models import AgentAnswerPiece
|
||||
|
||||
|
||||
def direct_llm_handling(
|
||||
state: MainState, config: RunnableConfig
|
||||
) -> InitialAnswerUpdate:
|
||||
now_start = datetime.now()
|
||||
|
||||
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
question = agent_a_config.search_request.query
|
||||
persona_prompt = get_persona_prompt(agent_a_config.search_request.persona)
|
||||
|
||||
if len(persona_prompt) == 0:
|
||||
persona_specification = ASSISTANT_SYSTEM_PROMPT_DEFAULT
|
||||
else:
|
||||
persona_specification = ASSISTANT_SYSTEM_PROMPT_PERSONA.format(
|
||||
persona_prompt=persona_prompt
|
||||
)
|
||||
|
||||
logger.debug(f"--------{now_start}--------LLM HANDLING START---")
|
||||
|
||||
model = agent_a_config.fast_llm
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=DIRECT_LLM_PROMPT.format(
|
||||
persona_specification=persona_specification, question=question
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
streamed_tokens: list[str | list[str | dict[str, Any]]] = [""]
|
||||
|
||||
for message in model.stream(msg):
|
||||
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
|
||||
content = message.content
|
||||
if not isinstance(content, str):
|
||||
raise ValueError(
|
||||
f"Expected content to be a string, but got {type(content)}"
|
||||
)
|
||||
dispatch_custom_event(
|
||||
"initial_agent_answer",
|
||||
AgentAnswerPiece(
|
||||
answer_piece=content,
|
||||
level=0,
|
||||
level_question_nr=0,
|
||||
answer_type="agent_level_answer",
|
||||
),
|
||||
)
|
||||
streamed_tokens.append(content)
|
||||
|
||||
response = merge_content(*streamed_tokens)
|
||||
answer = cast(str, response)
|
||||
|
||||
now_end = datetime.now()
|
||||
|
||||
logger.debug(f"--------{now_end}--{now_end - now_start}--------LLM HANDLING END---")
|
||||
|
||||
return InitialAnswerUpdate(
|
||||
initial_answer=answer,
|
||||
initial_agent_stats=None,
|
||||
generated_sub_questions=[],
|
||||
agent_base_end_time=now_end,
|
||||
agent_base_metrics=None,
|
||||
log_messages=[
|
||||
f"{now_end} -- Main - LLM handling: {answer}, Time taken: {now_end - now_start}"
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,137 @@
|
||||
import json
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import merge_message_runs
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.main.operations import logger
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import (
|
||||
EntityTermExtractionUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import MainState
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
trim_prompt_piece,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import Entity
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import (
|
||||
EntityRelationshipTermExtraction,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import Relationship
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import Term
|
||||
from onyx.agents.agent_search.shared_graph_utils.operators import (
|
||||
dedup_inference_sections,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import ENTITY_TERM_PROMPT
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import format_docs
|
||||
|
||||
|
||||
def entity_term_extraction_llm(
|
||||
state: MainState, config: RunnableConfig
|
||||
) -> EntityTermExtractionUpdate:
|
||||
now_start = datetime.now()
|
||||
|
||||
logger.debug(f"--------{now_start}--------GENERATE ENTITIES & TERMS---")
|
||||
|
||||
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
if not agent_a_config.allow_refinement:
|
||||
now_end = datetime.now()
|
||||
return EntityTermExtractionUpdate(
|
||||
entity_retlation_term_extractions=EntityRelationshipTermExtraction(
|
||||
entities=[],
|
||||
relationships=[],
|
||||
terms=[],
|
||||
),
|
||||
log_messages=[
|
||||
f"{now_end} -- Main - ETR Extraction, Time taken: {now_end - now_start}"
|
||||
],
|
||||
)
|
||||
|
||||
# first four lines duplicates from generate_initial_answer
|
||||
question = agent_a_config.search_request.query
|
||||
sub_question_docs = state.documents
|
||||
all_original_question_documents = state.all_original_question_documents
|
||||
relevant_docs = dedup_inference_sections(
|
||||
sub_question_docs, all_original_question_documents
|
||||
)
|
||||
|
||||
# start with the entity/term/extraction
|
||||
doc_context = format_docs(relevant_docs)
|
||||
|
||||
doc_context = trim_prompt_piece(
|
||||
agent_a_config.fast_llm.config, doc_context, ENTITY_TERM_PROMPT + question
|
||||
)
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=ENTITY_TERM_PROMPT.format(question=question, context=doc_context),
|
||||
)
|
||||
]
|
||||
fast_llm = agent_a_config.fast_llm
|
||||
# Grader
|
||||
llm_response_list = list(
|
||||
fast_llm.stream(
|
||||
prompt=msg,
|
||||
)
|
||||
)
|
||||
llm_response = merge_message_runs(llm_response_list, chunk_separator="")[0].content
|
||||
|
||||
cleaned_response = re.sub(r"```json\n|\n```", "", llm_response)
|
||||
parsed_response = json.loads(cleaned_response)
|
||||
|
||||
entities = []
|
||||
relationships = []
|
||||
terms = []
|
||||
for entity in parsed_response.get("retrieved_entities_relationships", {}).get(
|
||||
"entities", {}
|
||||
):
|
||||
entity_name = entity.get("entity_name", "")
|
||||
entity_type = entity.get("entity_type", "")
|
||||
entities.append(Entity(entity_name=entity_name, entity_type=entity_type))
|
||||
|
||||
for relationship in parsed_response.get("retrieved_entities_relationships", {}).get(
|
||||
"relationships", {}
|
||||
):
|
||||
relationship_name = relationship.get("relationship_name", "")
|
||||
relationship_type = relationship.get("relationship_type", "")
|
||||
relationship_entities = relationship.get("relationship_entities", [])
|
||||
relationships.append(
|
||||
Relationship(
|
||||
relationship_name=relationship_name,
|
||||
relationship_type=relationship_type,
|
||||
relationship_entities=relationship_entities,
|
||||
)
|
||||
)
|
||||
|
||||
for term in parsed_response.get("retrieved_entities_relationships", {}).get(
|
||||
"terms", {}
|
||||
):
|
||||
term_name = term.get("term_name", "")
|
||||
term_type = term.get("term_type", "")
|
||||
term_similar_to = term.get("term_similar_to", [])
|
||||
terms.append(
|
||||
Term(
|
||||
term_name=term_name,
|
||||
term_type=term_type,
|
||||
term_similar_to=term_similar_to,
|
||||
)
|
||||
)
|
||||
|
||||
now_end = datetime.now()
|
||||
|
||||
logger.debug(
|
||||
f"--------{now_end}--{now_end - now_start}--------ENTITY TERM EXTRACTION END---"
|
||||
)
|
||||
|
||||
return EntityTermExtractionUpdate(
|
||||
entity_retlation_term_extractions=EntityRelationshipTermExtraction(
|
||||
entities=entities,
|
||||
relationships=relationships,
|
||||
terms=terms,
|
||||
),
|
||||
log_messages=[
|
||||
f"{now_end} -- Main - ETR Extraction, Time taken: {now_end - now_start}"
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,257 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.callbacks.manager import dispatch_custom_event
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import merge_content
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.main.models import AgentBaseMetrics
|
||||
from onyx.agents.agent_search.deep_search_a.main.operations import (
|
||||
calculate_initial_agent_stats,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.operations import get_query_info
|
||||
from onyx.agents.agent_search.deep_search_a.main.operations import logger
|
||||
from onyx.agents.agent_search.deep_search_a.main.operations import (
|
||||
remove_document_citations,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import InitialAnswerUpdate
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import MainState
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
build_history_prompt,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
trim_prompt_piece,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import InitialAgentResultStats
|
||||
from onyx.agents.agent_search.shared_graph_utils.operators import (
|
||||
dedup_inference_sections,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import (
|
||||
ASSISTANT_SYSTEM_PROMPT_DEFAULT,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import (
|
||||
ASSISTANT_SYSTEM_PROMPT_PERSONA,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import INITIAL_RAG_PROMPT
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import (
|
||||
INITIAL_RAG_PROMPT_NO_SUB_QUESTIONS,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import (
|
||||
SUB_QUESTION_ANSWER_TEMPLATE,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import UNKNOWN_ANSWER
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import format_docs
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import get_persona_prompt
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
|
||||
from onyx.chat.models import AgentAnswerPiece
|
||||
from onyx.chat.models import ExtendedToolResponse
|
||||
from onyx.tools.tool_implementations.search.search_tool import yield_search_responses
|
||||
|
||||
|
||||
def generate_initial_answer(
|
||||
state: MainState, config: RunnableConfig
|
||||
) -> InitialAnswerUpdate:
|
||||
now_start = datetime.now()
|
||||
|
||||
logger.debug(f"--------{now_start}--------GENERATE INITIAL---")
|
||||
|
||||
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
question = agent_a_config.search_request.query
|
||||
persona_prompt = get_persona_prompt(agent_a_config.search_request.persona)
|
||||
|
||||
history = build_history_prompt(agent_a_config.prompt_builder)
|
||||
|
||||
sub_question_docs = state.documents
|
||||
all_original_question_documents = state.all_original_question_documents
|
||||
|
||||
relevant_docs = dedup_inference_sections(
|
||||
sub_question_docs, all_original_question_documents
|
||||
)
|
||||
decomp_questions = []
|
||||
|
||||
if len(relevant_docs) == 0:
|
||||
dispatch_custom_event(
|
||||
"initial_agent_answer",
|
||||
AgentAnswerPiece(
|
||||
answer_piece=UNKNOWN_ANSWER,
|
||||
level=0,
|
||||
level_question_nr=0,
|
||||
answer_type="agent_level_answer",
|
||||
),
|
||||
)
|
||||
|
||||
answer = UNKNOWN_ANSWER
|
||||
initial_agent_stats = InitialAgentResultStats(
|
||||
sub_questions={},
|
||||
original_question={},
|
||||
agent_effectiveness={},
|
||||
)
|
||||
|
||||
else:
|
||||
# Use the query info from the base document retrieval
|
||||
query_info = get_query_info(state.original_question_retrieval_results)
|
||||
|
||||
for tool_response in yield_search_responses(
|
||||
query=question,
|
||||
reranked_sections=relevant_docs,
|
||||
final_context_sections=relevant_docs,
|
||||
search_query_info=query_info,
|
||||
get_section_relevance=lambda: None, # TODO: add relevance
|
||||
search_tool=agent_a_config.search_tool,
|
||||
):
|
||||
dispatch_custom_event(
|
||||
"tool_response",
|
||||
ExtendedToolResponse(
|
||||
id=tool_response.id,
|
||||
response=tool_response.response,
|
||||
level=0,
|
||||
level_question_nr=0, # 0, 0 is the base question
|
||||
),
|
||||
)
|
||||
|
||||
net_new_original_question_docs = []
|
||||
for all_original_question_doc in all_original_question_documents:
|
||||
if all_original_question_doc not in sub_question_docs:
|
||||
net_new_original_question_docs.append(all_original_question_doc)
|
||||
|
||||
decomp_answer_results = state.decomp_answer_results
|
||||
|
||||
good_qa_list: list[str] = []
|
||||
|
||||
sub_question_nr = 1
|
||||
|
||||
for decomp_answer_result in decomp_answer_results:
|
||||
decomp_questions.append(decomp_answer_result.question)
|
||||
_, question_nr = parse_question_id(decomp_answer_result.question_id)
|
||||
if (
|
||||
decomp_answer_result.quality.lower().startswith("yes")
|
||||
and len(decomp_answer_result.answer) > 0
|
||||
and decomp_answer_result.answer != UNKNOWN_ANSWER
|
||||
):
|
||||
good_qa_list.append(
|
||||
SUB_QUESTION_ANSWER_TEMPLATE.format(
|
||||
sub_question=decomp_answer_result.question,
|
||||
sub_answer=decomp_answer_result.answer,
|
||||
sub_question_nr=sub_question_nr,
|
||||
)
|
||||
)
|
||||
sub_question_nr += 1
|
||||
|
||||
if len(good_qa_list) > 0:
|
||||
sub_question_answer_str = "\n\n------\n\n".join(good_qa_list)
|
||||
else:
|
||||
sub_question_answer_str = ""
|
||||
|
||||
# Determine which persona-specification prompt to use
|
||||
|
||||
if len(persona_prompt) == 0:
|
||||
persona_specification = ASSISTANT_SYSTEM_PROMPT_DEFAULT
|
||||
else:
|
||||
persona_specification = ASSISTANT_SYSTEM_PROMPT_PERSONA.format(
|
||||
persona_prompt=persona_prompt
|
||||
)
|
||||
|
||||
# Determine which base prompt to use given the sub-question information
|
||||
if len(good_qa_list) > 0:
|
||||
base_prompt = INITIAL_RAG_PROMPT
|
||||
else:
|
||||
base_prompt = INITIAL_RAG_PROMPT_NO_SUB_QUESTIONS
|
||||
|
||||
model = agent_a_config.fast_llm
|
||||
|
||||
doc_context = format_docs(relevant_docs)
|
||||
doc_context = trim_prompt_piece(
|
||||
model.config,
|
||||
doc_context,
|
||||
base_prompt + sub_question_answer_str + persona_specification + history,
|
||||
)
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=base_prompt.format(
|
||||
question=question,
|
||||
answered_sub_questions=remove_document_citations(
|
||||
sub_question_answer_str
|
||||
),
|
||||
relevant_docs=format_docs(relevant_docs),
|
||||
persona_specification=persona_specification,
|
||||
history=history,
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
streamed_tokens: list[str | list[str | dict[str, Any]]] = [""]
|
||||
for message in model.stream(msg):
|
||||
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
|
||||
content = message.content
|
||||
if not isinstance(content, str):
|
||||
raise ValueError(
|
||||
f"Expected content to be a string, but got {type(content)}"
|
||||
)
|
||||
dispatch_custom_event(
|
||||
"initial_agent_answer",
|
||||
AgentAnswerPiece(
|
||||
answer_piece=content,
|
||||
level=0,
|
||||
level_question_nr=0,
|
||||
answer_type="agent_level_answer",
|
||||
),
|
||||
)
|
||||
streamed_tokens.append(content)
|
||||
|
||||
response = merge_content(*streamed_tokens)
|
||||
answer = cast(str, response)
|
||||
|
||||
initial_agent_stats = calculate_initial_agent_stats(
|
||||
state.decomp_answer_results, state.original_question_retrieval_stats
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"\n\nYYYYY--Sub-Questions:\n\n{sub_question_answer_str}\n\nStats:\n\n"
|
||||
)
|
||||
|
||||
if initial_agent_stats:
|
||||
logger.debug(initial_agent_stats.original_question)
|
||||
logger.debug(initial_agent_stats.sub_questions)
|
||||
logger.debug(initial_agent_stats.agent_effectiveness)
|
||||
|
||||
now_end = datetime.now()
|
||||
|
||||
logger.debug(
|
||||
f"--------{now_end}--{now_end - now_start}--------INITIAL AGENT ANSWER END---\n\n"
|
||||
)
|
||||
|
||||
agent_base_end_time = datetime.now()
|
||||
|
||||
agent_base_metrics = AgentBaseMetrics(
|
||||
num_verified_documents_total=len(relevant_docs),
|
||||
num_verified_documents_core=state.original_question_retrieval_stats.verified_count,
|
||||
verified_avg_score_core=state.original_question_retrieval_stats.verified_avg_scores,
|
||||
num_verified_documents_base=initial_agent_stats.sub_questions.get(
|
||||
"num_verified_documents", None
|
||||
),
|
||||
verified_avg_score_base=initial_agent_stats.sub_questions.get(
|
||||
"verified_avg_score", None
|
||||
),
|
||||
base_doc_boost_factor=initial_agent_stats.agent_effectiveness.get(
|
||||
"utilized_chunk_ratio", None
|
||||
),
|
||||
support_boost_factor=initial_agent_stats.agent_effectiveness.get(
|
||||
"support_ratio", None
|
||||
),
|
||||
duration__s=(agent_base_end_time - state.agent_start_time).total_seconds(),
|
||||
)
|
||||
|
||||
return InitialAnswerUpdate(
|
||||
initial_answer=answer,
|
||||
initial_agent_stats=initial_agent_stats,
|
||||
generated_sub_questions=decomp_questions,
|
||||
agent_base_end_time=agent_base_end_time,
|
||||
agent_base_metrics=agent_base_metrics,
|
||||
log_messages=[
|
||||
f"{now_end} -- Main - Initial Answer generation, Time taken: {now_end - now_start}"
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,56 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.main.operations import logger
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import InitialAnswerBASEUpdate
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import MainState
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
trim_prompt_piece,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import INITIAL_RAG_BASE_PROMPT
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import format_docs
|
||||
|
||||
|
||||
def generate_initial_base_search_only_answer(
|
||||
state: MainState,
|
||||
config: RunnableConfig,
|
||||
) -> InitialAnswerBASEUpdate:
|
||||
now_start = datetime.now()
|
||||
|
||||
logger.debug(f"--------{now_start}--------GENERATE INITIAL BASE ANSWER---")
|
||||
|
||||
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
question = agent_a_config.search_request.query
|
||||
original_question_docs = state.all_original_question_documents
|
||||
|
||||
model = agent_a_config.fast_llm
|
||||
|
||||
doc_context = format_docs(original_question_docs)
|
||||
doc_context = trim_prompt_piece(
|
||||
model.config, doc_context, INITIAL_RAG_BASE_PROMPT + question
|
||||
)
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=INITIAL_RAG_BASE_PROMPT.format(
|
||||
question=question,
|
||||
context=doc_context,
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
# Grader
|
||||
response = model.invoke(msg)
|
||||
answer = response.pretty_repr()
|
||||
|
||||
now_end = datetime.now()
|
||||
|
||||
logger.debug(
|
||||
f"--------{now_end}--{now_end - now_start}--------INITIAL BASE ANSWER END---\n\n"
|
||||
)
|
||||
|
||||
return InitialAnswerBASEUpdate(initial_base_answer=answer)
|
||||
@@ -0,0 +1,324 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.callbacks.manager import dispatch_custom_event
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import merge_content
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.main.models import AgentRefinedMetrics
|
||||
from onyx.agents.agent_search.deep_search_a.main.operations import get_query_info
|
||||
from onyx.agents.agent_search.deep_search_a.main.operations import logger
|
||||
from onyx.agents.agent_search.deep_search_a.main.operations import (
|
||||
remove_document_citations,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import MainState
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import RefinedAnswerUpdate
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
build_history_prompt,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
trim_prompt_piece,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import RefinedAgentStats
|
||||
from onyx.agents.agent_search.shared_graph_utils.operators import (
|
||||
dedup_inference_sections,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import (
|
||||
ASSISTANT_SYSTEM_PROMPT_DEFAULT,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import (
|
||||
ASSISTANT_SYSTEM_PROMPT_PERSONA,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import REVISED_RAG_PROMPT
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import (
|
||||
REVISED_RAG_PROMPT_NO_SUB_QUESTIONS,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import (
|
||||
SUB_QUESTION_ANSWER_TEMPLATE,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import UNKNOWN_ANSWER
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import format_docs
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import get_persona_prompt
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
|
||||
from onyx.chat.models import AgentAnswerPiece
|
||||
from onyx.chat.models import ExtendedToolResponse
|
||||
from onyx.tools.tool_implementations.search.search_tool import yield_search_responses
|
||||
|
||||
|
||||
def generate_refined_answer(
|
||||
state: MainState, config: RunnableConfig
|
||||
) -> RefinedAnswerUpdate:
|
||||
now_start = datetime.now()
|
||||
|
||||
logger.debug(f"--------{now_start}--------GENERATE REFINED ANSWER---")
|
||||
|
||||
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
question = agent_a_config.search_request.query
|
||||
persona_prompt = get_persona_prompt(agent_a_config.search_request.persona)
|
||||
|
||||
history = build_history_prompt(agent_a_config.prompt_builder)
|
||||
|
||||
initial_documents = state.documents
|
||||
revised_documents = state.refined_documents
|
||||
|
||||
combined_documents = dedup_inference_sections(initial_documents, revised_documents)
|
||||
|
||||
query_info = get_query_info(state.original_question_retrieval_results)
|
||||
# stream refined answer docs
|
||||
for tool_response in yield_search_responses(
|
||||
query=question,
|
||||
reranked_sections=combined_documents,
|
||||
final_context_sections=combined_documents,
|
||||
search_query_info=query_info,
|
||||
get_section_relevance=lambda: None, # TODO: add relevance
|
||||
search_tool=agent_a_config.search_tool,
|
||||
):
|
||||
dispatch_custom_event(
|
||||
"tool_response",
|
||||
ExtendedToolResponse(
|
||||
id=tool_response.id,
|
||||
response=tool_response.response,
|
||||
level=1,
|
||||
level_question_nr=0, # 0, 0 is the base question
|
||||
),
|
||||
)
|
||||
|
||||
if len(initial_documents) > 0:
|
||||
revision_doc_effectiveness = len(combined_documents) / len(initial_documents)
|
||||
elif len(revised_documents) == 0:
|
||||
revision_doc_effectiveness = 0.0
|
||||
else:
|
||||
revision_doc_effectiveness = 10.0
|
||||
|
||||
decomp_answer_results = state.decomp_answer_results
|
||||
# revised_answer_results = state.refined_decomp_answer_results
|
||||
|
||||
good_qa_list: list[str] = []
|
||||
decomp_questions = []
|
||||
|
||||
initial_good_sub_questions: list[str] = []
|
||||
new_revised_good_sub_questions: list[str] = []
|
||||
|
||||
sub_question_nr = 1
|
||||
|
||||
for decomp_answer_result in decomp_answer_results:
|
||||
question_level, question_nr = parse_question_id(
|
||||
decomp_answer_result.question_id
|
||||
)
|
||||
|
||||
decomp_questions.append(decomp_answer_result.question)
|
||||
if (
|
||||
decomp_answer_result.quality.lower().startswith("yes")
|
||||
and len(decomp_answer_result.answer) > 0
|
||||
and decomp_answer_result.answer != UNKNOWN_ANSWER
|
||||
):
|
||||
good_qa_list.append(
|
||||
SUB_QUESTION_ANSWER_TEMPLATE.format(
|
||||
sub_question=decomp_answer_result.question,
|
||||
sub_answer=decomp_answer_result.answer,
|
||||
sub_question_nr=sub_question_nr,
|
||||
)
|
||||
)
|
||||
if question_level == 0:
|
||||
initial_good_sub_questions.append(decomp_answer_result.question)
|
||||
else:
|
||||
new_revised_good_sub_questions.append(decomp_answer_result.question)
|
||||
|
||||
sub_question_nr += 1
|
||||
|
||||
initial_good_sub_questions = list(set(initial_good_sub_questions))
|
||||
new_revised_good_sub_questions = list(set(new_revised_good_sub_questions))
|
||||
total_good_sub_questions = list(
|
||||
set(initial_good_sub_questions + new_revised_good_sub_questions)
|
||||
)
|
||||
if len(initial_good_sub_questions) > 0:
|
||||
revision_question_efficiency: float = len(total_good_sub_questions) / len(
|
||||
initial_good_sub_questions
|
||||
)
|
||||
elif len(new_revised_good_sub_questions) > 0:
|
||||
revision_question_efficiency = 10.0
|
||||
else:
|
||||
revision_question_efficiency = 1.0
|
||||
|
||||
sub_question_answer_str = "\n\n------\n\n".join(list(set(good_qa_list)))
|
||||
|
||||
# original answer
|
||||
|
||||
initial_answer = state.initial_answer
|
||||
|
||||
# Determine which persona-specification prompt to use
|
||||
|
||||
if len(persona_prompt) == 0:
|
||||
persona_specification = ASSISTANT_SYSTEM_PROMPT_DEFAULT
|
||||
else:
|
||||
persona_specification = ASSISTANT_SYSTEM_PROMPT_PERSONA.format(
|
||||
persona_prompt=persona_prompt
|
||||
)
|
||||
|
||||
# Determine which base prompt to use given the sub-question information
|
||||
if len(good_qa_list) > 0:
|
||||
base_prompt = REVISED_RAG_PROMPT
|
||||
else:
|
||||
base_prompt = REVISED_RAG_PROMPT_NO_SUB_QUESTIONS
|
||||
|
||||
model = agent_a_config.fast_llm
|
||||
relevant_docs = format_docs(combined_documents)
|
||||
relevant_docs = trim_prompt_piece(
|
||||
model.config,
|
||||
relevant_docs,
|
||||
base_prompt
|
||||
+ question
|
||||
+ sub_question_answer_str
|
||||
+ relevant_docs
|
||||
+ initial_answer
|
||||
+ persona_specification
|
||||
+ history,
|
||||
)
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=base_prompt.format(
|
||||
question=question,
|
||||
history=history,
|
||||
answered_sub_questions=remove_document_citations(
|
||||
sub_question_answer_str
|
||||
),
|
||||
relevant_docs=relevant_docs,
|
||||
initial_answer=remove_document_citations(initial_answer),
|
||||
persona_specification=persona_specification,
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
# Grader
|
||||
|
||||
streamed_tokens: list[str | list[str | dict[str, Any]]] = [""]
|
||||
for message in model.stream(msg):
|
||||
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
|
||||
content = message.content
|
||||
if not isinstance(content, str):
|
||||
raise ValueError(
|
||||
f"Expected content to be a string, but got {type(content)}"
|
||||
)
|
||||
dispatch_custom_event(
|
||||
"refined_agent_answer",
|
||||
AgentAnswerPiece(
|
||||
answer_piece=content,
|
||||
level=1,
|
||||
level_question_nr=0,
|
||||
answer_type="agent_level_answer",
|
||||
),
|
||||
)
|
||||
streamed_tokens.append(content)
|
||||
|
||||
response = merge_content(*streamed_tokens)
|
||||
answer = cast(str, response)
|
||||
|
||||
# refined_agent_stats = _calculate_refined_agent_stats(
|
||||
# state.decomp_answer_results, state.original_question_retrieval_stats
|
||||
# )
|
||||
|
||||
initial_good_sub_questions_str = "\n".join(list(set(initial_good_sub_questions)))
|
||||
new_revised_good_sub_questions_str = "\n".join(
|
||||
list(set(new_revised_good_sub_questions))
|
||||
)
|
||||
|
||||
refined_agent_stats = RefinedAgentStats(
|
||||
revision_doc_efficiency=revision_doc_effectiveness,
|
||||
revision_question_efficiency=revision_question_efficiency,
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"\n\n---INITIAL ANSWER START---\n\n Answer:\n Agent: {initial_answer}"
|
||||
)
|
||||
logger.debug("-" * 10)
|
||||
logger.debug(f"\n\n---REVISED AGENT ANSWER START---\n\n Answer:\n Agent: {answer}")
|
||||
|
||||
logger.debug("-" * 100)
|
||||
logger.debug(f"\n\nINITAL Sub-Questions\n\n{initial_good_sub_questions_str}\n\n")
|
||||
logger.debug("-" * 10)
|
||||
logger.debug(
|
||||
f"\n\nNEW REVISED Sub-Questions\n\n{new_revised_good_sub_questions_str}\n\n"
|
||||
)
|
||||
|
||||
logger.debug("-" * 100)
|
||||
|
||||
logger.debug(
|
||||
f"\n\nINITAL & REVISED Sub-Questions & Answers:\n\n{sub_question_answer_str}\n\nStas:\n\n"
|
||||
)
|
||||
|
||||
logger.debug("-" * 100)
|
||||
|
||||
if state.initial_agent_stats:
|
||||
initial_doc_boost_factor = state.initial_agent_stats.agent_effectiveness.get(
|
||||
"utilized_chunk_ratio", "--"
|
||||
)
|
||||
initial_support_boost_factor = (
|
||||
state.initial_agent_stats.agent_effectiveness.get("support_ratio", "--")
|
||||
)
|
||||
num_initial_verified_docs = state.initial_agent_stats.original_question.get(
|
||||
"num_verified_documents", "--"
|
||||
)
|
||||
initial_verified_docs_avg_score = (
|
||||
state.initial_agent_stats.original_question.get("verified_avg_score", "--")
|
||||
)
|
||||
initial_sub_questions_verified_docs = (
|
||||
state.initial_agent_stats.sub_questions.get("num_verified_documents", "--")
|
||||
)
|
||||
|
||||
logger.debug("INITIAL AGENT STATS")
|
||||
logger.debug(f"Document Boost Factor: {initial_doc_boost_factor}")
|
||||
logger.debug(f"Support Boost Factor: {initial_support_boost_factor}")
|
||||
logger.debug(f"Originally Verified Docs: {num_initial_verified_docs}")
|
||||
logger.debug(
|
||||
f"Originally Verified Docs Avg Score: {initial_verified_docs_avg_score}"
|
||||
)
|
||||
logger.debug(
|
||||
f"Sub-Questions Verified Docs: {initial_sub_questions_verified_docs}"
|
||||
)
|
||||
if refined_agent_stats:
|
||||
logger.debug("-" * 10)
|
||||
logger.debug("REFINED AGENT STATS")
|
||||
logger.debug(
|
||||
f"Revision Doc Factor: {refined_agent_stats.revision_doc_efficiency}"
|
||||
)
|
||||
logger.debug(
|
||||
f"Revision Question Factor: {refined_agent_stats.revision_question_efficiency}"
|
||||
)
|
||||
|
||||
now_end = datetime.now()
|
||||
|
||||
logger.debug(
|
||||
f"--------{now_end}--{now_end - now_start}--------INITIAL AGENT ANSWER END---\n\n"
|
||||
)
|
||||
|
||||
agent_refined_end_time = datetime.now()
|
||||
if state.agent_refined_start_time:
|
||||
agent_refined_duration = (
|
||||
agent_refined_end_time - state.agent_refined_start_time
|
||||
).total_seconds()
|
||||
else:
|
||||
agent_refined_duration = None
|
||||
|
||||
agent_refined_metrics = AgentRefinedMetrics(
|
||||
refined_doc_boost_factor=refined_agent_stats.revision_doc_efficiency,
|
||||
refined_question_boost_factor=refined_agent_stats.revision_question_efficiency,
|
||||
duration__s=agent_refined_duration,
|
||||
)
|
||||
|
||||
now_end = datetime.now()
|
||||
|
||||
logger.debug(
|
||||
f"--------{now_end}--{now_end - now_start}--------REFINED ANSWER UPDATE END---"
|
||||
)
|
||||
|
||||
return RefinedAnswerUpdate(
|
||||
refined_answer=answer,
|
||||
refined_answer_quality=True, # TODO: replace this with the actual check value
|
||||
refined_agent_stats=refined_agent_stats,
|
||||
agent_refined_end_time=agent_refined_end_time,
|
||||
agent_refined_metrics=agent_refined_metrics,
|
||||
)
|
||||
@@ -0,0 +1,39 @@
|
||||
from datetime import datetime
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.base_raw_search.states import (
|
||||
BaseRawSearchOutput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.operations import logger
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import ExpandedRetrievalUpdate
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkStats
|
||||
|
||||
|
||||
def ingest_initial_base_retrieval(
|
||||
state: BaseRawSearchOutput,
|
||||
) -> ExpandedRetrievalUpdate:
|
||||
now_start = datetime.now()
|
||||
|
||||
logger.debug(f"--------{now_start}--------INGEST INITIAL RETRIEVAL---")
|
||||
|
||||
sub_question_retrieval_stats = (
|
||||
state.base_expanded_retrieval_result.sub_question_retrieval_stats
|
||||
)
|
||||
if sub_question_retrieval_stats is None:
|
||||
sub_question_retrieval_stats = AgentChunkStats()
|
||||
else:
|
||||
sub_question_retrieval_stats = sub_question_retrieval_stats
|
||||
|
||||
now_end = datetime.now()
|
||||
|
||||
logger.debug(
|
||||
f"--------{now_end}--{now_end - now_start}--------INGEST INITIAL RETRIEVAL END---"
|
||||
)
|
||||
|
||||
return ExpandedRetrievalUpdate(
|
||||
original_question_retrieval_results=state.base_expanded_retrieval_result.expanded_queries_results,
|
||||
all_original_question_documents=state.base_expanded_retrieval_result.all_documents,
|
||||
original_question_retrieval_stats=sub_question_retrieval_stats,
|
||||
log_messages=[
|
||||
f"{now_end} -- Main - Ingestion base retrieval, Time taken: {now_end - now_start}"
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,38 @@
|
||||
from datetime import datetime
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.states import (
|
||||
AnswerQuestionOutput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.operations import logger
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import DecompAnswersUpdate
|
||||
from onyx.agents.agent_search.shared_graph_utils.operators import (
|
||||
dedup_inference_sections,
|
||||
)
|
||||
|
||||
|
||||
def ingest_initial_sub_question_answers(
|
||||
state: AnswerQuestionOutput,
|
||||
) -> DecompAnswersUpdate:
|
||||
now_start = datetime.now()
|
||||
|
||||
logger.debug(f"--------{now_start}--------INGEST ANSWERS---")
|
||||
documents = []
|
||||
answer_results = state.answer_results if hasattr(state, "answer_results") else []
|
||||
for answer_result in answer_results:
|
||||
documents.extend(answer_result.documents)
|
||||
|
||||
now_end = datetime.now()
|
||||
|
||||
logger.debug(
|
||||
f"--------{now_end}--{now_end - now_start}--------INGEST ANSWERS END---"
|
||||
)
|
||||
|
||||
return DecompAnswersUpdate(
|
||||
# Deduping is done by the documents operator for the main graph
|
||||
# so we might not need to dedup here
|
||||
documents=dedup_inference_sections(documents, []),
|
||||
decomp_answer_results=answer_results,
|
||||
log_messages=[
|
||||
f"{now_end} -- Main - Ingest initial processed sub questions, Time taken: {now_end - now_start}"
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,39 @@
|
||||
from datetime import datetime
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.states import (
|
||||
AnswerQuestionOutput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.operations import logger
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import DecompAnswersUpdate
|
||||
from onyx.agents.agent_search.shared_graph_utils.operators import (
|
||||
dedup_inference_sections,
|
||||
)
|
||||
|
||||
|
||||
def ingest_refined_answers(
|
||||
state: AnswerQuestionOutput,
|
||||
) -> DecompAnswersUpdate:
|
||||
now_start = datetime.now()
|
||||
|
||||
logger.debug(f"--------{now_start}--------INGEST FOLLOW UP ANSWERS---")
|
||||
|
||||
documents = []
|
||||
answer_results = state.answer_results if hasattr(state, "answer_results") else []
|
||||
for answer_result in answer_results:
|
||||
documents.extend(answer_result.documents)
|
||||
|
||||
now_end = datetime.now()
|
||||
|
||||
logger.debug(
|
||||
f"--------{now_end}--{now_end - now_start}--------INGEST FOLLOW UP ANSWERS END---"
|
||||
)
|
||||
|
||||
return DecompAnswersUpdate(
|
||||
# Deduping is done by the documents operator for the main graph
|
||||
# so we might not need to dedup here
|
||||
documents=dedup_inference_sections(documents, []),
|
||||
decomp_answer_results=answer_results,
|
||||
log_messages=[
|
||||
f"{now_end} -- Main - Ingest refined answers, Time taken: {now_end - now_start}"
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,40 @@
|
||||
from datetime import datetime
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.main.operations import logger
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import (
|
||||
InitialAnswerQualityUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import MainState
|
||||
|
||||
|
||||
def initial_answer_quality_check(state: MainState) -> InitialAnswerQualityUpdate:
|
||||
"""
|
||||
Check whether the final output satisfies the original user question
|
||||
|
||||
Args:
|
||||
state (messages): The current state
|
||||
|
||||
Returns:
|
||||
InitialAnswerQualityUpdate
|
||||
"""
|
||||
|
||||
now_start = datetime.now()
|
||||
|
||||
logger.debug(
|
||||
f"--------{now_start}--------Checking for base answer validity - for not set True/False manually"
|
||||
)
|
||||
|
||||
verdict = True
|
||||
|
||||
now_end = datetime.now()
|
||||
|
||||
logger.debug(
|
||||
f"--------{now_end}--{now_end - now_start}--------INITIAL ANSWER QUALITY CHECK END---"
|
||||
)
|
||||
|
||||
return InitialAnswerQualityUpdate(
|
||||
initial_answer_quality=verdict,
|
||||
log_messages=[
|
||||
f"{now_end} -- Main - Initial answer quality check, Time taken: {now_end - now_start}"
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,150 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.callbacks.manager import dispatch_custom_event
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import merge_content
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.main.models import AgentRefinedMetrics
|
||||
from onyx.agents.agent_search.deep_search_a.main.operations import dispatch_subquestion
|
||||
from onyx.agents.agent_search.deep_search_a.main.operations import logger
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import BaseDecompUpdate
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import MainState
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
build_history_prompt,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import (
|
||||
INITIAL_DECOMPOSITION_PROMPT_QUESTIONS,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import (
|
||||
INITIAL_DECOMPOSITION_PROMPT_QUESTIONS_AFTER_SEARCH,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import dispatch_separated
|
||||
from onyx.chat.models import StreamStopInfo
|
||||
from onyx.chat.models import StreamStopReason
|
||||
from onyx.chat.models import SubQuestionPiece
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.tools.tool_implementations.search.search_tool import (
|
||||
SEARCH_RESPONSE_SUMMARY_ID,
|
||||
)
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary
|
||||
|
||||
|
||||
def initial_sub_question_creation(
|
||||
state: MainState, config: RunnableConfig
|
||||
) -> BaseDecompUpdate:
|
||||
now_start = datetime.now()
|
||||
|
||||
logger.debug(f"--------{now_start}--------BASE DECOMP START---")
|
||||
|
||||
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
question = agent_a_config.search_request.query
|
||||
chat_session_id = agent_a_config.chat_session_id
|
||||
primary_message_id = agent_a_config.message_id
|
||||
perform_initial_search_decomposition = (
|
||||
agent_a_config.perform_initial_search_decomposition
|
||||
)
|
||||
perform_initial_search_path_decision = (
|
||||
agent_a_config.perform_initial_search_path_decision
|
||||
)
|
||||
history = build_history_prompt(agent_a_config.prompt_builder)
|
||||
|
||||
# Use the initial search results to inform the decomposition
|
||||
sample_doc_str = state.sample_doc_str if hasattr(state, "sample_doc_str") else ""
|
||||
|
||||
if not chat_session_id or not primary_message_id:
|
||||
raise ValueError(
|
||||
"chat_session_id and message_id must be provided for agent search"
|
||||
)
|
||||
agent_start_time = datetime.now()
|
||||
|
||||
# Initial search to inform decomposition. Just get top 3 fits
|
||||
|
||||
if perform_initial_search_decomposition:
|
||||
if not perform_initial_search_path_decision:
|
||||
search_tool = agent_a_config.search_tool
|
||||
retrieved_docs: list[InferenceSection] = []
|
||||
|
||||
# new db session to avoid concurrency issues
|
||||
with get_session_context_manager() as db_session:
|
||||
for tool_response in search_tool.run(
|
||||
query=question,
|
||||
force_no_rerank=True,
|
||||
alternate_db_session=db_session,
|
||||
):
|
||||
# get retrieved docs to send to the rest of the graph
|
||||
if tool_response.id == SEARCH_RESPONSE_SUMMARY_ID:
|
||||
response = cast(SearchResponseSummary, tool_response.response)
|
||||
retrieved_docs = response.top_sections
|
||||
break
|
||||
|
||||
sample_doc_str = "\n\n".join(
|
||||
[doc.combined_content for _, doc in enumerate(retrieved_docs[:3])]
|
||||
)
|
||||
|
||||
decomposition_prompt = (
|
||||
INITIAL_DECOMPOSITION_PROMPT_QUESTIONS_AFTER_SEARCH.format(
|
||||
question=question, sample_doc_str=sample_doc_str, history=history
|
||||
)
|
||||
)
|
||||
|
||||
else:
|
||||
decomposition_prompt = INITIAL_DECOMPOSITION_PROMPT_QUESTIONS.format(
|
||||
question=question, history=history
|
||||
)
|
||||
|
||||
# Start decomposition
|
||||
|
||||
msg = [HumanMessage(content=decomposition_prompt)]
|
||||
|
||||
# Get the rewritten queries in a defined format
|
||||
model = agent_a_config.fast_llm
|
||||
|
||||
# Send the initial question as a subquestion with number 0
|
||||
dispatch_custom_event(
|
||||
"decomp_qs",
|
||||
SubQuestionPiece(
|
||||
sub_question=question,
|
||||
level=0,
|
||||
level_question_nr=0,
|
||||
),
|
||||
)
|
||||
# dispatches custom events for subquestion tokens, adding in subquestion ids.
|
||||
streamed_tokens = dispatch_separated(model.stream(msg), dispatch_subquestion(0))
|
||||
|
||||
stop_event = StreamStopInfo(
|
||||
stop_reason=StreamStopReason.FINISHED,
|
||||
stream_type="sub_questions",
|
||||
level=0,
|
||||
)
|
||||
dispatch_custom_event("stream_finished", stop_event)
|
||||
|
||||
deomposition_response = merge_content(*streamed_tokens)
|
||||
|
||||
# this call should only return strings. Commenting out for efficiency
|
||||
# assert [type(tok) == str for tok in streamed_tokens]
|
||||
|
||||
# use no-op cast() instead of str() which runs code
|
||||
# list_of_subquestions = clean_and_parse_list_string(cast(str, response))
|
||||
list_of_subqs = cast(str, deomposition_response).split("\n")
|
||||
|
||||
decomp_list: list[str] = [sq.strip() for sq in list_of_subqs if sq.strip() != ""]
|
||||
|
||||
now_end = datetime.now()
|
||||
|
||||
logger.debug(f"--------{now_end}--{now_end - now_start}--------BASE DECOMP END---")
|
||||
|
||||
return BaseDecompUpdate(
|
||||
initial_decomp_questions=decomp_list,
|
||||
agent_start_time=agent_start_time,
|
||||
agent_refined_start_time=None,
|
||||
agent_refined_end_time=None,
|
||||
agent_refined_metrics=AgentRefinedMetrics(
|
||||
refined_doc_boost_factor=None,
|
||||
refined_question_boost_factor=None,
|
||||
duration__s=None,
|
||||
),
|
||||
)
|
||||
@@ -0,0 +1,47 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.main.operations import logger
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import MainState
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import (
|
||||
RequireRefinedAnswerUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
|
||||
|
||||
def refined_answer_decision(
|
||||
state: MainState, config: RunnableConfig
|
||||
) -> RequireRefinedAnswerUpdate:
|
||||
now_start = datetime.now()
|
||||
|
||||
logger.debug(f"--------{now_start}--------REFINED ANSWER DECISION---")
|
||||
|
||||
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
if "?" in agent_a_config.search_request.query:
|
||||
decision = False
|
||||
else:
|
||||
decision = True
|
||||
|
||||
decision = True
|
||||
|
||||
now_end = datetime.now()
|
||||
|
||||
logger.debug(
|
||||
f"--------{now_end}--{now_end - now_start}--------REFINED ANSWER DECISION END---"
|
||||
)
|
||||
log_messages = [
|
||||
f"{now_end} -- Main - Refined answer decision: {decision}, Time taken: {now_end - now_start}"
|
||||
]
|
||||
if agent_a_config.allow_refinement:
|
||||
return RequireRefinedAnswerUpdate(
|
||||
require_refined_answer=decision,
|
||||
log_messages=log_messages,
|
||||
)
|
||||
|
||||
else:
|
||||
return RequireRefinedAnswerUpdate(
|
||||
require_refined_answer=False,
|
||||
log_messages=log_messages,
|
||||
)
|
||||
@@ -0,0 +1,116 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.callbacks.manager import dispatch_custom_event
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import merge_content
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.main.models import FollowUpSubQuestion
|
||||
from onyx.agents.agent_search.deep_search_a.main.operations import dispatch_subquestion
|
||||
from onyx.agents.agent_search.deep_search_a.main.operations import logger
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import (
|
||||
FollowUpSubQuestionsUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import MainState
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
build_history_prompt,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import DEEP_DECOMPOSE_PROMPT
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import dispatch_separated
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
format_entity_term_extraction,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import make_question_id
|
||||
from onyx.tools.models import ToolCallKickoff
|
||||
|
||||
|
||||
def refined_sub_question_creation(
|
||||
state: MainState, config: RunnableConfig
|
||||
) -> FollowUpSubQuestionsUpdate:
|
||||
""" """
|
||||
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
dispatch_custom_event(
|
||||
"start_refined_answer_creation",
|
||||
ToolCallKickoff(
|
||||
tool_name="agent_search_1",
|
||||
tool_args={
|
||||
"query": agent_a_config.search_request.query,
|
||||
"answer": state.initial_answer,
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
now_start = datetime.now()
|
||||
|
||||
logger.debug(f"--------{now_start}--------FOLLOW UP DECOMPOSE---")
|
||||
|
||||
agent_refined_start_time = datetime.now()
|
||||
|
||||
question = agent_a_config.search_request.query
|
||||
base_answer = state.initial_answer
|
||||
history = build_history_prompt(agent_a_config.prompt_builder)
|
||||
# get the entity term extraction dict and properly format it
|
||||
entity_retlation_term_extractions = state.entity_retlation_term_extractions
|
||||
|
||||
entity_term_extraction_str = format_entity_term_extraction(
|
||||
entity_retlation_term_extractions
|
||||
)
|
||||
|
||||
initial_question_answers = state.decomp_answer_results
|
||||
|
||||
addressed_question_list = [
|
||||
x.question for x in initial_question_answers if "yes" in x.quality.lower()
|
||||
]
|
||||
|
||||
failed_question_list = [
|
||||
x.question for x in initial_question_answers if "no" in x.quality.lower()
|
||||
]
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=DEEP_DECOMPOSE_PROMPT.format(
|
||||
question=question,
|
||||
history=history,
|
||||
entity_term_extraction_str=entity_term_extraction_str,
|
||||
base_answer=base_answer,
|
||||
answered_sub_questions="\n - ".join(addressed_question_list),
|
||||
failed_sub_questions="\n - ".join(failed_question_list),
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
# Grader
|
||||
model = agent_a_config.fast_llm
|
||||
|
||||
streamed_tokens = dispatch_separated(model.stream(msg), dispatch_subquestion(1))
|
||||
response = merge_content(*streamed_tokens)
|
||||
|
||||
if isinstance(response, str):
|
||||
parsed_response = [q for q in response.split("\n") if q.strip() != ""]
|
||||
else:
|
||||
raise ValueError("LLM response is not a string")
|
||||
|
||||
refined_sub_question_dict = {}
|
||||
for sub_question_nr, sub_question in enumerate(parsed_response):
|
||||
refined_sub_question = FollowUpSubQuestion(
|
||||
sub_question=sub_question,
|
||||
sub_question_id=make_question_id(1, sub_question_nr + 1),
|
||||
verified=False,
|
||||
answered=False,
|
||||
answer="",
|
||||
)
|
||||
|
||||
refined_sub_question_dict[sub_question_nr + 1] = refined_sub_question
|
||||
|
||||
now_end = datetime.now()
|
||||
|
||||
logger.debug(
|
||||
f"--------{now_end}--{now_end - now_start}--------FOLLOW UP DECOMPOSE END---"
|
||||
)
|
||||
|
||||
return FollowUpSubQuestionsUpdate(
|
||||
refined_sub_questions=refined_sub_question_dict,
|
||||
agent_refined_start_time=agent_refined_start_time,
|
||||
)
|
||||
@@ -0,0 +1,12 @@
|
||||
from datetime import datetime
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import LoggerUpdate
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import MainState
|
||||
|
||||
|
||||
def retrieval_consolidation(
|
||||
state: MainState,
|
||||
) -> LoggerUpdate:
|
||||
now_start = datetime.now()
|
||||
|
||||
return LoggerUpdate(log_messages=[f"{now_start} -- Retrieval consolidation"])
|
||||
@@ -0,0 +1,145 @@
|
||||
import re
|
||||
from collections.abc import Callable
|
||||
|
||||
from langchain_core.callbacks.manager import dispatch_custom_event
|
||||
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkStats
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import InitialAgentResultStats
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import QueryResult
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import (
|
||||
QuestionAnswerResults,
|
||||
)
|
||||
from onyx.chat.models import SubQuestionPiece
|
||||
from onyx.tools.models import SearchQueryInfo
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def remove_document_citations(text: str) -> str:
|
||||
"""
|
||||
Removes citation expressions of format '[[D1]]()' from text.
|
||||
The number after D can vary.
|
||||
|
||||
Args:
|
||||
text: Input text containing citations
|
||||
|
||||
Returns:
|
||||
Text with citations removed
|
||||
"""
|
||||
# Pattern explanation:
|
||||
# \[\[D\d+\]\]\(\) matches:
|
||||
# \[\[ - literal [[ characters
|
||||
# D - literal D character
|
||||
# \d+ - one or more digits
|
||||
# \]\] - literal ]] characters
|
||||
# \(\) - literal () characters
|
||||
return re.sub(r"\[\[(?:D|Q)\d+\]\]\(\)", "", text)
|
||||
|
||||
|
||||
def dispatch_subquestion(level: int) -> Callable[[str, int], None]:
|
||||
def _helper(sub_question_part: str, num: int) -> None:
|
||||
dispatch_custom_event(
|
||||
"decomp_qs",
|
||||
SubQuestionPiece(
|
||||
sub_question=sub_question_part,
|
||||
level=level,
|
||||
level_question_nr=num,
|
||||
),
|
||||
)
|
||||
|
||||
return _helper
|
||||
|
||||
|
||||
def calculate_initial_agent_stats(
|
||||
decomp_answer_results: list[QuestionAnswerResults],
|
||||
original_question_stats: AgentChunkStats,
|
||||
) -> InitialAgentResultStats:
|
||||
initial_agent_result_stats: InitialAgentResultStats = InitialAgentResultStats(
|
||||
sub_questions={},
|
||||
original_question={},
|
||||
agent_effectiveness={},
|
||||
)
|
||||
|
||||
orig_verified = original_question_stats.verified_count
|
||||
orig_support_score = original_question_stats.verified_avg_scores
|
||||
|
||||
verified_document_chunk_ids = []
|
||||
support_scores = 0.0
|
||||
|
||||
for decomp_answer_result in decomp_answer_results:
|
||||
verified_document_chunk_ids += (
|
||||
decomp_answer_result.sub_question_retrieval_stats.verified_doc_chunk_ids
|
||||
)
|
||||
if (
|
||||
decomp_answer_result.sub_question_retrieval_stats.verified_avg_scores
|
||||
is not None
|
||||
):
|
||||
support_scores += (
|
||||
decomp_answer_result.sub_question_retrieval_stats.verified_avg_scores
|
||||
)
|
||||
|
||||
verified_document_chunk_ids = list(set(verified_document_chunk_ids))
|
||||
|
||||
# Calculate sub-question stats
|
||||
if (
|
||||
verified_document_chunk_ids
|
||||
and len(verified_document_chunk_ids) > 0
|
||||
and support_scores is not None
|
||||
):
|
||||
sub_question_stats: dict[str, float | int | None] = {
|
||||
"num_verified_documents": len(verified_document_chunk_ids),
|
||||
"verified_avg_score": float(support_scores / len(decomp_answer_results)),
|
||||
}
|
||||
else:
|
||||
sub_question_stats = {"num_verified_documents": 0, "verified_avg_score": None}
|
||||
|
||||
initial_agent_result_stats.sub_questions.update(sub_question_stats)
|
||||
|
||||
# Get original question stats
|
||||
initial_agent_result_stats.original_question.update(
|
||||
{
|
||||
"num_verified_documents": original_question_stats.verified_count,
|
||||
"verified_avg_score": original_question_stats.verified_avg_scores,
|
||||
}
|
||||
)
|
||||
|
||||
# Calculate chunk utilization ratio
|
||||
sub_verified = initial_agent_result_stats.sub_questions["num_verified_documents"]
|
||||
|
||||
chunk_ratio: float | None = None
|
||||
if sub_verified is not None and orig_verified is not None and orig_verified > 0:
|
||||
chunk_ratio = (float(sub_verified) / orig_verified) if sub_verified > 0 else 0.0
|
||||
elif sub_verified is not None and sub_verified > 0:
|
||||
chunk_ratio = 10.0
|
||||
|
||||
initial_agent_result_stats.agent_effectiveness["utilized_chunk_ratio"] = chunk_ratio
|
||||
|
||||
if (
|
||||
orig_support_score is None
|
||||
or orig_support_score == 0.0
|
||||
and initial_agent_result_stats.sub_questions["verified_avg_score"] is None
|
||||
):
|
||||
initial_agent_result_stats.agent_effectiveness["support_ratio"] = None
|
||||
elif orig_support_score is None or orig_support_score == 0.0:
|
||||
initial_agent_result_stats.agent_effectiveness["support_ratio"] = 10
|
||||
elif initial_agent_result_stats.sub_questions["verified_avg_score"] is None:
|
||||
initial_agent_result_stats.agent_effectiveness["support_ratio"] = 0
|
||||
else:
|
||||
initial_agent_result_stats.agent_effectiveness["support_ratio"] = (
|
||||
initial_agent_result_stats.sub_questions["verified_avg_score"]
|
||||
/ orig_support_score
|
||||
)
|
||||
|
||||
return initial_agent_result_stats
|
||||
|
||||
|
||||
def get_query_info(results: list[QueryResult]) -> SearchQueryInfo:
|
||||
# Use the query info from the base document retrieval
|
||||
# TODO: see if this is the right way to do this
|
||||
query_infos = [
|
||||
result.query_info for result in results if result.query_info is not None
|
||||
]
|
||||
if len(query_infos) == 0:
|
||||
raise ValueError("No query info found")
|
||||
return query_infos[0]
|
||||
187
backend/onyx/agents/agent_search/deep_search_a/main/states.py
Normal file
187
backend/onyx/agents/agent_search/deep_search_a/main/states.py
Normal file
@@ -0,0 +1,187 @@
|
||||
from datetime import datetime
|
||||
from operator import add
|
||||
from typing import Annotated
|
||||
from typing import TypedDict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.agents.agent_search.core_state import CoreState
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.models import (
|
||||
ExpandedRetrievalResult,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.models import AgentBaseMetrics
|
||||
from onyx.agents.agent_search.deep_search_a.main.models import AgentRefinedMetrics
|
||||
from onyx.agents.agent_search.deep_search_a.main.models import FollowUpSubQuestion
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkStats
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import (
|
||||
EntityRelationshipTermExtraction,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import InitialAgentResultStats
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import QueryResult
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import (
|
||||
QuestionAnswerResults,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import RefinedAgentStats
|
||||
from onyx.agents.agent_search.shared_graph_utils.operators import (
|
||||
dedup_inference_sections,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.operators import (
|
||||
dedup_question_answer_results,
|
||||
)
|
||||
from onyx.context.search.models import InferenceSection
|
||||
|
||||
### States ###
|
||||
|
||||
## Update States
|
||||
|
||||
|
||||
class LoggerUpdate(BaseModel):
|
||||
log_messages: Annotated[list[str], add] = []
|
||||
|
||||
|
||||
class RefinedAgentStartStats(BaseModel):
|
||||
agent_refined_start_time: datetime | None = None
|
||||
|
||||
|
||||
class RefinedAgentEndStats(BaseModel):
|
||||
agent_refined_end_time: datetime | None = None
|
||||
agent_refined_metrics: AgentRefinedMetrics = AgentRefinedMetrics()
|
||||
|
||||
|
||||
class BaseDecompUpdateBase(BaseModel):
|
||||
agent_start_time: datetime = datetime.now()
|
||||
initial_decomp_questions: list[str] = []
|
||||
|
||||
|
||||
class RoutingDecisionBase(BaseModel):
|
||||
routing: str = ""
|
||||
sample_doc_str: str = ""
|
||||
|
||||
|
||||
class RoutingDecision(RoutingDecisionBase, LoggerUpdate):
|
||||
pass
|
||||
|
||||
|
||||
class LoggingUpdate(BaseModel):
|
||||
log_messages: list[str] = []
|
||||
|
||||
|
||||
class BaseDecompUpdate(
|
||||
RefinedAgentStartStats, RefinedAgentEndStats, BaseDecompUpdateBase
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
class InitialAnswerBASEUpdate(BaseModel):
|
||||
initial_base_answer: str = ""
|
||||
|
||||
|
||||
class InitialAnswerUpdateBase(BaseModel):
|
||||
initial_answer: str = ""
|
||||
initial_agent_stats: InitialAgentResultStats | None = None
|
||||
generated_sub_questions: list[str] = []
|
||||
agent_base_end_time: datetime | None = None
|
||||
agent_base_metrics: AgentBaseMetrics | None = None
|
||||
|
||||
|
||||
class InitialAnswerUpdate(InitialAnswerUpdateBase, LoggerUpdate):
|
||||
pass
|
||||
|
||||
|
||||
class RefinedAnswerUpdateBase(BaseModel):
|
||||
refined_answer: str = ""
|
||||
refined_agent_stats: RefinedAgentStats | None = None
|
||||
refined_answer_quality: bool = False
|
||||
|
||||
|
||||
class RefinedAnswerUpdate(RefinedAgentEndStats, RefinedAnswerUpdateBase):
|
||||
pass
|
||||
|
||||
|
||||
class InitialAnswerQualityUpdate(LoggingUpdate):
|
||||
initial_answer_quality: bool = False
|
||||
|
||||
|
||||
class RequireRefinedAnswerUpdate(LoggingUpdate):
|
||||
require_refined_answer: bool = True
|
||||
|
||||
|
||||
class DecompAnswersUpdate(LoggingUpdate):
|
||||
documents: Annotated[list[InferenceSection], dedup_inference_sections] = []
|
||||
decomp_answer_results: Annotated[
|
||||
list[QuestionAnswerResults], dedup_question_answer_results
|
||||
] = []
|
||||
|
||||
|
||||
class FollowUpDecompAnswersUpdate(LoggingUpdate):
|
||||
refined_documents: Annotated[list[InferenceSection], dedup_inference_sections] = []
|
||||
refined_decomp_answer_results: Annotated[list[QuestionAnswerResults], add] = []
|
||||
|
||||
|
||||
class ExpandedRetrievalUpdate(LoggingUpdate):
|
||||
all_original_question_documents: Annotated[
|
||||
list[InferenceSection], dedup_inference_sections
|
||||
]
|
||||
original_question_retrieval_results: list[QueryResult] = []
|
||||
original_question_retrieval_stats: AgentChunkStats = AgentChunkStats()
|
||||
|
||||
|
||||
class EntityTermExtractionUpdateBase(LoggingUpdate):
|
||||
entity_retlation_term_extractions: EntityRelationshipTermExtraction = (
|
||||
EntityRelationshipTermExtraction()
|
||||
)
|
||||
|
||||
|
||||
class EntityTermExtractionUpdate(EntityTermExtractionUpdateBase, LoggerUpdate):
|
||||
pass
|
||||
|
||||
|
||||
class FollowUpSubQuestionsUpdateBase(BaseModel):
|
||||
refined_sub_questions: dict[int, FollowUpSubQuestion] = {}
|
||||
|
||||
|
||||
class FollowUpSubQuestionsUpdate(
|
||||
RefinedAgentStartStats, FollowUpSubQuestionsUpdateBase
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
## Graph Input State
|
||||
## Graph Input State
|
||||
|
||||
|
||||
class MainInput(CoreState):
|
||||
pass
|
||||
|
||||
|
||||
## Graph State
|
||||
|
||||
|
||||
class MainState(
|
||||
# This includes the core state
|
||||
MainInput,
|
||||
LoggerUpdate,
|
||||
BaseDecompUpdateBase,
|
||||
InitialAnswerUpdateBase,
|
||||
InitialAnswerBASEUpdate,
|
||||
DecompAnswersUpdate,
|
||||
ExpandedRetrievalUpdate,
|
||||
EntityTermExtractionUpdateBase,
|
||||
InitialAnswerQualityUpdate,
|
||||
RequireRefinedAnswerUpdate,
|
||||
FollowUpSubQuestionsUpdateBase,
|
||||
FollowUpDecompAnswersUpdate,
|
||||
RefinedAnswerUpdateBase,
|
||||
RefinedAgentStartStats,
|
||||
RefinedAgentEndStats,
|
||||
RoutingDecisionBase,
|
||||
):
|
||||
# expanded_retrieval_result: Annotated[list[ExpandedRetrievalResult], add]
|
||||
base_raw_search_result: Annotated[list[ExpandedRetrievalResult], add]
|
||||
|
||||
|
||||
## Graph Output State - presently not used
|
||||
|
||||
|
||||
class MainOutput(TypedDict):
|
||||
log_messages: list[str]
|
||||
86
backend/onyx/agents/agent_search/models.py
Normal file
86
backend/onyx/agents/agent_search/models.py
Normal file
@@ -0,0 +1,86 @@
|
||||
from dataclasses import dataclass
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import model_validator
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.context.search.models import SearchRequest
|
||||
from onyx.file_store.utils import InMemoryChatFile
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.tools.force import ForceUseTool
|
||||
from onyx.tools.tool import Tool
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentSearchConfig:
|
||||
"""
|
||||
Configuration for the Agent Search feature.
|
||||
"""
|
||||
|
||||
# The search request that was used to generate the Pro Search
|
||||
search_request: SearchRequest
|
||||
|
||||
primary_llm: LLM
|
||||
fast_llm: LLM
|
||||
search_tool: SearchTool
|
||||
|
||||
# Whether to force use of a tool, or to
|
||||
# force tool args IF the tool is used
|
||||
force_use_tool: ForceUseTool
|
||||
|
||||
# contains message history for the current chat session
|
||||
# has the following (at most one is non-None)
|
||||
# message_history: list[PreviousMessage] | None = None
|
||||
# single_message_history: str | None = None
|
||||
prompt_builder: AnswerPromptBuilder
|
||||
|
||||
use_agentic_search: bool = False
|
||||
|
||||
# For persisting agent search data
|
||||
chat_session_id: UUID | None = None
|
||||
|
||||
# The message ID of the user message that triggered the Pro Search
|
||||
message_id: int | None = None
|
||||
|
||||
# Whether to persistence data for the Pro Search (turned off for testing)
|
||||
use_persistence: bool = True
|
||||
|
||||
# The database session for the Pro Search
|
||||
db_session: Session | None = None
|
||||
|
||||
# Whether to perform initial search to inform decomposition
|
||||
perform_initial_search_path_decision: bool = True
|
||||
|
||||
# Whether to perform initial search to inform decomposition
|
||||
perform_initial_search_decomposition: bool = True
|
||||
|
||||
# Whether to allow creation of refinement questions (and entity extraction, etc.)
|
||||
allow_refinement: bool = True
|
||||
|
||||
# Tools available for use
|
||||
tools: list[Tool] | None = None
|
||||
|
||||
using_tool_calling_llm: bool = False
|
||||
|
||||
files: list[InMemoryChatFile] | None = None
|
||||
|
||||
structured_response_format: dict | None = None
|
||||
|
||||
skip_gen_ai_answer_generation: bool = False
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_db_session(self) -> "AgentSearchConfig":
|
||||
if self.use_persistence and self.db_session is None:
|
||||
raise ValueError(
|
||||
"db_session must be provided for pro search when using persistence"
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
class AgentDocumentCitations(BaseModel):
|
||||
document_id: str
|
||||
document_title: str
|
||||
link: str
|
||||
277
backend/onyx/agents/agent_search/run_graph.py
Normal file
277
backend/onyx/agents/agent_search/run_graph.py
Normal file
@@ -0,0 +1,277 @@
|
||||
import asyncio
|
||||
from asyncio import AbstractEventLoop
|
||||
from collections.abc import AsyncIterable
|
||||
from collections.abc import Iterable
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.runnables.schema import StreamEvent
|
||||
from langgraph.graph.state import CompiledStateGraph
|
||||
|
||||
from onyx.agents.agent_search.basic.graph_builder import basic_graph_builder
|
||||
from onyx.agents.agent_search.basic.states import BasicInput
|
||||
from onyx.agents.agent_search.deep_search_a.main.graph_builder import (
|
||||
main_graph_builder as main_graph_builder_a,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import MainInput as MainInput_a
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
|
||||
from onyx.chat.models import AgentAnswerPiece
|
||||
from onyx.chat.models import AnswerPacket
|
||||
from onyx.chat.models import AnswerStream
|
||||
from onyx.chat.models import ExtendedToolResponse
|
||||
from onyx.chat.models import StreamStopInfo
|
||||
from onyx.chat.models import SubQueryPiece
|
||||
from onyx.chat.models import SubQuestionPiece
|
||||
from onyx.chat.models import ToolResponse
|
||||
from onyx.configs.dev_configs import GRAPH_NAME
|
||||
from onyx.context.search.models import SearchRequest
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.tools.tool_runner import ToolCallKickoff
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_COMPILED_GRAPH: CompiledStateGraph | None = None
|
||||
|
||||
|
||||
def _set_combined_token_value(
|
||||
combined_token: str, parsed_object: AgentAnswerPiece
|
||||
) -> AgentAnswerPiece:
|
||||
parsed_object.answer_piece = combined_token
|
||||
|
||||
return parsed_object
|
||||
|
||||
|
||||
def _parse_agent_event(
|
||||
event: StreamEvent,
|
||||
) -> AnswerPacket | None:
|
||||
"""
|
||||
Parse the event into a typed object.
|
||||
Return None if we are not interested in the event.
|
||||
"""
|
||||
|
||||
event_type = event["event"]
|
||||
|
||||
# We always just yield the event data, but this piece is useful for two development reasons:
|
||||
# 1. It's a list of the names of every place we dispatch a custom event
|
||||
# 2. We maintain the intended types yielded by each event
|
||||
if event_type == "on_custom_event":
|
||||
# TODO: different AnswerStream types for different events
|
||||
if event["name"] == "decomp_qs":
|
||||
return cast(SubQuestionPiece, event["data"])
|
||||
elif event["name"] == "subqueries":
|
||||
return cast(SubQueryPiece, event["data"])
|
||||
elif event["name"] == "sub_answers":
|
||||
return cast(AgentAnswerPiece, event["data"])
|
||||
elif event["name"] == "stream_finished":
|
||||
return cast(StreamStopInfo, event["data"])
|
||||
elif event["name"] == "initial_agent_answer":
|
||||
return cast(AgentAnswerPiece, event["data"])
|
||||
elif event["name"] == "refined_agent_answer":
|
||||
return cast(AgentAnswerPiece, event["data"])
|
||||
elif event["name"] == "start_refined_answer_creation":
|
||||
return cast(ToolCallKickoff, event["data"])
|
||||
elif event["name"] == "tool_response":
|
||||
return cast(ToolResponse, event["data"])
|
||||
elif event["name"] == "basic_response":
|
||||
return cast(AnswerPacket, event["data"])
|
||||
return None
|
||||
|
||||
|
||||
async def tear_down(event_loop: AbstractEventLoop) -> None:
|
||||
# Collect all tasks and cancel those that are not 'done'.
|
||||
tasks = asyncio.all_tasks(event_loop)
|
||||
for task in tasks:
|
||||
task.cancel()
|
||||
|
||||
# Wait for all tasks to complete, ignoring any CancelledErrors
|
||||
try:
|
||||
await asyncio.wait(tasks)
|
||||
except asyncio.exceptions.CancelledError:
|
||||
pass
|
||||
|
||||
|
||||
def _manage_async_event_streaming(
|
||||
compiled_graph: CompiledStateGraph,
|
||||
config: AgentSearchConfig | None,
|
||||
graph_input: MainInput_a | BasicInput,
|
||||
) -> Iterable[StreamEvent]:
|
||||
async def _run_async_event_stream(
|
||||
loop: AbstractEventLoop,
|
||||
) -> AsyncIterable[StreamEvent]:
|
||||
try:
|
||||
message_id = config.message_id if config else None
|
||||
async for event in compiled_graph.astream_events(
|
||||
input=graph_input,
|
||||
config={"metadata": {"config": config, "thread_id": str(message_id)}},
|
||||
# debug=True,
|
||||
# indicating v2 here deserves further scrutiny
|
||||
version="v2",
|
||||
):
|
||||
yield event
|
||||
finally:
|
||||
await tear_down(loop)
|
||||
|
||||
# This might be able to be simplified
|
||||
def _yield_async_to_sync() -> Iterable[StreamEvent]:
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
# Get the async generator
|
||||
async_gen = _run_async_event_stream(loop)
|
||||
# Convert to AsyncIterator
|
||||
async_iter = async_gen.__aiter__()
|
||||
while True:
|
||||
try:
|
||||
# Create a coroutine by calling anext with the async iterator
|
||||
next_coro = anext(async_iter)
|
||||
# Run the coroutine to get the next event
|
||||
event = loop.run_until_complete(next_coro)
|
||||
yield event
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
return _yield_async_to_sync()
|
||||
|
||||
|
||||
def run_graph(
|
||||
compiled_graph: CompiledStateGraph,
|
||||
config: AgentSearchConfig,
|
||||
input: BasicInput | MainInput_a,
|
||||
) -> AnswerStream:
|
||||
# TODO: add these to the environment
|
||||
config.perform_initial_search_path_decision = True
|
||||
config.perform_initial_search_decomposition = True
|
||||
config.allow_refinement = True
|
||||
|
||||
for event in _manage_async_event_streaming(
|
||||
compiled_graph=compiled_graph, config=config, graph_input=input
|
||||
):
|
||||
if not (parsed_object := _parse_agent_event(event)):
|
||||
continue
|
||||
|
||||
yield parsed_object
|
||||
|
||||
|
||||
# TODO: call this once on startup, TBD where and if it should be gated based
|
||||
# on dev mode or not
|
||||
def load_compiled_graph(graph_name: str) -> CompiledStateGraph:
|
||||
main_graph_builder = (
|
||||
main_graph_builder_a if graph_name == "a" else main_graph_builder_a
|
||||
)
|
||||
global _COMPILED_GRAPH
|
||||
if _COMPILED_GRAPH is None:
|
||||
graph = main_graph_builder()
|
||||
_COMPILED_GRAPH = graph.compile()
|
||||
return _COMPILED_GRAPH
|
||||
|
||||
|
||||
def run_main_graph(
|
||||
config: AgentSearchConfig,
|
||||
graph_name: str = "a",
|
||||
) -> AnswerStream:
|
||||
compiled_graph = load_compiled_graph(graph_name)
|
||||
if graph_name == "a":
|
||||
input = MainInput_a(base_question=config.search_request.query, log_messages=[])
|
||||
else:
|
||||
input = MainInput_a(base_question=config.search_request.query, log_messages=[])
|
||||
|
||||
# Agent search is not a Tool per se, but this is helpful for the frontend
|
||||
yield ToolCallKickoff(
|
||||
tool_name="agent_search_0",
|
||||
tool_args={"query": config.search_request.query},
|
||||
)
|
||||
yield from run_graph(compiled_graph, config, input)
|
||||
|
||||
|
||||
# TODO: unify input types, especially prosearchconfig
|
||||
def run_basic_graph(
|
||||
config: AgentSearchConfig,
|
||||
) -> AnswerStream:
|
||||
graph = basic_graph_builder()
|
||||
compiled_graph = graph.compile()
|
||||
# TODO: unify basic input
|
||||
input = BasicInput(
|
||||
should_stream_answer=True,
|
||||
)
|
||||
return run_graph(compiled_graph, config, input)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from onyx.llm.factory import get_default_llms
|
||||
|
||||
now_start = datetime.now()
|
||||
logger.debug(f"Start at {now_start}")
|
||||
|
||||
if GRAPH_NAME == "a":
|
||||
graph = main_graph_builder_a()
|
||||
else:
|
||||
graph = main_graph_builder_a()
|
||||
compiled_graph = graph.compile()
|
||||
now_end = datetime.now()
|
||||
logger.debug(f"Graph compiled in {now_end - now_start} seconds")
|
||||
primary_llm, fast_llm = get_default_llms()
|
||||
search_request = SearchRequest(
|
||||
# query="what can you do with gitlab?",
|
||||
# query="What are the guiding principles behind the development of cockroachDB",
|
||||
# query="What are the temperatures in Munich, Hawaii, and New York?",
|
||||
# query="When was Washington born?",
|
||||
query="What is Onyx?",
|
||||
)
|
||||
# Joachim custom persona
|
||||
|
||||
with get_session_context_manager() as db_session:
|
||||
config, search_tool = get_test_config(
|
||||
db_session, primary_llm, fast_llm, search_request
|
||||
)
|
||||
# search_request.persona = get_persona_by_id(1, None, db_session)
|
||||
config.use_persistence = True
|
||||
config.perform_initial_search_path_decision = True
|
||||
config.perform_initial_search_decomposition = True
|
||||
if GRAPH_NAME == "a":
|
||||
input = MainInput_a(
|
||||
base_question=config.search_request.query, log_messages=[]
|
||||
)
|
||||
else:
|
||||
input = MainInput_a(
|
||||
base_question=config.search_request.query, log_messages=[]
|
||||
)
|
||||
# with open("output.txt", "w") as f:
|
||||
tool_responses: list = []
|
||||
for output in run_graph(compiled_graph, config, input):
|
||||
# pass
|
||||
|
||||
if isinstance(output, ToolCallKickoff):
|
||||
pass
|
||||
elif isinstance(output, ExtendedToolResponse):
|
||||
tool_responses.append(output.response)
|
||||
logger.info(
|
||||
f" ---- ET {output.level} - {output.level_question_nr} | "
|
||||
)
|
||||
elif isinstance(output, SubQueryPiece):
|
||||
logger.info(
|
||||
f"Sq {output.level} - {output.level_question_nr} - {output.sub_query} | "
|
||||
)
|
||||
elif isinstance(output, SubQuestionPiece):
|
||||
logger.info(
|
||||
f"SQ {output.level} - {output.level_question_nr} - {output.sub_question} | "
|
||||
)
|
||||
elif (
|
||||
isinstance(output, AgentAnswerPiece)
|
||||
and output.answer_type == "agent_sub_answer"
|
||||
):
|
||||
logger.info(
|
||||
f" ---- SA {output.level} - {output.level_question_nr} {output.answer_piece} | "
|
||||
)
|
||||
elif (
|
||||
isinstance(output, AgentAnswerPiece)
|
||||
and output.answer_type == "agent_level_answer"
|
||||
):
|
||||
logger.info(
|
||||
f" ---------- FA {output.level} - {output.level_question_nr} {output.answer_piece} | "
|
||||
)
|
||||
|
||||
# for tool_response in tool_responses:
|
||||
# logger.debug(tool_response)
|
||||
@@ -0,0 +1,90 @@
|
||||
from langchain.schema import AIMessage
|
||||
from langchain.schema import HumanMessage
|
||||
from langchain.schema import SystemMessage
|
||||
from langchain_core.messages.tool import ToolMessage
|
||||
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import BASE_RAG_PROMPT_v2
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import HISTORY_PROMPT
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.llm.interfaces import LLMConfig
|
||||
from onyx.llm.utils import get_max_input_tokens
|
||||
from onyx.natural_language_processing.utils import get_tokenizer
|
||||
from onyx.natural_language_processing.utils import tokenizer_trim_content
|
||||
|
||||
|
||||
def build_sub_question_answer_prompt(
|
||||
question: str,
|
||||
original_question: str,
|
||||
docs: list[InferenceSection],
|
||||
persona_specification: str,
|
||||
config: LLMConfig,
|
||||
) -> list[SystemMessage | HumanMessage | AIMessage | ToolMessage]:
|
||||
system_message = SystemMessage(
|
||||
content=persona_specification,
|
||||
)
|
||||
|
||||
docs_format_list = [
|
||||
f"""Document Number: [D{doc_nr + 1}]\n
|
||||
Content: {doc.combined_content}\n\n"""
|
||||
for doc_nr, doc in enumerate(docs)
|
||||
]
|
||||
|
||||
docs_str = "\n\n".join(docs_format_list)
|
||||
|
||||
docs_str = trim_prompt_piece(
|
||||
config, docs_str, BASE_RAG_PROMPT_v2 + question + original_question
|
||||
)
|
||||
human_message = HumanMessage(
|
||||
content=BASE_RAG_PROMPT_v2.format(
|
||||
question=question, original_question=original_question, context=docs_str
|
||||
)
|
||||
)
|
||||
|
||||
return [system_message, human_message]
|
||||
|
||||
|
||||
def trim_prompt_piece(config: LLMConfig, prompt_piece: str, reserved_str: str) -> str:
|
||||
# TODO: this truncating might add latency. We could do a rougher + faster check
|
||||
# first to determine whether truncation is needed
|
||||
|
||||
# TODO: maybe save the tokenizer and max input tokens if this is getting called multiple times?
|
||||
llm_tokenizer = get_tokenizer(
|
||||
provider_type=config.model_provider,
|
||||
model_name=config.model_name,
|
||||
)
|
||||
|
||||
max_tokens = get_max_input_tokens(
|
||||
model_provider=config.model_provider,
|
||||
model_name=config.model_name,
|
||||
)
|
||||
|
||||
# slightly conservative trimming
|
||||
return tokenizer_trim_content(
|
||||
content=prompt_piece,
|
||||
desired_length=max_tokens - len(llm_tokenizer.encode(reserved_str)),
|
||||
tokenizer=llm_tokenizer,
|
||||
)
|
||||
|
||||
|
||||
def build_history_prompt(prompt_builder: AnswerPromptBuilder | None) -> str:
|
||||
if prompt_builder is None:
|
||||
return ""
|
||||
|
||||
if prompt_builder.single_message_history is not None:
|
||||
history = prompt_builder.single_message_history
|
||||
else:
|
||||
history = ""
|
||||
previous_message_type = None
|
||||
for message in prompt_builder.raw_message_history:
|
||||
if "user" in message.message_type:
|
||||
history += f"User: {message.message}\n"
|
||||
previous_message_type = "user"
|
||||
elif "assistant" in message.message_type:
|
||||
# only use the initial agent answer for the history
|
||||
if previous_message_type != "assistant":
|
||||
history += f"You/Agent: {message.message}\n"
|
||||
previous_message_type = "assistant"
|
||||
else:
|
||||
continue
|
||||
return HISTORY_PROMPT.format(history=history) if history else ""
|
||||
@@ -0,0 +1,98 @@
|
||||
import numpy as np
|
||||
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import RetrievalFitScoreMetrics
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import RetrievalFitStats
|
||||
from onyx.chat.models import SectionRelevancePiece
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def unique_chunk_id(doc: InferenceSection) -> str:
|
||||
return f"{doc.center_chunk.document_id}_{doc.center_chunk.chunk_id}"
|
||||
|
||||
|
||||
def calculate_rank_shift(list1: list, list2: list, top_n: int = 20) -> float:
|
||||
shift = 0
|
||||
for rank_first, doc_id in enumerate(list1[:top_n], 1):
|
||||
try:
|
||||
rank_second = list2.index(doc_id) + 1
|
||||
except ValueError:
|
||||
rank_second = len(list2) # Document not found in second list
|
||||
|
||||
shift += np.abs(rank_first - rank_second) / np.log(1 + rank_first * rank_second)
|
||||
|
||||
return shift / top_n
|
||||
|
||||
|
||||
def get_fit_scores(
|
||||
pre_reranked_results: list[InferenceSection],
|
||||
post_reranked_results: list[InferenceSection] | list[SectionRelevancePiece],
|
||||
) -> RetrievalFitStats | None:
|
||||
"""
|
||||
Calculate retrieval metrics for search purposes
|
||||
"""
|
||||
|
||||
if len(pre_reranked_results) == 0 or len(post_reranked_results) == 0:
|
||||
return None
|
||||
|
||||
ranked_sections = {
|
||||
"initial": pre_reranked_results,
|
||||
"reranked": post_reranked_results,
|
||||
}
|
||||
|
||||
fit_eval: RetrievalFitStats = RetrievalFitStats(
|
||||
fit_score_lift=0,
|
||||
rerank_effect=0,
|
||||
fit_scores={
|
||||
"initial": RetrievalFitScoreMetrics(scores={}, chunk_ids=[]),
|
||||
"reranked": RetrievalFitScoreMetrics(scores={}, chunk_ids=[]),
|
||||
},
|
||||
)
|
||||
|
||||
for rank_type, docs in ranked_sections.items():
|
||||
logger.debug(f"rank_type: {rank_type}")
|
||||
|
||||
for i in [1, 5, 10]:
|
||||
fit_eval.fit_scores[rank_type].scores[str(i)] = (
|
||||
sum(
|
||||
[
|
||||
float(doc.center_chunk.score)
|
||||
for doc in docs[:i]
|
||||
if type(doc) == InferenceSection
|
||||
and doc.center_chunk.score is not None
|
||||
]
|
||||
)
|
||||
/ i
|
||||
)
|
||||
|
||||
fit_eval.fit_scores[rank_type].scores["fit_score"] = (
|
||||
1
|
||||
/ 3
|
||||
* (
|
||||
fit_eval.fit_scores[rank_type].scores["1"]
|
||||
+ fit_eval.fit_scores[rank_type].scores["5"]
|
||||
+ fit_eval.fit_scores[rank_type].scores["10"]
|
||||
)
|
||||
)
|
||||
|
||||
fit_eval.fit_scores[rank_type].scores["fit_score"] = fit_eval.fit_scores[
|
||||
rank_type
|
||||
].scores["1"]
|
||||
|
||||
fit_eval.fit_scores[rank_type].chunk_ids = [
|
||||
unique_chunk_id(doc) for doc in docs if type(doc) == InferenceSection
|
||||
]
|
||||
|
||||
fit_eval.fit_score_lift = (
|
||||
fit_eval.fit_scores["reranked"].scores["fit_score"]
|
||||
/ fit_eval.fit_scores["initial"].scores["fit_score"]
|
||||
)
|
||||
|
||||
fit_eval.rerank_effect = calculate_rank_shift(
|
||||
fit_eval.fit_scores["initial"].chunk_ids,
|
||||
fit_eval.fit_scores["reranked"].chunk_ids,
|
||||
)
|
||||
|
||||
return fit_eval
|
||||
112
backend/onyx/agents/agent_search/shared_graph_utils/models.py
Normal file
112
backend/onyx/agents/agent_search/shared_graph_utils/models.py
Normal file
@@ -0,0 +1,112 @@
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.main.models import AgentAdditionalMetrics
|
||||
from onyx.agents.agent_search.deep_search_a.main.models import AgentBaseMetrics
|
||||
from onyx.agents.agent_search.deep_search_a.main.models import AgentRefinedMetrics
|
||||
from onyx.agents.agent_search.deep_search_a.main.models import AgentTimings
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.tools.models import SearchQueryInfo
|
||||
|
||||
|
||||
# Pydantic models for structured outputs
|
||||
class RewrittenQueries(BaseModel):
|
||||
rewritten_queries: list[str]
|
||||
|
||||
|
||||
class BinaryDecision(BaseModel):
|
||||
decision: Literal["yes", "no"]
|
||||
|
||||
|
||||
class BinaryDecisionWithReasoning(BaseModel):
|
||||
reasoning: str
|
||||
decision: Literal["yes", "no"]
|
||||
|
||||
|
||||
class RetrievalFitScoreMetrics(BaseModel):
|
||||
scores: dict[str, float]
|
||||
chunk_ids: list[str]
|
||||
|
||||
|
||||
class RetrievalFitStats(BaseModel):
|
||||
fit_score_lift: float
|
||||
rerank_effect: float
|
||||
fit_scores: dict[str, RetrievalFitScoreMetrics]
|
||||
|
||||
|
||||
class AgentChunkScores(BaseModel):
|
||||
scores: dict[str, dict[str, list[int | float]]]
|
||||
|
||||
|
||||
class AgentChunkStats(BaseModel):
|
||||
verified_count: int | None = None
|
||||
verified_avg_scores: float | None = None
|
||||
rejected_count: int | None = None
|
||||
rejected_avg_scores: float | None = None
|
||||
verified_doc_chunk_ids: list[str] = []
|
||||
dismissed_doc_chunk_ids: list[str] = []
|
||||
|
||||
|
||||
class InitialAgentResultStats(BaseModel):
|
||||
sub_questions: dict[str, float | int | None]
|
||||
original_question: dict[str, float | int | None]
|
||||
agent_effectiveness: dict[str, float | int | None]
|
||||
|
||||
|
||||
class RefinedAgentStats(BaseModel):
|
||||
revision_doc_efficiency: float | None
|
||||
revision_question_efficiency: float | None
|
||||
|
||||
|
||||
class Term(BaseModel):
|
||||
term_name: str = ""
|
||||
term_type: str = ""
|
||||
term_similar_to: list[str] = []
|
||||
|
||||
|
||||
### Models ###
|
||||
|
||||
|
||||
class Entity(BaseModel):
|
||||
entity_name: str = ""
|
||||
entity_type: str = ""
|
||||
|
||||
|
||||
class Relationship(BaseModel):
|
||||
relationship_name: str = ""
|
||||
relationship_type: str = ""
|
||||
relationship_entities: list[str] = []
|
||||
|
||||
|
||||
class EntityRelationshipTermExtraction(BaseModel):
|
||||
entities: list[Entity] = []
|
||||
relationships: list[Relationship] = []
|
||||
terms: list[Term] = []
|
||||
|
||||
|
||||
### Models ###
|
||||
|
||||
|
||||
class QueryResult(BaseModel):
|
||||
query: str
|
||||
search_results: list[InferenceSection]
|
||||
stats: RetrievalFitStats | None
|
||||
query_info: SearchQueryInfo | None
|
||||
|
||||
|
||||
class QuestionAnswerResults(BaseModel):
|
||||
question: str
|
||||
question_id: str
|
||||
answer: str
|
||||
quality: str
|
||||
expanded_retrieval_results: list[QueryResult]
|
||||
documents: list[InferenceSection]
|
||||
sub_question_retrieval_stats: AgentChunkStats
|
||||
|
||||
|
||||
class CombinedAgentMetrics(BaseModel):
|
||||
timings: AgentTimings
|
||||
base_metrics: AgentBaseMetrics | None
|
||||
refined_metrics: AgentRefinedMetrics
|
||||
additional_metrics: AgentAdditionalMetrics
|
||||
@@ -0,0 +1,31 @@
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import (
|
||||
QuestionAnswerResults,
|
||||
)
|
||||
from onyx.chat.prune_and_merge import _merge_sections
|
||||
from onyx.context.search.models import InferenceSection
|
||||
|
||||
|
||||
def dedup_inference_sections(
|
||||
list1: list[InferenceSection], list2: list[InferenceSection]
|
||||
) -> list[InferenceSection]:
|
||||
deduped = _merge_sections(list1 + list2)
|
||||
return deduped
|
||||
|
||||
|
||||
def dedup_question_answer_results(
|
||||
question_answer_results_1: list[QuestionAnswerResults],
|
||||
question_answer_results_2: list[QuestionAnswerResults],
|
||||
) -> list[QuestionAnswerResults]:
|
||||
deduped_question_answer_results: list[
|
||||
QuestionAnswerResults
|
||||
] = question_answer_results_1
|
||||
utilized_question_ids: set[str] = set(
|
||||
[x.question_id for x in question_answer_results_1]
|
||||
)
|
||||
|
||||
for question_answer_result in question_answer_results_2:
|
||||
if question_answer_result.question_id not in utilized_question_ids:
|
||||
deduped_question_answer_results.append(question_answer_result)
|
||||
utilized_question_ids.add(question_answer_result.question_id)
|
||||
|
||||
return deduped_question_answer_results
|
||||
936
backend/onyx/agents/agent_search/shared_graph_utils/prompts.py
Normal file
936
backend/onyx/agents/agent_search/shared_graph_utils/prompts.py
Normal file
@@ -0,0 +1,936 @@
|
||||
UNKNOWN_ANSWER = "I do not have enough information to answer this question."
|
||||
|
||||
NO_RECOVERED_DOCS = "No relevant documents recovered"
|
||||
|
||||
HISTORY_PROMPT = """\n
|
||||
For more context, here is the history of the conversation so far that preceeded this question:
|
||||
\n ------- \n
|
||||
{history}
|
||||
\n ------- \n\n
|
||||
"""
|
||||
|
||||
REWRITE_PROMPT_MULTI_ORIGINAL = """ \n
|
||||
Please convert an initial user question into a 2-3 more appropriate short and pointed search queries for retrievel from a
|
||||
document store. Particularly, try to think about resolving ambiguities and make the search queries more specific,
|
||||
enabling the system to search more broadly.
|
||||
Also, try to make the search queries not redundant, i.e. not too similar! \n\n
|
||||
Here is the initial question:
|
||||
\n ------- \n
|
||||
{question}
|
||||
\n ------- \n
|
||||
Formulate the queries separated by newlines (Do not say 'Query 1: ...', just write the querytext) as follows:
|
||||
<query 1>
|
||||
<query 2>
|
||||
...
|
||||
queries: """
|
||||
|
||||
REWRITE_PROMPT_MULTI = """ \n
|
||||
Please create a list of 2-3 sample documents that could answer an original question. Each document
|
||||
should be about as long as the original question. \n
|
||||
Here is the initial question:
|
||||
\n ------- \n
|
||||
{question}
|
||||
\n ------- \n
|
||||
Formulate the sample documents separated by '--' (Do not say 'Document 1: ...', just write the text): """
|
||||
|
||||
# The prompt is only used if there is no persona prompt, so the placeholder is ''
|
||||
BASE_RAG_PROMPT = (
|
||||
""" \n
|
||||
{persona_specification}
|
||||
Use the context provided below - and only the
|
||||
provided context - to answer the given question. (Note that the answer is in service of anserwing a broader
|
||||
question, given below as 'motivation'.)
|
||||
|
||||
Again, only use the provided context and do not use your internal knowledge! If you cannot answer the
|
||||
question based on the context, say """
|
||||
+ f'"{UNKNOWN_ANSWER}"'
|
||||
+ """. It is a matter of life and death that you do NOT
|
||||
use your internal knowledge, just the provided information!
|
||||
|
||||
Make sure that you keep all relevant information, specifically as it concerns to the ultimate goal.
|
||||
(But keep other details as well.)
|
||||
|
||||
\nContext:\n {context} \n
|
||||
|
||||
Motivation:\n {original_question} \n\n
|
||||
\n\n
|
||||
And here is the question I want you to answer based on the context above (with the motivation in mind):
|
||||
\n--\n {question} \n--\n
|
||||
"""
|
||||
)
|
||||
|
||||
BASE_RAG_PROMPT_v2 = (
|
||||
""" \n
|
||||
Use the context provided below - and only the
|
||||
provided context - to answer the given question. (Note that the answer is in service of answering a broader
|
||||
question, given below as 'motivation'.)
|
||||
|
||||
Again, only use the provided context and do not use your internal knowledge! If you cannot answer the
|
||||
question based on the context, say """
|
||||
+ f'"{UNKNOWN_ANSWER}"'
|
||||
+ """. It is a matter of life and death that you do NOT
|
||||
use your internal knowledge, just the provided information!
|
||||
|
||||
Make sure that you keep all relevant information, specifically as it concerns to the ultimate goal.
|
||||
(But keep other details as well.)
|
||||
|
||||
Please remember to provide inline citations in the format [[D1]](), [[D2]](), [[D3]](), etc.
|
||||
Proper citations are very important to the user!\n\n\n
|
||||
|
||||
For your general information, here is the ultimate motivation:
|
||||
\n--\n {original_question} \n--\n
|
||||
\n\n
|
||||
And here is the actual question I want you to answer based on the context above (with the motivation in mind):
|
||||
\n--\n {question} \n--\n
|
||||
|
||||
Here is the context:
|
||||
\n\n\n--\n {context} \n--\n
|
||||
"""
|
||||
)
|
||||
|
||||
SUB_CHECK_YES = "yes"
|
||||
SUB_CHECK_NO = "no"
|
||||
|
||||
SUB_CHECK_PROMPT = (
|
||||
"""
|
||||
Your task is to see whether a given answer addresses a given question.
|
||||
Please do not use any internal knowledge you may have - just focus on whether the answer
|
||||
as given seems to largely address the question as given, or at least addresses part of the question.
|
||||
Here is the question:
|
||||
\n ------- \n
|
||||
{question}
|
||||
\n ------- \n
|
||||
Here is the suggested answer:
|
||||
\n ------- \n
|
||||
{base_answer}
|
||||
\n ------- \n
|
||||
Does the suggested answer address the question? Please answer with """
|
||||
+ f'"{SUB_CHECK_YES}" or "{SUB_CHECK_NO}".'
|
||||
)
|
||||
|
||||
|
||||
BASE_CHECK_PROMPT = """ \n
|
||||
Please check whether 1) the suggested answer seems to fully address the original question AND 2)the
|
||||
original question requests a simple, factual answer, and there are no ambiguities, judgements,
|
||||
aggregations, or any other complications that may require extra context. (I.e., if the question is
|
||||
somewhat addressed, but the answer would benefit from more context, then answer with 'no'.)
|
||||
|
||||
Please only answer with 'yes' or 'no' \n
|
||||
Here is the initial question:
|
||||
\n ------- \n
|
||||
{question}
|
||||
\n ------- \n
|
||||
Here is the proposed answer:
|
||||
\n ------- \n
|
||||
{initial_answer}
|
||||
\n ------- \n
|
||||
Please answer with yes or no:"""
|
||||
|
||||
VERIFIER_PROMPT = """
|
||||
You are supposed to judge whether a document text contains data or information that is potentially relevant for a question.
|
||||
|
||||
Here is a document text that you can take as a fact:
|
||||
--
|
||||
DOCUMENT INFORMATION:
|
||||
{document_content}
|
||||
--
|
||||
|
||||
Do you think that this information is useful and relevant to answer the following question?
|
||||
(Other documents may supply additional information, so do not worry if the provided information
|
||||
is not enough to answer the question, but it needs to be relevant to the question.)
|
||||
--
|
||||
QUESTION:
|
||||
{question}
|
||||
--
|
||||
|
||||
Please answer with 'yes' or 'no':
|
||||
|
||||
Answer:
|
||||
|
||||
"""
|
||||
|
||||
INITIAL_DECOMPOSITION_PROMPT_BASIC = """ \n
|
||||
If you think it is helpful, please decompose an initial user question into not more
|
||||
than 4 appropriate sub-questions that help to answer the original question.
|
||||
The purpose for this decomposition is to isolate individulal entities
|
||||
(i.e., 'compare sales of company A and company B' -> 'what are sales for company A' + 'what are sales
|
||||
for company B'), split ambiguous terms (i.e., 'what is our success with company A' -> 'what are our
|
||||
sales with company A' + 'what is our market share with company A' + 'is company A a reference customer
|
||||
for us'), etc. Each sub-question should be realistically be answerable by a good RAG system.
|
||||
|
||||
Importantly, if you think it is not needed or helpful, please just return an empty list. That is ok too.
|
||||
|
||||
Here is the initial question:
|
||||
\n ------- \n
|
||||
{question}
|
||||
\n ------- \n
|
||||
|
||||
Please formulate your answer as a list of subquestions:
|
||||
|
||||
Answer:
|
||||
"""
|
||||
|
||||
REWRITE_PROMPT_SINGLE = """ \n
|
||||
Please convert an initial user question into a more appropriate search query for retrievel from a
|
||||
document store. \n
|
||||
Here is the initial question:
|
||||
\n ------- \n
|
||||
{question}
|
||||
\n ------- \n
|
||||
|
||||
Formulate the query: """
|
||||
|
||||
MODIFIED_RAG_PROMPT = (
|
||||
"""You are an assistant for question-answering tasks. Use the context provided below
|
||||
- and only this context - to answer the question. It is a matter of life and death that you do NOT
|
||||
use your internal knowledge, just the provided information!
|
||||
If you don't have enough infortmation to generate an answer, just say """
|
||||
+ f'"{UNKNOWN_ANSWER}"'
|
||||
+ """.
|
||||
Use three sentences maximum and keep the answer concise.
|
||||
Pay also particular attention to the sub-questions and their answers, at least it may enrich the answer.
|
||||
Again, only use the provided context and do not use your internal knowledge!
|
||||
|
||||
\nQuestion: {question}
|
||||
\nContext: {combined_context} \n
|
||||
|
||||
Answer:"""
|
||||
)
|
||||
|
||||
ORIG_DEEP_DECOMPOSE_PROMPT = """ \n
|
||||
An initial user question needs to be answered. An initial answer has been provided but it wasn't quite
|
||||
good enough. Also, some sub-questions had been answered and this information has been used to provide
|
||||
the initial answer. Some other subquestions may have been suggested based on little knowledge, but they
|
||||
were not directly answerable. Also, some entities, relationships and terms are givenm to you so that
|
||||
you have an idea of how the avaiolable data looks like.
|
||||
|
||||
Your role is to generate 3-5 new sub-questions that would help to answer the initial question,
|
||||
considering:
|
||||
|
||||
1) The initial question
|
||||
2) The initial answer that was found to be unsatisfactory
|
||||
3) The sub-questions that were answered
|
||||
4) The sub-questions that were suggested but not answered
|
||||
5) The entities, relationships and terms that were extracted from the context
|
||||
|
||||
The individual questions should be answerable by a good RAG system.
|
||||
So a good idea would be to use the sub-questions to resolve ambiguities and/or to separate the
|
||||
question for different entities that may be involved in the original question, but in a way that does
|
||||
not duplicate questions that were already tried.
|
||||
|
||||
Additional Guidelines:
|
||||
- The sub-questions should be specific to the question and provide richer context for the question,
|
||||
resolve ambiguities, or address shortcoming of the initial answer
|
||||
- Each sub-question - when answered - should be relevant for the answer to the original question
|
||||
- The sub-questions should be free from comparisions, ambiguities,judgements, aggregations, or any
|
||||
other complications that may require extra context.
|
||||
- The sub-questions MUST have the full context of the original question so that it can be executed by
|
||||
a RAG system independently without the original question available
|
||||
(Example:
|
||||
- initial question: "What is the capital of France?"
|
||||
- bad sub-question: "What is the name of the river there?"
|
||||
- good sub-question: "What is the name of the river that flows through Paris?"
|
||||
- For each sub-question, please provide a short explanation for why it is a good sub-question. So
|
||||
generate a list of dictionaries with the following format:
|
||||
[{{"sub_question": <sub-question>, "explanation": <explanation>, "search_term": <rewrite the
|
||||
sub-question using as a search phrase for the document store>}}, ...]
|
||||
|
||||
\n\n
|
||||
Here is the initial question:
|
||||
\n ------- \n
|
||||
{question}
|
||||
\n ------- \n
|
||||
|
||||
Here is the initial sub-optimal answer:
|
||||
\n ------- \n
|
||||
{base_answer}
|
||||
\n ------- \n
|
||||
|
||||
Here are the sub-questions that were answered:
|
||||
\n ------- \n
|
||||
{answered_sub_questions}
|
||||
\n ------- \n
|
||||
|
||||
Here are the sub-questions that were suggested but not answered:
|
||||
\n ------- \n
|
||||
{failed_sub_questions}
|
||||
\n ------- \n
|
||||
|
||||
And here are the entities, relationships and terms extracted from the context:
|
||||
\n ------- \n
|
||||
{entity_term_extraction_str}
|
||||
\n ------- \n
|
||||
|
||||
Please generate the list of good, fully contextualized sub-questions that would help to address the
|
||||
main question. Again, please find questions that are NOT overlapping too much with the already answered
|
||||
sub-questions or those that already were suggested and failed.
|
||||
In other words - what can we try in addition to what has been tried so far?
|
||||
|
||||
Please think through it step by step and then generate the list of json dictionaries with the following
|
||||
format:
|
||||
|
||||
{{"sub_questions": [{{"sub_question": <sub-question>,
|
||||
"explanation": <explanation>,
|
||||
"search_term": <rewrite the sub-question using as a search phrase for the document store>}},
|
||||
...]}} """
|
||||
|
||||
DEEP_DECOMPOSE_PROMPT = """ \n
|
||||
An initial user question needs to be answered. An initial answer has been provided but it wasn't quite
|
||||
good enough. Also, some sub-questions had been answered and this information has been used to provide
|
||||
the initial answer. Some other subquestions may have been suggested based on little knowledge, but they
|
||||
were not directly answerable. Also, some entities, relationships and terms are givenm to you so that
|
||||
you have an idea of how the avaiolable data looks like.
|
||||
|
||||
Your role is to generate 2-4 new sub-questions that would help to answer the initial question,
|
||||
considering:
|
||||
|
||||
1) The initial question
|
||||
2) The initial answer that was found to be unsatisfactory
|
||||
3) The sub-questions that were answered
|
||||
4) The sub-questions that were suggested but not answered
|
||||
5) The entities, relationships and terms that were extracted from the context
|
||||
|
||||
The individual questions should be answerable by a good RAG system.
|
||||
So a good idea would be to use the sub-questions to resolve ambiguities and/or to separate the
|
||||
question for different entities that may be involved in the original question, but in a way that does
|
||||
not duplicate questions that were already tried.
|
||||
|
||||
Additional Guidelines:
|
||||
- The sub-questions should be specific to the question and provide richer context for the question,
|
||||
resolve ambiguities, or address shortcoming of the initial answer
|
||||
- Each sub-question - when answered - should be relevant for the answer to the original question
|
||||
- The sub-questions should be free from comparisions, ambiguities,judgements, aggregations, or any
|
||||
other complications that may require extra context.
|
||||
- The sub-questions MUST have the full context of the original question so that it can be executed by
|
||||
a RAG system independently without the original question available
|
||||
(Example:
|
||||
- initial question: "What is the capital of France?"
|
||||
- bad sub-question: "What is the name of the river there?"
|
||||
- good sub-question: "What is the name of the river that flows through Paris?"
|
||||
- For each sub-question, please also provide a search term that can be used to retrieve relevant
|
||||
documents from a document store.
|
||||
- Consider specifically the sub-questions that were suggested but not answered. This is a sign that they are not
|
||||
answerable with the available context, and you should not ask similar questions.
|
||||
\n\n
|
||||
Here is the initial question:
|
||||
\n ------- \n
|
||||
{question}
|
||||
\n ------- \n
|
||||
{history}
|
||||
|
||||
Here is the initial sub-optimal answer:
|
||||
\n ------- \n
|
||||
{base_answer}
|
||||
\n ------- \n
|
||||
|
||||
Here are the sub-questions that were answered:
|
||||
\n ------- \n
|
||||
{answered_sub_questions}
|
||||
\n ------- \n
|
||||
|
||||
Here are the sub-questions that were suggested but not answered:
|
||||
\n ------- \n
|
||||
{failed_sub_questions}
|
||||
\n ------- \n
|
||||
|
||||
And here are the entities, relationships and terms extracted from the context:
|
||||
\n ------- \n
|
||||
{entity_term_extraction_str}
|
||||
\n ------- \n
|
||||
|
||||
Please generate the list of good, fully contextualized sub-questions that would help to address the
|
||||
main question.
|
||||
|
||||
Specifically pay attention also to the entities, relationships and terms extracted, as these indicate what type of
|
||||
objects/relationships/terms you can ask about! Do not ask about entities, terms or relationships that are not
|
||||
mentioned in the 'entities, relationships and terms' section.
|
||||
|
||||
Again, please find questions that are NOT overlapping too much with the already answered
|
||||
sub-questions or those that already were suggested and failed.
|
||||
In other words - what can we try in addition to what has been tried so far?
|
||||
|
||||
Generate the list of questions separated by one new line like this:
|
||||
<sub-question 1>
|
||||
<sub-question 2>
|
||||
<sub-question 3>
|
||||
...
|
||||
"""
|
||||
|
||||
DECOMPOSE_PROMPT = """ \n
|
||||
For an initial user question, please generate at 5-10 individual sub-questions whose answers would help
|
||||
\n to answer the initial question. The individual questions should be answerable by a good RAG system.
|
||||
So a good idea would be to \n use the sub-questions to resolve ambiguities and/or to separate the
|
||||
question for different entities that may be involved in the original question.
|
||||
|
||||
In order to arrive at meaningful sub-questions, please also consider the context retrieved from the
|
||||
document store, expressed as entities, relationships and terms. You can also think about the types
|
||||
mentioned in brackets
|
||||
|
||||
Guidelines:
|
||||
- The sub-questions should be specific to the question and provide richer context for the question,
|
||||
and or resolve ambiguities
|
||||
- Each sub-question - when answered - should be relevant for the answer to the original question
|
||||
- The sub-questions should be free from comparisions, ambiguities,judgements, aggregations, or any
|
||||
other complications that may require extra context.
|
||||
- The sub-questions MUST have the full context of the original question so that it can be executed by
|
||||
a RAG system independently without the original question available
|
||||
(Example:
|
||||
- initial question: "What is the capital of France?"
|
||||
- bad sub-question: "What is the name of the river there?"
|
||||
- good sub-question: "What is the name of the river that flows through Paris?"
|
||||
- For each sub-question, please provide a short explanation for why it is a good sub-question. So
|
||||
generate a list of dictionaries with the following format:
|
||||
[{{"sub_question": <sub-question>, "explanation": <explanation>}}, ...]
|
||||
|
||||
\n\n
|
||||
Here is the initial question:
|
||||
\n ------- \n
|
||||
{question}
|
||||
\n ------- \n
|
||||
|
||||
And here are the entities, relationships and terms extracted from the context:
|
||||
\n ------- \n
|
||||
{entity_term_extraction_str}
|
||||
\n ------- \n
|
||||
|
||||
Please generate the list of good, fully contextualized sub-questions that would help to address the
|
||||
main question. Don't be too specific unless the original question is specific.
|
||||
Please think through it step by step and then generate the list of json dictionaries with the following
|
||||
format:
|
||||
{{"sub_questions": [{{"sub_question": <sub-question>,
|
||||
"explanation": <explanation>,
|
||||
"search_term": <rewrite the sub-question using as a search phrase for the document store>}},
|
||||
...]}} """
|
||||
|
||||
#### Consolidations
|
||||
COMBINED_CONTEXT = """-------
|
||||
Below you will find useful information to answer the original question. First, you see a number of
|
||||
sub-questions with their answers. This information should be considered to be more focussed and
|
||||
somewhat more specific to the original question as it tries to contextualized facts.
|
||||
After that will see the documents that were considered to be relevant to answer the original question.
|
||||
|
||||
Here are the sub-questions and their answers:
|
||||
\n\n {deep_answer_context} \n\n
|
||||
\n\n Here are the documents that were considered to be relevant to answer the original question:
|
||||
\n\n {formated_docs} \n\n
|
||||
----------------
|
||||
"""
|
||||
|
||||
SUB_QUESTION_EXPLANATION_RANKER_PROMPT = """-------
|
||||
Below you will find a question that we ultimately want to answer (the original question) and a list of
|
||||
motivations in arbitrary order for generated sub-questions that are supposed to help us answering the
|
||||
original question. The motivations are formatted as <motivation number>: <motivation explanation>.
|
||||
(Again, the numbering is arbitrary and does not necessarily mean that 1 is the most relevant
|
||||
motivation and 2 is less relevant.)
|
||||
|
||||
Please rank the motivations in order of relevance for answering the original question. Also, try to
|
||||
ensure that the top questions do not duplicate too much, i.e. that they are not too similar.
|
||||
Ultimately, create a list with the motivation numbers where the number of the most relevant
|
||||
motivations comes first.
|
||||
|
||||
Here is the original question:
|
||||
\n\n {original_question} \n\n
|
||||
\n\n Here is the list of sub-question motivations:
|
||||
\n\n {sub_question_explanations} \n\n
|
||||
----------------
|
||||
|
||||
Please think step by step and then generate the ranked list of motivations.
|
||||
|
||||
Please format your answer as a json object in the following format:
|
||||
{{"reasonning": <explain your reasoning for the ranking>,
|
||||
"ranked_motivations": <ranked list of motivation numbers>}}
|
||||
"""
|
||||
|
||||
|
||||
INITIAL_DECOMPOSITION_PROMPT_QUESTIONS = """
|
||||
If you think it is helpful, please decompose an initial user question into no more than 3 appropriate sub-questions that help to
|
||||
answer the original question. The purpose for this decomposition may be to
|
||||
1) isolate individual entities (i.e., 'compare sales of company A and company B' -> ['what are sales for company A',
|
||||
'what are sales for company B')]
|
||||
2) clarify or disambiguate ambiguous terms (i.e., 'what is our success with company A' -> ['what are our sales with company A',
|
||||
'what is our market share with company A', 'is company A a reference customer for us', etc.])
|
||||
3) if a term or a metric is essentially clear, but it could relate to various components of an entity and you are generally
|
||||
familiar with the entity, then you can decompose the question into sub-questions that are more specific to components
|
||||
(i.e., 'what do we do to improve scalability of product X', 'what do we to to improve scalability of product X',
|
||||
'what do we do to improve stability of product X', ...])
|
||||
4) research an area that could really help to answer the question. (But clarifications or disambiguations are more important.)
|
||||
|
||||
If you think that a decomposition is not needed or helpful, please just return an empty string. That is ok too.
|
||||
|
||||
Here is the initial question:
|
||||
-------
|
||||
{question}
|
||||
-------
|
||||
{history}
|
||||
|
||||
Please formulate your answer as a newline-separated list of questions like so:
|
||||
<sub-question>
|
||||
<sub-question>
|
||||
<sub-question>
|
||||
|
||||
Answer:"""
|
||||
|
||||
INITIAL_DECOMPOSITION_PROMPT_QUESTIONS_AFTER_SEARCH = """
|
||||
If you think it is helpful, please decompose an initial user question into no more than 3 appropriate sub-questions that help to
|
||||
answer the original question. The purpose for this decomposition may be to
|
||||
1) isolate individual entities (i.e., 'compare sales of company A and company B' -> ['what are sales for company A',
|
||||
'what are sales for company B')]
|
||||
2) clarify or disambiguate ambiguous terms (i.e., 'what is our success with company A' -> ['what are our sales with company A',
|
||||
'what is our market share with company A', 'is company A a reference customer for us', etc.])
|
||||
3) if a term or a metric is essentially clear, but it could relate to various components of an entity and you are generally
|
||||
familiar with the entity, then you can decompose the question into sub-questions that are more specific to components
|
||||
(i.e., 'what do we do to improve scalability of product X', 'what do we to to improve scalability of product X',
|
||||
'what do we do to improve stability of product X', ...])
|
||||
4) research an area that could really help to answer the question. (But clarifications or disambiguations are more important.)
|
||||
|
||||
Here are some other ruleds:
|
||||
|
||||
1) To give you some context, you will see below also some documents that relate to the question. Please only
|
||||
use this information to learn what the question is approximately asking about, but do not focus on the details
|
||||
to construct the sub-questions.
|
||||
2) If you think that a decomposition is not needed or helpful, please just return an empty string. That is very muchok too.
|
||||
|
||||
Here are the sampple docs to give you some context:
|
||||
-------
|
||||
{sample_doc_str}
|
||||
-------
|
||||
|
||||
And here is the initial question that you should think about decomposing:
|
||||
-------
|
||||
{question}
|
||||
-------
|
||||
|
||||
{history}
|
||||
|
||||
Please formulate your answer as a newline-separated list of questions like so:
|
||||
<sub-question>
|
||||
<sub-question>
|
||||
<sub-question>
|
||||
|
||||
Answer:"""
|
||||
|
||||
INITIAL_DECOMPOSITION_PROMPT = """ \n
|
||||
Please decompose an initial user question into 2 or 3 appropriate sub-questions that help to
|
||||
answer the original question. The purpose for this decomposition is to isolate individulal entities
|
||||
(i.e., 'compare sales of company A and company B' -> 'what are sales for company A' + 'what are sales
|
||||
for company B'), split ambiguous terms (i.e., 'what is our success with company A' -> 'what are our
|
||||
sales with company A' + 'what is our market share with company A' + 'is company A a reference customer
|
||||
for us'), etc. Each sub-question should be realistically be answerable by a good RAG system. \n
|
||||
|
||||
For each sub-question, please also create one search term that can be used to retrieve relevant
|
||||
documents from a document store.
|
||||
|
||||
Here is the initial question:
|
||||
\n ------- \n
|
||||
{question}
|
||||
\n ------- \n
|
||||
|
||||
Please formulate your answer as a list of json objects with the following format:
|
||||
|
||||
[{{"sub_question": <sub-question>, "search_term": <search term>}}, ...]
|
||||
|
||||
Answer:
|
||||
"""
|
||||
|
||||
INITIAL_RAG_BASE_PROMPT = (
|
||||
""" \n
|
||||
You are an assistant for question-answering tasks. Use the information provided below - and only the
|
||||
provided information - to answer the provided question.
|
||||
|
||||
The information provided below consists ofa number of documents that were deemed relevant for the question.
|
||||
|
||||
IMPORTANT RULES:
|
||||
- If you cannot reliably answer the question solely using the provided information, say that you cannot reliably answer.
|
||||
You may give some additional facts you learned, but do not try to invent an answer.
|
||||
- If the information is empty or irrelevant, just say """
|
||||
+ f'"{UNKNOWN_ANSWER}"'
|
||||
+ """.
|
||||
- If the information is relevant but not fully conclusive, specify that the information is not conclusive and say why.
|
||||
|
||||
Try to keep your answer concise.
|
||||
|
||||
Here is the contextual information from the document store:
|
||||
\n ------- \n
|
||||
{context} \n\n\n
|
||||
\n ------- \n
|
||||
And here is the question I want you to answer based on the context above (with the motivation in mind):
|
||||
\n--\n {question} \n--\n
|
||||
Answer:"""
|
||||
)
|
||||
|
||||
|
||||
AGENT_DECISION_PROMPT = """
|
||||
You are an large language model assistant helping users address their information needs. You are tasked with deciding
|
||||
whether to use a thorough agent search ('research') of a document store to answer a question or request, or whether you want to
|
||||
address the question or request yourself as an LLM.
|
||||
|
||||
Here are some rules:
|
||||
- If you think that a thorough search through a document store will help answer the question
|
||||
or address the request, you should choose the 'research' option.
|
||||
- If the question asks you do do somethng ('please create...', 'write for me...', etc.), you should choose the 'LLM' option.
|
||||
- If you think the question is very general and does not refer to a contents of a document store, you should choose
|
||||
the 'LLM' option.
|
||||
- Otherwise, you should choose the 'research' option.
|
||||
{history}
|
||||
|
||||
Here is the initial question:
|
||||
-------
|
||||
{question}
|
||||
-------
|
||||
|
||||
Please decide whether to use the agent search or the LLM to answer the question. Choose from two choices,
|
||||
'research' or 'LLM'.
|
||||
|
||||
Answer:"""
|
||||
|
||||
AGENT_DECISION_PROMPT_AFTER_SEARCH = """
|
||||
You are an large language model assistant helping users address their information needs. You are given an initial question
|
||||
or request and very few sample of documents that a preliminary and fast search from a document store returned.
|
||||
You are tasked with deciding whether to use a thorough agent search ('research') of the document store to answer a question
|
||||
or request, or whether you want to address the question or request yourself as an LLM.
|
||||
|
||||
Here are some rules:
|
||||
- If based on the retrieved documents you think there may be useful information in the document
|
||||
store to answer or materially help with the request, you should choose the 'research' option.
|
||||
- If you think that the retrieved document do not help to answer the question or do not help with the request, AND
|
||||
you know the answer/can handle the request, you should choose the 'LLM' option.
|
||||
- If the question asks you do do somethng ('please create...', 'write for me...', etc.), you should choose the 'LLM' option.
|
||||
- If in doubt, choose the 'research' option.
|
||||
{history}
|
||||
|
||||
Here is the initial question:
|
||||
-------
|
||||
{question}
|
||||
-------
|
||||
|
||||
Here is the sample of documents that were retrieved from a document store:
|
||||
-------
|
||||
{sample_doc_str}
|
||||
-------
|
||||
|
||||
Please decide whether to use the agent search ('research') or the LLM to answer the question. Choose from two choices,
|
||||
'research' or 'LLM'.
|
||||
|
||||
Answer:"""
|
||||
|
||||
### ANSWER GENERATION PROMPTS
|
||||
|
||||
# Persona specification
|
||||
ASSISTANT_SYSTEM_PROMPT_DEFAULT = """
|
||||
You are an assistant for question-answering tasks."""
|
||||
|
||||
ASSISTANT_SYSTEM_PROMPT_PERSONA = """
|
||||
You are an assistant for question-answering tasks. Here is more information about you:
|
||||
\n ------- \n
|
||||
{persona_prompt}
|
||||
\n ------- \n
|
||||
"""
|
||||
|
||||
SUB_QUESTION_ANSWER_TEMPLATE = """
|
||||
Sub-Question: Q{sub_question_nr}\n Sub-Question:\n - \n{sub_question}\n --\nAnswer:\n -\n {sub_answer}\n\n
|
||||
"""
|
||||
|
||||
SUB_QUESTION_ANSWER_TEMPLATE_REVISED = """
|
||||
Sub-Question: Q{sub_question_nr}\n Type: {level_type}\n Sub-Question:\n
|
||||
- \n{sub_question}\n --\nAnswer:\n -\n {sub_answer}\n\n
|
||||
"""
|
||||
|
||||
SUB_QUESTION_SEARCH_RESULTS_TEMPLATE = """
|
||||
Sub-Question: Q{sub_question_nr}\n Sub-Question:\n - \n{sub_question}\n --\nRelevant Documents:\n
|
||||
-\n {formatted_sub_question_docs}\n\n
|
||||
"""
|
||||
|
||||
INITIAL_RAG_PROMPT_SUB_QUESTION_SEARCH = (
|
||||
""" \n
|
||||
{persona_specification}
|
||||
|
||||
Use the information provided below - and only the provided information - to answer the main question that will be provided.
|
||||
|
||||
The information provided below consists of:
|
||||
1) a number of sub-questions and supporting document information that would help answer them.
|
||||
2) a broader collection of documents that were deemed relevant for the question. These documents contain informattion
|
||||
that was also provided in the sub-questions and often more.
|
||||
|
||||
IMPORTANT RULES:
|
||||
- If you cannot reliably answer the question solely using the provided information, say that you cannot reliably answer.
|
||||
You may give some additional facts you learned, but do not try to invent an answer.
|
||||
- If the information is empty or irrelevant, just say """
|
||||
+ f'"{UNKNOWN_ANSWER}"'
|
||||
+ """.
|
||||
- If the information is relevant but not fully conclusive, specify that the information is not conclusive and say why.
|
||||
- The answers to the subquestions should help you to structure your thoughts in order to answer the question.
|
||||
|
||||
Please provide inline citations of documentsin the format [[D1]](), [[D2]](), [[D3]](), etc., If you have multiple citations,
|
||||
please cite for example as [[D1]]()[[D3]](), or [[D2]]()[[D4]](), etc. Feel free to cite documents in addition
|
||||
to the sub-questions! Proper citations are important for the final answer to be verifiable! \n\n\n
|
||||
|
||||
Again, you should be sure that the answer is supported by the information provided!
|
||||
|
||||
Try to keep your answer concise. But also highlight uncertainties you may have should there be substantial ones,
|
||||
or assumptions you made.
|
||||
|
||||
Here is the contextual information:
|
||||
\n-------\n
|
||||
*Answered Sub-questions (these should really help to organize your thoughts):
|
||||
{answered_sub_questions}
|
||||
|
||||
And here are relevant document information that supports the sub-question answers, or that are relevant for the actual question:\n
|
||||
|
||||
{relevant_docs}
|
||||
|
||||
\n-------\n
|
||||
\n
|
||||
And here is the main question I want you to answer based on the information above:
|
||||
\n--\n
|
||||
{question}
|
||||
\n--\n\n
|
||||
Answer:"""
|
||||
)
|
||||
|
||||
|
||||
DIRECT_LLM_PROMPT = """ \n
|
||||
{persona_specification}
|
||||
|
||||
Please answer the following question/address the request:
|
||||
\n--\n
|
||||
{question}
|
||||
\n--\n\n
|
||||
Answer:"""
|
||||
|
||||
INITIAL_RAG_PROMPT = (
|
||||
""" \n
|
||||
{persona_specification}
|
||||
|
||||
Use the information provided below - and only the provided information - to answer the provided main question.
|
||||
|
||||
The information provided below consists of:
|
||||
1) a number of answered sub-questions - these are very important to help you organize your thoughts and your
|
||||
answer
|
||||
2) a number of documents that deemed relevant for the question.
|
||||
|
||||
{history}
|
||||
|
||||
Please provide inline citations to documents in the format [[D1]](), [[D2]](), [[D3]](), etc. If you have multiple
|
||||
citations, please cite for example as [[D1]]()[[D3]](), or [[D2]]()[[D4]](), etc.
|
||||
Feel free to also cite sub-questions in addition to documents, but make sure that you have documents cited with the sub-question
|
||||
citation. If you want to cite both a document and a sub-question, please use [[D1]]()[[Q3]](), or [[D2]]()[[D7]]()[[Q4]](), etc.
|
||||
Again, please do not cite sub-questions without a document citation!
|
||||
Citations are very important for the user!
|
||||
|
||||
IMPORTANT RULES:
|
||||
- If you cannot reliably answer the question solely using the provided information, say that you cannot reliably answer.
|
||||
You may give some additional facts you learned, but do not try to invent an answer.
|
||||
- If the information is empty or irrelevant, just say """
|
||||
+ f'"{UNKNOWN_ANSWER}"'
|
||||
+ """.
|
||||
- If the information is relevant but not fully conclusive, specify that the information is not conclusive and say why.
|
||||
|
||||
Again, you should be sure that the answer is supported by the information provided!
|
||||
|
||||
Try to keep your answer concise. But also highlight uncertainties you may have should there be substantial ones,
|
||||
or assumptions you made.
|
||||
|
||||
Here is the contextual information:
|
||||
\n-------\n
|
||||
*Answered Sub-questions (these should really matter!):
|
||||
{answered_sub_questions}
|
||||
|
||||
And here are relevant document information that support the sub-question answers, or that are relevant for the actual question:\n
|
||||
|
||||
{relevant_docs}
|
||||
|
||||
\n-------\n
|
||||
\n
|
||||
And here is the question I want you to answer based on the information above:
|
||||
\n--\n
|
||||
{question}
|
||||
\n--\n\n
|
||||
Answer:"""
|
||||
)
|
||||
|
||||
# sub_question_answer_str is empty
|
||||
INITIAL_RAG_PROMPT_NO_SUB_QUESTIONS = (
|
||||
"""{answered_sub_questions}
|
||||
{persona_specification}
|
||||
Use the information provided below
|
||||
- and only the provided information - to answer the provided question.
|
||||
The information provided below consists of a number of documents that were deemed relevant for the question.
|
||||
{history}
|
||||
|
||||
IMPORTANT RULES:
|
||||
- If you cannot reliably answer the question solely using the provided information, say that you cannot reliably answer.
|
||||
You may give some additional facts you learned, but do not try to invent an answer.
|
||||
- If the information is irrelevant, just say """
|
||||
+ f'"{UNKNOWN_ANSWER}"'
|
||||
+ """.
|
||||
- If the information is relevant but not fully conclusive, specify that the information is not conclusive and say why.
|
||||
|
||||
Again, you should be sure that the answer is supported by the information provided!
|
||||
|
||||
Please provide inline citations to documents in the format [[D1]](), [[D2]](), [[D3]](), etc. If you have multiple
|
||||
citations, please cite for example as [[D1]]()[[D3]](), or [[D2]]()[[D4]](), etc. Citations are very important for the
|
||||
user!
|
||||
|
||||
Try to keep your answer concise.
|
||||
|
||||
Here are is the relevant context information:
|
||||
\n-------\n
|
||||
{relevant_docs}
|
||||
\n-------\n
|
||||
|
||||
And here is the question I want you to answer based on the context above
|
||||
\n--\n
|
||||
{question}
|
||||
\n--\n
|
||||
|
||||
Answer:"""
|
||||
)
|
||||
|
||||
REVISED_RAG_PROMPT = (
|
||||
"""\n
|
||||
{persona_specification}
|
||||
Use the information provided below - and only the provided information - to answer the provided main question.
|
||||
|
||||
The information provided below consists of:
|
||||
1) an initial answer that was given but found to be lacking in some way.
|
||||
2) a number of answered sub-questions - these are very important(!) and definitely should help yoiu to answer
|
||||
the main question. Note that the sub-questions have a type, 'initial' and 'revised'. The 'initial'
|
||||
ones were available for the initial answer, and the 'revised' were not. So please use the 'revised' sub-questions in
|
||||
particular to update/extend/correct the initial answer!
|
||||
3) a number of documents that were deemed relevant for the question. This the is the context that you use largey for
|
||||
citations (see below).
|
||||
|
||||
Please provide inline citations to documents in the format [[D1]](), [[D2]](), [[D3]](), etc. If you have multiple
|
||||
citations, please cite for example as [[D1]]()[[D3]](), or [[D2]]()[[D4]](), etc.
|
||||
Feel free to also cite sub-questions in addition to documents, but make sure that you have documents cited with the sub-question
|
||||
citation. If you want to cite both a document and a sub-question, please use [[D1]]()[[Q3]](), or [[D2]]()[[D7]]()[[Q4]](), etc.
|
||||
Again, please do not cite sub-questions without a document citation!
|
||||
Citations are very important for the user!\n\n
|
||||
|
||||
{history}
|
||||
|
||||
IMPORTANT RULES:
|
||||
- If you cannot reliably answer the question solely using the provided information, say that you cannot reliably answer.
|
||||
You may give some additional facts you learned, but do not try to invent an answer.
|
||||
- If the information is empty or irrelevant, just say """
|
||||
+ f'"{UNKNOWN_ANSWER}"'
|
||||
+ """.
|
||||
- If the information is relevant but not fully conclusive, provide and answer to the extent you can but also
|
||||
specify that the information is not conclusive and why.
|
||||
- Ignore any exisiting citations within the answered sub-questions, like [[D1]]()... and [[Q2]]()!
|
||||
The citations you will need to use will need to refer to the documents (and sub-questions) that you are explicitly
|
||||
presented with below!
|
||||
|
||||
Again, you should be sure that the answer is supported by the information provided!
|
||||
|
||||
Try to keep your answer concise. But also highlight uncertainties you may have should there be substantial ones,
|
||||
or assumptions you made.
|
||||
|
||||
Here is the contextual information:
|
||||
\n-------\n
|
||||
|
||||
*Initial Answer that was found to be lacking:
|
||||
{initial_answer}
|
||||
|
||||
*Answered Sub-questions (these should really help ypu to research your answer! They also contain questions/answers
|
||||
that were not available when the original answer was constructed):
|
||||
{answered_sub_questions}
|
||||
|
||||
And here are the relevant documents that support the sub-question answers, and that are relevant for the actual question:\n
|
||||
|
||||
{relevant_docs}
|
||||
|
||||
\n-------\n
|
||||
\n
|
||||
Lastly, here is the main question I want you to answer based on the information above:
|
||||
\n--\n
|
||||
{question}
|
||||
\n--\n\n
|
||||
Answer:"""
|
||||
)
|
||||
|
||||
# sub_question_answer_str is empty
|
||||
REVISED_RAG_PROMPT_NO_SUB_QUESTIONS = (
|
||||
"""{answered_sub_questions}\n
|
||||
{persona_specification}
|
||||
Use the information provided below - and only the
|
||||
provided information - to answer the provided question.
|
||||
|
||||
The information provided below consists of:
|
||||
1) an initial answer that was given but found to be lacking in some way.
|
||||
2) a number of documents that were also deemed relevant for the question.
|
||||
|
||||
Please provide inline citations to documents in the format [[D1]](), [[D2]](), [[D3]](), etc. If you have multiple
|
||||
citations, please cite for example as [[D1]]()[[D3]](), or [[D2]]()[[D4]](), etc. Citations are very important for the user!\n\n
|
||||
|
||||
{history}
|
||||
|
||||
IMPORTANT RULES:
|
||||
- If you cannot reliably answer the question solely using the provided information, say that you cannot reliably answer.
|
||||
You may give some additional facts you learned, but do not try to invent an answer.
|
||||
- If the information is empty or irrelevant, just say """
|
||||
+ f'"{UNKNOWN_ANSWER}"'
|
||||
+ """.
|
||||
- If the information is relevant but not fully conclusive, provide and answer to the extent you can but also
|
||||
specify that the information is not conclusive and why.
|
||||
|
||||
Again, you should be sure that the answer is supported by the information provided!
|
||||
|
||||
Try to keep your answer concise. But also highlight uncertainties you may have should there be substantial ones,
|
||||
or assumptions you made.
|
||||
|
||||
Here is the contextual information:
|
||||
\n-------\n
|
||||
|
||||
*Initial Answer that was found to be lacking:
|
||||
{initial_answer}
|
||||
|
||||
And here are relevant document information that support the sub-question answers, or that are relevant for the actual question:\n
|
||||
|
||||
{relevant_docs}
|
||||
|
||||
\n-------\n
|
||||
\n
|
||||
Lastly, here is the question I want you to answer based on the information above:
|
||||
\n--\n
|
||||
{question}
|
||||
\n--\n\n
|
||||
Answer:"""
|
||||
)
|
||||
|
||||
|
||||
ENTITY_TERM_PROMPT = """ \n
|
||||
Based on the original question and the context retieved from a dataset, please generate a list of
|
||||
entities (e.g. companies, organizations, industries, products, locations, etc.), terms and concepts
|
||||
(e.g. sales, revenue, etc.) that are relevant for the question, plus their relations to each other.
|
||||
|
||||
\n\n
|
||||
Here is the original question:
|
||||
\n ------- \n
|
||||
{question}
|
||||
\n ------- \n
|
||||
And here is the context retrieved:
|
||||
\n ------- \n
|
||||
{context}
|
||||
\n ------- \n
|
||||
|
||||
Please format your answer as a json object in the following format:
|
||||
|
||||
{{"retrieved_entities_relationships": {{
|
||||
"entities": [{{
|
||||
"entity_name": <assign a name for the entity>,
|
||||
"entity_type": <specify a short type name for the entity, such as 'company', 'location',...>
|
||||
}}],
|
||||
"relationships": [{{
|
||||
"relationship_name": <assign a name for the relationship>,
|
||||
"relationship_type": <specify a short type name for the relationship, such as 'sales_to', 'is_location_of',...>,
|
||||
"relationship_entities": [<related entity name 1>, <related entity name 2>, ...]
|
||||
}}],
|
||||
"terms": [{{
|
||||
"term_name": <assign a name for the term>,
|
||||
"term_type": <specify a short type name for the term, such as 'revenue', 'market_share',...>,
|
||||
"term_similar_to": <list terms that are similar to this term>
|
||||
}}]
|
||||
}}
|
||||
}}
|
||||
"""
|
||||
278
backend/onyx/agents/agent_search/shared_graph_utils/utils.py
Normal file
278
backend/onyx/agents/agent_search/shared_graph_utils/utils.py
Normal file
@@ -0,0 +1,278 @@
|
||||
import ast
|
||||
import json
|
||||
import re
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from uuid import UUID
|
||||
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.messages import HumanMessage
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import (
|
||||
EntityRelationshipTermExtraction,
|
||||
)
|
||||
from onyx.chat.models import AnswerStyleConfig
|
||||
from onyx.chat.models import CitationConfig
|
||||
from onyx.chat.models import DocumentPruningConfig
|
||||
from onyx.chat.models import PromptConfig
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
|
||||
from onyx.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
|
||||
from onyx.configs.constants import DEFAULT_PERSONA_ID
|
||||
from onyx.context.search.enums import LLMEvaluationType
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.context.search.models import RetrievalDetails
|
||||
from onyx.context.search.models import SearchRequest
|
||||
from onyx.db.persona import get_persona_by_id
|
||||
from onyx.db.persona import Persona
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.tools.force import ForceUseTool
|
||||
from onyx.tools.tool_constructor import SearchToolConfig
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
|
||||
|
||||
def normalize_whitespace(text: str) -> str:
|
||||
"""Normalize whitespace in text to single spaces and strip leading/trailing whitespace."""
|
||||
import re
|
||||
|
||||
return re.sub(r"\s+", " ", text.strip())
|
||||
|
||||
|
||||
# Post-processing
|
||||
def format_docs(docs: Sequence[InferenceSection]) -> str:
|
||||
formatted_doc_list = []
|
||||
|
||||
for doc_nr, doc in enumerate(docs):
|
||||
formatted_doc_list.append(f"Document D{doc_nr + 1}:\n{doc.combined_content}")
|
||||
|
||||
return "\n\n".join(formatted_doc_list)
|
||||
|
||||
|
||||
def format_docs_content_flat(docs: Sequence[InferenceSection]) -> str:
|
||||
formatted_doc_list = []
|
||||
|
||||
for _, doc in enumerate(docs):
|
||||
formatted_doc_list.append(f"\n...{doc.combined_content}\n")
|
||||
|
||||
return "\n\n".join(formatted_doc_list)
|
||||
|
||||
|
||||
def clean_and_parse_list_string(json_string: str) -> list[dict]:
|
||||
# Remove any prefixes/labels before the actual JSON content
|
||||
json_string = re.sub(r"^.*?(?=\[)", "", json_string, flags=re.DOTALL)
|
||||
|
||||
# Remove markdown code block markers and any newline prefixes
|
||||
cleaned_string = re.sub(r"```json\n|\n```", "", json_string)
|
||||
cleaned_string = cleaned_string.replace("\\n", " ").replace("\n", " ")
|
||||
cleaned_string = " ".join(cleaned_string.split())
|
||||
|
||||
# Try parsing with json.loads first, fall back to ast.literal_eval
|
||||
try:
|
||||
return json.loads(cleaned_string)
|
||||
except json.JSONDecodeError:
|
||||
try:
|
||||
return ast.literal_eval(cleaned_string)
|
||||
except (ValueError, SyntaxError) as e:
|
||||
raise ValueError(f"Failed to parse JSON string: {cleaned_string}") from e
|
||||
|
||||
|
||||
def clean_and_parse_json_string(json_string: str) -> dict[str, Any]:
|
||||
# Remove markdown code block markers and any newline prefixes
|
||||
cleaned_string = re.sub(r"```json\n|\n```", "", json_string)
|
||||
cleaned_string = cleaned_string.replace("\\n", " ").replace("\n", " ")
|
||||
cleaned_string = " ".join(cleaned_string.split())
|
||||
# Parse the cleaned string into a Python dictionary
|
||||
return json.loads(cleaned_string)
|
||||
|
||||
|
||||
def format_entity_term_extraction(
|
||||
entity_term_extraction_dict: EntityRelationshipTermExtraction,
|
||||
) -> str:
|
||||
entities = entity_term_extraction_dict.entities
|
||||
terms = entity_term_extraction_dict.terms
|
||||
relationships = entity_term_extraction_dict.relationships
|
||||
|
||||
entity_strs = ["\nEntities:\n"]
|
||||
for entity in entities:
|
||||
entity_str = f"{entity.entity_name} ({entity.entity_type})"
|
||||
entity_strs.append(entity_str)
|
||||
|
||||
entity_str = "\n - ".join(entity_strs)
|
||||
|
||||
relationship_strs = ["\n\nRelationships:\n"]
|
||||
for relationship in relationships:
|
||||
relationship_name = relationship.relationship_name
|
||||
relationship_type = relationship.relationship_type
|
||||
relationship_entities = relationship.relationship_entities
|
||||
relationship_str = (
|
||||
f"""{relationship_name} ({relationship_type}): {relationship_entities}"""
|
||||
)
|
||||
relationship_strs.append(relationship_str)
|
||||
|
||||
relationship_str = "\n - ".join(relationship_strs)
|
||||
|
||||
term_strs = ["\n\nTerms:\n"]
|
||||
for term in terms:
|
||||
term_str = f"{term.term_name} ({term.term_type}): similar to {', '.join(term.term_similar_to)}"
|
||||
term_strs.append(term_str)
|
||||
|
||||
term_str = "\n - ".join(term_strs)
|
||||
|
||||
return "\n".join(entity_strs + relationship_strs + term_strs)
|
||||
|
||||
|
||||
def _format_time_delta(time: timedelta) -> str:
|
||||
seconds_from_start = f"{((time).seconds):03d}"
|
||||
microseconds_from_start = f"{((time).microseconds):06d}"
|
||||
return f"{seconds_from_start}.{microseconds_from_start}"
|
||||
|
||||
|
||||
def generate_log_message(
|
||||
message: str,
|
||||
node_start_time: datetime,
|
||||
graph_start_time: datetime | None = None,
|
||||
) -> str:
|
||||
current_time = datetime.now()
|
||||
|
||||
if graph_start_time is not None:
|
||||
graph_time_str = _format_time_delta(current_time - graph_start_time)
|
||||
else:
|
||||
graph_time_str = "N/A"
|
||||
|
||||
node_time_str = _format_time_delta(current_time - node_start_time)
|
||||
|
||||
return f"{graph_time_str} ({node_time_str} s): {message}"
|
||||
|
||||
|
||||
def get_test_config(
|
||||
db_session: Session, primary_llm: LLM, fast_llm: LLM, search_request: SearchRequest
|
||||
) -> tuple[AgentSearchConfig, SearchTool]:
|
||||
persona = get_persona_by_id(DEFAULT_PERSONA_ID, None, db_session)
|
||||
document_pruning_config = DocumentPruningConfig(
|
||||
max_chunks=int(
|
||||
persona.num_chunks
|
||||
if persona.num_chunks is not None
|
||||
else MAX_CHUNKS_FED_TO_CHAT
|
||||
),
|
||||
max_window_percentage=CHAT_TARGET_CHUNK_PERCENTAGE,
|
||||
)
|
||||
|
||||
answer_style_config = AnswerStyleConfig(
|
||||
citation_config=CitationConfig(
|
||||
# The docs retrieved by this flow are already relevance-filtered
|
||||
all_docs_useful=True
|
||||
),
|
||||
document_pruning_config=document_pruning_config,
|
||||
structured_response_format=None,
|
||||
)
|
||||
|
||||
search_tool_config = SearchToolConfig(
|
||||
answer_style_config=answer_style_config,
|
||||
document_pruning_config=document_pruning_config,
|
||||
retrieval_options=RetrievalDetails(), # may want to set dedupe_docs=True
|
||||
rerank_settings=None, # Can use this to change reranking model
|
||||
selected_sections=None,
|
||||
latest_query_files=None,
|
||||
bypass_acl=False,
|
||||
)
|
||||
|
||||
prompt_config = PromptConfig.from_model(persona.prompts[0])
|
||||
|
||||
search_tool = SearchTool(
|
||||
db_session=db_session,
|
||||
user=None,
|
||||
persona=persona,
|
||||
retrieval_options=search_tool_config.retrieval_options,
|
||||
prompt_config=prompt_config,
|
||||
llm=primary_llm,
|
||||
fast_llm=fast_llm,
|
||||
pruning_config=search_tool_config.document_pruning_config,
|
||||
answer_style_config=search_tool_config.answer_style_config,
|
||||
selected_sections=search_tool_config.selected_sections,
|
||||
chunks_above=search_tool_config.chunks_above,
|
||||
chunks_below=search_tool_config.chunks_below,
|
||||
full_doc=search_tool_config.full_doc,
|
||||
evaluation_type=(
|
||||
LLMEvaluationType.BASIC
|
||||
if persona.llm_relevance_filter
|
||||
else LLMEvaluationType.SKIP
|
||||
),
|
||||
rerank_settings=search_tool_config.rerank_settings,
|
||||
bypass_acl=search_tool_config.bypass_acl,
|
||||
)
|
||||
|
||||
config = AgentSearchConfig(
|
||||
search_request=search_request,
|
||||
primary_llm=primary_llm,
|
||||
fast_llm=fast_llm,
|
||||
search_tool=search_tool,
|
||||
force_use_tool=ForceUseTool(force_use=False, tool_name=""),
|
||||
prompt_builder=AnswerPromptBuilder(
|
||||
user_message=HumanMessage(content=search_request.query),
|
||||
message_history=[],
|
||||
llm_config=primary_llm.config,
|
||||
raw_user_query=search_request.query,
|
||||
raw_user_uploaded_files=[],
|
||||
),
|
||||
# chat_session_id=UUID("123e4567-e89b-12d3-a456-426614174000"),
|
||||
chat_session_id=UUID("edda10d5-6cef-45d8-acfb-39317552a1f4"), # Joachim
|
||||
# chat_session_id=UUID("d1acd613-2692-4bc3-9d65-c6d3da62e58e"), # Evan
|
||||
message_id=1,
|
||||
use_persistence=True,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
return config, search_tool
|
||||
|
||||
|
||||
def get_persona_prompt(persona: Persona | None) -> str:
|
||||
if persona is None:
|
||||
return ""
|
||||
else:
|
||||
return "\n".join([x.system_prompt for x in persona.prompts])
|
||||
|
||||
|
||||
def make_question_id(level: int, question_nr: int) -> str:
|
||||
return f"{level}_{question_nr}"
|
||||
|
||||
|
||||
def parse_question_id(question_id: str) -> tuple[int, int]:
|
||||
level, question_nr = question_id.split("_")
|
||||
return int(level), int(question_nr)
|
||||
|
||||
|
||||
def _dispatch_nonempty(
|
||||
content: str, dispatch_event: Callable[[str, int], None], num: int
|
||||
) -> None:
|
||||
if content != "":
|
||||
dispatch_event(content, num)
|
||||
|
||||
|
||||
def dispatch_separated(
|
||||
token_itr: Iterator[BaseMessage],
|
||||
dispatch_event: Callable[[str, int], None],
|
||||
sep: str = "\n",
|
||||
) -> list[str | list[str | dict[str, Any]]]:
|
||||
num = 1
|
||||
streamed_tokens: list[str | list[str | dict[str, Any]]] = [""]
|
||||
for message in token_itr:
|
||||
content = cast(str, message.content)
|
||||
if sep in content:
|
||||
sub_question_parts = content.split(sep)
|
||||
_dispatch_nonempty(sub_question_parts[0], dispatch_event, num)
|
||||
num += 1
|
||||
_dispatch_nonempty(
|
||||
"".join(sub_question_parts[1:]).strip(), dispatch_event, num
|
||||
)
|
||||
else:
|
||||
_dispatch_nonempty(content, dispatch_event, num)
|
||||
streamed_tokens.append(content)
|
||||
|
||||
return streamed_tokens
|
||||
@@ -23,7 +23,6 @@ def load_no_auth_user_preferences(store: KeyValueStore) -> UserPreferences:
|
||||
preferences_data = cast(
|
||||
Mapping[str, Any], store.load(KV_NO_AUTH_USER_PREFERENCES_KEY)
|
||||
)
|
||||
print("preferences_data", preferences_data)
|
||||
return UserPreferences(**preferences_data)
|
||||
except KvKeyNotFoundError:
|
||||
return UserPreferences(
|
||||
|
||||
@@ -55,6 +55,7 @@ from onyx.auth.invited_users import get_invited_users
|
||||
from onyx.auth.schemas import UserCreate
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.auth.schemas import UserUpdate
|
||||
from onyx.configs.app_configs import AUTH_COOKIE_EXPIRE_TIME_SECONDS
|
||||
from onyx.configs.app_configs import AUTH_TYPE
|
||||
from onyx.configs.app_configs import DISABLE_AUTH
|
||||
from onyx.configs.app_configs import EMAIL_CONFIGURED
|
||||
@@ -209,6 +210,7 @@ def verify_email_domain(email: str) -> None:
|
||||
class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
reset_password_token_secret = USER_AUTH_SECRET
|
||||
verification_token_secret = USER_AUTH_SECRET
|
||||
verification_token_lifetime_seconds = AUTH_COOKIE_EXPIRE_TIME_SECONDS
|
||||
|
||||
user_db: SQLAlchemyUserDatabase[User, uuid.UUID]
|
||||
|
||||
|
||||
@@ -23,8 +23,7 @@ from onyx.background.celery.celery_utils import celery_is_worker_primary
|
||||
from onyx.configs.constants import ONYX_CLOUD_CELERY_TASK_PREFIX
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.db.engine import get_sqlalchemy_engine
|
||||
from onyx.document_index.vespa.shared_utils.utils import get_vespa_http_client
|
||||
from onyx.document_index.vespa_constants import VESPA_CONFIG_SERVER_URL
|
||||
from onyx.document_index.vespa.shared_utils.utils import wait_for_vespa_with_timeout
|
||||
from onyx.redis.redis_connector import RedisConnector
|
||||
from onyx.redis.redis_connector_credential_pair import RedisConnectorCredentialPair
|
||||
from onyx.redis.redis_connector_delete import RedisConnectorDelete
|
||||
@@ -280,51 +279,6 @@ def wait_for_db(sender: Any, **kwargs: Any) -> None:
|
||||
return
|
||||
|
||||
|
||||
def wait_for_vespa(sender: Any, **kwargs: Any) -> None:
|
||||
"""Waits for Vespa to become ready subject to a hardcoded timeout.
|
||||
Will raise WorkerShutdown to kill the celery worker if the timeout is reached."""
|
||||
|
||||
WAIT_INTERVAL = 5
|
||||
WAIT_LIMIT = 60
|
||||
|
||||
ready = False
|
||||
time_start = time.monotonic()
|
||||
logger.info("Vespa: Readiness probe starting.")
|
||||
while True:
|
||||
try:
|
||||
client = get_vespa_http_client()
|
||||
response = client.get(f"{VESPA_CONFIG_SERVER_URL}/state/v1/health")
|
||||
response.raise_for_status()
|
||||
|
||||
response_dict = response.json()
|
||||
if response_dict["status"]["code"] == "up":
|
||||
ready = True
|
||||
break
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
time_elapsed = time.monotonic() - time_start
|
||||
if time_elapsed > WAIT_LIMIT:
|
||||
break
|
||||
|
||||
logger.info(
|
||||
f"Vespa: Readiness probe ongoing. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}"
|
||||
)
|
||||
|
||||
time.sleep(WAIT_INTERVAL)
|
||||
|
||||
if not ready:
|
||||
msg = (
|
||||
f"Vespa: Readiness probe did not succeed within the timeout "
|
||||
f"({WAIT_LIMIT} seconds). Exiting..."
|
||||
)
|
||||
logger.error(msg)
|
||||
raise WorkerShutdown(msg)
|
||||
|
||||
logger.info("Vespa: Readiness probe succeeded. Continuing...")
|
||||
return
|
||||
|
||||
|
||||
def on_secondary_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
logger.info("Running as a secondary celery worker.")
|
||||
|
||||
@@ -510,3 +464,13 @@ def reset_tenant_id(
|
||||
) -> None:
|
||||
"""Signal handler to reset tenant ID in context var after task ends."""
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(POSTGRES_DEFAULT_SCHEMA)
|
||||
|
||||
|
||||
def wait_for_vespa_or_shutdown(sender: Any, **kwargs: Any) -> None:
|
||||
"""Waits for Vespa to become ready subject to a timeout.
|
||||
Raises WorkerShutdown if the timeout is reached."""
|
||||
|
||||
if not wait_for_vespa_with_timeout():
|
||||
msg = "Vespa: Readiness probe did not succeed within the timeout. Exiting..."
|
||||
logger.error(msg)
|
||||
raise WorkerShutdown(msg)
|
||||
|
||||
@@ -62,7 +62,7 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
app_base.wait_for_vespa(sender, **kwargs)
|
||||
app_base.wait_for_vespa_or_shutdown(sender, **kwargs)
|
||||
|
||||
# Less startup checks in multi-tenant case
|
||||
if MULTI_TENANT:
|
||||
|
||||
@@ -68,7 +68,7 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
app_base.wait_for_vespa(sender, **kwargs)
|
||||
app_base.wait_for_vespa_or_shutdown(sender, **kwargs)
|
||||
|
||||
# Less startup checks in multi-tenant case
|
||||
if MULTI_TENANT:
|
||||
|
||||
@@ -63,7 +63,7 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
app_base.wait_for_vespa(sender, **kwargs)
|
||||
app_base.wait_for_vespa_or_shutdown(sender, **kwargs)
|
||||
|
||||
# Less startup checks in multi-tenant case
|
||||
if MULTI_TENANT:
|
||||
|
||||
@@ -86,7 +86,7 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
app_base.wait_for_vespa(sender, **kwargs)
|
||||
app_base.wait_for_vespa_or_shutdown(sender, **kwargs)
|
||||
|
||||
logger.info("Running as the primary celery worker.")
|
||||
|
||||
|
||||
@@ -29,6 +29,16 @@ cloud_tasks_to_schedule = [
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-alembic",
|
||||
"task": OnyxCeleryTask.CLOUD_CHECK_ALEMBIC,
|
||||
"schedule": timedelta(hours=1),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGH,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
"queue": OnyxCeleryQueues.MONITORING,
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
# tasks that run in either self-hosted on cloud
|
||||
|
||||
@@ -674,6 +674,9 @@ def connector_indexing_proxy_task(
|
||||
while True:
|
||||
sleep(5)
|
||||
|
||||
# renew watchdog signal (this has a shorter timeout than set_active)
|
||||
redis_connector_index.set_watchdog(True)
|
||||
|
||||
# renew active signal
|
||||
redis_connector_index.set_active()
|
||||
|
||||
@@ -780,6 +783,7 @@ def connector_indexing_proxy_task(
|
||||
)
|
||||
continue
|
||||
|
||||
redis_connector_index.set_watchdog(False)
|
||||
task_logger.info(
|
||||
f"Indexing watchdog - finished: attempt={index_attempt_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import json
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from datetime import timedelta
|
||||
from itertools import islice
|
||||
from typing import Any
|
||||
|
||||
from celery import shared_task
|
||||
@@ -10,13 +12,17 @@ from pydantic import BaseModel
|
||||
from redis import Redis
|
||||
from redis.lock import Lock as RedisLock
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.tasks.vespa.tasks import celery_get_queue_length
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import ONYX_CLOUD_TENANT_ID
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.db.engine import get_all_tenant_ids
|
||||
from onyx.db.engine import get_db_current_time
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.enums import IndexingStatus
|
||||
@@ -27,6 +33,7 @@ from onyx.db.models import IndexAttempt
|
||||
from onyx.db.models import SyncRecord
|
||||
from onyx.db.models import UserGroup
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_pool import redis_lock_dump
|
||||
from onyx.utils.telemetry import optional_telemetry
|
||||
from onyx.utils.telemetry import RecordType
|
||||
|
||||
@@ -456,3 +463,116 @@ def monitor_background_processes(self: Task, *, tenant_id: str | None) -> None:
|
||||
lock_monitoring.release()
|
||||
|
||||
task_logger.info("Background monitoring task finished")
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.CLOUD_CHECK_ALEMBIC,
|
||||
)
|
||||
def cloud_check_alembic() -> bool | None:
|
||||
"""A task to verify that all tenants are on the same alembic revision.
|
||||
|
||||
This check is expected to fail if a cloud alembic migration is currently running
|
||||
across all tenants.
|
||||
|
||||
TODO: have the cloud migration script set an activity signal that this check
|
||||
uses to know it doesn't make sense to run a check at the present time.
|
||||
"""
|
||||
time_start = time.monotonic()
|
||||
|
||||
redis_client = get_redis_client(tenant_id=ONYX_CLOUD_TENANT_ID)
|
||||
|
||||
lock_beat: RedisLock = redis_client.lock(
|
||||
OnyxRedisLocks.CLOUD_CHECK_ALEMBIC_BEAT_LOCK,
|
||||
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
# these tasks should never overlap
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return None
|
||||
|
||||
last_lock_time = time.monotonic()
|
||||
|
||||
tenant_to_revision: dict[str, str | None] = {}
|
||||
revision_counts: dict[str, int] = {}
|
||||
out_of_date_tenants: dict[str, str | None] = {}
|
||||
top_revision: str = ""
|
||||
|
||||
try:
|
||||
# map each tenant_id to its revision
|
||||
tenant_ids = get_all_tenant_ids()
|
||||
for tenant_id in tenant_ids:
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_lock_time >= (CELERY_GENERIC_BEAT_LOCK_TIMEOUT / 4):
|
||||
lock_beat.reacquire()
|
||||
last_lock_time = current_time
|
||||
|
||||
if tenant_id is None:
|
||||
continue
|
||||
|
||||
with get_session_with_tenant(tenant_id=None) as session:
|
||||
result = session.execute(
|
||||
text(f'SELECT * FROM "{tenant_id}".alembic_version LIMIT 1')
|
||||
)
|
||||
|
||||
result_scalar: str | None = result.scalar_one_or_none()
|
||||
tenant_to_revision[tenant_id] = result_scalar
|
||||
|
||||
# get the total count of each revision
|
||||
for k, v in tenant_to_revision.items():
|
||||
if v is None:
|
||||
continue
|
||||
|
||||
revision_counts[v] = revision_counts.get(v, 0) + 1
|
||||
|
||||
# get the revision with the most counts
|
||||
sorted_revision_counts = sorted(
|
||||
revision_counts.items(), key=lambda item: item[1], reverse=True
|
||||
)
|
||||
|
||||
if len(sorted_revision_counts) == 0:
|
||||
task_logger.error(
|
||||
f"cloud_check_alembic - No revisions found for {len(tenant_ids)} tenant ids!"
|
||||
)
|
||||
else:
|
||||
top_revision, _ = sorted_revision_counts[0]
|
||||
|
||||
# build a list of out of date tenants
|
||||
for k, v in tenant_to_revision.items():
|
||||
if v == top_revision:
|
||||
continue
|
||||
|
||||
out_of_date_tenants[k] = v
|
||||
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception("Unexpected exception during cloud alembic check")
|
||||
raise
|
||||
finally:
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
else:
|
||||
task_logger.error("cloud_check_alembic - Lock not owned on completion")
|
||||
redis_lock_dump(lock_beat, redis_client)
|
||||
|
||||
if len(out_of_date_tenants) > 0:
|
||||
task_logger.error(
|
||||
f"Found out of date tenants: "
|
||||
f"num_out_of_date_tenants={len(out_of_date_tenants)} "
|
||||
f"num_tenants={len(tenant_ids)} "
|
||||
f"revision={top_revision}"
|
||||
)
|
||||
for k, v in islice(out_of_date_tenants.items(), 5):
|
||||
task_logger.info(f"Out of date tenant: tenant={k} revision={v}")
|
||||
else:
|
||||
task_logger.info(
|
||||
f"All tenants are up to date: num_tenants={len(tenant_ids)} revision={top_revision}"
|
||||
)
|
||||
|
||||
time_elapsed = time.monotonic() - time_start
|
||||
task_logger.info(
|
||||
f"cloud_check_alembic finished: num_tenants={len(tenant_ids)} elapsed={time_elapsed:.2f}"
|
||||
)
|
||||
return True
|
||||
|
||||
@@ -735,7 +735,7 @@ def monitor_ccpair_indexing_taskset(
|
||||
composite_id = RedisConnector.get_id_from_fence_key(fence_key)
|
||||
if composite_id is None:
|
||||
task_logger.warning(
|
||||
f"monitor_ccpair_indexing_taskset: could not parse composite_id from {fence_key}"
|
||||
f"Connector indexing: could not parse composite_id from {fence_key}"
|
||||
)
|
||||
return
|
||||
|
||||
@@ -785,6 +785,7 @@ def monitor_ccpair_indexing_taskset(
|
||||
# inner/outer/inner double check pattern to avoid race conditions when checking for
|
||||
# bad state
|
||||
|
||||
# Verify: if the generator isn't complete, the task must not be in READY state
|
||||
# inner = get_completion / generator_complete not signaled
|
||||
# outer = result.state in READY state
|
||||
status_int = redis_connector_index.get_completion()
|
||||
@@ -830,7 +831,7 @@ def monitor_ccpair_indexing_taskset(
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
"monitor_ccpair_indexing_taskset - transient exception marking index attempt as failed: "
|
||||
"Connector indexing - Transient exception marking index attempt as failed: "
|
||||
f"attempt={payload.index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
@@ -840,6 +841,20 @@ def monitor_ccpair_indexing_taskset(
|
||||
redis_connector_index.reset()
|
||||
return
|
||||
|
||||
if redis_connector_index.watchdog_signaled():
|
||||
# if the generator is complete, don't clean up until the watchdog has exited
|
||||
task_logger.info(
|
||||
f"Connector indexing - Delaying finalization until watchdog has exited: "
|
||||
f"attempt={payload.index_attempt_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"progress={progress} "
|
||||
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f} "
|
||||
f"elapsed_started={elapsed_started_str}"
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
status_enum = HTTPStatus(status_int)
|
||||
|
||||
task_logger.info(
|
||||
@@ -858,9 +873,13 @@ def monitor_ccpair_indexing_taskset(
|
||||
|
||||
@shared_task(name=OnyxCeleryTask.MONITOR_VESPA_SYNC, soft_time_limit=300, bind=True)
|
||||
def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool | None:
|
||||
"""This is a celery beat task that monitors and finalizes metadata sync tasksets.
|
||||
"""This is a celery beat task that monitors and finalizes various long running tasks.
|
||||
|
||||
The name monitor_vespa_sync is a bit of a misnomer since it checks many different tasks
|
||||
now. Should change that at some point.
|
||||
|
||||
It scans for fence values and then gets the counts of any associated tasksets.
|
||||
If the count is 0, that means all tasks finished and we should clean up.
|
||||
For many tasks, the count is 0, that means all tasks finished and we should clean up.
|
||||
|
||||
This task lock timeout is CELERY_METADATA_SYNC_BEAT_LOCK_TIMEOUT seconds, so don't
|
||||
do anything too expensive in this function!
|
||||
@@ -1045,6 +1064,8 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool | None:
|
||||
def vespa_metadata_sync_task(
|
||||
self: Task, document_id: str, tenant_id: str | None
|
||||
) -> bool:
|
||||
start = time.monotonic()
|
||||
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
curr_ind_name, sec_ind_name = get_both_index_names(db_session)
|
||||
@@ -1095,7 +1116,13 @@ def vespa_metadata_sync_task(
|
||||
# r = get_redis_client(tenant_id=tenant_id)
|
||||
# r.delete(redis_syncing_key)
|
||||
|
||||
task_logger.info(f"doc={document_id} action=sync chunks={chunks_affected}")
|
||||
elapsed = time.monotonic() - start
|
||||
task_logger.info(
|
||||
f"doc={document_id} "
|
||||
f"action=sync "
|
||||
f"chunks={chunks_affected} "
|
||||
f"elapsed={elapsed:.2f}"
|
||||
)
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(f"SoftTimeLimitExceeded exception. doc={document_id}")
|
||||
except Exception as ex:
|
||||
|
||||
@@ -1,50 +1,35 @@
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from uuid import uuid4
|
||||
|
||||
from langchain.schema.messages import BaseMessage
|
||||
from langchain_core.messages import AIMessageChunk
|
||||
from langchain_core.messages import ToolCall
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.chat.llm_response_handler import LLMResponseHandlerManager
|
||||
from onyx.chat.models import AnswerQuestionPossibleReturn
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.run_graph import run_basic_graph
|
||||
from onyx.agents.agent_search.run_graph import run_main_graph
|
||||
from onyx.chat.models import AgentAnswerPiece
|
||||
from onyx.chat.models import AnswerPacket
|
||||
from onyx.chat.models import AnswerStream
|
||||
from onyx.chat.models import AnswerStyleConfig
|
||||
from onyx.chat.models import CitationInfo
|
||||
from onyx.chat.models import OnyxAnswerPiece
|
||||
from onyx.chat.models import PromptConfig
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_system_message
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_user_message
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import LLMCall
|
||||
from onyx.chat.stream_processing.answer_response_handler import (
|
||||
CitationResponseHandler,
|
||||
)
|
||||
from onyx.chat.stream_processing.answer_response_handler import (
|
||||
DummyAnswerResponseHandler,
|
||||
)
|
||||
from onyx.chat.stream_processing.utils import (
|
||||
map_document_id_order,
|
||||
)
|
||||
from onyx.chat.tool_handling.tool_response_handler import ToolResponseHandler
|
||||
from onyx.chat.models import StreamStopInfo
|
||||
from onyx.chat.models import StreamStopReason
|
||||
from onyx.chat.tool_handling.tool_response_handler import get_tool_by_name
|
||||
from onyx.configs.constants import BASIC_KEY
|
||||
from onyx.file_store.utils import InMemoryChatFile
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.models import PreviousMessage
|
||||
from onyx.natural_language_processing.utils import get_tokenizer
|
||||
from onyx.tools.force import ForceUseTool
|
||||
from onyx.tools.models import ToolResponse
|
||||
from onyx.tools.tool import Tool
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from onyx.tools.tool_runner import ToolCallKickoff
|
||||
from onyx.tools.utils import explicit_tool_calling_supported
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
AnswerStream = Iterator[AnswerQuestionPossibleReturn | ToolCallKickoff | ToolResponse]
|
||||
|
||||
|
||||
class Answer:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -53,13 +38,13 @@ class Answer:
|
||||
llm: LLM,
|
||||
prompt_config: PromptConfig,
|
||||
force_use_tool: ForceUseTool,
|
||||
agent_search_config: AgentSearchConfig,
|
||||
# must be the same length as `docs`. If None, all docs are considered "relevant"
|
||||
message_history: list[PreviousMessage] | None = None,
|
||||
single_message_history: str | None = None,
|
||||
# newly passed in files to include as part of this question
|
||||
# TODO THIS NEEDS TO BE HANDLED
|
||||
latest_query_files: list[InMemoryChatFile] | None = None,
|
||||
files: list[InMemoryChatFile] | None = None,
|
||||
tools: list[Tool] | None = None,
|
||||
# NOTE: for native tool-calling, this is only supported by OpenAI atm,
|
||||
# but we only support them anyways
|
||||
@@ -69,6 +54,8 @@ class Answer:
|
||||
return_contexts: bool = False,
|
||||
skip_gen_ai_answer_generation: bool = False,
|
||||
is_connected: Callable[[], bool] | None = None,
|
||||
fast_llm: LLM | None = None,
|
||||
db_session: Session | None = None,
|
||||
) -> None:
|
||||
if single_message_history and message_history:
|
||||
raise ValueError(
|
||||
@@ -79,7 +66,6 @@ class Answer:
|
||||
self.is_connected: Callable[[], bool] | None = is_connected
|
||||
|
||||
self.latest_query_files = latest_query_files or []
|
||||
self.file_id_to_file = {file.file_id: file for file in (files or [])}
|
||||
|
||||
self.tools = tools or []
|
||||
self.force_use_tool = force_use_tool
|
||||
@@ -92,6 +78,7 @@ class Answer:
|
||||
self.prompt_config = prompt_config
|
||||
|
||||
self.llm = llm
|
||||
self.fast_llm = fast_llm
|
||||
self.llm_tokenizer = get_tokenizer(
|
||||
provider_type=llm.config.model_provider,
|
||||
model_name=llm.config.model_name,
|
||||
@@ -100,9 +87,7 @@ class Answer:
|
||||
self._final_prompt: list[BaseMessage] | None = None
|
||||
|
||||
self._streamed_output: list[str] | None = None
|
||||
self._processed_stream: (
|
||||
list[AnswerQuestionPossibleReturn | ToolResponse | ToolCallKickoff] | None
|
||||
) = None
|
||||
self._processed_stream: (list[AnswerPacket] | None) = None
|
||||
|
||||
self._return_contexts = return_contexts
|
||||
self.skip_gen_ai_answer_generation = skip_gen_ai_answer_generation
|
||||
@@ -115,131 +100,150 @@ class Answer:
|
||||
and not skip_explicit_tool_calling
|
||||
)
|
||||
|
||||
self.agent_search_config = agent_search_config
|
||||
self.db_session = db_session
|
||||
|
||||
def _get_tools_list(self) -> list[Tool]:
|
||||
if not self.force_use_tool.force_use:
|
||||
return self.tools
|
||||
|
||||
tool = next(
|
||||
(t for t in self.tools if t.name == self.force_use_tool.tool_name), None
|
||||
)
|
||||
if tool is None:
|
||||
raise RuntimeError(f"Tool '{self.force_use_tool.tool_name}' not found")
|
||||
tool = get_tool_by_name(self.tools, self.force_use_tool.tool_name)
|
||||
|
||||
logger.info(
|
||||
f"Forcefully using tool='{tool.name}'"
|
||||
+ (
|
||||
f" with args='{self.force_use_tool.args}'"
|
||||
if self.force_use_tool.args is not None
|
||||
else ""
|
||||
)
|
||||
args_str = (
|
||||
f" with args='{self.force_use_tool.args}'"
|
||||
if self.force_use_tool.args
|
||||
else ""
|
||||
)
|
||||
logger.info(f"Forcefully using tool='{tool.name}'{args_str}")
|
||||
return [tool]
|
||||
|
||||
def _handle_specified_tool_call(
|
||||
self, llm_calls: list[LLMCall], tool: Tool, tool_args: dict
|
||||
) -> AnswerStream:
|
||||
current_llm_call = llm_calls[-1]
|
||||
# TODO: delete the function and move the full body to processed_streamed_output
|
||||
def _get_response(self) -> AnswerStream:
|
||||
# current_llm_call = llm_calls[-1]
|
||||
|
||||
# make a dummy tool handler
|
||||
tool_handler = ToolResponseHandler([tool])
|
||||
# tool, tool_args = None, None
|
||||
# # handle the case where no decision has to be made; we simply run the tool
|
||||
# if (
|
||||
# current_llm_call.force_use_tool.force_use
|
||||
# and current_llm_call.force_use_tool.args is not None
|
||||
# ):
|
||||
# tool_name, tool_args = (
|
||||
# current_llm_call.force_use_tool.tool_name,
|
||||
# current_llm_call.force_use_tool.args,
|
||||
# )
|
||||
# tool = get_tool_by_name(current_llm_call.tools, tool_name)
|
||||
|
||||
dummy_tool_call_chunk = AIMessageChunk(content="")
|
||||
dummy_tool_call_chunk.tool_calls = [
|
||||
ToolCall(name=tool.name, args=tool_args, id=str(uuid4()))
|
||||
]
|
||||
# # special pre-logic for non-tool calling LLM case
|
||||
# elif not self.using_tool_calling_llm and current_llm_call.tools:
|
||||
# chosen_tool_and_args = (
|
||||
# ToolResponseHandler.get_tool_call_for_non_tool_calling_llm(
|
||||
# current_llm_call, self.llm
|
||||
# )
|
||||
# )
|
||||
# if chosen_tool_and_args:
|
||||
# tool, tool_args = chosen_tool_and_args
|
||||
|
||||
response_handler_manager = LLMResponseHandlerManager(
|
||||
tool_handler, DummyAnswerResponseHandler(), self.is_cancelled
|
||||
)
|
||||
yield from response_handler_manager.handle_llm_response(
|
||||
iter([dummy_tool_call_chunk])
|
||||
)
|
||||
# if tool and tool_args:
|
||||
# dummy_tool_call_chunk = AIMessageChunk(content="")
|
||||
# dummy_tool_call_chunk.tool_calls = [
|
||||
# ToolCall(name=tool.name, args=tool_args, id=str(uuid4()))
|
||||
# ]
|
||||
|
||||
new_llm_call = response_handler_manager.next_llm_call(current_llm_call)
|
||||
if new_llm_call:
|
||||
yield from self._get_response(llm_calls + [new_llm_call])
|
||||
else:
|
||||
raise RuntimeError("Tool call handler did not return a new LLM call")
|
||||
# response_handler_manager = LLMResponseHandlerManager(
|
||||
# ToolResponseHandler([tool]), None, self.is_cancelled
|
||||
# )
|
||||
# yield from response_handler_manager.handle_llm_response(
|
||||
# iter([dummy_tool_call_chunk])
|
||||
# )
|
||||
|
||||
def _get_response(self, llm_calls: list[LLMCall]) -> AnswerStream:
|
||||
current_llm_call = llm_calls[-1]
|
||||
# tmp_call = response_handler_manager.next_llm_call(current_llm_call)
|
||||
# if tmp_call is None:
|
||||
# return # no more LLM calls to process
|
||||
# current_llm_call = tmp_call
|
||||
|
||||
# handle the case where no decision has to be made; we simply run the tool
|
||||
if (
|
||||
current_llm_call.force_use_tool.force_use
|
||||
and current_llm_call.force_use_tool.args is not None
|
||||
):
|
||||
tool_name, tool_args = (
|
||||
current_llm_call.force_use_tool.tool_name,
|
||||
current_llm_call.force_use_tool.args,
|
||||
)
|
||||
tool = next(
|
||||
(t for t in current_llm_call.tools if t.name == tool_name), None
|
||||
)
|
||||
if not tool:
|
||||
raise RuntimeError(f"Tool '{tool_name}' not found")
|
||||
# # if we're skipping gen ai answer generation, we should break
|
||||
# # out unless we're forcing a tool call. If we don't, we might generate an
|
||||
# # answer, which is a no-no!
|
||||
# if (
|
||||
# self.skip_gen_ai_answer_generation
|
||||
# and not current_llm_call.force_use_tool.force_use
|
||||
# ):
|
||||
# return
|
||||
|
||||
yield from self._handle_specified_tool_call(llm_calls, tool, tool_args)
|
||||
return
|
||||
# # set up "handlers" to listen to the LLM response stream and
|
||||
# # feed back the processed results + handle tool call requests
|
||||
# # + figure out what the next LLM call should be
|
||||
# tool_call_handler = ToolResponseHandler(current_llm_call.tools)
|
||||
|
||||
# special pre-logic for non-tool calling LLM case
|
||||
if not self.using_tool_calling_llm and current_llm_call.tools:
|
||||
chosen_tool_and_args = (
|
||||
ToolResponseHandler.get_tool_call_for_non_tool_calling_llm(
|
||||
current_llm_call, self.llm
|
||||
# final_search_results, displayed_search_results = SearchTool.get_search_result(
|
||||
# current_llm_call
|
||||
# ) or ([], [])
|
||||
|
||||
# # NEXT: we still want to handle the LLM response stream, but it is now:
|
||||
# # 1. handle the tool call requests
|
||||
# # 2. feed back the processed results
|
||||
# # 3. handle the citations
|
||||
|
||||
# answer_handler = CitationResponseHandler(
|
||||
# context_docs=final_search_results,
|
||||
# final_doc_id_to_rank_map=map_document_id_order(final_search_results),
|
||||
# display_doc_id_to_rank_map=map_document_id_order(displayed_search_results),
|
||||
# )
|
||||
|
||||
# # At the moment, this wrapper class passes streamed stuff through citation and tool handlers.
|
||||
# # In the future, we'll want to handle citations and tool calls in the langgraph graph.
|
||||
# response_handler_manager = LLMResponseHandlerManager(
|
||||
# tool_call_handler, answer_handler, self.is_cancelled
|
||||
# )
|
||||
|
||||
# In langgraph, whether we do the basic thing (call llm stream) or pro search
|
||||
# is based on a flag in the pro search config
|
||||
|
||||
if self.agent_search_config.use_agentic_search:
|
||||
if (
|
||||
self.agent_search_config.db_session is None
|
||||
and self.agent_search_config.use_persistence
|
||||
):
|
||||
raise ValueError(
|
||||
"db_session must be provided for pro search when using persistence"
|
||||
)
|
||||
|
||||
stream = run_main_graph(
|
||||
config=self.agent_search_config,
|
||||
)
|
||||
else:
|
||||
stream = run_basic_graph(
|
||||
config=self.agent_search_config,
|
||||
)
|
||||
if chosen_tool_and_args:
|
||||
tool, tool_args = chosen_tool_and_args
|
||||
yield from self._handle_specified_tool_call(llm_calls, tool, tool_args)
|
||||
return
|
||||
|
||||
# if we're skipping gen ai answer generation, we should break
|
||||
# out unless we're forcing a tool call. If we don't, we might generate an
|
||||
# answer, which is a no-no!
|
||||
if (
|
||||
self.skip_gen_ai_answer_generation
|
||||
and not current_llm_call.force_use_tool.force_use
|
||||
):
|
||||
return
|
||||
|
||||
# set up "handlers" to listen to the LLM response stream and
|
||||
# feed back the processed results + handle tool call requests
|
||||
# + figure out what the next LLM call should be
|
||||
tool_call_handler = ToolResponseHandler(current_llm_call.tools)
|
||||
|
||||
final_search_results, displayed_search_results = SearchTool.get_search_result(
|
||||
current_llm_call
|
||||
) or ([], [])
|
||||
|
||||
answer_handler = CitationResponseHandler(
|
||||
context_docs=final_search_results,
|
||||
final_doc_id_to_rank_map=map_document_id_order(final_search_results),
|
||||
display_doc_id_to_rank_map=map_document_id_order(displayed_search_results),
|
||||
)
|
||||
|
||||
response_handler_manager = LLMResponseHandlerManager(
|
||||
tool_call_handler, answer_handler, self.is_cancelled
|
||||
)
|
||||
|
||||
processed_stream = []
|
||||
for packet in stream:
|
||||
if self.is_cancelled():
|
||||
packet = StreamStopInfo(stop_reason=StreamStopReason.CANCELLED)
|
||||
yield packet
|
||||
break
|
||||
processed_stream.append(packet)
|
||||
yield packet
|
||||
self._processed_stream = processed_stream
|
||||
return
|
||||
# DEBUG: good breakpoint
|
||||
stream = self.llm.stream(
|
||||
# For tool calling LLMs, we want to insert the task prompt as part of this flow, this is because the LLM
|
||||
# may choose to not call any tools and just generate the answer, in which case the task prompt is needed.
|
||||
prompt=current_llm_call.prompt_builder.build(),
|
||||
tools=[tool.tool_definition() for tool in current_llm_call.tools] or None,
|
||||
tool_choice=(
|
||||
"required"
|
||||
if current_llm_call.tools and current_llm_call.force_use_tool.force_use
|
||||
else None
|
||||
),
|
||||
structured_response_format=self.answer_style_config.structured_response_format,
|
||||
)
|
||||
yield from response_handler_manager.handle_llm_response(stream)
|
||||
# stream = self.llm.stream(
|
||||
# # For tool calling LLMs, we want to insert the task prompt as part of this flow, this is because the LLM
|
||||
# # may choose to not call any tools and just generate the answer, in which case the task prompt is needed.
|
||||
# prompt=current_llm_call.prompt_builder.build(),
|
||||
# tools=[tool.tool_definition() for tool in current_llm_call.tools] or None,
|
||||
# tool_choice=(
|
||||
# "required"
|
||||
# if current_llm_call.tools and current_llm_call.force_use_tool.force_use
|
||||
# else None
|
||||
# ),
|
||||
# structured_response_format=self.answer_style_config.structured_response_format,
|
||||
# )
|
||||
# yield from response_handler_manager.handle_llm_response(stream)
|
||||
|
||||
new_llm_call = response_handler_manager.next_llm_call(current_llm_call)
|
||||
if new_llm_call:
|
||||
yield from self._get_response(llm_calls + [new_llm_call])
|
||||
# new_llm_call = response_handler_manager.next_llm_call(current_llm_call)
|
||||
# if new_llm_call:
|
||||
# yield from self._get_response(llm_calls + [new_llm_call])
|
||||
|
||||
@property
|
||||
def processed_streamed_output(self) -> AnswerStream:
|
||||
@@ -247,33 +251,33 @@ class Answer:
|
||||
yield from self._processed_stream
|
||||
return
|
||||
|
||||
prompt_builder = AnswerPromptBuilder(
|
||||
user_message=default_build_user_message(
|
||||
user_query=self.question,
|
||||
prompt_config=self.prompt_config,
|
||||
files=self.latest_query_files,
|
||||
single_message_history=self.single_message_history,
|
||||
),
|
||||
message_history=self.message_history,
|
||||
llm_config=self.llm.config,
|
||||
raw_user_query=self.question,
|
||||
raw_user_uploaded_files=self.latest_query_files or [],
|
||||
single_message_history=self.single_message_history,
|
||||
)
|
||||
prompt_builder.update_system_prompt(
|
||||
default_build_system_message(self.prompt_config)
|
||||
)
|
||||
llm_call = LLMCall(
|
||||
prompt_builder=prompt_builder,
|
||||
tools=self._get_tools_list(),
|
||||
force_use_tool=self.force_use_tool,
|
||||
files=self.latest_query_files,
|
||||
tool_call_info=[],
|
||||
using_tool_calling_llm=self.using_tool_calling_llm,
|
||||
)
|
||||
# prompt_builder = AnswerPromptBuilder(
|
||||
# user_message=default_build_user_message(
|
||||
# user_query=self.question,
|
||||
# prompt_config=self.prompt_config,
|
||||
# files=self.latest_query_files,
|
||||
# single_message_history=self.single_message_history,
|
||||
# ),
|
||||
# message_history=self.message_history,
|
||||
# llm_config=self.llm.config,
|
||||
# raw_user_query=self.question,
|
||||
# raw_user_uploaded_files=self.latest_query_files or [],
|
||||
# single_message_history=self.single_message_history,
|
||||
# )
|
||||
# prompt_builder.update_system_prompt(
|
||||
# default_build_system_message(self.prompt_config)
|
||||
# )
|
||||
# llm_call = LLMCall(
|
||||
# prompt_builder=prompt_builder,
|
||||
# tools=self._get_tools_list(),
|
||||
# force_use_tool=self.force_use_tool,
|
||||
# files=self.latest_query_files,
|
||||
# tool_call_info=[],
|
||||
# using_tool_calling_llm=self.using_tool_calling_llm,
|
||||
# )
|
||||
|
||||
processed_stream = []
|
||||
for processed_packet in self._get_response([llm_call]):
|
||||
for processed_packet in self._get_response():
|
||||
processed_stream.append(processed_packet)
|
||||
yield processed_packet
|
||||
|
||||
@@ -283,20 +287,56 @@ class Answer:
|
||||
def llm_answer(self) -> str:
|
||||
answer = ""
|
||||
for packet in self.processed_streamed_output:
|
||||
if isinstance(packet, OnyxAnswerPiece) and packet.answer_piece:
|
||||
# handle basic answer flow, plus level 0 agent answer flow
|
||||
# since level 0 is the first answer the user sees and therefore the
|
||||
# child message of the user message in the db (so it is handled
|
||||
# like a basic flow answer)
|
||||
if (isinstance(packet, OnyxAnswerPiece) and packet.answer_piece) or (
|
||||
isinstance(packet, AgentAnswerPiece)
|
||||
and packet.answer_piece
|
||||
and packet.answer_type == "agent_level_answer"
|
||||
and packet.level == 0
|
||||
):
|
||||
answer += packet.answer_piece
|
||||
|
||||
return answer
|
||||
|
||||
def llm_answer_by_level(self) -> dict[int, str]:
|
||||
answer_by_level: dict[int, str] = defaultdict(str)
|
||||
for packet in self.processed_streamed_output:
|
||||
if (
|
||||
isinstance(packet, AgentAnswerPiece)
|
||||
and packet.answer_piece
|
||||
and packet.answer_type == "agent_level_answer"
|
||||
):
|
||||
answer_by_level[packet.level] += packet.answer_piece
|
||||
elif isinstance(packet, OnyxAnswerPiece) and packet.answer_piece:
|
||||
answer_by_level[BASIC_KEY[0]] += packet.answer_piece
|
||||
return answer_by_level
|
||||
|
||||
@property
|
||||
def citations(self) -> list[CitationInfo]:
|
||||
citations: list[CitationInfo] = []
|
||||
for packet in self.processed_streamed_output:
|
||||
if isinstance(packet, CitationInfo):
|
||||
if isinstance(packet, CitationInfo) and packet.level is None:
|
||||
citations.append(packet)
|
||||
|
||||
return citations
|
||||
|
||||
def citations_by_subquestion(self) -> dict[tuple[int, int], list[CitationInfo]]:
|
||||
citations_by_subquestion: dict[
|
||||
tuple[int, int], list[CitationInfo]
|
||||
] = defaultdict(list)
|
||||
for packet in self.processed_streamed_output:
|
||||
if isinstance(packet, CitationInfo):
|
||||
if packet.level_question_nr is not None and packet.level is not None:
|
||||
citations_by_subquestion[
|
||||
(packet.level, packet.level_question_nr)
|
||||
].append(packet)
|
||||
elif packet.level is None:
|
||||
citations_by_subquestion[BASIC_KEY].append(packet)
|
||||
return citations_by_subquestion
|
||||
|
||||
def is_cancelled(self) -> bool:
|
||||
if self._is_cancelled:
|
||||
return True
|
||||
|
||||
@@ -48,6 +48,8 @@ def prepare_chat_message_request(
|
||||
retrieval_details: RetrievalDetails | None,
|
||||
rerank_settings: RerankingDetails | None,
|
||||
db_session: Session,
|
||||
use_agentic_search: bool = False,
|
||||
skip_gen_ai_answer_generation: bool = False,
|
||||
) -> CreateChatMessageRequest:
|
||||
# Typically used for one shot flows like SlackBot or non-chat API endpoint use cases
|
||||
new_chat_session = create_chat_session(
|
||||
@@ -72,6 +74,8 @@ def prepare_chat_message_request(
|
||||
search_doc_ids=None,
|
||||
retrieval_options=retrieval_details,
|
||||
rerank_settings=rerank_settings,
|
||||
use_agentic_search=use_agentic_search,
|
||||
skip_gen_ai_answer_generation=skip_gen_ai_answer_generation,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -9,25 +9,37 @@ from onyx.chat.models import StreamStopInfo
|
||||
from onyx.chat.models import StreamStopReason
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import LLMCall
|
||||
from onyx.chat.stream_processing.answer_response_handler import AnswerResponseHandler
|
||||
from onyx.chat.stream_processing.answer_response_handler import (
|
||||
DummyAnswerResponseHandler,
|
||||
)
|
||||
from onyx.chat.tool_handling.tool_response_handler import ToolResponseHandler
|
||||
|
||||
|
||||
class LLMResponseHandlerManager:
|
||||
"""
|
||||
This class is responsible for postprocessing the LLM response stream.
|
||||
In particular, we:
|
||||
1. handle the tool call requests
|
||||
2. handle citations
|
||||
3. pass through answers generated by the LLM
|
||||
4. Stop yielding if the client disconnects
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tool_handler: ToolResponseHandler,
|
||||
answer_handler: AnswerResponseHandler,
|
||||
tool_handler: ToolResponseHandler | None,
|
||||
answer_handler: AnswerResponseHandler | None,
|
||||
is_cancelled: Callable[[], bool],
|
||||
):
|
||||
self.tool_handler = tool_handler
|
||||
self.answer_handler = answer_handler
|
||||
self.tool_handler = tool_handler or ToolResponseHandler([])
|
||||
self.answer_handler = answer_handler or DummyAnswerResponseHandler()
|
||||
self.is_cancelled = is_cancelled
|
||||
|
||||
def handle_llm_response(
|
||||
self,
|
||||
stream: Iterator[BaseMessage],
|
||||
) -> Generator[ResponsePart, None, None]:
|
||||
all_messages: list[BaseMessage] = []
|
||||
all_messages: list[BaseMessage | str] = []
|
||||
for message in stream:
|
||||
if self.is_cancelled():
|
||||
yield StreamStopInfo(stop_reason=StreamStopReason.CANCELLED)
|
||||
|
||||
@@ -3,6 +3,7 @@ from collections.abc import Iterator
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
from typing import Literal
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from pydantic import BaseModel
|
||||
@@ -48,6 +49,8 @@ class QADocsResponse(RetrievalDocs):
|
||||
applied_source_filters: list[DocumentSource] | None
|
||||
applied_time_cutoff: datetime | None
|
||||
recency_bias_multiplier: float
|
||||
level: int | None = None
|
||||
level_question_nr: int | None = None
|
||||
|
||||
def model_dump(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
|
||||
initial_dict = super().model_dump(mode="json", *args, **kwargs) # type: ignore
|
||||
@@ -61,11 +64,17 @@ class QADocsResponse(RetrievalDocs):
|
||||
class StreamStopReason(Enum):
|
||||
CONTEXT_LENGTH = "context_length"
|
||||
CANCELLED = "cancelled"
|
||||
FINISHED = "finished"
|
||||
|
||||
|
||||
class StreamStopInfo(BaseModel):
|
||||
stop_reason: StreamStopReason
|
||||
|
||||
stream_type: Literal["", "sub_questions", "sub_answer"] = ""
|
||||
# used to identify the stream that was stopped for agent search
|
||||
level: int | None = None
|
||||
level_question_nr: int | None = None
|
||||
|
||||
def model_dump(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
|
||||
data = super().model_dump(mode="json", *args, **kwargs) # type: ignore
|
||||
data["stop_reason"] = self.stop_reason.name
|
||||
@@ -108,6 +117,8 @@ class OnyxAnswerPiece(BaseModel):
|
||||
class CitationInfo(BaseModel):
|
||||
citation_num: int
|
||||
document_id: str
|
||||
level: int | None = None
|
||||
level_question_nr: int | None = None
|
||||
|
||||
|
||||
class AllCitations(BaseModel):
|
||||
@@ -299,6 +310,40 @@ class PromptConfig(BaseModel):
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
|
||||
class SubQueryPiece(BaseModel):
|
||||
sub_query: str
|
||||
level: int
|
||||
level_question_nr: int
|
||||
query_id: int
|
||||
|
||||
|
||||
class AgentAnswerPiece(BaseModel):
|
||||
answer_piece: str
|
||||
level: int
|
||||
level_question_nr: int
|
||||
answer_type: Literal["agent_sub_answer", "agent_level_answer"]
|
||||
|
||||
|
||||
class SubQuestionPiece(BaseModel):
|
||||
sub_question: str
|
||||
level: int
|
||||
level_question_nr: int
|
||||
|
||||
|
||||
class ExtendedToolResponse(ToolResponse):
|
||||
level: int
|
||||
level_question_nr: int
|
||||
|
||||
|
||||
ProSearchPacket = (
|
||||
SubQuestionPiece | AgentAnswerPiece | SubQueryPiece | ExtendedToolResponse
|
||||
)
|
||||
|
||||
AnswerPacket = (
|
||||
AnswerQuestionPossibleReturn | ProSearchPacket | ToolCallKickoff | ToolResponse
|
||||
)
|
||||
|
||||
|
||||
ResponsePart = (
|
||||
OnyxAnswerPiece
|
||||
| CitationInfo
|
||||
@@ -306,4 +351,7 @@ ResponsePart = (
|
||||
| ToolResponse
|
||||
| ToolCallFinalResult
|
||||
| StreamStopInfo
|
||||
| ProSearchPacket
|
||||
)
|
||||
|
||||
AnswerStream = Iterator[AnswerPacket]
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
import traceback
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.chat.answer import Answer
|
||||
from onyx.chat.chat_utils import create_chat_chain
|
||||
from onyx.chat.chat_utils import create_temporary_persona
|
||||
@@ -16,6 +19,7 @@ from onyx.chat.models import CitationConfig
|
||||
from onyx.chat.models import CitationInfo
|
||||
from onyx.chat.models import CustomToolResponse
|
||||
from onyx.chat.models import DocumentPruningConfig
|
||||
from onyx.chat.models import ExtendedToolResponse
|
||||
from onyx.chat.models import FileChatDisplay
|
||||
from onyx.chat.models import FinalUsedContextDocsResponse
|
||||
from onyx.chat.models import LLMRelevanceFilterResponse
|
||||
@@ -24,20 +28,29 @@ from onyx.chat.models import MessageSpecificCitations
|
||||
from onyx.chat.models import OnyxAnswerPiece
|
||||
from onyx.chat.models import OnyxContexts
|
||||
from onyx.chat.models import PromptConfig
|
||||
from onyx.chat.models import ProSearchPacket
|
||||
from onyx.chat.models import QADocsResponse
|
||||
from onyx.chat.models import StreamingError
|
||||
from onyx.chat.models import StreamStopInfo
|
||||
from onyx.chat.models import StreamStopReason
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_system_message
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_user_message
|
||||
from onyx.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
|
||||
from onyx.configs.chat_configs import DISABLE_LLM_CHOOSE_SEARCH
|
||||
from onyx.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
|
||||
from onyx.configs.constants import AGENT_SEARCH_INITIAL_KEY
|
||||
from onyx.configs.constants import BASIC_KEY
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.configs.constants import MilestoneRecordType
|
||||
from onyx.configs.constants import NO_AUTH_USER_ID
|
||||
from onyx.context.search.enums import LLMEvaluationType
|
||||
from onyx.context.search.enums import OptionalSearchSetting
|
||||
from onyx.context.search.enums import QueryFlow
|
||||
from onyx.context.search.enums import SearchType
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.context.search.models import RetrievalDetails
|
||||
from onyx.context.search.models import SearchRequest
|
||||
from onyx.context.search.retrieval.search_runner import inference_sections_from_ids
|
||||
from onyx.context.search.utils import chunks_or_sections_to_search_docs
|
||||
from onyx.context.search.utils import dedupe_documents
|
||||
@@ -120,6 +133,7 @@ from onyx.tools.tool_implementations.search.search_tool import (
|
||||
SECTION_RELEVANCE_LIST_ID,
|
||||
)
|
||||
from onyx.tools.tool_runner import ToolCallFinalResult
|
||||
from onyx.tools.utils import explicit_tool_calling_supported
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.long_term_log import LongTermLogger
|
||||
from onyx.utils.telemetry import mt_cloud_telemetry
|
||||
@@ -127,7 +141,6 @@ from onyx.utils.timing import log_function_time
|
||||
from onyx.utils.timing import log_generator_function_time
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@@ -159,12 +172,15 @@ def _handle_search_tool_response_summary(
|
||||
) -> tuple[QADocsResponse, list[DbSearchDoc], list[int] | None]:
|
||||
response_sumary = cast(SearchResponseSummary, packet.response)
|
||||
|
||||
is_extended = isinstance(packet, ExtendedToolResponse)
|
||||
dropped_inds = None
|
||||
if not selected_search_docs:
|
||||
top_docs = chunks_or_sections_to_search_docs(response_sumary.top_sections)
|
||||
|
||||
deduped_docs = top_docs
|
||||
if dedupe_docs:
|
||||
if (
|
||||
dedupe_docs and not is_extended
|
||||
): # Extended tool responses are already deduped
|
||||
deduped_docs, dropped_inds = dedupe_documents(top_docs)
|
||||
|
||||
reference_db_search_docs = [
|
||||
@@ -178,6 +194,10 @@ def _handle_search_tool_response_summary(
|
||||
translate_db_search_doc_to_server_search_doc(db_search_doc)
|
||||
for db_search_doc in reference_db_search_docs
|
||||
]
|
||||
|
||||
level, question_nr = None, None
|
||||
if isinstance(packet, ExtendedToolResponse):
|
||||
level, question_nr = packet.level, packet.level_question_nr
|
||||
return (
|
||||
QADocsResponse(
|
||||
rephrased_query=response_sumary.rephrased_query,
|
||||
@@ -187,6 +207,8 @@ def _handle_search_tool_response_summary(
|
||||
applied_source_filters=response_sumary.final_filters.source_type,
|
||||
applied_time_cutoff=response_sumary.final_filters.time_cutoff,
|
||||
recency_bias_multiplier=response_sumary.recency_bias_multiplier,
|
||||
level=level,
|
||||
level_question_nr=question_nr,
|
||||
),
|
||||
reference_db_search_docs,
|
||||
dropped_inds,
|
||||
@@ -281,10 +303,22 @@ ChatPacket = (
|
||||
| MessageSpecificCitations
|
||||
| MessageResponseIDInfo
|
||||
| StreamStopInfo
|
||||
| ProSearchPacket
|
||||
)
|
||||
ChatPacketStream = Iterator[ChatPacket]
|
||||
|
||||
|
||||
# can't store a DbSearchDoc in a Pydantic BaseModel
|
||||
@dataclass
|
||||
class AnswerPostInfo:
|
||||
ai_message_files: list[FileDescriptor]
|
||||
qa_docs_response: QADocsResponse | None = None
|
||||
reference_db_search_docs: list[DbSearchDoc] | None = None
|
||||
dropped_indices: list[int] | None = None
|
||||
tool_result: ToolCallFinalResult | None = None
|
||||
message_specific_citations: MessageSpecificCitations | None = None
|
||||
|
||||
|
||||
def stream_chat_message_objects(
|
||||
new_msg_req: CreateChatMessageRequest,
|
||||
user: User | None,
|
||||
@@ -323,6 +357,7 @@ def stream_chat_message_objects(
|
||||
new_msg_req.chunks_above = 0
|
||||
new_msg_req.chunks_below = 0
|
||||
|
||||
llm = None
|
||||
try:
|
||||
user_id = user.id if user is not None else None
|
||||
|
||||
@@ -511,11 +546,8 @@ def stream_chat_message_objects(
|
||||
files = load_all_chat_files(
|
||||
history_msgs, new_msg_req.file_descriptors, db_session
|
||||
)
|
||||
latest_query_files = [
|
||||
file
|
||||
for file in files
|
||||
if file.file_id in [f["id"] for f in new_msg_req.file_descriptors]
|
||||
]
|
||||
req_file_ids = [f["id"] for f in new_msg_req.file_descriptors]
|
||||
latest_query_files = [file for file in files if file.file_id in req_file_ids]
|
||||
|
||||
if user_message:
|
||||
attach_files_to_chat_message(
|
||||
@@ -688,6 +720,83 @@ def stream_chat_message_objects(
|
||||
for tool_list in tool_dict.values():
|
||||
tools.extend(tool_list)
|
||||
|
||||
message_history = [
|
||||
PreviousMessage.from_chat_message(msg, files) for msg in history_msgs
|
||||
]
|
||||
|
||||
search_request = SearchRequest(
|
||||
query=final_msg.message,
|
||||
evaluation_type=(
|
||||
LLMEvaluationType.BASIC
|
||||
if persona.llm_relevance_filter
|
||||
else LLMEvaluationType.SKIP
|
||||
),
|
||||
human_selected_filters=(
|
||||
retrieval_options.filters if retrieval_options else None
|
||||
),
|
||||
persona=persona,
|
||||
offset=(retrieval_options.offset if retrieval_options else None),
|
||||
limit=retrieval_options.limit if retrieval_options else None,
|
||||
rerank_settings=new_msg_req.rerank_settings,
|
||||
chunks_above=new_msg_req.chunks_above,
|
||||
chunks_below=new_msg_req.chunks_below,
|
||||
full_doc=new_msg_req.full_doc,
|
||||
enable_auto_detect_filters=(
|
||||
retrieval_options.enable_auto_detect_filters
|
||||
if retrieval_options
|
||||
else None
|
||||
),
|
||||
)
|
||||
# TODO: Since we're deleting the current main path in Answer,
|
||||
# we should construct this unconditionally inside Answer instead
|
||||
# Leaving it here for the time being to avoid breaking changes
|
||||
search_tools = [tool for tool in tools if isinstance(tool, SearchTool)]
|
||||
if len(search_tools) == 0:
|
||||
raise ValueError("No search tool found")
|
||||
elif len(search_tools) > 1:
|
||||
# TODO: handle multiple search tools
|
||||
raise ValueError("Multiple search tools found")
|
||||
search_tool = search_tools[0]
|
||||
using_tool_calling_llm = explicit_tool_calling_supported(
|
||||
llm.config.model_provider, llm.config.model_name
|
||||
)
|
||||
force_use_tool = _get_force_search_settings(new_msg_req, tools)
|
||||
prompt_builder = AnswerPromptBuilder(
|
||||
user_message=default_build_user_message(
|
||||
user_query=final_msg.message,
|
||||
prompt_config=prompt_config,
|
||||
files=latest_query_files,
|
||||
single_message_history=single_message_history,
|
||||
),
|
||||
system_message=default_build_system_message(prompt_config),
|
||||
message_history=message_history,
|
||||
llm_config=llm.config,
|
||||
raw_user_query=final_msg.message,
|
||||
raw_user_uploaded_files=latest_query_files or [],
|
||||
single_message_history=single_message_history,
|
||||
)
|
||||
agent_search_config = AgentSearchConfig(
|
||||
search_request=search_request,
|
||||
primary_llm=llm,
|
||||
fast_llm=fast_llm,
|
||||
search_tool=search_tool,
|
||||
force_use_tool=force_use_tool,
|
||||
use_agentic_search=new_msg_req.use_agentic_search,
|
||||
chat_session_id=chat_session_id,
|
||||
message_id=reserved_message_id,
|
||||
use_persistence=True,
|
||||
allow_refinement=True,
|
||||
db_session=db_session,
|
||||
prompt_builder=prompt_builder,
|
||||
tools=tools,
|
||||
using_tool_calling_llm=using_tool_calling_llm,
|
||||
files=latest_query_files,
|
||||
structured_response_format=new_msg_req.structured_response_format,
|
||||
skip_gen_ai_answer_generation=new_msg_req.skip_gen_ai_answer_generation,
|
||||
)
|
||||
|
||||
# TODO: add previous messages, answer style config, tools, etc.
|
||||
|
||||
# LLM prompt building, response capturing, etc.
|
||||
answer = Answer(
|
||||
is_connected=is_connected,
|
||||
@@ -707,28 +816,40 @@ def stream_chat_message_objects(
|
||||
)
|
||||
)
|
||||
),
|
||||
message_history=[
|
||||
PreviousMessage.from_chat_message(msg, files) for msg in history_msgs
|
||||
],
|
||||
fast_llm=fast_llm,
|
||||
message_history=message_history,
|
||||
tools=tools,
|
||||
force_use_tool=_get_force_search_settings(new_msg_req, tools),
|
||||
force_use_tool=force_use_tool,
|
||||
single_message_history=single_message_history,
|
||||
agent_search_config=agent_search_config,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
reference_db_search_docs = None
|
||||
qa_docs_response = None
|
||||
# any files to associate with the AI message e.g. dall-e generated images
|
||||
ai_message_files = []
|
||||
dropped_indices = None
|
||||
tool_result = None
|
||||
# reference_db_search_docs = None
|
||||
# qa_docs_response = None
|
||||
# # any files to associate with the AI message e.g. dall-e generated images
|
||||
# ai_message_files = []
|
||||
# dropped_indices = None
|
||||
# tool_result = None
|
||||
|
||||
# TODO: different channels for stored info when it's coming from the agent flow
|
||||
info_by_subq: dict[tuple[int, int], AnswerPostInfo] = defaultdict(
|
||||
lambda: AnswerPostInfo(ai_message_files=[])
|
||||
)
|
||||
for packet in answer.processed_streamed_output:
|
||||
if isinstance(packet, ToolResponse):
|
||||
level, level_question_nr = (
|
||||
(packet.level, packet.level_question_nr)
|
||||
if isinstance(packet, ExtendedToolResponse)
|
||||
else BASIC_KEY
|
||||
)
|
||||
info = info_by_subq[(level, level_question_nr)]
|
||||
# TODO: don't need to dedupe here when we do it in agent flow
|
||||
if packet.id == SEARCH_RESPONSE_SUMMARY_ID:
|
||||
(
|
||||
qa_docs_response,
|
||||
reference_db_search_docs,
|
||||
dropped_indices,
|
||||
info.qa_docs_response,
|
||||
info.reference_db_search_docs,
|
||||
info.dropped_indices,
|
||||
) = _handle_search_tool_response_summary(
|
||||
packet=packet,
|
||||
db_session=db_session,
|
||||
@@ -740,29 +861,34 @@ def stream_chat_message_objects(
|
||||
else False
|
||||
),
|
||||
)
|
||||
yield qa_docs_response
|
||||
yield info.qa_docs_response
|
||||
elif packet.id == SECTION_RELEVANCE_LIST_ID:
|
||||
relevance_sections = packet.response
|
||||
|
||||
if reference_db_search_docs is not None:
|
||||
llm_indices = relevant_sections_to_indices(
|
||||
relevance_sections=relevance_sections,
|
||||
items=[
|
||||
translate_db_search_doc_to_server_search_doc(doc)
|
||||
for doc in reference_db_search_docs
|
||||
],
|
||||
if info.reference_db_search_docs is None:
|
||||
logger.warning(
|
||||
"No reference docs found for relevance filtering"
|
||||
)
|
||||
continue
|
||||
|
||||
llm_indices = relevant_sections_to_indices(
|
||||
relevance_sections=relevance_sections,
|
||||
items=[
|
||||
translate_db_search_doc_to_server_search_doc(doc)
|
||||
for doc in info.reference_db_search_docs
|
||||
],
|
||||
)
|
||||
|
||||
if info.dropped_indices:
|
||||
llm_indices = drop_llm_indices(
|
||||
llm_indices=llm_indices,
|
||||
search_docs=info.reference_db_search_docs,
|
||||
dropped_indices=info.dropped_indices,
|
||||
)
|
||||
|
||||
if dropped_indices:
|
||||
llm_indices = drop_llm_indices(
|
||||
llm_indices=llm_indices,
|
||||
search_docs=reference_db_search_docs,
|
||||
dropped_indices=dropped_indices,
|
||||
)
|
||||
|
||||
yield LLMRelevanceFilterResponse(
|
||||
llm_selected_doc_indices=llm_indices
|
||||
)
|
||||
yield LLMRelevanceFilterResponse(
|
||||
llm_selected_doc_indices=llm_indices
|
||||
)
|
||||
elif packet.id == FINAL_CONTEXT_DOCUMENTS_ID:
|
||||
yield FinalUsedContextDocsResponse(
|
||||
final_context_docs=packet.response
|
||||
@@ -782,22 +908,24 @@ def stream_chat_message_objects(
|
||||
],
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
ai_message_files = [
|
||||
FileDescriptor(id=str(file_id), type=ChatFileType.IMAGE)
|
||||
for file_id in file_ids
|
||||
]
|
||||
info.ai_message_files.extend(
|
||||
[
|
||||
FileDescriptor(id=str(file_id), type=ChatFileType.IMAGE)
|
||||
for file_id in file_ids
|
||||
]
|
||||
)
|
||||
yield FileChatDisplay(
|
||||
file_ids=[str(file_id) for file_id in file_ids]
|
||||
)
|
||||
elif packet.id == INTERNET_SEARCH_RESPONSE_ID:
|
||||
(
|
||||
qa_docs_response,
|
||||
reference_db_search_docs,
|
||||
info.qa_docs_response,
|
||||
info.reference_db_search_docs,
|
||||
) = _handle_internet_search_tool_response_summary(
|
||||
packet=packet,
|
||||
db_session=db_session,
|
||||
)
|
||||
yield qa_docs_response
|
||||
yield info.qa_docs_response
|
||||
elif packet.id == CUSTOM_TOOL_RESPONSE_ID:
|
||||
custom_tool_response = cast(CustomToolCallSummary, packet.response)
|
||||
|
||||
@@ -806,7 +934,7 @@ def stream_chat_message_objects(
|
||||
or custom_tool_response.response_type == "csv"
|
||||
):
|
||||
file_ids = custom_tool_response.tool_result.file_ids
|
||||
ai_message_files.extend(
|
||||
info.ai_message_files.extend(
|
||||
[
|
||||
FileDescriptor(
|
||||
id=str(file_id),
|
||||
@@ -831,10 +959,18 @@ def stream_chat_message_objects(
|
||||
yield cast(OnyxContexts, packet.response)
|
||||
|
||||
elif isinstance(packet, StreamStopInfo):
|
||||
pass
|
||||
if packet.stop_reason == StreamStopReason.FINISHED:
|
||||
yield packet
|
||||
else:
|
||||
if isinstance(packet, ToolCallFinalResult):
|
||||
tool_result = packet
|
||||
level, level_question_nr = (
|
||||
(packet.level, packet.level_question_nr)
|
||||
if packet.level is not None
|
||||
and packet.level_question_nr is not None
|
||||
else BASIC_KEY
|
||||
)
|
||||
info = info_by_subq[(level, level_question_nr)]
|
||||
info.tool_result = packet
|
||||
yield cast(ChatPacket, packet)
|
||||
logger.debug("Reached end of stream")
|
||||
except ValueError as e:
|
||||
@@ -850,59 +986,98 @@ def stream_chat_message_objects(
|
||||
|
||||
error_msg = str(e)
|
||||
stack_trace = traceback.format_exc()
|
||||
client_error_msg = litellm_exception_to_error_msg(e, llm)
|
||||
if llm.config.api_key and len(llm.config.api_key) > 2:
|
||||
error_msg = error_msg.replace(llm.config.api_key, "[REDACTED_API_KEY]")
|
||||
stack_trace = stack_trace.replace(llm.config.api_key, "[REDACTED_API_KEY]")
|
||||
if llm:
|
||||
client_error_msg = litellm_exception_to_error_msg(e, llm)
|
||||
if llm.config.api_key and len(llm.config.api_key) > 2:
|
||||
error_msg = error_msg.replace(llm.config.api_key, "[REDACTED_API_KEY]")
|
||||
stack_trace = stack_trace.replace(
|
||||
llm.config.api_key, "[REDACTED_API_KEY]"
|
||||
)
|
||||
|
||||
yield StreamingError(error=client_error_msg, stack_trace=stack_trace)
|
||||
yield StreamingError(error=client_error_msg, stack_trace=stack_trace)
|
||||
db_session.rollback()
|
||||
return
|
||||
|
||||
# Post-LLM answer processing
|
||||
try:
|
||||
logger.debug("Post-LLM answer processing")
|
||||
message_specific_citations: MessageSpecificCitations | None = None
|
||||
if reference_db_search_docs:
|
||||
message_specific_citations = _translate_citations(
|
||||
citations_list=answer.citations,
|
||||
db_docs=reference_db_search_docs,
|
||||
)
|
||||
if not answer.is_cancelled():
|
||||
yield AllCitations(citations=answer.citations)
|
||||
|
||||
# Saving Gen AI answer and responding with message info
|
||||
tool_name_to_tool_id: dict[str, int] = {}
|
||||
for tool_id, tool_list in tool_dict.items():
|
||||
for tool in tool_list:
|
||||
tool_name_to_tool_id[tool.name] = tool_id
|
||||
|
||||
subq_citations = answer.citations_by_subquestion()
|
||||
for pair in subq_citations:
|
||||
level, level_question_nr = pair
|
||||
info = info_by_subq[(level, level_question_nr)]
|
||||
logger.debug("Post-LLM answer processing")
|
||||
if info.reference_db_search_docs:
|
||||
info.message_specific_citations = _translate_citations(
|
||||
citations_list=subq_citations[pair],
|
||||
db_docs=info.reference_db_search_docs,
|
||||
)
|
||||
|
||||
# TODO: AllCitations should contain subq info?
|
||||
if not answer.is_cancelled():
|
||||
yield AllCitations(citations=subq_citations[pair])
|
||||
|
||||
# Saving Gen AI answer and responding with message info
|
||||
|
||||
info = (
|
||||
info_by_subq[BASIC_KEY]
|
||||
if BASIC_KEY in info_by_subq
|
||||
else info_by_subq[AGENT_SEARCH_INITIAL_KEY]
|
||||
)
|
||||
gen_ai_response_message = partial_response(
|
||||
message=answer.llm_answer,
|
||||
rephrased_query=(
|
||||
qa_docs_response.rephrased_query if qa_docs_response else None
|
||||
info.qa_docs_response.rephrased_query if info.qa_docs_response else None
|
||||
),
|
||||
reference_docs=reference_db_search_docs,
|
||||
files=ai_message_files,
|
||||
reference_docs=info.reference_db_search_docs,
|
||||
files=info.ai_message_files,
|
||||
token_count=len(llm_tokenizer_encode_func(answer.llm_answer)),
|
||||
citations=(
|
||||
message_specific_citations.citation_map
|
||||
if message_specific_citations
|
||||
info.message_specific_citations.citation_map
|
||||
if info.message_specific_citations
|
||||
else None
|
||||
),
|
||||
error=None,
|
||||
tool_call=(
|
||||
ToolCall(
|
||||
tool_id=tool_name_to_tool_id[tool_result.tool_name],
|
||||
tool_name=tool_result.tool_name,
|
||||
tool_arguments=tool_result.tool_args,
|
||||
tool_result=tool_result.tool_result,
|
||||
tool_id=tool_name_to_tool_id[info.tool_result.tool_name],
|
||||
tool_name=info.tool_result.tool_name,
|
||||
tool_arguments=info.tool_result.tool_args,
|
||||
tool_result=info.tool_result.tool_result,
|
||||
)
|
||||
if tool_result
|
||||
if info.tool_result
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
# TODO: add answers for levels >= 1, where each level has the previous as its parent. Use
|
||||
# the answer_by_level method in answer.py to get the answers for each level
|
||||
next_level = 1
|
||||
prev_message = gen_ai_response_message
|
||||
agent_answers = answer.llm_answer_by_level()
|
||||
while next_level in agent_answers:
|
||||
next_answer = agent_answers[next_level]
|
||||
info = info_by_subq[(next_level, AGENT_SEARCH_INITIAL_KEY[1])]
|
||||
next_answer_message = create_new_chat_message(
|
||||
chat_session_id=chat_session_id,
|
||||
parent_message=prev_message,
|
||||
message=next_answer,
|
||||
prompt_id=None,
|
||||
token_count=len(llm_tokenizer_encode_func(next_answer)),
|
||||
message_type=MessageType.ASSISTANT,
|
||||
db_session=db_session,
|
||||
files=info.ai_message_files,
|
||||
reference_docs=info.reference_db_search_docs,
|
||||
citations=info.message_specific_citations.citation_map
|
||||
if info.message_specific_citations
|
||||
else None,
|
||||
)
|
||||
next_level += 1
|
||||
prev_message = next_answer_message
|
||||
|
||||
logger.debug("Committing messages")
|
||||
db_session.commit() # actually save user / assistant message
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user