mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-16 23:35:46 +00:00
Compare commits
86 Commits
nightly-la
...
asdfasdfas
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9ee4c3daa9 | ||
|
|
a0d1d95df6 | ||
|
|
f546e85cff | ||
|
|
5c82f920ec | ||
|
|
ec9888fca2 | ||
|
|
8962c768a0 | ||
|
|
22565d7334 | ||
|
|
5d9bab05e0 | ||
|
|
8a99d24434 | ||
|
|
e5f1f67a71 | ||
|
|
bee0137509 | ||
|
|
f6afc95e2f | ||
|
|
1c9bc48705 | ||
|
|
e6242f7438 | ||
|
|
ad16d66684 | ||
|
|
76c72f88e9 | ||
|
|
4d218c8a5d | ||
|
|
36d4e0a1c6 | ||
|
|
f233ed071e | ||
|
|
9770db967b | ||
|
|
38a63d74a9 | ||
|
|
17fb4219da | ||
|
|
d019e14d00 | ||
|
|
70f61d53e0 | ||
|
|
c2d9ad9143 | ||
|
|
7b6901a8c1 | ||
|
|
16300b42a1 | ||
|
|
5635b32ac4 | ||
|
|
13ec194153 | ||
|
|
db0dee1aac | ||
|
|
23dbb521c3 | ||
|
|
b81f03131f | ||
|
|
e896786693 | ||
|
|
1ebddb1ebe | ||
|
|
0146d3c66d | ||
|
|
f2a54b79ac | ||
|
|
27982a1d5a | ||
|
|
ff3b4d28f4 | ||
|
|
081efa0831 | ||
|
|
4dec31c93d | ||
|
|
b693fe8248 | ||
|
|
9c2b152dd7 | ||
|
|
3afb79a7d8 | ||
|
|
a1c3ba7eba | ||
|
|
1bb7f782c8 | ||
|
|
ee8d9ddc3d | ||
|
|
54d6f31f0d | ||
|
|
3306479674 | ||
|
|
c08ed464a9 | ||
|
|
d07588d1ce | ||
|
|
9d494adc3e | ||
|
|
be7f3f6eed | ||
|
|
8561a50eac | ||
|
|
3ce783a1fe | ||
|
|
e9ea2a1b1f | ||
|
|
b6caedf50b | ||
|
|
4ad1dca233 | ||
|
|
df9f808c4a | ||
|
|
d68bcb6ac2 | ||
|
|
e9f2d6468d | ||
|
|
0ca153b651 | ||
|
|
279f0e9374 | ||
|
|
ea73f12844 | ||
|
|
1007d4684a | ||
|
|
d336fa31b4 | ||
|
|
0b7719c158 | ||
|
|
c41f4599bf | ||
|
|
6bc232f040 | ||
|
|
b10166d4d2 | ||
|
|
836f84f946 | ||
|
|
ffb627621e | ||
|
|
ca776932a9 | ||
|
|
64e13986bb | ||
|
|
fcaee1979f | ||
|
|
105a0f7b49 | ||
|
|
afce385fd0 | ||
|
|
978650028a | ||
|
|
c87270a279 | ||
|
|
bc13a6caa7 | ||
|
|
f14bc9ab8a | ||
|
|
b95c566c15 | ||
|
|
9f5c9142b5 | ||
|
|
0c5b47e36f | ||
|
|
09020a202e | ||
|
|
2a662c40b8 | ||
|
|
35593cae0d |
4
.gitignore
vendored
4
.gitignore
vendored
@@ -7,4 +7,6 @@
|
||||
.vscode/
|
||||
*.sw?
|
||||
/backend/tests/regression/answer_quality/search_test_config.yaml
|
||||
/web/test-results/
|
||||
/web/test-results/
|
||||
backend/onyx/agent_search/main/test_data.json
|
||||
backend/tests/regression/answer_quality/test_data.json
|
||||
|
||||
6
.vscode/env_template.txt
vendored
6
.vscode/env_template.txt
vendored
@@ -52,3 +52,9 @@ BING_API_KEY=<REPLACE THIS>
|
||||
# Enable the full set of Danswer Enterprise Edition features
|
||||
# NOTE: DO NOT ENABLE THIS UNLESS YOU HAVE A PAID ENTERPRISE LICENSE (or if you are using this for local testing/development)
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=False
|
||||
|
||||
# Agent Search configs # TODO: Remove give proper namings
|
||||
AGENT_RETRIEVAL_STATS=False # Note: This setting will incur substantial re-ranking effort
|
||||
AGENT_RERANKING_STATS=True
|
||||
AGENT_MAX_QUERY_RETRIEVAL_RESULTS=20
|
||||
AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS=20
|
||||
|
||||
@@ -0,0 +1,29 @@
|
||||
"""agent_doc_result_col
|
||||
|
||||
Revision ID: 1adf5ea20d2b
|
||||
Revises: e9cf2bd7baed
|
||||
Create Date: 2025-01-05 13:14:58.344316
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "1adf5ea20d2b"
|
||||
down_revision = "e9cf2bd7baed"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add the new column with JSONB type
|
||||
op.add_column(
|
||||
"sub_question",
|
||||
sa.Column("sub_question_doc_results", postgresql.JSONB(), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop the column
|
||||
op.drop_column("sub_question", "sub_question_doc_results")
|
||||
@@ -0,0 +1,31 @@
|
||||
"""refined answer improvement
|
||||
|
||||
Revision ID: 211b14ab5a91
|
||||
Revises: 925b58bd75b6
|
||||
Create Date: 2025-01-24 14:05:03.334309
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "211b14ab5a91"
|
||||
down_revision = "925b58bd75b6"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"chat_message",
|
||||
sa.Column(
|
||||
"refined_answer_improvement",
|
||||
sa.Boolean(),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("chat_message", "refined_answer_improvement")
|
||||
@@ -1,80 +0,0 @@
|
||||
"""foreign key input prompts
|
||||
|
||||
Revision ID: 33ea50e88f24
|
||||
Revises: a6df6b88ef81
|
||||
Create Date: 2025-01-29 10:54:22.141765
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "33ea50e88f24"
|
||||
down_revision = "a6df6b88ef81"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Safely drop constraints if exists
|
||||
op.execute(
|
||||
"""
|
||||
ALTER TABLE inputprompt__user
|
||||
DROP CONSTRAINT IF EXISTS inputprompt__user_input_prompt_id_fkey
|
||||
"""
|
||||
)
|
||||
op.execute(
|
||||
"""
|
||||
ALTER TABLE inputprompt__user
|
||||
DROP CONSTRAINT IF EXISTS inputprompt__user_user_id_fkey
|
||||
"""
|
||||
)
|
||||
|
||||
# Recreate with ON DELETE CASCADE
|
||||
op.create_foreign_key(
|
||||
"inputprompt__user_input_prompt_id_fkey",
|
||||
"inputprompt__user",
|
||||
"inputprompt",
|
||||
["input_prompt_id"],
|
||||
["id"],
|
||||
ondelete="CASCADE",
|
||||
)
|
||||
|
||||
op.create_foreign_key(
|
||||
"inputprompt__user_user_id_fkey",
|
||||
"inputprompt__user",
|
||||
"user",
|
||||
["user_id"],
|
||||
["id"],
|
||||
ondelete="CASCADE",
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop the new FKs with ondelete
|
||||
op.drop_constraint(
|
||||
"inputprompt__user_input_prompt_id_fkey",
|
||||
"inputprompt__user",
|
||||
type_="foreignkey",
|
||||
)
|
||||
op.drop_constraint(
|
||||
"inputprompt__user_user_id_fkey",
|
||||
"inputprompt__user",
|
||||
type_="foreignkey",
|
||||
)
|
||||
|
||||
# Recreate them without cascading
|
||||
op.create_foreign_key(
|
||||
"inputprompt__user_input_prompt_id_fkey",
|
||||
"inputprompt__user",
|
||||
"inputprompt",
|
||||
["input_prompt_id"],
|
||||
["id"],
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"inputprompt__user_user_id_fkey",
|
||||
"inputprompt__user",
|
||||
"user",
|
||||
["user_id"],
|
||||
["id"],
|
||||
)
|
||||
@@ -0,0 +1,35 @@
|
||||
"""agent_metric_col_rename__s
|
||||
|
||||
Revision ID: 925b58bd75b6
|
||||
Revises: 9787be927e58
|
||||
Create Date: 2025-01-06 11:20:26.752441
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "925b58bd75b6"
|
||||
down_revision = "9787be927e58"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Rename columns using PostgreSQL syntax
|
||||
op.alter_column(
|
||||
"agent__search_metrics", "base_duration_s", new_column_name="base_duration__s"
|
||||
)
|
||||
op.alter_column(
|
||||
"agent__search_metrics", "full_duration_s", new_column_name="full_duration__s"
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Revert the column renames
|
||||
op.alter_column(
|
||||
"agent__search_metrics", "base_duration__s", new_column_name="base_duration_s"
|
||||
)
|
||||
op.alter_column(
|
||||
"agent__search_metrics", "full_duration__s", new_column_name="full_duration_s"
|
||||
)
|
||||
@@ -0,0 +1,25 @@
|
||||
"""agent_metric_table_renames__agent__
|
||||
|
||||
Revision ID: 9787be927e58
|
||||
Revises: bceb76d618ec
|
||||
Create Date: 2025-01-06 11:01:44.210160
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "9787be927e58"
|
||||
down_revision = "bceb76d618ec"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Rename table from agent_search_metrics to agent__search_metrics
|
||||
op.rename_table("agent_search_metrics", "agent__search_metrics")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Rename table back from agent__search_metrics to agent_search_metrics
|
||||
op.rename_table("agent__search_metrics", "agent_search_metrics")
|
||||
42
backend/alembic/versions/98a5008d8711_agent_tracking.py
Normal file
42
backend/alembic/versions/98a5008d8711_agent_tracking.py
Normal file
@@ -0,0 +1,42 @@
|
||||
"""agent_tracking
|
||||
|
||||
Revision ID: 98a5008d8711
|
||||
Revises: 4d58345da04a
|
||||
Create Date: 2025-01-04 14:41:52.732238
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "98a5008d8711"
|
||||
down_revision = "4d58345da04a"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"agent_search_metrics",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=True),
|
||||
sa.Column("persona_id", sa.Integer(), nullable=True),
|
||||
sa.Column("agent_type", sa.String(), nullable=False),
|
||||
sa.Column("start_time", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("base_duration_s", sa.Float(), nullable=False),
|
||||
sa.Column("full_duration_s", sa.Float(), nullable=False),
|
||||
sa.Column("base_metrics", postgresql.JSONB(), nullable=True),
|
||||
sa.Column("refined_metrics", postgresql.JSONB(), nullable=True),
|
||||
sa.Column("all_metrics", postgresql.JSONB(), nullable=True),
|
||||
sa.ForeignKeyConstraint(
|
||||
["persona_id"],
|
||||
["persona.id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("agent_search_metrics")
|
||||
@@ -1,29 +0,0 @@
|
||||
"""remove recent assistants
|
||||
|
||||
Revision ID: a6df6b88ef81
|
||||
Revises: 4d58345da04a
|
||||
Create Date: 2025-01-29 10:25:52.790407
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "a6df6b88ef81"
|
||||
down_revision = "4d58345da04a"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.drop_column("user", "recent_assistants")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column(
|
||||
"recent_assistants", postgresql.JSONB(), server_default="[]", nullable=False
|
||||
),
|
||||
)
|
||||
@@ -0,0 +1,84 @@
|
||||
"""agent_table_renames__agent__
|
||||
|
||||
Revision ID: bceb76d618ec
|
||||
Revises: c0132518a25b
|
||||
Create Date: 2025-01-06 10:50:48.109285
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "bceb76d618ec"
|
||||
down_revision = "c0132518a25b"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.drop_constraint(
|
||||
"sub_query__search_doc_sub_query_id_fkey",
|
||||
"sub_query__search_doc",
|
||||
type_="foreignkey",
|
||||
)
|
||||
op.drop_constraint(
|
||||
"sub_query__search_doc_search_doc_id_fkey",
|
||||
"sub_query__search_doc",
|
||||
type_="foreignkey",
|
||||
)
|
||||
# Rename tables
|
||||
op.rename_table("sub_query", "agent__sub_query")
|
||||
op.rename_table("sub_question", "agent__sub_question")
|
||||
op.rename_table("sub_query__search_doc", "agent__sub_query__search_doc")
|
||||
|
||||
# Update both foreign key constraints for agent__sub_query__search_doc
|
||||
|
||||
# Create new foreign keys with updated names
|
||||
op.create_foreign_key(
|
||||
"agent__sub_query__search_doc_sub_query_id_fkey",
|
||||
"agent__sub_query__search_doc",
|
||||
"agent__sub_query",
|
||||
["sub_query_id"],
|
||||
["id"],
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"agent__sub_query__search_doc_search_doc_id_fkey",
|
||||
"agent__sub_query__search_doc",
|
||||
"search_doc", # This table name doesn't change
|
||||
["search_doc_id"],
|
||||
["id"],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Update foreign key constraints for sub_query__search_doc
|
||||
op.drop_constraint(
|
||||
"agent__sub_query__search_doc_sub_query_id_fkey",
|
||||
"agent__sub_query__search_doc",
|
||||
type_="foreignkey",
|
||||
)
|
||||
op.drop_constraint(
|
||||
"agent__sub_query__search_doc_search_doc_id_fkey",
|
||||
"agent__sub_query__search_doc",
|
||||
type_="foreignkey",
|
||||
)
|
||||
|
||||
# Rename tables back
|
||||
op.rename_table("agent__sub_query__search_doc", "sub_query__search_doc")
|
||||
op.rename_table("agent__sub_question", "sub_question")
|
||||
op.rename_table("agent__sub_query", "sub_query")
|
||||
|
||||
op.create_foreign_key(
|
||||
"sub_query__search_doc_sub_query_id_fkey",
|
||||
"sub_query__search_doc",
|
||||
"sub_query",
|
||||
["sub_query_id"],
|
||||
["id"],
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"sub_query__search_doc_search_doc_id_fkey",
|
||||
"sub_query__search_doc",
|
||||
"search_doc", # This table name doesn't change
|
||||
["search_doc_id"],
|
||||
["id"],
|
||||
)
|
||||
@@ -0,0 +1,40 @@
|
||||
"""agent_table_changes_rename_level
|
||||
|
||||
Revision ID: c0132518a25b
|
||||
Revises: 1adf5ea20d2b
|
||||
Create Date: 2025-01-05 16:38:37.660152
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "c0132518a25b"
|
||||
down_revision = "1adf5ea20d2b"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add level and level_question_nr columns with NOT NULL constraint
|
||||
op.add_column(
|
||||
"sub_question",
|
||||
sa.Column("level", sa.Integer(), nullable=False, server_default="0"),
|
||||
)
|
||||
op.add_column(
|
||||
"sub_question",
|
||||
sa.Column(
|
||||
"level_question_nr", sa.Integer(), nullable=False, server_default="0"
|
||||
),
|
||||
)
|
||||
|
||||
# Remove the server_default after the columns are created
|
||||
op.alter_column("sub_question", "level", server_default=None)
|
||||
op.alter_column("sub_question", "level_question_nr", server_default=None)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Remove the columns
|
||||
op.drop_column("sub_question", "level_question_nr")
|
||||
op.drop_column("sub_question", "level")
|
||||
@@ -0,0 +1,68 @@
|
||||
"""create pro search persistence tables
|
||||
|
||||
Revision ID: e9cf2bd7baed
|
||||
Revises: 98a5008d8711
|
||||
Create Date: 2025-01-02 17:55:56.544246
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "e9cf2bd7baed"
|
||||
down_revision = "98a5008d8711"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Create sub_question table
|
||||
op.create_table(
|
||||
"sub_question",
|
||||
sa.Column("id", sa.Integer, primary_key=True),
|
||||
sa.Column("primary_question_id", sa.Integer, sa.ForeignKey("chat_message.id")),
|
||||
sa.Column(
|
||||
"chat_session_id", UUID(as_uuid=True), sa.ForeignKey("chat_session.id")
|
||||
),
|
||||
sa.Column("sub_question", sa.Text),
|
||||
sa.Column(
|
||||
"time_created", sa.DateTime(timezone=True), server_default=sa.func.now()
|
||||
),
|
||||
sa.Column("sub_answer", sa.Text),
|
||||
)
|
||||
|
||||
# Create sub_query table
|
||||
op.create_table(
|
||||
"sub_query",
|
||||
sa.Column("id", sa.Integer, primary_key=True),
|
||||
sa.Column("parent_question_id", sa.Integer, sa.ForeignKey("sub_question.id")),
|
||||
sa.Column(
|
||||
"chat_session_id", UUID(as_uuid=True), sa.ForeignKey("chat_session.id")
|
||||
),
|
||||
sa.Column("sub_query", sa.Text),
|
||||
sa.Column(
|
||||
"time_created", sa.DateTime(timezone=True), server_default=sa.func.now()
|
||||
),
|
||||
)
|
||||
|
||||
# Create sub_query__search_doc association table
|
||||
op.create_table(
|
||||
"sub_query__search_doc",
|
||||
sa.Column(
|
||||
"sub_query_id", sa.Integer, sa.ForeignKey("sub_query.id"), primary_key=True
|
||||
),
|
||||
sa.Column(
|
||||
"search_doc_id",
|
||||
sa.Integer,
|
||||
sa.ForeignKey("search_doc.id"),
|
||||
primary_key=True,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("sub_query__search_doc")
|
||||
op.drop_table("sub_query")
|
||||
op.drop_table("sub_question")
|
||||
@@ -179,6 +179,7 @@ def handle_simplified_chat_message(
|
||||
chunks_below=0,
|
||||
full_doc=chat_message_req.full_doc,
|
||||
structured_response_format=chat_message_req.structured_response_format,
|
||||
use_agentic_search=chat_message_req.use_agentic_search,
|
||||
)
|
||||
|
||||
packets = stream_chat_message_objects(
|
||||
@@ -301,6 +302,7 @@ def handle_send_message_simple_with_history(
|
||||
chunks_below=0,
|
||||
full_doc=req.full_doc,
|
||||
structured_response_format=req.structured_response_format,
|
||||
use_agentic_search=req.use_agentic_search,
|
||||
)
|
||||
|
||||
packets = stream_chat_message_objects(
|
||||
|
||||
@@ -57,6 +57,9 @@ class BasicCreateChatMessageRequest(ChunkContext):
|
||||
# https://platform.openai.com/docs/guides/structured-outputs/introduction
|
||||
structured_response_format: dict | None = None
|
||||
|
||||
# If True, uses agentic search instead of basic search
|
||||
use_agentic_search: bool = False
|
||||
|
||||
|
||||
class BasicCreateChatMessageWithHistoryRequest(ChunkContext):
|
||||
# Last element is the new query. All previous elements are historical context
|
||||
@@ -71,6 +74,8 @@ class BasicCreateChatMessageWithHistoryRequest(ChunkContext):
|
||||
# only works if using an OpenAI model. See the following for more details:
|
||||
# https://platform.openai.com/docs/guides/structured-outputs/introduction
|
||||
structured_response_format: dict | None = None
|
||||
# If True, uses agentic search instead of basic search
|
||||
use_agentic_search: bool = False
|
||||
|
||||
|
||||
class SimpleDoc(BaseModel):
|
||||
@@ -123,6 +128,9 @@ class OneShotQARequest(ChunkContext):
|
||||
# If True, skips generative an AI response to the search query
|
||||
skip_gen_ai_answer_generation: bool = False
|
||||
|
||||
# If True, uses pro search instead of basic search
|
||||
use_agentic_search: bool = False
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_persona_fields(self) -> "OneShotQARequest":
|
||||
if self.persona_override_config is None and self.persona_id is None:
|
||||
|
||||
@@ -196,6 +196,8 @@ def get_answer_stream(
|
||||
retrieval_details=query_request.retrieval_options,
|
||||
rerank_settings=query_request.rerank_settings,
|
||||
db_session=db_session,
|
||||
use_agentic_search=query_request.use_agentic_search,
|
||||
skip_gen_ai_answer_generation=query_request.skip_gen_ai_answer_generation,
|
||||
)
|
||||
|
||||
packets = stream_chat_message_objects(
|
||||
|
||||
98
backend/onyx/agents/agent_search/basic/graph_builder.py
Normal file
98
backend/onyx/agents/agent_search/basic/graph_builder.py
Normal file
@@ -0,0 +1,98 @@
|
||||
from langgraph.graph import END
|
||||
from langgraph.graph import START
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
from onyx.agents.agent_search.basic.states import BasicInput
|
||||
from onyx.agents.agent_search.basic.states import BasicOutput
|
||||
from onyx.agents.agent_search.basic.states import BasicState
|
||||
from onyx.agents.agent_search.orchestration.nodes.basic_use_tool_response import (
|
||||
basic_use_tool_response,
|
||||
)
|
||||
from onyx.agents.agent_search.orchestration.nodes.llm_tool_choice import llm_tool_choice
|
||||
from onyx.agents.agent_search.orchestration.nodes.prepare_tool_input import (
|
||||
prepare_tool_input,
|
||||
)
|
||||
from onyx.agents.agent_search.orchestration.nodes.tool_call import tool_call
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def basic_graph_builder() -> StateGraph:
|
||||
graph = StateGraph(
|
||||
state_schema=BasicState,
|
||||
input=BasicInput,
|
||||
output=BasicOutput,
|
||||
)
|
||||
|
||||
### Add nodes ###
|
||||
|
||||
graph.add_node(
|
||||
node="prepare_tool_input",
|
||||
action=prepare_tool_input,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
node="llm_tool_choice",
|
||||
action=llm_tool_choice,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
node="tool_call",
|
||||
action=tool_call,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
node="basic_use_tool_response",
|
||||
action=basic_use_tool_response,
|
||||
)
|
||||
|
||||
### Add edges ###
|
||||
|
||||
graph.add_edge(start_key=START, end_key="prepare_tool_input")
|
||||
|
||||
graph.add_edge(start_key="prepare_tool_input", end_key="llm_tool_choice")
|
||||
|
||||
graph.add_conditional_edges("llm_tool_choice", should_continue, ["tool_call", END])
|
||||
|
||||
graph.add_edge(
|
||||
start_key="tool_call",
|
||||
end_key="basic_use_tool_response",
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key="basic_use_tool_response",
|
||||
end_key=END,
|
||||
)
|
||||
|
||||
return graph
|
||||
|
||||
|
||||
def should_continue(state: BasicState) -> str:
|
||||
return (
|
||||
# If there are no tool calls, basic graph already streamed the answer
|
||||
END
|
||||
if state.tool_choice is None
|
||||
else "tool_call"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.context.search.models import SearchRequest
|
||||
from onyx.llm.factory import get_default_llms
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
|
||||
|
||||
graph = basic_graph_builder()
|
||||
compiled_graph = graph.compile()
|
||||
# TODO: unify basic input
|
||||
input = BasicInput(logs="")
|
||||
primary_llm, fast_llm = get_default_llms()
|
||||
with get_session_context_manager() as db_session:
|
||||
config, _ = get_test_config(
|
||||
db_session=db_session,
|
||||
primary_llm=primary_llm,
|
||||
fast_llm=fast_llm,
|
||||
search_request=SearchRequest(query="How does onyx use FastAPI?"),
|
||||
)
|
||||
compiled_graph.invoke(input, config={"metadata": {"config": config}})
|
||||
42
backend/onyx/agents/agent_search/basic/states.py
Normal file
42
backend/onyx/agents/agent_search/basic/states.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from typing import TypedDict
|
||||
|
||||
from langchain_core.messages import AIMessageChunk
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.agents.agent_search.orchestration.states import ToolCallUpdate
|
||||
from onyx.agents.agent_search.orchestration.states import ToolChoiceInput
|
||||
from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
|
||||
|
||||
# States contain values that change over the course of graph execution,
|
||||
# Config is for values that are set at the start and never change.
|
||||
# If you are using a value from the config and realize it needs to change,
|
||||
# you should add it to the state and use/update the version in the state.
|
||||
|
||||
## Graph Input State
|
||||
|
||||
|
||||
class BasicInput(BaseModel):
|
||||
# TODO: subclass global log update state
|
||||
logs: str = ""
|
||||
|
||||
|
||||
## Graph Output State
|
||||
|
||||
|
||||
class BasicOutput(TypedDict):
|
||||
tool_call_chunk: AIMessageChunk
|
||||
|
||||
|
||||
## Update States
|
||||
|
||||
|
||||
## Graph State
|
||||
|
||||
|
||||
class BasicState(
|
||||
BasicInput,
|
||||
ToolChoiceInput,
|
||||
ToolCallUpdate,
|
||||
ToolChoiceUpdate,
|
||||
):
|
||||
pass
|
||||
69
backend/onyx/agents/agent_search/basic/utils.py
Normal file
69
backend/onyx/agents/agent_search/basic/utils.py
Normal file
@@ -0,0 +1,69 @@
|
||||
from collections.abc import Iterator
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.callbacks.manager import dispatch_custom_event
|
||||
from langchain_core.messages import AIMessageChunk
|
||||
from langchain_core.messages import BaseMessage
|
||||
|
||||
from onyx.chat.models import LlmDoc
|
||||
from onyx.chat.stream_processing.answer_response_handler import AnswerResponseHandler
|
||||
from onyx.chat.stream_processing.answer_response_handler import CitationResponseHandler
|
||||
from onyx.chat.stream_processing.answer_response_handler import (
|
||||
PassThroughAnswerResponseHandler,
|
||||
)
|
||||
from onyx.chat.stream_processing.utils import map_document_id_order
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# TODO: handle citations here; below is what was previously passed in
|
||||
# see basic_use_tool_response.py for where these variables come from
|
||||
# answer_handler = CitationResponseHandler(
|
||||
# context_docs=final_search_results,
|
||||
# final_doc_id_to_rank_map=map_document_id_order(final_search_results),
|
||||
# display_doc_id_to_rank_map=map_document_id_order(displayed_search_results),
|
||||
# )
|
||||
|
||||
|
||||
def process_llm_stream(
|
||||
stream: Iterator[BaseMessage],
|
||||
should_stream_answer: bool,
|
||||
final_search_results: list[LlmDoc] | None = None,
|
||||
displayed_search_results: list[LlmDoc] | None = None,
|
||||
) -> AIMessageChunk:
|
||||
tool_call_chunk = AIMessageChunk(content="")
|
||||
# for response in response_handler_manager.handle_llm_response(stream):
|
||||
|
||||
if final_search_results and displayed_search_results:
|
||||
answer_handler: AnswerResponseHandler = CitationResponseHandler(
|
||||
context_docs=final_search_results,
|
||||
final_doc_id_to_rank_map=map_document_id_order(final_search_results),
|
||||
display_doc_id_to_rank_map=map_document_id_order(displayed_search_results),
|
||||
)
|
||||
else:
|
||||
answer_handler = PassThroughAnswerResponseHandler()
|
||||
|
||||
full_answer = ""
|
||||
# This stream will be the llm answer if no tool is chosen. When a tool is chosen,
|
||||
# the stream will contain AIMessageChunks with tool call information.
|
||||
for response in stream:
|
||||
answer_piece = response.content
|
||||
if not isinstance(answer_piece, str):
|
||||
# TODO: handle non-string content
|
||||
logger.warning(f"Received non-string content: {type(answer_piece)}")
|
||||
answer_piece = str(answer_piece)
|
||||
full_answer += answer_piece
|
||||
|
||||
if isinstance(response, AIMessageChunk) and (
|
||||
response.tool_call_chunks or response.tool_calls
|
||||
):
|
||||
tool_call_chunk += response # type: ignore
|
||||
elif should_stream_answer:
|
||||
for response_part in answer_handler.handle_response_part(response, []):
|
||||
dispatch_custom_event(
|
||||
"basic_response",
|
||||
response_part,
|
||||
)
|
||||
|
||||
logger.info(f"Full answer: {full_answer}")
|
||||
return cast(AIMessageChunk, tool_call_chunk)
|
||||
21
backend/onyx/agents/agent_search/core_state.py
Normal file
21
backend/onyx/agents/agent_search/core_state.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from operator import add
|
||||
from typing import Annotated
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class CoreState(BaseModel):
|
||||
"""
|
||||
This is the core state that is shared across all subgraphs.
|
||||
"""
|
||||
|
||||
base_question: str = ""
|
||||
log_messages: Annotated[list[str], add] = []
|
||||
|
||||
|
||||
class SubgraphCoreState(BaseModel):
|
||||
"""
|
||||
This is the core state that is shared across all subgraphs.
|
||||
"""
|
||||
|
||||
log_messages: Annotated[list[str], add]
|
||||
66
backend/onyx/agents/agent_search/db_operations.py
Normal file
66
backend/onyx/agents/agent_search/db_operations.py
Normal file
@@ -0,0 +1,66 @@
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.models import AgentSubQuery
|
||||
from onyx.db.models import AgentSubQuestion
|
||||
|
||||
|
||||
def create_sub_question(
|
||||
db_session: Session,
|
||||
chat_session_id: UUID,
|
||||
primary_message_id: int,
|
||||
sub_question: str,
|
||||
sub_answer: str,
|
||||
) -> AgentSubQuestion:
|
||||
"""Create a new sub-question record in the database."""
|
||||
sub_q = AgentSubQuestion(
|
||||
chat_session_id=chat_session_id,
|
||||
primary_question_id=primary_message_id,
|
||||
sub_question=sub_question,
|
||||
sub_answer=sub_answer,
|
||||
)
|
||||
db_session.add(sub_q)
|
||||
db_session.flush()
|
||||
return sub_q
|
||||
|
||||
|
||||
def create_sub_query(
|
||||
db_session: Session,
|
||||
chat_session_id: UUID,
|
||||
parent_question_id: int,
|
||||
sub_query: str,
|
||||
) -> AgentSubQuery:
|
||||
"""Create a new sub-query record in the database."""
|
||||
sub_q = AgentSubQuery(
|
||||
chat_session_id=chat_session_id,
|
||||
parent_question_id=parent_question_id,
|
||||
sub_query=sub_query,
|
||||
)
|
||||
db_session.add(sub_q)
|
||||
db_session.flush()
|
||||
return sub_q
|
||||
|
||||
|
||||
def get_sub_questions_for_message(
|
||||
db_session: Session,
|
||||
primary_message_id: int,
|
||||
) -> list[AgentSubQuestion]:
|
||||
"""Get all sub-questions for a given primary message."""
|
||||
return (
|
||||
db_session.query(AgentSubQuestion)
|
||||
.filter(AgentSubQuestion.primary_question_id == primary_message_id)
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
def get_sub_queries_for_question(
|
||||
db_session: Session,
|
||||
sub_question_id: int,
|
||||
) -> list[AgentSubQuery]:
|
||||
"""Get all sub-queries for a given sub-question."""
|
||||
return (
|
||||
db_session.query(AgentSubQuery)
|
||||
.filter(AgentSubQuery.parent_question_id == sub_question_id)
|
||||
.all()
|
||||
)
|
||||
@@ -0,0 +1,55 @@
|
||||
from collections.abc import Hashable
|
||||
from datetime import datetime
|
||||
|
||||
from langgraph.types import Send
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.initial__individual_sub_answer__subgraph.states import (
|
||||
AnswerQuestionInput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.initial__individual_sub_answer__subgraph.states import (
|
||||
AnswerQuestionOutput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.initial__retrieval_sub_answers__subgraph.states import (
|
||||
SearchSQState,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import make_question_id
|
||||
|
||||
|
||||
def parallelize_initial_sub_question_answering(
|
||||
state: SearchSQState,
|
||||
) -> list[Send | Hashable]:
|
||||
now_start = datetime.now()
|
||||
if len(state.initial_decomp_questions) > 0:
|
||||
# sub_question_record_ids = [subq_record.id for subq_record in state["sub_question_records"]]
|
||||
# if len(state["sub_question_records"]) == 0:
|
||||
# if state["config"].use_persistence:
|
||||
# raise ValueError("No sub-questions found for initial decompozed questions")
|
||||
# else:
|
||||
# # in this case, we are doing retrieval on the original question.
|
||||
# # to make all the logic consistent, we create a new sub-question
|
||||
# # with the same content as the original question
|
||||
# sub_question_record_ids = [1] * len(state["initial_decomp_questions"])
|
||||
|
||||
return [
|
||||
Send(
|
||||
"answer_query_subgraph",
|
||||
AnswerQuestionInput(
|
||||
question=question,
|
||||
question_id=make_question_id(0, question_nr + 1),
|
||||
log_messages=[
|
||||
f"{now_start} -- Main Edge - Parallelize Initial Sub-question Answering"
|
||||
],
|
||||
),
|
||||
)
|
||||
for question_nr, question in enumerate(state.initial_decomp_questions)
|
||||
]
|
||||
|
||||
else:
|
||||
return [
|
||||
Send(
|
||||
"ingest_answers",
|
||||
AnswerQuestionOutput(
|
||||
answer_results=[],
|
||||
),
|
||||
)
|
||||
]
|
||||
@@ -0,0 +1,100 @@
|
||||
from langgraph.graph import END
|
||||
from langgraph.graph import START
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.initial__consolidate_sub_answers__subgraph.edges import (
|
||||
parallelize_initial_sub_question_answering,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.initial__consolidate_sub_answers__subgraph.nodes.ingest_initial_sub_answers import (
|
||||
ingest_initial_sub_answers,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.initial__consolidate_sub_answers__subgraph.nodes.initial_decomposition import (
|
||||
initial_sub_question_creation,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.initial__consolidate_sub_answers__subgraph.states import (
|
||||
SQInput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.initial__consolidate_sub_answers__subgraph.states import (
|
||||
SQState,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.initial__individual_sub_answer__subgraph.graph_builder import (
|
||||
answer_query_graph_builder,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
test_mode = False
|
||||
|
||||
|
||||
def initial_sq_subgraph_builder(test_mode: bool = False) -> StateGraph:
|
||||
graph = StateGraph(
|
||||
state_schema=SQState,
|
||||
input=SQInput,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
node="initial_sub_question_creation",
|
||||
action=initial_sub_question_creation,
|
||||
)
|
||||
answer_query_subgraph = answer_query_graph_builder().compile()
|
||||
graph.add_node(
|
||||
node="answer_query_subgraph",
|
||||
action=answer_query_subgraph,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
node="ingest_initial_sub_question_answers",
|
||||
action=ingest_initial_sub_answers,
|
||||
)
|
||||
|
||||
### Add edges ###
|
||||
|
||||
# raph.add_edge(start_key=START, end_key="base_raw_search_subgraph")
|
||||
|
||||
# graph.add_edge(
|
||||
# start_key="agent_search_start",
|
||||
# end_key="entity_term_extraction_llm",
|
||||
# )
|
||||
|
||||
graph.add_edge(
|
||||
start_key=START,
|
||||
end_key="initial_sub_question_creation",
|
||||
)
|
||||
|
||||
# graph.add_edge(
|
||||
# start_key="LLM",
|
||||
# end_key=END,
|
||||
# )
|
||||
|
||||
# graph.add_edge(
|
||||
# start_key=START,
|
||||
# end_key="initial_sub_question_creation",
|
||||
# )
|
||||
|
||||
graph.add_conditional_edges(
|
||||
source="initial_sub_question_creation",
|
||||
path=parallelize_initial_sub_question_answering,
|
||||
path_map=["answer_query_subgraph"],
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="answer_query_subgraph",
|
||||
end_key="ingest_initial_sub_question_answers",
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key="ingest_initial_sub_question_answers",
|
||||
end_key=END,
|
||||
)
|
||||
|
||||
# graph.add_edge(
|
||||
# start_key="generate_refined_answer",
|
||||
# end_key="check_refined_answer",
|
||||
# )
|
||||
|
||||
# graph.add_edge(
|
||||
# start_key="check_refined_answer",
|
||||
# end_key=END,
|
||||
# )
|
||||
|
||||
return graph
|
||||
@@ -0,0 +1,45 @@
|
||||
from datetime import datetime
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.initial__individual_sub_answer__subgraph.states import (
|
||||
AnswerQuestionOutput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main__graph.operations import logger
|
||||
from onyx.agents.agent_search.deep_search_a.main__graph.states import (
|
||||
DecompAnswersUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.operators import (
|
||||
dedup_inference_sections,
|
||||
)
|
||||
|
||||
|
||||
def ingest_initial_sub_answers(
|
||||
state: AnswerQuestionOutput,
|
||||
) -> DecompAnswersUpdate:
|
||||
now_start = datetime.now()
|
||||
|
||||
logger.info(f"--------{now_start}--------INGEST ANSWERS---")
|
||||
documents = []
|
||||
context_documents = []
|
||||
cited_docs = []
|
||||
answer_results = state.answer_results if hasattr(state, "answer_results") else []
|
||||
for answer_result in answer_results:
|
||||
documents.extend(answer_result.documents)
|
||||
context_documents.extend(answer_result.context_documents)
|
||||
cited_docs.extend(answer_result.cited_docs)
|
||||
now_end = datetime.now()
|
||||
|
||||
logger.debug(
|
||||
f"--------{now_end}--{now_end - now_start}--------INGEST ANSWERS END---"
|
||||
)
|
||||
|
||||
return DecompAnswersUpdate(
|
||||
# Deduping is done by the documents operator for the main graph
|
||||
# so we might not need to dedup here
|
||||
documents=dedup_inference_sections(documents, []),
|
||||
context_documents=dedup_inference_sections(context_documents, []),
|
||||
cited_docs=dedup_inference_sections(cited_docs, []),
|
||||
decomp_answer_results=answer_results,
|
||||
log_messages=[
|
||||
f"{now_start} -- Main - Ingest initial processed sub questions, Time taken: {now_end - now_start}"
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,138 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.callbacks.manager import dispatch_custom_event
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import merge_content
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.initial__retrieval_sub_answers__subgraph.states import (
|
||||
SearchSQState,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main__graph.models import (
|
||||
AgentRefinedMetrics,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main__graph.operations import (
|
||||
dispatch_subquestion,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main__graph.operations import logger
|
||||
from onyx.agents.agent_search.deep_search_a.main__graph.states import BaseDecompUpdate
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
build_history_prompt,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import (
|
||||
INITIAL_DECOMPOSITION_PROMPT_QUESTIONS,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import (
|
||||
INITIAL_DECOMPOSITION_PROMPT_QUESTIONS_AFTER_SEARCH,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import dispatch_separated
|
||||
from onyx.chat.models import StreamStopInfo
|
||||
from onyx.chat.models import StreamStopReason
|
||||
from onyx.chat.models import SubQuestionPiece
|
||||
from onyx.configs.agent_configs import AGENT_NUM_DOCS_FOR_DECOMPOSITION
|
||||
|
||||
|
||||
def initial_sub_question_creation(
|
||||
state: SearchSQState, config: RunnableConfig
|
||||
) -> BaseDecompUpdate:
|
||||
now_start = datetime.now()
|
||||
|
||||
logger.info(f"--------{now_start}--------BASE DECOMP START---")
|
||||
|
||||
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
question = agent_a_config.search_request.query
|
||||
chat_session_id = agent_a_config.chat_session_id
|
||||
primary_message_id = agent_a_config.message_id
|
||||
perform_initial_search_decomposition = (
|
||||
agent_a_config.perform_initial_search_decomposition
|
||||
)
|
||||
# Get the rewritten queries in a defined format
|
||||
model = agent_a_config.fast_llm
|
||||
|
||||
history = build_history_prompt(agent_a_config, question)
|
||||
|
||||
# Use the initial search results to inform the decomposition
|
||||
sample_doc_str = state.sample_doc_str if hasattr(state, "sample_doc_str") else ""
|
||||
|
||||
if not chat_session_id or not primary_message_id:
|
||||
raise ValueError(
|
||||
"chat_session_id and message_id must be provided for agent search"
|
||||
)
|
||||
agent_start_time = datetime.now()
|
||||
|
||||
# Initial search to inform decomposition. Just get top 3 fits
|
||||
|
||||
if perform_initial_search_decomposition:
|
||||
sample_doc_str = "\n\n".join(
|
||||
[
|
||||
doc.combined_content
|
||||
for doc in state.exploratory_search_results[
|
||||
:AGENT_NUM_DOCS_FOR_DECOMPOSITION
|
||||
]
|
||||
]
|
||||
)
|
||||
|
||||
decomposition_prompt = (
|
||||
INITIAL_DECOMPOSITION_PROMPT_QUESTIONS_AFTER_SEARCH.format(
|
||||
question=question, sample_doc_str=sample_doc_str, history=history
|
||||
)
|
||||
)
|
||||
|
||||
else:
|
||||
decomposition_prompt = INITIAL_DECOMPOSITION_PROMPT_QUESTIONS.format(
|
||||
question=question, history=history
|
||||
)
|
||||
|
||||
# Start decomposition
|
||||
|
||||
msg = [HumanMessage(content=decomposition_prompt)]
|
||||
|
||||
# Send the initial question as a subquestion with number 0
|
||||
dispatch_custom_event(
|
||||
"decomp_qs",
|
||||
SubQuestionPiece(
|
||||
sub_question=question,
|
||||
level=0,
|
||||
level_question_nr=0,
|
||||
),
|
||||
)
|
||||
# dispatches custom events for subquestion tokens, adding in subquestion ids.
|
||||
streamed_tokens = dispatch_separated(model.stream(msg), dispatch_subquestion(0))
|
||||
|
||||
stop_event = StreamStopInfo(
|
||||
stop_reason=StreamStopReason.FINISHED,
|
||||
stream_type="sub_questions",
|
||||
level=0,
|
||||
)
|
||||
dispatch_custom_event("stream_finished", stop_event)
|
||||
|
||||
deomposition_response = merge_content(*streamed_tokens)
|
||||
|
||||
# this call should only return strings. Commenting out for efficiency
|
||||
# assert [type(tok) == str for tok in streamed_tokens]
|
||||
|
||||
# use no-op cast() instead of str() which runs code
|
||||
# list_of_subquestions = clean_and_parse_list_string(cast(str, response))
|
||||
list_of_subqs = cast(str, deomposition_response).split("\n")
|
||||
|
||||
decomp_list: list[str] = [sq.strip() for sq in list_of_subqs if sq.strip() != ""]
|
||||
|
||||
now_end = datetime.now()
|
||||
|
||||
logger.info(
|
||||
f"{now_start} -- INITIAL SUBQUESTION ANSWERING - Base Decomposition, Time taken: {now_end - now_start}"
|
||||
)
|
||||
|
||||
return BaseDecompUpdate(
|
||||
initial_decomp_questions=decomp_list,
|
||||
agent_start_time=agent_start_time,
|
||||
agent_refined_start_time=None,
|
||||
agent_refined_end_time=None,
|
||||
agent_refined_metrics=AgentRefinedMetrics(
|
||||
refined_doc_boost_factor=None,
|
||||
refined_question_boost_factor=None,
|
||||
duration__s=None,
|
||||
),
|
||||
)
|
||||
@@ -0,0 +1,37 @@
|
||||
from typing import TypedDict
|
||||
|
||||
from onyx.agents.agent_search.core_state import CoreState
|
||||
from onyx.agents.agent_search.deep_search_a.main__graph.states import BaseDecompUpdate
|
||||
from onyx.agents.agent_search.deep_search_a.main__graph.states import (
|
||||
DecompAnswersUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main__graph.states import (
|
||||
InitialAnswerUpdate,
|
||||
)
|
||||
|
||||
### States ###
|
||||
|
||||
|
||||
class SQInput(CoreState):
|
||||
pass
|
||||
|
||||
|
||||
## Graph State
|
||||
|
||||
|
||||
class SQState(
|
||||
# This includes the core state
|
||||
SQInput,
|
||||
BaseDecompUpdate,
|
||||
InitialAnswerUpdate,
|
||||
DecompAnswersUpdate,
|
||||
):
|
||||
# expanded_retrieval_result: Annotated[list[ExpandedRetrievalResult], add]
|
||||
pass
|
||||
|
||||
|
||||
## Graph Output State - presently not used
|
||||
|
||||
|
||||
class SQOutput(TypedDict):
|
||||
log_messages: list[str]
|
||||
@@ -0,0 +1,29 @@
|
||||
from collections.abc import Hashable
|
||||
from datetime import datetime
|
||||
|
||||
from langgraph.types import Send
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.initial__individual_sub_answer__subgraph.states import (
|
||||
AnswerQuestionInput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.util__expanded_retrieval__subgraph.states import (
|
||||
ExpandedRetrievalInput,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def send_to_expanded_retrieval(state: AnswerQuestionInput) -> Send | Hashable:
|
||||
logger.debug("sending to expanded retrieval via edge")
|
||||
now_start = datetime.now()
|
||||
|
||||
return Send(
|
||||
"initial_sub_question_expanded_retrieval",
|
||||
ExpandedRetrievalInput(
|
||||
question=state.question,
|
||||
base_search=False,
|
||||
sub_question_id=state.question_id,
|
||||
log_messages=[f"{now_start} -- Sending to expanded retrieval"],
|
||||
),
|
||||
)
|
||||
@@ -0,0 +1,126 @@
|
||||
from langgraph.graph import END
|
||||
from langgraph.graph import START
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.initial__individual_sub_answer__subgraph.edges import (
|
||||
send_to_expanded_retrieval,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.initial__individual_sub_answer__subgraph.nodes.answer_check import (
|
||||
answer_check,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.initial__individual_sub_answer__subgraph.nodes.answer_generation import (
|
||||
answer_generation,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.initial__individual_sub_answer__subgraph.nodes.format_answer import (
|
||||
format_answer,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.initial__individual_sub_answer__subgraph.nodes.ingest_retrieval import (
|
||||
ingest_retrieval,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.initial__individual_sub_answer__subgraph.states import (
|
||||
AnswerQuestionInput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.initial__individual_sub_answer__subgraph.states import (
|
||||
AnswerQuestionOutput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.initial__individual_sub_answer__subgraph.states import (
|
||||
AnswerQuestionState,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.util__expanded_retrieval__subgraph.graph_builder import (
|
||||
expanded_retrieval_graph_builder,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def answer_query_graph_builder() -> StateGraph:
|
||||
graph = StateGraph(
|
||||
state_schema=AnswerQuestionState,
|
||||
input=AnswerQuestionInput,
|
||||
output=AnswerQuestionOutput,
|
||||
)
|
||||
|
||||
### Add nodes ###
|
||||
|
||||
expanded_retrieval = expanded_retrieval_graph_builder().compile()
|
||||
graph.add_node(
|
||||
node="initial_sub_question_expanded_retrieval",
|
||||
action=expanded_retrieval,
|
||||
)
|
||||
graph.add_node(
|
||||
node="answer_check",
|
||||
action=answer_check,
|
||||
)
|
||||
graph.add_node(
|
||||
node="answer_generation",
|
||||
action=answer_generation,
|
||||
)
|
||||
graph.add_node(
|
||||
node="format_answer",
|
||||
action=format_answer,
|
||||
)
|
||||
graph.add_node(
|
||||
node="ingest_retrieval",
|
||||
action=ingest_retrieval,
|
||||
)
|
||||
|
||||
### Add edges ###
|
||||
|
||||
graph.add_conditional_edges(
|
||||
source=START,
|
||||
path=send_to_expanded_retrieval,
|
||||
path_map=["initial_sub_question_expanded_retrieval"],
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="initial_sub_question_expanded_retrieval",
|
||||
end_key="ingest_retrieval",
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="ingest_retrieval",
|
||||
end_key="answer_generation",
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="answer_generation",
|
||||
end_key="answer_check",
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="answer_check",
|
||||
end_key="format_answer",
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="format_answer",
|
||||
end_key=END,
|
||||
)
|
||||
|
||||
return graph
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.llm.factory import get_default_llms
|
||||
from onyx.context.search.models import SearchRequest
|
||||
|
||||
graph = answer_query_graph_builder()
|
||||
compiled_graph = graph.compile()
|
||||
primary_llm, fast_llm = get_default_llms()
|
||||
search_request = SearchRequest(
|
||||
query="what can you do with onyx or danswer?",
|
||||
)
|
||||
with get_session_context_manager() as db_session:
|
||||
agent_search_config, search_tool = get_test_config(
|
||||
db_session, primary_llm, fast_llm, search_request
|
||||
)
|
||||
inputs = AnswerQuestionInput(
|
||||
question="what can you do with onyx?",
|
||||
question_id="0_0",
|
||||
log_messages=[],
|
||||
)
|
||||
for thing in compiled_graph.stream(
|
||||
input=inputs,
|
||||
config={"configurable": {"config": agent_search_config}},
|
||||
# debug=True,
|
||||
# subgraphs=True,
|
||||
):
|
||||
logger.debug(thing)
|
||||
@@ -0,0 +1,8 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
### Models ###
|
||||
|
||||
|
||||
class AnswerRetrievalStats(BaseModel):
|
||||
answer_retrieval_stats: dict[str, float | int]
|
||||
@@ -0,0 +1,59 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import merge_message_runs
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.initial__individual_sub_answer__subgraph.states import (
|
||||
AnswerQuestionState,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.initial__individual_sub_answer__subgraph.states import (
|
||||
QACheckUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import SUB_CHECK_NO
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import SUB_CHECK_PROMPT
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import UNKNOWN_ANSWER
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
|
||||
|
||||
|
||||
def answer_check(state: AnswerQuestionState, config: RunnableConfig) -> QACheckUpdate:
|
||||
now_start = datetime.now()
|
||||
|
||||
level, question_num = parse_question_id(state.question_id)
|
||||
if state.answer == UNKNOWN_ANSWER:
|
||||
now_end = datetime.now()
|
||||
return QACheckUpdate(
|
||||
answer_quality=SUB_CHECK_NO,
|
||||
log_messages=[
|
||||
f"{now_start} -- Answer check SQ-{level}-{question_num} - unknown answer, Time taken: {now_end - now_start}"
|
||||
],
|
||||
)
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=SUB_CHECK_PROMPT.format(
|
||||
question=state.question,
|
||||
base_answer=state.answer,
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
agent_searchch_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
fast_llm = agent_searchch_config.fast_llm
|
||||
response = list(
|
||||
fast_llm.stream(
|
||||
prompt=msg,
|
||||
)
|
||||
)
|
||||
|
||||
quality_str = merge_message_runs(response, chunk_separator="")[0].content
|
||||
|
||||
now_end = datetime.now()
|
||||
return QACheckUpdate(
|
||||
answer_quality=quality_str,
|
||||
log_messages=[
|
||||
f"""{now_start} -- Answer check SQ-{level}-{question_num} - Answer quality: {quality_str},
|
||||
Time taken: {now_end - now_start}"""
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,126 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.callbacks.manager import dispatch_custom_event
|
||||
from langchain_core.messages import merge_message_runs
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.initial__individual_sub_answer__subgraph.states import (
|
||||
AnswerQuestionState,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.initial__individual_sub_answer__subgraph.states import (
|
||||
QAGenerationUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
build_sub_question_answer_prompt,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import NO_RECOVERED_DOCS
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import get_answer_citation_ids
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_persona_agent_prompt_expressions,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
|
||||
from onyx.chat.models import AgentAnswerPiece
|
||||
from onyx.chat.models import StreamStopInfo
|
||||
from onyx.chat.models import StreamStopReason
|
||||
from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def answer_generation(
|
||||
state: AnswerQuestionState, config: RunnableConfig
|
||||
) -> QAGenerationUpdate:
|
||||
now_start = datetime.now()
|
||||
logger.info(f"--------{now_start}--------START ANSWER GENERATION---")
|
||||
|
||||
agent_search_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
question = state.question
|
||||
state.documents
|
||||
level, question_nr = parse_question_id(state.question_id)
|
||||
context_docs = state.context_documents[:AGENT_MAX_ANSWER_CONTEXT_DOCS]
|
||||
persona_contextualized_prompt = get_persona_agent_prompt_expressions(
|
||||
agent_search_config.search_request.persona
|
||||
).contextualized_prompt
|
||||
|
||||
if len(context_docs) == 0:
|
||||
answer_str = NO_RECOVERED_DOCS
|
||||
dispatch_custom_event(
|
||||
"sub_answers",
|
||||
AgentAnswerPiece(
|
||||
answer_piece=answer_str,
|
||||
level=level,
|
||||
level_question_nr=question_nr,
|
||||
answer_type="agent_sub_answer",
|
||||
),
|
||||
)
|
||||
else:
|
||||
logger.debug(f"Number of verified retrieval docs: {len(context_docs)}")
|
||||
|
||||
fast_llm = agent_search_config.fast_llm
|
||||
msg = build_sub_question_answer_prompt(
|
||||
question=question,
|
||||
original_question=agent_search_config.search_request.query,
|
||||
docs=context_docs,
|
||||
persona_specification=persona_contextualized_prompt,
|
||||
config=fast_llm.config,
|
||||
)
|
||||
|
||||
response: list[str | list[str | dict[str, Any]]] = []
|
||||
dispatch_timings: list[float] = []
|
||||
for message in fast_llm.stream(
|
||||
prompt=msg,
|
||||
):
|
||||
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
|
||||
content = message.content
|
||||
if not isinstance(content, str):
|
||||
raise ValueError(
|
||||
f"Expected content to be a string, but got {type(content)}"
|
||||
)
|
||||
start_stream_token = datetime.now()
|
||||
dispatch_custom_event(
|
||||
"sub_answers",
|
||||
AgentAnswerPiece(
|
||||
answer_piece=content,
|
||||
level=level,
|
||||
level_question_nr=question_nr,
|
||||
answer_type="agent_sub_answer",
|
||||
),
|
||||
)
|
||||
end_stream_token = datetime.now()
|
||||
dispatch_timings.append(
|
||||
(end_stream_token - start_stream_token).microseconds
|
||||
)
|
||||
response.append(content)
|
||||
|
||||
answer_str = merge_message_runs(response, chunk_separator="")[0].content
|
||||
logger.info(
|
||||
f"Average dispatch time: {sum(dispatch_timings) / len(dispatch_timings)}"
|
||||
)
|
||||
|
||||
answer_citation_ids = get_answer_citation_ids(answer_str)
|
||||
cited_docs = [context_docs[id] for id in answer_citation_ids]
|
||||
|
||||
stop_event = StreamStopInfo(
|
||||
stop_reason=StreamStopReason.FINISHED,
|
||||
stream_type="sub_answer",
|
||||
level=level,
|
||||
level_question_nr=question_nr,
|
||||
)
|
||||
dispatch_custom_event("stream_finished", stop_event)
|
||||
|
||||
now_end = datetime.now()
|
||||
|
||||
logger.info(
|
||||
f"{now_start} -- Answer generation SQ-{level} - Q{question_nr} - Time taken: {now_end - now_start}"
|
||||
)
|
||||
return QAGenerationUpdate(
|
||||
answer=answer_str,
|
||||
cited_docs=cited_docs,
|
||||
log_messages=[
|
||||
f"{now_start} -- Answer generation SQ-{level} - Q{question_nr} - Time taken: {now_end - now_start}"
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,29 @@
|
||||
from onyx.agents.agent_search.deep_search_a.initial__individual_sub_answer__subgraph.states import (
|
||||
AnswerQuestionOutput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.initial__individual_sub_answer__subgraph.states import (
|
||||
AnswerQuestionState,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import (
|
||||
QuestionAnswerResults,
|
||||
)
|
||||
|
||||
|
||||
def format_answer(state: AnswerQuestionState) -> AnswerQuestionOutput:
|
||||
return AnswerQuestionOutput(
|
||||
answer_results=[
|
||||
QuestionAnswerResults(
|
||||
question=state.question,
|
||||
question_id=state.question_id,
|
||||
quality=state.answer_quality
|
||||
if hasattr(state, "answer_quality")
|
||||
else "No",
|
||||
answer=state.answer,
|
||||
expanded_retrieval_results=state.expanded_retrieval_results,
|
||||
documents=state.documents,
|
||||
context_documents=state.context_documents,
|
||||
cited_docs=state.cited_docs,
|
||||
sub_question_retrieval_stats=state.sub_question_retrieval_stats,
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,22 @@
|
||||
from onyx.agents.agent_search.deep_search_a.initial__individual_sub_answer__subgraph.states import (
|
||||
RetrievalIngestionUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.util__expanded_retrieval__subgraph.states import (
|
||||
ExpandedRetrievalOutput,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkStats
|
||||
|
||||
|
||||
def ingest_retrieval(state: ExpandedRetrievalOutput) -> RetrievalIngestionUpdate:
|
||||
sub_question_retrieval_stats = (
|
||||
state.expanded_retrieval_result.sub_question_retrieval_stats
|
||||
)
|
||||
if sub_question_retrieval_stats is None:
|
||||
sub_question_retrieval_stats = [AgentChunkStats()]
|
||||
|
||||
return RetrievalIngestionUpdate(
|
||||
expanded_retrieval_results=state.expanded_retrieval_result.expanded_queries_results,
|
||||
documents=state.expanded_retrieval_result.reranked_documents,
|
||||
context_documents=state.expanded_retrieval_result.context_documents,
|
||||
sub_question_retrieval_stats=sub_question_retrieval_stats,
|
||||
)
|
||||
@@ -0,0 +1,72 @@
|
||||
from operator import add
|
||||
from typing import Annotated
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.agents.agent_search.core_state import SubgraphCoreState
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkStats
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import QueryResult
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import (
|
||||
QuestionAnswerResults,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.operators import (
|
||||
dedup_inference_sections,
|
||||
)
|
||||
from onyx.context.search.models import InferenceSection
|
||||
|
||||
|
||||
## Update States
|
||||
class QACheckUpdate(BaseModel):
|
||||
answer_quality: str = ""
|
||||
log_messages: list[str] = []
|
||||
|
||||
|
||||
class QAGenerationUpdate(BaseModel):
|
||||
answer: str = ""
|
||||
log_messages: list[str] = []
|
||||
cited_docs: Annotated[list[InferenceSection], dedup_inference_sections] = []
|
||||
# answer_stat: AnswerStats
|
||||
|
||||
|
||||
class RetrievalIngestionUpdate(BaseModel):
|
||||
expanded_retrieval_results: list[QueryResult] = []
|
||||
documents: Annotated[list[InferenceSection], dedup_inference_sections] = []
|
||||
context_documents: Annotated[list[InferenceSection], dedup_inference_sections] = []
|
||||
sub_question_retrieval_stats: AgentChunkStats = AgentChunkStats()
|
||||
|
||||
|
||||
## Graph Input State
|
||||
|
||||
|
||||
class AnswerQuestionInput(SubgraphCoreState):
|
||||
question: str = ""
|
||||
question_id: str = (
|
||||
"" # 0_0 is original question, everything else is <level>_<question_num>.
|
||||
)
|
||||
# level 0 is original question and first decomposition, level 1 is follow up, etc
|
||||
# question_num is a unique number per original question per level.
|
||||
|
||||
|
||||
## Graph State
|
||||
|
||||
|
||||
class AnswerQuestionState(
|
||||
AnswerQuestionInput,
|
||||
QAGenerationUpdate,
|
||||
QACheckUpdate,
|
||||
RetrievalIngestionUpdate,
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
## Graph Output State
|
||||
|
||||
|
||||
class AnswerQuestionOutput(BaseModel):
|
||||
"""
|
||||
This is a list of results even though each call of this subgraph only returns one result.
|
||||
This is because if we parallelize the answer query subgraph, there will be multiple
|
||||
results in a list so the add operator is used to add them together.
|
||||
"""
|
||||
|
||||
answer_results: Annotated[list[QuestionAnswerResults], add] = []
|
||||
@@ -0,0 +1,89 @@
|
||||
from langgraph.graph import END
|
||||
from langgraph.graph import START
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.initial__retrieval__subgraph.nodes.format_raw_search_results import (
|
||||
format_raw_search_results,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.initial__retrieval__subgraph.nodes.generate_raw_search_data import (
|
||||
generate_raw_search_data,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.initial__retrieval__subgraph.nodes.ingest_initial_base_retrieval import (
|
||||
ingest_initial_base_retrieval,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.initial__retrieval__subgraph.states import (
|
||||
BaseRawSearchInput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.initial__retrieval__subgraph.states import (
|
||||
BaseRawSearchOutput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.initial__retrieval__subgraph.states import (
|
||||
BaseRawSearchState,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.util__expanded_retrieval__subgraph.graph_builder import (
|
||||
expanded_retrieval_graph_builder,
|
||||
)
|
||||
|
||||
|
||||
def base_raw_search_graph_builder() -> StateGraph:
|
||||
graph = StateGraph(
|
||||
state_schema=BaseRawSearchState,
|
||||
input=BaseRawSearchInput,
|
||||
output=BaseRawSearchOutput,
|
||||
)
|
||||
|
||||
### Add nodes ###
|
||||
|
||||
graph.add_node(
|
||||
node="generate_raw_search_data",
|
||||
action=generate_raw_search_data,
|
||||
)
|
||||
|
||||
expanded_retrieval = expanded_retrieval_graph_builder().compile()
|
||||
graph.add_node(
|
||||
node="expanded_retrieval_base_search",
|
||||
action=expanded_retrieval,
|
||||
)
|
||||
graph.add_node(
|
||||
node="format_raw_search_results",
|
||||
action=format_raw_search_results,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
node="ingest_initial_base_retrieval",
|
||||
action=ingest_initial_base_retrieval,
|
||||
)
|
||||
|
||||
### Add edges ###
|
||||
|
||||
graph.add_edge(start_key=START, end_key="generate_raw_search_data")
|
||||
|
||||
graph.add_edge(
|
||||
start_key="generate_raw_search_data",
|
||||
end_key="expanded_retrieval_base_search",
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="expanded_retrieval_base_search",
|
||||
end_key="format_raw_search_results",
|
||||
)
|
||||
|
||||
# graph.add_edge(
|
||||
# start_key="expanded_retrieval_base_search",
|
||||
# end_key=END,
|
||||
# )
|
||||
|
||||
graph.add_edge(
|
||||
start_key="format_raw_search_results",
|
||||
end_key="ingest_initial_base_retrieval",
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key="ingest_initial_base_retrieval",
|
||||
end_key=END,
|
||||
)
|
||||
|
||||
return graph
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pass
|
||||
@@ -0,0 +1,20 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkStats
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import QueryResult
|
||||
from onyx.context.search.models import InferenceSection
|
||||
|
||||
### Models ###
|
||||
|
||||
|
||||
class AnswerRetrievalStats(BaseModel):
|
||||
answer_retrieval_stats: dict[str, float | int]
|
||||
|
||||
|
||||
class QuestionAnswerResults(BaseModel):
|
||||
question: str
|
||||
answer: str
|
||||
quality: str
|
||||
expanded_retrieval_results: list[QueryResult]
|
||||
documents: list[InferenceSection]
|
||||
sub_question_retrieval_stats: list[AgentChunkStats]
|
||||
@@ -0,0 +1,18 @@
|
||||
from onyx.agents.agent_search.deep_search_a.initial__retrieval__subgraph.states import (
|
||||
BaseRawSearchOutput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.util__expanded_retrieval__subgraph.states import (
|
||||
ExpandedRetrievalOutput,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def format_raw_search_results(state: ExpandedRetrievalOutput) -> BaseRawSearchOutput:
|
||||
logger.debug("format_raw_search_results")
|
||||
return BaseRawSearchOutput(
|
||||
base_expanded_retrieval_result=state.expanded_retrieval_result,
|
||||
# base_retrieval_results=[state.expanded_retrieval_result],
|
||||
# base_search_documents=[],
|
||||
)
|
||||
@@ -0,0 +1,25 @@
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.core_state import CoreState
|
||||
from onyx.agents.agent_search.deep_search_a.util__expanded_retrieval__subgraph.states import (
|
||||
ExpandedRetrievalInput,
|
||||
)
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def generate_raw_search_data(
|
||||
state: CoreState, config: RunnableConfig
|
||||
) -> ExpandedRetrievalInput:
|
||||
logger.debug("generate_raw_search_data")
|
||||
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
return ExpandedRetrievalInput(
|
||||
question=agent_a_config.search_request.query,
|
||||
base_search=True,
|
||||
sub_question_id=None, # This graph is always and only used for the original question
|
||||
log_messages=[],
|
||||
)
|
||||
@@ -0,0 +1,41 @@
|
||||
from datetime import datetime
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.initial__retrieval__subgraph.states import (
|
||||
BaseRawSearchOutput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main__graph.operations import logger
|
||||
from onyx.agents.agent_search.deep_search_a.main__graph.states import (
|
||||
ExpandedRetrievalUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkStats
|
||||
|
||||
|
||||
def ingest_initial_base_retrieval(
|
||||
state: BaseRawSearchOutput,
|
||||
) -> ExpandedRetrievalUpdate:
|
||||
now_start = datetime.now()
|
||||
|
||||
logger.info(f"--------{now_start}--------INGEST INITIAL RETRIEVAL---")
|
||||
|
||||
sub_question_retrieval_stats = (
|
||||
state.base_expanded_retrieval_result.sub_question_retrieval_stats
|
||||
)
|
||||
if sub_question_retrieval_stats is None:
|
||||
sub_question_retrieval_stats = AgentChunkStats()
|
||||
else:
|
||||
sub_question_retrieval_stats = sub_question_retrieval_stats
|
||||
|
||||
now_end = datetime.now()
|
||||
|
||||
logger.debug(
|
||||
f"--------{now_end}--{now_end - now_start}--------INGEST INITIAL RETRIEVAL END---"
|
||||
)
|
||||
|
||||
return ExpandedRetrievalUpdate(
|
||||
original_question_retrieval_results=state.base_expanded_retrieval_result.expanded_queries_results,
|
||||
all_original_question_documents=state.base_expanded_retrieval_result.context_documents,
|
||||
original_question_retrieval_stats=sub_question_retrieval_stats,
|
||||
log_messages=[
|
||||
f"{now_start} -- Main - Ingestion base retrieval, Time taken: {now_end - now_start}"
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,44 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.main__graph.states import (
|
||||
ExpandedRetrievalUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.util__expanded_retrieval__subgraph.models import (
|
||||
ExpandedRetrievalResult,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.util__expanded_retrieval__subgraph.states import (
|
||||
ExpandedRetrievalInput,
|
||||
)
|
||||
|
||||
## Update States
|
||||
|
||||
|
||||
## Graph Input State
|
||||
|
||||
|
||||
class BaseRawSearchInput(ExpandedRetrievalInput):
|
||||
pass
|
||||
|
||||
|
||||
## Graph Output State
|
||||
|
||||
|
||||
class BaseRawSearchOutput(BaseModel):
|
||||
"""
|
||||
This is a list of results even though each call of this subgraph only returns one result.
|
||||
This is because if we parallelize the answer query subgraph, there will be multiple
|
||||
results in a list so the add operator is used to add them together.
|
||||
"""
|
||||
|
||||
# base_search_documents: Annotated[list[InferenceSection], dedup_inference_sections]
|
||||
# base_retrieval_results: Annotated[list[ExpandedRetrievalResult], add]
|
||||
base_expanded_retrieval_result: ExpandedRetrievalResult = ExpandedRetrievalResult()
|
||||
|
||||
|
||||
## Graph State
|
||||
|
||||
|
||||
class BaseRawSearchState(
|
||||
BaseRawSearchInput, BaseRawSearchOutput, ExpandedRetrievalUpdate
|
||||
):
|
||||
pass
|
||||
@@ -0,0 +1,55 @@
|
||||
from collections.abc import Hashable
|
||||
from datetime import datetime
|
||||
|
||||
from langgraph.types import Send
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.initial__individual_sub_answer__subgraph.states import (
|
||||
AnswerQuestionInput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.initial__individual_sub_answer__subgraph.states import (
|
||||
AnswerQuestionOutput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.initial__retrieval_sub_answers__subgraph.states import (
|
||||
SearchSQState,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import make_question_id
|
||||
|
||||
|
||||
def parallelize_initial_sub_question_answering(
|
||||
state: SearchSQState,
|
||||
) -> list[Send | Hashable]:
|
||||
now_start = datetime.now()
|
||||
if len(state.initial_decomp_questions) > 0:
|
||||
# sub_question_record_ids = [subq_record.id for subq_record in state["sub_question_records"]]
|
||||
# if len(state["sub_question_records"]) == 0:
|
||||
# if state["config"].use_persistence:
|
||||
# raise ValueError("No sub-questions found for initial decompozed questions")
|
||||
# else:
|
||||
# # in this case, we are doing retrieval on the original question.
|
||||
# # to make all the logic consistent, we create a new sub-question
|
||||
# # with the same content as the original question
|
||||
# sub_question_record_ids = [1] * len(state["initial_decomp_questions"])
|
||||
|
||||
return [
|
||||
Send(
|
||||
"answer_query_subgraph",
|
||||
AnswerQuestionInput(
|
||||
question=question,
|
||||
question_id=make_question_id(0, question_nr + 1),
|
||||
log_messages=[
|
||||
f"{now_start} -- Main Edge - Parallelize Initial Sub-question Answering"
|
||||
],
|
||||
),
|
||||
)
|
||||
for question_nr, question in enumerate(state.initial_decomp_questions)
|
||||
]
|
||||
|
||||
else:
|
||||
return [
|
||||
Send(
|
||||
"ingest_answers",
|
||||
AnswerQuestionOutput(
|
||||
answer_results=[],
|
||||
),
|
||||
)
|
||||
]
|
||||
@@ -0,0 +1,139 @@
|
||||
from langgraph.graph import END
|
||||
from langgraph.graph import START
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.initial__consolidate_sub_answers__subgraph.graph_builder import (
|
||||
initial_sq_subgraph_builder,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.initial__retrieval__subgraph.graph_builder import (
|
||||
base_raw_search_graph_builder,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.initial__retrieval_sub_answers__subgraph.nodes.generate_initial_answer import (
|
||||
generate_initial_answer,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.initial__retrieval_sub_answers__subgraph.nodes.initial_answer_quality_check import (
|
||||
initial_answer_quality_check,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.initial__retrieval_sub_answers__subgraph.nodes.retrieval_consolidation import (
|
||||
retrieval_consolidation,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.initial__retrieval_sub_answers__subgraph.states import (
|
||||
SearchSQInput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.initial__retrieval_sub_answers__subgraph.states import (
|
||||
SearchSQState,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def initial_search_sq_subgraph_builder(test_mode: bool = False) -> StateGraph:
|
||||
graph = StateGraph(
|
||||
state_schema=SearchSQState,
|
||||
input=SearchSQInput,
|
||||
)
|
||||
|
||||
# graph.add_node(
|
||||
# node="initial_sub_question_creation",
|
||||
# action=initial_sub_question_creation,
|
||||
# )
|
||||
|
||||
sub_question_answering_subgraph = initial_sq_subgraph_builder().compile()
|
||||
graph.add_node(
|
||||
node="sub_question_answering_subgraph",
|
||||
action=sub_question_answering_subgraph,
|
||||
)
|
||||
|
||||
# answer_query_subgraph = answer_query_graph_builder().compile()
|
||||
# graph.add_node(
|
||||
# node="answer_query_subgraph",
|
||||
# action=answer_query_subgraph,
|
||||
# )
|
||||
|
||||
base_raw_search_subgraph = base_raw_search_graph_builder().compile()
|
||||
graph.add_node(
|
||||
node="base_raw_search_subgraph",
|
||||
action=base_raw_search_subgraph,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
node="retrieval_consolidation",
|
||||
action=retrieval_consolidation,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
node="generate_initial_answer",
|
||||
action=generate_initial_answer,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
node="initial_answer_quality_check",
|
||||
action=initial_answer_quality_check,
|
||||
)
|
||||
|
||||
### Add edges ###
|
||||
|
||||
# raph.add_edge(start_key=START, end_key="base_raw_search_subgraph")
|
||||
|
||||
graph.add_edge(
|
||||
start_key=START,
|
||||
end_key="base_raw_search_subgraph",
|
||||
)
|
||||
|
||||
# graph.add_edge(
|
||||
# start_key="agent_search_start",
|
||||
# end_key="entity_term_extraction_llm",
|
||||
# )
|
||||
|
||||
graph.add_edge(
|
||||
start_key=START,
|
||||
end_key="sub_question_answering_subgraph",
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key=["base_raw_search_subgraph", "sub_question_answering_subgraph"],
|
||||
end_key="retrieval_consolidation",
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key="retrieval_consolidation",
|
||||
end_key="generate_initial_answer",
|
||||
)
|
||||
|
||||
# graph.add_edge(
|
||||
# start_key="LLM",
|
||||
# end_key=END,
|
||||
# )
|
||||
|
||||
# graph.add_edge(
|
||||
# start_key=START,
|
||||
# end_key="initial_sub_question_creation",
|
||||
# )
|
||||
|
||||
graph.add_edge(
|
||||
start_key="retrieval_consolidation",
|
||||
end_key="generate_initial_answer",
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key="generate_initial_answer",
|
||||
end_key="initial_answer_quality_check",
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key="initial_answer_quality_check",
|
||||
end_key=END,
|
||||
)
|
||||
|
||||
# graph.add_edge(
|
||||
# start_key="generate_refined_answer",
|
||||
# end_key="check_refined_answer",
|
||||
# )
|
||||
|
||||
# graph.add_edge(
|
||||
# start_key="check_refined_answer",
|
||||
# end_key=END,
|
||||
# )
|
||||
|
||||
return graph
|
||||
@@ -0,0 +1,276 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.callbacks.manager import dispatch_custom_event
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import merge_content
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.initial__retrieval_sub_answers__subgraph.states import (
|
||||
SearchSQState,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main__graph.models import AgentBaseMetrics
|
||||
from onyx.agents.agent_search.deep_search_a.main__graph.operations import (
|
||||
calculate_initial_agent_stats,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main__graph.operations import get_query_info
|
||||
from onyx.agents.agent_search.deep_search_a.main__graph.operations import logger
|
||||
from onyx.agents.agent_search.deep_search_a.main__graph.operations import (
|
||||
remove_document_citations,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main__graph.states import (
|
||||
InitialAnswerUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
get_prompt_enrichment_components,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
trim_prompt_piece,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import InitialAgentResultStats
|
||||
from onyx.agents.agent_search.shared_graph_utils.operators import (
|
||||
dedup_inference_sections,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import INITIAL_RAG_PROMPT
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import (
|
||||
INITIAL_RAG_PROMPT_NO_SUB_QUESTIONS,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import (
|
||||
SUB_QUESTION_ANSWER_TEMPLATE,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import UNKNOWN_ANSWER
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
dispatch_main_answer_stop_info,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import format_docs
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
|
||||
from onyx.chat.models import AgentAnswerPiece
|
||||
from onyx.chat.models import ExtendedToolResponse
|
||||
from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS
|
||||
from onyx.configs.agent_configs import AGENT_MIN_ORIG_QUESTION_DOCS
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.tools.tool_implementations.search.search_tool import yield_search_responses
|
||||
|
||||
|
||||
def generate_initial_answer(
|
||||
state: SearchSQState, config: RunnableConfig
|
||||
) -> InitialAnswerUpdate:
|
||||
now_start = datetime.now()
|
||||
|
||||
logger.info(f"--------{now_start}--------GENERATE INITIAL---")
|
||||
|
||||
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
question = agent_a_config.search_request.query
|
||||
prompt_enrichment_components = get_prompt_enrichment_components(agent_a_config)
|
||||
|
||||
sub_questions_cited_docs = state.cited_docs
|
||||
all_original_question_documents = state.all_original_question_documents
|
||||
|
||||
consolidated_context_docs: list[InferenceSection] = sub_questions_cited_docs
|
||||
counter = 0
|
||||
for original_doc_number, original_doc in enumerate(all_original_question_documents):
|
||||
if original_doc_number not in sub_questions_cited_docs:
|
||||
if (
|
||||
counter <= AGENT_MIN_ORIG_QUESTION_DOCS
|
||||
or len(consolidated_context_docs) < AGENT_MAX_ANSWER_CONTEXT_DOCS
|
||||
):
|
||||
consolidated_context_docs.append(original_doc)
|
||||
counter += 1
|
||||
|
||||
# sort docs by their scores - though the scores refer to different questions
|
||||
relevant_docs = dedup_inference_sections(
|
||||
consolidated_context_docs, consolidated_context_docs
|
||||
)
|
||||
|
||||
decomp_questions = []
|
||||
|
||||
# Use the query info from the base document retrieval
|
||||
query_info = get_query_info(state.original_question_retrieval_results)
|
||||
if agent_a_config.search_tool is None:
|
||||
raise ValueError("search_tool must be provided for agentic search")
|
||||
for tool_response in yield_search_responses(
|
||||
query=question,
|
||||
reranked_sections=relevant_docs,
|
||||
final_context_sections=relevant_docs,
|
||||
search_query_info=query_info,
|
||||
get_section_relevance=lambda: None, # TODO: add relevance
|
||||
search_tool=agent_a_config.search_tool,
|
||||
):
|
||||
dispatch_custom_event(
|
||||
"tool_response",
|
||||
ExtendedToolResponse(
|
||||
id=tool_response.id,
|
||||
response=tool_response.response,
|
||||
level=0,
|
||||
level_question_nr=0, # 0, 0 is the base question
|
||||
),
|
||||
)
|
||||
|
||||
if len(relevant_docs) == 0:
|
||||
dispatch_custom_event(
|
||||
"initial_agent_answer",
|
||||
AgentAnswerPiece(
|
||||
answer_piece=UNKNOWN_ANSWER,
|
||||
level=0,
|
||||
level_question_nr=0,
|
||||
answer_type="agent_level_answer",
|
||||
),
|
||||
)
|
||||
dispatch_main_answer_stop_info(0)
|
||||
|
||||
answer = UNKNOWN_ANSWER
|
||||
initial_agent_stats = InitialAgentResultStats(
|
||||
sub_questions={},
|
||||
original_question={},
|
||||
agent_effectiveness={},
|
||||
)
|
||||
|
||||
else:
|
||||
decomp_answer_results = state.decomp_answer_results
|
||||
|
||||
good_qa_list: list[str] = []
|
||||
|
||||
sub_question_nr = 1
|
||||
|
||||
for decomp_answer_result in decomp_answer_results:
|
||||
decomp_questions.append(decomp_answer_result.question)
|
||||
_, question_nr = parse_question_id(decomp_answer_result.question_id)
|
||||
if (
|
||||
decomp_answer_result.quality.lower().startswith("yes")
|
||||
and len(decomp_answer_result.answer) > 0
|
||||
and decomp_answer_result.answer != UNKNOWN_ANSWER
|
||||
):
|
||||
good_qa_list.append(
|
||||
SUB_QUESTION_ANSWER_TEMPLATE.format(
|
||||
sub_question=decomp_answer_result.question,
|
||||
sub_answer=decomp_answer_result.answer,
|
||||
sub_question_nr=sub_question_nr,
|
||||
)
|
||||
)
|
||||
sub_question_nr += 1
|
||||
|
||||
if len(good_qa_list) > 0:
|
||||
sub_question_answer_str = "\n\n------\n\n".join(good_qa_list)
|
||||
else:
|
||||
sub_question_answer_str = ""
|
||||
|
||||
# Determine which base prompt to use given the sub-question information
|
||||
if len(good_qa_list) > 0:
|
||||
base_prompt = INITIAL_RAG_PROMPT
|
||||
else:
|
||||
base_prompt = INITIAL_RAG_PROMPT_NO_SUB_QUESTIONS
|
||||
|
||||
model = agent_a_config.fast_llm
|
||||
|
||||
doc_context = format_docs(relevant_docs)
|
||||
doc_context = trim_prompt_piece(
|
||||
model.config,
|
||||
doc_context,
|
||||
base_prompt
|
||||
+ sub_question_answer_str
|
||||
+ prompt_enrichment_components.persona_prompts.contextualized_prompt
|
||||
+ prompt_enrichment_components.history
|
||||
+ prompt_enrichment_components.date_str,
|
||||
)
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=base_prompt.format(
|
||||
question=question,
|
||||
answered_sub_questions=remove_document_citations(
|
||||
sub_question_answer_str
|
||||
),
|
||||
relevant_docs=format_docs(relevant_docs),
|
||||
persona_specification=prompt_enrichment_components.persona_prompts.contextualized_prompt,
|
||||
history=prompt_enrichment_components.history,
|
||||
date_prompt=prompt_enrichment_components.date_str,
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
streamed_tokens: list[str | list[str | dict[str, Any]]] = [""]
|
||||
dispatch_timings: list[float] = []
|
||||
for message in model.stream(msg):
|
||||
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
|
||||
content = message.content
|
||||
if not isinstance(content, str):
|
||||
raise ValueError(
|
||||
f"Expected content to be a string, but got {type(content)}"
|
||||
)
|
||||
start_stream_token = datetime.now()
|
||||
|
||||
dispatch_custom_event(
|
||||
"initial_agent_answer",
|
||||
AgentAnswerPiece(
|
||||
answer_piece=content,
|
||||
level=0,
|
||||
level_question_nr=0,
|
||||
answer_type="agent_level_answer",
|
||||
),
|
||||
)
|
||||
end_stream_token = datetime.now()
|
||||
dispatch_timings.append(
|
||||
(end_stream_token - start_stream_token).microseconds
|
||||
)
|
||||
streamed_tokens.append(content)
|
||||
|
||||
logger.info(
|
||||
f"Average dispatch time for initial answer: {sum(dispatch_timings) / len(dispatch_timings)}"
|
||||
)
|
||||
|
||||
dispatch_main_answer_stop_info(0)
|
||||
response = merge_content(*streamed_tokens)
|
||||
answer = cast(str, response)
|
||||
|
||||
initial_agent_stats = calculate_initial_agent_stats(
|
||||
state.decomp_answer_results, state.original_question_retrieval_stats
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"\n\nYYYYY--Sub-Questions:\n\n{sub_question_answer_str}\n\nStats:\n\n"
|
||||
)
|
||||
|
||||
if initial_agent_stats:
|
||||
logger.debug(initial_agent_stats.original_question)
|
||||
logger.debug(initial_agent_stats.sub_questions)
|
||||
logger.debug(initial_agent_stats.agent_effectiveness)
|
||||
|
||||
now_end = datetime.now()
|
||||
|
||||
agent_base_end_time = datetime.now()
|
||||
|
||||
agent_base_metrics = AgentBaseMetrics(
|
||||
num_verified_documents_total=len(relevant_docs),
|
||||
num_verified_documents_core=state.original_question_retrieval_stats.verified_count,
|
||||
verified_avg_score_core=state.original_question_retrieval_stats.verified_avg_scores,
|
||||
num_verified_documents_base=initial_agent_stats.sub_questions.get(
|
||||
"num_verified_documents", None
|
||||
),
|
||||
verified_avg_score_base=initial_agent_stats.sub_questions.get(
|
||||
"verified_avg_score", None
|
||||
),
|
||||
base_doc_boost_factor=initial_agent_stats.agent_effectiveness.get(
|
||||
"utilized_chunk_ratio", None
|
||||
),
|
||||
support_boost_factor=initial_agent_stats.agent_effectiveness.get(
|
||||
"support_ratio", None
|
||||
),
|
||||
duration__s=(agent_base_end_time - state.agent_start_time).total_seconds(),
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"{now_start} -- Main - Initial Answer generation, Time taken: {now_end - now_start}"
|
||||
)
|
||||
|
||||
return InitialAnswerUpdate(
|
||||
initial_answer=answer,
|
||||
initial_agent_stats=initial_agent_stats,
|
||||
generated_sub_questions=decomp_questions,
|
||||
agent_base_end_time=agent_base_end_time,
|
||||
agent_base_metrics=agent_base_metrics,
|
||||
log_messages=[
|
||||
f"{now_start} -- Main - Initial Answer generation, Time taken: {now_end - now_start}"
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,43 @@
|
||||
from datetime import datetime
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.initial__retrieval__subgraph.states import (
|
||||
BaseRawSearchOutput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main__graph.operations import logger
|
||||
from onyx.agents.agent_search.deep_search_a.main__graph.states import (
|
||||
ExpandedRetrievalUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkStats
|
||||
|
||||
|
||||
def ingest_initial_base_retrieval(
|
||||
state: BaseRawSearchOutput,
|
||||
) -> ExpandedRetrievalUpdate:
|
||||
now_start = datetime.now()
|
||||
|
||||
logger.info(f"--------{now_start}--------INGEST INITIAL RETRIEVAL---")
|
||||
|
||||
sub_question_retrieval_stats = (
|
||||
state.base_expanded_retrieval_result.sub_question_retrieval_stats
|
||||
)
|
||||
# if sub_question_retrieval_stats is None:
|
||||
# sub_question_retrieval_stats = AgentChunkStats()
|
||||
# else:
|
||||
# sub_question_retrieval_stats = sub_question_retrieval_stats
|
||||
|
||||
sub_question_retrieval_stats = sub_question_retrieval_stats or AgentChunkStats()
|
||||
|
||||
now_end = datetime.now()
|
||||
|
||||
logger.debug(
|
||||
f"--------{now_end}--{now_end - now_start}--------INGEST INITIAL RETRIEVAL END---"
|
||||
)
|
||||
|
||||
return ExpandedRetrievalUpdate(
|
||||
original_question_retrieval_results=state.base_expanded_retrieval_result.expanded_queries_results,
|
||||
all_original_question_documents=state.base_expanded_retrieval_result.context_documents,
|
||||
original_question_retrieval_stats=sub_question_retrieval_stats,
|
||||
log_messages=[
|
||||
f"{now_start} -- Main - Ingestion base retrieval, Time taken: {now_end - now_start}"
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,42 @@
|
||||
from datetime import datetime
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.initial__retrieval_sub_answers__subgraph.states import (
|
||||
SearchSQState,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main__graph.operations import logger
|
||||
from onyx.agents.agent_search.deep_search_a.main__graph.states import (
|
||||
InitialAnswerQualityUpdate,
|
||||
)
|
||||
|
||||
|
||||
def initial_answer_quality_check(state: SearchSQState) -> InitialAnswerQualityUpdate:
|
||||
"""
|
||||
Check whether the final output satisfies the original user question
|
||||
|
||||
Args:
|
||||
state (messages): The current state
|
||||
|
||||
Returns:
|
||||
InitialAnswerQualityUpdate
|
||||
"""
|
||||
|
||||
now_start = datetime.now()
|
||||
|
||||
logger.debug(
|
||||
f"--------{now_start}--------Checking for base answer validity - for not set True/False manually"
|
||||
)
|
||||
|
||||
verdict = True
|
||||
|
||||
now_end = datetime.now()
|
||||
|
||||
logger.debug(
|
||||
f"--------{now_end}--{now_end - now_start}--------INITIAL ANSWER QUALITY CHECK END---"
|
||||
)
|
||||
|
||||
return InitialAnswerQualityUpdate(
|
||||
initial_answer_quality_eval=verdict,
|
||||
log_messages=[
|
||||
f"{now_start} -- Main - Initial answer quality check, Time taken: {now_end - now_start}"
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,14 @@
|
||||
from datetime import datetime
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.initial__retrieval_sub_answers__subgraph.states import (
|
||||
SearchSQState,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main__graph.states import LoggerUpdate
|
||||
|
||||
|
||||
def retrieval_consolidation(
|
||||
state: SearchSQState,
|
||||
) -> LoggerUpdate:
|
||||
now_start = datetime.now()
|
||||
|
||||
return LoggerUpdate(log_messages=[f"{now_start} -- Retrieval consolidation"])
|
||||
@@ -0,0 +1,54 @@
|
||||
from operator import add
|
||||
from typing import Annotated
|
||||
from typing import TypedDict
|
||||
|
||||
from onyx.agents.agent_search.core_state import CoreState
|
||||
from onyx.agents.agent_search.deep_search_a.main__graph.states import BaseDecompUpdate
|
||||
from onyx.agents.agent_search.deep_search_a.main__graph.states import (
|
||||
DecompAnswersUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main__graph.states import (
|
||||
ExpandedRetrievalUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main__graph.states import (
|
||||
ExploratorySearchUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main__graph.states import (
|
||||
InitialAnswerQualityUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main__graph.states import (
|
||||
InitialAnswerUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.util__expanded_retrieval__subgraph.models import (
|
||||
ExpandedRetrievalResult,
|
||||
)
|
||||
|
||||
### States ###
|
||||
|
||||
|
||||
class SearchSQInput(CoreState):
|
||||
pass
|
||||
|
||||
|
||||
## Graph State
|
||||
|
||||
|
||||
class SearchSQState(
|
||||
# This includes the core state
|
||||
SearchSQInput,
|
||||
BaseDecompUpdate,
|
||||
InitialAnswerUpdate,
|
||||
DecompAnswersUpdate,
|
||||
ExpandedRetrievalUpdate,
|
||||
InitialAnswerQualityUpdate,
|
||||
ExploratorySearchUpdate,
|
||||
):
|
||||
# expanded_retrieval_result: Annotated[list[ExpandedRetrievalResult], add]
|
||||
base_raw_search_result: Annotated[list[ExpandedRetrievalResult], add]
|
||||
|
||||
|
||||
## Graph Output State - presently not used
|
||||
|
||||
|
||||
class SearchSQOutput(TypedDict):
|
||||
log_messages: list[str]
|
||||
@@ -0,0 +1,120 @@
|
||||
from collections.abc import Hashable
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
from typing import Literal
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import Send
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.initial__individual_sub_answer__subgraph.states import (
|
||||
AnswerQuestionInput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.initial__individual_sub_answer__subgraph.states import (
|
||||
AnswerQuestionOutput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main__graph.states import MainState
|
||||
from onyx.agents.agent_search.deep_search_a.main__graph.states import (
|
||||
RequireRefinedAnswerUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import make_question_id
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def route_initial_tool_choice(
|
||||
state: MainState, config: RunnableConfig
|
||||
) -> Literal["tool_call", "agent_search_start", "logging_node"]:
|
||||
agent_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
if state.tool_choice is not None:
|
||||
if (
|
||||
agent_config.use_agentic_search
|
||||
and agent_config.search_tool is not None
|
||||
and state.tool_choice.tool.name == agent_config.search_tool.name
|
||||
):
|
||||
return "agent_search_start"
|
||||
else:
|
||||
return "tool_call"
|
||||
else:
|
||||
return "logging_node"
|
||||
|
||||
|
||||
def parallelize_initial_sub_question_answering(
|
||||
state: MainState,
|
||||
) -> list[Send | Hashable]:
|
||||
now_start = datetime.now()
|
||||
if len(state.initial_decomp_questions) > 0:
|
||||
# sub_question_record_ids = [subq_record.id for subq_record in state["sub_question_records"]]
|
||||
# if len(state["sub_question_records"]) == 0:
|
||||
# if state["config"].use_persistence:
|
||||
# raise ValueError("No sub-questions found for initial decompozed questions")
|
||||
# else:
|
||||
# # in this case, we are doing retrieval on the original question.
|
||||
# # to make all the logic consistent, we create a new sub-question
|
||||
# # with the same content as the original question
|
||||
# sub_question_record_ids = [1] * len(state["initial_decomp_questions"])
|
||||
|
||||
return [
|
||||
Send(
|
||||
"answer_query_subgraph",
|
||||
AnswerQuestionInput(
|
||||
question=question,
|
||||
question_id=make_question_id(0, question_nr + 1),
|
||||
log_messages=[
|
||||
f"{now_start} -- Main Edge - Parallelize Initial Sub-question Answering"
|
||||
],
|
||||
),
|
||||
)
|
||||
for question_nr, question in enumerate(state.initial_decomp_questions)
|
||||
]
|
||||
|
||||
else:
|
||||
return [
|
||||
Send(
|
||||
"ingest_answers",
|
||||
AnswerQuestionOutput(
|
||||
answer_results=[],
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
# Define the function that determines whether to continue or not
|
||||
def continue_to_refined_answer_or_end(
|
||||
state: RequireRefinedAnswerUpdate,
|
||||
) -> Literal["refined_sub_question_creation", "logging_node"]:
|
||||
if state.require_refined_answer_eval:
|
||||
return "refined_sub_question_creation"
|
||||
else:
|
||||
return "logging_node"
|
||||
|
||||
|
||||
def parallelize_refined_sub_question_answering(
|
||||
state: MainState,
|
||||
) -> list[Send | Hashable]:
|
||||
now_start = datetime.now()
|
||||
if len(state.refined_sub_questions) > 0:
|
||||
return [
|
||||
Send(
|
||||
"answer_refined_question",
|
||||
AnswerQuestionInput(
|
||||
question=question_data.sub_question,
|
||||
question_id=make_question_id(1, question_nr),
|
||||
log_messages=[
|
||||
f"{now_start} -- Main Edge - Parallelize Refined Sub-question Answering"
|
||||
],
|
||||
),
|
||||
)
|
||||
for question_nr, question_data in state.refined_sub_questions.items()
|
||||
]
|
||||
|
||||
else:
|
||||
return [
|
||||
Send(
|
||||
"ingest_refined_sub_answers",
|
||||
AnswerQuestionOutput(
|
||||
answer_results=[],
|
||||
),
|
||||
)
|
||||
]
|
||||
@@ -0,0 +1,410 @@
|
||||
from langgraph.graph import END
|
||||
from langgraph.graph import START
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.initial__retrieval_sub_answers__subgraph.graph_builder import (
|
||||
initial_search_sq_subgraph_builder,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main__graph.edges import (
|
||||
continue_to_refined_answer_or_end,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main__graph.edges import (
|
||||
parallelize_refined_sub_question_answering,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main__graph.edges import (
|
||||
route_initial_tool_choice,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main__graph.nodes.agent_logging import (
|
||||
agent_logging,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main__graph.nodes.agent_search_start import (
|
||||
agent_search_start,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main__graph.nodes.answer_comparison import (
|
||||
answer_comparison,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main__graph.nodes.entity_term_extraction_llm import (
|
||||
entity_term_extraction_llm,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main__graph.nodes.generate_refined_answer import (
|
||||
generate_refined_answer,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main__graph.nodes.ingest_refined_answers import (
|
||||
ingest_refined_answers,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main__graph.nodes.refined_answer_decision import (
|
||||
refined_answer_decision,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main__graph.nodes.refined_sub_question_creation import (
|
||||
refined_sub_question_creation,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main__graph.states import MainInput
|
||||
from onyx.agents.agent_search.deep_search_a.main__graph.states import MainState
|
||||
from onyx.agents.agent_search.deep_search_a.refinement__consolidate_sub_answers__subgraph.graph_builder import (
|
||||
answer_refined_query_graph_builder,
|
||||
)
|
||||
from onyx.agents.agent_search.orchestration.nodes.basic_use_tool_response import (
|
||||
basic_use_tool_response,
|
||||
)
|
||||
from onyx.agents.agent_search.orchestration.nodes.llm_tool_choice import llm_tool_choice
|
||||
from onyx.agents.agent_search.orchestration.nodes.prepare_tool_input import (
|
||||
prepare_tool_input,
|
||||
)
|
||||
from onyx.agents.agent_search.orchestration.nodes.tool_call import tool_call
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
test_mode = False
|
||||
|
||||
|
||||
def main_graph_builder(test_mode: bool = False) -> StateGraph:
|
||||
graph = StateGraph(
|
||||
state_schema=MainState,
|
||||
input=MainInput,
|
||||
)
|
||||
|
||||
# graph.add_node(
|
||||
# node="agent_path_decision",
|
||||
# action=agent_path_decision,
|
||||
# )
|
||||
|
||||
# graph.add_node(
|
||||
# node="agent_path_routing",
|
||||
# action=agent_path_routing,
|
||||
# )
|
||||
|
||||
# graph.add_node(
|
||||
# node="LLM",
|
||||
# action=direct_llm_handling,
|
||||
# )
|
||||
graph.add_node(
|
||||
node="prepare_tool_input",
|
||||
action=prepare_tool_input,
|
||||
)
|
||||
graph.add_node(
|
||||
node="initial_tool_choice",
|
||||
action=llm_tool_choice,
|
||||
)
|
||||
graph.add_node(
|
||||
node="tool_call",
|
||||
action=tool_call,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
node="basic_use_tool_response",
|
||||
action=basic_use_tool_response,
|
||||
)
|
||||
graph.add_node(
|
||||
node="agent_search_start",
|
||||
action=agent_search_start,
|
||||
)
|
||||
|
||||
# graph.add_node(
|
||||
# node="initial_sub_question_creation",
|
||||
# action=initial_sub_question_creation,
|
||||
# )
|
||||
|
||||
initial_search_sq_subgraph = initial_search_sq_subgraph_builder().compile()
|
||||
graph.add_node(
|
||||
node="initial_search_sq_subgraph",
|
||||
action=initial_search_sq_subgraph,
|
||||
)
|
||||
|
||||
# answer_query_subgraph = answer_query_graph_builder().compile()
|
||||
# graph.add_node(
|
||||
# node="answer_query_subgraph",
|
||||
# action=answer_query_subgraph,
|
||||
# )
|
||||
|
||||
# base_raw_search_subgraph = base_raw_search_graph_builder().compile()
|
||||
# graph.add_node(
|
||||
# node="base_raw_search_subgraph",
|
||||
# action=base_raw_search_subgraph,
|
||||
# )
|
||||
|
||||
# refined_answer_subgraph = refined_answers_graph_builder().compile()
|
||||
# graph.add_node(
|
||||
# node="refined_answer_subgraph",
|
||||
# action=refined_answer_subgraph,
|
||||
# )
|
||||
|
||||
graph.add_node(
|
||||
node="refined_sub_question_creation",
|
||||
action=refined_sub_question_creation,
|
||||
)
|
||||
|
||||
answer_refined_question = answer_refined_query_graph_builder().compile()
|
||||
graph.add_node(
|
||||
node="answer_refined_question",
|
||||
action=answer_refined_question,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
node="ingest_refined_answers",
|
||||
action=ingest_refined_answers,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
node="generate_refined_answer",
|
||||
action=generate_refined_answer,
|
||||
)
|
||||
|
||||
# graph.add_node(
|
||||
# node="check_refined_answer",
|
||||
# action=check_refined_answer,
|
||||
# )
|
||||
|
||||
# graph.add_node(
|
||||
# node="ingest_initial_retrieval",
|
||||
# action=ingest_initial_base_retrieval,
|
||||
# )
|
||||
|
||||
# graph.add_node(
|
||||
# node="retrieval_consolidation",
|
||||
# action=retrieval_consolidation,
|
||||
# )
|
||||
|
||||
# graph.add_node(
|
||||
# node="ingest_initial_sub_question_answers",
|
||||
# action=ingest_initial_sub_question_answers,
|
||||
# )
|
||||
# graph.add_node(
|
||||
# node="generate_initial_answer",
|
||||
# action=generate_initial_answer,
|
||||
# )
|
||||
|
||||
# graph.add_node(
|
||||
# node="initial_answer_quality_check",
|
||||
# action=initial_answer_quality_check,
|
||||
# )
|
||||
|
||||
graph.add_node(
|
||||
node="entity_term_extraction_llm",
|
||||
action=entity_term_extraction_llm,
|
||||
)
|
||||
graph.add_node(
|
||||
node="refined_answer_decision",
|
||||
action=refined_answer_decision,
|
||||
)
|
||||
graph.add_node(
|
||||
node="answer_comparison",
|
||||
action=answer_comparison,
|
||||
)
|
||||
graph.add_node(
|
||||
node="logging_node",
|
||||
action=agent_logging,
|
||||
)
|
||||
# if test_mode:
|
||||
# graph.add_node(
|
||||
# node="generate_initial_base_answer",
|
||||
# action=generate_initial_base_answer,
|
||||
# )
|
||||
|
||||
### Add edges ###
|
||||
|
||||
# raph.add_edge(start_key=START, end_key="base_raw_search_subgraph")
|
||||
|
||||
# graph.add_edge(
|
||||
# start_key=START,
|
||||
# end_key="agent_path_decision",
|
||||
# )
|
||||
|
||||
# graph.add_edge(
|
||||
# start_key="agent_path_decision",
|
||||
# end_key="agent_path_routing",
|
||||
# )
|
||||
graph.add_edge(start_key=START, end_key="prepare_tool_input")
|
||||
|
||||
graph.add_edge(
|
||||
start_key="prepare_tool_input",
|
||||
end_key="initial_tool_choice",
|
||||
)
|
||||
|
||||
graph.add_conditional_edges(
|
||||
"initial_tool_choice",
|
||||
route_initial_tool_choice,
|
||||
["tool_call", "agent_search_start", "logging_node"],
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key="tool_call",
|
||||
end_key="basic_use_tool_response",
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="basic_use_tool_response",
|
||||
end_key="logging_node",
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key="agent_search_start",
|
||||
end_key="initial_search_sq_subgraph",
|
||||
)
|
||||
# graph.add_edge(
|
||||
# start_key="agent_search_start",
|
||||
# end_key="base_raw_search_subgraph",
|
||||
# )
|
||||
|
||||
graph.add_edge(
|
||||
start_key="agent_search_start",
|
||||
end_key="entity_term_extraction_llm",
|
||||
)
|
||||
|
||||
# graph.add_edge(
|
||||
# start_key="agent_search_start",
|
||||
# end_key="initial_sub_question_creation",
|
||||
# )
|
||||
|
||||
# graph.add_edge(
|
||||
# start_key="base_raw_search_subgraph",
|
||||
# end_key="ingest_initial_retrieval",
|
||||
# )
|
||||
|
||||
# graph.add_edge(
|
||||
# start_key=["ingest_initial_retrieval", "ingest_initial_sub_question_answers"],
|
||||
# end_key="retrieval_consolidation",
|
||||
# )
|
||||
|
||||
# graph.add_edge(
|
||||
# start_key="retrieval_consolidation",
|
||||
# end_key="generate_initial_answer",
|
||||
# )
|
||||
|
||||
# graph.add_edge(
|
||||
# start_key="LLM",
|
||||
# end_key=END,
|
||||
# )
|
||||
|
||||
# graph.add_edge(
|
||||
# start_key=START,
|
||||
# end_key="initial_sub_question_creation",
|
||||
# )
|
||||
|
||||
# graph.add_conditional_edges(
|
||||
# source="initial_sub_question_creation",
|
||||
# path=parallelize_initial_sub_question_answering,
|
||||
# path_map=["answer_query_subgraph"],
|
||||
# )
|
||||
# graph.add_edge(
|
||||
# start_key="answer_query_subgraph",
|
||||
# end_key="ingest_initial_sub_question_answers",
|
||||
# )
|
||||
|
||||
# graph.add_edge(
|
||||
# start_key="retrieval_consolidation",
|
||||
# end_key="generate_initial_answer",
|
||||
# )
|
||||
|
||||
# graph.add_edge(
|
||||
# start_key="generate_initial_answer",
|
||||
# end_key="entity_term_extraction_llm",
|
||||
# )
|
||||
|
||||
# graph.add_edge(
|
||||
# start_key="generate_initial_answer",
|
||||
# end_key="initial_answer_quality_check",
|
||||
# )
|
||||
|
||||
# graph.add_edge(
|
||||
# start_key=["initial_answer_quality_check", "entity_term_extraction_llm"],
|
||||
# end_key="refined_answer_decision",
|
||||
# )
|
||||
# graph.add_edge(
|
||||
# start_key="initial_answer_quality_check",
|
||||
# end_key="refined_answer_decision",
|
||||
# )
|
||||
|
||||
graph.add_edge(
|
||||
start_key=["initial_search_sq_subgraph", "entity_term_extraction_llm"],
|
||||
end_key="refined_answer_decision",
|
||||
)
|
||||
|
||||
graph.add_conditional_edges(
|
||||
source="refined_answer_decision",
|
||||
path=continue_to_refined_answer_or_end,
|
||||
path_map=["refined_sub_question_creation", "logging_node"],
|
||||
)
|
||||
|
||||
graph.add_conditional_edges(
|
||||
source="refined_sub_question_creation", # DONE
|
||||
path=parallelize_refined_sub_question_answering,
|
||||
path_map=["answer_refined_question"],
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="answer_refined_question", # HERE
|
||||
end_key="ingest_refined_answers",
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key="ingest_refined_answers",
|
||||
end_key="generate_refined_answer",
|
||||
)
|
||||
|
||||
# graph.add_conditional_edges(
|
||||
# source="refined_answer_decision",
|
||||
# path=continue_to_refined_answer_or_end,
|
||||
# path_map=["refined_answer_subgraph", END],
|
||||
# )
|
||||
|
||||
# graph.add_edge(
|
||||
# start_key="refined_answer_subgraph",
|
||||
# end_key="generate_refined_answer",
|
||||
# )
|
||||
|
||||
graph.add_edge(
|
||||
start_key="generate_refined_answer",
|
||||
end_key="answer_comparison",
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="answer_comparison",
|
||||
end_key="logging_node",
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key="logging_node",
|
||||
end_key=END,
|
||||
)
|
||||
|
||||
# graph.add_edge(
|
||||
# start_key="generate_refined_answer",
|
||||
# end_key="check_refined_answer",
|
||||
# )
|
||||
|
||||
# graph.add_edge(
|
||||
# start_key="check_refined_answer",
|
||||
# end_key=END,
|
||||
# )
|
||||
|
||||
return graph
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pass
|
||||
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.llm.factory import get_default_llms
|
||||
from onyx.context.search.models import SearchRequest
|
||||
|
||||
graph = main_graph_builder()
|
||||
compiled_graph = graph.compile()
|
||||
primary_llm, fast_llm = get_default_llms()
|
||||
|
||||
with get_session_context_manager() as db_session:
|
||||
search_request = SearchRequest(query="Who created Excel?")
|
||||
agent_a_config, search_tool = get_test_config(
|
||||
db_session, primary_llm, fast_llm, search_request
|
||||
)
|
||||
|
||||
inputs = MainInput(
|
||||
base_question=agent_a_config.search_request.query, log_messages=[]
|
||||
)
|
||||
|
||||
for thing in compiled_graph.stream(
|
||||
input=inputs,
|
||||
config={"configurable": {"config": agent_a_config}},
|
||||
# stream_mode="debug",
|
||||
# debug=True,
|
||||
subgraphs=True,
|
||||
):
|
||||
logger.debug(thing)
|
||||
@@ -0,0 +1,36 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class FollowUpSubQuestion(BaseModel):
|
||||
sub_question: str
|
||||
sub_question_id: str
|
||||
verified: bool
|
||||
answered: bool
|
||||
answer: str
|
||||
|
||||
|
||||
class AgentTimings(BaseModel):
|
||||
base_duration__s: float | None
|
||||
refined_duration__s: float | None
|
||||
full_duration__s: float | None
|
||||
|
||||
|
||||
class AgentBaseMetrics(BaseModel):
|
||||
num_verified_documents_total: int | None
|
||||
num_verified_documents_core: int | None
|
||||
verified_avg_score_core: float | None
|
||||
num_verified_documents_base: int | float | None
|
||||
verified_avg_score_base: float | None = None
|
||||
base_doc_boost_factor: float | None = None
|
||||
support_boost_factor: float | None = None
|
||||
duration__s: float | None = None
|
||||
|
||||
|
||||
class AgentRefinedMetrics(BaseModel):
|
||||
refined_doc_boost_factor: float | None = None
|
||||
refined_question_boost_factor: float | None = None
|
||||
duration__s: float | None = None
|
||||
|
||||
|
||||
class AgentAdditionalMetrics(BaseModel):
|
||||
pass
|
||||
@@ -0,0 +1,135 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.main__graph.models import (
|
||||
AgentAdditionalMetrics,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main__graph.models import AgentTimings
|
||||
from onyx.agents.agent_search.deep_search_a.main__graph.operations import logger
|
||||
from onyx.agents.agent_search.deep_search_a.main__graph.states import MainOutput
|
||||
from onyx.agents.agent_search.deep_search_a.main__graph.states import MainState
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import CombinedAgentMetrics
|
||||
from onyx.db.chat import log_agent_metrics
|
||||
from onyx.db.chat import log_agent_sub_question_results
|
||||
|
||||
|
||||
def agent_logging(state: MainState, config: RunnableConfig) -> MainOutput:
|
||||
now_start = datetime.now()
|
||||
|
||||
logger.info(f"--------{now_start}--------LOGGING NODE---")
|
||||
|
||||
agent_start_time = state.agent_start_time
|
||||
agent_base_end_time = state.agent_base_end_time
|
||||
agent_refined_start_time = state.agent_refined_start_time
|
||||
agent_refined_end_time = state.agent_refined_end_time
|
||||
agent_end_time = agent_refined_end_time or agent_base_end_time
|
||||
|
||||
agent_base_duration = None
|
||||
if agent_base_end_time:
|
||||
agent_base_duration = (agent_base_end_time - agent_start_time).total_seconds()
|
||||
|
||||
agent_refined_duration = None
|
||||
if agent_refined_start_time and agent_refined_end_time:
|
||||
agent_refined_duration = (
|
||||
agent_refined_end_time - agent_refined_start_time
|
||||
).total_seconds()
|
||||
|
||||
agent_full_duration = None
|
||||
if agent_end_time:
|
||||
agent_full_duration = (agent_end_time - agent_start_time).total_seconds()
|
||||
|
||||
agent_type = "refined" if agent_refined_duration else "base"
|
||||
|
||||
agent_base_metrics = state.agent_base_metrics
|
||||
agent_refined_metrics = state.agent_refined_metrics
|
||||
|
||||
combined_agent_metrics = CombinedAgentMetrics(
|
||||
timings=AgentTimings(
|
||||
base_duration__s=agent_base_duration,
|
||||
refined_duration__s=agent_refined_duration,
|
||||
full_duration__s=agent_full_duration,
|
||||
),
|
||||
base_metrics=agent_base_metrics,
|
||||
refined_metrics=agent_refined_metrics,
|
||||
additional_metrics=AgentAdditionalMetrics(),
|
||||
)
|
||||
|
||||
persona_id = None
|
||||
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
if agent_a_config.search_request.persona:
|
||||
persona_id = agent_a_config.search_request.persona.id
|
||||
|
||||
user_id = None
|
||||
if agent_a_config.search_tool is not None:
|
||||
user = agent_a_config.search_tool.user
|
||||
if user:
|
||||
user_id = user.id
|
||||
|
||||
# log the agent metrics
|
||||
if agent_a_config.db_session is not None:
|
||||
if agent_base_duration is not None:
|
||||
log_agent_metrics(
|
||||
db_session=agent_a_config.db_session,
|
||||
user_id=user_id,
|
||||
persona_id=persona_id,
|
||||
agent_type=agent_type,
|
||||
start_time=agent_start_time,
|
||||
agent_metrics=combined_agent_metrics,
|
||||
)
|
||||
|
||||
if agent_a_config.use_agentic_persistence:
|
||||
# Persist the sub-answer in the database
|
||||
db_session = agent_a_config.db_session
|
||||
chat_session_id = agent_a_config.chat_session_id
|
||||
primary_message_id = agent_a_config.message_id
|
||||
sub_question_answer_results = state.decomp_answer_results
|
||||
|
||||
log_agent_sub_question_results(
|
||||
db_session=db_session,
|
||||
chat_session_id=chat_session_id,
|
||||
primary_message_id=primary_message_id,
|
||||
sub_question_answer_results=sub_question_answer_results,
|
||||
)
|
||||
|
||||
# if chat_session_id is not None and primary_message_id is not None and sub_question_id is not None:
|
||||
# create_sub_answer(
|
||||
# db_session=db_session,
|
||||
# chat_session_id=chat_session_id,
|
||||
# primary_message_id=primary_message_id,
|
||||
# sub_question_id=sub_question_id,
|
||||
# answer=answer_str,
|
||||
# # )
|
||||
# pass
|
||||
|
||||
now_end = datetime.now()
|
||||
main_output = MainOutput(
|
||||
log_messages=[
|
||||
f"{now_start} -- Main - Logging, Time taken: {now_end - now_start}"
|
||||
],
|
||||
)
|
||||
|
||||
logger.info(f"--------{now_end}--{now_end - now_start}--------LOGGING NODE END---")
|
||||
|
||||
for log_message in state.log_messages:
|
||||
logger.info(log_message)
|
||||
|
||||
logger.info("")
|
||||
|
||||
if state.agent_base_metrics:
|
||||
logger.info(f"Initial loop: {state.agent_base_metrics.duration__s}")
|
||||
if state.agent_refined_metrics:
|
||||
logger.info(f"Refined loop: {state.agent_refined_metrics.duration__s}")
|
||||
if (
|
||||
state.agent_base_metrics
|
||||
and state.agent_refined_metrics
|
||||
and state.agent_base_metrics.duration__s
|
||||
and state.agent_refined_metrics.duration__s
|
||||
):
|
||||
logger.info(
|
||||
f"Total time: {float(state.agent_base_metrics.duration__s) + float(state.agent_refined_metrics.duration__s)}"
|
||||
)
|
||||
|
||||
return main_output
|
||||
@@ -0,0 +1,36 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.main__graph.operations import logger
|
||||
from onyx.agents.agent_search.deep_search_a.main__graph.states import MainState
|
||||
from onyx.agents.agent_search.deep_search_a.main__graph.states import RoutingDecision
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
|
||||
|
||||
def agent_path_decision(state: MainState, config: RunnableConfig) -> RoutingDecision:
|
||||
now_start = datetime.now()
|
||||
|
||||
cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
|
||||
# perform_initial_search_path_decision = (
|
||||
# agent_a_config.perform_initial_search_path_decision
|
||||
# )
|
||||
|
||||
logger.info(f"--------{now_start}--------DECIDING TO SEARCH OR GO TO LLM---")
|
||||
|
||||
routing = "agent_search"
|
||||
|
||||
now_end = datetime.now()
|
||||
|
||||
logger.debug(
|
||||
f"--------{now_end}--{now_end - now_start}--------DECIDING TO SEARCH OR GO TO LLM END---"
|
||||
)
|
||||
return RoutingDecision(
|
||||
# Decide which route to take
|
||||
routing_decision=routing,
|
||||
log_messages=[
|
||||
f"{now_end} -- Path decision: {routing}, Time taken: {now_end - now_start}"
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,31 @@
|
||||
from datetime import datetime
|
||||
from typing import Literal
|
||||
|
||||
from langgraph.types import Command
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.main__graph.states import MainState
|
||||
|
||||
|
||||
def agent_path_routing(
|
||||
state: MainState,
|
||||
) -> Command[Literal["agent_search_start", "LLM"]]:
|
||||
now_start = datetime.now()
|
||||
routing = state.routing_decision if hasattr(state, "routing") else "agent_search"
|
||||
|
||||
if routing == "agent_search":
|
||||
agent_path = "agent_search_start"
|
||||
else:
|
||||
agent_path = "LLM"
|
||||
|
||||
now_end = datetime.now()
|
||||
|
||||
return Command(
|
||||
# state update
|
||||
update={
|
||||
"log_messages": [
|
||||
f"{now_start} -- Main - Path routing: {agent_path}, Time taken: {now_end - now_start}"
|
||||
]
|
||||
},
|
||||
# control flow
|
||||
goto=agent_path,
|
||||
)
|
||||
@@ -0,0 +1,59 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.main__graph.operations import logger
|
||||
from onyx.agents.agent_search.deep_search_a.main__graph.states import (
|
||||
ExploratorySearchUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main__graph.states import MainState
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
build_history_prompt,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import retrieve_search_docs
|
||||
from onyx.configs.agent_configs import AGENT_EXPLORATORY_SEARCH_RESULTS
|
||||
from onyx.context.search.models import InferenceSection
|
||||
|
||||
|
||||
def agent_search_start(
|
||||
state: MainState, config: RunnableConfig
|
||||
) -> ExploratorySearchUpdate:
|
||||
now_start = datetime.now()
|
||||
|
||||
logger.info(f"--------{now_start}--------EXPLORATORY SEARCH START---")
|
||||
|
||||
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
question = agent_a_config.search_request.query
|
||||
chat_session_id = agent_a_config.chat_session_id
|
||||
primary_message_id = agent_a_config.message_id
|
||||
agent_a_config.fast_llm
|
||||
|
||||
history = build_history_prompt(agent_a_config, question)
|
||||
|
||||
if chat_session_id is None or primary_message_id is None:
|
||||
raise ValueError(
|
||||
"chat_session_id and message_id must be provided for agent search"
|
||||
)
|
||||
|
||||
# Initial search to inform decomposition. Just get top 3 fits
|
||||
|
||||
search_tool = agent_a_config.search_tool
|
||||
if search_tool is None:
|
||||
raise ValueError("search_tool must be provided for agentic search")
|
||||
retrieved_docs: list[InferenceSection] = retrieve_search_docs(search_tool, question)
|
||||
|
||||
exploratory_search_results = retrieved_docs[:AGENT_EXPLORATORY_SEARCH_RESULTS]
|
||||
now_end = datetime.now()
|
||||
logger.debug(
|
||||
f"--------{now_end}--{now_end - now_start}--------EXPLORATORY SEARCH END---"
|
||||
)
|
||||
|
||||
return ExploratorySearchUpdate(
|
||||
exploratory_search_results=exploratory_search_results,
|
||||
previous_history_summary=history,
|
||||
log_messages=[
|
||||
f"{now_start} -- Main - Exploratory Search, Time taken: {now_end - now_start}"
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,60 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.callbacks.manager import dispatch_custom_event
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.main__graph.operations import logger
|
||||
from onyx.agents.agent_search.deep_search_a.main__graph.states import AnswerComparison
|
||||
from onyx.agents.agent_search.deep_search_a.main__graph.states import MainState
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import ANSWER_COMPARISON_PROMPT
|
||||
from onyx.chat.models import RefinedAnswerImprovement
|
||||
|
||||
|
||||
def answer_comparison(state: MainState, config: RunnableConfig) -> AnswerComparison:
|
||||
now_start = datetime.now()
|
||||
|
||||
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
question = agent_a_config.search_request.query
|
||||
initial_answer = state.initial_answer
|
||||
refined_answer = state.refined_answer
|
||||
|
||||
logger.info(f"--------{now_start}--------ANSWER COMPARISON STARTED--")
|
||||
|
||||
answer_comparison_prompt = ANSWER_COMPARISON_PROMPT.format(
|
||||
question=question, initial_answer=initial_answer, refined_answer=refined_answer
|
||||
)
|
||||
|
||||
msg = [HumanMessage(content=answer_comparison_prompt)]
|
||||
|
||||
# Get the rewritten queries in a defined format
|
||||
model = agent_a_config.fast_llm
|
||||
|
||||
# no need to stream this
|
||||
resp = model.invoke(msg)
|
||||
|
||||
refined_answer_improvement = (
|
||||
isinstance(resp.content, str) and "yes" in resp.content.lower()
|
||||
)
|
||||
|
||||
dispatch_custom_event(
|
||||
"refined_answer_improvement",
|
||||
RefinedAnswerImprovement(
|
||||
refined_answer_improvement=refined_answer_improvement,
|
||||
),
|
||||
)
|
||||
|
||||
now_end = datetime.now()
|
||||
|
||||
logger.info(
|
||||
f"{now_start} -- MAIN - Answer comparison, Time taken: {now_end - now_start}"
|
||||
)
|
||||
|
||||
return AnswerComparison(
|
||||
refined_answer_improvement_eval=refined_answer_improvement,
|
||||
log_messages=[
|
||||
f"{now_start} -- Answer comparison: {refined_answer_improvement}, Time taken: {now_end - now_start}"
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,83 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.callbacks.manager import dispatch_custom_event
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import merge_content
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.main__graph.operations import logger
|
||||
from onyx.agents.agent_search.deep_search_a.main__graph.states import (
|
||||
InitialAnswerUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main__graph.states import MainState
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import DIRECT_LLM_PROMPT
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_persona_agent_prompt_expressions,
|
||||
)
|
||||
from onyx.chat.models import AgentAnswerPiece
|
||||
|
||||
|
||||
def direct_llm_handling(
|
||||
state: MainState, config: RunnableConfig
|
||||
) -> InitialAnswerUpdate:
|
||||
now_start = datetime.now()
|
||||
|
||||
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
question = agent_a_config.search_request.query
|
||||
persona_contextualialized_prompt = get_persona_agent_prompt_expressions(
|
||||
agent_a_config.search_request.persona
|
||||
).contextualized_prompt
|
||||
|
||||
logger.info(f"--------{now_start}--------LLM HANDLING START---")
|
||||
|
||||
model = agent_a_config.fast_llm
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=DIRECT_LLM_PROMPT.format(
|
||||
persona_specification=persona_contextualialized_prompt,
|
||||
question=question,
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
streamed_tokens: list[str | list[str | dict[str, Any]]] = [""]
|
||||
|
||||
for message in model.stream(msg):
|
||||
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
|
||||
content = message.content
|
||||
if not isinstance(content, str):
|
||||
raise ValueError(
|
||||
f"Expected content to be a string, but got {type(content)}"
|
||||
)
|
||||
dispatch_custom_event(
|
||||
"initial_agent_answer",
|
||||
AgentAnswerPiece(
|
||||
answer_piece=content,
|
||||
level=0,
|
||||
level_question_nr=0,
|
||||
answer_type="agent_level_answer",
|
||||
),
|
||||
)
|
||||
streamed_tokens.append(content)
|
||||
|
||||
response = merge_content(*streamed_tokens)
|
||||
answer = cast(str, response)
|
||||
|
||||
now_end = datetime.now()
|
||||
|
||||
logger.info(f"--------{now_end}--{now_end - now_start}--------LLM HANDLING END---")
|
||||
|
||||
return InitialAnswerUpdate(
|
||||
initial_answer=answer,
|
||||
initial_agent_stats=None,
|
||||
generated_sub_questions=[],
|
||||
agent_base_end_time=now_end,
|
||||
agent_base_metrics=None,
|
||||
log_messages=[
|
||||
f"{now_start} -- Main - LLM handling: {answer}, Time taken: {now_end - now_start}"
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,126 @@
|
||||
import json
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.main__graph.operations import logger
|
||||
from onyx.agents.agent_search.deep_search_a.main__graph.states import (
|
||||
EntityTermExtractionUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main__graph.states import MainState
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
trim_prompt_piece,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import Entity
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import (
|
||||
EntityRelationshipTermExtraction,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import Relationship
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import Term
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import ENTITY_TERM_PROMPT
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import format_docs
|
||||
|
||||
|
||||
def entity_term_extraction_llm(
|
||||
state: MainState, config: RunnableConfig
|
||||
) -> EntityTermExtractionUpdate:
|
||||
now_start = datetime.now()
|
||||
|
||||
logger.info(f"--------{now_start}--------GENERATE ENTITIES & TERMS---")
|
||||
|
||||
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
if not agent_a_config.allow_refinement:
|
||||
now_end = datetime.now()
|
||||
return EntityTermExtractionUpdate(
|
||||
entity_relation_term_extractions=EntityRelationshipTermExtraction(
|
||||
entities=[],
|
||||
relationships=[],
|
||||
terms=[],
|
||||
),
|
||||
log_messages=[
|
||||
f"{now_start} -- Main - ETR Extraction, Time taken: {now_end - now_start}"
|
||||
],
|
||||
)
|
||||
|
||||
# first four lines duplicates from generate_initial_answer
|
||||
question = agent_a_config.search_request.query
|
||||
initial_search_docs = state.exploratory_search_results[:15]
|
||||
|
||||
# start with the entity/term/extraction
|
||||
doc_context = format_docs(initial_search_docs)
|
||||
|
||||
doc_context = trim_prompt_piece(
|
||||
agent_a_config.fast_llm.config, doc_context, ENTITY_TERM_PROMPT + question
|
||||
)
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=ENTITY_TERM_PROMPT.format(question=question, context=doc_context),
|
||||
)
|
||||
]
|
||||
fast_llm = agent_a_config.fast_llm
|
||||
# Grader
|
||||
llm_response = fast_llm.invoke(
|
||||
prompt=msg,
|
||||
)
|
||||
|
||||
cleaned_response = re.sub(r"```json\n|\n```", "", str(llm_response.content))
|
||||
parsed_response = json.loads(cleaned_response)
|
||||
|
||||
entities = []
|
||||
relationships = []
|
||||
terms = []
|
||||
for entity in parsed_response.get("retrieved_entities_relationships", {}).get(
|
||||
"entities", {}
|
||||
):
|
||||
entity_name = entity.get("entity_name", "")
|
||||
entity_type = entity.get("entity_type", "")
|
||||
entities.append(Entity(entity_name=entity_name, entity_type=entity_type))
|
||||
|
||||
for relationship in parsed_response.get("retrieved_entities_relationships", {}).get(
|
||||
"relationships", {}
|
||||
):
|
||||
relationship_name = relationship.get("relationship_name", "")
|
||||
relationship_type = relationship.get("relationship_type", "")
|
||||
relationship_entities = relationship.get("relationship_entities", [])
|
||||
relationships.append(
|
||||
Relationship(
|
||||
relationship_name=relationship_name,
|
||||
relationship_type=relationship_type,
|
||||
relationship_entities=relationship_entities,
|
||||
)
|
||||
)
|
||||
|
||||
for term in parsed_response.get("retrieved_entities_relationships", {}).get(
|
||||
"terms", {}
|
||||
):
|
||||
term_name = term.get("term_name", "")
|
||||
term_type = term.get("term_type", "")
|
||||
term_similar_to = term.get("term_similar_to", [])
|
||||
terms.append(
|
||||
Term(
|
||||
term_name=term_name,
|
||||
term_type=term_type,
|
||||
term_similar_to=term_similar_to,
|
||||
)
|
||||
)
|
||||
|
||||
now_end = datetime.now()
|
||||
|
||||
logger.info(
|
||||
f"{now_start} -- MAIN - Entity term extraction, Time taken: {now_end - now_start}"
|
||||
)
|
||||
|
||||
return EntityTermExtractionUpdate(
|
||||
entity_relation_term_extractions=EntityRelationshipTermExtraction(
|
||||
entities=entities,
|
||||
relationships=relationships,
|
||||
terms=terms,
|
||||
),
|
||||
log_messages=[
|
||||
f"{now_start} -- Main - ETR Extraction, Time taken: {now_end - now_start}"
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,58 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.main__graph.operations import logger
|
||||
from onyx.agents.agent_search.deep_search_a.main__graph.states import (
|
||||
InitialAnswerBASEUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main__graph.states import MainState
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
trim_prompt_piece,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import INITIAL_RAG_BASE_PROMPT
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import format_docs
|
||||
|
||||
|
||||
def generate_initial_base_search_only_answer(
|
||||
state: MainState,
|
||||
config: RunnableConfig,
|
||||
) -> InitialAnswerBASEUpdate:
|
||||
now_start = datetime.now()
|
||||
|
||||
logger.info(f"--------{now_start}--------GENERATE INITIAL BASE ANSWER---")
|
||||
|
||||
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
question = agent_a_config.search_request.query
|
||||
original_question_docs = state.all_original_question_documents
|
||||
|
||||
model = agent_a_config.fast_llm
|
||||
|
||||
doc_context = format_docs(original_question_docs)
|
||||
doc_context = trim_prompt_piece(
|
||||
model.config, doc_context, INITIAL_RAG_BASE_PROMPT + question
|
||||
)
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=INITIAL_RAG_BASE_PROMPT.format(
|
||||
question=question,
|
||||
context=doc_context,
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
# Grader
|
||||
response = model.invoke(msg)
|
||||
answer = response.pretty_repr()
|
||||
|
||||
now_end = datetime.now()
|
||||
|
||||
logger.debug(
|
||||
f"--------{now_end}--{now_end - now_start}--------INITIAL BASE ANSWER END---\n\n"
|
||||
)
|
||||
|
||||
return InitialAnswerBASEUpdate(initial_base_answer=answer)
|
||||
@@ -0,0 +1,335 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.callbacks.manager import dispatch_custom_event
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import merge_content
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.main__graph.models import (
|
||||
AgentRefinedMetrics,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main__graph.operations import get_query_info
|
||||
from onyx.agents.agent_search.deep_search_a.main__graph.operations import logger
|
||||
from onyx.agents.agent_search.deep_search_a.main__graph.operations import (
|
||||
remove_document_citations,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main__graph.states import MainState
|
||||
from onyx.agents.agent_search.deep_search_a.main__graph.states import (
|
||||
RefinedAnswerUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
get_prompt_enrichment_components,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
trim_prompt_piece,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import InferenceSection
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import RefinedAgentStats
|
||||
from onyx.agents.agent_search.shared_graph_utils.operators import (
|
||||
dedup_inference_sections,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import REVISED_RAG_PROMPT
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import (
|
||||
REVISED_RAG_PROMPT_NO_SUB_QUESTIONS,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import (
|
||||
SUB_QUESTION_ANSWER_TEMPLATE,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import UNKNOWN_ANSWER
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
dispatch_main_answer_stop_info,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import format_docs
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
|
||||
from onyx.chat.models import AgentAnswerPiece
|
||||
from onyx.chat.models import ExtendedToolResponse
|
||||
from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS
|
||||
from onyx.configs.agent_configs import AGENT_MIN_ORIG_QUESTION_DOCS
|
||||
from onyx.tools.tool_implementations.search.search_tool import yield_search_responses
|
||||
|
||||
|
||||
def generate_refined_answer(
|
||||
state: MainState, config: RunnableConfig
|
||||
) -> RefinedAnswerUpdate:
|
||||
now_start = datetime.now()
|
||||
|
||||
logger.info(f"--------{now_start}--------GENERATE REFINED ANSWER---")
|
||||
|
||||
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
question = agent_a_config.search_request.query
|
||||
prompt_enrichment_components = get_prompt_enrichment_components(agent_a_config)
|
||||
|
||||
persona_contextualized_prompt = (
|
||||
prompt_enrichment_components.persona_prompts.contextualized_prompt
|
||||
)
|
||||
|
||||
initial_documents = state.documents
|
||||
refined_documents = state.refined_documents
|
||||
sub_questions_cited_docs = state.cited_docs
|
||||
all_original_question_documents = state.all_original_question_documents
|
||||
|
||||
consolidated_context_docs: list[InferenceSection] = sub_questions_cited_docs
|
||||
|
||||
counter = 0
|
||||
for original_doc_number, original_doc in enumerate(all_original_question_documents):
|
||||
if original_doc_number not in sub_questions_cited_docs:
|
||||
if (
|
||||
counter <= AGENT_MIN_ORIG_QUESTION_DOCS
|
||||
or len(consolidated_context_docs)
|
||||
< 1.5
|
||||
* AGENT_MAX_ANSWER_CONTEXT_DOCS # allow for larger context in refinement
|
||||
):
|
||||
consolidated_context_docs.append(original_doc)
|
||||
counter += 1
|
||||
|
||||
# sort docs by their scores - though the scores refer to different questions
|
||||
relevant_docs = dedup_inference_sections(
|
||||
consolidated_context_docs, consolidated_context_docs
|
||||
)
|
||||
|
||||
query_info = get_query_info(state.original_question_retrieval_results)
|
||||
if agent_a_config.search_tool is None:
|
||||
raise ValueError("search_tool must be provided for agentic search")
|
||||
# stream refined answer docs
|
||||
for tool_response in yield_search_responses(
|
||||
query=question,
|
||||
reranked_sections=relevant_docs,
|
||||
final_context_sections=relevant_docs,
|
||||
search_query_info=query_info,
|
||||
get_section_relevance=lambda: None, # TODO: add relevance
|
||||
search_tool=agent_a_config.search_tool,
|
||||
):
|
||||
dispatch_custom_event(
|
||||
"tool_response",
|
||||
ExtendedToolResponse(
|
||||
id=tool_response.id,
|
||||
response=tool_response.response,
|
||||
level=1,
|
||||
level_question_nr=0, # 0, 0 is the base question
|
||||
),
|
||||
)
|
||||
|
||||
if len(initial_documents) > 0:
|
||||
revision_doc_effectiveness = len(relevant_docs) / len(initial_documents)
|
||||
elif len(refined_documents) == 0:
|
||||
revision_doc_effectiveness = 0.0
|
||||
else:
|
||||
revision_doc_effectiveness = 10.0
|
||||
|
||||
decomp_answer_results = state.decomp_answer_results
|
||||
# revised_answer_results = state.refined_decomp_answer_results
|
||||
|
||||
answered_qa_list: list[str] = []
|
||||
decomp_questions = []
|
||||
|
||||
initial_good_sub_questions: list[str] = []
|
||||
new_revised_good_sub_questions: list[str] = []
|
||||
|
||||
sub_question_nr = 1
|
||||
|
||||
for decomp_answer_result in decomp_answer_results:
|
||||
question_level, question_nr = parse_question_id(
|
||||
decomp_answer_result.question_id
|
||||
)
|
||||
|
||||
decomp_questions.append(decomp_answer_result.question)
|
||||
if (
|
||||
decomp_answer_result.quality.lower().startswith("yes")
|
||||
and len(decomp_answer_result.answer) > 0
|
||||
and decomp_answer_result.answer != UNKNOWN_ANSWER
|
||||
):
|
||||
answered_qa_list.append(
|
||||
SUB_QUESTION_ANSWER_TEMPLATE.format(
|
||||
sub_question=decomp_answer_result.question,
|
||||
sub_answer=decomp_answer_result.answer,
|
||||
sub_question_nr=sub_question_nr,
|
||||
)
|
||||
)
|
||||
if question_level == 0:
|
||||
initial_good_sub_questions.append(decomp_answer_result.question)
|
||||
else:
|
||||
new_revised_good_sub_questions.append(decomp_answer_result.question)
|
||||
|
||||
sub_question_nr += 1
|
||||
|
||||
initial_good_sub_questions = list(set(initial_good_sub_questions))
|
||||
new_revised_good_sub_questions = list(set(new_revised_good_sub_questions))
|
||||
total_good_sub_questions = list(
|
||||
set(initial_good_sub_questions + new_revised_good_sub_questions)
|
||||
)
|
||||
if len(initial_good_sub_questions) > 0:
|
||||
revision_question_efficiency: float = len(total_good_sub_questions) / len(
|
||||
initial_good_sub_questions
|
||||
)
|
||||
elif len(new_revised_good_sub_questions) > 0:
|
||||
revision_question_efficiency = 10.0
|
||||
else:
|
||||
revision_question_efficiency = 1.0
|
||||
|
||||
sub_question_answer_str = "\n\n------\n\n".join(list(set(answered_qa_list)))
|
||||
|
||||
# original answer
|
||||
|
||||
initial_answer = state.initial_answer
|
||||
|
||||
# Determine which persona-specification prompt to use
|
||||
|
||||
# Determine which base prompt to use given the sub-question information
|
||||
if len(answered_qa_list) > 0:
|
||||
base_prompt = REVISED_RAG_PROMPT
|
||||
else:
|
||||
base_prompt = REVISED_RAG_PROMPT_NO_SUB_QUESTIONS
|
||||
|
||||
model = agent_a_config.fast_llm
|
||||
relevant_docs_str = format_docs(relevant_docs)
|
||||
relevant_docs_str = trim_prompt_piece(
|
||||
model.config,
|
||||
relevant_docs_str,
|
||||
base_prompt
|
||||
+ question
|
||||
+ sub_question_answer_str
|
||||
+ initial_answer
|
||||
+ persona_contextualized_prompt
|
||||
+ prompt_enrichment_components.history,
|
||||
)
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=base_prompt.format(
|
||||
question=question,
|
||||
history=prompt_enrichment_components.history,
|
||||
answered_sub_questions=remove_document_citations(
|
||||
sub_question_answer_str
|
||||
),
|
||||
relevant_docs=relevant_docs,
|
||||
initial_answer=remove_document_citations(initial_answer),
|
||||
persona_specification=persona_contextualized_prompt,
|
||||
date_prompt=prompt_enrichment_components.date_str,
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
# Grader
|
||||
|
||||
streamed_tokens: list[str | list[str | dict[str, Any]]] = [""]
|
||||
dispatch_timings: list[float] = []
|
||||
for message in model.stream(msg):
|
||||
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
|
||||
content = message.content
|
||||
if not isinstance(content, str):
|
||||
raise ValueError(
|
||||
f"Expected content to be a string, but got {type(content)}"
|
||||
)
|
||||
|
||||
start_stream_token = datetime.now()
|
||||
dispatch_custom_event(
|
||||
"refined_agent_answer",
|
||||
AgentAnswerPiece(
|
||||
answer_piece=content,
|
||||
level=1,
|
||||
level_question_nr=0,
|
||||
answer_type="agent_level_answer",
|
||||
),
|
||||
)
|
||||
end_stream_token = datetime.now()
|
||||
dispatch_timings.append((end_stream_token - start_stream_token).microseconds)
|
||||
streamed_tokens.append(content)
|
||||
|
||||
logger.info(
|
||||
f"Average dispatch time for refined answer: {sum(dispatch_timings) / len(dispatch_timings)}"
|
||||
)
|
||||
dispatch_main_answer_stop_info(1)
|
||||
response = merge_content(*streamed_tokens)
|
||||
answer = cast(str, response)
|
||||
|
||||
# refined_agent_stats = _calculate_refined_agent_stats(
|
||||
# state.decomp_answer_results, state.original_question_retrieval_stats
|
||||
# )
|
||||
|
||||
refined_agent_stats = RefinedAgentStats(
|
||||
revision_doc_efficiency=revision_doc_effectiveness,
|
||||
revision_question_efficiency=revision_question_efficiency,
|
||||
)
|
||||
|
||||
logger.debug(f"\n\n---INITIAL ANSWER ---\n\n Answer:\n Agent: {initial_answer}")
|
||||
logger.debug("-" * 10)
|
||||
logger.debug(f"\n\n---REVISED AGENT ANSWER ---\n\n Answer:\n Agent: {answer}")
|
||||
|
||||
logger.debug("-" * 100)
|
||||
|
||||
if state.initial_agent_stats:
|
||||
initial_doc_boost_factor = state.initial_agent_stats.agent_effectiveness.get(
|
||||
"utilized_chunk_ratio", "--"
|
||||
)
|
||||
initial_support_boost_factor = (
|
||||
state.initial_agent_stats.agent_effectiveness.get("support_ratio", "--")
|
||||
)
|
||||
num_initial_verified_docs = state.initial_agent_stats.original_question.get(
|
||||
"num_verified_documents", "--"
|
||||
)
|
||||
initial_verified_docs_avg_score = (
|
||||
state.initial_agent_stats.original_question.get("verified_avg_score", "--")
|
||||
)
|
||||
initial_sub_questions_verified_docs = (
|
||||
state.initial_agent_stats.sub_questions.get("num_verified_documents", "--")
|
||||
)
|
||||
|
||||
logger.debug("INITIAL AGENT STATS")
|
||||
logger.debug(f"Document Boost Factor: {initial_doc_boost_factor}")
|
||||
logger.debug(f"Support Boost Factor: {initial_support_boost_factor}")
|
||||
logger.debug(f"Originally Verified Docs: {num_initial_verified_docs}")
|
||||
logger.debug(
|
||||
f"Originally Verified Docs Avg Score: {initial_verified_docs_avg_score}"
|
||||
)
|
||||
logger.debug(
|
||||
f"Sub-Questions Verified Docs: {initial_sub_questions_verified_docs}"
|
||||
)
|
||||
if refined_agent_stats:
|
||||
logger.debug("-" * 10)
|
||||
logger.debug("REFINED AGENT STATS")
|
||||
logger.debug(
|
||||
f"Revision Doc Factor: {refined_agent_stats.revision_doc_efficiency}"
|
||||
)
|
||||
logger.debug(
|
||||
f"Revision Question Factor: {refined_agent_stats.revision_question_efficiency}"
|
||||
)
|
||||
|
||||
now_end = datetime.now()
|
||||
|
||||
logger.debug(
|
||||
f"--------{now_end}--{now_end - now_start}--------INITIAL AGENT ANSWER END---\n\n"
|
||||
)
|
||||
|
||||
agent_refined_end_time = datetime.now()
|
||||
if state.agent_refined_start_time:
|
||||
agent_refined_duration = (
|
||||
agent_refined_end_time - state.agent_refined_start_time
|
||||
).total_seconds()
|
||||
else:
|
||||
agent_refined_duration = None
|
||||
|
||||
agent_refined_metrics = AgentRefinedMetrics(
|
||||
refined_doc_boost_factor=refined_agent_stats.revision_doc_efficiency,
|
||||
refined_question_boost_factor=refined_agent_stats.revision_question_efficiency,
|
||||
duration__s=agent_refined_duration,
|
||||
)
|
||||
|
||||
now_end = datetime.now()
|
||||
|
||||
logger.info(
|
||||
f"{now_start} -- MAIN - Generate refined answer, Time taken: {now_end - now_start}"
|
||||
)
|
||||
|
||||
return RefinedAnswerUpdate(
|
||||
refined_answer=answer,
|
||||
refined_answer_quality=True, # TODO: replace this with the actual check value
|
||||
refined_agent_stats=refined_agent_stats,
|
||||
agent_refined_end_time=agent_refined_end_time,
|
||||
agent_refined_metrics=agent_refined_metrics,
|
||||
log_messages=[
|
||||
f"{now_start} -- MAIN - Generate refined answer, Time taken: {now_end - now_start}"
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,41 @@
|
||||
from datetime import datetime
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.initial__individual_sub_answer__subgraph.states import (
|
||||
AnswerQuestionOutput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main__graph.operations import logger
|
||||
from onyx.agents.agent_search.deep_search_a.main__graph.states import (
|
||||
DecompAnswersUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.operators import (
|
||||
dedup_inference_sections,
|
||||
)
|
||||
|
||||
|
||||
def ingest_refined_answers(
|
||||
state: AnswerQuestionOutput,
|
||||
) -> DecompAnswersUpdate:
|
||||
now_start = datetime.now()
|
||||
|
||||
logger.info(f"--------{now_start}--------INGEST FOLLOW UP ANSWERS---")
|
||||
|
||||
documents = []
|
||||
answer_results = state.answer_results if hasattr(state, "answer_results") else []
|
||||
for answer_result in answer_results:
|
||||
documents.extend(answer_result.documents)
|
||||
|
||||
now_end = datetime.now()
|
||||
|
||||
logger.debug(
|
||||
f"--------{now_end}--{now_end - now_start}--------INGEST FOLLOW UP ANSWERS END---"
|
||||
)
|
||||
|
||||
return DecompAnswersUpdate(
|
||||
# Deduping is done by the documents operator for the main graph
|
||||
# so we might not need to dedup here
|
||||
documents=dedup_inference_sections(documents, []),
|
||||
decomp_answer_results=answer_results,
|
||||
log_messages=[
|
||||
f"{now_start} -- Main - Ingest refined answers, Time taken: {now_end - now_start}"
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,47 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.main__graph.operations import logger
|
||||
from onyx.agents.agent_search.deep_search_a.main__graph.states import MainState
|
||||
from onyx.agents.agent_search.deep_search_a.main__graph.states import (
|
||||
RequireRefinedAnswerUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
|
||||
|
||||
def refined_answer_decision(
|
||||
state: MainState, config: RunnableConfig
|
||||
) -> RequireRefinedAnswerUpdate:
|
||||
now_start = datetime.now()
|
||||
|
||||
logger.info(f"--------{now_start}--------REFINED ANSWER DECISION---")
|
||||
|
||||
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
if "?" in agent_a_config.search_request.query:
|
||||
decision = False
|
||||
else:
|
||||
decision = True
|
||||
|
||||
decision = True
|
||||
|
||||
now_end = datetime.now()
|
||||
|
||||
logger.info(
|
||||
f"{now_start} -- MAIN - Refined answer decision, Time taken: {now_end - now_start}"
|
||||
)
|
||||
log_messages = [
|
||||
f"{now_start} -- Main - Refined answer decision: {decision}, Time taken: {now_end - now_start}"
|
||||
]
|
||||
if agent_a_config.allow_refinement:
|
||||
return RequireRefinedAnswerUpdate(
|
||||
require_refined_answer_eval=decision,
|
||||
log_messages=log_messages,
|
||||
)
|
||||
|
||||
else:
|
||||
return RequireRefinedAnswerUpdate(
|
||||
require_refined_answer_eval=False,
|
||||
log_messages=log_messages,
|
||||
)
|
||||
@@ -0,0 +1,122 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.callbacks.manager import dispatch_custom_event
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import merge_content
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.main__graph.models import (
|
||||
FollowUpSubQuestion,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main__graph.operations import (
|
||||
dispatch_subquestion,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main__graph.operations import logger
|
||||
from onyx.agents.agent_search.deep_search_a.main__graph.states import (
|
||||
FollowUpSubQuestionsUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main__graph.states import MainState
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
build_history_prompt,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import (
|
||||
DEEP_DECOMPOSE_PROMPT_WITH_ENTITIES,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import dispatch_separated
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
format_entity_term_extraction,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import make_question_id
|
||||
from onyx.tools.models import ToolCallKickoff
|
||||
|
||||
|
||||
def refined_sub_question_creation(
|
||||
state: MainState, config: RunnableConfig
|
||||
) -> FollowUpSubQuestionsUpdate:
|
||||
""" """
|
||||
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
dispatch_custom_event(
|
||||
"start_refined_answer_creation",
|
||||
ToolCallKickoff(
|
||||
tool_name="agent_search_1",
|
||||
tool_args={
|
||||
"query": agent_a_config.search_request.query,
|
||||
"answer": state.initial_answer,
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
now_start = datetime.now()
|
||||
|
||||
logger.info(f"--------{now_start}--------FOLLOW UP DECOMPOSE---")
|
||||
|
||||
agent_refined_start_time = datetime.now()
|
||||
|
||||
question = agent_a_config.search_request.query
|
||||
base_answer = state.initial_answer
|
||||
history = build_history_prompt(agent_a_config, question)
|
||||
# get the entity term extraction dict and properly format it
|
||||
entity_retlation_term_extractions = state.entity_relation_term_extractions
|
||||
|
||||
entity_term_extraction_str = format_entity_term_extraction(
|
||||
entity_retlation_term_extractions
|
||||
)
|
||||
|
||||
initial_question_answers = state.decomp_answer_results
|
||||
|
||||
addressed_question_list = [
|
||||
x.question for x in initial_question_answers if "yes" in x.quality.lower()
|
||||
]
|
||||
|
||||
failed_question_list = [
|
||||
x.question for x in initial_question_answers if "no" in x.quality.lower()
|
||||
]
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=DEEP_DECOMPOSE_PROMPT_WITH_ENTITIES.format(
|
||||
question=question,
|
||||
history=history,
|
||||
entity_term_extraction_str=entity_term_extraction_str,
|
||||
base_answer=base_answer,
|
||||
answered_sub_questions="\n - ".join(addressed_question_list),
|
||||
failed_sub_questions="\n - ".join(failed_question_list),
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
# Grader
|
||||
model = agent_a_config.fast_llm
|
||||
|
||||
streamed_tokens = dispatch_separated(model.stream(msg), dispatch_subquestion(1))
|
||||
response = merge_content(*streamed_tokens)
|
||||
|
||||
if isinstance(response, str):
|
||||
parsed_response = [q for q in response.split("\n") if q.strip() != ""]
|
||||
else:
|
||||
raise ValueError("LLM response is not a string")
|
||||
|
||||
refined_sub_question_dict = {}
|
||||
for sub_question_nr, sub_question in enumerate(parsed_response):
|
||||
refined_sub_question = FollowUpSubQuestion(
|
||||
sub_question=sub_question,
|
||||
sub_question_id=make_question_id(1, sub_question_nr + 1),
|
||||
verified=False,
|
||||
answered=False,
|
||||
answer="",
|
||||
)
|
||||
|
||||
refined_sub_question_dict[sub_question_nr + 1] = refined_sub_question
|
||||
|
||||
now_end = datetime.now()
|
||||
|
||||
logger.info(
|
||||
f"{now_start} -- MAIN - Refined sub question creation, Time taken: {now_end - now_start}"
|
||||
)
|
||||
|
||||
return FollowUpSubQuestionsUpdate(
|
||||
refined_sub_questions=refined_sub_question_dict,
|
||||
agent_refined_start_time=agent_refined_start_time,
|
||||
)
|
||||
@@ -0,0 +1,145 @@
|
||||
import re
|
||||
from collections.abc import Callable
|
||||
|
||||
from langchain_core.callbacks.manager import dispatch_custom_event
|
||||
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkStats
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import InitialAgentResultStats
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import QueryResult
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import (
|
||||
QuestionAnswerResults,
|
||||
)
|
||||
from onyx.chat.models import SubQuestionPiece
|
||||
from onyx.tools.models import SearchQueryInfo
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def remove_document_citations(text: str) -> str:
|
||||
"""
|
||||
Removes citation expressions of format '[[D1]]()' from text.
|
||||
The number after D can vary.
|
||||
|
||||
Args:
|
||||
text: Input text containing citations
|
||||
|
||||
Returns:
|
||||
Text with citations removed
|
||||
"""
|
||||
# Pattern explanation:
|
||||
# \[\[D\d+\]\]\(\) matches:
|
||||
# \[\[ - literal [[ characters
|
||||
# D - literal D character
|
||||
# \d+ - one or more digits
|
||||
# \]\] - literal ]] characters
|
||||
# \(\) - literal () characters
|
||||
return re.sub(r"\[\[(?:D|Q)\d+\]\]\(\)", "", text)
|
||||
|
||||
|
||||
def dispatch_subquestion(level: int) -> Callable[[str, int], None]:
|
||||
def _helper(sub_question_part: str, num: int) -> None:
|
||||
dispatch_custom_event(
|
||||
"decomp_qs",
|
||||
SubQuestionPiece(
|
||||
sub_question=sub_question_part,
|
||||
level=level,
|
||||
level_question_nr=num,
|
||||
),
|
||||
)
|
||||
|
||||
return _helper
|
||||
|
||||
|
||||
def calculate_initial_agent_stats(
|
||||
decomp_answer_results: list[QuestionAnswerResults],
|
||||
original_question_stats: AgentChunkStats,
|
||||
) -> InitialAgentResultStats:
|
||||
initial_agent_result_stats: InitialAgentResultStats = InitialAgentResultStats(
|
||||
sub_questions={},
|
||||
original_question={},
|
||||
agent_effectiveness={},
|
||||
)
|
||||
|
||||
orig_verified = original_question_stats.verified_count
|
||||
orig_support_score = original_question_stats.verified_avg_scores
|
||||
|
||||
verified_document_chunk_ids = []
|
||||
support_scores = 0.0
|
||||
|
||||
for decomp_answer_result in decomp_answer_results:
|
||||
verified_document_chunk_ids += (
|
||||
decomp_answer_result.sub_question_retrieval_stats.verified_doc_chunk_ids
|
||||
)
|
||||
if (
|
||||
decomp_answer_result.sub_question_retrieval_stats.verified_avg_scores
|
||||
is not None
|
||||
):
|
||||
support_scores += (
|
||||
decomp_answer_result.sub_question_retrieval_stats.verified_avg_scores
|
||||
)
|
||||
|
||||
verified_document_chunk_ids = list(set(verified_document_chunk_ids))
|
||||
|
||||
# Calculate sub-question stats
|
||||
if (
|
||||
verified_document_chunk_ids
|
||||
and len(verified_document_chunk_ids) > 0
|
||||
and support_scores is not None
|
||||
):
|
||||
sub_question_stats: dict[str, float | int | None] = {
|
||||
"num_verified_documents": len(verified_document_chunk_ids),
|
||||
"verified_avg_score": float(support_scores / len(decomp_answer_results)),
|
||||
}
|
||||
else:
|
||||
sub_question_stats = {"num_verified_documents": 0, "verified_avg_score": None}
|
||||
|
||||
initial_agent_result_stats.sub_questions.update(sub_question_stats)
|
||||
|
||||
# Get original question stats
|
||||
initial_agent_result_stats.original_question.update(
|
||||
{
|
||||
"num_verified_documents": original_question_stats.verified_count,
|
||||
"verified_avg_score": original_question_stats.verified_avg_scores,
|
||||
}
|
||||
)
|
||||
|
||||
# Calculate chunk utilization ratio
|
||||
sub_verified = initial_agent_result_stats.sub_questions["num_verified_documents"]
|
||||
|
||||
chunk_ratio: float | None = None
|
||||
if sub_verified is not None and orig_verified is not None and orig_verified > 0:
|
||||
chunk_ratio = (float(sub_verified) / orig_verified) if sub_verified > 0 else 0.0
|
||||
elif sub_verified is not None and sub_verified > 0:
|
||||
chunk_ratio = 10.0
|
||||
|
||||
initial_agent_result_stats.agent_effectiveness["utilized_chunk_ratio"] = chunk_ratio
|
||||
|
||||
if (
|
||||
orig_support_score is None
|
||||
or orig_support_score == 0.0
|
||||
and initial_agent_result_stats.sub_questions["verified_avg_score"] is None
|
||||
):
|
||||
initial_agent_result_stats.agent_effectiveness["support_ratio"] = None
|
||||
elif orig_support_score is None or orig_support_score == 0.0:
|
||||
initial_agent_result_stats.agent_effectiveness["support_ratio"] = 10
|
||||
elif initial_agent_result_stats.sub_questions["verified_avg_score"] is None:
|
||||
initial_agent_result_stats.agent_effectiveness["support_ratio"] = 0
|
||||
else:
|
||||
initial_agent_result_stats.agent_effectiveness["support_ratio"] = (
|
||||
initial_agent_result_stats.sub_questions["verified_avg_score"]
|
||||
/ orig_support_score
|
||||
)
|
||||
|
||||
return initial_agent_result_stats
|
||||
|
||||
|
||||
def get_query_info(results: list[QueryResult]) -> SearchQueryInfo:
|
||||
# Use the query info from the base document retrieval
|
||||
# TODO: see if this is the right way to do this
|
||||
query_infos = [
|
||||
result.query_info for result in results if result.query_info is not None
|
||||
]
|
||||
if len(query_infos) == 0:
|
||||
raise ValueError("No query info found")
|
||||
return query_infos[0]
|
||||
@@ -0,0 +1,180 @@
|
||||
from datetime import datetime
|
||||
from operator import add
|
||||
from typing import Annotated
|
||||
from typing import TypedDict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.agents.agent_search.core_state import CoreState
|
||||
from onyx.agents.agent_search.deep_search_a.main__graph.models import AgentBaseMetrics
|
||||
from onyx.agents.agent_search.deep_search_a.main__graph.models import (
|
||||
AgentRefinedMetrics,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main__graph.models import (
|
||||
FollowUpSubQuestion,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.util__expanded_retrieval__subgraph.models import (
|
||||
ExpandedRetrievalResult,
|
||||
)
|
||||
from onyx.agents.agent_search.orchestration.states import ToolCallUpdate
|
||||
from onyx.agents.agent_search.orchestration.states import ToolChoiceInput
|
||||
from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkStats
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import (
|
||||
EntityRelationshipTermExtraction,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import InitialAgentResultStats
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import QueryResult
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import (
|
||||
QuestionAnswerResults,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import RefinedAgentStats
|
||||
from onyx.agents.agent_search.shared_graph_utils.operators import (
|
||||
dedup_inference_sections,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.operators import (
|
||||
dedup_question_answer_results,
|
||||
)
|
||||
from onyx.context.search.models import InferenceSection
|
||||
|
||||
### States ###
|
||||
|
||||
## Update States
|
||||
|
||||
|
||||
class LoggerUpdate(BaseModel):
|
||||
log_messages: Annotated[list[str], add] = []
|
||||
|
||||
|
||||
class RefinedAgentStartStats(BaseModel):
|
||||
agent_refined_start_time: datetime | None = None
|
||||
|
||||
|
||||
class RefinedAgentEndStats(BaseModel):
|
||||
agent_refined_end_time: datetime | None = None
|
||||
agent_refined_metrics: AgentRefinedMetrics = AgentRefinedMetrics()
|
||||
|
||||
|
||||
class BaseDecompUpdate(RefinedAgentStartStats, RefinedAgentEndStats):
|
||||
agent_start_time: datetime = datetime.now()
|
||||
previous_history: str = ""
|
||||
initial_decomp_questions: list[str] = []
|
||||
|
||||
|
||||
class ExploratorySearchUpdate(LoggerUpdate):
|
||||
exploratory_search_results: list[InferenceSection] = []
|
||||
previous_history_summary: str = ""
|
||||
|
||||
|
||||
class AnswerComparison(LoggerUpdate):
|
||||
refined_answer_improvement_eval: bool = False
|
||||
|
||||
|
||||
class RoutingDecision(LoggerUpdate):
|
||||
routing_decision: str = ""
|
||||
|
||||
|
||||
# Not used in current graph
|
||||
class InitialAnswerBASEUpdate(BaseModel):
|
||||
initial_base_answer: str = ""
|
||||
|
||||
|
||||
class InitialAnswerUpdate(LoggerUpdate):
|
||||
initial_answer: str = ""
|
||||
initial_agent_stats: InitialAgentResultStats | None = None
|
||||
generated_sub_questions: list[str] = []
|
||||
agent_base_end_time: datetime | None = None
|
||||
agent_base_metrics: AgentBaseMetrics | None = None
|
||||
|
||||
|
||||
class RefinedAnswerUpdate(RefinedAgentEndStats, LoggerUpdate):
|
||||
refined_answer: str = ""
|
||||
refined_agent_stats: RefinedAgentStats | None = None
|
||||
refined_answer_quality: bool = False
|
||||
|
||||
|
||||
class InitialAnswerQualityUpdate(LoggerUpdate):
|
||||
initial_answer_quality_eval: bool = False
|
||||
|
||||
|
||||
class RequireRefinedAnswerUpdate(LoggerUpdate):
|
||||
require_refined_answer_eval: bool = True
|
||||
|
||||
|
||||
class DecompAnswersUpdate(LoggerUpdate):
|
||||
documents: Annotated[list[InferenceSection], dedup_inference_sections] = []
|
||||
context_documents: Annotated[list[InferenceSection], dedup_inference_sections] = []
|
||||
cited_docs: Annotated[
|
||||
list[InferenceSection], dedup_inference_sections
|
||||
] = [] # cited docs from sub-answers are used for answer context
|
||||
decomp_answer_results: Annotated[
|
||||
list[QuestionAnswerResults], dedup_question_answer_results
|
||||
] = []
|
||||
|
||||
|
||||
class FollowUpDecompAnswersUpdate(LoggerUpdate):
|
||||
refined_documents: Annotated[list[InferenceSection], dedup_inference_sections] = []
|
||||
refined_decomp_answer_results: Annotated[list[QuestionAnswerResults], add] = []
|
||||
|
||||
|
||||
class ExpandedRetrievalUpdate(LoggerUpdate):
|
||||
all_original_question_documents: Annotated[
|
||||
list[InferenceSection], dedup_inference_sections
|
||||
]
|
||||
original_question_retrieval_results: list[QueryResult] = []
|
||||
original_question_retrieval_stats: AgentChunkStats = AgentChunkStats()
|
||||
|
||||
|
||||
class EntityTermExtractionUpdate(LoggerUpdate):
|
||||
entity_relation_term_extractions: EntityRelationshipTermExtraction = (
|
||||
EntityRelationshipTermExtraction()
|
||||
)
|
||||
|
||||
|
||||
class FollowUpSubQuestionsUpdate(RefinedAgentStartStats):
|
||||
refined_sub_questions: dict[int, FollowUpSubQuestion] = {}
|
||||
|
||||
|
||||
## Graph Input State
|
||||
## Graph Input State
|
||||
|
||||
|
||||
class MainInput(CoreState):
|
||||
pass
|
||||
|
||||
|
||||
## Graph State
|
||||
|
||||
|
||||
class MainState(
|
||||
# This includes the core state
|
||||
MainInput,
|
||||
ToolChoiceInput,
|
||||
ToolCallUpdate,
|
||||
ToolChoiceUpdate,
|
||||
BaseDecompUpdate,
|
||||
InitialAnswerUpdate,
|
||||
InitialAnswerBASEUpdate,
|
||||
DecompAnswersUpdate,
|
||||
ExpandedRetrievalUpdate,
|
||||
EntityTermExtractionUpdate,
|
||||
InitialAnswerQualityUpdate,
|
||||
RequireRefinedAnswerUpdate,
|
||||
FollowUpSubQuestionsUpdate,
|
||||
FollowUpDecompAnswersUpdate,
|
||||
RefinedAnswerUpdate,
|
||||
RefinedAgentStartStats,
|
||||
RefinedAgentEndStats,
|
||||
RoutingDecision,
|
||||
AnswerComparison,
|
||||
ExploratorySearchUpdate,
|
||||
):
|
||||
# expanded_retrieval_result: Annotated[list[ExpandedRetrievalResult], add]
|
||||
base_raw_search_result: Annotated[list[ExpandedRetrievalResult], add]
|
||||
|
||||
|
||||
## Graph Output State - presently not used
|
||||
|
||||
|
||||
class MainOutput(TypedDict):
|
||||
log_messages: list[str]
|
||||
@@ -0,0 +1,28 @@
|
||||
from collections.abc import Hashable
|
||||
from datetime import datetime
|
||||
|
||||
from langgraph.types import Send
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.initial__individual_sub_answer__subgraph.states import (
|
||||
AnswerQuestionInput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.util__expanded_retrieval__subgraph.states import (
|
||||
ExpandedRetrievalInput,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def send_to_expanded_refined_retrieval(state: AnswerQuestionInput) -> Send | Hashable:
|
||||
logger.debug("sending to expanded retrieval for follow up question via edge")
|
||||
datetime.now()
|
||||
return Send(
|
||||
"refined_sub_question_expanded_retrieval",
|
||||
ExpandedRetrievalInput(
|
||||
question=state.question,
|
||||
sub_question_id=state.question_id,
|
||||
base_search=False,
|
||||
log_messages=[f"{datetime.now()} -- Sending to expanded retrieval"],
|
||||
),
|
||||
)
|
||||
@@ -0,0 +1,123 @@
|
||||
from langgraph.graph import END
|
||||
from langgraph.graph import START
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.initial__individual_sub_answer__subgraph.nodes.answer_check import (
|
||||
answer_check,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.initial__individual_sub_answer__subgraph.nodes.answer_generation import (
|
||||
answer_generation,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.initial__individual_sub_answer__subgraph.nodes.format_answer import (
|
||||
format_answer,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.initial__individual_sub_answer__subgraph.nodes.ingest_retrieval import (
|
||||
ingest_retrieval,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.initial__individual_sub_answer__subgraph.states import (
|
||||
AnswerQuestionInput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.initial__individual_sub_answer__subgraph.states import (
|
||||
AnswerQuestionOutput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.initial__individual_sub_answer__subgraph.states import (
|
||||
AnswerQuestionState,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.refinement__consolidate_sub_answers__subgraph.edges import (
|
||||
send_to_expanded_refined_retrieval,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.util__expanded_retrieval__subgraph.graph_builder import (
|
||||
expanded_retrieval_graph_builder,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def answer_refined_query_graph_builder() -> StateGraph:
|
||||
graph = StateGraph(
|
||||
state_schema=AnswerQuestionState,
|
||||
input=AnswerQuestionInput,
|
||||
output=AnswerQuestionOutput,
|
||||
)
|
||||
|
||||
### Add nodes ###
|
||||
|
||||
expanded_retrieval = expanded_retrieval_graph_builder().compile()
|
||||
graph.add_node(
|
||||
node="refined_sub_question_expanded_retrieval",
|
||||
action=expanded_retrieval,
|
||||
)
|
||||
graph.add_node(
|
||||
node="refined_sub_answer_check",
|
||||
action=answer_check,
|
||||
)
|
||||
graph.add_node(
|
||||
node="refined_sub_answer_generation",
|
||||
action=answer_generation,
|
||||
)
|
||||
graph.add_node(
|
||||
node="format_refined_sub_answer",
|
||||
action=format_answer,
|
||||
)
|
||||
graph.add_node(
|
||||
node="ingest_refined_retrieval",
|
||||
action=ingest_retrieval,
|
||||
)
|
||||
|
||||
### Add edges ###
|
||||
|
||||
graph.add_conditional_edges(
|
||||
source=START,
|
||||
path=send_to_expanded_refined_retrieval,
|
||||
path_map=["refined_sub_question_expanded_retrieval"],
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="refined_sub_question_expanded_retrieval",
|
||||
end_key="ingest_refined_retrieval",
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="ingest_refined_retrieval",
|
||||
end_key="refined_sub_answer_generation",
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="refined_sub_answer_generation",
|
||||
end_key="refined_sub_answer_check",
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="refined_sub_answer_check",
|
||||
end_key="format_refined_sub_answer",
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="format_refined_sub_answer",
|
||||
end_key=END,
|
||||
)
|
||||
|
||||
return graph
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.llm.factory import get_default_llms
|
||||
from onyx.context.search.models import SearchRequest
|
||||
|
||||
graph = answer_refined_query_graph_builder()
|
||||
compiled_graph = graph.compile()
|
||||
primary_llm, fast_llm = get_default_llms()
|
||||
search_request = SearchRequest(
|
||||
query="what can you do with onyx or danswer?",
|
||||
)
|
||||
with get_session_context_manager() as db_session:
|
||||
inputs = AnswerQuestionInput(
|
||||
question="what can you do with onyx?",
|
||||
question_id="0_0",
|
||||
log_messages=[],
|
||||
)
|
||||
for thing in compiled_graph.stream(
|
||||
input=inputs,
|
||||
# debug=True,
|
||||
# subgraphs=True,
|
||||
):
|
||||
logger.debug(thing)
|
||||
# output = compiled_graph.invoke(inputs)
|
||||
# logger.debug(output)
|
||||
@@ -0,0 +1,19 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkStats
|
||||
from onyx.context.search.models import InferenceSection
|
||||
|
||||
### Models ###
|
||||
|
||||
|
||||
class AnswerRetrievalStats(BaseModel):
|
||||
answer_retrieval_stats: dict[str, float | int]
|
||||
|
||||
|
||||
class QuestionAnswerResults(BaseModel):
|
||||
question: str
|
||||
answer: str
|
||||
quality: str
|
||||
# expanded_retrieval_results: list[QueryResult]
|
||||
documents: list[InferenceSection]
|
||||
sub_question_retrieval_stats: AgentChunkStats
|
||||
@@ -0,0 +1,37 @@
|
||||
from collections.abc import Hashable
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
from langgraph.types import Send
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.util__expanded_retrieval__subgraph.states import (
|
||||
ExpandedRetrievalState,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.util__expanded_retrieval__subgraph.states import (
|
||||
RetrievalInput,
|
||||
)
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
|
||||
|
||||
def parallel_retrieval_edge(
|
||||
state: ExpandedRetrievalState, config: RunnableConfig
|
||||
) -> list[Send | Hashable]:
|
||||
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
question = state.question if state.question else agent_a_config.search_request.query
|
||||
|
||||
query_expansions = (
|
||||
state.expanded_queries if state.expanded_queries else [] + [question]
|
||||
)
|
||||
return [
|
||||
Send(
|
||||
"doc_retrieval",
|
||||
RetrievalInput(
|
||||
query_to_retrieve=query,
|
||||
question=question,
|
||||
base_search=False,
|
||||
sub_question_id=state.sub_question_id,
|
||||
log_messages=[],
|
||||
),
|
||||
)
|
||||
for query in query_expansions
|
||||
]
|
||||
@@ -0,0 +1,147 @@
|
||||
from langgraph.graph import END
|
||||
from langgraph.graph import START
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.util__expanded_retrieval__subgraph.edges import (
|
||||
parallel_retrieval_edge,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.util__expanded_retrieval__subgraph.nodes.doc_reranking import (
|
||||
doc_reranking,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.util__expanded_retrieval__subgraph.nodes.doc_retrieval import (
|
||||
doc_retrieval,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.util__expanded_retrieval__subgraph.nodes.doc_verification import (
|
||||
doc_verification,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.util__expanded_retrieval__subgraph.nodes.dummy import (
|
||||
dummy,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.util__expanded_retrieval__subgraph.nodes.expand_queries import (
|
||||
expand_queries,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.util__expanded_retrieval__subgraph.nodes.format_results import (
|
||||
format_results,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.util__expanded_retrieval__subgraph.nodes.verification_kickoff import (
|
||||
verification_kickoff,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.util__expanded_retrieval__subgraph.states import (
|
||||
ExpandedRetrievalInput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.util__expanded_retrieval__subgraph.states import (
|
||||
ExpandedRetrievalOutput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.util__expanded_retrieval__subgraph.states import (
|
||||
ExpandedRetrievalState,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def expanded_retrieval_graph_builder() -> StateGraph:
|
||||
graph = StateGraph(
|
||||
state_schema=ExpandedRetrievalState,
|
||||
input=ExpandedRetrievalInput,
|
||||
output=ExpandedRetrievalOutput,
|
||||
)
|
||||
|
||||
### Add nodes ###
|
||||
|
||||
graph.add_node(
|
||||
node="expand_queries",
|
||||
action=expand_queries,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
node="dummy",
|
||||
action=dummy,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
node="doc_retrieval",
|
||||
action=doc_retrieval,
|
||||
)
|
||||
graph.add_node(
|
||||
node="verification_kickoff",
|
||||
action=verification_kickoff,
|
||||
)
|
||||
graph.add_node(
|
||||
node="doc_verification",
|
||||
action=doc_verification,
|
||||
)
|
||||
graph.add_node(
|
||||
node="doc_reranking",
|
||||
action=doc_reranking,
|
||||
)
|
||||
graph.add_node(
|
||||
node="format_results",
|
||||
action=format_results,
|
||||
)
|
||||
|
||||
### Add edges ###
|
||||
graph.add_edge(
|
||||
start_key=START,
|
||||
end_key="expand_queries",
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="expand_queries",
|
||||
end_key="dummy",
|
||||
)
|
||||
|
||||
graph.add_conditional_edges(
|
||||
source="dummy",
|
||||
path=parallel_retrieval_edge,
|
||||
path_map=["doc_retrieval"],
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="doc_retrieval",
|
||||
end_key="verification_kickoff",
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="doc_verification",
|
||||
end_key="doc_reranking",
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="doc_reranking",
|
||||
end_key="format_results",
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="format_results",
|
||||
end_key=END,
|
||||
)
|
||||
|
||||
return graph
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.llm.factory import get_default_llms
|
||||
from onyx.context.search.models import SearchRequest
|
||||
|
||||
graph = expanded_retrieval_graph_builder()
|
||||
compiled_graph = graph.compile()
|
||||
primary_llm, fast_llm = get_default_llms()
|
||||
search_request = SearchRequest(
|
||||
query="what can you do with onyx or danswer?",
|
||||
)
|
||||
|
||||
with get_session_context_manager() as db_session:
|
||||
agent_a_config, search_tool = get_test_config(
|
||||
db_session, primary_llm, fast_llm, search_request
|
||||
)
|
||||
inputs = ExpandedRetrievalInput(
|
||||
question="what can you do with onyx?",
|
||||
base_search=False,
|
||||
sub_question_id=None,
|
||||
log_messages=[],
|
||||
)
|
||||
for thing in compiled_graph.stream(
|
||||
input=inputs,
|
||||
config={"configurable": {"config": agent_a_config}},
|
||||
# debug=True,
|
||||
subgraphs=True,
|
||||
):
|
||||
logger.debug(thing)
|
||||
@@ -0,0 +1,12 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkStats
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import QueryResult
|
||||
from onyx.context.search.models import InferenceSection
|
||||
|
||||
|
||||
class ExpandedRetrievalResult(BaseModel):
|
||||
expanded_queries_results: list[QueryResult] = []
|
||||
reranked_documents: list[InferenceSection] = []
|
||||
context_documents: list[InferenceSection] = []
|
||||
sub_question_retrieval_stats: AgentChunkStats = AgentChunkStats()
|
||||
@@ -0,0 +1,96 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.util__expanded_retrieval__subgraph.operations import (
|
||||
logger,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.util__expanded_retrieval__subgraph.states import (
|
||||
DocRerankingUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.util__expanded_retrieval__subgraph.states import (
|
||||
ExpandedRetrievalState,
|
||||
)
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.calculations import get_fit_scores
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import RetrievalFitStats
|
||||
from onyx.configs.agent_configs import AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS
|
||||
from onyx.configs.agent_configs import AGENT_RERANKING_STATS
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.context.search.models import SearchRequest
|
||||
from onyx.context.search.pipeline import retrieval_preprocessing
|
||||
from onyx.context.search.postprocessing.postprocessing import rerank_sections
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
|
||||
|
||||
def doc_reranking(
|
||||
state: ExpandedRetrievalState, config: RunnableConfig
|
||||
) -> DocRerankingUpdate:
|
||||
now_start = datetime.now()
|
||||
verified_documents = state.verified_documents
|
||||
|
||||
# Rerank post retrieval and verification. First, create a search query
|
||||
# then create the list of reranked sections
|
||||
|
||||
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
question = state.question if state.question else agent_a_config.search_request.query
|
||||
if agent_a_config.search_tool is None:
|
||||
raise ValueError("search_tool must be provided for agentic search")
|
||||
with get_session_context_manager() as db_session:
|
||||
# we ignore some of the user specified fields since this search is
|
||||
# internal to agentic search, but we still want to pass through
|
||||
# persona (for stuff like document sets) and rerank settings
|
||||
# (to not make an unnecessary db call).
|
||||
search_request = SearchRequest(
|
||||
query=question,
|
||||
persona=agent_a_config.search_request.persona,
|
||||
rerank_settings=agent_a_config.search_request.rerank_settings,
|
||||
)
|
||||
_search_query = retrieval_preprocessing(
|
||||
search_request=search_request,
|
||||
user=agent_a_config.search_tool.user, # bit of a hack
|
||||
llm=agent_a_config.fast_llm,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# skip section filtering
|
||||
|
||||
if (
|
||||
_search_query.rerank_settings
|
||||
and _search_query.rerank_settings.rerank_model_name
|
||||
and _search_query.rerank_settings.num_rerank > 0
|
||||
and len(verified_documents) > 0
|
||||
):
|
||||
if len(verified_documents) > 1:
|
||||
reranked_documents = rerank_sections(
|
||||
_search_query,
|
||||
verified_documents,
|
||||
)
|
||||
else:
|
||||
num = "No" if len(verified_documents) == 0 else "One"
|
||||
logger.warning(f"{num} verified document(s) found, skipping reranking")
|
||||
reranked_documents = verified_documents
|
||||
else:
|
||||
logger.warning("No reranking settings found, using unranked documents")
|
||||
reranked_documents = verified_documents
|
||||
|
||||
if AGENT_RERANKING_STATS:
|
||||
fit_scores = get_fit_scores(verified_documents, reranked_documents)
|
||||
else:
|
||||
fit_scores = RetrievalFitStats(fit_score_lift=0, rerank_effect=0, fit_scores={})
|
||||
|
||||
# TODO: stream deduped docs here, or decide to use search tool ranking/verification
|
||||
now_end = datetime.now()
|
||||
logger.info(
|
||||
f"{now_start} -- Expanded Retrieval - Reranking - Time taken: {now_end - now_start}"
|
||||
)
|
||||
return DocRerankingUpdate(
|
||||
reranked_documents=[
|
||||
doc for doc in reranked_documents if type(doc) == InferenceSection
|
||||
][:AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS],
|
||||
sub_question_retrieval_stats=fit_scores,
|
||||
log_messages=[
|
||||
f"{now_start} -- Expanded Retrieval - Reranking - Time taken: {now_end - now_start}"
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,110 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.util__expanded_retrieval__subgraph.operations import (
|
||||
logger,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.util__expanded_retrieval__subgraph.states import (
|
||||
DocRetrievalUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.util__expanded_retrieval__subgraph.states import (
|
||||
RetrievalInput,
|
||||
)
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.calculations import get_fit_scores
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import QueryResult
|
||||
from onyx.configs.agent_configs import AGENT_MAX_QUERY_RETRIEVAL_RESULTS
|
||||
from onyx.configs.agent_configs import AGENT_RETRIEVAL_STATS
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.tools.models import SearchQueryInfo
|
||||
from onyx.tools.tool_implementations.search.search_tool import (
|
||||
SEARCH_RESPONSE_SUMMARY_ID,
|
||||
)
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary
|
||||
|
||||
|
||||
def doc_retrieval(state: RetrievalInput, config: RunnableConfig) -> DocRetrievalUpdate:
|
||||
"""
|
||||
Retrieve documents
|
||||
|
||||
Args:
|
||||
state (RetrievalInput): Primary state + the query to retrieve
|
||||
config (RunnableConfig): Configuration containing ProSearchConfig
|
||||
|
||||
Updates:
|
||||
expanded_retrieval_results: list[ExpandedRetrievalResult]
|
||||
retrieved_documents: list[InferenceSection]
|
||||
"""
|
||||
now_start = datetime.now()
|
||||
query_to_retrieve = state.query_to_retrieve
|
||||
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
search_tool = agent_a_config.search_tool
|
||||
|
||||
retrieved_docs: list[InferenceSection] = []
|
||||
if not query_to_retrieve.strip():
|
||||
logger.warning("Empty query, skipping retrieval")
|
||||
now_end = datetime.now()
|
||||
return DocRetrievalUpdate(
|
||||
expanded_retrieval_results=[],
|
||||
retrieved_documents=[],
|
||||
log_messages=[
|
||||
f"{now_start} -- Expanded Retrieval - Retrieval - Empty Query - Time taken: {now_end - now_start}"
|
||||
],
|
||||
)
|
||||
|
||||
query_info = None
|
||||
if search_tool is None:
|
||||
raise ValueError("search_tool must be provided for agentic search")
|
||||
# new db session to avoid concurrency issues
|
||||
with get_session_context_manager() as db_session:
|
||||
for tool_response in search_tool.run(
|
||||
query=query_to_retrieve,
|
||||
force_no_rerank=True,
|
||||
alternate_db_session=db_session,
|
||||
):
|
||||
# get retrieved docs to send to the rest of the graph
|
||||
if tool_response.id == SEARCH_RESPONSE_SUMMARY_ID:
|
||||
response = cast(SearchResponseSummary, tool_response.response)
|
||||
retrieved_docs = response.top_sections
|
||||
query_info = SearchQueryInfo(
|
||||
predicted_search=response.predicted_search,
|
||||
final_filters=response.final_filters,
|
||||
recency_bias_multiplier=response.recency_bias_multiplier,
|
||||
)
|
||||
break
|
||||
|
||||
retrieved_docs = retrieved_docs[:AGENT_MAX_QUERY_RETRIEVAL_RESULTS]
|
||||
pre_rerank_docs = retrieved_docs
|
||||
if search_tool.search_pipeline is not None:
|
||||
pre_rerank_docs = (
|
||||
search_tool.search_pipeline._retrieved_sections or retrieved_docs
|
||||
)
|
||||
|
||||
if AGENT_RETRIEVAL_STATS:
|
||||
fit_scores = get_fit_scores(
|
||||
pre_rerank_docs,
|
||||
retrieved_docs,
|
||||
)
|
||||
else:
|
||||
fit_scores = None
|
||||
|
||||
expanded_retrieval_result = QueryResult(
|
||||
query=query_to_retrieve,
|
||||
search_results=retrieved_docs,
|
||||
stats=fit_scores,
|
||||
query_info=query_info,
|
||||
)
|
||||
now_end = datetime.now()
|
||||
logger.info(
|
||||
f"{now_start} -- Expanded Retrieval - Retrieval - Time taken: {now_end - now_start}"
|
||||
)
|
||||
return DocRetrievalUpdate(
|
||||
expanded_retrieval_results=[expanded_retrieval_result],
|
||||
retrieved_documents=retrieved_docs,
|
||||
log_messages=[
|
||||
f"{now_start} -- Expanded Retrieval - Retrieval - Time taken: {now_end - now_start}"
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,60 @@
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.util__expanded_retrieval__subgraph.states import (
|
||||
DocVerificationInput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.util__expanded_retrieval__subgraph.states import (
|
||||
DocVerificationUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
trim_prompt_piece,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import VERIFIER_PROMPT
|
||||
|
||||
|
||||
def doc_verification(
|
||||
state: DocVerificationInput, config: RunnableConfig
|
||||
) -> DocVerificationUpdate:
|
||||
"""
|
||||
Check whether the document is relevant for the original user question
|
||||
|
||||
Args:
|
||||
state (DocVerificationInput): The current state
|
||||
config (RunnableConfig): Configuration containing ProSearchConfig
|
||||
|
||||
Updates:
|
||||
verified_documents: list[InferenceSection]
|
||||
"""
|
||||
|
||||
question = state.question
|
||||
doc_to_verify = state.doc_to_verify
|
||||
document_content = doc_to_verify.combined_content
|
||||
|
||||
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
fast_llm = agent_a_config.fast_llm
|
||||
|
||||
document_content = trim_prompt_piece(
|
||||
fast_llm.config, document_content, VERIFIER_PROMPT + question
|
||||
)
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=VERIFIER_PROMPT.format(
|
||||
question=question, document_content=document_content
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
response = fast_llm.invoke(msg)
|
||||
|
||||
verified_documents = []
|
||||
if isinstance(response.content, str) and "yes" in response.content.lower():
|
||||
verified_documents.append(doc_to_verify)
|
||||
|
||||
return DocVerificationUpdate(
|
||||
verified_documents=verified_documents,
|
||||
)
|
||||
@@ -0,0 +1,16 @@
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.util__expanded_retrieval__subgraph.states import (
|
||||
ExpandedRetrievalState,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.util__expanded_retrieval__subgraph.states import (
|
||||
QueryExpansionUpdate,
|
||||
)
|
||||
|
||||
|
||||
def dummy(
|
||||
state: ExpandedRetrievalState, config: RunnableConfig
|
||||
) -> QueryExpansionUpdate:
|
||||
return QueryExpansionUpdate(
|
||||
expanded_queries=state.expanded_queries,
|
||||
)
|
||||
@@ -0,0 +1,74 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import merge_message_runs
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.util__expanded_retrieval__subgraph.operations import (
|
||||
dispatch_subquery,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.util__expanded_retrieval__subgraph.operations import (
|
||||
logger,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.util__expanded_retrieval__subgraph.states import (
|
||||
ExpandedRetrievalInput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.util__expanded_retrieval__subgraph.states import (
|
||||
QueryExpansionUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import (
|
||||
REWRITE_PROMPT_MULTI_ORIGINAL,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import dispatch_separated
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
|
||||
|
||||
|
||||
def expand_queries(
|
||||
state: ExpandedRetrievalInput, config: RunnableConfig
|
||||
) -> QueryExpansionUpdate:
|
||||
# Sometimes we want to expand the original question, sometimes we want to expand a sub-question.
|
||||
# When we are running this node on the original question, no question is explictly passed in.
|
||||
# Instead, we use the original question from the search request.
|
||||
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
now_start = datetime.now()
|
||||
question = (
|
||||
state.question
|
||||
if hasattr(state, "question")
|
||||
else agent_a_config.search_request.query
|
||||
)
|
||||
llm = agent_a_config.fast_llm
|
||||
chat_session_id = agent_a_config.chat_session_id
|
||||
sub_question_id = state.sub_question_id
|
||||
if sub_question_id is None:
|
||||
level, question_nr = 0, 0
|
||||
else:
|
||||
level, question_nr = parse_question_id(sub_question_id)
|
||||
|
||||
if chat_session_id is None:
|
||||
raise ValueError("chat_session_id must be provided for agent search")
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=REWRITE_PROMPT_MULTI_ORIGINAL.format(question=question),
|
||||
)
|
||||
]
|
||||
|
||||
llm_response_list = dispatch_separated(
|
||||
llm.stream(prompt=msg), dispatch_subquery(level, question_nr)
|
||||
)
|
||||
|
||||
llm_response = merge_message_runs(llm_response_list, chunk_separator="")[0].content
|
||||
|
||||
rewritten_queries = llm_response.split("\n")
|
||||
now_end = datetime.now()
|
||||
logger.info(
|
||||
f"{now_start} -- Expanded Retrieval - Query Expansion - Time taken: {now_end - now_start}"
|
||||
)
|
||||
return QueryExpansionUpdate(
|
||||
expanded_queries=rewritten_queries,
|
||||
log_messages=[
|
||||
f"{now_start} -- Expanded Retrieval - Query Expansion - Time taken: {now_end - now_start}"
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,84 @@
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.callbacks.manager import dispatch_custom_event
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.util__expanded_retrieval__subgraph.models import (
|
||||
ExpandedRetrievalResult,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.util__expanded_retrieval__subgraph.operations import (
|
||||
calculate_sub_question_retrieval_stats,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.util__expanded_retrieval__subgraph.states import (
|
||||
ExpandedRetrievalState,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.util__expanded_retrieval__subgraph.states import (
|
||||
ExpandedRetrievalUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkStats
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
|
||||
from onyx.chat.models import ExtendedToolResponse
|
||||
from onyx.tools.tool_implementations.search.search_tool import yield_search_responses
|
||||
|
||||
|
||||
def format_results(
|
||||
state: ExpandedRetrievalState, config: RunnableConfig
|
||||
) -> ExpandedRetrievalUpdate:
|
||||
level, question_nr = parse_question_id(state.sub_question_id or "0_0")
|
||||
query_infos = [
|
||||
result.query_info
|
||||
for result in state.expanded_retrieval_results
|
||||
if result.query_info is not None
|
||||
]
|
||||
if len(query_infos) == 0:
|
||||
raise ValueError("No query info found")
|
||||
|
||||
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
# main question docs will be sent later after aggregation and deduping with sub-question docs
|
||||
|
||||
reranked_documents = state.reranked_documents
|
||||
|
||||
if not (level == 0 and question_nr == 0):
|
||||
if len(reranked_documents) == 0:
|
||||
# The sub-question is used as the last query. If no verified documents are found, stream
|
||||
# the top 3 for that one. We may want to revisit this.
|
||||
reranked_documents = state.expanded_retrieval_results[-1].search_results[:3]
|
||||
|
||||
if agent_a_config.search_tool is None:
|
||||
raise ValueError("search_tool must be provided for agentic search")
|
||||
for tool_response in yield_search_responses(
|
||||
query=state.question,
|
||||
reranked_sections=state.retrieved_documents, # TODO: rename params. (sections pre-merging here.)
|
||||
final_context_sections=reranked_documents,
|
||||
search_query_info=query_infos[0], # TODO: handle differing query infos?
|
||||
get_section_relevance=lambda: None, # TODO: add relevance
|
||||
search_tool=agent_a_config.search_tool,
|
||||
):
|
||||
dispatch_custom_event(
|
||||
"tool_response",
|
||||
ExtendedToolResponse(
|
||||
id=tool_response.id,
|
||||
response=tool_response.response,
|
||||
level=level,
|
||||
level_question_nr=question_nr,
|
||||
),
|
||||
)
|
||||
sub_question_retrieval_stats = calculate_sub_question_retrieval_stats(
|
||||
verified_documents=state.verified_documents,
|
||||
expanded_retrieval_results=state.expanded_retrieval_results,
|
||||
)
|
||||
|
||||
if sub_question_retrieval_stats is None:
|
||||
sub_question_retrieval_stats = AgentChunkStats()
|
||||
# else:
|
||||
# sub_question_retrieval_stats = [sub_question_retrieval_stats]
|
||||
|
||||
return ExpandedRetrievalUpdate(
|
||||
expanded_retrieval_result=ExpandedRetrievalResult(
|
||||
expanded_queries_results=state.expanded_retrieval_results,
|
||||
reranked_documents=reranked_documents,
|
||||
context_documents=state.reranked_documents,
|
||||
sub_question_retrieval_stats=sub_question_retrieval_stats,
|
||||
),
|
||||
)
|
||||
@@ -0,0 +1,44 @@
|
||||
from typing import cast
|
||||
from typing import Literal
|
||||
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
from langgraph.types import Command
|
||||
from langgraph.types import Send
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.util__expanded_retrieval__subgraph.states import (
|
||||
DocVerificationInput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.util__expanded_retrieval__subgraph.states import (
|
||||
ExpandedRetrievalState,
|
||||
)
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
|
||||
|
||||
def verification_kickoff(
|
||||
state: ExpandedRetrievalState,
|
||||
config: RunnableConfig,
|
||||
) -> Command[Literal["doc_verification"]]:
|
||||
documents = state.retrieved_documents
|
||||
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
verification_question = (
|
||||
state.question
|
||||
if hasattr(state, "question")
|
||||
else agent_a_config.search_request.query
|
||||
)
|
||||
sub_question_id = state.sub_question_id
|
||||
return Command(
|
||||
update={},
|
||||
goto=[
|
||||
Send(
|
||||
node="doc_verification",
|
||||
arg=DocVerificationInput(
|
||||
doc_to_verify=doc,
|
||||
question=verification_question,
|
||||
base_search=False,
|
||||
sub_question_id=sub_question_id,
|
||||
log_messages=[],
|
||||
),
|
||||
)
|
||||
for doc in documents
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,97 @@
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
|
||||
import numpy as np
|
||||
from langchain_core.callbacks.manager import dispatch_custom_event
|
||||
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkStats
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import QueryResult
|
||||
from onyx.chat.models import SubQueryPiece
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def dispatch_subquery(level: int, question_nr: int) -> Callable[[str, int], None]:
|
||||
def helper(token: str, num: int) -> None:
|
||||
dispatch_custom_event(
|
||||
"subqueries",
|
||||
SubQueryPiece(
|
||||
sub_query=token,
|
||||
level=level,
|
||||
level_question_nr=question_nr,
|
||||
query_id=num,
|
||||
),
|
||||
)
|
||||
|
||||
return helper
|
||||
|
||||
|
||||
def calculate_sub_question_retrieval_stats(
|
||||
verified_documents: list[InferenceSection],
|
||||
expanded_retrieval_results: list[QueryResult],
|
||||
) -> AgentChunkStats:
|
||||
chunk_scores: dict[str, dict[str, list[int | float]]] = defaultdict(
|
||||
lambda: defaultdict(list)
|
||||
)
|
||||
|
||||
for expanded_retrieval_result in expanded_retrieval_results:
|
||||
for doc in expanded_retrieval_result.search_results:
|
||||
doc_chunk_id = f"{doc.center_chunk.document_id}_{doc.center_chunk.chunk_id}"
|
||||
if doc.center_chunk.score is not None:
|
||||
chunk_scores[doc_chunk_id]["score"].append(doc.center_chunk.score)
|
||||
|
||||
verified_doc_chunk_ids = [
|
||||
f"{verified_document.center_chunk.document_id}_{verified_document.center_chunk.chunk_id}"
|
||||
for verified_document in verified_documents
|
||||
]
|
||||
dismissed_doc_chunk_ids = []
|
||||
|
||||
raw_chunk_stats_counts: dict[str, int] = defaultdict(int)
|
||||
raw_chunk_stats_scores: dict[str, float] = defaultdict(float)
|
||||
for doc_chunk_id, chunk_data in chunk_scores.items():
|
||||
if doc_chunk_id in verified_doc_chunk_ids:
|
||||
raw_chunk_stats_counts["verified_count"] += 1
|
||||
|
||||
valid_chunk_scores = [
|
||||
score for score in chunk_data["score"] if score is not None
|
||||
]
|
||||
raw_chunk_stats_scores["verified_scores"] += float(
|
||||
np.mean(valid_chunk_scores)
|
||||
)
|
||||
else:
|
||||
raw_chunk_stats_counts["rejected_count"] += 1
|
||||
valid_chunk_scores = [
|
||||
score for score in chunk_data["score"] if score is not None
|
||||
]
|
||||
raw_chunk_stats_scores["rejected_scores"] += float(
|
||||
np.mean(valid_chunk_scores)
|
||||
)
|
||||
dismissed_doc_chunk_ids.append(doc_chunk_id)
|
||||
|
||||
if raw_chunk_stats_counts["verified_count"] == 0:
|
||||
verified_avg_scores = 0.0
|
||||
else:
|
||||
verified_avg_scores = raw_chunk_stats_scores["verified_scores"] / float(
|
||||
raw_chunk_stats_counts["verified_count"]
|
||||
)
|
||||
|
||||
rejected_scores = raw_chunk_stats_scores.get("rejected_scores", None)
|
||||
if rejected_scores is not None:
|
||||
rejected_avg_scores = rejected_scores / float(
|
||||
raw_chunk_stats_counts["rejected_count"]
|
||||
)
|
||||
else:
|
||||
rejected_avg_scores = None
|
||||
|
||||
chunk_stats = AgentChunkStats(
|
||||
verified_count=raw_chunk_stats_counts["verified_count"],
|
||||
verified_avg_scores=verified_avg_scores,
|
||||
rejected_count=raw_chunk_stats_counts["rejected_count"],
|
||||
rejected_avg_scores=rejected_avg_scores,
|
||||
verified_doc_chunk_ids=verified_doc_chunk_ids,
|
||||
dismissed_doc_chunk_ids=dismissed_doc_chunk_ids,
|
||||
)
|
||||
|
||||
return chunk_stats
|
||||
@@ -0,0 +1,91 @@
|
||||
from operator import add
|
||||
from typing import Annotated
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.agents.agent_search.core_state import SubgraphCoreState
|
||||
from onyx.agents.agent_search.deep_search_a.util__expanded_retrieval__subgraph.models import (
|
||||
ExpandedRetrievalResult,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import QueryResult
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import RetrievalFitStats
|
||||
from onyx.agents.agent_search.shared_graph_utils.operators import (
|
||||
dedup_inference_sections,
|
||||
)
|
||||
from onyx.context.search.models import InferenceSection
|
||||
|
||||
|
||||
### States ###
|
||||
|
||||
## Graph Input State
|
||||
|
||||
|
||||
class ExpandedRetrievalInput(SubgraphCoreState):
|
||||
question: str = ""
|
||||
base_search: bool = False
|
||||
sub_question_id: str | None = None
|
||||
|
||||
|
||||
## Update/Return States
|
||||
|
||||
|
||||
class QueryExpansionUpdate(BaseModel):
|
||||
expanded_queries: list[str] = ["aaa", "bbb"]
|
||||
log_messages: list[str] = []
|
||||
|
||||
|
||||
class DocVerificationUpdate(BaseModel):
|
||||
verified_documents: Annotated[list[InferenceSection], dedup_inference_sections] = []
|
||||
|
||||
|
||||
class DocRetrievalUpdate(BaseModel):
|
||||
expanded_retrieval_results: Annotated[list[QueryResult], add] = []
|
||||
retrieved_documents: Annotated[
|
||||
list[InferenceSection], dedup_inference_sections
|
||||
] = []
|
||||
log_messages: list[str] = []
|
||||
|
||||
|
||||
class DocRerankingUpdate(BaseModel):
|
||||
reranked_documents: Annotated[list[InferenceSection], dedup_inference_sections] = []
|
||||
sub_question_retrieval_stats: RetrievalFitStats | None = None
|
||||
log_messages: list[str] = []
|
||||
|
||||
|
||||
class ExpandedRetrievalUpdate(BaseModel):
|
||||
expanded_retrieval_result: ExpandedRetrievalResult
|
||||
|
||||
|
||||
## Graph Output State
|
||||
|
||||
|
||||
class ExpandedRetrievalOutput(BaseModel):
|
||||
expanded_retrieval_result: ExpandedRetrievalResult = ExpandedRetrievalResult()
|
||||
base_expanded_retrieval_result: ExpandedRetrievalResult = ExpandedRetrievalResult()
|
||||
log_messages: list[str] = []
|
||||
|
||||
|
||||
## Graph State
|
||||
|
||||
|
||||
class ExpandedRetrievalState(
|
||||
# This includes the core state
|
||||
ExpandedRetrievalInput,
|
||||
QueryExpansionUpdate,
|
||||
DocRetrievalUpdate,
|
||||
DocVerificationUpdate,
|
||||
DocRerankingUpdate,
|
||||
ExpandedRetrievalOutput,
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
## Conditional Input States
|
||||
|
||||
|
||||
class DocVerificationInput(ExpandedRetrievalInput):
|
||||
doc_to_verify: InferenceSection
|
||||
|
||||
|
||||
class RetrievalInput(ExpandedRetrievalInput):
|
||||
query_to_retrieve: str = ""
|
||||
101
backend/onyx/agents/agent_search/models.py
Normal file
101
backend/onyx/agents/agent_search/models.py
Normal file
@@ -0,0 +1,101 @@
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import model_validator
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import PersonaExpressions
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.context.search.models import SearchRequest
|
||||
from onyx.file_store.utils import InMemoryChatFile
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.tools.force import ForceUseTool
|
||||
from onyx.tools.tool import Tool
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
|
||||
|
||||
class AgentSearchConfig(BaseModel):
|
||||
"""
|
||||
Configuration for the Agent Search feature.
|
||||
"""
|
||||
|
||||
# The search request that was used to generate the Pro Search
|
||||
search_request: SearchRequest
|
||||
|
||||
primary_llm: LLM
|
||||
fast_llm: LLM
|
||||
|
||||
# Whether to force use of a tool, or to
|
||||
# force tool args IF the tool is used
|
||||
force_use_tool: ForceUseTool
|
||||
|
||||
# contains message history for the current chat session
|
||||
# has the following (at most one is non-None)
|
||||
# message_history: list[PreviousMessage] | None = None
|
||||
# single_message_history: str | None = None
|
||||
prompt_builder: AnswerPromptBuilder
|
||||
|
||||
search_tool: SearchTool | None = None
|
||||
|
||||
use_agentic_search: bool = False
|
||||
|
||||
# For persisting agent search data
|
||||
chat_session_id: UUID | None = None
|
||||
|
||||
# The message ID of the user message that triggered the Pro Search
|
||||
message_id: int | None = None
|
||||
|
||||
# Whether to persistence data for Agentic Search (turned off for testing)
|
||||
use_agentic_persistence: bool = True
|
||||
|
||||
# The database session for Agentic Search
|
||||
db_session: Session | None = None
|
||||
|
||||
# Whether to perform initial search to inform decomposition
|
||||
# perform_initial_search_path_decision: bool = True
|
||||
|
||||
# Whether to perform initial search to inform decomposition
|
||||
perform_initial_search_decomposition: bool = True
|
||||
|
||||
# Whether to allow creation of refinement questions (and entity extraction, etc.)
|
||||
allow_refinement: bool = True
|
||||
|
||||
# Tools available for use
|
||||
tools: list[Tool] | None = None
|
||||
|
||||
using_tool_calling_llm: bool = False
|
||||
|
||||
files: list[InMemoryChatFile] | None = None
|
||||
|
||||
structured_response_format: dict | None = None
|
||||
|
||||
skip_gen_ai_answer_generation: bool = False
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_db_session(self) -> "AgentSearchConfig":
|
||||
if self.use_agentic_persistence and self.db_session is None:
|
||||
raise ValueError(
|
||||
"db_session must be provided for pro search when using persistence"
|
||||
)
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_search_tool(self) -> "AgentSearchConfig":
|
||||
if self.use_agentic_search and self.search_tool is None:
|
||||
raise ValueError("search_tool must be provided for agentic search")
|
||||
return self
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
class AgentDocumentCitations(BaseModel):
|
||||
document_id: str
|
||||
document_title: str
|
||||
link: str
|
||||
|
||||
|
||||
class AgentPromptEnrichmentComponents(BaseModel):
|
||||
persona_prompts: PersonaExpressions
|
||||
history: str
|
||||
date_str: str
|
||||
@@ -0,0 +1,72 @@
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import AIMessageChunk
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.basic.states import BasicOutput
|
||||
from onyx.agents.agent_search.basic.states import BasicState
|
||||
from onyx.agents.agent_search.basic.utils import process_llm_stream
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.chat.models import LlmDoc
|
||||
from onyx.tools.tool_implementations.search.search_tool import (
|
||||
SEARCH_DOC_CONTENT_ID,
|
||||
)
|
||||
from onyx.tools.tool_implementations.search_like_tool_utils import (
|
||||
FINAL_CONTEXT_DOCUMENTS_ID,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def basic_use_tool_response(state: BasicState, config: RunnableConfig) -> BasicOutput:
|
||||
agent_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
structured_response_format = agent_config.structured_response_format
|
||||
llm = agent_config.primary_llm
|
||||
tool_choice = state.tool_choice
|
||||
if tool_choice is None:
|
||||
raise ValueError("Tool choice is None")
|
||||
tool = tool_choice.tool
|
||||
prompt_builder = agent_config.prompt_builder
|
||||
if state.tool_call_output is None:
|
||||
raise ValueError("Tool call output is None")
|
||||
tool_call_output = state.tool_call_output
|
||||
tool_call_summary = tool_call_output.tool_call_summary
|
||||
tool_call_responses = tool_call_output.tool_call_responses
|
||||
|
||||
new_prompt_builder = tool.build_next_prompt(
|
||||
prompt_builder=prompt_builder,
|
||||
tool_call_summary=tool_call_summary,
|
||||
tool_responses=tool_call_responses,
|
||||
using_tool_calling_llm=agent_config.using_tool_calling_llm,
|
||||
)
|
||||
|
||||
final_search_results = []
|
||||
initial_search_results = []
|
||||
for yield_item in tool_call_responses:
|
||||
if yield_item.id == FINAL_CONTEXT_DOCUMENTS_ID:
|
||||
final_search_results = cast(list[LlmDoc], yield_item.response)
|
||||
elif yield_item.id == SEARCH_DOC_CONTENT_ID:
|
||||
search_contexts = yield_item.response.contexts
|
||||
for doc in search_contexts:
|
||||
if doc.document_id not in initial_search_results:
|
||||
initial_search_results.append(doc)
|
||||
|
||||
initial_search_results = cast(list[LlmDoc], initial_search_results)
|
||||
|
||||
new_tool_call_chunk = AIMessageChunk(content="")
|
||||
if not agent_config.skip_gen_ai_answer_generation:
|
||||
stream = llm.stream(
|
||||
prompt=new_prompt_builder.build(),
|
||||
structured_response_format=structured_response_format,
|
||||
)
|
||||
|
||||
# For now, we don't do multiple tool calls, so we ignore the tool_message
|
||||
new_tool_call_chunk = process_llm_stream(
|
||||
stream,
|
||||
True,
|
||||
final_search_results=final_search_results,
|
||||
displayed_search_results=initial_search_results,
|
||||
)
|
||||
|
||||
return BasicOutput(tool_call_chunk=new_tool_call_chunk)
|
||||
@@ -0,0 +1,144 @@
|
||||
from typing import cast
|
||||
from uuid import uuid4
|
||||
|
||||
from langchain_core.messages import ToolCall
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.basic.utils import process_llm_stream
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.orchestration.states import ToolChoice
|
||||
from onyx.agents.agent_search.orchestration.states import ToolChoiceState
|
||||
from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.chat.tool_handling.tool_response_handler import get_tool_by_name
|
||||
from onyx.chat.tool_handling.tool_response_handler import (
|
||||
get_tool_call_for_non_tool_calling_llm_impl,
|
||||
)
|
||||
from onyx.tools.tool import Tool
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
# TODO: break this out into an implementation function
|
||||
# and a function that handles extracting the necessary fields
|
||||
# from the state and config
|
||||
# TODO: fan-out to multiple tool call nodes? Make this configurable?
|
||||
def llm_tool_choice(state: ToolChoiceState, config: RunnableConfig) -> ToolChoiceUpdate:
|
||||
"""
|
||||
This node is responsible for calling the LLM to choose a tool. If no tool is chosen,
|
||||
The node MAY emit an answer, depending on whether state["should_stream_answer"] is set.
|
||||
"""
|
||||
should_stream_answer = state.should_stream_answer
|
||||
|
||||
agent_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
using_tool_calling_llm = agent_config.using_tool_calling_llm
|
||||
prompt_builder = state.prompt_snapshot or agent_config.prompt_builder
|
||||
|
||||
llm = agent_config.primary_llm
|
||||
skip_gen_ai_answer_generation = agent_config.skip_gen_ai_answer_generation
|
||||
|
||||
structured_response_format = agent_config.structured_response_format
|
||||
tools = [tool for tool in (agent_config.tools or []) if tool.name in state.tools]
|
||||
force_use_tool = agent_config.force_use_tool
|
||||
|
||||
tool, tool_args = None, None
|
||||
if force_use_tool.force_use and force_use_tool.args is not None:
|
||||
tool_name, tool_args = (
|
||||
force_use_tool.tool_name,
|
||||
force_use_tool.args,
|
||||
)
|
||||
tool = get_tool_by_name(tools, tool_name)
|
||||
|
||||
# special pre-logic for non-tool calling LLM case
|
||||
elif not using_tool_calling_llm and tools:
|
||||
chosen_tool_and_args = get_tool_call_for_non_tool_calling_llm_impl(
|
||||
force_use_tool=force_use_tool,
|
||||
tools=tools,
|
||||
prompt_builder=prompt_builder,
|
||||
llm=llm,
|
||||
)
|
||||
if chosen_tool_and_args:
|
||||
tool, tool_args = chosen_tool_and_args
|
||||
|
||||
# If we have a tool and tool args, we are redy to request a tool call.
|
||||
# This only happens if the tool call was forced or we are using a non-tool calling LLM.
|
||||
if tool and tool_args:
|
||||
return ToolChoiceUpdate(
|
||||
tool_choice=ToolChoice(
|
||||
tool=tool,
|
||||
tool_args=tool_args,
|
||||
id=str(uuid4()),
|
||||
),
|
||||
)
|
||||
|
||||
# if we're skipping gen ai answer generation, we should only
|
||||
# continue if we're forcing a tool call (which will be emitted by
|
||||
# the tool calling llm in the stream() below)
|
||||
if skip_gen_ai_answer_generation and not force_use_tool.force_use:
|
||||
return ToolChoiceUpdate(
|
||||
tool_choice=None,
|
||||
)
|
||||
|
||||
built_prompt = (
|
||||
prompt_builder.build()
|
||||
if isinstance(prompt_builder, AnswerPromptBuilder)
|
||||
else prompt_builder.built_prompt
|
||||
)
|
||||
# At this point, we are either using a tool calling LLM or we are skipping the tool call.
|
||||
# DEBUG: good breakpoint
|
||||
stream = llm.stream(
|
||||
# For tool calling LLMs, we want to insert the task prompt as part of this flow, this is because the LLM
|
||||
# may choose to not call any tools and just generate the answer, in which case the task prompt is needed.
|
||||
prompt=built_prompt,
|
||||
tools=[tool.tool_definition() for tool in tools] or None,
|
||||
tool_choice=("required" if tools and force_use_tool.force_use else None),
|
||||
structured_response_format=structured_response_format,
|
||||
)
|
||||
|
||||
tool_message = process_llm_stream(
|
||||
stream, should_stream_answer and not agent_config.skip_gen_ai_answer_generation
|
||||
)
|
||||
|
||||
# If no tool calls are emitted by the LLM, we should not choose a tool
|
||||
if len(tool_message.tool_calls) == 0:
|
||||
logger.info("No tool calls emitted by LLM")
|
||||
return ToolChoiceUpdate(
|
||||
tool_choice=None,
|
||||
)
|
||||
|
||||
# TODO: here we could handle parallel tool calls. Right now
|
||||
# we just pick the first one that matches.
|
||||
selected_tool: Tool | None = None
|
||||
selected_tool_call_request: ToolCall | None = None
|
||||
for tool_call_request in tool_message.tool_calls:
|
||||
known_tools_by_name = [
|
||||
tool for tool in tools if tool.name == tool_call_request["name"]
|
||||
]
|
||||
|
||||
if known_tools_by_name:
|
||||
selected_tool = known_tools_by_name[0]
|
||||
selected_tool_call_request = tool_call_request
|
||||
break
|
||||
|
||||
logger.error(
|
||||
"Tool call requested with unknown name field. \n"
|
||||
f"tools: {tools}"
|
||||
f"tool_call_request: {tool_call_request}"
|
||||
)
|
||||
|
||||
if not selected_tool or not selected_tool_call_request:
|
||||
raise ValueError(
|
||||
f"Tool call attempted with tool {selected_tool}, request {selected_tool_call_request}"
|
||||
)
|
||||
|
||||
logger.info(f"Selected tool: {selected_tool.name}")
|
||||
logger.debug(f"Selected tool call request: {selected_tool_call_request}")
|
||||
|
||||
return ToolChoiceUpdate(
|
||||
tool_choice=ToolChoice(
|
||||
tool=selected_tool,
|
||||
tool_args=selected_tool_call_request["args"],
|
||||
id=selected_tool_call_request["id"],
|
||||
),
|
||||
)
|
||||
@@ -0,0 +1,17 @@
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.orchestration.states import ToolChoiceInput
|
||||
|
||||
|
||||
def prepare_tool_input(state: Any, config: RunnableConfig) -> ToolChoiceInput:
|
||||
agent_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
return ToolChoiceInput(
|
||||
# NOTE: this node is used at the top level of the agent, so we always stream
|
||||
should_stream_answer=True,
|
||||
prompt_snapshot=None, # uses default prompt builder
|
||||
tools=[tool.name for tool in (agent_config.tools or [])],
|
||||
)
|
||||
@@ -0,0 +1,65 @@
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.callbacks.manager import dispatch_custom_event
|
||||
from langchain_core.messages import AIMessageChunk
|
||||
from langchain_core.messages.tool import ToolCall
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.orchestration.states import ToolCallOutput
|
||||
from onyx.agents.agent_search.orchestration.states import ToolCallUpdate
|
||||
from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
|
||||
from onyx.chat.models import AnswerPacket
|
||||
from onyx.tools.message import build_tool_message
|
||||
from onyx.tools.message import ToolCallSummary
|
||||
from onyx.tools.tool_runner import ToolRunner
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def emit_packet(packet: AnswerPacket) -> None:
|
||||
dispatch_custom_event("basic_response", packet)
|
||||
|
||||
|
||||
def tool_call(state: ToolChoiceUpdate, config: RunnableConfig) -> ToolCallUpdate:
|
||||
"""Calls the tool specified in the state and updates the state with the result"""
|
||||
|
||||
cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
|
||||
tool_choice = state.tool_choice
|
||||
if tool_choice is None:
|
||||
raise ValueError("Cannot invoke tool call node without a tool choice")
|
||||
|
||||
tool = tool_choice.tool
|
||||
tool_args = tool_choice.tool_args
|
||||
tool_id = tool_choice.id
|
||||
tool_runner = ToolRunner(tool, tool_args)
|
||||
tool_kickoff = tool_runner.kickoff()
|
||||
|
||||
emit_packet(tool_kickoff)
|
||||
|
||||
tool_responses = []
|
||||
for response in tool_runner.tool_responses():
|
||||
tool_responses.append(response)
|
||||
emit_packet(response)
|
||||
|
||||
tool_final_result = tool_runner.tool_final_result()
|
||||
emit_packet(tool_final_result)
|
||||
|
||||
tool_call = ToolCall(name=tool.name, args=tool_args, id=tool_id)
|
||||
tool_call_summary = ToolCallSummary(
|
||||
tool_call_request=AIMessageChunk(content="", tool_calls=[tool_call]),
|
||||
tool_call_result=build_tool_message(
|
||||
tool_call, tool_runner.tool_message_content()
|
||||
),
|
||||
)
|
||||
|
||||
tool_call_output = ToolCallOutput(
|
||||
tool_call_summary=tool_call_summary,
|
||||
tool_call_kickoff=tool_kickoff,
|
||||
tool_call_responses=tool_responses,
|
||||
tool_call_final_result=tool_final_result,
|
||||
)
|
||||
return ToolCallUpdate(tool_call_output=tool_call_output)
|
||||
48
backend/onyx/agents/agent_search/orchestration/states.py
Normal file
48
backend/onyx/agents/agent_search/orchestration/states.py
Normal file
@@ -0,0 +1,48 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import PromptSnapshot
|
||||
from onyx.tools.message import ToolCallSummary
|
||||
from onyx.tools.models import ToolCallFinalResult
|
||||
from onyx.tools.models import ToolCallKickoff
|
||||
from onyx.tools.models import ToolResponse
|
||||
from onyx.tools.tool import Tool
|
||||
|
||||
|
||||
# TODO: adapt the tool choice/tool call to allow for parallel tool calls by
|
||||
# creating a subgraph that can be invoked in parallel via Send/Command APIs
|
||||
class ToolChoiceInput(BaseModel):
|
||||
should_stream_answer: bool = True
|
||||
# default to the prompt builder from the config, but
|
||||
# allow overrides for arbitrary tool calls
|
||||
prompt_snapshot: PromptSnapshot | None = None
|
||||
|
||||
# names of tools to use for tool calling. Filters the tools available in the config
|
||||
tools: list[str] = []
|
||||
|
||||
|
||||
class ToolCallOutput(BaseModel):
|
||||
tool_call_summary: ToolCallSummary
|
||||
tool_call_kickoff: ToolCallKickoff
|
||||
tool_call_responses: list[ToolResponse]
|
||||
tool_call_final_result: ToolCallFinalResult
|
||||
|
||||
|
||||
class ToolCallUpdate(BaseModel):
|
||||
tool_call_output: ToolCallOutput | None = None
|
||||
|
||||
|
||||
class ToolChoice(BaseModel):
|
||||
tool: Tool
|
||||
tool_args: dict
|
||||
id: str | None
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
class ToolChoiceUpdate(BaseModel):
|
||||
tool_choice: ToolChoice | None = None
|
||||
|
||||
|
||||
class ToolChoiceState(ToolChoiceUpdate, ToolChoiceInput):
|
||||
pass
|
||||
261
backend/onyx/agents/agent_search/run_graph.py
Normal file
261
backend/onyx/agents/agent_search/run_graph.py
Normal file
@@ -0,0 +1,261 @@
|
||||
import asyncio
|
||||
from collections.abc import AsyncIterable
|
||||
from collections.abc import Iterable
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.runnables.schema import StreamEvent
|
||||
from langgraph.graph.state import CompiledStateGraph
|
||||
|
||||
from onyx.agents.agent_search.basic.graph_builder import basic_graph_builder
|
||||
from onyx.agents.agent_search.basic.states import BasicInput
|
||||
from onyx.agents.agent_search.deep_search_a.main__graph.graph_builder import (
|
||||
main_graph_builder as main_graph_builder_a,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main__graph.states import (
|
||||
MainInput as MainInput_a,
|
||||
)
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
|
||||
from onyx.chat.models import AgentAnswerPiece
|
||||
from onyx.chat.models import AnswerPacket
|
||||
from onyx.chat.models import AnswerStream
|
||||
from onyx.chat.models import ExtendedToolResponse
|
||||
from onyx.chat.models import RefinedAnswerImprovement
|
||||
from onyx.chat.models import StreamStopInfo
|
||||
from onyx.chat.models import SubQueryPiece
|
||||
from onyx.chat.models import SubQuestionPiece
|
||||
from onyx.chat.models import ToolResponse
|
||||
from onyx.context.search.models import SearchRequest
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.tools.tool_runner import ToolCallKickoff
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_COMPILED_GRAPH: CompiledStateGraph | None = None
|
||||
|
||||
|
||||
def _set_combined_token_value(
|
||||
combined_token: str, parsed_object: AgentAnswerPiece
|
||||
) -> AgentAnswerPiece:
|
||||
parsed_object.answer_piece = combined_token
|
||||
|
||||
return parsed_object
|
||||
|
||||
|
||||
def _parse_agent_event(
|
||||
event: StreamEvent,
|
||||
) -> AnswerPacket | None:
|
||||
"""
|
||||
Parse the event into a typed object.
|
||||
Return None if we are not interested in the event.
|
||||
"""
|
||||
|
||||
event_type = event["event"]
|
||||
|
||||
# We always just yield the event data, but this piece is useful for two development reasons:
|
||||
# 1. It's a list of the names of every place we dispatch a custom event
|
||||
# 2. We maintain the intended types yielded by each event
|
||||
if event_type == "on_custom_event":
|
||||
# TODO: different AnswerStream types for different events
|
||||
if event["name"] == "decomp_qs":
|
||||
return cast(SubQuestionPiece, event["data"])
|
||||
elif event["name"] == "subqueries":
|
||||
return cast(SubQueryPiece, event["data"])
|
||||
elif event["name"] == "sub_answers":
|
||||
return cast(AgentAnswerPiece, event["data"])
|
||||
elif event["name"] == "stream_finished":
|
||||
return cast(StreamStopInfo, event["data"])
|
||||
elif event["name"] == "initial_agent_answer":
|
||||
return cast(AgentAnswerPiece, event["data"])
|
||||
elif event["name"] == "refined_agent_answer":
|
||||
return cast(AgentAnswerPiece, event["data"])
|
||||
elif event["name"] == "start_refined_answer_creation":
|
||||
return cast(ToolCallKickoff, event["data"])
|
||||
elif event["name"] == "tool_response":
|
||||
return cast(ToolResponse, event["data"])
|
||||
elif event["name"] == "basic_response":
|
||||
return cast(AnswerPacket, event["data"])
|
||||
elif event["name"] == "refined_answer_improvement":
|
||||
return cast(RefinedAnswerImprovement, event["data"])
|
||||
return None
|
||||
|
||||
|
||||
# https://stackoverflow.com/questions/60226557/how-to-forcefully-close-an-async-generator
|
||||
# https://stackoverflow.com/questions/40897428/please-explain-task-was-destroyed-but-it-is-pending-after-cancelling-tasks
|
||||
task_references: set[asyncio.Task[StreamEvent]] = set()
|
||||
|
||||
|
||||
def _manage_async_event_streaming(
|
||||
compiled_graph: CompiledStateGraph,
|
||||
config: AgentSearchConfig | None,
|
||||
graph_input: MainInput_a | BasicInput,
|
||||
) -> Iterable[StreamEvent]:
|
||||
async def _run_async_event_stream() -> AsyncIterable[StreamEvent]:
|
||||
message_id = config.message_id if config else None
|
||||
async for event in compiled_graph.astream_events(
|
||||
input=graph_input,
|
||||
config={"metadata": {"config": config, "thread_id": str(message_id)}},
|
||||
# debug=True,
|
||||
# indicating v2 here deserves further scrutiny
|
||||
version="v2",
|
||||
):
|
||||
yield event
|
||||
|
||||
# This might be able to be simplified
|
||||
def _yield_async_to_sync() -> Iterable[StreamEvent]:
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
# Get the async generator
|
||||
async_gen = _run_async_event_stream()
|
||||
# Convert to AsyncIterator
|
||||
async_iter = async_gen.__aiter__()
|
||||
while True:
|
||||
try:
|
||||
# Create a coroutine by calling anext with the async iterator
|
||||
next_coro = anext(async_iter)
|
||||
task = asyncio.ensure_future(next_coro, loop=loop)
|
||||
task_references.add(task)
|
||||
# Run the coroutine to get the next event
|
||||
event = loop.run_until_complete(task)
|
||||
yield event
|
||||
except (StopAsyncIteration, GeneratorExit):
|
||||
break
|
||||
finally:
|
||||
try:
|
||||
for task in task_references.pop():
|
||||
task.cancel()
|
||||
except StopAsyncIteration:
|
||||
pass
|
||||
loop.close()
|
||||
|
||||
return _yield_async_to_sync()
|
||||
|
||||
|
||||
def run_graph(
|
||||
compiled_graph: CompiledStateGraph,
|
||||
config: AgentSearchConfig,
|
||||
input: BasicInput | MainInput_a,
|
||||
) -> AnswerStream:
|
||||
# TODO: add these to the environment
|
||||
# config.perform_initial_search_path_decision = False
|
||||
config.perform_initial_search_decomposition = True
|
||||
config.allow_refinement = True
|
||||
|
||||
for event in _manage_async_event_streaming(
|
||||
compiled_graph=compiled_graph, config=config, graph_input=input
|
||||
):
|
||||
if not (parsed_object := _parse_agent_event(event)):
|
||||
continue
|
||||
|
||||
yield parsed_object
|
||||
|
||||
|
||||
# TODO: call this once on startup, TBD where and if it should be gated based
|
||||
# on dev mode or not
|
||||
def load_compiled_graph() -> CompiledStateGraph:
|
||||
global _COMPILED_GRAPH
|
||||
if _COMPILED_GRAPH is None:
|
||||
graph = main_graph_builder_a()
|
||||
_COMPILED_GRAPH = graph.compile()
|
||||
return _COMPILED_GRAPH
|
||||
|
||||
|
||||
def run_main_graph(
|
||||
config: AgentSearchConfig,
|
||||
) -> AnswerStream:
|
||||
compiled_graph = load_compiled_graph()
|
||||
input = MainInput_a(base_question=config.search_request.query, log_messages=[])
|
||||
|
||||
# Agent search is not a Tool per se, but this is helpful for the frontend
|
||||
yield ToolCallKickoff(
|
||||
tool_name="agent_search_0",
|
||||
tool_args={"query": config.search_request.query},
|
||||
)
|
||||
yield from run_graph(compiled_graph, config, input)
|
||||
|
||||
|
||||
# TODO: unify input types, especially prosearchconfig
|
||||
def run_basic_graph(
|
||||
config: AgentSearchConfig,
|
||||
) -> AnswerStream:
|
||||
graph = basic_graph_builder()
|
||||
compiled_graph = graph.compile()
|
||||
# TODO: unify basic input
|
||||
input = BasicInput()
|
||||
return run_graph(compiled_graph, config, input)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from onyx.llm.factory import get_default_llms
|
||||
|
||||
for _ in range(1):
|
||||
now_start = datetime.now()
|
||||
logger.debug(f"Start at {now_start}")
|
||||
graph = main_graph_builder_a()
|
||||
compiled_graph = graph.compile()
|
||||
now_end = datetime.now()
|
||||
logger.debug(f"Graph compiled in {now_end - now_start} seconds")
|
||||
primary_llm, fast_llm = get_default_llms()
|
||||
search_request = SearchRequest(
|
||||
# query="what can you do with gitlab?",
|
||||
# query="What are the guiding principles behind the development of cockroachDB",
|
||||
# query="What are the temperatures in Munich, Hawaii, and New York?",
|
||||
# query="When was Washington born?",
|
||||
# query="What is Onyx?",
|
||||
# query="What is the difference between astronomy and astrology?",
|
||||
query="Do a search to tell me what is the difference between astronomy and astrology?",
|
||||
)
|
||||
# Joachim custom persona
|
||||
|
||||
with get_session_context_manager() as db_session:
|
||||
config, search_tool = get_test_config(
|
||||
db_session, primary_llm, fast_llm, search_request
|
||||
)
|
||||
# search_request.persona = get_persona_by_id(1, None, db_session)
|
||||
config.use_agentic_persistence = True
|
||||
# config.perform_initial_search_path_decision = False
|
||||
config.perform_initial_search_decomposition = True
|
||||
input = MainInput_a(
|
||||
base_question=config.search_request.query, log_messages=[]
|
||||
)
|
||||
|
||||
# with open("output.txt", "w") as f:
|
||||
tool_responses: list = []
|
||||
for output in run_graph(compiled_graph, config, input):
|
||||
# pass
|
||||
|
||||
if isinstance(output, ToolCallKickoff):
|
||||
pass
|
||||
elif isinstance(output, ExtendedToolResponse):
|
||||
tool_responses.append(output.response)
|
||||
logger.info(
|
||||
f" ---- ET {output.level} - {output.level_question_nr} | "
|
||||
)
|
||||
elif isinstance(output, SubQueryPiece):
|
||||
logger.info(
|
||||
f"Sq {output.level} - {output.level_question_nr} - {output.sub_query} | "
|
||||
)
|
||||
elif isinstance(output, SubQuestionPiece):
|
||||
logger.info(
|
||||
f"SQ {output.level} - {output.level_question_nr} - {output.sub_question} | "
|
||||
)
|
||||
elif (
|
||||
isinstance(output, AgentAnswerPiece)
|
||||
and output.answer_type == "agent_sub_answer"
|
||||
):
|
||||
logger.info(
|
||||
f" ---- SA {output.level} - {output.level_question_nr} {output.answer_piece} | "
|
||||
)
|
||||
elif (
|
||||
isinstance(output, AgentAnswerPiece)
|
||||
and output.answer_type == "agent_level_answer"
|
||||
):
|
||||
logger.info(
|
||||
f" ---------- FA {output.level} - {output.level_question_nr} {output.answer_piece} | "
|
||||
)
|
||||
elif isinstance(output, RefinedAnswerImprovement):
|
||||
logger.info(
|
||||
f" ---------- RE {output.refined_answer_improvement} | "
|
||||
)
|
||||
@@ -0,0 +1,134 @@
|
||||
from langchain.schema import AIMessage
|
||||
from langchain.schema import HumanMessage
|
||||
from langchain.schema import SystemMessage
|
||||
from langchain_core.messages.tool import ToolMessage
|
||||
|
||||
from onyx.agents.agent_search.models import AgentPromptEnrichmentComponents
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import BASE_RAG_PROMPT_v2
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import HISTORY_PROMPT
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_persona_agent_prompt_expressions,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import get_today_prompt
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import summarize_history
|
||||
from onyx.configs.agent_configs import AGENT_MAX_STATIC_HISTORY_CHAR_LENGTH
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.llm.interfaces import LLMConfig
|
||||
from onyx.llm.utils import get_max_input_tokens
|
||||
from onyx.natural_language_processing.utils import get_tokenizer
|
||||
from onyx.natural_language_processing.utils import tokenizer_trim_content
|
||||
|
||||
|
||||
def build_sub_question_answer_prompt(
|
||||
question: str,
|
||||
original_question: str,
|
||||
docs: list[InferenceSection],
|
||||
persona_specification: str,
|
||||
config: LLMConfig,
|
||||
) -> list[SystemMessage | HumanMessage | AIMessage | ToolMessage]:
|
||||
system_message = SystemMessage(
|
||||
content=persona_specification,
|
||||
)
|
||||
|
||||
date_str = get_today_prompt()
|
||||
|
||||
docs_format_list = [
|
||||
f"""Document Number: [D{doc_nr + 1}]\n
|
||||
Content: {doc.combined_content}\n\n"""
|
||||
for doc_nr, doc in enumerate(docs)
|
||||
]
|
||||
|
||||
docs_str = "\n\n".join(docs_format_list)
|
||||
|
||||
docs_str = trim_prompt_piece(
|
||||
config, docs_str, BASE_RAG_PROMPT_v2 + question + original_question + date_str
|
||||
)
|
||||
human_message = HumanMessage(
|
||||
content=BASE_RAG_PROMPT_v2.format(
|
||||
question=question,
|
||||
original_question=original_question,
|
||||
context=docs_str,
|
||||
date_prompt=date_str,
|
||||
)
|
||||
)
|
||||
|
||||
return [system_message, human_message]
|
||||
|
||||
|
||||
def trim_prompt_piece(config: LLMConfig, prompt_piece: str, reserved_str: str) -> str:
|
||||
# TODO: this truncating might add latency. We could do a rougher + faster check
|
||||
# first to determine whether truncation is needed
|
||||
|
||||
# TODO: maybe save the tokenizer and max input tokens if this is getting called multiple times?
|
||||
llm_tokenizer = get_tokenizer(
|
||||
provider_type=config.model_provider,
|
||||
model_name=config.model_name,
|
||||
)
|
||||
|
||||
max_tokens = get_max_input_tokens(
|
||||
model_provider=config.model_provider,
|
||||
model_name=config.model_name,
|
||||
)
|
||||
|
||||
# slightly conservative trimming
|
||||
return tokenizer_trim_content(
|
||||
content=prompt_piece,
|
||||
desired_length=max_tokens - len(llm_tokenizer.encode(reserved_str)),
|
||||
tokenizer=llm_tokenizer,
|
||||
)
|
||||
|
||||
|
||||
def build_history_prompt(config: AgentSearchConfig, question: str) -> str:
|
||||
prompt_builder = config.prompt_builder
|
||||
model = config.fast_llm
|
||||
persona_base = get_persona_agent_prompt_expressions(
|
||||
config.search_request.persona
|
||||
).base_prompt
|
||||
|
||||
if prompt_builder is None:
|
||||
return ""
|
||||
|
||||
if prompt_builder.single_message_history is not None:
|
||||
history = prompt_builder.single_message_history
|
||||
else:
|
||||
history_components = []
|
||||
previous_message_type = None
|
||||
for message in prompt_builder.raw_message_history:
|
||||
if "user" in message.message_type:
|
||||
history_components.append(f"User: {message.message}\n")
|
||||
previous_message_type = "user"
|
||||
elif "assistant" in message.message_type:
|
||||
# only use the last agent answer for the history
|
||||
if previous_message_type != "assistant":
|
||||
history_components.append(f"You/Agent: {message.message}\n")
|
||||
else:
|
||||
history_components = history_components[:-1]
|
||||
history_components.append(f"You/Agent: {message.message}\n")
|
||||
previous_message_type = "assistant"
|
||||
else:
|
||||
continue
|
||||
history = "\n".join(history_components)
|
||||
|
||||
if len(history) > AGENT_MAX_STATIC_HISTORY_CHAR_LENGTH:
|
||||
history = summarize_history(history, question, persona_base, model)
|
||||
|
||||
return HISTORY_PROMPT.format(history=history) if history else ""
|
||||
|
||||
|
||||
def get_prompt_enrichment_components(
|
||||
config: AgentSearchConfig,
|
||||
) -> AgentPromptEnrichmentComponents:
|
||||
persona_prompts = get_persona_agent_prompt_expressions(
|
||||
config.search_request.persona
|
||||
)
|
||||
|
||||
history = build_history_prompt(config, config.search_request.query)
|
||||
|
||||
date_str = get_today_prompt()
|
||||
|
||||
return AgentPromptEnrichmentComponents(
|
||||
persona_prompts=persona_prompts,
|
||||
history=history,
|
||||
date_str=date_str,
|
||||
)
|
||||
@@ -0,0 +1,98 @@
|
||||
import numpy as np
|
||||
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import RetrievalFitScoreMetrics
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import RetrievalFitStats
|
||||
from onyx.chat.models import SectionRelevancePiece
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def unique_chunk_id(doc: InferenceSection) -> str:
|
||||
return f"{doc.center_chunk.document_id}_{doc.center_chunk.chunk_id}"
|
||||
|
||||
|
||||
def calculate_rank_shift(list1: list, list2: list, top_n: int = 20) -> float:
|
||||
shift = 0
|
||||
for rank_first, doc_id in enumerate(list1[:top_n], 1):
|
||||
try:
|
||||
rank_second = list2.index(doc_id) + 1
|
||||
except ValueError:
|
||||
rank_second = len(list2) # Document not found in second list
|
||||
|
||||
shift += np.abs(rank_first - rank_second) / np.log(1 + rank_first * rank_second)
|
||||
|
||||
return shift / top_n
|
||||
|
||||
|
||||
def get_fit_scores(
|
||||
pre_reranked_results: list[InferenceSection],
|
||||
post_reranked_results: list[InferenceSection] | list[SectionRelevancePiece],
|
||||
) -> RetrievalFitStats | None:
|
||||
"""
|
||||
Calculate retrieval metrics for search purposes
|
||||
"""
|
||||
|
||||
if len(pre_reranked_results) == 0 or len(post_reranked_results) == 0:
|
||||
return None
|
||||
|
||||
ranked_sections = {
|
||||
"initial": pre_reranked_results,
|
||||
"reranked": post_reranked_results,
|
||||
}
|
||||
|
||||
fit_eval: RetrievalFitStats = RetrievalFitStats(
|
||||
fit_score_lift=0,
|
||||
rerank_effect=0,
|
||||
fit_scores={
|
||||
"initial": RetrievalFitScoreMetrics(scores={}, chunk_ids=[]),
|
||||
"reranked": RetrievalFitScoreMetrics(scores={}, chunk_ids=[]),
|
||||
},
|
||||
)
|
||||
|
||||
for rank_type, docs in ranked_sections.items():
|
||||
logger.debug(f"rank_type: {rank_type}")
|
||||
|
||||
for i in [1, 5, 10]:
|
||||
fit_eval.fit_scores[rank_type].scores[str(i)] = (
|
||||
sum(
|
||||
[
|
||||
float(doc.center_chunk.score)
|
||||
for doc in docs[:i]
|
||||
if type(doc) == InferenceSection
|
||||
and doc.center_chunk.score is not None
|
||||
]
|
||||
)
|
||||
/ i
|
||||
)
|
||||
|
||||
fit_eval.fit_scores[rank_type].scores["fit_score"] = (
|
||||
1
|
||||
/ 3
|
||||
* (
|
||||
fit_eval.fit_scores[rank_type].scores["1"]
|
||||
+ fit_eval.fit_scores[rank_type].scores["5"]
|
||||
+ fit_eval.fit_scores[rank_type].scores["10"]
|
||||
)
|
||||
)
|
||||
|
||||
fit_eval.fit_scores[rank_type].scores["fit_score"] = fit_eval.fit_scores[
|
||||
rank_type
|
||||
].scores["1"]
|
||||
|
||||
fit_eval.fit_scores[rank_type].chunk_ids = [
|
||||
unique_chunk_id(doc) for doc in docs if type(doc) == InferenceSection
|
||||
]
|
||||
|
||||
fit_eval.fit_score_lift = (
|
||||
fit_eval.fit_scores["reranked"].scores["fit_score"]
|
||||
/ fit_eval.fit_scores["initial"].scores["fit_score"]
|
||||
)
|
||||
|
||||
fit_eval.rerank_effect = calculate_rank_shift(
|
||||
fit_eval.fit_scores["initial"].chunk_ids,
|
||||
fit_eval.fit_scores["reranked"].chunk_ids,
|
||||
)
|
||||
|
||||
return fit_eval
|
||||
123
backend/onyx/agents/agent_search/shared_graph_utils/models.py
Normal file
123
backend/onyx/agents/agent_search/shared_graph_utils/models.py
Normal file
@@ -0,0 +1,123 @@
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.main__graph.models import (
|
||||
AgentAdditionalMetrics,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main__graph.models import AgentBaseMetrics
|
||||
from onyx.agents.agent_search.deep_search_a.main__graph.models import (
|
||||
AgentRefinedMetrics,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main__graph.models import AgentTimings
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.tools.models import SearchQueryInfo
|
||||
|
||||
|
||||
# Pydantic models for structured outputs
|
||||
class RewrittenQueries(BaseModel):
|
||||
rewritten_queries: list[str]
|
||||
|
||||
|
||||
class BinaryDecision(BaseModel):
|
||||
decision: Literal["yes", "no"]
|
||||
|
||||
|
||||
class BinaryDecisionWithReasoning(BaseModel):
|
||||
reasoning: str
|
||||
decision: Literal["yes", "no"]
|
||||
|
||||
|
||||
class RetrievalFitScoreMetrics(BaseModel):
|
||||
scores: dict[str, float]
|
||||
chunk_ids: list[str]
|
||||
|
||||
|
||||
class RetrievalFitStats(BaseModel):
|
||||
fit_score_lift: float
|
||||
rerank_effect: float
|
||||
fit_scores: dict[str, RetrievalFitScoreMetrics]
|
||||
|
||||
|
||||
class AgentChunkScores(BaseModel):
|
||||
scores: dict[str, dict[str, list[int | float]]]
|
||||
|
||||
|
||||
class AgentChunkStats(BaseModel):
|
||||
verified_count: int | None = None
|
||||
verified_avg_scores: float | None = None
|
||||
rejected_count: int | None = None
|
||||
rejected_avg_scores: float | None = None
|
||||
verified_doc_chunk_ids: list[str] = []
|
||||
dismissed_doc_chunk_ids: list[str] = []
|
||||
|
||||
|
||||
class InitialAgentResultStats(BaseModel):
|
||||
sub_questions: dict[str, float | int | None]
|
||||
original_question: dict[str, float | int | None]
|
||||
agent_effectiveness: dict[str, float | int | None]
|
||||
|
||||
|
||||
class RefinedAgentStats(BaseModel):
|
||||
revision_doc_efficiency: float | None
|
||||
revision_question_efficiency: float | None
|
||||
|
||||
|
||||
class Term(BaseModel):
|
||||
term_name: str = ""
|
||||
term_type: str = ""
|
||||
term_similar_to: list[str] = []
|
||||
|
||||
|
||||
### Models ###
|
||||
|
||||
|
||||
class Entity(BaseModel):
|
||||
entity_name: str = ""
|
||||
entity_type: str = ""
|
||||
|
||||
|
||||
class Relationship(BaseModel):
|
||||
relationship_name: str = ""
|
||||
relationship_type: str = ""
|
||||
relationship_entities: list[str] = []
|
||||
|
||||
|
||||
class EntityRelationshipTermExtraction(BaseModel):
|
||||
entities: list[Entity] = []
|
||||
relationships: list[Relationship] = []
|
||||
terms: list[Term] = []
|
||||
|
||||
|
||||
### Models ###
|
||||
|
||||
|
||||
class QueryResult(BaseModel):
|
||||
query: str
|
||||
search_results: list[InferenceSection]
|
||||
stats: RetrievalFitStats | None
|
||||
query_info: SearchQueryInfo | None
|
||||
|
||||
|
||||
class QuestionAnswerResults(BaseModel):
|
||||
question: str
|
||||
question_id: str
|
||||
answer: str
|
||||
quality: str
|
||||
expanded_retrieval_results: list[QueryResult]
|
||||
documents: list[InferenceSection]
|
||||
context_documents: list[InferenceSection]
|
||||
cited_docs: list[InferenceSection]
|
||||
sub_question_retrieval_stats: AgentChunkStats
|
||||
|
||||
|
||||
class CombinedAgentMetrics(BaseModel):
|
||||
timings: AgentTimings
|
||||
base_metrics: AgentBaseMetrics | None
|
||||
refined_metrics: AgentRefinedMetrics
|
||||
additional_metrics: AgentAdditionalMetrics
|
||||
|
||||
|
||||
class PersonaExpressions(BaseModel):
|
||||
contextualized_prompt: str
|
||||
base_prompt: str
|
||||
@@ -0,0 +1,31 @@
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import (
|
||||
QuestionAnswerResults,
|
||||
)
|
||||
from onyx.chat.prune_and_merge import _merge_sections
|
||||
from onyx.context.search.models import InferenceSection
|
||||
|
||||
|
||||
def dedup_inference_sections(
|
||||
list1: list[InferenceSection], list2: list[InferenceSection]
|
||||
) -> list[InferenceSection]:
|
||||
deduped = _merge_sections(list1 + list2)
|
||||
return deduped
|
||||
|
||||
|
||||
def dedup_question_answer_results(
|
||||
question_answer_results_1: list[QuestionAnswerResults],
|
||||
question_answer_results_2: list[QuestionAnswerResults],
|
||||
) -> list[QuestionAnswerResults]:
|
||||
deduped_question_answer_results: list[
|
||||
QuestionAnswerResults
|
||||
] = question_answer_results_1
|
||||
utilized_question_ids: set[str] = set(
|
||||
[x.question_id for x in question_answer_results_1]
|
||||
)
|
||||
|
||||
for question_answer_result in question_answer_results_2:
|
||||
if question_answer_result.question_id not in utilized_question_ids:
|
||||
deduped_question_answer_results.append(question_answer_result)
|
||||
utilized_question_ids.add(question_answer_result.question_id)
|
||||
|
||||
return deduped_question_answer_results
|
||||
1113
backend/onyx/agents/agent_search/shared_graph_utils/prompts.py
Normal file
1113
backend/onyx/agents/agent_search/shared_graph_utils/prompts.py
Normal file
File diff suppressed because it is too large
Load Diff
367
backend/onyx/agents/agent_search/shared_graph_utils/utils.py
Normal file
367
backend/onyx/agents/agent_search/shared_graph_utils/utils.py
Normal file
@@ -0,0 +1,367 @@
|
||||
import ast
|
||||
import json
|
||||
import re
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from uuid import UUID
|
||||
|
||||
from langchain_core.callbacks.manager import dispatch_custom_event
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.messages import HumanMessage
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import (
|
||||
EntityRelationshipTermExtraction,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import PersonaExpressions
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import (
|
||||
ASSISTANT_SYSTEM_PROMPT_DEFAULT,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import (
|
||||
ASSISTANT_SYSTEM_PROMPT_PERSONA,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import DATE_PROMPT
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import (
|
||||
HISTORY_CONTEXT_SUMMARY_PROMPT,
|
||||
)
|
||||
from onyx.chat.models import AnswerStyleConfig
|
||||
from onyx.chat.models import CitationConfig
|
||||
from onyx.chat.models import DocumentPruningConfig
|
||||
from onyx.chat.models import PromptConfig
|
||||
from onyx.chat.models import StreamStopInfo
|
||||
from onyx.chat.models import StreamStopReason
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
|
||||
from onyx.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
|
||||
from onyx.configs.constants import DEFAULT_PERSONA_ID
|
||||
from onyx.context.search.enums import LLMEvaluationType
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.context.search.models import RetrievalDetails
|
||||
from onyx.context.search.models import SearchRequest
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.db.persona import get_persona_by_id
|
||||
from onyx.db.persona import Persona
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.tools.force import ForceUseTool
|
||||
from onyx.tools.tool_constructor import SearchToolConfig
|
||||
from onyx.tools.tool_implementations.search.search_tool import (
|
||||
SEARCH_RESPONSE_SUMMARY_ID,
|
||||
)
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
|
||||
|
||||
def normalize_whitespace(text: str) -> str:
|
||||
"""Normalize whitespace in text to single spaces and strip leading/trailing whitespace."""
|
||||
import re
|
||||
|
||||
return re.sub(r"\s+", " ", text.strip())
|
||||
|
||||
|
||||
# Post-processing
|
||||
def format_docs(docs: Sequence[InferenceSection]) -> str:
|
||||
formatted_doc_list = []
|
||||
|
||||
for doc_nr, doc in enumerate(docs):
|
||||
formatted_doc_list.append(f"Document D{doc_nr + 1}:\n{doc.combined_content}")
|
||||
|
||||
return "\n\n".join(formatted_doc_list)
|
||||
|
||||
|
||||
def format_docs_content_flat(docs: Sequence[InferenceSection]) -> str:
|
||||
formatted_doc_list = []
|
||||
|
||||
for _, doc in enumerate(docs):
|
||||
formatted_doc_list.append(f"\n...{doc.combined_content}\n")
|
||||
|
||||
return "\n\n".join(formatted_doc_list)
|
||||
|
||||
|
||||
def clean_and_parse_list_string(json_string: str) -> list[dict]:
|
||||
# Remove any prefixes/labels before the actual JSON content
|
||||
json_string = re.sub(r"^.*?(?=\[)", "", json_string, flags=re.DOTALL)
|
||||
|
||||
# Remove markdown code block markers and any newline prefixes
|
||||
cleaned_string = re.sub(r"```json\n|\n```", "", json_string)
|
||||
cleaned_string = cleaned_string.replace("\\n", " ").replace("\n", " ")
|
||||
cleaned_string = " ".join(cleaned_string.split())
|
||||
|
||||
# Try parsing with json.loads first, fall back to ast.literal_eval
|
||||
try:
|
||||
return json.loads(cleaned_string)
|
||||
except json.JSONDecodeError:
|
||||
try:
|
||||
return ast.literal_eval(cleaned_string)
|
||||
except (ValueError, SyntaxError) as e:
|
||||
raise ValueError(f"Failed to parse JSON string: {cleaned_string}") from e
|
||||
|
||||
|
||||
def clean_and_parse_json_string(json_string: str) -> dict[str, Any]:
|
||||
# Remove markdown code block markers and any newline prefixes
|
||||
cleaned_string = re.sub(r"```json\n|\n```", "", json_string)
|
||||
cleaned_string = cleaned_string.replace("\\n", " ").replace("\n", " ")
|
||||
cleaned_string = " ".join(cleaned_string.split())
|
||||
# Parse the cleaned string into a Python dictionary
|
||||
return json.loads(cleaned_string)
|
||||
|
||||
|
||||
def format_entity_term_extraction(
|
||||
entity_term_extraction_dict: EntityRelationshipTermExtraction,
|
||||
) -> str:
|
||||
entities = entity_term_extraction_dict.entities
|
||||
terms = entity_term_extraction_dict.terms
|
||||
relationships = entity_term_extraction_dict.relationships
|
||||
|
||||
entity_strs = ["\nEntities:\n"]
|
||||
for entity in entities:
|
||||
entity_str = f"{entity.entity_name} ({entity.entity_type})"
|
||||
entity_strs.append(entity_str)
|
||||
|
||||
entity_str = "\n - ".join(entity_strs)
|
||||
|
||||
relationship_strs = ["\n\nRelationships:\n"]
|
||||
for relationship in relationships:
|
||||
relationship_name = relationship.relationship_name
|
||||
relationship_type = relationship.relationship_type
|
||||
relationship_entities = relationship.relationship_entities
|
||||
relationship_str = (
|
||||
f"""{relationship_name} ({relationship_type}): {relationship_entities}"""
|
||||
)
|
||||
relationship_strs.append(relationship_str)
|
||||
|
||||
relationship_str = "\n - ".join(relationship_strs)
|
||||
|
||||
term_strs = ["\n\nTerms:\n"]
|
||||
for term in terms:
|
||||
term_str = f"{term.term_name} ({term.term_type}): similar to {', '.join(term.term_similar_to)}"
|
||||
term_strs.append(term_str)
|
||||
|
||||
term_str = "\n - ".join(term_strs)
|
||||
|
||||
return "\n".join(entity_strs + relationship_strs + term_strs)
|
||||
|
||||
|
||||
def _format_time_delta(time: timedelta) -> str:
|
||||
seconds_from_start = f"{((time).seconds):03d}"
|
||||
microseconds_from_start = f"{((time).microseconds):06d}"
|
||||
return f"{seconds_from_start}.{microseconds_from_start}"
|
||||
|
||||
|
||||
def generate_log_message(
|
||||
message: str,
|
||||
node_start_time: datetime,
|
||||
graph_start_time: datetime | None = None,
|
||||
) -> str:
|
||||
current_time = datetime.now()
|
||||
|
||||
if graph_start_time is not None:
|
||||
graph_time_str = _format_time_delta(current_time - graph_start_time)
|
||||
else:
|
||||
graph_time_str = "N/A"
|
||||
|
||||
node_time_str = _format_time_delta(current_time - node_start_time)
|
||||
|
||||
return f"{graph_time_str} ({node_time_str} s): {message}"
|
||||
|
||||
|
||||
def get_test_config(
|
||||
db_session: Session,
|
||||
primary_llm: LLM,
|
||||
fast_llm: LLM,
|
||||
search_request: SearchRequest,
|
||||
use_agentic_search: bool = True,
|
||||
) -> tuple[AgentSearchConfig, SearchTool]:
|
||||
persona = get_persona_by_id(DEFAULT_PERSONA_ID, None, db_session)
|
||||
document_pruning_config = DocumentPruningConfig(
|
||||
max_chunks=int(
|
||||
persona.num_chunks
|
||||
if persona.num_chunks is not None
|
||||
else MAX_CHUNKS_FED_TO_CHAT
|
||||
),
|
||||
max_window_percentage=CHAT_TARGET_CHUNK_PERCENTAGE,
|
||||
)
|
||||
|
||||
answer_style_config = AnswerStyleConfig(
|
||||
citation_config=CitationConfig(
|
||||
# The docs retrieved by this flow are already relevance-filtered
|
||||
all_docs_useful=True
|
||||
),
|
||||
document_pruning_config=document_pruning_config,
|
||||
structured_response_format=None,
|
||||
)
|
||||
|
||||
search_tool_config = SearchToolConfig(
|
||||
answer_style_config=answer_style_config,
|
||||
document_pruning_config=document_pruning_config,
|
||||
retrieval_options=RetrievalDetails(), # may want to set dedupe_docs=True
|
||||
rerank_settings=None, # Can use this to change reranking model
|
||||
selected_sections=None,
|
||||
latest_query_files=None,
|
||||
bypass_acl=False,
|
||||
)
|
||||
|
||||
prompt_config = PromptConfig.from_model(persona.prompts[0])
|
||||
|
||||
search_tool = SearchTool(
|
||||
db_session=db_session,
|
||||
user=None,
|
||||
persona=persona,
|
||||
retrieval_options=search_tool_config.retrieval_options,
|
||||
prompt_config=prompt_config,
|
||||
llm=primary_llm,
|
||||
fast_llm=fast_llm,
|
||||
pruning_config=search_tool_config.document_pruning_config,
|
||||
answer_style_config=search_tool_config.answer_style_config,
|
||||
selected_sections=search_tool_config.selected_sections,
|
||||
chunks_above=search_tool_config.chunks_above,
|
||||
chunks_below=search_tool_config.chunks_below,
|
||||
full_doc=search_tool_config.full_doc,
|
||||
evaluation_type=(
|
||||
LLMEvaluationType.BASIC
|
||||
if persona.llm_relevance_filter
|
||||
else LLMEvaluationType.SKIP
|
||||
),
|
||||
rerank_settings=search_tool_config.rerank_settings,
|
||||
bypass_acl=search_tool_config.bypass_acl,
|
||||
)
|
||||
|
||||
config = AgentSearchConfig(
|
||||
search_request=search_request,
|
||||
primary_llm=primary_llm,
|
||||
fast_llm=fast_llm,
|
||||
search_tool=search_tool,
|
||||
force_use_tool=ForceUseTool(force_use=False, tool_name=""),
|
||||
prompt_builder=AnswerPromptBuilder(
|
||||
user_message=HumanMessage(content=search_request.query),
|
||||
message_history=[],
|
||||
llm_config=primary_llm.config,
|
||||
raw_user_query=search_request.query,
|
||||
raw_user_uploaded_files=[],
|
||||
),
|
||||
# chat_session_id=UUID("123e4567-e89b-12d3-a456-426614174000"),
|
||||
chat_session_id=UUID("edda10d5-6cef-45d8-acfb-39317552a1f4"), # Joachim
|
||||
# chat_session_id=UUID("d1acd613-2692-4bc3-9d65-c6d3da62e58e"), # Evan
|
||||
message_id=1,
|
||||
use_agentic_persistence=True,
|
||||
db_session=db_session,
|
||||
tools=[search_tool],
|
||||
use_agentic_search=use_agentic_search,
|
||||
)
|
||||
|
||||
return config, search_tool
|
||||
|
||||
|
||||
def get_persona_agent_prompt_expressions(persona: Persona | None) -> PersonaExpressions:
|
||||
if persona is None:
|
||||
persona_prompt = ASSISTANT_SYSTEM_PROMPT_DEFAULT
|
||||
persona_base = ""
|
||||
else:
|
||||
persona_base = "\n".join([x.system_prompt for x in persona.prompts])
|
||||
|
||||
persona_prompt = ASSISTANT_SYSTEM_PROMPT_PERSONA.format(
|
||||
persona_prompt=persona_base
|
||||
)
|
||||
return PersonaExpressions(
|
||||
contextualized_prompt=persona_prompt, base_prompt=persona_base
|
||||
)
|
||||
|
||||
|
||||
def make_question_id(level: int, question_nr: int) -> str:
|
||||
return f"{level}_{question_nr}"
|
||||
|
||||
|
||||
def parse_question_id(question_id: str) -> tuple[int, int]:
|
||||
level, question_nr = question_id.split("_")
|
||||
return int(level), int(question_nr)
|
||||
|
||||
|
||||
def _dispatch_nonempty(
|
||||
content: str, dispatch_event: Callable[[str, int], None], num: int
|
||||
) -> None:
|
||||
if content != "":
|
||||
dispatch_event(content, num)
|
||||
|
||||
|
||||
def dispatch_separated(
|
||||
token_itr: Iterator[BaseMessage],
|
||||
dispatch_event: Callable[[str, int], None],
|
||||
sep: str = "\n",
|
||||
) -> list[str | list[str | dict[str, Any]]]:
|
||||
num = 1
|
||||
streamed_tokens: list[str | list[str | dict[str, Any]]] = [""]
|
||||
for message in token_itr:
|
||||
content = cast(str, message.content)
|
||||
if sep in content:
|
||||
sub_question_parts = content.split(sep)
|
||||
_dispatch_nonempty(sub_question_parts[0], dispatch_event, num)
|
||||
num += 1
|
||||
_dispatch_nonempty(
|
||||
"".join(sub_question_parts[1:]).strip(), dispatch_event, num
|
||||
)
|
||||
else:
|
||||
_dispatch_nonempty(content, dispatch_event, num)
|
||||
streamed_tokens.append(content)
|
||||
|
||||
return streamed_tokens
|
||||
|
||||
|
||||
def dispatch_main_answer_stop_info(level: int) -> None:
|
||||
stop_event = StreamStopInfo(
|
||||
stop_reason=StreamStopReason.FINISHED,
|
||||
stream_type="main_answer",
|
||||
level=level,
|
||||
)
|
||||
dispatch_custom_event("stream_finished", stop_event)
|
||||
|
||||
|
||||
def get_today_prompt() -> str:
|
||||
return DATE_PROMPT.format(date=datetime.now().strftime("%A, %B %d, %Y"))
|
||||
|
||||
|
||||
def retrieve_search_docs(
|
||||
search_tool: SearchTool, question: str
|
||||
) -> list[InferenceSection]:
|
||||
retrieved_docs: list[InferenceSection] = []
|
||||
|
||||
# new db session to avoid concurrency issues
|
||||
with get_session_context_manager() as db_session:
|
||||
for tool_response in search_tool.run(
|
||||
query=question,
|
||||
force_no_rerank=True,
|
||||
alternate_db_session=db_session,
|
||||
):
|
||||
# get retrieved docs to send to the rest of the graph
|
||||
if tool_response.id == SEARCH_RESPONSE_SUMMARY_ID:
|
||||
response = cast(SearchResponseSummary, tool_response.response)
|
||||
retrieved_docs = response.top_sections
|
||||
break
|
||||
|
||||
return retrieved_docs
|
||||
|
||||
|
||||
def get_answer_citation_ids(answer_str: str) -> list[int]:
|
||||
citation_ids = re.findall(r"\[\[D(\d+)\]\]", answer_str)
|
||||
return list(set([(int(id) - 1) for id in citation_ids]))
|
||||
|
||||
|
||||
def summarize_history(
|
||||
history: str, question: str, persona_specification: str, model: LLM
|
||||
) -> str:
|
||||
history_context_prompt = HISTORY_CONTEXT_SUMMARY_PROMPT.format(
|
||||
persona_specification=persona_specification, question=question, history=history
|
||||
)
|
||||
|
||||
history_response = model.invoke(history_context_prompt)
|
||||
|
||||
if isinstance(history_response.content, str):
|
||||
history_context_response_str = history_response.content
|
||||
else:
|
||||
history_context_response_str = ""
|
||||
|
||||
return history_context_response_str
|
||||
@@ -1,110 +1,81 @@
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from uuid import uuid4
|
||||
from uuid import UUID
|
||||
|
||||
from langchain.schema.messages import BaseMessage
|
||||
from langchain_core.messages import AIMessageChunk
|
||||
from langchain_core.messages import ToolCall
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.chat.llm_response_handler import LLMResponseHandlerManager
|
||||
from onyx.chat.models import AnswerQuestionPossibleReturn
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.run_graph import run_basic_graph
|
||||
from onyx.agents.agent_search.run_graph import run_main_graph
|
||||
from onyx.chat.models import AgentAnswerPiece
|
||||
from onyx.chat.models import AnswerPacket
|
||||
from onyx.chat.models import AnswerStream
|
||||
from onyx.chat.models import AnswerStyleConfig
|
||||
from onyx.chat.models import CitationInfo
|
||||
from onyx.chat.models import OnyxAnswerPiece
|
||||
from onyx.chat.models import PromptConfig
|
||||
from onyx.chat.models import StreamStopInfo
|
||||
from onyx.chat.models import StreamStopReason
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_system_message
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_user_message
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import LLMCall
|
||||
from onyx.chat.stream_processing.answer_response_handler import (
|
||||
CitationResponseHandler,
|
||||
)
|
||||
from onyx.chat.stream_processing.answer_response_handler import (
|
||||
DummyAnswerResponseHandler,
|
||||
)
|
||||
from onyx.chat.stream_processing.utils import (
|
||||
map_document_id_order,
|
||||
)
|
||||
from onyx.chat.tool_handling.tool_response_handler import ToolResponseHandler
|
||||
from onyx.chat.tool_handling.tool_response_handler import get_tool_by_name
|
||||
from onyx.configs.constants import BASIC_KEY
|
||||
from onyx.context.search.models import SearchRequest
|
||||
from onyx.file_store.utils import InMemoryChatFile
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.models import PreviousMessage
|
||||
from onyx.natural_language_processing.utils import get_tokenizer
|
||||
from onyx.tools.force import ForceUseTool
|
||||
from onyx.tools.models import ToolResponse
|
||||
from onyx.tools.tool import Tool
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from onyx.tools.tool_runner import ToolCallKickoff
|
||||
from onyx.tools.utils import explicit_tool_calling_supported
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
AnswerStream = Iterator[AnswerQuestionPossibleReturn | ToolCallKickoff | ToolResponse]
|
||||
|
||||
|
||||
class Answer:
|
||||
def __init__(
|
||||
self,
|
||||
question: str,
|
||||
prompt_builder: AnswerPromptBuilder,
|
||||
answer_style_config: AnswerStyleConfig,
|
||||
llm: LLM,
|
||||
prompt_config: PromptConfig,
|
||||
fast_llm: LLM,
|
||||
force_use_tool: ForceUseTool,
|
||||
# must be the same length as `docs`. If None, all docs are considered "relevant"
|
||||
message_history: list[PreviousMessage] | None = None,
|
||||
single_message_history: str | None = None,
|
||||
search_request: SearchRequest,
|
||||
chat_session_id: UUID,
|
||||
current_agent_message_id: int,
|
||||
# newly passed in files to include as part of this question
|
||||
# TODO THIS NEEDS TO BE HANDLED
|
||||
latest_query_files: list[InMemoryChatFile] | None = None,
|
||||
files: list[InMemoryChatFile] | None = None,
|
||||
tools: list[Tool] | None = None,
|
||||
# NOTE: for native tool-calling, this is only supported by OpenAI atm,
|
||||
# but we only support them anyways
|
||||
# if set to True, then never use the LLMs provided tool-calling functonality
|
||||
skip_explicit_tool_calling: bool = False,
|
||||
# Returns the full document sections text from the search tool
|
||||
return_contexts: bool = False,
|
||||
skip_gen_ai_answer_generation: bool = False,
|
||||
is_connected: Callable[[], bool] | None = None,
|
||||
db_session: Session | None = None,
|
||||
use_agentic_search: bool = False,
|
||||
use_agentic_persistence: bool = True,
|
||||
) -> None:
|
||||
if single_message_history and message_history:
|
||||
raise ValueError(
|
||||
"Cannot provide both `message_history` and `single_message_history`"
|
||||
)
|
||||
|
||||
self.question = question
|
||||
self.is_connected: Callable[[], bool] | None = is_connected
|
||||
|
||||
self.latest_query_files = latest_query_files or []
|
||||
self.file_id_to_file = {file.file_id: file for file in (files or [])}
|
||||
|
||||
self.tools = tools or []
|
||||
self.force_use_tool = force_use_tool
|
||||
|
||||
self.message_history = message_history or []
|
||||
# used for QA flow where we only want to send a single message
|
||||
self.single_message_history = single_message_history
|
||||
|
||||
self.answer_style_config = answer_style_config
|
||||
self.prompt_config = prompt_config
|
||||
|
||||
self.llm = llm
|
||||
self.fast_llm = fast_llm
|
||||
self.llm_tokenizer = get_tokenizer(
|
||||
provider_type=llm.config.model_provider,
|
||||
model_name=llm.config.model_name,
|
||||
)
|
||||
|
||||
self._final_prompt: list[BaseMessage] | None = None
|
||||
|
||||
self._streamed_output: list[str] | None = None
|
||||
self._processed_stream: (
|
||||
list[AnswerQuestionPossibleReturn | ToolResponse | ToolCallKickoff] | None
|
||||
) = None
|
||||
self._processed_stream: (list[AnswerPacket] | None) = None
|
||||
|
||||
self._return_contexts = return_contexts
|
||||
self.skip_gen_ai_answer_generation = skip_gen_ai_answer_generation
|
||||
self._is_cancelled = False
|
||||
|
||||
@@ -115,167 +86,76 @@ class Answer:
|
||||
and not skip_explicit_tool_calling
|
||||
)
|
||||
|
||||
search_tools = [tool for tool in (tools or []) if isinstance(tool, SearchTool)]
|
||||
search_tool: SearchTool | None = None
|
||||
|
||||
if len(search_tools) > 1:
|
||||
# TODO: handle multiple search tools
|
||||
raise ValueError("Multiple search tools found")
|
||||
elif len(search_tools) == 1:
|
||||
search_tool = search_tools[0]
|
||||
|
||||
using_tool_calling_llm = explicit_tool_calling_supported(
|
||||
llm.config.model_provider, llm.config.model_name
|
||||
)
|
||||
self.agent_search_config = AgentSearchConfig(
|
||||
search_request=search_request,
|
||||
primary_llm=llm,
|
||||
fast_llm=fast_llm,
|
||||
search_tool=search_tool,
|
||||
force_use_tool=force_use_tool,
|
||||
use_agentic_search=use_agentic_search,
|
||||
chat_session_id=chat_session_id,
|
||||
message_id=current_agent_message_id,
|
||||
use_agentic_persistence=use_agentic_persistence,
|
||||
allow_refinement=True,
|
||||
db_session=db_session,
|
||||
prompt_builder=prompt_builder,
|
||||
tools=tools,
|
||||
using_tool_calling_llm=using_tool_calling_llm,
|
||||
files=latest_query_files,
|
||||
structured_response_format=answer_style_config.structured_response_format,
|
||||
skip_gen_ai_answer_generation=skip_gen_ai_answer_generation,
|
||||
)
|
||||
self.db_session = db_session
|
||||
|
||||
def _get_tools_list(self) -> list[Tool]:
|
||||
if not self.force_use_tool.force_use:
|
||||
return self.tools
|
||||
|
||||
tool = next(
|
||||
(t for t in self.tools if t.name == self.force_use_tool.tool_name), None
|
||||
)
|
||||
if tool is None:
|
||||
raise RuntimeError(f"Tool '{self.force_use_tool.tool_name}' not found")
|
||||
tool = get_tool_by_name(self.tools, self.force_use_tool.tool_name)
|
||||
|
||||
logger.info(
|
||||
f"Forcefully using tool='{tool.name}'"
|
||||
+ (
|
||||
f" with args='{self.force_use_tool.args}'"
|
||||
if self.force_use_tool.args is not None
|
||||
else ""
|
||||
)
|
||||
args_str = (
|
||||
f" with args='{self.force_use_tool.args}'"
|
||||
if self.force_use_tool.args
|
||||
else ""
|
||||
)
|
||||
logger.info(f"Forcefully using tool='{tool.name}'{args_str}")
|
||||
return [tool]
|
||||
|
||||
def _handle_specified_tool_call(
|
||||
self, llm_calls: list[LLMCall], tool: Tool, tool_args: dict
|
||||
) -> AnswerStream:
|
||||
current_llm_call = llm_calls[-1]
|
||||
|
||||
# make a dummy tool handler
|
||||
tool_handler = ToolResponseHandler([tool])
|
||||
|
||||
dummy_tool_call_chunk = AIMessageChunk(content="")
|
||||
dummy_tool_call_chunk.tool_calls = [
|
||||
ToolCall(name=tool.name, args=tool_args, id=str(uuid4()))
|
||||
]
|
||||
|
||||
response_handler_manager = LLMResponseHandlerManager(
|
||||
tool_handler, DummyAnswerResponseHandler(), self.is_cancelled
|
||||
)
|
||||
yield from response_handler_manager.handle_llm_response(
|
||||
iter([dummy_tool_call_chunk])
|
||||
)
|
||||
|
||||
new_llm_call = response_handler_manager.next_llm_call(current_llm_call)
|
||||
if new_llm_call:
|
||||
yield from self._get_response(llm_calls + [new_llm_call])
|
||||
else:
|
||||
raise RuntimeError("Tool call handler did not return a new LLM call")
|
||||
|
||||
def _get_response(self, llm_calls: list[LLMCall]) -> AnswerStream:
|
||||
current_llm_call = llm_calls[-1]
|
||||
|
||||
# handle the case where no decision has to be made; we simply run the tool
|
||||
if (
|
||||
current_llm_call.force_use_tool.force_use
|
||||
and current_llm_call.force_use_tool.args is not None
|
||||
):
|
||||
tool_name, tool_args = (
|
||||
current_llm_call.force_use_tool.tool_name,
|
||||
current_llm_call.force_use_tool.args,
|
||||
)
|
||||
tool = next(
|
||||
(t for t in current_llm_call.tools if t.name == tool_name), None
|
||||
)
|
||||
if not tool:
|
||||
raise RuntimeError(f"Tool '{tool_name}' not found")
|
||||
|
||||
yield from self._handle_specified_tool_call(llm_calls, tool, tool_args)
|
||||
return
|
||||
|
||||
# special pre-logic for non-tool calling LLM case
|
||||
if not self.using_tool_calling_llm and current_llm_call.tools:
|
||||
chosen_tool_and_args = (
|
||||
ToolResponseHandler.get_tool_call_for_non_tool_calling_llm(
|
||||
current_llm_call, self.llm
|
||||
)
|
||||
)
|
||||
if chosen_tool_and_args:
|
||||
tool, tool_args = chosen_tool_and_args
|
||||
yield from self._handle_specified_tool_call(llm_calls, tool, tool_args)
|
||||
return
|
||||
|
||||
# if we're skipping gen ai answer generation, we should break
|
||||
# out unless we're forcing a tool call. If we don't, we might generate an
|
||||
# answer, which is a no-no!
|
||||
if (
|
||||
self.skip_gen_ai_answer_generation
|
||||
and not current_llm_call.force_use_tool.force_use
|
||||
):
|
||||
return
|
||||
|
||||
# set up "handlers" to listen to the LLM response stream and
|
||||
# feed back the processed results + handle tool call requests
|
||||
# + figure out what the next LLM call should be
|
||||
tool_call_handler = ToolResponseHandler(current_llm_call.tools)
|
||||
|
||||
final_search_results, displayed_search_results = SearchTool.get_search_result(
|
||||
current_llm_call
|
||||
) or ([], [])
|
||||
|
||||
answer_handler = CitationResponseHandler(
|
||||
context_docs=final_search_results,
|
||||
final_doc_id_to_rank_map=map_document_id_order(final_search_results),
|
||||
display_doc_id_to_rank_map=map_document_id_order(displayed_search_results),
|
||||
)
|
||||
|
||||
response_handler_manager = LLMResponseHandlerManager(
|
||||
tool_call_handler, answer_handler, self.is_cancelled
|
||||
)
|
||||
|
||||
# DEBUG: good breakpoint
|
||||
stream = self.llm.stream(
|
||||
# For tool calling LLMs, we want to insert the task prompt as part of this flow, this is because the LLM
|
||||
# may choose to not call any tools and just generate the answer, in which case the task prompt is needed.
|
||||
prompt=current_llm_call.prompt_builder.build(),
|
||||
tools=[tool.tool_definition() for tool in current_llm_call.tools] or None,
|
||||
tool_choice=(
|
||||
"required"
|
||||
if current_llm_call.tools and current_llm_call.force_use_tool.force_use
|
||||
else None
|
||||
),
|
||||
structured_response_format=self.answer_style_config.structured_response_format,
|
||||
)
|
||||
yield from response_handler_manager.handle_llm_response(stream)
|
||||
|
||||
new_llm_call = response_handler_manager.next_llm_call(current_llm_call)
|
||||
if new_llm_call:
|
||||
yield from self._get_response(llm_calls + [new_llm_call])
|
||||
|
||||
@property
|
||||
def processed_streamed_output(self) -> AnswerStream:
|
||||
if self._processed_stream is not None:
|
||||
yield from self._processed_stream
|
||||
return
|
||||
|
||||
prompt_builder = AnswerPromptBuilder(
|
||||
user_message=default_build_user_message(
|
||||
user_query=self.question,
|
||||
prompt_config=self.prompt_config,
|
||||
files=self.latest_query_files,
|
||||
single_message_history=self.single_message_history,
|
||||
),
|
||||
message_history=self.message_history,
|
||||
llm_config=self.llm.config,
|
||||
raw_user_query=self.question,
|
||||
raw_user_uploaded_files=self.latest_query_files or [],
|
||||
single_message_history=self.single_message_history,
|
||||
run_langgraph = (
|
||||
run_main_graph
|
||||
if self.agent_search_config.use_agentic_search
|
||||
else run_basic_graph
|
||||
)
|
||||
prompt_builder.update_system_prompt(
|
||||
default_build_system_message(self.prompt_config)
|
||||
)
|
||||
llm_call = LLMCall(
|
||||
prompt_builder=prompt_builder,
|
||||
tools=self._get_tools_list(),
|
||||
force_use_tool=self.force_use_tool,
|
||||
files=self.latest_query_files,
|
||||
tool_call_info=[],
|
||||
using_tool_calling_llm=self.using_tool_calling_llm,
|
||||
stream = run_langgraph(
|
||||
self.agent_search_config,
|
||||
)
|
||||
|
||||
processed_stream = []
|
||||
for processed_packet in self._get_response([llm_call]):
|
||||
processed_stream.append(processed_packet)
|
||||
yield processed_packet
|
||||
for packet in stream:
|
||||
if self.is_cancelled():
|
||||
packet = StreamStopInfo(stop_reason=StreamStopReason.CANCELLED)
|
||||
yield packet
|
||||
break
|
||||
processed_stream.append(packet)
|
||||
yield packet
|
||||
|
||||
self._processed_stream = processed_stream
|
||||
|
||||
@@ -283,20 +163,57 @@ class Answer:
|
||||
def llm_answer(self) -> str:
|
||||
answer = ""
|
||||
for packet in self.processed_streamed_output:
|
||||
if isinstance(packet, OnyxAnswerPiece) and packet.answer_piece:
|
||||
# handle basic answer flow, plus level 0 agent answer flow
|
||||
# since level 0 is the first answer the user sees and therefore the
|
||||
# child message of the user message in the db (so it is handled
|
||||
# like a basic flow answer)
|
||||
if (isinstance(packet, OnyxAnswerPiece) and packet.answer_piece) or (
|
||||
isinstance(packet, AgentAnswerPiece)
|
||||
and packet.answer_piece
|
||||
and packet.answer_type == "agent_level_answer"
|
||||
and packet.level == 0
|
||||
):
|
||||
answer += packet.answer_piece
|
||||
|
||||
return answer
|
||||
|
||||
def llm_answer_by_level(self) -> dict[int, str]:
|
||||
answer_by_level: dict[int, str] = defaultdict(str)
|
||||
for packet in self.processed_streamed_output:
|
||||
if (
|
||||
isinstance(packet, AgentAnswerPiece)
|
||||
and packet.answer_piece
|
||||
and packet.answer_type == "agent_level_answer"
|
||||
):
|
||||
answer_by_level[packet.level] += packet.answer_piece
|
||||
elif isinstance(packet, OnyxAnswerPiece) and packet.answer_piece:
|
||||
answer_by_level[BASIC_KEY[0]] += packet.answer_piece
|
||||
return answer_by_level
|
||||
|
||||
@property
|
||||
def citations(self) -> list[CitationInfo]:
|
||||
citations: list[CitationInfo] = []
|
||||
for packet in self.processed_streamed_output:
|
||||
if isinstance(packet, CitationInfo):
|
||||
if isinstance(packet, CitationInfo) and packet.level is None:
|
||||
citations.append(packet)
|
||||
|
||||
return citations
|
||||
|
||||
# TODO: replace tuple of ints with SubQuestionId EVERYWHERE
|
||||
def citations_by_subquestion(self) -> dict[tuple[int, int], list[CitationInfo]]:
|
||||
citations_by_subquestion: dict[
|
||||
tuple[int, int], list[CitationInfo]
|
||||
] = defaultdict(list)
|
||||
for packet in self.processed_streamed_output:
|
||||
if isinstance(packet, CitationInfo):
|
||||
if packet.level_question_nr is not None and packet.level is not None:
|
||||
citations_by_subquestion[
|
||||
(packet.level, packet.level_question_nr)
|
||||
].append(packet)
|
||||
elif packet.level is None:
|
||||
citations_by_subquestion[BASIC_KEY].append(packet)
|
||||
return citations_by_subquestion
|
||||
|
||||
def is_cancelled(self) -> bool:
|
||||
if self._is_cancelled:
|
||||
return True
|
||||
|
||||
@@ -48,6 +48,8 @@ def prepare_chat_message_request(
|
||||
retrieval_details: RetrievalDetails | None,
|
||||
rerank_settings: RerankingDetails | None,
|
||||
db_session: Session,
|
||||
use_agentic_search: bool = False,
|
||||
skip_gen_ai_answer_generation: bool = False,
|
||||
) -> CreateChatMessageRequest:
|
||||
# Typically used for one shot flows like SlackBot or non-chat API endpoint use cases
|
||||
new_chat_session = create_chat_session(
|
||||
@@ -72,6 +74,8 @@ def prepare_chat_message_request(
|
||||
search_doc_ids=None,
|
||||
retrieval_options=retrieval_details,
|
||||
rerank_settings=rerank_settings,
|
||||
use_agentic_search=use_agentic_search,
|
||||
skip_gen_ai_answer_generation=skip_gen_ai_answer_generation,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -9,25 +9,37 @@ from onyx.chat.models import StreamStopInfo
|
||||
from onyx.chat.models import StreamStopReason
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import LLMCall
|
||||
from onyx.chat.stream_processing.answer_response_handler import AnswerResponseHandler
|
||||
from onyx.chat.stream_processing.answer_response_handler import (
|
||||
DummyAnswerResponseHandler,
|
||||
)
|
||||
from onyx.chat.tool_handling.tool_response_handler import ToolResponseHandler
|
||||
|
||||
|
||||
class LLMResponseHandlerManager:
|
||||
"""
|
||||
This class is responsible for postprocessing the LLM response stream.
|
||||
In particular, we:
|
||||
1. handle the tool call requests
|
||||
2. handle citations
|
||||
3. pass through answers generated by the LLM
|
||||
4. Stop yielding if the client disconnects
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tool_handler: ToolResponseHandler,
|
||||
answer_handler: AnswerResponseHandler,
|
||||
tool_handler: ToolResponseHandler | None,
|
||||
answer_handler: AnswerResponseHandler | None,
|
||||
is_cancelled: Callable[[], bool],
|
||||
):
|
||||
self.tool_handler = tool_handler
|
||||
self.answer_handler = answer_handler
|
||||
self.tool_handler = tool_handler or ToolResponseHandler([])
|
||||
self.answer_handler = answer_handler or DummyAnswerResponseHandler()
|
||||
self.is_cancelled = is_cancelled
|
||||
|
||||
def handle_llm_response(
|
||||
self,
|
||||
stream: Iterator[BaseMessage],
|
||||
) -> Generator[ResponsePart, None, None]:
|
||||
all_messages: list[BaseMessage] = []
|
||||
all_messages: list[BaseMessage | str] = []
|
||||
for message in stream:
|
||||
if self.is_cancelled():
|
||||
yield StreamStopInfo(stop_reason=StreamStopReason.CANCELLED)
|
||||
|
||||
@@ -3,6 +3,7 @@ from collections.abc import Iterator
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
from typing import Literal
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from pydantic import BaseModel
|
||||
@@ -48,6 +49,8 @@ class QADocsResponse(RetrievalDocs):
|
||||
applied_source_filters: list[DocumentSource] | None
|
||||
applied_time_cutoff: datetime | None
|
||||
recency_bias_multiplier: float
|
||||
level: int | None = None
|
||||
level_question_nr: int | None = None
|
||||
|
||||
def model_dump(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
|
||||
initial_dict = super().model_dump(mode="json", *args, **kwargs) # type: ignore
|
||||
@@ -61,11 +64,17 @@ class QADocsResponse(RetrievalDocs):
|
||||
class StreamStopReason(Enum):
|
||||
CONTEXT_LENGTH = "context_length"
|
||||
CANCELLED = "cancelled"
|
||||
FINISHED = "finished"
|
||||
|
||||
|
||||
class StreamStopInfo(BaseModel):
|
||||
stop_reason: StreamStopReason
|
||||
|
||||
stream_type: Literal["", "sub_questions", "sub_answer", "main_answer"] = ""
|
||||
# used to identify the stream that was stopped for agent search
|
||||
level: int | None = None
|
||||
level_question_nr: int | None = None
|
||||
|
||||
def model_dump(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
|
||||
data = super().model_dump(mode="json", *args, **kwargs) # type: ignore
|
||||
data["stop_reason"] = self.stop_reason.name
|
||||
@@ -108,6 +117,8 @@ class OnyxAnswerPiece(BaseModel):
|
||||
class CitationInfo(BaseModel):
|
||||
citation_num: int
|
||||
document_id: str
|
||||
level: int | None = None
|
||||
level_question_nr: int | None = None
|
||||
|
||||
|
||||
class AllCitations(BaseModel):
|
||||
@@ -273,7 +284,7 @@ class AnswerStyleConfig(BaseModel):
|
||||
|
||||
class PromptConfig(BaseModel):
|
||||
"""Final representation of the Prompt configuration passed
|
||||
into the `Answer` object."""
|
||||
into the `PromptBuilder` object."""
|
||||
|
||||
system_prompt: str
|
||||
task_prompt: str
|
||||
@@ -299,6 +310,48 @@ class PromptConfig(BaseModel):
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
|
||||
class SubQueryPiece(BaseModel):
|
||||
sub_query: str
|
||||
level: int
|
||||
level_question_nr: int
|
||||
query_id: int
|
||||
|
||||
|
||||
class AgentAnswerPiece(BaseModel):
|
||||
answer_piece: str
|
||||
level: int
|
||||
level_question_nr: int
|
||||
answer_type: Literal["agent_sub_answer", "agent_level_answer"]
|
||||
|
||||
|
||||
class SubQuestionPiece(BaseModel):
|
||||
sub_question: str
|
||||
level: int
|
||||
level_question_nr: int
|
||||
|
||||
|
||||
class ExtendedToolResponse(ToolResponse):
|
||||
level: int
|
||||
level_question_nr: int
|
||||
|
||||
|
||||
class RefinedAnswerImprovement(BaseModel):
|
||||
refined_answer_improvement: bool
|
||||
|
||||
|
||||
AgentSearchPacket = (
|
||||
SubQuestionPiece
|
||||
| AgentAnswerPiece
|
||||
| SubQueryPiece
|
||||
| ExtendedToolResponse
|
||||
| RefinedAnswerImprovement
|
||||
)
|
||||
|
||||
AnswerPacket = (
|
||||
AnswerQuestionPossibleReturn | AgentSearchPacket | ToolCallKickoff | ToolResponse
|
||||
)
|
||||
|
||||
|
||||
ResponsePart = (
|
||||
OnyxAnswerPiece
|
||||
| CitationInfo
|
||||
@@ -306,4 +359,7 @@ ResponsePart = (
|
||||
| ToolResponse
|
||||
| ToolCallFinalResult
|
||||
| StreamStopInfo
|
||||
| AgentSearchPacket
|
||||
)
|
||||
|
||||
AnswerStream = Iterator[AnswerPacket]
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import traceback
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from typing import cast
|
||||
|
||||
@@ -9,6 +11,7 @@ from sqlalchemy.orm import Session
|
||||
from onyx.chat.answer import Answer
|
||||
from onyx.chat.chat_utils import create_chat_chain
|
||||
from onyx.chat.chat_utils import create_temporary_persona
|
||||
from onyx.chat.models import AgentSearchPacket
|
||||
from onyx.chat.models import AllCitations
|
||||
from onyx.chat.models import AnswerStyleConfig
|
||||
from onyx.chat.models import ChatOnyxBotResponse
|
||||
@@ -16,6 +19,7 @@ from onyx.chat.models import CitationConfig
|
||||
from onyx.chat.models import CitationInfo
|
||||
from onyx.chat.models import CustomToolResponse
|
||||
from onyx.chat.models import DocumentPruningConfig
|
||||
from onyx.chat.models import ExtendedToolResponse
|
||||
from onyx.chat.models import FileChatDisplay
|
||||
from onyx.chat.models import FinalUsedContextDocsResponse
|
||||
from onyx.chat.models import LLMRelevanceFilterResponse
|
||||
@@ -25,19 +29,28 @@ from onyx.chat.models import OnyxAnswerPiece
|
||||
from onyx.chat.models import OnyxContexts
|
||||
from onyx.chat.models import PromptConfig
|
||||
from onyx.chat.models import QADocsResponse
|
||||
from onyx.chat.models import RefinedAnswerImprovement
|
||||
from onyx.chat.models import StreamingError
|
||||
from onyx.chat.models import StreamStopInfo
|
||||
from onyx.chat.models import StreamStopReason
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_system_message
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_user_message
|
||||
from onyx.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
|
||||
from onyx.configs.chat_configs import DISABLE_LLM_CHOOSE_SEARCH
|
||||
from onyx.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
|
||||
from onyx.configs.constants import AGENT_SEARCH_INITIAL_KEY
|
||||
from onyx.configs.constants import BASIC_KEY
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.configs.constants import MilestoneRecordType
|
||||
from onyx.configs.constants import NO_AUTH_USER_ID
|
||||
from onyx.context.search.enums import LLMEvaluationType
|
||||
from onyx.context.search.enums import OptionalSearchSetting
|
||||
from onyx.context.search.enums import QueryFlow
|
||||
from onyx.context.search.enums import SearchType
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.context.search.models import RetrievalDetails
|
||||
from onyx.context.search.models import SearchRequest
|
||||
from onyx.context.search.retrieval.search_runner import inference_sections_from_ids
|
||||
from onyx.context.search.utils import chunks_or_sections_to_search_docs
|
||||
from onyx.context.search.utils import dedupe_documents
|
||||
@@ -127,7 +140,6 @@ from onyx.utils.timing import log_function_time
|
||||
from onyx.utils.timing import log_generator_function_time
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@@ -159,12 +171,15 @@ def _handle_search_tool_response_summary(
|
||||
) -> tuple[QADocsResponse, list[DbSearchDoc], list[int] | None]:
|
||||
response_sumary = cast(SearchResponseSummary, packet.response)
|
||||
|
||||
is_extended = isinstance(packet, ExtendedToolResponse)
|
||||
dropped_inds = None
|
||||
if not selected_search_docs:
|
||||
top_docs = chunks_or_sections_to_search_docs(response_sumary.top_sections)
|
||||
|
||||
deduped_docs = top_docs
|
||||
if dedupe_docs:
|
||||
if (
|
||||
dedupe_docs and not is_extended
|
||||
): # Extended tool responses are already deduped
|
||||
deduped_docs, dropped_inds = dedupe_documents(top_docs)
|
||||
|
||||
reference_db_search_docs = [
|
||||
@@ -178,6 +193,10 @@ def _handle_search_tool_response_summary(
|
||||
translate_db_search_doc_to_server_search_doc(db_search_doc)
|
||||
for db_search_doc in reference_db_search_docs
|
||||
]
|
||||
|
||||
level, question_nr = None, None
|
||||
if isinstance(packet, ExtendedToolResponse):
|
||||
level, question_nr = packet.level, packet.level_question_nr
|
||||
return (
|
||||
QADocsResponse(
|
||||
rephrased_query=response_sumary.rephrased_query,
|
||||
@@ -187,6 +206,8 @@ def _handle_search_tool_response_summary(
|
||||
applied_source_filters=response_sumary.final_filters.source_type,
|
||||
applied_time_cutoff=response_sumary.final_filters.time_cutoff,
|
||||
recency_bias_multiplier=response_sumary.recency_bias_multiplier,
|
||||
level=level,
|
||||
level_question_nr=question_nr,
|
||||
),
|
||||
reference_db_search_docs,
|
||||
dropped_inds,
|
||||
@@ -282,10 +303,22 @@ ChatPacket = (
|
||||
| MessageSpecificCitations
|
||||
| MessageResponseIDInfo
|
||||
| StreamStopInfo
|
||||
| AgentSearchPacket
|
||||
)
|
||||
ChatPacketStream = Iterator[ChatPacket]
|
||||
|
||||
|
||||
# can't store a DbSearchDoc in a Pydantic BaseModel
|
||||
@dataclass
|
||||
class AnswerPostInfo:
|
||||
ai_message_files: list[FileDescriptor]
|
||||
qa_docs_response: QADocsResponse | None = None
|
||||
reference_db_search_docs: list[DbSearchDoc] | None = None
|
||||
dropped_indices: list[int] | None = None
|
||||
tool_result: ToolCallFinalResult | None = None
|
||||
message_specific_citations: MessageSpecificCitations | None = None
|
||||
|
||||
|
||||
def stream_chat_message_objects(
|
||||
new_msg_req: CreateChatMessageRequest,
|
||||
user: User | None,
|
||||
@@ -324,6 +357,7 @@ def stream_chat_message_objects(
|
||||
new_msg_req.chunks_above = 0
|
||||
new_msg_req.chunks_below = 0
|
||||
|
||||
llm = None
|
||||
try:
|
||||
user_id = user.id if user is not None else None
|
||||
|
||||
@@ -502,11 +536,8 @@ def stream_chat_message_objects(
|
||||
files = load_all_chat_files(
|
||||
history_msgs, new_msg_req.file_descriptors, db_session
|
||||
)
|
||||
latest_query_files = [
|
||||
file
|
||||
for file in files
|
||||
if file.file_id in [f["id"] for f in new_msg_req.file_descriptors]
|
||||
]
|
||||
req_file_ids = [f["id"] for f in new_msg_req.file_descriptors]
|
||||
latest_query_files = [file for file in files if file.file_id in req_file_ids]
|
||||
|
||||
if user_message:
|
||||
attach_files_to_chat_message(
|
||||
@@ -679,13 +710,58 @@ def stream_chat_message_objects(
|
||||
for tool_list in tool_dict.values():
|
||||
tools.extend(tool_list)
|
||||
|
||||
# TODO: unify message history with single message history
|
||||
message_history = [
|
||||
PreviousMessage.from_chat_message(msg, files) for msg in history_msgs
|
||||
]
|
||||
|
||||
search_request = SearchRequest(
|
||||
query=final_msg.message,
|
||||
evaluation_type=(
|
||||
LLMEvaluationType.BASIC
|
||||
if persona.llm_relevance_filter
|
||||
else LLMEvaluationType.SKIP
|
||||
),
|
||||
human_selected_filters=(
|
||||
retrieval_options.filters if retrieval_options else None
|
||||
),
|
||||
persona=persona,
|
||||
offset=(retrieval_options.offset if retrieval_options else None),
|
||||
limit=retrieval_options.limit if retrieval_options else None,
|
||||
rerank_settings=new_msg_req.rerank_settings,
|
||||
chunks_above=new_msg_req.chunks_above,
|
||||
chunks_below=new_msg_req.chunks_below,
|
||||
full_doc=new_msg_req.full_doc,
|
||||
enable_auto_detect_filters=(
|
||||
retrieval_options.enable_auto_detect_filters
|
||||
if retrieval_options
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
force_use_tool = _get_force_search_settings(new_msg_req, tools)
|
||||
prompt_builder = AnswerPromptBuilder(
|
||||
user_message=default_build_user_message(
|
||||
user_query=final_msg.message,
|
||||
prompt_config=prompt_config,
|
||||
files=latest_query_files,
|
||||
single_message_history=single_message_history,
|
||||
),
|
||||
system_message=default_build_system_message(prompt_config),
|
||||
message_history=message_history,
|
||||
llm_config=llm.config,
|
||||
raw_user_query=final_msg.message,
|
||||
raw_user_uploaded_files=latest_query_files or [],
|
||||
single_message_history=single_message_history,
|
||||
)
|
||||
prompt_builder.update_system_prompt(default_build_system_message(prompt_config))
|
||||
|
||||
# LLM prompt building, response capturing, etc.
|
||||
answer = Answer(
|
||||
prompt_builder=prompt_builder,
|
||||
is_connected=is_connected,
|
||||
question=final_msg.message,
|
||||
latest_query_files=latest_query_files,
|
||||
answer_style_config=answer_style_config,
|
||||
prompt_config=prompt_config,
|
||||
llm=(
|
||||
llm
|
||||
or get_main_llm_from_tuple(
|
||||
@@ -698,28 +774,42 @@ def stream_chat_message_objects(
|
||||
)
|
||||
)
|
||||
),
|
||||
message_history=[
|
||||
PreviousMessage.from_chat_message(msg, files) for msg in history_msgs
|
||||
],
|
||||
fast_llm=fast_llm,
|
||||
force_use_tool=force_use_tool,
|
||||
search_request=search_request,
|
||||
chat_session_id=chat_session_id,
|
||||
current_agent_message_id=reserved_message_id,
|
||||
tools=tools,
|
||||
force_use_tool=_get_force_search_settings(new_msg_req, tools),
|
||||
single_message_history=single_message_history,
|
||||
db_session=db_session,
|
||||
use_agentic_search=new_msg_req.use_agentic_search,
|
||||
)
|
||||
|
||||
reference_db_search_docs = None
|
||||
qa_docs_response = None
|
||||
# any files to associate with the AI message e.g. dall-e generated images
|
||||
ai_message_files = []
|
||||
dropped_indices = None
|
||||
tool_result = None
|
||||
# reference_db_search_docs = None
|
||||
# qa_docs_response = None
|
||||
# # any files to associate with the AI message e.g. dall-e generated images
|
||||
# ai_message_files = []
|
||||
# dropped_indices = None
|
||||
# tool_result = None
|
||||
|
||||
# TODO: different channels for stored info when it's coming from the agent flow
|
||||
info_by_subq: dict[tuple[int, int], AnswerPostInfo] = defaultdict(
|
||||
lambda: AnswerPostInfo(ai_message_files=[])
|
||||
)
|
||||
refined_answer_improvement = True
|
||||
for packet in answer.processed_streamed_output:
|
||||
if isinstance(packet, ToolResponse):
|
||||
level, level_question_nr = (
|
||||
(packet.level, packet.level_question_nr)
|
||||
if isinstance(packet, ExtendedToolResponse)
|
||||
else BASIC_KEY
|
||||
)
|
||||
info = info_by_subq[(level, level_question_nr)]
|
||||
# TODO: don't need to dedupe here when we do it in agent flow
|
||||
if packet.id == SEARCH_RESPONSE_SUMMARY_ID:
|
||||
(
|
||||
qa_docs_response,
|
||||
reference_db_search_docs,
|
||||
dropped_indices,
|
||||
info.qa_docs_response,
|
||||
info.reference_db_search_docs,
|
||||
info.dropped_indices,
|
||||
) = _handle_search_tool_response_summary(
|
||||
packet=packet,
|
||||
db_session=db_session,
|
||||
@@ -731,29 +821,34 @@ def stream_chat_message_objects(
|
||||
else False
|
||||
),
|
||||
)
|
||||
yield qa_docs_response
|
||||
yield info.qa_docs_response
|
||||
elif packet.id == SECTION_RELEVANCE_LIST_ID:
|
||||
relevance_sections = packet.response
|
||||
|
||||
if reference_db_search_docs is not None:
|
||||
llm_indices = relevant_sections_to_indices(
|
||||
relevance_sections=relevance_sections,
|
||||
items=[
|
||||
translate_db_search_doc_to_server_search_doc(doc)
|
||||
for doc in reference_db_search_docs
|
||||
],
|
||||
if info.reference_db_search_docs is None:
|
||||
logger.warning(
|
||||
"No reference docs found for relevance filtering"
|
||||
)
|
||||
continue
|
||||
|
||||
llm_indices = relevant_sections_to_indices(
|
||||
relevance_sections=relevance_sections,
|
||||
items=[
|
||||
translate_db_search_doc_to_server_search_doc(doc)
|
||||
for doc in info.reference_db_search_docs
|
||||
],
|
||||
)
|
||||
|
||||
if info.dropped_indices:
|
||||
llm_indices = drop_llm_indices(
|
||||
llm_indices=llm_indices,
|
||||
search_docs=info.reference_db_search_docs,
|
||||
dropped_indices=info.dropped_indices,
|
||||
)
|
||||
|
||||
if dropped_indices:
|
||||
llm_indices = drop_llm_indices(
|
||||
llm_indices=llm_indices,
|
||||
search_docs=reference_db_search_docs,
|
||||
dropped_indices=dropped_indices,
|
||||
)
|
||||
|
||||
yield LLMRelevanceFilterResponse(
|
||||
llm_selected_doc_indices=llm_indices
|
||||
)
|
||||
yield LLMRelevanceFilterResponse(
|
||||
llm_selected_doc_indices=llm_indices
|
||||
)
|
||||
elif packet.id == FINAL_CONTEXT_DOCUMENTS_ID:
|
||||
yield FinalUsedContextDocsResponse(
|
||||
final_context_docs=packet.response
|
||||
@@ -773,22 +868,24 @@ def stream_chat_message_objects(
|
||||
],
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
ai_message_files = [
|
||||
FileDescriptor(id=str(file_id), type=ChatFileType.IMAGE)
|
||||
for file_id in file_ids
|
||||
]
|
||||
info.ai_message_files.extend(
|
||||
[
|
||||
FileDescriptor(id=str(file_id), type=ChatFileType.IMAGE)
|
||||
for file_id in file_ids
|
||||
]
|
||||
)
|
||||
yield FileChatDisplay(
|
||||
file_ids=[str(file_id) for file_id in file_ids]
|
||||
)
|
||||
elif packet.id == INTERNET_SEARCH_RESPONSE_ID:
|
||||
(
|
||||
qa_docs_response,
|
||||
reference_db_search_docs,
|
||||
info.qa_docs_response,
|
||||
info.reference_db_search_docs,
|
||||
) = _handle_internet_search_tool_response_summary(
|
||||
packet=packet,
|
||||
db_session=db_session,
|
||||
)
|
||||
yield qa_docs_response
|
||||
yield info.qa_docs_response
|
||||
elif packet.id == CUSTOM_TOOL_RESPONSE_ID:
|
||||
custom_tool_response = cast(CustomToolCallSummary, packet.response)
|
||||
|
||||
@@ -797,7 +894,7 @@ def stream_chat_message_objects(
|
||||
or custom_tool_response.response_type == "csv"
|
||||
):
|
||||
file_ids = custom_tool_response.tool_result.file_ids
|
||||
ai_message_files.extend(
|
||||
info.ai_message_files.extend(
|
||||
[
|
||||
FileDescriptor(
|
||||
id=str(file_id),
|
||||
@@ -822,10 +919,21 @@ def stream_chat_message_objects(
|
||||
yield cast(OnyxContexts, packet.response)
|
||||
|
||||
elif isinstance(packet, StreamStopInfo):
|
||||
pass
|
||||
if packet.stop_reason == StreamStopReason.FINISHED:
|
||||
yield packet
|
||||
elif isinstance(packet, RefinedAnswerImprovement):
|
||||
refined_answer_improvement = packet.refined_answer_improvement
|
||||
yield packet
|
||||
else:
|
||||
if isinstance(packet, ToolCallFinalResult):
|
||||
tool_result = packet
|
||||
level, level_question_nr = (
|
||||
(packet.level, packet.level_question_nr)
|
||||
if packet.level is not None
|
||||
and packet.level_question_nr is not None
|
||||
else BASIC_KEY
|
||||
)
|
||||
info = info_by_subq[(level, level_question_nr)]
|
||||
info.tool_result = packet
|
||||
yield cast(ChatPacket, packet)
|
||||
logger.debug("Reached end of stream")
|
||||
except ValueError as e:
|
||||
@@ -841,59 +949,99 @@ def stream_chat_message_objects(
|
||||
|
||||
error_msg = str(e)
|
||||
stack_trace = traceback.format_exc()
|
||||
client_error_msg = litellm_exception_to_error_msg(e, llm)
|
||||
if llm.config.api_key and len(llm.config.api_key) > 2:
|
||||
error_msg = error_msg.replace(llm.config.api_key, "[REDACTED_API_KEY]")
|
||||
stack_trace = stack_trace.replace(llm.config.api_key, "[REDACTED_API_KEY]")
|
||||
if llm:
|
||||
client_error_msg = litellm_exception_to_error_msg(e, llm)
|
||||
if llm.config.api_key and len(llm.config.api_key) > 2:
|
||||
error_msg = error_msg.replace(llm.config.api_key, "[REDACTED_API_KEY]")
|
||||
stack_trace = stack_trace.replace(
|
||||
llm.config.api_key, "[REDACTED_API_KEY]"
|
||||
)
|
||||
|
||||
yield StreamingError(error=client_error_msg, stack_trace=stack_trace)
|
||||
yield StreamingError(error=client_error_msg, stack_trace=stack_trace)
|
||||
db_session.rollback()
|
||||
return
|
||||
|
||||
# Post-LLM answer processing
|
||||
try:
|
||||
logger.debug("Post-LLM answer processing")
|
||||
message_specific_citations: MessageSpecificCitations | None = None
|
||||
if reference_db_search_docs:
|
||||
message_specific_citations = _translate_citations(
|
||||
citations_list=answer.citations,
|
||||
db_docs=reference_db_search_docs,
|
||||
)
|
||||
if not answer.is_cancelled():
|
||||
yield AllCitations(citations=answer.citations)
|
||||
|
||||
# Saving Gen AI answer and responding with message info
|
||||
tool_name_to_tool_id: dict[str, int] = {}
|
||||
for tool_id, tool_list in tool_dict.items():
|
||||
for tool in tool_list:
|
||||
tool_name_to_tool_id[tool.name] = tool_id
|
||||
|
||||
subq_citations = answer.citations_by_subquestion()
|
||||
for pair in subq_citations:
|
||||
level, level_question_nr = pair
|
||||
info = info_by_subq[(level, level_question_nr)]
|
||||
logger.debug("Post-LLM answer processing")
|
||||
if info.reference_db_search_docs:
|
||||
info.message_specific_citations = _translate_citations(
|
||||
citations_list=subq_citations[pair],
|
||||
db_docs=info.reference_db_search_docs,
|
||||
)
|
||||
|
||||
# TODO: AllCitations should contain subq info?
|
||||
if not answer.is_cancelled():
|
||||
yield AllCitations(citations=subq_citations[pair])
|
||||
|
||||
# Saving Gen AI answer and responding with message info
|
||||
|
||||
info = (
|
||||
info_by_subq[BASIC_KEY]
|
||||
if BASIC_KEY in info_by_subq
|
||||
else info_by_subq[AGENT_SEARCH_INITIAL_KEY]
|
||||
)
|
||||
gen_ai_response_message = partial_response(
|
||||
message=answer.llm_answer,
|
||||
rephrased_query=(
|
||||
qa_docs_response.rephrased_query if qa_docs_response else None
|
||||
info.qa_docs_response.rephrased_query if info.qa_docs_response else None
|
||||
),
|
||||
reference_docs=reference_db_search_docs,
|
||||
files=ai_message_files,
|
||||
reference_docs=info.reference_db_search_docs,
|
||||
files=info.ai_message_files,
|
||||
token_count=len(llm_tokenizer_encode_func(answer.llm_answer)),
|
||||
citations=(
|
||||
message_specific_citations.citation_map
|
||||
if message_specific_citations
|
||||
info.message_specific_citations.citation_map
|
||||
if info.message_specific_citations
|
||||
else None
|
||||
),
|
||||
error=None,
|
||||
tool_call=(
|
||||
ToolCall(
|
||||
tool_id=tool_name_to_tool_id[tool_result.tool_name],
|
||||
tool_name=tool_result.tool_name,
|
||||
tool_arguments=tool_result.tool_args,
|
||||
tool_result=tool_result.tool_result,
|
||||
tool_id=tool_name_to_tool_id[info.tool_result.tool_name],
|
||||
tool_name=info.tool_result.tool_name,
|
||||
tool_arguments=info.tool_result.tool_args,
|
||||
tool_result=info.tool_result.tool_result,
|
||||
)
|
||||
if tool_result
|
||||
if info.tool_result
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
# TODO: add answers for levels >= 1, where each level has the previous as its parent. Use
|
||||
# the answer_by_level method in answer.py to get the answers for each level
|
||||
next_level = 1
|
||||
prev_message = gen_ai_response_message
|
||||
agent_answers = answer.llm_answer_by_level()
|
||||
while next_level in agent_answers:
|
||||
next_answer = agent_answers[next_level]
|
||||
info = info_by_subq[(next_level, AGENT_SEARCH_INITIAL_KEY[1])]
|
||||
next_answer_message = create_new_chat_message(
|
||||
chat_session_id=chat_session_id,
|
||||
parent_message=prev_message,
|
||||
message=next_answer,
|
||||
prompt_id=None,
|
||||
token_count=len(llm_tokenizer_encode_func(next_answer)),
|
||||
message_type=MessageType.ASSISTANT,
|
||||
db_session=db_session,
|
||||
files=info.ai_message_files,
|
||||
reference_docs=info.reference_db_search_docs,
|
||||
citations=info.message_specific_citations.citation_map
|
||||
if info.message_specific_citations
|
||||
else None,
|
||||
refined_answer_improvement=refined_answer_improvement,
|
||||
)
|
||||
next_level += 1
|
||||
prev_message = next_answer_message
|
||||
|
||||
logger.debug("Committing messages")
|
||||
db_session.commit() # actually save user / assistant message
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import cast
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import SystemMessage
|
||||
from pydantic import BaseModel
|
||||
from pydantic.v1 import BaseModel as BaseModel__v1
|
||||
|
||||
from onyx.chat.models import PromptConfig
|
||||
@@ -84,6 +85,7 @@ class AnswerPromptBuilder:
|
||||
raw_user_query: str,
|
||||
raw_user_uploaded_files: list[InMemoryChatFile],
|
||||
single_message_history: str | None = None,
|
||||
system_message: SystemMessage | None = None,
|
||||
) -> None:
|
||||
self.max_tokens = compute_max_llm_input_tokens(llm_config)
|
||||
|
||||
@@ -108,7 +110,14 @@ class AnswerPromptBuilder:
|
||||
),
|
||||
)
|
||||
|
||||
self.system_message_and_token_cnt: tuple[SystemMessage, int] | None = None
|
||||
self.system_message_and_token_cnt: tuple[SystemMessage, int] | None = (
|
||||
(
|
||||
system_message,
|
||||
check_message_tokens(system_message, self.llm_tokenizer_encode_func),
|
||||
)
|
||||
if system_message
|
||||
else None
|
||||
)
|
||||
self.user_message_and_token_cnt = (
|
||||
user_message,
|
||||
check_message_tokens(
|
||||
@@ -174,6 +183,14 @@ class AnswerPromptBuilder:
|
||||
)
|
||||
|
||||
|
||||
# Stores some parts of a prompt builder as needed for tool calls
|
||||
class PromptSnapshot(BaseModel):
|
||||
raw_message_history: list[PreviousMessage]
|
||||
raw_user_query: str
|
||||
built_prompt: list[BaseMessage]
|
||||
|
||||
|
||||
# TODO: rename this? AnswerConfig maybe?
|
||||
class LLMCall(BaseModel__v1):
|
||||
prompt_builder: AnswerPromptBuilder
|
||||
tools: list[Tool]
|
||||
|
||||
@@ -3,9 +3,10 @@ from collections.abc import Generator
|
||||
|
||||
from langchain_core.messages import BaseMessage
|
||||
|
||||
from onyx.chat.llm_response_handler import ResponsePart
|
||||
from onyx.chat.models import CitationInfo
|
||||
from onyx.chat.models import LlmDoc
|
||||
from onyx.chat.models import OnyxAnswerPiece
|
||||
from onyx.chat.models import ResponsePart
|
||||
from onyx.chat.stream_processing.citation_processing import CitationProcessor
|
||||
from onyx.chat.stream_processing.utils import DocumentIdOrderMapping
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -13,21 +14,32 @@ from onyx.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
# TODO: remove update() once it is no longer needed
|
||||
class AnswerResponseHandler(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def handle_response_part(
|
||||
self,
|
||||
response_item: BaseMessage | None,
|
||||
previous_response_items: list[BaseMessage],
|
||||
response_item: BaseMessage | str | None,
|
||||
previous_response_items: list[BaseMessage | str],
|
||||
) -> Generator[ResponsePart, None, None]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class PassThroughAnswerResponseHandler(AnswerResponseHandler):
|
||||
def handle_response_part(
|
||||
self,
|
||||
response_item: BaseMessage | str | None,
|
||||
previous_response_items: list[BaseMessage | str],
|
||||
) -> Generator[ResponsePart, None, None]:
|
||||
content = _message_to_str(response_item)
|
||||
yield OnyxAnswerPiece(answer_piece=content)
|
||||
|
||||
|
||||
class DummyAnswerResponseHandler(AnswerResponseHandler):
|
||||
def handle_response_part(
|
||||
self,
|
||||
response_item: BaseMessage | None,
|
||||
previous_response_items: list[BaseMessage],
|
||||
response_item: BaseMessage | str | None,
|
||||
previous_response_items: list[BaseMessage | str],
|
||||
) -> Generator[ResponsePart, None, None]:
|
||||
# This is a dummy handler that returns nothing
|
||||
yield from []
|
||||
@@ -56,43 +68,25 @@ class CitationResponseHandler(AnswerResponseHandler):
|
||||
|
||||
def handle_response_part(
|
||||
self,
|
||||
response_item: BaseMessage | None,
|
||||
previous_response_items: list[BaseMessage],
|
||||
response_item: BaseMessage | str | None,
|
||||
previous_response_items: list[BaseMessage | str],
|
||||
) -> Generator[ResponsePart, None, None]:
|
||||
if response_item is None:
|
||||
return
|
||||
|
||||
content = (
|
||||
response_item.content if isinstance(response_item.content, str) else ""
|
||||
)
|
||||
content = _message_to_str(response_item)
|
||||
|
||||
# Process the new content through the citation processor
|
||||
yield from self.citation_processor.process_token(content)
|
||||
|
||||
|
||||
# No longer in use, remove later
|
||||
# class QuotesResponseHandler(AnswerResponseHandler):
|
||||
# def __init__(
|
||||
# self,
|
||||
# context_docs: list[LlmDoc],
|
||||
# is_json_prompt: bool = True,
|
||||
# ):
|
||||
# self.quotes_processor = QuotesProcessor(
|
||||
# context_docs=context_docs,
|
||||
# is_json_prompt=is_json_prompt,
|
||||
# )
|
||||
|
||||
# def handle_response_part(
|
||||
# self,
|
||||
# response_item: BaseMessage | None,
|
||||
# previous_response_items: list[BaseMessage],
|
||||
# ) -> Generator[ResponsePart, None, None]:
|
||||
# if response_item is None:
|
||||
# yield from self.quotes_processor.process_token(None)
|
||||
# return
|
||||
|
||||
# content = (
|
||||
# response_item.content if isinstance(response_item.content, str) else ""
|
||||
# )
|
||||
|
||||
# yield from self.quotes_processor.process_token(content)
|
||||
def _message_to_str(message: BaseMessage | str | None) -> str:
|
||||
if message is None:
|
||||
return ""
|
||||
if isinstance(message, str):
|
||||
return message
|
||||
content = message.content if isinstance(message, BaseMessage) else message
|
||||
if not isinstance(content, str):
|
||||
logger.warning(f"Received non-string content: {type(content)}")
|
||||
content = str(content) if content is not None else ""
|
||||
return content
|
||||
|
||||
@@ -5,7 +5,9 @@ from langchain_core.messages import BaseMessage
|
||||
from langchain_core.messages import ToolCall
|
||||
|
||||
from onyx.chat.models import ResponsePart
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import LLMCall
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import PromptSnapshot
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.tools.force import ForceUseTool
|
||||
from onyx.tools.message import build_tool_message
|
||||
@@ -25,6 +27,13 @@ from onyx.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def get_tool_by_name(tools: list[Tool], tool_name: str) -> Tool:
|
||||
for tool in tools:
|
||||
if tool.name == tool_name:
|
||||
return tool
|
||||
raise RuntimeError(f"Tool '{tool_name}' not found")
|
||||
|
||||
|
||||
class ToolResponseHandler:
|
||||
def __init__(self, tools: list[Tool]):
|
||||
self.tools = tools
|
||||
@@ -43,67 +52,12 @@ class ToolResponseHandler:
|
||||
def get_tool_call_for_non_tool_calling_llm(
|
||||
cls, llm_call: LLMCall, llm: LLM
|
||||
) -> tuple[Tool, dict] | None:
|
||||
if llm_call.force_use_tool.force_use:
|
||||
# if we are forcing a tool, we don't need to check which tools to run
|
||||
tool = next(
|
||||
(
|
||||
t
|
||||
for t in llm_call.tools
|
||||
if t.name == llm_call.force_use_tool.tool_name
|
||||
),
|
||||
None,
|
||||
)
|
||||
if not tool:
|
||||
raise RuntimeError(
|
||||
f"Tool '{llm_call.force_use_tool.tool_name}' not found"
|
||||
)
|
||||
|
||||
tool_args = (
|
||||
llm_call.force_use_tool.args
|
||||
if llm_call.force_use_tool.args is not None
|
||||
else tool.get_args_for_non_tool_calling_llm(
|
||||
query=llm_call.prompt_builder.raw_user_query,
|
||||
history=llm_call.prompt_builder.raw_message_history,
|
||||
llm=llm,
|
||||
force_run=True,
|
||||
)
|
||||
)
|
||||
|
||||
if tool_args is None:
|
||||
raise RuntimeError(f"Tool '{tool.name}' did not return args")
|
||||
|
||||
return (tool, tool_args)
|
||||
else:
|
||||
tool_options = check_which_tools_should_run_for_non_tool_calling_llm(
|
||||
tools=llm_call.tools,
|
||||
query=llm_call.prompt_builder.raw_user_query,
|
||||
history=llm_call.prompt_builder.raw_message_history,
|
||||
llm=llm,
|
||||
)
|
||||
|
||||
available_tools_and_args = [
|
||||
(llm_call.tools[ind], args)
|
||||
for ind, args in enumerate(tool_options)
|
||||
if args is not None
|
||||
]
|
||||
|
||||
logger.info(
|
||||
f"Selecting single tool from tools: {[(tool.name, args) for tool, args in available_tools_and_args]}"
|
||||
)
|
||||
|
||||
chosen_tool_and_args = (
|
||||
select_single_tool_for_non_tool_calling_llm(
|
||||
tools_and_args=available_tools_and_args,
|
||||
history=llm_call.prompt_builder.raw_message_history,
|
||||
query=llm_call.prompt_builder.raw_user_query,
|
||||
llm=llm,
|
||||
)
|
||||
if available_tools_and_args
|
||||
else None
|
||||
)
|
||||
|
||||
logger.notice(f"Chosen tool: {chosen_tool_and_args}")
|
||||
return chosen_tool_and_args
|
||||
return get_tool_call_for_non_tool_calling_llm_impl(
|
||||
force_use_tool=llm_call.force_use_tool,
|
||||
tools=llm_call.tools,
|
||||
prompt_builder=llm_call.prompt_builder,
|
||||
llm=llm,
|
||||
)
|
||||
|
||||
def _handle_tool_call(self) -> Generator[ResponsePart, None, None]:
|
||||
if not self.tool_call_chunk or not self.tool_call_chunk.tool_calls:
|
||||
@@ -118,20 +72,17 @@ class ToolResponseHandler:
|
||||
tool for tool in self.tools if tool.name == tool_call_request["name"]
|
||||
]
|
||||
|
||||
if not known_tools_by_name:
|
||||
logger.error(
|
||||
"Tool call requested with unknown name field. \n"
|
||||
f"self.tools: {self.tools}"
|
||||
f"tool_call_request: {tool_call_request}"
|
||||
)
|
||||
continue
|
||||
else:
|
||||
if known_tools_by_name:
|
||||
selected_tool = known_tools_by_name[0]
|
||||
selected_tool_call_request = tool_call_request
|
||||
|
||||
if selected_tool and selected_tool_call_request:
|
||||
break
|
||||
|
||||
logger.error(
|
||||
"Tool call requested with unknown name field. \n"
|
||||
f"self.tools: {self.tools}"
|
||||
f"tool_call_request: {tool_call_request}"
|
||||
)
|
||||
|
||||
if not selected_tool or not selected_tool_call_request:
|
||||
return
|
||||
|
||||
@@ -157,8 +108,8 @@ class ToolResponseHandler:
|
||||
|
||||
def handle_response_part(
|
||||
self,
|
||||
response_item: BaseMessage | None,
|
||||
previous_response_items: list[BaseMessage],
|
||||
response_item: BaseMessage | str | None,
|
||||
previous_response_items: list[BaseMessage | str],
|
||||
) -> Generator[ResponsePart, None, None]:
|
||||
if response_item is None:
|
||||
yield from self._handle_tool_call()
|
||||
@@ -171,8 +122,6 @@ class ToolResponseHandler:
|
||||
else:
|
||||
self.tool_call_chunk += response_item # type: ignore
|
||||
|
||||
return
|
||||
|
||||
def next_llm_call(self, current_llm_call: LLMCall) -> LLMCall | None:
|
||||
if (
|
||||
self.tool_runner is None
|
||||
@@ -205,3 +154,61 @@ class ToolResponseHandler:
|
||||
self.tool_final_result,
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def get_tool_call_for_non_tool_calling_llm_impl(
|
||||
force_use_tool: ForceUseTool,
|
||||
tools: list[Tool],
|
||||
prompt_builder: AnswerPromptBuilder | PromptSnapshot,
|
||||
llm: LLM,
|
||||
) -> tuple[Tool, dict] | None:
|
||||
if force_use_tool.force_use:
|
||||
# if we are forcing a tool, we don't need to check which tools to run
|
||||
tool = get_tool_by_name(tools, force_use_tool.tool_name)
|
||||
|
||||
tool_args = (
|
||||
force_use_tool.args
|
||||
if force_use_tool.args is not None
|
||||
else tool.get_args_for_non_tool_calling_llm(
|
||||
query=prompt_builder.raw_user_query,
|
||||
history=prompt_builder.raw_message_history,
|
||||
llm=llm,
|
||||
force_run=True,
|
||||
)
|
||||
)
|
||||
|
||||
if tool_args is None:
|
||||
raise RuntimeError(f"Tool '{tool.name}' did not return args")
|
||||
|
||||
return (tool, tool_args)
|
||||
else:
|
||||
tool_options = check_which_tools_should_run_for_non_tool_calling_llm(
|
||||
tools=tools,
|
||||
query=prompt_builder.raw_user_query,
|
||||
history=prompt_builder.raw_message_history,
|
||||
llm=llm,
|
||||
)
|
||||
|
||||
available_tools_and_args = [
|
||||
(tools[ind], args)
|
||||
for ind, args in enumerate(tool_options)
|
||||
if args is not None
|
||||
]
|
||||
|
||||
logger.info(
|
||||
f"Selecting single tool from tools: {[(tool.name, args) for tool, args in available_tools_and_args]}"
|
||||
)
|
||||
|
||||
chosen_tool_and_args = (
|
||||
select_single_tool_for_non_tool_calling_llm(
|
||||
tools_and_args=available_tools_and_args,
|
||||
history=prompt_builder.raw_message_history,
|
||||
query=prompt_builder.raw_user_query,
|
||||
llm=llm,
|
||||
)
|
||||
if available_tools_and_args
|
||||
else None
|
||||
)
|
||||
|
||||
logger.notice(f"Chosen tool: {chosen_tool_and_args}")
|
||||
return chosen_tool_and_args
|
||||
|
||||
77
backend/onyx/configs/agent_configs.py
Normal file
77
backend/onyx/configs/agent_configs.py
Normal file
@@ -0,0 +1,77 @@
|
||||
import os
|
||||
|
||||
AGENT_DEFAULT_RETRIEVAL_HITS = 15
|
||||
AGENT_DEFAULT_RERANKING_HITS = 10
|
||||
AGENT_DEFAULT_SUB_QUESTION_MAX_CONTEXT_HITS = 8
|
||||
AGENT_DEFAULT_NUM_DOCS_FOR_INITIAL_DECOMPOSITION = 3
|
||||
AGENT_DEFAULT_NUM_DOCS_FOR_REFINED_DECOMPOSITION = 5
|
||||
AGENT_DEFAULT_EXPLORATORY_SEARCH_RESULTS = 3
|
||||
AGENT_DEFAULT_MIN_ORIG_QUESTION_DOCS = 3
|
||||
AGENT_DEFAULT_MAX_ANSWER_CONTEXT_DOCS = 10
|
||||
AGENT_DEFAULT_MAX_STATIC_HISTORY_CHAR_LENGTH = 10000
|
||||
|
||||
#####
|
||||
# Agent Configs
|
||||
#####
|
||||
|
||||
|
||||
AGENT_RETRIEVAL_STATS = (
|
||||
not os.environ.get("AGENT_RETRIEVAL_STATS") == "False"
|
||||
) or True # default True
|
||||
|
||||
|
||||
AGENT_MAX_QUERY_RETRIEVAL_RESULTS = int(
|
||||
os.environ.get("AGENT_MAX_QUERY_RETRIEVAL_RESULTS") or AGENT_DEFAULT_RETRIEVAL_HITS
|
||||
) # 15
|
||||
|
||||
AGENT_MAX_QUERY_RETRIEVAL_RESULTS = int(
|
||||
os.environ.get("AGENT_MAX_QUERY_RETRIEVAL_RESULTS") or AGENT_DEFAULT_RETRIEVAL_HITS
|
||||
) # 15
|
||||
|
||||
# Reranking agent configs
|
||||
# Reranking stats - no influence on flow outside of stats collection
|
||||
AGENT_RERANKING_STATS = (
|
||||
not os.environ.get("AGENT_RERANKING_STATS") == "True"
|
||||
) or False # default False
|
||||
|
||||
AGENT_MAX_QUERY_RETRIEVAL_RESULTS = int(
|
||||
os.environ.get("AGENT_MAX_QUERY_RETRIEVAL_RESULTS") or AGENT_DEFAULT_RETRIEVAL_HITS
|
||||
) # 15
|
||||
|
||||
AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS = int(
|
||||
os.environ.get("AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS")
|
||||
or AGENT_DEFAULT_RERANKING_HITS
|
||||
) # 10
|
||||
|
||||
AGENT_NUM_DOCS_FOR_DECOMPOSITION = int(
|
||||
os.environ.get("AGENT_NUM_DOCS_FOR_DECOMPOSITION")
|
||||
or AGENT_DEFAULT_NUM_DOCS_FOR_INITIAL_DECOMPOSITION
|
||||
) # 3
|
||||
|
||||
AGENT_NUM_DOCS_FOR_REFINED_DECOMPOSITION = int(
|
||||
os.environ.get("AGENT_NUM_DOCS_FOR_REFINED_DECOMPOSITION")
|
||||
or AGENT_DEFAULT_NUM_DOCS_FOR_REFINED_DECOMPOSITION
|
||||
) # 5
|
||||
|
||||
AGENT_EXPLORATORY_SEARCH_RESULTS = int(
|
||||
os.environ.get("AGENT_EXPLORATORY_SEARCH_RESULTS")
|
||||
or AGENT_DEFAULT_EXPLORATORY_SEARCH_RESULTS
|
||||
) # 3
|
||||
|
||||
AGENT_MIN_ORIG_QUESTION_DOCS = int(
|
||||
os.environ.get("AGENT_MIN_ORIG_QUESTION_DOCS")
|
||||
or AGENT_DEFAULT_MIN_ORIG_QUESTION_DOCS
|
||||
) # 3
|
||||
|
||||
AGENT_MAX_ANSWER_CONTEXT_DOCS = int(
|
||||
os.environ.get("AGENT_MAX_ANSWER_CONTEXT_DOCS")
|
||||
or AGENT_DEFAULT_SUB_QUESTION_MAX_CONTEXT_HITS
|
||||
) # 8
|
||||
|
||||
|
||||
AGENT_MAX_STATIC_HISTORY_CHAR_LENGTH = int(
|
||||
os.environ.get("AGENT_MAX_STATIC_HISTORY_CHAR_LENGTH")
|
||||
or AGENT_DEFAULT_MAX_STATIC_HISTORY_CHAR_LENGTH
|
||||
) # 10000
|
||||
|
||||
GRAPH_VERSION_NAME: str = "a"
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user