Compare commits

...

86 Commits

Author SHA1 Message Date
pablodanswer
9ee4c3daa9 k 2025-02-02 11:43:19 -08:00
pablodanswer
a0d1d95df6 latex update 2025-01-31 22:03:15 -08:00
pablodanswer
f546e85cff post rebase fix 2025-01-31 22:03:15 -08:00
Evan Lohn
5c82f920ec fix rebase issue 2025-01-31 22:03:15 -08:00
Evan Lohn
ec9888fca2 first pass at dead code deletion 2025-01-31 22:03:14 -08:00
joachim-danswer
8962c768a0 var initialization 2025-01-31 22:03:14 -08:00
joachim-danswer
22565d7334 k 2025-01-31 22:03:14 -08:00
joachim-danswer
5d9bab05e0 k 2025-01-31 22:03:14 -08:00
joachim-danswer
8a99d24434 k 2025-01-31 22:03:14 -08:00
joachim-danswer
e5f1f67a71 Enrichment prompts, prompt improvements, dispatch logging & reinsert empty tool response 2025-01-31 22:03:14 -08:00
joachim-danswer
bee0137509 variable renaming 2025-01-31 22:03:14 -08:00
joachim-danswer
f6afc95e2f fix for merge error (#3814) 2025-01-31 22:03:14 -08:00
joachim-danswer
1c9bc48705 graph directory renamings 2025-01-31 22:03:14 -08:00
joachim-danswer
e6242f7438 persona_prompt improvements 2025-01-31 22:03:14 -08:00
joachim-danswer
ad16d66684 average dispatch time collection for sub-answers 2025-01-31 22:03:14 -08:00
joachim-danswer
76c72f88e9 added total time to logging 2025-01-31 22:03:14 -08:00
joachim-danswer
4d218c8a5d agent default changes/restructuring 2025-01-31 22:03:14 -08:00
joachim-danswer
36d4e0a1c6 increased logging 2025-01-31 22:03:14 -08:00
joachim-danswer
f233ed071e cleanup of refined answer generation 2025-01-31 22:03:14 -08:00
joachim-danswer
9770db967b application of content limitation ion refined answer as well 2025-01-31 22:03:14 -08:00
joachim-danswer
38a63d74a9 Optimizations: docs for context & history
- summarize history if long
- introduced cited_docs from SQ as those must be provided to answer generations
- limit number of docs

TODO: same for refined flow
2025-01-31 22:03:14 -08:00
Evan Lohn
17fb4219da nit 2025-01-31 22:03:14 -08:00
Evan Lohn
d019e14d00 AgentPromptConfig in Answer class 2025-01-31 22:03:14 -08:00
Evan Lohn
70f61d53e0 use reranking settings and persona during preprocessing in reranker 2025-01-31 22:03:14 -08:00
Evan Lohn
c2d9ad9143 removed unused files 2025-01-31 22:03:14 -08:00
Evan Lohn
7b6901a8c1 always send search response 2025-01-31 22:03:14 -08:00
Evan Lohn
16300b42a1 remove debug 2025-01-31 22:03:14 -08:00
pablodanswer
5635b32ac4 improve regeneration state 2025-01-31 22:03:14 -08:00
pablodanswer
13ec194153 nit 2025-01-31 22:03:14 -08:00
pablodanswer
db0dee1aac improved timing 2025-01-31 22:03:14 -08:00
Evan Lohn
23dbb521c3 increased timeout to get rid of asyncio logger errors 2025-01-31 22:03:14 -08:00
joachim-danswer
b81f03131f addressing nits of EL 2025-01-31 22:03:14 -08:00
joachim-danswer
e896786693 updated answer_comparison prompt + small cleanup 2025-01-31 22:03:14 -08:00
joachim-danswer
1ebddb1ebe refined search + question answering as sub-graphs 2025-01-31 22:03:14 -08:00
joachim-danswer
0146d3c66d sub-graphs for initial question/search 2025-01-31 22:03:14 -08:00
joachim-danswer
f2a54b79ac refined search + question answering as sub-graphs 2025-01-31 22:03:14 -08:00
pablodanswer
27982a1d5a minor update 2025-01-31 22:03:14 -08:00
pablodanswer
ff3b4d28f4 k 2025-01-31 22:03:14 -08:00
pablodanswer
081efa0831 update switching logic 2025-01-31 22:03:14 -08:00
pablodanswer
4dec31c93d fix toggling edge case 2025-01-31 22:03:14 -08:00
pablodanswer
b693fe8248 update bool 2025-01-31 22:03:14 -08:00
pablodanswer
9c2b152dd7 various improvements 2025-01-31 22:03:14 -08:00
pablodanswer
3afb79a7d8 quick nit 2025-01-31 22:03:14 -08:00
Evan Lohn
a1c3ba7eba allowed empty Search Tool for non-agentic search 2025-01-31 22:03:14 -08:00
pablodanswer
1bb7f782c8 minor update - doc ordering 2025-01-31 22:03:14 -08:00
pablodanswer
ee8d9ddc3d k 2025-01-31 22:03:14 -08:00
pablodanswer
54d6f31f0d quick nit 2025-01-31 22:03:14 -08:00
pablodanswer
3306479674 k 2025-01-31 22:03:14 -08:00
joachim-danswer
c08ed464a9 Replaced additional limit with variable 2025-01-31 22:03:14 -08:00
joachim-danswer
d07588d1ce Addressing EL's comments
- created vars for a couple of agent settings
 - moved agent configs
 - created a search function
2025-01-31 22:03:14 -08:00
joachim-danswer
9d494adc3e taking out Extraction for now 2025-01-31 22:03:14 -08:00
joachim-danswer
be7f3f6eed earlier entity extraction & sharper generation prompts 2025-01-31 22:03:14 -08:00
joachim-danswer
8561a50eac tmp: force agent search 2025-01-31 22:03:14 -08:00
Evan Lohn
3ce783a1fe skip reranking for <=1 doc 2025-01-31 22:03:14 -08:00
Evan Lohn
e9ea2a1b1f stop infos when done streaming answers 2025-01-31 22:03:14 -08:00
Evan Lohn
b6caedf50b make field nullable 2025-01-31 22:03:14 -08:00
Evan Lohn
4ad1dca233 persisting refined answer improvement 2025-01-31 22:03:14 -08:00
Evan Lohn
df9f808c4a address JR comments 2025-01-31 22:03:14 -08:00
Evan Lohn
d68bcb6ac2 fixed chat tests 2025-01-31 22:03:14 -08:00
Evan Lohn
e9f2d6468d implemented top-level tool calling + force search 2025-01-31 22:03:14 -08:00
Evan Lohn
0ca153b651 WIP, but working basic search using initial tool choice node 2025-01-31 22:03:14 -08:00
pablodanswer
279f0e9374 k 2025-01-31 22:03:14 -08:00
pablodanswer
ea73f12844 updated + functional 2025-01-31 22:03:14 -08:00
pablodanswer
1007d4684a update- reorg 2025-01-31 22:03:14 -08:00
pablodanswer
d336fa31b4 k 2025-01-31 22:03:13 -08:00
pablodanswer
0b7719c158 build fix 2025-01-31 22:02:56 -08:00
joachim-danswer
c41f4599bf EL comments addressed 2025-01-31 22:02:56 -08:00
joachim-danswer
6bc232f040 loser verification prompt 2025-01-31 22:02:56 -08:00
joachim-danswer
b10166d4d2 turning off initial search pre route decision 2025-01-31 22:02:56 -08:00
joachim-danswer
836f84f946 change of sub-question answer if no docs recovered 2025-01-31 22:02:56 -08:00
joachim-danswer
ffb627621e various fixes from Yuhong's list 2025-01-31 22:02:56 -08:00
Yuhong Sun
ca776932a9 Copy changes 2025-01-31 22:02:56 -08:00
Evan Lohn
64e13986bb removed print statements, fixed pass through handling 2025-01-31 22:02:56 -08:00
Evan Lohn
fcaee1979f fixed basic flow citations and second test 2025-01-31 22:02:56 -08:00
Evan Lohn
105a0f7b49 fix for early cancellation test; solves issue with tasks being destroyed while pending 2025-01-31 22:02:56 -08:00
pablodanswer
afce385fd0 add agent search frontend 2025-01-31 22:02:54 -08:00
Evan Lohn
978650028a fix alembic history 2025-01-31 22:01:08 -08:00
joachim-danswer
c87270a279 streaming + saving of search docs of no verified ones available
- sub-questions only
2025-01-31 22:01:08 -08:00
Evan Lohn
bc13a6caa7 reworked history messages in agent config 2025-01-31 22:01:08 -08:00
Evan Lohn
f14bc9ab8a missed files from prev commit 2025-01-31 22:01:08 -08:00
Evan Lohn
b95c566c15 basic search restructure: WIP on fixing tests 2025-01-31 22:01:08 -08:00
joachim-danswer
9f5c9142b5 prompts that even further motivates to cite docs over sub-q's 2025-01-31 22:01:08 -08:00
joachim-danswer
0c5b47e36f pydantic for LangGraph + changed ERT extraction flow 2025-01-31 22:01:08 -08:00
joachim-danswer
09020a202e history added to agent flow 2025-01-31 22:01:08 -08:00
pablodanswer
2a662c40b8 minor fixes to branch 2025-01-31 22:01:08 -08:00
Evan Lohn
35593cae0d second clean commit 2025-01-31 22:01:08 -08:00
155 changed files with 13489 additions and 1088 deletions

4
.gitignore vendored
View File

@@ -7,4 +7,6 @@
.vscode/
*.sw?
/backend/tests/regression/answer_quality/search_test_config.yaml
/web/test-results/
/web/test-results/
backend/onyx/agent_search/main/test_data.json
backend/tests/regression/answer_quality/test_data.json

View File

@@ -52,3 +52,9 @@ BING_API_KEY=<REPLACE THIS>
# Enable the full set of Danswer Enterprise Edition features
# NOTE: DO NOT ENABLE THIS UNLESS YOU HAVE A PAID ENTERPRISE LICENSE (or if you are using this for local testing/development)
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=False
# Agent Search configs # TODO: Remove give proper namings
AGENT_RETRIEVAL_STATS=False # Note: This setting will incur substantial re-ranking effort
AGENT_RERANKING_STATS=True
AGENT_MAX_QUERY_RETRIEVAL_RESULTS=20
AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS=20

View File

@@ -0,0 +1,29 @@
"""agent_doc_result_col
Revision ID: 1adf5ea20d2b
Revises: e9cf2bd7baed
Create Date: 2025-01-05 13:14:58.344316
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "1adf5ea20d2b"
down_revision = "e9cf2bd7baed"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Add the new column with JSONB type
op.add_column(
"sub_question",
sa.Column("sub_question_doc_results", postgresql.JSONB(), nullable=True),
)
def downgrade() -> None:
# Drop the column
op.drop_column("sub_question", "sub_question_doc_results")

View File

@@ -0,0 +1,31 @@
"""refined answer improvement
Revision ID: 211b14ab5a91
Revises: 925b58bd75b6
Create Date: 2025-01-24 14:05:03.334309
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "211b14ab5a91"
down_revision = "925b58bd75b6"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"chat_message",
sa.Column(
"refined_answer_improvement",
sa.Boolean(),
nullable=True,
),
)
def downgrade() -> None:
op.drop_column("chat_message", "refined_answer_improvement")

View File

@@ -1,80 +0,0 @@
"""foreign key input prompts
Revision ID: 33ea50e88f24
Revises: a6df6b88ef81
Create Date: 2025-01-29 10:54:22.141765
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "33ea50e88f24"
down_revision = "a6df6b88ef81"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Safely drop constraints if exists
op.execute(
"""
ALTER TABLE inputprompt__user
DROP CONSTRAINT IF EXISTS inputprompt__user_input_prompt_id_fkey
"""
)
op.execute(
"""
ALTER TABLE inputprompt__user
DROP CONSTRAINT IF EXISTS inputprompt__user_user_id_fkey
"""
)
# Recreate with ON DELETE CASCADE
op.create_foreign_key(
"inputprompt__user_input_prompt_id_fkey",
"inputprompt__user",
"inputprompt",
["input_prompt_id"],
["id"],
ondelete="CASCADE",
)
op.create_foreign_key(
"inputprompt__user_user_id_fkey",
"inputprompt__user",
"user",
["user_id"],
["id"],
ondelete="CASCADE",
)
def downgrade() -> None:
# Drop the new FKs with ondelete
op.drop_constraint(
"inputprompt__user_input_prompt_id_fkey",
"inputprompt__user",
type_="foreignkey",
)
op.drop_constraint(
"inputprompt__user_user_id_fkey",
"inputprompt__user",
type_="foreignkey",
)
# Recreate them without cascading
op.create_foreign_key(
"inputprompt__user_input_prompt_id_fkey",
"inputprompt__user",
"inputprompt",
["input_prompt_id"],
["id"],
)
op.create_foreign_key(
"inputprompt__user_user_id_fkey",
"inputprompt__user",
"user",
["user_id"],
["id"],
)

View File

@@ -0,0 +1,35 @@
"""agent_metric_col_rename__s
Revision ID: 925b58bd75b6
Revises: 9787be927e58
Create Date: 2025-01-06 11:20:26.752441
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "925b58bd75b6"
down_revision = "9787be927e58"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Rename columns using PostgreSQL syntax
op.alter_column(
"agent__search_metrics", "base_duration_s", new_column_name="base_duration__s"
)
op.alter_column(
"agent__search_metrics", "full_duration_s", new_column_name="full_duration__s"
)
def downgrade() -> None:
# Revert the column renames
op.alter_column(
"agent__search_metrics", "base_duration__s", new_column_name="base_duration_s"
)
op.alter_column(
"agent__search_metrics", "full_duration__s", new_column_name="full_duration_s"
)

View File

@@ -0,0 +1,25 @@
"""agent_metric_table_renames__agent__
Revision ID: 9787be927e58
Revises: bceb76d618ec
Create Date: 2025-01-06 11:01:44.210160
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "9787be927e58"
down_revision = "bceb76d618ec"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Rename table from agent_search_metrics to agent__search_metrics
op.rename_table("agent_search_metrics", "agent__search_metrics")
def downgrade() -> None:
# Rename table back from agent__search_metrics to agent_search_metrics
op.rename_table("agent__search_metrics", "agent_search_metrics")

View File

@@ -0,0 +1,42 @@
"""agent_tracking
Revision ID: 98a5008d8711
Revises: 4d58345da04a
Create Date: 2025-01-04 14:41:52.732238
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "98a5008d8711"
down_revision = "4d58345da04a"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.create_table(
"agent_search_metrics",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=True),
sa.Column("persona_id", sa.Integer(), nullable=True),
sa.Column("agent_type", sa.String(), nullable=False),
sa.Column("start_time", sa.DateTime(timezone=True), nullable=False),
sa.Column("base_duration_s", sa.Float(), nullable=False),
sa.Column("full_duration_s", sa.Float(), nullable=False),
sa.Column("base_metrics", postgresql.JSONB(), nullable=True),
sa.Column("refined_metrics", postgresql.JSONB(), nullable=True),
sa.Column("all_metrics", postgresql.JSONB(), nullable=True),
sa.ForeignKeyConstraint(
["persona_id"],
["persona.id"],
),
sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"),
)
def downgrade() -> None:
op.drop_table("agent_search_metrics")

View File

@@ -1,29 +0,0 @@
"""remove recent assistants
Revision ID: a6df6b88ef81
Revises: 4d58345da04a
Create Date: 2025-01-29 10:25:52.790407
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "a6df6b88ef81"
down_revision = "4d58345da04a"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.drop_column("user", "recent_assistants")
def downgrade() -> None:
op.add_column(
"user",
sa.Column(
"recent_assistants", postgresql.JSONB(), server_default="[]", nullable=False
),
)

View File

@@ -0,0 +1,84 @@
"""agent_table_renames__agent__
Revision ID: bceb76d618ec
Revises: c0132518a25b
Create Date: 2025-01-06 10:50:48.109285
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "bceb76d618ec"
down_revision = "c0132518a25b"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.drop_constraint(
"sub_query__search_doc_sub_query_id_fkey",
"sub_query__search_doc",
type_="foreignkey",
)
op.drop_constraint(
"sub_query__search_doc_search_doc_id_fkey",
"sub_query__search_doc",
type_="foreignkey",
)
# Rename tables
op.rename_table("sub_query", "agent__sub_query")
op.rename_table("sub_question", "agent__sub_question")
op.rename_table("sub_query__search_doc", "agent__sub_query__search_doc")
# Update both foreign key constraints for agent__sub_query__search_doc
# Create new foreign keys with updated names
op.create_foreign_key(
"agent__sub_query__search_doc_sub_query_id_fkey",
"agent__sub_query__search_doc",
"agent__sub_query",
["sub_query_id"],
["id"],
)
op.create_foreign_key(
"agent__sub_query__search_doc_search_doc_id_fkey",
"agent__sub_query__search_doc",
"search_doc", # This table name doesn't change
["search_doc_id"],
["id"],
)
def downgrade() -> None:
# Update foreign key constraints for sub_query__search_doc
op.drop_constraint(
"agent__sub_query__search_doc_sub_query_id_fkey",
"agent__sub_query__search_doc",
type_="foreignkey",
)
op.drop_constraint(
"agent__sub_query__search_doc_search_doc_id_fkey",
"agent__sub_query__search_doc",
type_="foreignkey",
)
# Rename tables back
op.rename_table("agent__sub_query__search_doc", "sub_query__search_doc")
op.rename_table("agent__sub_question", "sub_question")
op.rename_table("agent__sub_query", "sub_query")
op.create_foreign_key(
"sub_query__search_doc_sub_query_id_fkey",
"sub_query__search_doc",
"sub_query",
["sub_query_id"],
["id"],
)
op.create_foreign_key(
"sub_query__search_doc_search_doc_id_fkey",
"sub_query__search_doc",
"search_doc", # This table name doesn't change
["search_doc_id"],
["id"],
)

View File

@@ -0,0 +1,40 @@
"""agent_table_changes_rename_level
Revision ID: c0132518a25b
Revises: 1adf5ea20d2b
Create Date: 2025-01-05 16:38:37.660152
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "c0132518a25b"
down_revision = "1adf5ea20d2b"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Add level and level_question_nr columns with NOT NULL constraint
op.add_column(
"sub_question",
sa.Column("level", sa.Integer(), nullable=False, server_default="0"),
)
op.add_column(
"sub_question",
sa.Column(
"level_question_nr", sa.Integer(), nullable=False, server_default="0"
),
)
# Remove the server_default after the columns are created
op.alter_column("sub_question", "level", server_default=None)
op.alter_column("sub_question", "level_question_nr", server_default=None)
def downgrade() -> None:
# Remove the columns
op.drop_column("sub_question", "level_question_nr")
op.drop_column("sub_question", "level")

View File

@@ -0,0 +1,68 @@
"""create pro search persistence tables
Revision ID: e9cf2bd7baed
Revises: 98a5008d8711
Create Date: 2025-01-02 17:55:56.544246
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects.postgresql import UUID
# revision identifiers, used by Alembic.
revision = "e9cf2bd7baed"
down_revision = "98a5008d8711"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Create sub_question table
op.create_table(
"sub_question",
sa.Column("id", sa.Integer, primary_key=True),
sa.Column("primary_question_id", sa.Integer, sa.ForeignKey("chat_message.id")),
sa.Column(
"chat_session_id", UUID(as_uuid=True), sa.ForeignKey("chat_session.id")
),
sa.Column("sub_question", sa.Text),
sa.Column(
"time_created", sa.DateTime(timezone=True), server_default=sa.func.now()
),
sa.Column("sub_answer", sa.Text),
)
# Create sub_query table
op.create_table(
"sub_query",
sa.Column("id", sa.Integer, primary_key=True),
sa.Column("parent_question_id", sa.Integer, sa.ForeignKey("sub_question.id")),
sa.Column(
"chat_session_id", UUID(as_uuid=True), sa.ForeignKey("chat_session.id")
),
sa.Column("sub_query", sa.Text),
sa.Column(
"time_created", sa.DateTime(timezone=True), server_default=sa.func.now()
),
)
# Create sub_query__search_doc association table
op.create_table(
"sub_query__search_doc",
sa.Column(
"sub_query_id", sa.Integer, sa.ForeignKey("sub_query.id"), primary_key=True
),
sa.Column(
"search_doc_id",
sa.Integer,
sa.ForeignKey("search_doc.id"),
primary_key=True,
),
)
def downgrade() -> None:
op.drop_table("sub_query__search_doc")
op.drop_table("sub_query")
op.drop_table("sub_question")

View File

@@ -179,6 +179,7 @@ def handle_simplified_chat_message(
chunks_below=0,
full_doc=chat_message_req.full_doc,
structured_response_format=chat_message_req.structured_response_format,
use_agentic_search=chat_message_req.use_agentic_search,
)
packets = stream_chat_message_objects(
@@ -301,6 +302,7 @@ def handle_send_message_simple_with_history(
chunks_below=0,
full_doc=req.full_doc,
structured_response_format=req.structured_response_format,
use_agentic_search=req.use_agentic_search,
)
packets = stream_chat_message_objects(

View File

@@ -57,6 +57,9 @@ class BasicCreateChatMessageRequest(ChunkContext):
# https://platform.openai.com/docs/guides/structured-outputs/introduction
structured_response_format: dict | None = None
# If True, uses agentic search instead of basic search
use_agentic_search: bool = False
class BasicCreateChatMessageWithHistoryRequest(ChunkContext):
# Last element is the new query. All previous elements are historical context
@@ -71,6 +74,8 @@ class BasicCreateChatMessageWithHistoryRequest(ChunkContext):
# only works if using an OpenAI model. See the following for more details:
# https://platform.openai.com/docs/guides/structured-outputs/introduction
structured_response_format: dict | None = None
# If True, uses agentic search instead of basic search
use_agentic_search: bool = False
class SimpleDoc(BaseModel):
@@ -123,6 +128,9 @@ class OneShotQARequest(ChunkContext):
# If True, skips generative an AI response to the search query
skip_gen_ai_answer_generation: bool = False
# If True, uses pro search instead of basic search
use_agentic_search: bool = False
@model_validator(mode="after")
def check_persona_fields(self) -> "OneShotQARequest":
if self.persona_override_config is None and self.persona_id is None:

View File

@@ -196,6 +196,8 @@ def get_answer_stream(
retrieval_details=query_request.retrieval_options,
rerank_settings=query_request.rerank_settings,
db_session=db_session,
use_agentic_search=query_request.use_agentic_search,
skip_gen_ai_answer_generation=query_request.skip_gen_ai_answer_generation,
)
packets = stream_chat_message_objects(

View File

@@ -0,0 +1,98 @@
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
from onyx.agents.agent_search.basic.states import BasicInput
from onyx.agents.agent_search.basic.states import BasicOutput
from onyx.agents.agent_search.basic.states import BasicState
from onyx.agents.agent_search.orchestration.nodes.basic_use_tool_response import (
basic_use_tool_response,
)
from onyx.agents.agent_search.orchestration.nodes.llm_tool_choice import llm_tool_choice
from onyx.agents.agent_search.orchestration.nodes.prepare_tool_input import (
prepare_tool_input,
)
from onyx.agents.agent_search.orchestration.nodes.tool_call import tool_call
from onyx.utils.logger import setup_logger
logger = setup_logger()
def basic_graph_builder() -> StateGraph:
graph = StateGraph(
state_schema=BasicState,
input=BasicInput,
output=BasicOutput,
)
### Add nodes ###
graph.add_node(
node="prepare_tool_input",
action=prepare_tool_input,
)
graph.add_node(
node="llm_tool_choice",
action=llm_tool_choice,
)
graph.add_node(
node="tool_call",
action=tool_call,
)
graph.add_node(
node="basic_use_tool_response",
action=basic_use_tool_response,
)
### Add edges ###
graph.add_edge(start_key=START, end_key="prepare_tool_input")
graph.add_edge(start_key="prepare_tool_input", end_key="llm_tool_choice")
graph.add_conditional_edges("llm_tool_choice", should_continue, ["tool_call", END])
graph.add_edge(
start_key="tool_call",
end_key="basic_use_tool_response",
)
graph.add_edge(
start_key="basic_use_tool_response",
end_key=END,
)
return graph
def should_continue(state: BasicState) -> str:
return (
# If there are no tool calls, basic graph already streamed the answer
END
if state.tool_choice is None
else "tool_call"
)
if __name__ == "__main__":
from onyx.db.engine import get_session_context_manager
from onyx.context.search.models import SearchRequest
from onyx.llm.factory import get_default_llms
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
graph = basic_graph_builder()
compiled_graph = graph.compile()
# TODO: unify basic input
input = BasicInput(logs="")
primary_llm, fast_llm = get_default_llms()
with get_session_context_manager() as db_session:
config, _ = get_test_config(
db_session=db_session,
primary_llm=primary_llm,
fast_llm=fast_llm,
search_request=SearchRequest(query="How does onyx use FastAPI?"),
)
compiled_graph.invoke(input, config={"metadata": {"config": config}})

View File

@@ -0,0 +1,42 @@
from typing import TypedDict
from langchain_core.messages import AIMessageChunk
from pydantic import BaseModel
from onyx.agents.agent_search.orchestration.states import ToolCallUpdate
from onyx.agents.agent_search.orchestration.states import ToolChoiceInput
from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
# States contain values that change over the course of graph execution,
# Config is for values that are set at the start and never change.
# If you are using a value from the config and realize it needs to change,
# you should add it to the state and use/update the version in the state.
## Graph Input State
class BasicInput(BaseModel):
# TODO: subclass global log update state
logs: str = ""
## Graph Output State
class BasicOutput(TypedDict):
tool_call_chunk: AIMessageChunk
## Update States
## Graph State
class BasicState(
BasicInput,
ToolChoiceInput,
ToolCallUpdate,
ToolChoiceUpdate,
):
pass

View File

@@ -0,0 +1,69 @@
from collections.abc import Iterator
from typing import cast
from langchain_core.callbacks.manager import dispatch_custom_event
from langchain_core.messages import AIMessageChunk
from langchain_core.messages import BaseMessage
from onyx.chat.models import LlmDoc
from onyx.chat.stream_processing.answer_response_handler import AnswerResponseHandler
from onyx.chat.stream_processing.answer_response_handler import CitationResponseHandler
from onyx.chat.stream_processing.answer_response_handler import (
PassThroughAnswerResponseHandler,
)
from onyx.chat.stream_processing.utils import map_document_id_order
from onyx.utils.logger import setup_logger
logger = setup_logger()
# TODO: handle citations here; below is what was previously passed in
# see basic_use_tool_response.py for where these variables come from
# answer_handler = CitationResponseHandler(
# context_docs=final_search_results,
# final_doc_id_to_rank_map=map_document_id_order(final_search_results),
# display_doc_id_to_rank_map=map_document_id_order(displayed_search_results),
# )
def process_llm_stream(
stream: Iterator[BaseMessage],
should_stream_answer: bool,
final_search_results: list[LlmDoc] | None = None,
displayed_search_results: list[LlmDoc] | None = None,
) -> AIMessageChunk:
tool_call_chunk = AIMessageChunk(content="")
# for response in response_handler_manager.handle_llm_response(stream):
if final_search_results and displayed_search_results:
answer_handler: AnswerResponseHandler = CitationResponseHandler(
context_docs=final_search_results,
final_doc_id_to_rank_map=map_document_id_order(final_search_results),
display_doc_id_to_rank_map=map_document_id_order(displayed_search_results),
)
else:
answer_handler = PassThroughAnswerResponseHandler()
full_answer = ""
# This stream will be the llm answer if no tool is chosen. When a tool is chosen,
# the stream will contain AIMessageChunks with tool call information.
for response in stream:
answer_piece = response.content
if not isinstance(answer_piece, str):
# TODO: handle non-string content
logger.warning(f"Received non-string content: {type(answer_piece)}")
answer_piece = str(answer_piece)
full_answer += answer_piece
if isinstance(response, AIMessageChunk) and (
response.tool_call_chunks or response.tool_calls
):
tool_call_chunk += response # type: ignore
elif should_stream_answer:
for response_part in answer_handler.handle_response_part(response, []):
dispatch_custom_event(
"basic_response",
response_part,
)
logger.info(f"Full answer: {full_answer}")
return cast(AIMessageChunk, tool_call_chunk)

View File

@@ -0,0 +1,21 @@
from operator import add
from typing import Annotated
from pydantic import BaseModel
class CoreState(BaseModel):
"""
This is the core state that is shared across all subgraphs.
"""
base_question: str = ""
log_messages: Annotated[list[str], add] = []
class SubgraphCoreState(BaseModel):
"""
This is the core state that is shared across all subgraphs.
"""
log_messages: Annotated[list[str], add]

View File

@@ -0,0 +1,66 @@
from uuid import UUID
from sqlalchemy.orm import Session
from onyx.db.models import AgentSubQuery
from onyx.db.models import AgentSubQuestion
def create_sub_question(
db_session: Session,
chat_session_id: UUID,
primary_message_id: int,
sub_question: str,
sub_answer: str,
) -> AgentSubQuestion:
"""Create a new sub-question record in the database."""
sub_q = AgentSubQuestion(
chat_session_id=chat_session_id,
primary_question_id=primary_message_id,
sub_question=sub_question,
sub_answer=sub_answer,
)
db_session.add(sub_q)
db_session.flush()
return sub_q
def create_sub_query(
db_session: Session,
chat_session_id: UUID,
parent_question_id: int,
sub_query: str,
) -> AgentSubQuery:
"""Create a new sub-query record in the database."""
sub_q = AgentSubQuery(
chat_session_id=chat_session_id,
parent_question_id=parent_question_id,
sub_query=sub_query,
)
db_session.add(sub_q)
db_session.flush()
return sub_q
def get_sub_questions_for_message(
db_session: Session,
primary_message_id: int,
) -> list[AgentSubQuestion]:
"""Get all sub-questions for a given primary message."""
return (
db_session.query(AgentSubQuestion)
.filter(AgentSubQuestion.primary_question_id == primary_message_id)
.all()
)
def get_sub_queries_for_question(
db_session: Session,
sub_question_id: int,
) -> list[AgentSubQuery]:
"""Get all sub-queries for a given sub-question."""
return (
db_session.query(AgentSubQuery)
.filter(AgentSubQuery.parent_question_id == sub_question_id)
.all()
)

View File

@@ -0,0 +1,55 @@
from collections.abc import Hashable
from datetime import datetime
from langgraph.types import Send
from onyx.agents.agent_search.deep_search_a.initial__individual_sub_answer__subgraph.states import (
AnswerQuestionInput,
)
from onyx.agents.agent_search.deep_search_a.initial__individual_sub_answer__subgraph.states import (
AnswerQuestionOutput,
)
from onyx.agents.agent_search.deep_search_a.initial__retrieval_sub_answers__subgraph.states import (
SearchSQState,
)
from onyx.agents.agent_search.shared_graph_utils.utils import make_question_id
def parallelize_initial_sub_question_answering(
state: SearchSQState,
) -> list[Send | Hashable]:
now_start = datetime.now()
if len(state.initial_decomp_questions) > 0:
# sub_question_record_ids = [subq_record.id for subq_record in state["sub_question_records"]]
# if len(state["sub_question_records"]) == 0:
# if state["config"].use_persistence:
# raise ValueError("No sub-questions found for initial decompozed questions")
# else:
# # in this case, we are doing retrieval on the original question.
# # to make all the logic consistent, we create a new sub-question
# # with the same content as the original question
# sub_question_record_ids = [1] * len(state["initial_decomp_questions"])
return [
Send(
"answer_query_subgraph",
AnswerQuestionInput(
question=question,
question_id=make_question_id(0, question_nr + 1),
log_messages=[
f"{now_start} -- Main Edge - Parallelize Initial Sub-question Answering"
],
),
)
for question_nr, question in enumerate(state.initial_decomp_questions)
]
else:
return [
Send(
"ingest_answers",
AnswerQuestionOutput(
answer_results=[],
),
)
]

View File

@@ -0,0 +1,100 @@
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
from onyx.agents.agent_search.deep_search_a.initial__consolidate_sub_answers__subgraph.edges import (
parallelize_initial_sub_question_answering,
)
from onyx.agents.agent_search.deep_search_a.initial__consolidate_sub_answers__subgraph.nodes.ingest_initial_sub_answers import (
ingest_initial_sub_answers,
)
from onyx.agents.agent_search.deep_search_a.initial__consolidate_sub_answers__subgraph.nodes.initial_decomposition import (
initial_sub_question_creation,
)
from onyx.agents.agent_search.deep_search_a.initial__consolidate_sub_answers__subgraph.states import (
SQInput,
)
from onyx.agents.agent_search.deep_search_a.initial__consolidate_sub_answers__subgraph.states import (
SQState,
)
from onyx.agents.agent_search.deep_search_a.initial__individual_sub_answer__subgraph.graph_builder import (
answer_query_graph_builder,
)
from onyx.utils.logger import setup_logger
logger = setup_logger()
test_mode = False
def initial_sq_subgraph_builder(test_mode: bool = False) -> StateGraph:
graph = StateGraph(
state_schema=SQState,
input=SQInput,
)
graph.add_node(
node="initial_sub_question_creation",
action=initial_sub_question_creation,
)
answer_query_subgraph = answer_query_graph_builder().compile()
graph.add_node(
node="answer_query_subgraph",
action=answer_query_subgraph,
)
graph.add_node(
node="ingest_initial_sub_question_answers",
action=ingest_initial_sub_answers,
)
### Add edges ###
# raph.add_edge(start_key=START, end_key="base_raw_search_subgraph")
# graph.add_edge(
# start_key="agent_search_start",
# end_key="entity_term_extraction_llm",
# )
graph.add_edge(
start_key=START,
end_key="initial_sub_question_creation",
)
# graph.add_edge(
# start_key="LLM",
# end_key=END,
# )
# graph.add_edge(
# start_key=START,
# end_key="initial_sub_question_creation",
# )
graph.add_conditional_edges(
source="initial_sub_question_creation",
path=parallelize_initial_sub_question_answering,
path_map=["answer_query_subgraph"],
)
graph.add_edge(
start_key="answer_query_subgraph",
end_key="ingest_initial_sub_question_answers",
)
graph.add_edge(
start_key="ingest_initial_sub_question_answers",
end_key=END,
)
# graph.add_edge(
# start_key="generate_refined_answer",
# end_key="check_refined_answer",
# )
# graph.add_edge(
# start_key="check_refined_answer",
# end_key=END,
# )
return graph

View File

@@ -0,0 +1,45 @@
from datetime import datetime
from onyx.agents.agent_search.deep_search_a.initial__individual_sub_answer__subgraph.states import (
AnswerQuestionOutput,
)
from onyx.agents.agent_search.deep_search_a.main__graph.operations import logger
from onyx.agents.agent_search.deep_search_a.main__graph.states import (
DecompAnswersUpdate,
)
from onyx.agents.agent_search.shared_graph_utils.operators import (
dedup_inference_sections,
)
def ingest_initial_sub_answers(
state: AnswerQuestionOutput,
) -> DecompAnswersUpdate:
now_start = datetime.now()
logger.info(f"--------{now_start}--------INGEST ANSWERS---")
documents = []
context_documents = []
cited_docs = []
answer_results = state.answer_results if hasattr(state, "answer_results") else []
for answer_result in answer_results:
documents.extend(answer_result.documents)
context_documents.extend(answer_result.context_documents)
cited_docs.extend(answer_result.cited_docs)
now_end = datetime.now()
logger.debug(
f"--------{now_end}--{now_end - now_start}--------INGEST ANSWERS END---"
)
return DecompAnswersUpdate(
# Deduping is done by the documents operator for the main graph
# so we might not need to dedup here
documents=dedup_inference_sections(documents, []),
context_documents=dedup_inference_sections(context_documents, []),
cited_docs=dedup_inference_sections(cited_docs, []),
decomp_answer_results=answer_results,
log_messages=[
f"{now_start} -- Main - Ingest initial processed sub questions, Time taken: {now_end - now_start}"
],
)

View File

@@ -0,0 +1,138 @@
from datetime import datetime
from typing import cast
from langchain_core.callbacks.manager import dispatch_custom_event
from langchain_core.messages import HumanMessage
from langchain_core.messages import merge_content
from langchain_core.runnables import RunnableConfig
from onyx.agents.agent_search.deep_search_a.initial__retrieval_sub_answers__subgraph.states import (
SearchSQState,
)
from onyx.agents.agent_search.deep_search_a.main__graph.models import (
AgentRefinedMetrics,
)
from onyx.agents.agent_search.deep_search_a.main__graph.operations import (
dispatch_subquestion,
)
from onyx.agents.agent_search.deep_search_a.main__graph.operations import logger
from onyx.agents.agent_search.deep_search_a.main__graph.states import BaseDecompUpdate
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
build_history_prompt,
)
from onyx.agents.agent_search.shared_graph_utils.prompts import (
INITIAL_DECOMPOSITION_PROMPT_QUESTIONS,
)
from onyx.agents.agent_search.shared_graph_utils.prompts import (
INITIAL_DECOMPOSITION_PROMPT_QUESTIONS_AFTER_SEARCH,
)
from onyx.agents.agent_search.shared_graph_utils.utils import dispatch_separated
from onyx.chat.models import StreamStopInfo
from onyx.chat.models import StreamStopReason
from onyx.chat.models import SubQuestionPiece
from onyx.configs.agent_configs import AGENT_NUM_DOCS_FOR_DECOMPOSITION
def initial_sub_question_creation(
state: SearchSQState, config: RunnableConfig
) -> BaseDecompUpdate:
now_start = datetime.now()
logger.info(f"--------{now_start}--------BASE DECOMP START---")
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
question = agent_a_config.search_request.query
chat_session_id = agent_a_config.chat_session_id
primary_message_id = agent_a_config.message_id
perform_initial_search_decomposition = (
agent_a_config.perform_initial_search_decomposition
)
# Get the rewritten queries in a defined format
model = agent_a_config.fast_llm
history = build_history_prompt(agent_a_config, question)
# Use the initial search results to inform the decomposition
sample_doc_str = state.sample_doc_str if hasattr(state, "sample_doc_str") else ""
if not chat_session_id or not primary_message_id:
raise ValueError(
"chat_session_id and message_id must be provided for agent search"
)
agent_start_time = datetime.now()
# Initial search to inform decomposition. Just get top 3 fits
if perform_initial_search_decomposition:
sample_doc_str = "\n\n".join(
[
doc.combined_content
for doc in state.exploratory_search_results[
:AGENT_NUM_DOCS_FOR_DECOMPOSITION
]
]
)
decomposition_prompt = (
INITIAL_DECOMPOSITION_PROMPT_QUESTIONS_AFTER_SEARCH.format(
question=question, sample_doc_str=sample_doc_str, history=history
)
)
else:
decomposition_prompt = INITIAL_DECOMPOSITION_PROMPT_QUESTIONS.format(
question=question, history=history
)
# Start decomposition
msg = [HumanMessage(content=decomposition_prompt)]
# Send the initial question as a subquestion with number 0
dispatch_custom_event(
"decomp_qs",
SubQuestionPiece(
sub_question=question,
level=0,
level_question_nr=0,
),
)
# dispatches custom events for subquestion tokens, adding in subquestion ids.
streamed_tokens = dispatch_separated(model.stream(msg), dispatch_subquestion(0))
stop_event = StreamStopInfo(
stop_reason=StreamStopReason.FINISHED,
stream_type="sub_questions",
level=0,
)
dispatch_custom_event("stream_finished", stop_event)
deomposition_response = merge_content(*streamed_tokens)
# this call should only return strings. Commenting out for efficiency
# assert [type(tok) == str for tok in streamed_tokens]
# use no-op cast() instead of str() which runs code
# list_of_subquestions = clean_and_parse_list_string(cast(str, response))
list_of_subqs = cast(str, deomposition_response).split("\n")
decomp_list: list[str] = [sq.strip() for sq in list_of_subqs if sq.strip() != ""]
now_end = datetime.now()
logger.info(
f"{now_start} -- INITIAL SUBQUESTION ANSWERING - Base Decomposition, Time taken: {now_end - now_start}"
)
return BaseDecompUpdate(
initial_decomp_questions=decomp_list,
agent_start_time=agent_start_time,
agent_refined_start_time=None,
agent_refined_end_time=None,
agent_refined_metrics=AgentRefinedMetrics(
refined_doc_boost_factor=None,
refined_question_boost_factor=None,
duration__s=None,
),
)

View File

@@ -0,0 +1,37 @@
from typing import TypedDict
from onyx.agents.agent_search.core_state import CoreState
from onyx.agents.agent_search.deep_search_a.main__graph.states import BaseDecompUpdate
from onyx.agents.agent_search.deep_search_a.main__graph.states import (
DecompAnswersUpdate,
)
from onyx.agents.agent_search.deep_search_a.main__graph.states import (
InitialAnswerUpdate,
)
### States ###
class SQInput(CoreState):
pass
## Graph State
class SQState(
# This includes the core state
SQInput,
BaseDecompUpdate,
InitialAnswerUpdate,
DecompAnswersUpdate,
):
# expanded_retrieval_result: Annotated[list[ExpandedRetrievalResult], add]
pass
## Graph Output State - presently not used
class SQOutput(TypedDict):
log_messages: list[str]

View File

@@ -0,0 +1,29 @@
from collections.abc import Hashable
from datetime import datetime
from langgraph.types import Send
from onyx.agents.agent_search.deep_search_a.initial__individual_sub_answer__subgraph.states import (
AnswerQuestionInput,
)
from onyx.agents.agent_search.deep_search_a.util__expanded_retrieval__subgraph.states import (
ExpandedRetrievalInput,
)
from onyx.utils.logger import setup_logger
logger = setup_logger()
def send_to_expanded_retrieval(state: AnswerQuestionInput) -> Send | Hashable:
logger.debug("sending to expanded retrieval via edge")
now_start = datetime.now()
return Send(
"initial_sub_question_expanded_retrieval",
ExpandedRetrievalInput(
question=state.question,
base_search=False,
sub_question_id=state.question_id,
log_messages=[f"{now_start} -- Sending to expanded retrieval"],
),
)

View File

@@ -0,0 +1,126 @@
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
from onyx.agents.agent_search.deep_search_a.initial__individual_sub_answer__subgraph.edges import (
send_to_expanded_retrieval,
)
from onyx.agents.agent_search.deep_search_a.initial__individual_sub_answer__subgraph.nodes.answer_check import (
answer_check,
)
from onyx.agents.agent_search.deep_search_a.initial__individual_sub_answer__subgraph.nodes.answer_generation import (
answer_generation,
)
from onyx.agents.agent_search.deep_search_a.initial__individual_sub_answer__subgraph.nodes.format_answer import (
format_answer,
)
from onyx.agents.agent_search.deep_search_a.initial__individual_sub_answer__subgraph.nodes.ingest_retrieval import (
ingest_retrieval,
)
from onyx.agents.agent_search.deep_search_a.initial__individual_sub_answer__subgraph.states import (
AnswerQuestionInput,
)
from onyx.agents.agent_search.deep_search_a.initial__individual_sub_answer__subgraph.states import (
AnswerQuestionOutput,
)
from onyx.agents.agent_search.deep_search_a.initial__individual_sub_answer__subgraph.states import (
AnswerQuestionState,
)
from onyx.agents.agent_search.deep_search_a.util__expanded_retrieval__subgraph.graph_builder import (
expanded_retrieval_graph_builder,
)
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
from onyx.utils.logger import setup_logger
logger = setup_logger()
def answer_query_graph_builder() -> StateGraph:
graph = StateGraph(
state_schema=AnswerQuestionState,
input=AnswerQuestionInput,
output=AnswerQuestionOutput,
)
### Add nodes ###
expanded_retrieval = expanded_retrieval_graph_builder().compile()
graph.add_node(
node="initial_sub_question_expanded_retrieval",
action=expanded_retrieval,
)
graph.add_node(
node="answer_check",
action=answer_check,
)
graph.add_node(
node="answer_generation",
action=answer_generation,
)
graph.add_node(
node="format_answer",
action=format_answer,
)
graph.add_node(
node="ingest_retrieval",
action=ingest_retrieval,
)
### Add edges ###
graph.add_conditional_edges(
source=START,
path=send_to_expanded_retrieval,
path_map=["initial_sub_question_expanded_retrieval"],
)
graph.add_edge(
start_key="initial_sub_question_expanded_retrieval",
end_key="ingest_retrieval",
)
graph.add_edge(
start_key="ingest_retrieval",
end_key="answer_generation",
)
graph.add_edge(
start_key="answer_generation",
end_key="answer_check",
)
graph.add_edge(
start_key="answer_check",
end_key="format_answer",
)
graph.add_edge(
start_key="format_answer",
end_key=END,
)
return graph
if __name__ == "__main__":
from onyx.db.engine import get_session_context_manager
from onyx.llm.factory import get_default_llms
from onyx.context.search.models import SearchRequest
graph = answer_query_graph_builder()
compiled_graph = graph.compile()
primary_llm, fast_llm = get_default_llms()
search_request = SearchRequest(
query="what can you do with onyx or danswer?",
)
with get_session_context_manager() as db_session:
agent_search_config, search_tool = get_test_config(
db_session, primary_llm, fast_llm, search_request
)
inputs = AnswerQuestionInput(
question="what can you do with onyx?",
question_id="0_0",
log_messages=[],
)
for thing in compiled_graph.stream(
input=inputs,
config={"configurable": {"config": agent_search_config}},
# debug=True,
# subgraphs=True,
):
logger.debug(thing)

View File

@@ -0,0 +1,8 @@
from pydantic import BaseModel
### Models ###
class AnswerRetrievalStats(BaseModel):
answer_retrieval_stats: dict[str, float | int]

View File

@@ -0,0 +1,59 @@
from datetime import datetime
from typing import cast
from langchain_core.messages import HumanMessage
from langchain_core.messages import merge_message_runs
from langchain_core.runnables.config import RunnableConfig
from onyx.agents.agent_search.deep_search_a.initial__individual_sub_answer__subgraph.states import (
AnswerQuestionState,
)
from onyx.agents.agent_search.deep_search_a.initial__individual_sub_answer__subgraph.states import (
QACheckUpdate,
)
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.agents.agent_search.shared_graph_utils.prompts import SUB_CHECK_NO
from onyx.agents.agent_search.shared_graph_utils.prompts import SUB_CHECK_PROMPT
from onyx.agents.agent_search.shared_graph_utils.prompts import UNKNOWN_ANSWER
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
def answer_check(state: AnswerQuestionState, config: RunnableConfig) -> QACheckUpdate:
now_start = datetime.now()
level, question_num = parse_question_id(state.question_id)
if state.answer == UNKNOWN_ANSWER:
now_end = datetime.now()
return QACheckUpdate(
answer_quality=SUB_CHECK_NO,
log_messages=[
f"{now_start} -- Answer check SQ-{level}-{question_num} - unknown answer, Time taken: {now_end - now_start}"
],
)
msg = [
HumanMessage(
content=SUB_CHECK_PROMPT.format(
question=state.question,
base_answer=state.answer,
)
)
]
agent_searchch_config = cast(AgentSearchConfig, config["metadata"]["config"])
fast_llm = agent_searchch_config.fast_llm
response = list(
fast_llm.stream(
prompt=msg,
)
)
quality_str = merge_message_runs(response, chunk_separator="")[0].content
now_end = datetime.now()
return QACheckUpdate(
answer_quality=quality_str,
log_messages=[
f"""{now_start} -- Answer check SQ-{level}-{question_num} - Answer quality: {quality_str},
Time taken: {now_end - now_start}"""
],
)

View File

@@ -0,0 +1,126 @@
from datetime import datetime
from typing import Any
from typing import cast
from langchain_core.callbacks.manager import dispatch_custom_event
from langchain_core.messages import merge_message_runs
from langchain_core.runnables.config import RunnableConfig
from onyx.agents.agent_search.deep_search_a.initial__individual_sub_answer__subgraph.states import (
AnswerQuestionState,
)
from onyx.agents.agent_search.deep_search_a.initial__individual_sub_answer__subgraph.states import (
QAGenerationUpdate,
)
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
build_sub_question_answer_prompt,
)
from onyx.agents.agent_search.shared_graph_utils.prompts import NO_RECOVERED_DOCS
from onyx.agents.agent_search.shared_graph_utils.utils import get_answer_citation_ids
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_persona_agent_prompt_expressions,
)
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
from onyx.chat.models import AgentAnswerPiece
from onyx.chat.models import StreamStopInfo
from onyx.chat.models import StreamStopReason
from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS
from onyx.utils.logger import setup_logger
logger = setup_logger()
def answer_generation(
state: AnswerQuestionState, config: RunnableConfig
) -> QAGenerationUpdate:
now_start = datetime.now()
logger.info(f"--------{now_start}--------START ANSWER GENERATION---")
agent_search_config = cast(AgentSearchConfig, config["metadata"]["config"])
question = state.question
state.documents
level, question_nr = parse_question_id(state.question_id)
context_docs = state.context_documents[:AGENT_MAX_ANSWER_CONTEXT_DOCS]
persona_contextualized_prompt = get_persona_agent_prompt_expressions(
agent_search_config.search_request.persona
).contextualized_prompt
if len(context_docs) == 0:
answer_str = NO_RECOVERED_DOCS
dispatch_custom_event(
"sub_answers",
AgentAnswerPiece(
answer_piece=answer_str,
level=level,
level_question_nr=question_nr,
answer_type="agent_sub_answer",
),
)
else:
logger.debug(f"Number of verified retrieval docs: {len(context_docs)}")
fast_llm = agent_search_config.fast_llm
msg = build_sub_question_answer_prompt(
question=question,
original_question=agent_search_config.search_request.query,
docs=context_docs,
persona_specification=persona_contextualized_prompt,
config=fast_llm.config,
)
response: list[str | list[str | dict[str, Any]]] = []
dispatch_timings: list[float] = []
for message in fast_llm.stream(
prompt=msg,
):
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
content = message.content
if not isinstance(content, str):
raise ValueError(
f"Expected content to be a string, but got {type(content)}"
)
start_stream_token = datetime.now()
dispatch_custom_event(
"sub_answers",
AgentAnswerPiece(
answer_piece=content,
level=level,
level_question_nr=question_nr,
answer_type="agent_sub_answer",
),
)
end_stream_token = datetime.now()
dispatch_timings.append(
(end_stream_token - start_stream_token).microseconds
)
response.append(content)
answer_str = merge_message_runs(response, chunk_separator="")[0].content
logger.info(
f"Average dispatch time: {sum(dispatch_timings) / len(dispatch_timings)}"
)
answer_citation_ids = get_answer_citation_ids(answer_str)
cited_docs = [context_docs[id] for id in answer_citation_ids]
stop_event = StreamStopInfo(
stop_reason=StreamStopReason.FINISHED,
stream_type="sub_answer",
level=level,
level_question_nr=question_nr,
)
dispatch_custom_event("stream_finished", stop_event)
now_end = datetime.now()
logger.info(
f"{now_start} -- Answer generation SQ-{level} - Q{question_nr} - Time taken: {now_end - now_start}"
)
return QAGenerationUpdate(
answer=answer_str,
cited_docs=cited_docs,
log_messages=[
f"{now_start} -- Answer generation SQ-{level} - Q{question_nr} - Time taken: {now_end - now_start}"
],
)

View File

@@ -0,0 +1,29 @@
from onyx.agents.agent_search.deep_search_a.initial__individual_sub_answer__subgraph.states import (
AnswerQuestionOutput,
)
from onyx.agents.agent_search.deep_search_a.initial__individual_sub_answer__subgraph.states import (
AnswerQuestionState,
)
from onyx.agents.agent_search.shared_graph_utils.models import (
QuestionAnswerResults,
)
def format_answer(state: AnswerQuestionState) -> AnswerQuestionOutput:
return AnswerQuestionOutput(
answer_results=[
QuestionAnswerResults(
question=state.question,
question_id=state.question_id,
quality=state.answer_quality
if hasattr(state, "answer_quality")
else "No",
answer=state.answer,
expanded_retrieval_results=state.expanded_retrieval_results,
documents=state.documents,
context_documents=state.context_documents,
cited_docs=state.cited_docs,
sub_question_retrieval_stats=state.sub_question_retrieval_stats,
)
],
)

View File

@@ -0,0 +1,22 @@
from onyx.agents.agent_search.deep_search_a.initial__individual_sub_answer__subgraph.states import (
RetrievalIngestionUpdate,
)
from onyx.agents.agent_search.deep_search_a.util__expanded_retrieval__subgraph.states import (
ExpandedRetrievalOutput,
)
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkStats
def ingest_retrieval(state: ExpandedRetrievalOutput) -> RetrievalIngestionUpdate:
sub_question_retrieval_stats = (
state.expanded_retrieval_result.sub_question_retrieval_stats
)
if sub_question_retrieval_stats is None:
sub_question_retrieval_stats = [AgentChunkStats()]
return RetrievalIngestionUpdate(
expanded_retrieval_results=state.expanded_retrieval_result.expanded_queries_results,
documents=state.expanded_retrieval_result.reranked_documents,
context_documents=state.expanded_retrieval_result.context_documents,
sub_question_retrieval_stats=sub_question_retrieval_stats,
)

View File

@@ -0,0 +1,72 @@
from operator import add
from typing import Annotated
from pydantic import BaseModel
from onyx.agents.agent_search.core_state import SubgraphCoreState
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkStats
from onyx.agents.agent_search.shared_graph_utils.models import QueryResult
from onyx.agents.agent_search.shared_graph_utils.models import (
QuestionAnswerResults,
)
from onyx.agents.agent_search.shared_graph_utils.operators import (
dedup_inference_sections,
)
from onyx.context.search.models import InferenceSection
## Update States
class QACheckUpdate(BaseModel):
answer_quality: str = ""
log_messages: list[str] = []
class QAGenerationUpdate(BaseModel):
answer: str = ""
log_messages: list[str] = []
cited_docs: Annotated[list[InferenceSection], dedup_inference_sections] = []
# answer_stat: AnswerStats
class RetrievalIngestionUpdate(BaseModel):
expanded_retrieval_results: list[QueryResult] = []
documents: Annotated[list[InferenceSection], dedup_inference_sections] = []
context_documents: Annotated[list[InferenceSection], dedup_inference_sections] = []
sub_question_retrieval_stats: AgentChunkStats = AgentChunkStats()
## Graph Input State
class AnswerQuestionInput(SubgraphCoreState):
question: str = ""
question_id: str = (
"" # 0_0 is original question, everything else is <level>_<question_num>.
)
# level 0 is original question and first decomposition, level 1 is follow up, etc
# question_num is a unique number per original question per level.
## Graph State
class AnswerQuestionState(
AnswerQuestionInput,
QAGenerationUpdate,
QACheckUpdate,
RetrievalIngestionUpdate,
):
pass
## Graph Output State
class AnswerQuestionOutput(BaseModel):
"""
This is a list of results even though each call of this subgraph only returns one result.
This is because if we parallelize the answer query subgraph, there will be multiple
results in a list so the add operator is used to add them together.
"""
answer_results: Annotated[list[QuestionAnswerResults], add] = []

View File

@@ -0,0 +1,89 @@
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
from onyx.agents.agent_search.deep_search_a.initial__retrieval__subgraph.nodes.format_raw_search_results import (
format_raw_search_results,
)
from onyx.agents.agent_search.deep_search_a.initial__retrieval__subgraph.nodes.generate_raw_search_data import (
generate_raw_search_data,
)
from onyx.agents.agent_search.deep_search_a.initial__retrieval__subgraph.nodes.ingest_initial_base_retrieval import (
ingest_initial_base_retrieval,
)
from onyx.agents.agent_search.deep_search_a.initial__retrieval__subgraph.states import (
BaseRawSearchInput,
)
from onyx.agents.agent_search.deep_search_a.initial__retrieval__subgraph.states import (
BaseRawSearchOutput,
)
from onyx.agents.agent_search.deep_search_a.initial__retrieval__subgraph.states import (
BaseRawSearchState,
)
from onyx.agents.agent_search.deep_search_a.util__expanded_retrieval__subgraph.graph_builder import (
expanded_retrieval_graph_builder,
)
def base_raw_search_graph_builder() -> StateGraph:
graph = StateGraph(
state_schema=BaseRawSearchState,
input=BaseRawSearchInput,
output=BaseRawSearchOutput,
)
### Add nodes ###
graph.add_node(
node="generate_raw_search_data",
action=generate_raw_search_data,
)
expanded_retrieval = expanded_retrieval_graph_builder().compile()
graph.add_node(
node="expanded_retrieval_base_search",
action=expanded_retrieval,
)
graph.add_node(
node="format_raw_search_results",
action=format_raw_search_results,
)
graph.add_node(
node="ingest_initial_base_retrieval",
action=ingest_initial_base_retrieval,
)
### Add edges ###
graph.add_edge(start_key=START, end_key="generate_raw_search_data")
graph.add_edge(
start_key="generate_raw_search_data",
end_key="expanded_retrieval_base_search",
)
graph.add_edge(
start_key="expanded_retrieval_base_search",
end_key="format_raw_search_results",
)
# graph.add_edge(
# start_key="expanded_retrieval_base_search",
# end_key=END,
# )
graph.add_edge(
start_key="format_raw_search_results",
end_key="ingest_initial_base_retrieval",
)
graph.add_edge(
start_key="ingest_initial_base_retrieval",
end_key=END,
)
return graph
if __name__ == "__main__":
pass

View File

@@ -0,0 +1,20 @@
from pydantic import BaseModel
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkStats
from onyx.agents.agent_search.shared_graph_utils.models import QueryResult
from onyx.context.search.models import InferenceSection
### Models ###
class AnswerRetrievalStats(BaseModel):
answer_retrieval_stats: dict[str, float | int]
class QuestionAnswerResults(BaseModel):
question: str
answer: str
quality: str
expanded_retrieval_results: list[QueryResult]
documents: list[InferenceSection]
sub_question_retrieval_stats: list[AgentChunkStats]

View File

@@ -0,0 +1,18 @@
from onyx.agents.agent_search.deep_search_a.initial__retrieval__subgraph.states import (
BaseRawSearchOutput,
)
from onyx.agents.agent_search.deep_search_a.util__expanded_retrieval__subgraph.states import (
ExpandedRetrievalOutput,
)
from onyx.utils.logger import setup_logger
logger = setup_logger()
def format_raw_search_results(state: ExpandedRetrievalOutput) -> BaseRawSearchOutput:
logger.debug("format_raw_search_results")
return BaseRawSearchOutput(
base_expanded_retrieval_result=state.expanded_retrieval_result,
# base_retrieval_results=[state.expanded_retrieval_result],
# base_search_documents=[],
)

View File

@@ -0,0 +1,25 @@
from typing import cast
from langchain_core.runnables.config import RunnableConfig
from onyx.agents.agent_search.core_state import CoreState
from onyx.agents.agent_search.deep_search_a.util__expanded_retrieval__subgraph.states import (
ExpandedRetrievalInput,
)
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.utils.logger import setup_logger
logger = setup_logger()
def generate_raw_search_data(
state: CoreState, config: RunnableConfig
) -> ExpandedRetrievalInput:
logger.debug("generate_raw_search_data")
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
return ExpandedRetrievalInput(
question=agent_a_config.search_request.query,
base_search=True,
sub_question_id=None, # This graph is always and only used for the original question
log_messages=[],
)

View File

@@ -0,0 +1,41 @@
from datetime import datetime
from onyx.agents.agent_search.deep_search_a.initial__retrieval__subgraph.states import (
BaseRawSearchOutput,
)
from onyx.agents.agent_search.deep_search_a.main__graph.operations import logger
from onyx.agents.agent_search.deep_search_a.main__graph.states import (
ExpandedRetrievalUpdate,
)
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkStats
def ingest_initial_base_retrieval(
state: BaseRawSearchOutput,
) -> ExpandedRetrievalUpdate:
now_start = datetime.now()
logger.info(f"--------{now_start}--------INGEST INITIAL RETRIEVAL---")
sub_question_retrieval_stats = (
state.base_expanded_retrieval_result.sub_question_retrieval_stats
)
if sub_question_retrieval_stats is None:
sub_question_retrieval_stats = AgentChunkStats()
else:
sub_question_retrieval_stats = sub_question_retrieval_stats
now_end = datetime.now()
logger.debug(
f"--------{now_end}--{now_end - now_start}--------INGEST INITIAL RETRIEVAL END---"
)
return ExpandedRetrievalUpdate(
original_question_retrieval_results=state.base_expanded_retrieval_result.expanded_queries_results,
all_original_question_documents=state.base_expanded_retrieval_result.context_documents,
original_question_retrieval_stats=sub_question_retrieval_stats,
log_messages=[
f"{now_start} -- Main - Ingestion base retrieval, Time taken: {now_end - now_start}"
],
)

View File

@@ -0,0 +1,44 @@
from pydantic import BaseModel
from onyx.agents.agent_search.deep_search_a.main__graph.states import (
ExpandedRetrievalUpdate,
)
from onyx.agents.agent_search.deep_search_a.util__expanded_retrieval__subgraph.models import (
ExpandedRetrievalResult,
)
from onyx.agents.agent_search.deep_search_a.util__expanded_retrieval__subgraph.states import (
ExpandedRetrievalInput,
)
## Update States
## Graph Input State
class BaseRawSearchInput(ExpandedRetrievalInput):
pass
## Graph Output State
class BaseRawSearchOutput(BaseModel):
"""
This is a list of results even though each call of this subgraph only returns one result.
This is because if we parallelize the answer query subgraph, there will be multiple
results in a list so the add operator is used to add them together.
"""
# base_search_documents: Annotated[list[InferenceSection], dedup_inference_sections]
# base_retrieval_results: Annotated[list[ExpandedRetrievalResult], add]
base_expanded_retrieval_result: ExpandedRetrievalResult = ExpandedRetrievalResult()
## Graph State
class BaseRawSearchState(
BaseRawSearchInput, BaseRawSearchOutput, ExpandedRetrievalUpdate
):
pass

View File

@@ -0,0 +1,55 @@
from collections.abc import Hashable
from datetime import datetime
from langgraph.types import Send
from onyx.agents.agent_search.deep_search_a.initial__individual_sub_answer__subgraph.states import (
AnswerQuestionInput,
)
from onyx.agents.agent_search.deep_search_a.initial__individual_sub_answer__subgraph.states import (
AnswerQuestionOutput,
)
from onyx.agents.agent_search.deep_search_a.initial__retrieval_sub_answers__subgraph.states import (
SearchSQState,
)
from onyx.agents.agent_search.shared_graph_utils.utils import make_question_id
def parallelize_initial_sub_question_answering(
state: SearchSQState,
) -> list[Send | Hashable]:
now_start = datetime.now()
if len(state.initial_decomp_questions) > 0:
# sub_question_record_ids = [subq_record.id for subq_record in state["sub_question_records"]]
# if len(state["sub_question_records"]) == 0:
# if state["config"].use_persistence:
# raise ValueError("No sub-questions found for initial decompozed questions")
# else:
# # in this case, we are doing retrieval on the original question.
# # to make all the logic consistent, we create a new sub-question
# # with the same content as the original question
# sub_question_record_ids = [1] * len(state["initial_decomp_questions"])
return [
Send(
"answer_query_subgraph",
AnswerQuestionInput(
question=question,
question_id=make_question_id(0, question_nr + 1),
log_messages=[
f"{now_start} -- Main Edge - Parallelize Initial Sub-question Answering"
],
),
)
for question_nr, question in enumerate(state.initial_decomp_questions)
]
else:
return [
Send(
"ingest_answers",
AnswerQuestionOutput(
answer_results=[],
),
)
]

View File

@@ -0,0 +1,139 @@
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
from onyx.agents.agent_search.deep_search_a.initial__consolidate_sub_answers__subgraph.graph_builder import (
initial_sq_subgraph_builder,
)
from onyx.agents.agent_search.deep_search_a.initial__retrieval__subgraph.graph_builder import (
base_raw_search_graph_builder,
)
from onyx.agents.agent_search.deep_search_a.initial__retrieval_sub_answers__subgraph.nodes.generate_initial_answer import (
generate_initial_answer,
)
from onyx.agents.agent_search.deep_search_a.initial__retrieval_sub_answers__subgraph.nodes.initial_answer_quality_check import (
initial_answer_quality_check,
)
from onyx.agents.agent_search.deep_search_a.initial__retrieval_sub_answers__subgraph.nodes.retrieval_consolidation import (
retrieval_consolidation,
)
from onyx.agents.agent_search.deep_search_a.initial__retrieval_sub_answers__subgraph.states import (
SearchSQInput,
)
from onyx.agents.agent_search.deep_search_a.initial__retrieval_sub_answers__subgraph.states import (
SearchSQState,
)
from onyx.utils.logger import setup_logger
logger = setup_logger()
def initial_search_sq_subgraph_builder(test_mode: bool = False) -> StateGraph:
graph = StateGraph(
state_schema=SearchSQState,
input=SearchSQInput,
)
# graph.add_node(
# node="initial_sub_question_creation",
# action=initial_sub_question_creation,
# )
sub_question_answering_subgraph = initial_sq_subgraph_builder().compile()
graph.add_node(
node="sub_question_answering_subgraph",
action=sub_question_answering_subgraph,
)
# answer_query_subgraph = answer_query_graph_builder().compile()
# graph.add_node(
# node="answer_query_subgraph",
# action=answer_query_subgraph,
# )
base_raw_search_subgraph = base_raw_search_graph_builder().compile()
graph.add_node(
node="base_raw_search_subgraph",
action=base_raw_search_subgraph,
)
graph.add_node(
node="retrieval_consolidation",
action=retrieval_consolidation,
)
graph.add_node(
node="generate_initial_answer",
action=generate_initial_answer,
)
graph.add_node(
node="initial_answer_quality_check",
action=initial_answer_quality_check,
)
### Add edges ###
# raph.add_edge(start_key=START, end_key="base_raw_search_subgraph")
graph.add_edge(
start_key=START,
end_key="base_raw_search_subgraph",
)
# graph.add_edge(
# start_key="agent_search_start",
# end_key="entity_term_extraction_llm",
# )
graph.add_edge(
start_key=START,
end_key="sub_question_answering_subgraph",
)
graph.add_edge(
start_key=["base_raw_search_subgraph", "sub_question_answering_subgraph"],
end_key="retrieval_consolidation",
)
graph.add_edge(
start_key="retrieval_consolidation",
end_key="generate_initial_answer",
)
# graph.add_edge(
# start_key="LLM",
# end_key=END,
# )
# graph.add_edge(
# start_key=START,
# end_key="initial_sub_question_creation",
# )
graph.add_edge(
start_key="retrieval_consolidation",
end_key="generate_initial_answer",
)
graph.add_edge(
start_key="generate_initial_answer",
end_key="initial_answer_quality_check",
)
graph.add_edge(
start_key="initial_answer_quality_check",
end_key=END,
)
# graph.add_edge(
# start_key="generate_refined_answer",
# end_key="check_refined_answer",
# )
# graph.add_edge(
# start_key="check_refined_answer",
# end_key=END,
# )
return graph

View File

@@ -0,0 +1,276 @@
from datetime import datetime
from typing import Any
from typing import cast
from langchain_core.callbacks.manager import dispatch_custom_event
from langchain_core.messages import HumanMessage
from langchain_core.messages import merge_content
from langchain_core.runnables import RunnableConfig
from onyx.agents.agent_search.deep_search_a.initial__retrieval_sub_answers__subgraph.states import (
SearchSQState,
)
from onyx.agents.agent_search.deep_search_a.main__graph.models import AgentBaseMetrics
from onyx.agents.agent_search.deep_search_a.main__graph.operations import (
calculate_initial_agent_stats,
)
from onyx.agents.agent_search.deep_search_a.main__graph.operations import get_query_info
from onyx.agents.agent_search.deep_search_a.main__graph.operations import logger
from onyx.agents.agent_search.deep_search_a.main__graph.operations import (
remove_document_citations,
)
from onyx.agents.agent_search.deep_search_a.main__graph.states import (
InitialAnswerUpdate,
)
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
get_prompt_enrichment_components,
)
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
trim_prompt_piece,
)
from onyx.agents.agent_search.shared_graph_utils.models import InitialAgentResultStats
from onyx.agents.agent_search.shared_graph_utils.operators import (
dedup_inference_sections,
)
from onyx.agents.agent_search.shared_graph_utils.prompts import INITIAL_RAG_PROMPT
from onyx.agents.agent_search.shared_graph_utils.prompts import (
INITIAL_RAG_PROMPT_NO_SUB_QUESTIONS,
)
from onyx.agents.agent_search.shared_graph_utils.prompts import (
SUB_QUESTION_ANSWER_TEMPLATE,
)
from onyx.agents.agent_search.shared_graph_utils.prompts import UNKNOWN_ANSWER
from onyx.agents.agent_search.shared_graph_utils.utils import (
dispatch_main_answer_stop_info,
)
from onyx.agents.agent_search.shared_graph_utils.utils import format_docs
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
from onyx.chat.models import AgentAnswerPiece
from onyx.chat.models import ExtendedToolResponse
from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS
from onyx.configs.agent_configs import AGENT_MIN_ORIG_QUESTION_DOCS
from onyx.context.search.models import InferenceSection
from onyx.tools.tool_implementations.search.search_tool import yield_search_responses
def generate_initial_answer(
state: SearchSQState, config: RunnableConfig
) -> InitialAnswerUpdate:
now_start = datetime.now()
logger.info(f"--------{now_start}--------GENERATE INITIAL---")
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
question = agent_a_config.search_request.query
prompt_enrichment_components = get_prompt_enrichment_components(agent_a_config)
sub_questions_cited_docs = state.cited_docs
all_original_question_documents = state.all_original_question_documents
consolidated_context_docs: list[InferenceSection] = sub_questions_cited_docs
counter = 0
for original_doc_number, original_doc in enumerate(all_original_question_documents):
if original_doc_number not in sub_questions_cited_docs:
if (
counter <= AGENT_MIN_ORIG_QUESTION_DOCS
or len(consolidated_context_docs) < AGENT_MAX_ANSWER_CONTEXT_DOCS
):
consolidated_context_docs.append(original_doc)
counter += 1
# sort docs by their scores - though the scores refer to different questions
relevant_docs = dedup_inference_sections(
consolidated_context_docs, consolidated_context_docs
)
decomp_questions = []
# Use the query info from the base document retrieval
query_info = get_query_info(state.original_question_retrieval_results)
if agent_a_config.search_tool is None:
raise ValueError("search_tool must be provided for agentic search")
for tool_response in yield_search_responses(
query=question,
reranked_sections=relevant_docs,
final_context_sections=relevant_docs,
search_query_info=query_info,
get_section_relevance=lambda: None, # TODO: add relevance
search_tool=agent_a_config.search_tool,
):
dispatch_custom_event(
"tool_response",
ExtendedToolResponse(
id=tool_response.id,
response=tool_response.response,
level=0,
level_question_nr=0, # 0, 0 is the base question
),
)
if len(relevant_docs) == 0:
dispatch_custom_event(
"initial_agent_answer",
AgentAnswerPiece(
answer_piece=UNKNOWN_ANSWER,
level=0,
level_question_nr=0,
answer_type="agent_level_answer",
),
)
dispatch_main_answer_stop_info(0)
answer = UNKNOWN_ANSWER
initial_agent_stats = InitialAgentResultStats(
sub_questions={},
original_question={},
agent_effectiveness={},
)
else:
decomp_answer_results = state.decomp_answer_results
good_qa_list: list[str] = []
sub_question_nr = 1
for decomp_answer_result in decomp_answer_results:
decomp_questions.append(decomp_answer_result.question)
_, question_nr = parse_question_id(decomp_answer_result.question_id)
if (
decomp_answer_result.quality.lower().startswith("yes")
and len(decomp_answer_result.answer) > 0
and decomp_answer_result.answer != UNKNOWN_ANSWER
):
good_qa_list.append(
SUB_QUESTION_ANSWER_TEMPLATE.format(
sub_question=decomp_answer_result.question,
sub_answer=decomp_answer_result.answer,
sub_question_nr=sub_question_nr,
)
)
sub_question_nr += 1
if len(good_qa_list) > 0:
sub_question_answer_str = "\n\n------\n\n".join(good_qa_list)
else:
sub_question_answer_str = ""
# Determine which base prompt to use given the sub-question information
if len(good_qa_list) > 0:
base_prompt = INITIAL_RAG_PROMPT
else:
base_prompt = INITIAL_RAG_PROMPT_NO_SUB_QUESTIONS
model = agent_a_config.fast_llm
doc_context = format_docs(relevant_docs)
doc_context = trim_prompt_piece(
model.config,
doc_context,
base_prompt
+ sub_question_answer_str
+ prompt_enrichment_components.persona_prompts.contextualized_prompt
+ prompt_enrichment_components.history
+ prompt_enrichment_components.date_str,
)
msg = [
HumanMessage(
content=base_prompt.format(
question=question,
answered_sub_questions=remove_document_citations(
sub_question_answer_str
),
relevant_docs=format_docs(relevant_docs),
persona_specification=prompt_enrichment_components.persona_prompts.contextualized_prompt,
history=prompt_enrichment_components.history,
date_prompt=prompt_enrichment_components.date_str,
)
)
]
streamed_tokens: list[str | list[str | dict[str, Any]]] = [""]
dispatch_timings: list[float] = []
for message in model.stream(msg):
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
content = message.content
if not isinstance(content, str):
raise ValueError(
f"Expected content to be a string, but got {type(content)}"
)
start_stream_token = datetime.now()
dispatch_custom_event(
"initial_agent_answer",
AgentAnswerPiece(
answer_piece=content,
level=0,
level_question_nr=0,
answer_type="agent_level_answer",
),
)
end_stream_token = datetime.now()
dispatch_timings.append(
(end_stream_token - start_stream_token).microseconds
)
streamed_tokens.append(content)
logger.info(
f"Average dispatch time for initial answer: {sum(dispatch_timings) / len(dispatch_timings)}"
)
dispatch_main_answer_stop_info(0)
response = merge_content(*streamed_tokens)
answer = cast(str, response)
initial_agent_stats = calculate_initial_agent_stats(
state.decomp_answer_results, state.original_question_retrieval_stats
)
logger.debug(
f"\n\nYYYYY--Sub-Questions:\n\n{sub_question_answer_str}\n\nStats:\n\n"
)
if initial_agent_stats:
logger.debug(initial_agent_stats.original_question)
logger.debug(initial_agent_stats.sub_questions)
logger.debug(initial_agent_stats.agent_effectiveness)
now_end = datetime.now()
agent_base_end_time = datetime.now()
agent_base_metrics = AgentBaseMetrics(
num_verified_documents_total=len(relevant_docs),
num_verified_documents_core=state.original_question_retrieval_stats.verified_count,
verified_avg_score_core=state.original_question_retrieval_stats.verified_avg_scores,
num_verified_documents_base=initial_agent_stats.sub_questions.get(
"num_verified_documents", None
),
verified_avg_score_base=initial_agent_stats.sub_questions.get(
"verified_avg_score", None
),
base_doc_boost_factor=initial_agent_stats.agent_effectiveness.get(
"utilized_chunk_ratio", None
),
support_boost_factor=initial_agent_stats.agent_effectiveness.get(
"support_ratio", None
),
duration__s=(agent_base_end_time - state.agent_start_time).total_seconds(),
)
logger.info(
f"{now_start} -- Main - Initial Answer generation, Time taken: {now_end - now_start}"
)
return InitialAnswerUpdate(
initial_answer=answer,
initial_agent_stats=initial_agent_stats,
generated_sub_questions=decomp_questions,
agent_base_end_time=agent_base_end_time,
agent_base_metrics=agent_base_metrics,
log_messages=[
f"{now_start} -- Main - Initial Answer generation, Time taken: {now_end - now_start}"
],
)

View File

@@ -0,0 +1,43 @@
from datetime import datetime
from onyx.agents.agent_search.deep_search_a.initial__retrieval__subgraph.states import (
BaseRawSearchOutput,
)
from onyx.agents.agent_search.deep_search_a.main__graph.operations import logger
from onyx.agents.agent_search.deep_search_a.main__graph.states import (
ExpandedRetrievalUpdate,
)
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkStats
def ingest_initial_base_retrieval(
state: BaseRawSearchOutput,
) -> ExpandedRetrievalUpdate:
now_start = datetime.now()
logger.info(f"--------{now_start}--------INGEST INITIAL RETRIEVAL---")
sub_question_retrieval_stats = (
state.base_expanded_retrieval_result.sub_question_retrieval_stats
)
# if sub_question_retrieval_stats is None:
# sub_question_retrieval_stats = AgentChunkStats()
# else:
# sub_question_retrieval_stats = sub_question_retrieval_stats
sub_question_retrieval_stats = sub_question_retrieval_stats or AgentChunkStats()
now_end = datetime.now()
logger.debug(
f"--------{now_end}--{now_end - now_start}--------INGEST INITIAL RETRIEVAL END---"
)
return ExpandedRetrievalUpdate(
original_question_retrieval_results=state.base_expanded_retrieval_result.expanded_queries_results,
all_original_question_documents=state.base_expanded_retrieval_result.context_documents,
original_question_retrieval_stats=sub_question_retrieval_stats,
log_messages=[
f"{now_start} -- Main - Ingestion base retrieval, Time taken: {now_end - now_start}"
],
)

View File

@@ -0,0 +1,42 @@
from datetime import datetime
from onyx.agents.agent_search.deep_search_a.initial__retrieval_sub_answers__subgraph.states import (
SearchSQState,
)
from onyx.agents.agent_search.deep_search_a.main__graph.operations import logger
from onyx.agents.agent_search.deep_search_a.main__graph.states import (
InitialAnswerQualityUpdate,
)
def initial_answer_quality_check(state: SearchSQState) -> InitialAnswerQualityUpdate:
"""
Check whether the final output satisfies the original user question
Args:
state (messages): The current state
Returns:
InitialAnswerQualityUpdate
"""
now_start = datetime.now()
logger.debug(
f"--------{now_start}--------Checking for base answer validity - for not set True/False manually"
)
verdict = True
now_end = datetime.now()
logger.debug(
f"--------{now_end}--{now_end - now_start}--------INITIAL ANSWER QUALITY CHECK END---"
)
return InitialAnswerQualityUpdate(
initial_answer_quality_eval=verdict,
log_messages=[
f"{now_start} -- Main - Initial answer quality check, Time taken: {now_end - now_start}"
],
)

View File

@@ -0,0 +1,14 @@
from datetime import datetime
from onyx.agents.agent_search.deep_search_a.initial__retrieval_sub_answers__subgraph.states import (
SearchSQState,
)
from onyx.agents.agent_search.deep_search_a.main__graph.states import LoggerUpdate
def retrieval_consolidation(
state: SearchSQState,
) -> LoggerUpdate:
now_start = datetime.now()
return LoggerUpdate(log_messages=[f"{now_start} -- Retrieval consolidation"])

View File

@@ -0,0 +1,54 @@
from operator import add
from typing import Annotated
from typing import TypedDict
from onyx.agents.agent_search.core_state import CoreState
from onyx.agents.agent_search.deep_search_a.main__graph.states import BaseDecompUpdate
from onyx.agents.agent_search.deep_search_a.main__graph.states import (
DecompAnswersUpdate,
)
from onyx.agents.agent_search.deep_search_a.main__graph.states import (
ExpandedRetrievalUpdate,
)
from onyx.agents.agent_search.deep_search_a.main__graph.states import (
ExploratorySearchUpdate,
)
from onyx.agents.agent_search.deep_search_a.main__graph.states import (
InitialAnswerQualityUpdate,
)
from onyx.agents.agent_search.deep_search_a.main__graph.states import (
InitialAnswerUpdate,
)
from onyx.agents.agent_search.deep_search_a.util__expanded_retrieval__subgraph.models import (
ExpandedRetrievalResult,
)
### States ###
class SearchSQInput(CoreState):
pass
## Graph State
class SearchSQState(
# This includes the core state
SearchSQInput,
BaseDecompUpdate,
InitialAnswerUpdate,
DecompAnswersUpdate,
ExpandedRetrievalUpdate,
InitialAnswerQualityUpdate,
ExploratorySearchUpdate,
):
# expanded_retrieval_result: Annotated[list[ExpandedRetrievalResult], add]
base_raw_search_result: Annotated[list[ExpandedRetrievalResult], add]
## Graph Output State - presently not used
class SearchSQOutput(TypedDict):
log_messages: list[str]

View File

@@ -0,0 +1,120 @@
from collections.abc import Hashable
from datetime import datetime
from typing import cast
from typing import Literal
from langchain_core.runnables import RunnableConfig
from langgraph.types import Send
from onyx.agents.agent_search.deep_search_a.initial__individual_sub_answer__subgraph.states import (
AnswerQuestionInput,
)
from onyx.agents.agent_search.deep_search_a.initial__individual_sub_answer__subgraph.states import (
AnswerQuestionOutput,
)
from onyx.agents.agent_search.deep_search_a.main__graph.states import MainState
from onyx.agents.agent_search.deep_search_a.main__graph.states import (
RequireRefinedAnswerUpdate,
)
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.agents.agent_search.shared_graph_utils.utils import make_question_id
from onyx.utils.logger import setup_logger
logger = setup_logger()
def route_initial_tool_choice(
state: MainState, config: RunnableConfig
) -> Literal["tool_call", "agent_search_start", "logging_node"]:
agent_config = cast(AgentSearchConfig, config["metadata"]["config"])
if state.tool_choice is not None:
if (
agent_config.use_agentic_search
and agent_config.search_tool is not None
and state.tool_choice.tool.name == agent_config.search_tool.name
):
return "agent_search_start"
else:
return "tool_call"
else:
return "logging_node"
def parallelize_initial_sub_question_answering(
state: MainState,
) -> list[Send | Hashable]:
now_start = datetime.now()
if len(state.initial_decomp_questions) > 0:
# sub_question_record_ids = [subq_record.id for subq_record in state["sub_question_records"]]
# if len(state["sub_question_records"]) == 0:
# if state["config"].use_persistence:
# raise ValueError("No sub-questions found for initial decompozed questions")
# else:
# # in this case, we are doing retrieval on the original question.
# # to make all the logic consistent, we create a new sub-question
# # with the same content as the original question
# sub_question_record_ids = [1] * len(state["initial_decomp_questions"])
return [
Send(
"answer_query_subgraph",
AnswerQuestionInput(
question=question,
question_id=make_question_id(0, question_nr + 1),
log_messages=[
f"{now_start} -- Main Edge - Parallelize Initial Sub-question Answering"
],
),
)
for question_nr, question in enumerate(state.initial_decomp_questions)
]
else:
return [
Send(
"ingest_answers",
AnswerQuestionOutput(
answer_results=[],
),
)
]
# Define the function that determines whether to continue or not
def continue_to_refined_answer_or_end(
state: RequireRefinedAnswerUpdate,
) -> Literal["refined_sub_question_creation", "logging_node"]:
if state.require_refined_answer_eval:
return "refined_sub_question_creation"
else:
return "logging_node"
def parallelize_refined_sub_question_answering(
state: MainState,
) -> list[Send | Hashable]:
now_start = datetime.now()
if len(state.refined_sub_questions) > 0:
return [
Send(
"answer_refined_question",
AnswerQuestionInput(
question=question_data.sub_question,
question_id=make_question_id(1, question_nr),
log_messages=[
f"{now_start} -- Main Edge - Parallelize Refined Sub-question Answering"
],
),
)
for question_nr, question_data in state.refined_sub_questions.items()
]
else:
return [
Send(
"ingest_refined_sub_answers",
AnswerQuestionOutput(
answer_results=[],
),
)
]

View File

@@ -0,0 +1,410 @@
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
from onyx.agents.agent_search.deep_search_a.initial__retrieval_sub_answers__subgraph.graph_builder import (
initial_search_sq_subgraph_builder,
)
from onyx.agents.agent_search.deep_search_a.main__graph.edges import (
continue_to_refined_answer_or_end,
)
from onyx.agents.agent_search.deep_search_a.main__graph.edges import (
parallelize_refined_sub_question_answering,
)
from onyx.agents.agent_search.deep_search_a.main__graph.edges import (
route_initial_tool_choice,
)
from onyx.agents.agent_search.deep_search_a.main__graph.nodes.agent_logging import (
agent_logging,
)
from onyx.agents.agent_search.deep_search_a.main__graph.nodes.agent_search_start import (
agent_search_start,
)
from onyx.agents.agent_search.deep_search_a.main__graph.nodes.answer_comparison import (
answer_comparison,
)
from onyx.agents.agent_search.deep_search_a.main__graph.nodes.entity_term_extraction_llm import (
entity_term_extraction_llm,
)
from onyx.agents.agent_search.deep_search_a.main__graph.nodes.generate_refined_answer import (
generate_refined_answer,
)
from onyx.agents.agent_search.deep_search_a.main__graph.nodes.ingest_refined_answers import (
ingest_refined_answers,
)
from onyx.agents.agent_search.deep_search_a.main__graph.nodes.refined_answer_decision import (
refined_answer_decision,
)
from onyx.agents.agent_search.deep_search_a.main__graph.nodes.refined_sub_question_creation import (
refined_sub_question_creation,
)
from onyx.agents.agent_search.deep_search_a.main__graph.states import MainInput
from onyx.agents.agent_search.deep_search_a.main__graph.states import MainState
from onyx.agents.agent_search.deep_search_a.refinement__consolidate_sub_answers__subgraph.graph_builder import (
answer_refined_query_graph_builder,
)
from onyx.agents.agent_search.orchestration.nodes.basic_use_tool_response import (
basic_use_tool_response,
)
from onyx.agents.agent_search.orchestration.nodes.llm_tool_choice import llm_tool_choice
from onyx.agents.agent_search.orchestration.nodes.prepare_tool_input import (
prepare_tool_input,
)
from onyx.agents.agent_search.orchestration.nodes.tool_call import tool_call
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
from onyx.utils.logger import setup_logger
logger = setup_logger()
test_mode = False
def main_graph_builder(test_mode: bool = False) -> StateGraph:
graph = StateGraph(
state_schema=MainState,
input=MainInput,
)
# graph.add_node(
# node="agent_path_decision",
# action=agent_path_decision,
# )
# graph.add_node(
# node="agent_path_routing",
# action=agent_path_routing,
# )
# graph.add_node(
# node="LLM",
# action=direct_llm_handling,
# )
graph.add_node(
node="prepare_tool_input",
action=prepare_tool_input,
)
graph.add_node(
node="initial_tool_choice",
action=llm_tool_choice,
)
graph.add_node(
node="tool_call",
action=tool_call,
)
graph.add_node(
node="basic_use_tool_response",
action=basic_use_tool_response,
)
graph.add_node(
node="agent_search_start",
action=agent_search_start,
)
# graph.add_node(
# node="initial_sub_question_creation",
# action=initial_sub_question_creation,
# )
initial_search_sq_subgraph = initial_search_sq_subgraph_builder().compile()
graph.add_node(
node="initial_search_sq_subgraph",
action=initial_search_sq_subgraph,
)
# answer_query_subgraph = answer_query_graph_builder().compile()
# graph.add_node(
# node="answer_query_subgraph",
# action=answer_query_subgraph,
# )
# base_raw_search_subgraph = base_raw_search_graph_builder().compile()
# graph.add_node(
# node="base_raw_search_subgraph",
# action=base_raw_search_subgraph,
# )
# refined_answer_subgraph = refined_answers_graph_builder().compile()
# graph.add_node(
# node="refined_answer_subgraph",
# action=refined_answer_subgraph,
# )
graph.add_node(
node="refined_sub_question_creation",
action=refined_sub_question_creation,
)
answer_refined_question = answer_refined_query_graph_builder().compile()
graph.add_node(
node="answer_refined_question",
action=answer_refined_question,
)
graph.add_node(
node="ingest_refined_answers",
action=ingest_refined_answers,
)
graph.add_node(
node="generate_refined_answer",
action=generate_refined_answer,
)
# graph.add_node(
# node="check_refined_answer",
# action=check_refined_answer,
# )
# graph.add_node(
# node="ingest_initial_retrieval",
# action=ingest_initial_base_retrieval,
# )
# graph.add_node(
# node="retrieval_consolidation",
# action=retrieval_consolidation,
# )
# graph.add_node(
# node="ingest_initial_sub_question_answers",
# action=ingest_initial_sub_question_answers,
# )
# graph.add_node(
# node="generate_initial_answer",
# action=generate_initial_answer,
# )
# graph.add_node(
# node="initial_answer_quality_check",
# action=initial_answer_quality_check,
# )
graph.add_node(
node="entity_term_extraction_llm",
action=entity_term_extraction_llm,
)
graph.add_node(
node="refined_answer_decision",
action=refined_answer_decision,
)
graph.add_node(
node="answer_comparison",
action=answer_comparison,
)
graph.add_node(
node="logging_node",
action=agent_logging,
)
# if test_mode:
# graph.add_node(
# node="generate_initial_base_answer",
# action=generate_initial_base_answer,
# )
### Add edges ###
# raph.add_edge(start_key=START, end_key="base_raw_search_subgraph")
# graph.add_edge(
# start_key=START,
# end_key="agent_path_decision",
# )
# graph.add_edge(
# start_key="agent_path_decision",
# end_key="agent_path_routing",
# )
graph.add_edge(start_key=START, end_key="prepare_tool_input")
graph.add_edge(
start_key="prepare_tool_input",
end_key="initial_tool_choice",
)
graph.add_conditional_edges(
"initial_tool_choice",
route_initial_tool_choice,
["tool_call", "agent_search_start", "logging_node"],
)
graph.add_edge(
start_key="tool_call",
end_key="basic_use_tool_response",
)
graph.add_edge(
start_key="basic_use_tool_response",
end_key="logging_node",
)
graph.add_edge(
start_key="agent_search_start",
end_key="initial_search_sq_subgraph",
)
# graph.add_edge(
# start_key="agent_search_start",
# end_key="base_raw_search_subgraph",
# )
graph.add_edge(
start_key="agent_search_start",
end_key="entity_term_extraction_llm",
)
# graph.add_edge(
# start_key="agent_search_start",
# end_key="initial_sub_question_creation",
# )
# graph.add_edge(
# start_key="base_raw_search_subgraph",
# end_key="ingest_initial_retrieval",
# )
# graph.add_edge(
# start_key=["ingest_initial_retrieval", "ingest_initial_sub_question_answers"],
# end_key="retrieval_consolidation",
# )
# graph.add_edge(
# start_key="retrieval_consolidation",
# end_key="generate_initial_answer",
# )
# graph.add_edge(
# start_key="LLM",
# end_key=END,
# )
# graph.add_edge(
# start_key=START,
# end_key="initial_sub_question_creation",
# )
# graph.add_conditional_edges(
# source="initial_sub_question_creation",
# path=parallelize_initial_sub_question_answering,
# path_map=["answer_query_subgraph"],
# )
# graph.add_edge(
# start_key="answer_query_subgraph",
# end_key="ingest_initial_sub_question_answers",
# )
# graph.add_edge(
# start_key="retrieval_consolidation",
# end_key="generate_initial_answer",
# )
# graph.add_edge(
# start_key="generate_initial_answer",
# end_key="entity_term_extraction_llm",
# )
# graph.add_edge(
# start_key="generate_initial_answer",
# end_key="initial_answer_quality_check",
# )
# graph.add_edge(
# start_key=["initial_answer_quality_check", "entity_term_extraction_llm"],
# end_key="refined_answer_decision",
# )
# graph.add_edge(
# start_key="initial_answer_quality_check",
# end_key="refined_answer_decision",
# )
graph.add_edge(
start_key=["initial_search_sq_subgraph", "entity_term_extraction_llm"],
end_key="refined_answer_decision",
)
graph.add_conditional_edges(
source="refined_answer_decision",
path=continue_to_refined_answer_or_end,
path_map=["refined_sub_question_creation", "logging_node"],
)
graph.add_conditional_edges(
source="refined_sub_question_creation", # DONE
path=parallelize_refined_sub_question_answering,
path_map=["answer_refined_question"],
)
graph.add_edge(
start_key="answer_refined_question", # HERE
end_key="ingest_refined_answers",
)
graph.add_edge(
start_key="ingest_refined_answers",
end_key="generate_refined_answer",
)
# graph.add_conditional_edges(
# source="refined_answer_decision",
# path=continue_to_refined_answer_or_end,
# path_map=["refined_answer_subgraph", END],
# )
# graph.add_edge(
# start_key="refined_answer_subgraph",
# end_key="generate_refined_answer",
# )
graph.add_edge(
start_key="generate_refined_answer",
end_key="answer_comparison",
)
graph.add_edge(
start_key="answer_comparison",
end_key="logging_node",
)
graph.add_edge(
start_key="logging_node",
end_key=END,
)
# graph.add_edge(
# start_key="generate_refined_answer",
# end_key="check_refined_answer",
# )
# graph.add_edge(
# start_key="check_refined_answer",
# end_key=END,
# )
return graph
if __name__ == "__main__":
pass
from onyx.db.engine import get_session_context_manager
from onyx.llm.factory import get_default_llms
from onyx.context.search.models import SearchRequest
graph = main_graph_builder()
compiled_graph = graph.compile()
primary_llm, fast_llm = get_default_llms()
with get_session_context_manager() as db_session:
search_request = SearchRequest(query="Who created Excel?")
agent_a_config, search_tool = get_test_config(
db_session, primary_llm, fast_llm, search_request
)
inputs = MainInput(
base_question=agent_a_config.search_request.query, log_messages=[]
)
for thing in compiled_graph.stream(
input=inputs,
config={"configurable": {"config": agent_a_config}},
# stream_mode="debug",
# debug=True,
subgraphs=True,
):
logger.debug(thing)

View File

@@ -0,0 +1,36 @@
from pydantic import BaseModel
class FollowUpSubQuestion(BaseModel):
sub_question: str
sub_question_id: str
verified: bool
answered: bool
answer: str
class AgentTimings(BaseModel):
base_duration__s: float | None
refined_duration__s: float | None
full_duration__s: float | None
class AgentBaseMetrics(BaseModel):
num_verified_documents_total: int | None
num_verified_documents_core: int | None
verified_avg_score_core: float | None
num_verified_documents_base: int | float | None
verified_avg_score_base: float | None = None
base_doc_boost_factor: float | None = None
support_boost_factor: float | None = None
duration__s: float | None = None
class AgentRefinedMetrics(BaseModel):
refined_doc_boost_factor: float | None = None
refined_question_boost_factor: float | None = None
duration__s: float | None = None
class AgentAdditionalMetrics(BaseModel):
pass

View File

@@ -0,0 +1,135 @@
from datetime import datetime
from typing import cast
from langchain_core.runnables import RunnableConfig
from onyx.agents.agent_search.deep_search_a.main__graph.models import (
AgentAdditionalMetrics,
)
from onyx.agents.agent_search.deep_search_a.main__graph.models import AgentTimings
from onyx.agents.agent_search.deep_search_a.main__graph.operations import logger
from onyx.agents.agent_search.deep_search_a.main__graph.states import MainOutput
from onyx.agents.agent_search.deep_search_a.main__graph.states import MainState
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.agents.agent_search.shared_graph_utils.models import CombinedAgentMetrics
from onyx.db.chat import log_agent_metrics
from onyx.db.chat import log_agent_sub_question_results
def agent_logging(state: MainState, config: RunnableConfig) -> MainOutput:
now_start = datetime.now()
logger.info(f"--------{now_start}--------LOGGING NODE---")
agent_start_time = state.agent_start_time
agent_base_end_time = state.agent_base_end_time
agent_refined_start_time = state.agent_refined_start_time
agent_refined_end_time = state.agent_refined_end_time
agent_end_time = agent_refined_end_time or agent_base_end_time
agent_base_duration = None
if agent_base_end_time:
agent_base_duration = (agent_base_end_time - agent_start_time).total_seconds()
agent_refined_duration = None
if agent_refined_start_time and agent_refined_end_time:
agent_refined_duration = (
agent_refined_end_time - agent_refined_start_time
).total_seconds()
agent_full_duration = None
if agent_end_time:
agent_full_duration = (agent_end_time - agent_start_time).total_seconds()
agent_type = "refined" if agent_refined_duration else "base"
agent_base_metrics = state.agent_base_metrics
agent_refined_metrics = state.agent_refined_metrics
combined_agent_metrics = CombinedAgentMetrics(
timings=AgentTimings(
base_duration__s=agent_base_duration,
refined_duration__s=agent_refined_duration,
full_duration__s=agent_full_duration,
),
base_metrics=agent_base_metrics,
refined_metrics=agent_refined_metrics,
additional_metrics=AgentAdditionalMetrics(),
)
persona_id = None
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
if agent_a_config.search_request.persona:
persona_id = agent_a_config.search_request.persona.id
user_id = None
if agent_a_config.search_tool is not None:
user = agent_a_config.search_tool.user
if user:
user_id = user.id
# log the agent metrics
if agent_a_config.db_session is not None:
if agent_base_duration is not None:
log_agent_metrics(
db_session=agent_a_config.db_session,
user_id=user_id,
persona_id=persona_id,
agent_type=agent_type,
start_time=agent_start_time,
agent_metrics=combined_agent_metrics,
)
if agent_a_config.use_agentic_persistence:
# Persist the sub-answer in the database
db_session = agent_a_config.db_session
chat_session_id = agent_a_config.chat_session_id
primary_message_id = agent_a_config.message_id
sub_question_answer_results = state.decomp_answer_results
log_agent_sub_question_results(
db_session=db_session,
chat_session_id=chat_session_id,
primary_message_id=primary_message_id,
sub_question_answer_results=sub_question_answer_results,
)
# if chat_session_id is not None and primary_message_id is not None and sub_question_id is not None:
# create_sub_answer(
# db_session=db_session,
# chat_session_id=chat_session_id,
# primary_message_id=primary_message_id,
# sub_question_id=sub_question_id,
# answer=answer_str,
# # )
# pass
now_end = datetime.now()
main_output = MainOutput(
log_messages=[
f"{now_start} -- Main - Logging, Time taken: {now_end - now_start}"
],
)
logger.info(f"--------{now_end}--{now_end - now_start}--------LOGGING NODE END---")
for log_message in state.log_messages:
logger.info(log_message)
logger.info("")
if state.agent_base_metrics:
logger.info(f"Initial loop: {state.agent_base_metrics.duration__s}")
if state.agent_refined_metrics:
logger.info(f"Refined loop: {state.agent_refined_metrics.duration__s}")
if (
state.agent_base_metrics
and state.agent_refined_metrics
and state.agent_base_metrics.duration__s
and state.agent_refined_metrics.duration__s
):
logger.info(
f"Total time: {float(state.agent_base_metrics.duration__s) + float(state.agent_refined_metrics.duration__s)}"
)
return main_output

View File

@@ -0,0 +1,36 @@
from datetime import datetime
from typing import cast
from langchain_core.runnables import RunnableConfig
from onyx.agents.agent_search.deep_search_a.main__graph.operations import logger
from onyx.agents.agent_search.deep_search_a.main__graph.states import MainState
from onyx.agents.agent_search.deep_search_a.main__graph.states import RoutingDecision
from onyx.agents.agent_search.models import AgentSearchConfig
def agent_path_decision(state: MainState, config: RunnableConfig) -> RoutingDecision:
now_start = datetime.now()
cast(AgentSearchConfig, config["metadata"]["config"])
# perform_initial_search_path_decision = (
# agent_a_config.perform_initial_search_path_decision
# )
logger.info(f"--------{now_start}--------DECIDING TO SEARCH OR GO TO LLM---")
routing = "agent_search"
now_end = datetime.now()
logger.debug(
f"--------{now_end}--{now_end - now_start}--------DECIDING TO SEARCH OR GO TO LLM END---"
)
return RoutingDecision(
# Decide which route to take
routing_decision=routing,
log_messages=[
f"{now_end} -- Path decision: {routing}, Time taken: {now_end - now_start}"
],
)

View File

@@ -0,0 +1,31 @@
from datetime import datetime
from typing import Literal
from langgraph.types import Command
from onyx.agents.agent_search.deep_search_a.main__graph.states import MainState
def agent_path_routing(
state: MainState,
) -> Command[Literal["agent_search_start", "LLM"]]:
now_start = datetime.now()
routing = state.routing_decision if hasattr(state, "routing") else "agent_search"
if routing == "agent_search":
agent_path = "agent_search_start"
else:
agent_path = "LLM"
now_end = datetime.now()
return Command(
# state update
update={
"log_messages": [
f"{now_start} -- Main - Path routing: {agent_path}, Time taken: {now_end - now_start}"
]
},
# control flow
goto=agent_path,
)

View File

@@ -0,0 +1,59 @@
from datetime import datetime
from typing import cast
from langchain_core.runnables import RunnableConfig
from onyx.agents.agent_search.deep_search_a.main__graph.operations import logger
from onyx.agents.agent_search.deep_search_a.main__graph.states import (
ExploratorySearchUpdate,
)
from onyx.agents.agent_search.deep_search_a.main__graph.states import MainState
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
build_history_prompt,
)
from onyx.agents.agent_search.shared_graph_utils.utils import retrieve_search_docs
from onyx.configs.agent_configs import AGENT_EXPLORATORY_SEARCH_RESULTS
from onyx.context.search.models import InferenceSection
def agent_search_start(
state: MainState, config: RunnableConfig
) -> ExploratorySearchUpdate:
now_start = datetime.now()
logger.info(f"--------{now_start}--------EXPLORATORY SEARCH START---")
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
question = agent_a_config.search_request.query
chat_session_id = agent_a_config.chat_session_id
primary_message_id = agent_a_config.message_id
agent_a_config.fast_llm
history = build_history_prompt(agent_a_config, question)
if chat_session_id is None or primary_message_id is None:
raise ValueError(
"chat_session_id and message_id must be provided for agent search"
)
# Initial search to inform decomposition. Just get top 3 fits
search_tool = agent_a_config.search_tool
if search_tool is None:
raise ValueError("search_tool must be provided for agentic search")
retrieved_docs: list[InferenceSection] = retrieve_search_docs(search_tool, question)
exploratory_search_results = retrieved_docs[:AGENT_EXPLORATORY_SEARCH_RESULTS]
now_end = datetime.now()
logger.debug(
f"--------{now_end}--{now_end - now_start}--------EXPLORATORY SEARCH END---"
)
return ExploratorySearchUpdate(
exploratory_search_results=exploratory_search_results,
previous_history_summary=history,
log_messages=[
f"{now_start} -- Main - Exploratory Search, Time taken: {now_end - now_start}"
],
)

View File

@@ -0,0 +1,60 @@
from datetime import datetime
from typing import cast
from langchain_core.callbacks.manager import dispatch_custom_event
from langchain_core.messages import HumanMessage
from langchain_core.runnables import RunnableConfig
from onyx.agents.agent_search.deep_search_a.main__graph.operations import logger
from onyx.agents.agent_search.deep_search_a.main__graph.states import AnswerComparison
from onyx.agents.agent_search.deep_search_a.main__graph.states import MainState
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.agents.agent_search.shared_graph_utils.prompts import ANSWER_COMPARISON_PROMPT
from onyx.chat.models import RefinedAnswerImprovement
def answer_comparison(state: MainState, config: RunnableConfig) -> AnswerComparison:
now_start = datetime.now()
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
question = agent_a_config.search_request.query
initial_answer = state.initial_answer
refined_answer = state.refined_answer
logger.info(f"--------{now_start}--------ANSWER COMPARISON STARTED--")
answer_comparison_prompt = ANSWER_COMPARISON_PROMPT.format(
question=question, initial_answer=initial_answer, refined_answer=refined_answer
)
msg = [HumanMessage(content=answer_comparison_prompt)]
# Get the rewritten queries in a defined format
model = agent_a_config.fast_llm
# no need to stream this
resp = model.invoke(msg)
refined_answer_improvement = (
isinstance(resp.content, str) and "yes" in resp.content.lower()
)
dispatch_custom_event(
"refined_answer_improvement",
RefinedAnswerImprovement(
refined_answer_improvement=refined_answer_improvement,
),
)
now_end = datetime.now()
logger.info(
f"{now_start} -- MAIN - Answer comparison, Time taken: {now_end - now_start}"
)
return AnswerComparison(
refined_answer_improvement_eval=refined_answer_improvement,
log_messages=[
f"{now_start} -- Answer comparison: {refined_answer_improvement}, Time taken: {now_end - now_start}"
],
)

View File

@@ -0,0 +1,83 @@
from datetime import datetime
from typing import Any
from typing import cast
from langchain_core.callbacks.manager import dispatch_custom_event
from langchain_core.messages import HumanMessage
from langchain_core.messages import merge_content
from langchain_core.runnables import RunnableConfig
from onyx.agents.agent_search.deep_search_a.main__graph.operations import logger
from onyx.agents.agent_search.deep_search_a.main__graph.states import (
InitialAnswerUpdate,
)
from onyx.agents.agent_search.deep_search_a.main__graph.states import MainState
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.agents.agent_search.shared_graph_utils.prompts import DIRECT_LLM_PROMPT
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_persona_agent_prompt_expressions,
)
from onyx.chat.models import AgentAnswerPiece
def direct_llm_handling(
state: MainState, config: RunnableConfig
) -> InitialAnswerUpdate:
now_start = datetime.now()
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
question = agent_a_config.search_request.query
persona_contextualialized_prompt = get_persona_agent_prompt_expressions(
agent_a_config.search_request.persona
).contextualized_prompt
logger.info(f"--------{now_start}--------LLM HANDLING START---")
model = agent_a_config.fast_llm
msg = [
HumanMessage(
content=DIRECT_LLM_PROMPT.format(
persona_specification=persona_contextualialized_prompt,
question=question,
)
)
]
streamed_tokens: list[str | list[str | dict[str, Any]]] = [""]
for message in model.stream(msg):
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
content = message.content
if not isinstance(content, str):
raise ValueError(
f"Expected content to be a string, but got {type(content)}"
)
dispatch_custom_event(
"initial_agent_answer",
AgentAnswerPiece(
answer_piece=content,
level=0,
level_question_nr=0,
answer_type="agent_level_answer",
),
)
streamed_tokens.append(content)
response = merge_content(*streamed_tokens)
answer = cast(str, response)
now_end = datetime.now()
logger.info(f"--------{now_end}--{now_end - now_start}--------LLM HANDLING END---")
return InitialAnswerUpdate(
initial_answer=answer,
initial_agent_stats=None,
generated_sub_questions=[],
agent_base_end_time=now_end,
agent_base_metrics=None,
log_messages=[
f"{now_start} -- Main - LLM handling: {answer}, Time taken: {now_end - now_start}"
],
)

View File

@@ -0,0 +1,126 @@
import json
import re
from datetime import datetime
from typing import cast
from langchain_core.messages import HumanMessage
from langchain_core.runnables import RunnableConfig
from onyx.agents.agent_search.deep_search_a.main__graph.operations import logger
from onyx.agents.agent_search.deep_search_a.main__graph.states import (
EntityTermExtractionUpdate,
)
from onyx.agents.agent_search.deep_search_a.main__graph.states import MainState
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
trim_prompt_piece,
)
from onyx.agents.agent_search.shared_graph_utils.models import Entity
from onyx.agents.agent_search.shared_graph_utils.models import (
EntityRelationshipTermExtraction,
)
from onyx.agents.agent_search.shared_graph_utils.models import Relationship
from onyx.agents.agent_search.shared_graph_utils.models import Term
from onyx.agents.agent_search.shared_graph_utils.prompts import ENTITY_TERM_PROMPT
from onyx.agents.agent_search.shared_graph_utils.utils import format_docs
def entity_term_extraction_llm(
state: MainState, config: RunnableConfig
) -> EntityTermExtractionUpdate:
now_start = datetime.now()
logger.info(f"--------{now_start}--------GENERATE ENTITIES & TERMS---")
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
if not agent_a_config.allow_refinement:
now_end = datetime.now()
return EntityTermExtractionUpdate(
entity_relation_term_extractions=EntityRelationshipTermExtraction(
entities=[],
relationships=[],
terms=[],
),
log_messages=[
f"{now_start} -- Main - ETR Extraction, Time taken: {now_end - now_start}"
],
)
# first four lines duplicates from generate_initial_answer
question = agent_a_config.search_request.query
initial_search_docs = state.exploratory_search_results[:15]
# start with the entity/term/extraction
doc_context = format_docs(initial_search_docs)
doc_context = trim_prompt_piece(
agent_a_config.fast_llm.config, doc_context, ENTITY_TERM_PROMPT + question
)
msg = [
HumanMessage(
content=ENTITY_TERM_PROMPT.format(question=question, context=doc_context),
)
]
fast_llm = agent_a_config.fast_llm
# Grader
llm_response = fast_llm.invoke(
prompt=msg,
)
cleaned_response = re.sub(r"```json\n|\n```", "", str(llm_response.content))
parsed_response = json.loads(cleaned_response)
entities = []
relationships = []
terms = []
for entity in parsed_response.get("retrieved_entities_relationships", {}).get(
"entities", {}
):
entity_name = entity.get("entity_name", "")
entity_type = entity.get("entity_type", "")
entities.append(Entity(entity_name=entity_name, entity_type=entity_type))
for relationship in parsed_response.get("retrieved_entities_relationships", {}).get(
"relationships", {}
):
relationship_name = relationship.get("relationship_name", "")
relationship_type = relationship.get("relationship_type", "")
relationship_entities = relationship.get("relationship_entities", [])
relationships.append(
Relationship(
relationship_name=relationship_name,
relationship_type=relationship_type,
relationship_entities=relationship_entities,
)
)
for term in parsed_response.get("retrieved_entities_relationships", {}).get(
"terms", {}
):
term_name = term.get("term_name", "")
term_type = term.get("term_type", "")
term_similar_to = term.get("term_similar_to", [])
terms.append(
Term(
term_name=term_name,
term_type=term_type,
term_similar_to=term_similar_to,
)
)
now_end = datetime.now()
logger.info(
f"{now_start} -- MAIN - Entity term extraction, Time taken: {now_end - now_start}"
)
return EntityTermExtractionUpdate(
entity_relation_term_extractions=EntityRelationshipTermExtraction(
entities=entities,
relationships=relationships,
terms=terms,
),
log_messages=[
f"{now_start} -- Main - ETR Extraction, Time taken: {now_end - now_start}"
],
)

View File

@@ -0,0 +1,58 @@
from datetime import datetime
from typing import cast
from langchain_core.messages import HumanMessage
from langchain_core.runnables import RunnableConfig
from onyx.agents.agent_search.deep_search_a.main__graph.operations import logger
from onyx.agents.agent_search.deep_search_a.main__graph.states import (
InitialAnswerBASEUpdate,
)
from onyx.agents.agent_search.deep_search_a.main__graph.states import MainState
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
trim_prompt_piece,
)
from onyx.agents.agent_search.shared_graph_utils.prompts import INITIAL_RAG_BASE_PROMPT
from onyx.agents.agent_search.shared_graph_utils.utils import format_docs
def generate_initial_base_search_only_answer(
state: MainState,
config: RunnableConfig,
) -> InitialAnswerBASEUpdate:
now_start = datetime.now()
logger.info(f"--------{now_start}--------GENERATE INITIAL BASE ANSWER---")
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
question = agent_a_config.search_request.query
original_question_docs = state.all_original_question_documents
model = agent_a_config.fast_llm
doc_context = format_docs(original_question_docs)
doc_context = trim_prompt_piece(
model.config, doc_context, INITIAL_RAG_BASE_PROMPT + question
)
msg = [
HumanMessage(
content=INITIAL_RAG_BASE_PROMPT.format(
question=question,
context=doc_context,
)
)
]
# Grader
response = model.invoke(msg)
answer = response.pretty_repr()
now_end = datetime.now()
logger.debug(
f"--------{now_end}--{now_end - now_start}--------INITIAL BASE ANSWER END---\n\n"
)
return InitialAnswerBASEUpdate(initial_base_answer=answer)

View File

@@ -0,0 +1,335 @@
from datetime import datetime
from typing import Any
from typing import cast
from langchain_core.callbacks.manager import dispatch_custom_event
from langchain_core.messages import HumanMessage
from langchain_core.messages import merge_content
from langchain_core.runnables import RunnableConfig
from onyx.agents.agent_search.deep_search_a.main__graph.models import (
AgentRefinedMetrics,
)
from onyx.agents.agent_search.deep_search_a.main__graph.operations import get_query_info
from onyx.agents.agent_search.deep_search_a.main__graph.operations import logger
from onyx.agents.agent_search.deep_search_a.main__graph.operations import (
remove_document_citations,
)
from onyx.agents.agent_search.deep_search_a.main__graph.states import MainState
from onyx.agents.agent_search.deep_search_a.main__graph.states import (
RefinedAnswerUpdate,
)
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
get_prompt_enrichment_components,
)
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
trim_prompt_piece,
)
from onyx.agents.agent_search.shared_graph_utils.models import InferenceSection
from onyx.agents.agent_search.shared_graph_utils.models import RefinedAgentStats
from onyx.agents.agent_search.shared_graph_utils.operators import (
dedup_inference_sections,
)
from onyx.agents.agent_search.shared_graph_utils.prompts import REVISED_RAG_PROMPT
from onyx.agents.agent_search.shared_graph_utils.prompts import (
REVISED_RAG_PROMPT_NO_SUB_QUESTIONS,
)
from onyx.agents.agent_search.shared_graph_utils.prompts import (
SUB_QUESTION_ANSWER_TEMPLATE,
)
from onyx.agents.agent_search.shared_graph_utils.prompts import UNKNOWN_ANSWER
from onyx.agents.agent_search.shared_graph_utils.utils import (
dispatch_main_answer_stop_info,
)
from onyx.agents.agent_search.shared_graph_utils.utils import format_docs
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
from onyx.chat.models import AgentAnswerPiece
from onyx.chat.models import ExtendedToolResponse
from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS
from onyx.configs.agent_configs import AGENT_MIN_ORIG_QUESTION_DOCS
from onyx.tools.tool_implementations.search.search_tool import yield_search_responses
def generate_refined_answer(
state: MainState, config: RunnableConfig
) -> RefinedAnswerUpdate:
now_start = datetime.now()
logger.info(f"--------{now_start}--------GENERATE REFINED ANSWER---")
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
question = agent_a_config.search_request.query
prompt_enrichment_components = get_prompt_enrichment_components(agent_a_config)
persona_contextualized_prompt = (
prompt_enrichment_components.persona_prompts.contextualized_prompt
)
initial_documents = state.documents
refined_documents = state.refined_documents
sub_questions_cited_docs = state.cited_docs
all_original_question_documents = state.all_original_question_documents
consolidated_context_docs: list[InferenceSection] = sub_questions_cited_docs
counter = 0
for original_doc_number, original_doc in enumerate(all_original_question_documents):
if original_doc_number not in sub_questions_cited_docs:
if (
counter <= AGENT_MIN_ORIG_QUESTION_DOCS
or len(consolidated_context_docs)
< 1.5
* AGENT_MAX_ANSWER_CONTEXT_DOCS # allow for larger context in refinement
):
consolidated_context_docs.append(original_doc)
counter += 1
# sort docs by their scores - though the scores refer to different questions
relevant_docs = dedup_inference_sections(
consolidated_context_docs, consolidated_context_docs
)
query_info = get_query_info(state.original_question_retrieval_results)
if agent_a_config.search_tool is None:
raise ValueError("search_tool must be provided for agentic search")
# stream refined answer docs
for tool_response in yield_search_responses(
query=question,
reranked_sections=relevant_docs,
final_context_sections=relevant_docs,
search_query_info=query_info,
get_section_relevance=lambda: None, # TODO: add relevance
search_tool=agent_a_config.search_tool,
):
dispatch_custom_event(
"tool_response",
ExtendedToolResponse(
id=tool_response.id,
response=tool_response.response,
level=1,
level_question_nr=0, # 0, 0 is the base question
),
)
if len(initial_documents) > 0:
revision_doc_effectiveness = len(relevant_docs) / len(initial_documents)
elif len(refined_documents) == 0:
revision_doc_effectiveness = 0.0
else:
revision_doc_effectiveness = 10.0
decomp_answer_results = state.decomp_answer_results
# revised_answer_results = state.refined_decomp_answer_results
answered_qa_list: list[str] = []
decomp_questions = []
initial_good_sub_questions: list[str] = []
new_revised_good_sub_questions: list[str] = []
sub_question_nr = 1
for decomp_answer_result in decomp_answer_results:
question_level, question_nr = parse_question_id(
decomp_answer_result.question_id
)
decomp_questions.append(decomp_answer_result.question)
if (
decomp_answer_result.quality.lower().startswith("yes")
and len(decomp_answer_result.answer) > 0
and decomp_answer_result.answer != UNKNOWN_ANSWER
):
answered_qa_list.append(
SUB_QUESTION_ANSWER_TEMPLATE.format(
sub_question=decomp_answer_result.question,
sub_answer=decomp_answer_result.answer,
sub_question_nr=sub_question_nr,
)
)
if question_level == 0:
initial_good_sub_questions.append(decomp_answer_result.question)
else:
new_revised_good_sub_questions.append(decomp_answer_result.question)
sub_question_nr += 1
initial_good_sub_questions = list(set(initial_good_sub_questions))
new_revised_good_sub_questions = list(set(new_revised_good_sub_questions))
total_good_sub_questions = list(
set(initial_good_sub_questions + new_revised_good_sub_questions)
)
if len(initial_good_sub_questions) > 0:
revision_question_efficiency: float = len(total_good_sub_questions) / len(
initial_good_sub_questions
)
elif len(new_revised_good_sub_questions) > 0:
revision_question_efficiency = 10.0
else:
revision_question_efficiency = 1.0
sub_question_answer_str = "\n\n------\n\n".join(list(set(answered_qa_list)))
# original answer
initial_answer = state.initial_answer
# Determine which persona-specification prompt to use
# Determine which base prompt to use given the sub-question information
if len(answered_qa_list) > 0:
base_prompt = REVISED_RAG_PROMPT
else:
base_prompt = REVISED_RAG_PROMPT_NO_SUB_QUESTIONS
model = agent_a_config.fast_llm
relevant_docs_str = format_docs(relevant_docs)
relevant_docs_str = trim_prompt_piece(
model.config,
relevant_docs_str,
base_prompt
+ question
+ sub_question_answer_str
+ initial_answer
+ persona_contextualized_prompt
+ prompt_enrichment_components.history,
)
msg = [
HumanMessage(
content=base_prompt.format(
question=question,
history=prompt_enrichment_components.history,
answered_sub_questions=remove_document_citations(
sub_question_answer_str
),
relevant_docs=relevant_docs,
initial_answer=remove_document_citations(initial_answer),
persona_specification=persona_contextualized_prompt,
date_prompt=prompt_enrichment_components.date_str,
)
)
]
# Grader
streamed_tokens: list[str | list[str | dict[str, Any]]] = [""]
dispatch_timings: list[float] = []
for message in model.stream(msg):
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
content = message.content
if not isinstance(content, str):
raise ValueError(
f"Expected content to be a string, but got {type(content)}"
)
start_stream_token = datetime.now()
dispatch_custom_event(
"refined_agent_answer",
AgentAnswerPiece(
answer_piece=content,
level=1,
level_question_nr=0,
answer_type="agent_level_answer",
),
)
end_stream_token = datetime.now()
dispatch_timings.append((end_stream_token - start_stream_token).microseconds)
streamed_tokens.append(content)
logger.info(
f"Average dispatch time for refined answer: {sum(dispatch_timings) / len(dispatch_timings)}"
)
dispatch_main_answer_stop_info(1)
response = merge_content(*streamed_tokens)
answer = cast(str, response)
# refined_agent_stats = _calculate_refined_agent_stats(
# state.decomp_answer_results, state.original_question_retrieval_stats
# )
refined_agent_stats = RefinedAgentStats(
revision_doc_efficiency=revision_doc_effectiveness,
revision_question_efficiency=revision_question_efficiency,
)
logger.debug(f"\n\n---INITIAL ANSWER ---\n\n Answer:\n Agent: {initial_answer}")
logger.debug("-" * 10)
logger.debug(f"\n\n---REVISED AGENT ANSWER ---\n\n Answer:\n Agent: {answer}")
logger.debug("-" * 100)
if state.initial_agent_stats:
initial_doc_boost_factor = state.initial_agent_stats.agent_effectiveness.get(
"utilized_chunk_ratio", "--"
)
initial_support_boost_factor = (
state.initial_agent_stats.agent_effectiveness.get("support_ratio", "--")
)
num_initial_verified_docs = state.initial_agent_stats.original_question.get(
"num_verified_documents", "--"
)
initial_verified_docs_avg_score = (
state.initial_agent_stats.original_question.get("verified_avg_score", "--")
)
initial_sub_questions_verified_docs = (
state.initial_agent_stats.sub_questions.get("num_verified_documents", "--")
)
logger.debug("INITIAL AGENT STATS")
logger.debug(f"Document Boost Factor: {initial_doc_boost_factor}")
logger.debug(f"Support Boost Factor: {initial_support_boost_factor}")
logger.debug(f"Originally Verified Docs: {num_initial_verified_docs}")
logger.debug(
f"Originally Verified Docs Avg Score: {initial_verified_docs_avg_score}"
)
logger.debug(
f"Sub-Questions Verified Docs: {initial_sub_questions_verified_docs}"
)
if refined_agent_stats:
logger.debug("-" * 10)
logger.debug("REFINED AGENT STATS")
logger.debug(
f"Revision Doc Factor: {refined_agent_stats.revision_doc_efficiency}"
)
logger.debug(
f"Revision Question Factor: {refined_agent_stats.revision_question_efficiency}"
)
now_end = datetime.now()
logger.debug(
f"--------{now_end}--{now_end - now_start}--------INITIAL AGENT ANSWER END---\n\n"
)
agent_refined_end_time = datetime.now()
if state.agent_refined_start_time:
agent_refined_duration = (
agent_refined_end_time - state.agent_refined_start_time
).total_seconds()
else:
agent_refined_duration = None
agent_refined_metrics = AgentRefinedMetrics(
refined_doc_boost_factor=refined_agent_stats.revision_doc_efficiency,
refined_question_boost_factor=refined_agent_stats.revision_question_efficiency,
duration__s=agent_refined_duration,
)
now_end = datetime.now()
logger.info(
f"{now_start} -- MAIN - Generate refined answer, Time taken: {now_end - now_start}"
)
return RefinedAnswerUpdate(
refined_answer=answer,
refined_answer_quality=True, # TODO: replace this with the actual check value
refined_agent_stats=refined_agent_stats,
agent_refined_end_time=agent_refined_end_time,
agent_refined_metrics=agent_refined_metrics,
log_messages=[
f"{now_start} -- MAIN - Generate refined answer, Time taken: {now_end - now_start}"
],
)

View File

@@ -0,0 +1,41 @@
from datetime import datetime
from onyx.agents.agent_search.deep_search_a.initial__individual_sub_answer__subgraph.states import (
AnswerQuestionOutput,
)
from onyx.agents.agent_search.deep_search_a.main__graph.operations import logger
from onyx.agents.agent_search.deep_search_a.main__graph.states import (
DecompAnswersUpdate,
)
from onyx.agents.agent_search.shared_graph_utils.operators import (
dedup_inference_sections,
)
def ingest_refined_answers(
state: AnswerQuestionOutput,
) -> DecompAnswersUpdate:
now_start = datetime.now()
logger.info(f"--------{now_start}--------INGEST FOLLOW UP ANSWERS---")
documents = []
answer_results = state.answer_results if hasattr(state, "answer_results") else []
for answer_result in answer_results:
documents.extend(answer_result.documents)
now_end = datetime.now()
logger.debug(
f"--------{now_end}--{now_end - now_start}--------INGEST FOLLOW UP ANSWERS END---"
)
return DecompAnswersUpdate(
# Deduping is done by the documents operator for the main graph
# so we might not need to dedup here
documents=dedup_inference_sections(documents, []),
decomp_answer_results=answer_results,
log_messages=[
f"{now_start} -- Main - Ingest refined answers, Time taken: {now_end - now_start}"
],
)

View File

@@ -0,0 +1,47 @@
from datetime import datetime
from typing import cast
from langchain_core.runnables import RunnableConfig
from onyx.agents.agent_search.deep_search_a.main__graph.operations import logger
from onyx.agents.agent_search.deep_search_a.main__graph.states import MainState
from onyx.agents.agent_search.deep_search_a.main__graph.states import (
RequireRefinedAnswerUpdate,
)
from onyx.agents.agent_search.models import AgentSearchConfig
def refined_answer_decision(
state: MainState, config: RunnableConfig
) -> RequireRefinedAnswerUpdate:
now_start = datetime.now()
logger.info(f"--------{now_start}--------REFINED ANSWER DECISION---")
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
if "?" in agent_a_config.search_request.query:
decision = False
else:
decision = True
decision = True
now_end = datetime.now()
logger.info(
f"{now_start} -- MAIN - Refined answer decision, Time taken: {now_end - now_start}"
)
log_messages = [
f"{now_start} -- Main - Refined answer decision: {decision}, Time taken: {now_end - now_start}"
]
if agent_a_config.allow_refinement:
return RequireRefinedAnswerUpdate(
require_refined_answer_eval=decision,
log_messages=log_messages,
)
else:
return RequireRefinedAnswerUpdate(
require_refined_answer_eval=False,
log_messages=log_messages,
)

View File

@@ -0,0 +1,122 @@
from datetime import datetime
from typing import cast
from langchain_core.callbacks.manager import dispatch_custom_event
from langchain_core.messages import HumanMessage
from langchain_core.messages import merge_content
from langchain_core.runnables import RunnableConfig
from onyx.agents.agent_search.deep_search_a.main__graph.models import (
FollowUpSubQuestion,
)
from onyx.agents.agent_search.deep_search_a.main__graph.operations import (
dispatch_subquestion,
)
from onyx.agents.agent_search.deep_search_a.main__graph.operations import logger
from onyx.agents.agent_search.deep_search_a.main__graph.states import (
FollowUpSubQuestionsUpdate,
)
from onyx.agents.agent_search.deep_search_a.main__graph.states import MainState
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
build_history_prompt,
)
from onyx.agents.agent_search.shared_graph_utils.prompts import (
DEEP_DECOMPOSE_PROMPT_WITH_ENTITIES,
)
from onyx.agents.agent_search.shared_graph_utils.utils import dispatch_separated
from onyx.agents.agent_search.shared_graph_utils.utils import (
format_entity_term_extraction,
)
from onyx.agents.agent_search.shared_graph_utils.utils import make_question_id
from onyx.tools.models import ToolCallKickoff
def refined_sub_question_creation(
state: MainState, config: RunnableConfig
) -> FollowUpSubQuestionsUpdate:
""" """
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
dispatch_custom_event(
"start_refined_answer_creation",
ToolCallKickoff(
tool_name="agent_search_1",
tool_args={
"query": agent_a_config.search_request.query,
"answer": state.initial_answer,
},
),
)
now_start = datetime.now()
logger.info(f"--------{now_start}--------FOLLOW UP DECOMPOSE---")
agent_refined_start_time = datetime.now()
question = agent_a_config.search_request.query
base_answer = state.initial_answer
history = build_history_prompt(agent_a_config, question)
# get the entity term extraction dict and properly format it
entity_retlation_term_extractions = state.entity_relation_term_extractions
entity_term_extraction_str = format_entity_term_extraction(
entity_retlation_term_extractions
)
initial_question_answers = state.decomp_answer_results
addressed_question_list = [
x.question for x in initial_question_answers if "yes" in x.quality.lower()
]
failed_question_list = [
x.question for x in initial_question_answers if "no" in x.quality.lower()
]
msg = [
HumanMessage(
content=DEEP_DECOMPOSE_PROMPT_WITH_ENTITIES.format(
question=question,
history=history,
entity_term_extraction_str=entity_term_extraction_str,
base_answer=base_answer,
answered_sub_questions="\n - ".join(addressed_question_list),
failed_sub_questions="\n - ".join(failed_question_list),
),
)
]
# Grader
model = agent_a_config.fast_llm
streamed_tokens = dispatch_separated(model.stream(msg), dispatch_subquestion(1))
response = merge_content(*streamed_tokens)
if isinstance(response, str):
parsed_response = [q for q in response.split("\n") if q.strip() != ""]
else:
raise ValueError("LLM response is not a string")
refined_sub_question_dict = {}
for sub_question_nr, sub_question in enumerate(parsed_response):
refined_sub_question = FollowUpSubQuestion(
sub_question=sub_question,
sub_question_id=make_question_id(1, sub_question_nr + 1),
verified=False,
answered=False,
answer="",
)
refined_sub_question_dict[sub_question_nr + 1] = refined_sub_question
now_end = datetime.now()
logger.info(
f"{now_start} -- MAIN - Refined sub question creation, Time taken: {now_end - now_start}"
)
return FollowUpSubQuestionsUpdate(
refined_sub_questions=refined_sub_question_dict,
agent_refined_start_time=agent_refined_start_time,
)

View File

@@ -0,0 +1,145 @@
import re
from collections.abc import Callable
from langchain_core.callbacks.manager import dispatch_custom_event
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkStats
from onyx.agents.agent_search.shared_graph_utils.models import InitialAgentResultStats
from onyx.agents.agent_search.shared_graph_utils.models import QueryResult
from onyx.agents.agent_search.shared_graph_utils.models import (
QuestionAnswerResults,
)
from onyx.chat.models import SubQuestionPiece
from onyx.tools.models import SearchQueryInfo
from onyx.utils.logger import setup_logger
logger = setup_logger()
def remove_document_citations(text: str) -> str:
"""
Removes citation expressions of format '[[D1]]()' from text.
The number after D can vary.
Args:
text: Input text containing citations
Returns:
Text with citations removed
"""
# Pattern explanation:
# \[\[D\d+\]\]\(\) matches:
# \[\[ - literal [[ characters
# D - literal D character
# \d+ - one or more digits
# \]\] - literal ]] characters
# \(\) - literal () characters
return re.sub(r"\[\[(?:D|Q)\d+\]\]\(\)", "", text)
def dispatch_subquestion(level: int) -> Callable[[str, int], None]:
def _helper(sub_question_part: str, num: int) -> None:
dispatch_custom_event(
"decomp_qs",
SubQuestionPiece(
sub_question=sub_question_part,
level=level,
level_question_nr=num,
),
)
return _helper
def calculate_initial_agent_stats(
decomp_answer_results: list[QuestionAnswerResults],
original_question_stats: AgentChunkStats,
) -> InitialAgentResultStats:
initial_agent_result_stats: InitialAgentResultStats = InitialAgentResultStats(
sub_questions={},
original_question={},
agent_effectiveness={},
)
orig_verified = original_question_stats.verified_count
orig_support_score = original_question_stats.verified_avg_scores
verified_document_chunk_ids = []
support_scores = 0.0
for decomp_answer_result in decomp_answer_results:
verified_document_chunk_ids += (
decomp_answer_result.sub_question_retrieval_stats.verified_doc_chunk_ids
)
if (
decomp_answer_result.sub_question_retrieval_stats.verified_avg_scores
is not None
):
support_scores += (
decomp_answer_result.sub_question_retrieval_stats.verified_avg_scores
)
verified_document_chunk_ids = list(set(verified_document_chunk_ids))
# Calculate sub-question stats
if (
verified_document_chunk_ids
and len(verified_document_chunk_ids) > 0
and support_scores is not None
):
sub_question_stats: dict[str, float | int | None] = {
"num_verified_documents": len(verified_document_chunk_ids),
"verified_avg_score": float(support_scores / len(decomp_answer_results)),
}
else:
sub_question_stats = {"num_verified_documents": 0, "verified_avg_score": None}
initial_agent_result_stats.sub_questions.update(sub_question_stats)
# Get original question stats
initial_agent_result_stats.original_question.update(
{
"num_verified_documents": original_question_stats.verified_count,
"verified_avg_score": original_question_stats.verified_avg_scores,
}
)
# Calculate chunk utilization ratio
sub_verified = initial_agent_result_stats.sub_questions["num_verified_documents"]
chunk_ratio: float | None = None
if sub_verified is not None and orig_verified is not None and orig_verified > 0:
chunk_ratio = (float(sub_verified) / orig_verified) if sub_verified > 0 else 0.0
elif sub_verified is not None and sub_verified > 0:
chunk_ratio = 10.0
initial_agent_result_stats.agent_effectiveness["utilized_chunk_ratio"] = chunk_ratio
if (
orig_support_score is None
or orig_support_score == 0.0
and initial_agent_result_stats.sub_questions["verified_avg_score"] is None
):
initial_agent_result_stats.agent_effectiveness["support_ratio"] = None
elif orig_support_score is None or orig_support_score == 0.0:
initial_agent_result_stats.agent_effectiveness["support_ratio"] = 10
elif initial_agent_result_stats.sub_questions["verified_avg_score"] is None:
initial_agent_result_stats.agent_effectiveness["support_ratio"] = 0
else:
initial_agent_result_stats.agent_effectiveness["support_ratio"] = (
initial_agent_result_stats.sub_questions["verified_avg_score"]
/ orig_support_score
)
return initial_agent_result_stats
def get_query_info(results: list[QueryResult]) -> SearchQueryInfo:
# Use the query info from the base document retrieval
# TODO: see if this is the right way to do this
query_infos = [
result.query_info for result in results if result.query_info is not None
]
if len(query_infos) == 0:
raise ValueError("No query info found")
return query_infos[0]

View File

@@ -0,0 +1,180 @@
from datetime import datetime
from operator import add
from typing import Annotated
from typing import TypedDict
from pydantic import BaseModel
from onyx.agents.agent_search.core_state import CoreState
from onyx.agents.agent_search.deep_search_a.main__graph.models import AgentBaseMetrics
from onyx.agents.agent_search.deep_search_a.main__graph.models import (
AgentRefinedMetrics,
)
from onyx.agents.agent_search.deep_search_a.main__graph.models import (
FollowUpSubQuestion,
)
from onyx.agents.agent_search.deep_search_a.util__expanded_retrieval__subgraph.models import (
ExpandedRetrievalResult,
)
from onyx.agents.agent_search.orchestration.states import ToolCallUpdate
from onyx.agents.agent_search.orchestration.states import ToolChoiceInput
from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkStats
from onyx.agents.agent_search.shared_graph_utils.models import (
EntityRelationshipTermExtraction,
)
from onyx.agents.agent_search.shared_graph_utils.models import InitialAgentResultStats
from onyx.agents.agent_search.shared_graph_utils.models import QueryResult
from onyx.agents.agent_search.shared_graph_utils.models import (
QuestionAnswerResults,
)
from onyx.agents.agent_search.shared_graph_utils.models import RefinedAgentStats
from onyx.agents.agent_search.shared_graph_utils.operators import (
dedup_inference_sections,
)
from onyx.agents.agent_search.shared_graph_utils.operators import (
dedup_question_answer_results,
)
from onyx.context.search.models import InferenceSection
### States ###
## Update States
class LoggerUpdate(BaseModel):
log_messages: Annotated[list[str], add] = []
class RefinedAgentStartStats(BaseModel):
agent_refined_start_time: datetime | None = None
class RefinedAgentEndStats(BaseModel):
agent_refined_end_time: datetime | None = None
agent_refined_metrics: AgentRefinedMetrics = AgentRefinedMetrics()
class BaseDecompUpdate(RefinedAgentStartStats, RefinedAgentEndStats):
agent_start_time: datetime = datetime.now()
previous_history: str = ""
initial_decomp_questions: list[str] = []
class ExploratorySearchUpdate(LoggerUpdate):
exploratory_search_results: list[InferenceSection] = []
previous_history_summary: str = ""
class AnswerComparison(LoggerUpdate):
refined_answer_improvement_eval: bool = False
class RoutingDecision(LoggerUpdate):
routing_decision: str = ""
# Not used in current graph
class InitialAnswerBASEUpdate(BaseModel):
initial_base_answer: str = ""
class InitialAnswerUpdate(LoggerUpdate):
initial_answer: str = ""
initial_agent_stats: InitialAgentResultStats | None = None
generated_sub_questions: list[str] = []
agent_base_end_time: datetime | None = None
agent_base_metrics: AgentBaseMetrics | None = None
class RefinedAnswerUpdate(RefinedAgentEndStats, LoggerUpdate):
refined_answer: str = ""
refined_agent_stats: RefinedAgentStats | None = None
refined_answer_quality: bool = False
class InitialAnswerQualityUpdate(LoggerUpdate):
initial_answer_quality_eval: bool = False
class RequireRefinedAnswerUpdate(LoggerUpdate):
require_refined_answer_eval: bool = True
class DecompAnswersUpdate(LoggerUpdate):
documents: Annotated[list[InferenceSection], dedup_inference_sections] = []
context_documents: Annotated[list[InferenceSection], dedup_inference_sections] = []
cited_docs: Annotated[
list[InferenceSection], dedup_inference_sections
] = [] # cited docs from sub-answers are used for answer context
decomp_answer_results: Annotated[
list[QuestionAnswerResults], dedup_question_answer_results
] = []
class FollowUpDecompAnswersUpdate(LoggerUpdate):
refined_documents: Annotated[list[InferenceSection], dedup_inference_sections] = []
refined_decomp_answer_results: Annotated[list[QuestionAnswerResults], add] = []
class ExpandedRetrievalUpdate(LoggerUpdate):
all_original_question_documents: Annotated[
list[InferenceSection], dedup_inference_sections
]
original_question_retrieval_results: list[QueryResult] = []
original_question_retrieval_stats: AgentChunkStats = AgentChunkStats()
class EntityTermExtractionUpdate(LoggerUpdate):
entity_relation_term_extractions: EntityRelationshipTermExtraction = (
EntityRelationshipTermExtraction()
)
class FollowUpSubQuestionsUpdate(RefinedAgentStartStats):
refined_sub_questions: dict[int, FollowUpSubQuestion] = {}
## Graph Input State
## Graph Input State
class MainInput(CoreState):
pass
## Graph State
class MainState(
# This includes the core state
MainInput,
ToolChoiceInput,
ToolCallUpdate,
ToolChoiceUpdate,
BaseDecompUpdate,
InitialAnswerUpdate,
InitialAnswerBASEUpdate,
DecompAnswersUpdate,
ExpandedRetrievalUpdate,
EntityTermExtractionUpdate,
InitialAnswerQualityUpdate,
RequireRefinedAnswerUpdate,
FollowUpSubQuestionsUpdate,
FollowUpDecompAnswersUpdate,
RefinedAnswerUpdate,
RefinedAgentStartStats,
RefinedAgentEndStats,
RoutingDecision,
AnswerComparison,
ExploratorySearchUpdate,
):
# expanded_retrieval_result: Annotated[list[ExpandedRetrievalResult], add]
base_raw_search_result: Annotated[list[ExpandedRetrievalResult], add]
## Graph Output State - presently not used
class MainOutput(TypedDict):
log_messages: list[str]

View File

@@ -0,0 +1,28 @@
from collections.abc import Hashable
from datetime import datetime
from langgraph.types import Send
from onyx.agents.agent_search.deep_search_a.initial__individual_sub_answer__subgraph.states import (
AnswerQuestionInput,
)
from onyx.agents.agent_search.deep_search_a.util__expanded_retrieval__subgraph.states import (
ExpandedRetrievalInput,
)
from onyx.utils.logger import setup_logger
logger = setup_logger()
def send_to_expanded_refined_retrieval(state: AnswerQuestionInput) -> Send | Hashable:
logger.debug("sending to expanded retrieval for follow up question via edge")
datetime.now()
return Send(
"refined_sub_question_expanded_retrieval",
ExpandedRetrievalInput(
question=state.question,
sub_question_id=state.question_id,
base_search=False,
log_messages=[f"{datetime.now()} -- Sending to expanded retrieval"],
),
)

View File

@@ -0,0 +1,123 @@
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
from onyx.agents.agent_search.deep_search_a.initial__individual_sub_answer__subgraph.nodes.answer_check import (
answer_check,
)
from onyx.agents.agent_search.deep_search_a.initial__individual_sub_answer__subgraph.nodes.answer_generation import (
answer_generation,
)
from onyx.agents.agent_search.deep_search_a.initial__individual_sub_answer__subgraph.nodes.format_answer import (
format_answer,
)
from onyx.agents.agent_search.deep_search_a.initial__individual_sub_answer__subgraph.nodes.ingest_retrieval import (
ingest_retrieval,
)
from onyx.agents.agent_search.deep_search_a.initial__individual_sub_answer__subgraph.states import (
AnswerQuestionInput,
)
from onyx.agents.agent_search.deep_search_a.initial__individual_sub_answer__subgraph.states import (
AnswerQuestionOutput,
)
from onyx.agents.agent_search.deep_search_a.initial__individual_sub_answer__subgraph.states import (
AnswerQuestionState,
)
from onyx.agents.agent_search.deep_search_a.refinement__consolidate_sub_answers__subgraph.edges import (
send_to_expanded_refined_retrieval,
)
from onyx.agents.agent_search.deep_search_a.util__expanded_retrieval__subgraph.graph_builder import (
expanded_retrieval_graph_builder,
)
from onyx.utils.logger import setup_logger
logger = setup_logger()
def answer_refined_query_graph_builder() -> StateGraph:
graph = StateGraph(
state_schema=AnswerQuestionState,
input=AnswerQuestionInput,
output=AnswerQuestionOutput,
)
### Add nodes ###
expanded_retrieval = expanded_retrieval_graph_builder().compile()
graph.add_node(
node="refined_sub_question_expanded_retrieval",
action=expanded_retrieval,
)
graph.add_node(
node="refined_sub_answer_check",
action=answer_check,
)
graph.add_node(
node="refined_sub_answer_generation",
action=answer_generation,
)
graph.add_node(
node="format_refined_sub_answer",
action=format_answer,
)
graph.add_node(
node="ingest_refined_retrieval",
action=ingest_retrieval,
)
### Add edges ###
graph.add_conditional_edges(
source=START,
path=send_to_expanded_refined_retrieval,
path_map=["refined_sub_question_expanded_retrieval"],
)
graph.add_edge(
start_key="refined_sub_question_expanded_retrieval",
end_key="ingest_refined_retrieval",
)
graph.add_edge(
start_key="ingest_refined_retrieval",
end_key="refined_sub_answer_generation",
)
graph.add_edge(
start_key="refined_sub_answer_generation",
end_key="refined_sub_answer_check",
)
graph.add_edge(
start_key="refined_sub_answer_check",
end_key="format_refined_sub_answer",
)
graph.add_edge(
start_key="format_refined_sub_answer",
end_key=END,
)
return graph
if __name__ == "__main__":
from onyx.db.engine import get_session_context_manager
from onyx.llm.factory import get_default_llms
from onyx.context.search.models import SearchRequest
graph = answer_refined_query_graph_builder()
compiled_graph = graph.compile()
primary_llm, fast_llm = get_default_llms()
search_request = SearchRequest(
query="what can you do with onyx or danswer?",
)
with get_session_context_manager() as db_session:
inputs = AnswerQuestionInput(
question="what can you do with onyx?",
question_id="0_0",
log_messages=[],
)
for thing in compiled_graph.stream(
input=inputs,
# debug=True,
# subgraphs=True,
):
logger.debug(thing)
# output = compiled_graph.invoke(inputs)
# logger.debug(output)

View File

@@ -0,0 +1,19 @@
from pydantic import BaseModel
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkStats
from onyx.context.search.models import InferenceSection
### Models ###
class AnswerRetrievalStats(BaseModel):
answer_retrieval_stats: dict[str, float | int]
class QuestionAnswerResults(BaseModel):
question: str
answer: str
quality: str
# expanded_retrieval_results: list[QueryResult]
documents: list[InferenceSection]
sub_question_retrieval_stats: AgentChunkStats

View File

@@ -0,0 +1,37 @@
from collections.abc import Hashable
from typing import cast
from langchain_core.runnables.config import RunnableConfig
from langgraph.types import Send
from onyx.agents.agent_search.deep_search_a.util__expanded_retrieval__subgraph.states import (
ExpandedRetrievalState,
)
from onyx.agents.agent_search.deep_search_a.util__expanded_retrieval__subgraph.states import (
RetrievalInput,
)
from onyx.agents.agent_search.models import AgentSearchConfig
def parallel_retrieval_edge(
state: ExpandedRetrievalState, config: RunnableConfig
) -> list[Send | Hashable]:
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
question = state.question if state.question else agent_a_config.search_request.query
query_expansions = (
state.expanded_queries if state.expanded_queries else [] + [question]
)
return [
Send(
"doc_retrieval",
RetrievalInput(
query_to_retrieve=query,
question=question,
base_search=False,
sub_question_id=state.sub_question_id,
log_messages=[],
),
)
for query in query_expansions
]

View File

@@ -0,0 +1,147 @@
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
from onyx.agents.agent_search.deep_search_a.util__expanded_retrieval__subgraph.edges import (
parallel_retrieval_edge,
)
from onyx.agents.agent_search.deep_search_a.util__expanded_retrieval__subgraph.nodes.doc_reranking import (
doc_reranking,
)
from onyx.agents.agent_search.deep_search_a.util__expanded_retrieval__subgraph.nodes.doc_retrieval import (
doc_retrieval,
)
from onyx.agents.agent_search.deep_search_a.util__expanded_retrieval__subgraph.nodes.doc_verification import (
doc_verification,
)
from onyx.agents.agent_search.deep_search_a.util__expanded_retrieval__subgraph.nodes.dummy import (
dummy,
)
from onyx.agents.agent_search.deep_search_a.util__expanded_retrieval__subgraph.nodes.expand_queries import (
expand_queries,
)
from onyx.agents.agent_search.deep_search_a.util__expanded_retrieval__subgraph.nodes.format_results import (
format_results,
)
from onyx.agents.agent_search.deep_search_a.util__expanded_retrieval__subgraph.nodes.verification_kickoff import (
verification_kickoff,
)
from onyx.agents.agent_search.deep_search_a.util__expanded_retrieval__subgraph.states import (
ExpandedRetrievalInput,
)
from onyx.agents.agent_search.deep_search_a.util__expanded_retrieval__subgraph.states import (
ExpandedRetrievalOutput,
)
from onyx.agents.agent_search.deep_search_a.util__expanded_retrieval__subgraph.states import (
ExpandedRetrievalState,
)
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
from onyx.utils.logger import setup_logger
logger = setup_logger()
def expanded_retrieval_graph_builder() -> StateGraph:
graph = StateGraph(
state_schema=ExpandedRetrievalState,
input=ExpandedRetrievalInput,
output=ExpandedRetrievalOutput,
)
### Add nodes ###
graph.add_node(
node="expand_queries",
action=expand_queries,
)
graph.add_node(
node="dummy",
action=dummy,
)
graph.add_node(
node="doc_retrieval",
action=doc_retrieval,
)
graph.add_node(
node="verification_kickoff",
action=verification_kickoff,
)
graph.add_node(
node="doc_verification",
action=doc_verification,
)
graph.add_node(
node="doc_reranking",
action=doc_reranking,
)
graph.add_node(
node="format_results",
action=format_results,
)
### Add edges ###
graph.add_edge(
start_key=START,
end_key="expand_queries",
)
graph.add_edge(
start_key="expand_queries",
end_key="dummy",
)
graph.add_conditional_edges(
source="dummy",
path=parallel_retrieval_edge,
path_map=["doc_retrieval"],
)
graph.add_edge(
start_key="doc_retrieval",
end_key="verification_kickoff",
)
graph.add_edge(
start_key="doc_verification",
end_key="doc_reranking",
)
graph.add_edge(
start_key="doc_reranking",
end_key="format_results",
)
graph.add_edge(
start_key="format_results",
end_key=END,
)
return graph
if __name__ == "__main__":
from onyx.db.engine import get_session_context_manager
from onyx.llm.factory import get_default_llms
from onyx.context.search.models import SearchRequest
graph = expanded_retrieval_graph_builder()
compiled_graph = graph.compile()
primary_llm, fast_llm = get_default_llms()
search_request = SearchRequest(
query="what can you do with onyx or danswer?",
)
with get_session_context_manager() as db_session:
agent_a_config, search_tool = get_test_config(
db_session, primary_llm, fast_llm, search_request
)
inputs = ExpandedRetrievalInput(
question="what can you do with onyx?",
base_search=False,
sub_question_id=None,
log_messages=[],
)
for thing in compiled_graph.stream(
input=inputs,
config={"configurable": {"config": agent_a_config}},
# debug=True,
subgraphs=True,
):
logger.debug(thing)

View File

@@ -0,0 +1,12 @@
from pydantic import BaseModel
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkStats
from onyx.agents.agent_search.shared_graph_utils.models import QueryResult
from onyx.context.search.models import InferenceSection
class ExpandedRetrievalResult(BaseModel):
expanded_queries_results: list[QueryResult] = []
reranked_documents: list[InferenceSection] = []
context_documents: list[InferenceSection] = []
sub_question_retrieval_stats: AgentChunkStats = AgentChunkStats()

View File

@@ -0,0 +1,96 @@
from datetime import datetime
from typing import cast
from langchain_core.runnables.config import RunnableConfig
from onyx.agents.agent_search.deep_search_a.util__expanded_retrieval__subgraph.operations import (
logger,
)
from onyx.agents.agent_search.deep_search_a.util__expanded_retrieval__subgraph.states import (
DocRerankingUpdate,
)
from onyx.agents.agent_search.deep_search_a.util__expanded_retrieval__subgraph.states import (
ExpandedRetrievalState,
)
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.agents.agent_search.shared_graph_utils.calculations import get_fit_scores
from onyx.agents.agent_search.shared_graph_utils.models import RetrievalFitStats
from onyx.configs.agent_configs import AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS
from onyx.configs.agent_configs import AGENT_RERANKING_STATS
from onyx.context.search.models import InferenceSection
from onyx.context.search.models import SearchRequest
from onyx.context.search.pipeline import retrieval_preprocessing
from onyx.context.search.postprocessing.postprocessing import rerank_sections
from onyx.db.engine import get_session_context_manager
def doc_reranking(
state: ExpandedRetrievalState, config: RunnableConfig
) -> DocRerankingUpdate:
now_start = datetime.now()
verified_documents = state.verified_documents
# Rerank post retrieval and verification. First, create a search query
# then create the list of reranked sections
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
question = state.question if state.question else agent_a_config.search_request.query
if agent_a_config.search_tool is None:
raise ValueError("search_tool must be provided for agentic search")
with get_session_context_manager() as db_session:
# we ignore some of the user specified fields since this search is
# internal to agentic search, but we still want to pass through
# persona (for stuff like document sets) and rerank settings
# (to not make an unnecessary db call).
search_request = SearchRequest(
query=question,
persona=agent_a_config.search_request.persona,
rerank_settings=agent_a_config.search_request.rerank_settings,
)
_search_query = retrieval_preprocessing(
search_request=search_request,
user=agent_a_config.search_tool.user, # bit of a hack
llm=agent_a_config.fast_llm,
db_session=db_session,
)
# skip section filtering
if (
_search_query.rerank_settings
and _search_query.rerank_settings.rerank_model_name
and _search_query.rerank_settings.num_rerank > 0
and len(verified_documents) > 0
):
if len(verified_documents) > 1:
reranked_documents = rerank_sections(
_search_query,
verified_documents,
)
else:
num = "No" if len(verified_documents) == 0 else "One"
logger.warning(f"{num} verified document(s) found, skipping reranking")
reranked_documents = verified_documents
else:
logger.warning("No reranking settings found, using unranked documents")
reranked_documents = verified_documents
if AGENT_RERANKING_STATS:
fit_scores = get_fit_scores(verified_documents, reranked_documents)
else:
fit_scores = RetrievalFitStats(fit_score_lift=0, rerank_effect=0, fit_scores={})
# TODO: stream deduped docs here, or decide to use search tool ranking/verification
now_end = datetime.now()
logger.info(
f"{now_start} -- Expanded Retrieval - Reranking - Time taken: {now_end - now_start}"
)
return DocRerankingUpdate(
reranked_documents=[
doc for doc in reranked_documents if type(doc) == InferenceSection
][:AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS],
sub_question_retrieval_stats=fit_scores,
log_messages=[
f"{now_start} -- Expanded Retrieval - Reranking - Time taken: {now_end - now_start}"
],
)

View File

@@ -0,0 +1,110 @@
from datetime import datetime
from typing import cast
from langchain_core.runnables.config import RunnableConfig
from onyx.agents.agent_search.deep_search_a.util__expanded_retrieval__subgraph.operations import (
logger,
)
from onyx.agents.agent_search.deep_search_a.util__expanded_retrieval__subgraph.states import (
DocRetrievalUpdate,
)
from onyx.agents.agent_search.deep_search_a.util__expanded_retrieval__subgraph.states import (
RetrievalInput,
)
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.agents.agent_search.shared_graph_utils.calculations import get_fit_scores
from onyx.agents.agent_search.shared_graph_utils.models import QueryResult
from onyx.configs.agent_configs import AGENT_MAX_QUERY_RETRIEVAL_RESULTS
from onyx.configs.agent_configs import AGENT_RETRIEVAL_STATS
from onyx.context.search.models import InferenceSection
from onyx.db.engine import get_session_context_manager
from onyx.tools.models import SearchQueryInfo
from onyx.tools.tool_implementations.search.search_tool import (
SEARCH_RESPONSE_SUMMARY_ID,
)
from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary
def doc_retrieval(state: RetrievalInput, config: RunnableConfig) -> DocRetrievalUpdate:
"""
Retrieve documents
Args:
state (RetrievalInput): Primary state + the query to retrieve
config (RunnableConfig): Configuration containing ProSearchConfig
Updates:
expanded_retrieval_results: list[ExpandedRetrievalResult]
retrieved_documents: list[InferenceSection]
"""
now_start = datetime.now()
query_to_retrieve = state.query_to_retrieve
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
search_tool = agent_a_config.search_tool
retrieved_docs: list[InferenceSection] = []
if not query_to_retrieve.strip():
logger.warning("Empty query, skipping retrieval")
now_end = datetime.now()
return DocRetrievalUpdate(
expanded_retrieval_results=[],
retrieved_documents=[],
log_messages=[
f"{now_start} -- Expanded Retrieval - Retrieval - Empty Query - Time taken: {now_end - now_start}"
],
)
query_info = None
if search_tool is None:
raise ValueError("search_tool must be provided for agentic search")
# new db session to avoid concurrency issues
with get_session_context_manager() as db_session:
for tool_response in search_tool.run(
query=query_to_retrieve,
force_no_rerank=True,
alternate_db_session=db_session,
):
# get retrieved docs to send to the rest of the graph
if tool_response.id == SEARCH_RESPONSE_SUMMARY_ID:
response = cast(SearchResponseSummary, tool_response.response)
retrieved_docs = response.top_sections
query_info = SearchQueryInfo(
predicted_search=response.predicted_search,
final_filters=response.final_filters,
recency_bias_multiplier=response.recency_bias_multiplier,
)
break
retrieved_docs = retrieved_docs[:AGENT_MAX_QUERY_RETRIEVAL_RESULTS]
pre_rerank_docs = retrieved_docs
if search_tool.search_pipeline is not None:
pre_rerank_docs = (
search_tool.search_pipeline._retrieved_sections or retrieved_docs
)
if AGENT_RETRIEVAL_STATS:
fit_scores = get_fit_scores(
pre_rerank_docs,
retrieved_docs,
)
else:
fit_scores = None
expanded_retrieval_result = QueryResult(
query=query_to_retrieve,
search_results=retrieved_docs,
stats=fit_scores,
query_info=query_info,
)
now_end = datetime.now()
logger.info(
f"{now_start} -- Expanded Retrieval - Retrieval - Time taken: {now_end - now_start}"
)
return DocRetrievalUpdate(
expanded_retrieval_results=[expanded_retrieval_result],
retrieved_documents=retrieved_docs,
log_messages=[
f"{now_start} -- Expanded Retrieval - Retrieval - Time taken: {now_end - now_start}"
],
)

View File

@@ -0,0 +1,60 @@
from typing import cast
from langchain_core.messages import HumanMessage
from langchain_core.runnables.config import RunnableConfig
from onyx.agents.agent_search.deep_search_a.util__expanded_retrieval__subgraph.states import (
DocVerificationInput,
)
from onyx.agents.agent_search.deep_search_a.util__expanded_retrieval__subgraph.states import (
DocVerificationUpdate,
)
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
trim_prompt_piece,
)
from onyx.agents.agent_search.shared_graph_utils.prompts import VERIFIER_PROMPT
def doc_verification(
state: DocVerificationInput, config: RunnableConfig
) -> DocVerificationUpdate:
"""
Check whether the document is relevant for the original user question
Args:
state (DocVerificationInput): The current state
config (RunnableConfig): Configuration containing ProSearchConfig
Updates:
verified_documents: list[InferenceSection]
"""
question = state.question
doc_to_verify = state.doc_to_verify
document_content = doc_to_verify.combined_content
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
fast_llm = agent_a_config.fast_llm
document_content = trim_prompt_piece(
fast_llm.config, document_content, VERIFIER_PROMPT + question
)
msg = [
HumanMessage(
content=VERIFIER_PROMPT.format(
question=question, document_content=document_content
)
)
]
response = fast_llm.invoke(msg)
verified_documents = []
if isinstance(response.content, str) and "yes" in response.content.lower():
verified_documents.append(doc_to_verify)
return DocVerificationUpdate(
verified_documents=verified_documents,
)

View File

@@ -0,0 +1,16 @@
from langchain_core.runnables.config import RunnableConfig
from onyx.agents.agent_search.deep_search_a.util__expanded_retrieval__subgraph.states import (
ExpandedRetrievalState,
)
from onyx.agents.agent_search.deep_search_a.util__expanded_retrieval__subgraph.states import (
QueryExpansionUpdate,
)
def dummy(
state: ExpandedRetrievalState, config: RunnableConfig
) -> QueryExpansionUpdate:
return QueryExpansionUpdate(
expanded_queries=state.expanded_queries,
)

View File

@@ -0,0 +1,74 @@
from datetime import datetime
from typing import cast
from langchain_core.messages import HumanMessage
from langchain_core.messages import merge_message_runs
from langchain_core.runnables.config import RunnableConfig
from onyx.agents.agent_search.deep_search_a.util__expanded_retrieval__subgraph.operations import (
dispatch_subquery,
)
from onyx.agents.agent_search.deep_search_a.util__expanded_retrieval__subgraph.operations import (
logger,
)
from onyx.agents.agent_search.deep_search_a.util__expanded_retrieval__subgraph.states import (
ExpandedRetrievalInput,
)
from onyx.agents.agent_search.deep_search_a.util__expanded_retrieval__subgraph.states import (
QueryExpansionUpdate,
)
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.agents.agent_search.shared_graph_utils.prompts import (
REWRITE_PROMPT_MULTI_ORIGINAL,
)
from onyx.agents.agent_search.shared_graph_utils.utils import dispatch_separated
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
def expand_queries(
state: ExpandedRetrievalInput, config: RunnableConfig
) -> QueryExpansionUpdate:
# Sometimes we want to expand the original question, sometimes we want to expand a sub-question.
# When we are running this node on the original question, no question is explictly passed in.
# Instead, we use the original question from the search request.
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
now_start = datetime.now()
question = (
state.question
if hasattr(state, "question")
else agent_a_config.search_request.query
)
llm = agent_a_config.fast_llm
chat_session_id = agent_a_config.chat_session_id
sub_question_id = state.sub_question_id
if sub_question_id is None:
level, question_nr = 0, 0
else:
level, question_nr = parse_question_id(sub_question_id)
if chat_session_id is None:
raise ValueError("chat_session_id must be provided for agent search")
msg = [
HumanMessage(
content=REWRITE_PROMPT_MULTI_ORIGINAL.format(question=question),
)
]
llm_response_list = dispatch_separated(
llm.stream(prompt=msg), dispatch_subquery(level, question_nr)
)
llm_response = merge_message_runs(llm_response_list, chunk_separator="")[0].content
rewritten_queries = llm_response.split("\n")
now_end = datetime.now()
logger.info(
f"{now_start} -- Expanded Retrieval - Query Expansion - Time taken: {now_end - now_start}"
)
return QueryExpansionUpdate(
expanded_queries=rewritten_queries,
log_messages=[
f"{now_start} -- Expanded Retrieval - Query Expansion - Time taken: {now_end - now_start}"
],
)

View File

@@ -0,0 +1,84 @@
from typing import cast
from langchain_core.callbacks.manager import dispatch_custom_event
from langchain_core.runnables.config import RunnableConfig
from onyx.agents.agent_search.deep_search_a.util__expanded_retrieval__subgraph.models import (
ExpandedRetrievalResult,
)
from onyx.agents.agent_search.deep_search_a.util__expanded_retrieval__subgraph.operations import (
calculate_sub_question_retrieval_stats,
)
from onyx.agents.agent_search.deep_search_a.util__expanded_retrieval__subgraph.states import (
ExpandedRetrievalState,
)
from onyx.agents.agent_search.deep_search_a.util__expanded_retrieval__subgraph.states import (
ExpandedRetrievalUpdate,
)
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkStats
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
from onyx.chat.models import ExtendedToolResponse
from onyx.tools.tool_implementations.search.search_tool import yield_search_responses
def format_results(
state: ExpandedRetrievalState, config: RunnableConfig
) -> ExpandedRetrievalUpdate:
level, question_nr = parse_question_id(state.sub_question_id or "0_0")
query_infos = [
result.query_info
for result in state.expanded_retrieval_results
if result.query_info is not None
]
if len(query_infos) == 0:
raise ValueError("No query info found")
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
# main question docs will be sent later after aggregation and deduping with sub-question docs
reranked_documents = state.reranked_documents
if not (level == 0 and question_nr == 0):
if len(reranked_documents) == 0:
# The sub-question is used as the last query. If no verified documents are found, stream
# the top 3 for that one. We may want to revisit this.
reranked_documents = state.expanded_retrieval_results[-1].search_results[:3]
if agent_a_config.search_tool is None:
raise ValueError("search_tool must be provided for agentic search")
for tool_response in yield_search_responses(
query=state.question,
reranked_sections=state.retrieved_documents, # TODO: rename params. (sections pre-merging here.)
final_context_sections=reranked_documents,
search_query_info=query_infos[0], # TODO: handle differing query infos?
get_section_relevance=lambda: None, # TODO: add relevance
search_tool=agent_a_config.search_tool,
):
dispatch_custom_event(
"tool_response",
ExtendedToolResponse(
id=tool_response.id,
response=tool_response.response,
level=level,
level_question_nr=question_nr,
),
)
sub_question_retrieval_stats = calculate_sub_question_retrieval_stats(
verified_documents=state.verified_documents,
expanded_retrieval_results=state.expanded_retrieval_results,
)
if sub_question_retrieval_stats is None:
sub_question_retrieval_stats = AgentChunkStats()
# else:
# sub_question_retrieval_stats = [sub_question_retrieval_stats]
return ExpandedRetrievalUpdate(
expanded_retrieval_result=ExpandedRetrievalResult(
expanded_queries_results=state.expanded_retrieval_results,
reranked_documents=reranked_documents,
context_documents=state.reranked_documents,
sub_question_retrieval_stats=sub_question_retrieval_stats,
),
)

View File

@@ -0,0 +1,44 @@
from typing import cast
from typing import Literal
from langchain_core.runnables.config import RunnableConfig
from langgraph.types import Command
from langgraph.types import Send
from onyx.agents.agent_search.deep_search_a.util__expanded_retrieval__subgraph.states import (
DocVerificationInput,
)
from onyx.agents.agent_search.deep_search_a.util__expanded_retrieval__subgraph.states import (
ExpandedRetrievalState,
)
from onyx.agents.agent_search.models import AgentSearchConfig
def verification_kickoff(
state: ExpandedRetrievalState,
config: RunnableConfig,
) -> Command[Literal["doc_verification"]]:
documents = state.retrieved_documents
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
verification_question = (
state.question
if hasattr(state, "question")
else agent_a_config.search_request.query
)
sub_question_id = state.sub_question_id
return Command(
update={},
goto=[
Send(
node="doc_verification",
arg=DocVerificationInput(
doc_to_verify=doc,
question=verification_question,
base_search=False,
sub_question_id=sub_question_id,
log_messages=[],
),
)
for doc in documents
],
)

View File

@@ -0,0 +1,97 @@
from collections import defaultdict
from collections.abc import Callable
import numpy as np
from langchain_core.callbacks.manager import dispatch_custom_event
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkStats
from onyx.agents.agent_search.shared_graph_utils.models import QueryResult
from onyx.chat.models import SubQueryPiece
from onyx.context.search.models import InferenceSection
from onyx.utils.logger import setup_logger
logger = setup_logger()
def dispatch_subquery(level: int, question_nr: int) -> Callable[[str, int], None]:
def helper(token: str, num: int) -> None:
dispatch_custom_event(
"subqueries",
SubQueryPiece(
sub_query=token,
level=level,
level_question_nr=question_nr,
query_id=num,
),
)
return helper
def calculate_sub_question_retrieval_stats(
verified_documents: list[InferenceSection],
expanded_retrieval_results: list[QueryResult],
) -> AgentChunkStats:
chunk_scores: dict[str, dict[str, list[int | float]]] = defaultdict(
lambda: defaultdict(list)
)
for expanded_retrieval_result in expanded_retrieval_results:
for doc in expanded_retrieval_result.search_results:
doc_chunk_id = f"{doc.center_chunk.document_id}_{doc.center_chunk.chunk_id}"
if doc.center_chunk.score is not None:
chunk_scores[doc_chunk_id]["score"].append(doc.center_chunk.score)
verified_doc_chunk_ids = [
f"{verified_document.center_chunk.document_id}_{verified_document.center_chunk.chunk_id}"
for verified_document in verified_documents
]
dismissed_doc_chunk_ids = []
raw_chunk_stats_counts: dict[str, int] = defaultdict(int)
raw_chunk_stats_scores: dict[str, float] = defaultdict(float)
for doc_chunk_id, chunk_data in chunk_scores.items():
if doc_chunk_id in verified_doc_chunk_ids:
raw_chunk_stats_counts["verified_count"] += 1
valid_chunk_scores = [
score for score in chunk_data["score"] if score is not None
]
raw_chunk_stats_scores["verified_scores"] += float(
np.mean(valid_chunk_scores)
)
else:
raw_chunk_stats_counts["rejected_count"] += 1
valid_chunk_scores = [
score for score in chunk_data["score"] if score is not None
]
raw_chunk_stats_scores["rejected_scores"] += float(
np.mean(valid_chunk_scores)
)
dismissed_doc_chunk_ids.append(doc_chunk_id)
if raw_chunk_stats_counts["verified_count"] == 0:
verified_avg_scores = 0.0
else:
verified_avg_scores = raw_chunk_stats_scores["verified_scores"] / float(
raw_chunk_stats_counts["verified_count"]
)
rejected_scores = raw_chunk_stats_scores.get("rejected_scores", None)
if rejected_scores is not None:
rejected_avg_scores = rejected_scores / float(
raw_chunk_stats_counts["rejected_count"]
)
else:
rejected_avg_scores = None
chunk_stats = AgentChunkStats(
verified_count=raw_chunk_stats_counts["verified_count"],
verified_avg_scores=verified_avg_scores,
rejected_count=raw_chunk_stats_counts["rejected_count"],
rejected_avg_scores=rejected_avg_scores,
verified_doc_chunk_ids=verified_doc_chunk_ids,
dismissed_doc_chunk_ids=dismissed_doc_chunk_ids,
)
return chunk_stats

View File

@@ -0,0 +1,91 @@
from operator import add
from typing import Annotated
from pydantic import BaseModel
from onyx.agents.agent_search.core_state import SubgraphCoreState
from onyx.agents.agent_search.deep_search_a.util__expanded_retrieval__subgraph.models import (
ExpandedRetrievalResult,
)
from onyx.agents.agent_search.shared_graph_utils.models import QueryResult
from onyx.agents.agent_search.shared_graph_utils.models import RetrievalFitStats
from onyx.agents.agent_search.shared_graph_utils.operators import (
dedup_inference_sections,
)
from onyx.context.search.models import InferenceSection
### States ###
## Graph Input State
class ExpandedRetrievalInput(SubgraphCoreState):
question: str = ""
base_search: bool = False
sub_question_id: str | None = None
## Update/Return States
class QueryExpansionUpdate(BaseModel):
expanded_queries: list[str] = ["aaa", "bbb"]
log_messages: list[str] = []
class DocVerificationUpdate(BaseModel):
verified_documents: Annotated[list[InferenceSection], dedup_inference_sections] = []
class DocRetrievalUpdate(BaseModel):
expanded_retrieval_results: Annotated[list[QueryResult], add] = []
retrieved_documents: Annotated[
list[InferenceSection], dedup_inference_sections
] = []
log_messages: list[str] = []
class DocRerankingUpdate(BaseModel):
reranked_documents: Annotated[list[InferenceSection], dedup_inference_sections] = []
sub_question_retrieval_stats: RetrievalFitStats | None = None
log_messages: list[str] = []
class ExpandedRetrievalUpdate(BaseModel):
expanded_retrieval_result: ExpandedRetrievalResult
## Graph Output State
class ExpandedRetrievalOutput(BaseModel):
expanded_retrieval_result: ExpandedRetrievalResult = ExpandedRetrievalResult()
base_expanded_retrieval_result: ExpandedRetrievalResult = ExpandedRetrievalResult()
log_messages: list[str] = []
## Graph State
class ExpandedRetrievalState(
# This includes the core state
ExpandedRetrievalInput,
QueryExpansionUpdate,
DocRetrievalUpdate,
DocVerificationUpdate,
DocRerankingUpdate,
ExpandedRetrievalOutput,
):
pass
## Conditional Input States
class DocVerificationInput(ExpandedRetrievalInput):
doc_to_verify: InferenceSection
class RetrievalInput(ExpandedRetrievalInput):
query_to_retrieve: str = ""

View File

@@ -0,0 +1,101 @@
from uuid import UUID
from pydantic import BaseModel
from pydantic import model_validator
from sqlalchemy.orm import Session
from onyx.agents.agent_search.shared_graph_utils.models import PersonaExpressions
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.context.search.models import SearchRequest
from onyx.file_store.utils import InMemoryChatFile
from onyx.llm.interfaces import LLM
from onyx.tools.force import ForceUseTool
from onyx.tools.tool import Tool
from onyx.tools.tool_implementations.search.search_tool import SearchTool
class AgentSearchConfig(BaseModel):
"""
Configuration for the Agent Search feature.
"""
# The search request that was used to generate the Pro Search
search_request: SearchRequest
primary_llm: LLM
fast_llm: LLM
# Whether to force use of a tool, or to
# force tool args IF the tool is used
force_use_tool: ForceUseTool
# contains message history for the current chat session
# has the following (at most one is non-None)
# message_history: list[PreviousMessage] | None = None
# single_message_history: str | None = None
prompt_builder: AnswerPromptBuilder
search_tool: SearchTool | None = None
use_agentic_search: bool = False
# For persisting agent search data
chat_session_id: UUID | None = None
# The message ID of the user message that triggered the Pro Search
message_id: int | None = None
# Whether to persistence data for Agentic Search (turned off for testing)
use_agentic_persistence: bool = True
# The database session for Agentic Search
db_session: Session | None = None
# Whether to perform initial search to inform decomposition
# perform_initial_search_path_decision: bool = True
# Whether to perform initial search to inform decomposition
perform_initial_search_decomposition: bool = True
# Whether to allow creation of refinement questions (and entity extraction, etc.)
allow_refinement: bool = True
# Tools available for use
tools: list[Tool] | None = None
using_tool_calling_llm: bool = False
files: list[InMemoryChatFile] | None = None
structured_response_format: dict | None = None
skip_gen_ai_answer_generation: bool = False
@model_validator(mode="after")
def validate_db_session(self) -> "AgentSearchConfig":
if self.use_agentic_persistence and self.db_session is None:
raise ValueError(
"db_session must be provided for pro search when using persistence"
)
return self
@model_validator(mode="after")
def validate_search_tool(self) -> "AgentSearchConfig":
if self.use_agentic_search and self.search_tool is None:
raise ValueError("search_tool must be provided for agentic search")
return self
class Config:
arbitrary_types_allowed = True
class AgentDocumentCitations(BaseModel):
document_id: str
document_title: str
link: str
class AgentPromptEnrichmentComponents(BaseModel):
persona_prompts: PersonaExpressions
history: str
date_str: str

View File

@@ -0,0 +1,72 @@
from typing import cast
from langchain_core.messages import AIMessageChunk
from langchain_core.runnables.config import RunnableConfig
from onyx.agents.agent_search.basic.states import BasicOutput
from onyx.agents.agent_search.basic.states import BasicState
from onyx.agents.agent_search.basic.utils import process_llm_stream
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.chat.models import LlmDoc
from onyx.tools.tool_implementations.search.search_tool import (
SEARCH_DOC_CONTENT_ID,
)
from onyx.tools.tool_implementations.search_like_tool_utils import (
FINAL_CONTEXT_DOCUMENTS_ID,
)
from onyx.utils.logger import setup_logger
logger = setup_logger()
def basic_use_tool_response(state: BasicState, config: RunnableConfig) -> BasicOutput:
agent_config = cast(AgentSearchConfig, config["metadata"]["config"])
structured_response_format = agent_config.structured_response_format
llm = agent_config.primary_llm
tool_choice = state.tool_choice
if tool_choice is None:
raise ValueError("Tool choice is None")
tool = tool_choice.tool
prompt_builder = agent_config.prompt_builder
if state.tool_call_output is None:
raise ValueError("Tool call output is None")
tool_call_output = state.tool_call_output
tool_call_summary = tool_call_output.tool_call_summary
tool_call_responses = tool_call_output.tool_call_responses
new_prompt_builder = tool.build_next_prompt(
prompt_builder=prompt_builder,
tool_call_summary=tool_call_summary,
tool_responses=tool_call_responses,
using_tool_calling_llm=agent_config.using_tool_calling_llm,
)
final_search_results = []
initial_search_results = []
for yield_item in tool_call_responses:
if yield_item.id == FINAL_CONTEXT_DOCUMENTS_ID:
final_search_results = cast(list[LlmDoc], yield_item.response)
elif yield_item.id == SEARCH_DOC_CONTENT_ID:
search_contexts = yield_item.response.contexts
for doc in search_contexts:
if doc.document_id not in initial_search_results:
initial_search_results.append(doc)
initial_search_results = cast(list[LlmDoc], initial_search_results)
new_tool_call_chunk = AIMessageChunk(content="")
if not agent_config.skip_gen_ai_answer_generation:
stream = llm.stream(
prompt=new_prompt_builder.build(),
structured_response_format=structured_response_format,
)
# For now, we don't do multiple tool calls, so we ignore the tool_message
new_tool_call_chunk = process_llm_stream(
stream,
True,
final_search_results=final_search_results,
displayed_search_results=initial_search_results,
)
return BasicOutput(tool_call_chunk=new_tool_call_chunk)

View File

@@ -0,0 +1,144 @@
from typing import cast
from uuid import uuid4
from langchain_core.messages import ToolCall
from langchain_core.runnables.config import RunnableConfig
from onyx.agents.agent_search.basic.utils import process_llm_stream
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.agents.agent_search.orchestration.states import ToolChoice
from onyx.agents.agent_search.orchestration.states import ToolChoiceState
from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.chat.tool_handling.tool_response_handler import get_tool_by_name
from onyx.chat.tool_handling.tool_response_handler import (
get_tool_call_for_non_tool_calling_llm_impl,
)
from onyx.tools.tool import Tool
from onyx.utils.logger import setup_logger
logger = setup_logger()
# TODO: break this out into an implementation function
# and a function that handles extracting the necessary fields
# from the state and config
# TODO: fan-out to multiple tool call nodes? Make this configurable?
def llm_tool_choice(state: ToolChoiceState, config: RunnableConfig) -> ToolChoiceUpdate:
"""
This node is responsible for calling the LLM to choose a tool. If no tool is chosen,
The node MAY emit an answer, depending on whether state["should_stream_answer"] is set.
"""
should_stream_answer = state.should_stream_answer
agent_config = cast(AgentSearchConfig, config["metadata"]["config"])
using_tool_calling_llm = agent_config.using_tool_calling_llm
prompt_builder = state.prompt_snapshot or agent_config.prompt_builder
llm = agent_config.primary_llm
skip_gen_ai_answer_generation = agent_config.skip_gen_ai_answer_generation
structured_response_format = agent_config.structured_response_format
tools = [tool for tool in (agent_config.tools or []) if tool.name in state.tools]
force_use_tool = agent_config.force_use_tool
tool, tool_args = None, None
if force_use_tool.force_use and force_use_tool.args is not None:
tool_name, tool_args = (
force_use_tool.tool_name,
force_use_tool.args,
)
tool = get_tool_by_name(tools, tool_name)
# special pre-logic for non-tool calling LLM case
elif not using_tool_calling_llm and tools:
chosen_tool_and_args = get_tool_call_for_non_tool_calling_llm_impl(
force_use_tool=force_use_tool,
tools=tools,
prompt_builder=prompt_builder,
llm=llm,
)
if chosen_tool_and_args:
tool, tool_args = chosen_tool_and_args
# If we have a tool and tool args, we are redy to request a tool call.
# This only happens if the tool call was forced or we are using a non-tool calling LLM.
if tool and tool_args:
return ToolChoiceUpdate(
tool_choice=ToolChoice(
tool=tool,
tool_args=tool_args,
id=str(uuid4()),
),
)
# if we're skipping gen ai answer generation, we should only
# continue if we're forcing a tool call (which will be emitted by
# the tool calling llm in the stream() below)
if skip_gen_ai_answer_generation and not force_use_tool.force_use:
return ToolChoiceUpdate(
tool_choice=None,
)
built_prompt = (
prompt_builder.build()
if isinstance(prompt_builder, AnswerPromptBuilder)
else prompt_builder.built_prompt
)
# At this point, we are either using a tool calling LLM or we are skipping the tool call.
# DEBUG: good breakpoint
stream = llm.stream(
# For tool calling LLMs, we want to insert the task prompt as part of this flow, this is because the LLM
# may choose to not call any tools and just generate the answer, in which case the task prompt is needed.
prompt=built_prompt,
tools=[tool.tool_definition() for tool in tools] or None,
tool_choice=("required" if tools and force_use_tool.force_use else None),
structured_response_format=structured_response_format,
)
tool_message = process_llm_stream(
stream, should_stream_answer and not agent_config.skip_gen_ai_answer_generation
)
# If no tool calls are emitted by the LLM, we should not choose a tool
if len(tool_message.tool_calls) == 0:
logger.info("No tool calls emitted by LLM")
return ToolChoiceUpdate(
tool_choice=None,
)
# TODO: here we could handle parallel tool calls. Right now
# we just pick the first one that matches.
selected_tool: Tool | None = None
selected_tool_call_request: ToolCall | None = None
for tool_call_request in tool_message.tool_calls:
known_tools_by_name = [
tool for tool in tools if tool.name == tool_call_request["name"]
]
if known_tools_by_name:
selected_tool = known_tools_by_name[0]
selected_tool_call_request = tool_call_request
break
logger.error(
"Tool call requested with unknown name field. \n"
f"tools: {tools}"
f"tool_call_request: {tool_call_request}"
)
if not selected_tool or not selected_tool_call_request:
raise ValueError(
f"Tool call attempted with tool {selected_tool}, request {selected_tool_call_request}"
)
logger.info(f"Selected tool: {selected_tool.name}")
logger.debug(f"Selected tool call request: {selected_tool_call_request}")
return ToolChoiceUpdate(
tool_choice=ToolChoice(
tool=selected_tool,
tool_args=selected_tool_call_request["args"],
id=selected_tool_call_request["id"],
),
)

View File

@@ -0,0 +1,17 @@
from typing import Any
from typing import cast
from langchain_core.runnables.config import RunnableConfig
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.agents.agent_search.orchestration.states import ToolChoiceInput
def prepare_tool_input(state: Any, config: RunnableConfig) -> ToolChoiceInput:
agent_config = cast(AgentSearchConfig, config["metadata"]["config"])
return ToolChoiceInput(
# NOTE: this node is used at the top level of the agent, so we always stream
should_stream_answer=True,
prompt_snapshot=None, # uses default prompt builder
tools=[tool.name for tool in (agent_config.tools or [])],
)

View File

@@ -0,0 +1,65 @@
from typing import cast
from langchain_core.callbacks.manager import dispatch_custom_event
from langchain_core.messages import AIMessageChunk
from langchain_core.messages.tool import ToolCall
from langchain_core.runnables.config import RunnableConfig
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.agents.agent_search.orchestration.states import ToolCallOutput
from onyx.agents.agent_search.orchestration.states import ToolCallUpdate
from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
from onyx.chat.models import AnswerPacket
from onyx.tools.message import build_tool_message
from onyx.tools.message import ToolCallSummary
from onyx.tools.tool_runner import ToolRunner
from onyx.utils.logger import setup_logger
logger = setup_logger()
def emit_packet(packet: AnswerPacket) -> None:
dispatch_custom_event("basic_response", packet)
def tool_call(state: ToolChoiceUpdate, config: RunnableConfig) -> ToolCallUpdate:
"""Calls the tool specified in the state and updates the state with the result"""
cast(AgentSearchConfig, config["metadata"]["config"])
tool_choice = state.tool_choice
if tool_choice is None:
raise ValueError("Cannot invoke tool call node without a tool choice")
tool = tool_choice.tool
tool_args = tool_choice.tool_args
tool_id = tool_choice.id
tool_runner = ToolRunner(tool, tool_args)
tool_kickoff = tool_runner.kickoff()
emit_packet(tool_kickoff)
tool_responses = []
for response in tool_runner.tool_responses():
tool_responses.append(response)
emit_packet(response)
tool_final_result = tool_runner.tool_final_result()
emit_packet(tool_final_result)
tool_call = ToolCall(name=tool.name, args=tool_args, id=tool_id)
tool_call_summary = ToolCallSummary(
tool_call_request=AIMessageChunk(content="", tool_calls=[tool_call]),
tool_call_result=build_tool_message(
tool_call, tool_runner.tool_message_content()
),
)
tool_call_output = ToolCallOutput(
tool_call_summary=tool_call_summary,
tool_call_kickoff=tool_kickoff,
tool_call_responses=tool_responses,
tool_call_final_result=tool_final_result,
)
return ToolCallUpdate(tool_call_output=tool_call_output)

View File

@@ -0,0 +1,48 @@
from pydantic import BaseModel
from onyx.chat.prompt_builder.answer_prompt_builder import PromptSnapshot
from onyx.tools.message import ToolCallSummary
from onyx.tools.models import ToolCallFinalResult
from onyx.tools.models import ToolCallKickoff
from onyx.tools.models import ToolResponse
from onyx.tools.tool import Tool
# TODO: adapt the tool choice/tool call to allow for parallel tool calls by
# creating a subgraph that can be invoked in parallel via Send/Command APIs
class ToolChoiceInput(BaseModel):
should_stream_answer: bool = True
# default to the prompt builder from the config, but
# allow overrides for arbitrary tool calls
prompt_snapshot: PromptSnapshot | None = None
# names of tools to use for tool calling. Filters the tools available in the config
tools: list[str] = []
class ToolCallOutput(BaseModel):
tool_call_summary: ToolCallSummary
tool_call_kickoff: ToolCallKickoff
tool_call_responses: list[ToolResponse]
tool_call_final_result: ToolCallFinalResult
class ToolCallUpdate(BaseModel):
tool_call_output: ToolCallOutput | None = None
class ToolChoice(BaseModel):
tool: Tool
tool_args: dict
id: str | None
class Config:
arbitrary_types_allowed = True
class ToolChoiceUpdate(BaseModel):
tool_choice: ToolChoice | None = None
class ToolChoiceState(ToolChoiceUpdate, ToolChoiceInput):
pass

View File

@@ -0,0 +1,261 @@
import asyncio
from collections.abc import AsyncIterable
from collections.abc import Iterable
from datetime import datetime
from typing import cast
from langchain_core.runnables.schema import StreamEvent
from langgraph.graph.state import CompiledStateGraph
from onyx.agents.agent_search.basic.graph_builder import basic_graph_builder
from onyx.agents.agent_search.basic.states import BasicInput
from onyx.agents.agent_search.deep_search_a.main__graph.graph_builder import (
main_graph_builder as main_graph_builder_a,
)
from onyx.agents.agent_search.deep_search_a.main__graph.states import (
MainInput as MainInput_a,
)
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
from onyx.chat.models import AgentAnswerPiece
from onyx.chat.models import AnswerPacket
from onyx.chat.models import AnswerStream
from onyx.chat.models import ExtendedToolResponse
from onyx.chat.models import RefinedAnswerImprovement
from onyx.chat.models import StreamStopInfo
from onyx.chat.models import SubQueryPiece
from onyx.chat.models import SubQuestionPiece
from onyx.chat.models import ToolResponse
from onyx.context.search.models import SearchRequest
from onyx.db.engine import get_session_context_manager
from onyx.tools.tool_runner import ToolCallKickoff
from onyx.utils.logger import setup_logger
logger = setup_logger()
_COMPILED_GRAPH: CompiledStateGraph | None = None
def _set_combined_token_value(
combined_token: str, parsed_object: AgentAnswerPiece
) -> AgentAnswerPiece:
parsed_object.answer_piece = combined_token
return parsed_object
def _parse_agent_event(
event: StreamEvent,
) -> AnswerPacket | None:
"""
Parse the event into a typed object.
Return None if we are not interested in the event.
"""
event_type = event["event"]
# We always just yield the event data, but this piece is useful for two development reasons:
# 1. It's a list of the names of every place we dispatch a custom event
# 2. We maintain the intended types yielded by each event
if event_type == "on_custom_event":
# TODO: different AnswerStream types for different events
if event["name"] == "decomp_qs":
return cast(SubQuestionPiece, event["data"])
elif event["name"] == "subqueries":
return cast(SubQueryPiece, event["data"])
elif event["name"] == "sub_answers":
return cast(AgentAnswerPiece, event["data"])
elif event["name"] == "stream_finished":
return cast(StreamStopInfo, event["data"])
elif event["name"] == "initial_agent_answer":
return cast(AgentAnswerPiece, event["data"])
elif event["name"] == "refined_agent_answer":
return cast(AgentAnswerPiece, event["data"])
elif event["name"] == "start_refined_answer_creation":
return cast(ToolCallKickoff, event["data"])
elif event["name"] == "tool_response":
return cast(ToolResponse, event["data"])
elif event["name"] == "basic_response":
return cast(AnswerPacket, event["data"])
elif event["name"] == "refined_answer_improvement":
return cast(RefinedAnswerImprovement, event["data"])
return None
# https://stackoverflow.com/questions/60226557/how-to-forcefully-close-an-async-generator
# https://stackoverflow.com/questions/40897428/please-explain-task-was-destroyed-but-it-is-pending-after-cancelling-tasks
task_references: set[asyncio.Task[StreamEvent]] = set()
def _manage_async_event_streaming(
compiled_graph: CompiledStateGraph,
config: AgentSearchConfig | None,
graph_input: MainInput_a | BasicInput,
) -> Iterable[StreamEvent]:
async def _run_async_event_stream() -> AsyncIterable[StreamEvent]:
message_id = config.message_id if config else None
async for event in compiled_graph.astream_events(
input=graph_input,
config={"metadata": {"config": config, "thread_id": str(message_id)}},
# debug=True,
# indicating v2 here deserves further scrutiny
version="v2",
):
yield event
# This might be able to be simplified
def _yield_async_to_sync() -> Iterable[StreamEvent]:
loop = asyncio.new_event_loop()
try:
# Get the async generator
async_gen = _run_async_event_stream()
# Convert to AsyncIterator
async_iter = async_gen.__aiter__()
while True:
try:
# Create a coroutine by calling anext with the async iterator
next_coro = anext(async_iter)
task = asyncio.ensure_future(next_coro, loop=loop)
task_references.add(task)
# Run the coroutine to get the next event
event = loop.run_until_complete(task)
yield event
except (StopAsyncIteration, GeneratorExit):
break
finally:
try:
for task in task_references.pop():
task.cancel()
except StopAsyncIteration:
pass
loop.close()
return _yield_async_to_sync()
def run_graph(
compiled_graph: CompiledStateGraph,
config: AgentSearchConfig,
input: BasicInput | MainInput_a,
) -> AnswerStream:
# TODO: add these to the environment
# config.perform_initial_search_path_decision = False
config.perform_initial_search_decomposition = True
config.allow_refinement = True
for event in _manage_async_event_streaming(
compiled_graph=compiled_graph, config=config, graph_input=input
):
if not (parsed_object := _parse_agent_event(event)):
continue
yield parsed_object
# TODO: call this once on startup, TBD where and if it should be gated based
# on dev mode or not
def load_compiled_graph() -> CompiledStateGraph:
global _COMPILED_GRAPH
if _COMPILED_GRAPH is None:
graph = main_graph_builder_a()
_COMPILED_GRAPH = graph.compile()
return _COMPILED_GRAPH
def run_main_graph(
config: AgentSearchConfig,
) -> AnswerStream:
compiled_graph = load_compiled_graph()
input = MainInput_a(base_question=config.search_request.query, log_messages=[])
# Agent search is not a Tool per se, but this is helpful for the frontend
yield ToolCallKickoff(
tool_name="agent_search_0",
tool_args={"query": config.search_request.query},
)
yield from run_graph(compiled_graph, config, input)
# TODO: unify input types, especially prosearchconfig
def run_basic_graph(
config: AgentSearchConfig,
) -> AnswerStream:
graph = basic_graph_builder()
compiled_graph = graph.compile()
# TODO: unify basic input
input = BasicInput()
return run_graph(compiled_graph, config, input)
if __name__ == "__main__":
from onyx.llm.factory import get_default_llms
for _ in range(1):
now_start = datetime.now()
logger.debug(f"Start at {now_start}")
graph = main_graph_builder_a()
compiled_graph = graph.compile()
now_end = datetime.now()
logger.debug(f"Graph compiled in {now_end - now_start} seconds")
primary_llm, fast_llm = get_default_llms()
search_request = SearchRequest(
# query="what can you do with gitlab?",
# query="What are the guiding principles behind the development of cockroachDB",
# query="What are the temperatures in Munich, Hawaii, and New York?",
# query="When was Washington born?",
# query="What is Onyx?",
# query="What is the difference between astronomy and astrology?",
query="Do a search to tell me what is the difference between astronomy and astrology?",
)
# Joachim custom persona
with get_session_context_manager() as db_session:
config, search_tool = get_test_config(
db_session, primary_llm, fast_llm, search_request
)
# search_request.persona = get_persona_by_id(1, None, db_session)
config.use_agentic_persistence = True
# config.perform_initial_search_path_decision = False
config.perform_initial_search_decomposition = True
input = MainInput_a(
base_question=config.search_request.query, log_messages=[]
)
# with open("output.txt", "w") as f:
tool_responses: list = []
for output in run_graph(compiled_graph, config, input):
# pass
if isinstance(output, ToolCallKickoff):
pass
elif isinstance(output, ExtendedToolResponse):
tool_responses.append(output.response)
logger.info(
f" ---- ET {output.level} - {output.level_question_nr} | "
)
elif isinstance(output, SubQueryPiece):
logger.info(
f"Sq {output.level} - {output.level_question_nr} - {output.sub_query} | "
)
elif isinstance(output, SubQuestionPiece):
logger.info(
f"SQ {output.level} - {output.level_question_nr} - {output.sub_question} | "
)
elif (
isinstance(output, AgentAnswerPiece)
and output.answer_type == "agent_sub_answer"
):
logger.info(
f" ---- SA {output.level} - {output.level_question_nr} {output.answer_piece} | "
)
elif (
isinstance(output, AgentAnswerPiece)
and output.answer_type == "agent_level_answer"
):
logger.info(
f" ---------- FA {output.level} - {output.level_question_nr} {output.answer_piece} | "
)
elif isinstance(output, RefinedAnswerImprovement):
logger.info(
f" ---------- RE {output.refined_answer_improvement} | "
)

View File

@@ -0,0 +1,134 @@
from langchain.schema import AIMessage
from langchain.schema import HumanMessage
from langchain.schema import SystemMessage
from langchain_core.messages.tool import ToolMessage
from onyx.agents.agent_search.models import AgentPromptEnrichmentComponents
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.agents.agent_search.shared_graph_utils.prompts import BASE_RAG_PROMPT_v2
from onyx.agents.agent_search.shared_graph_utils.prompts import HISTORY_PROMPT
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_persona_agent_prompt_expressions,
)
from onyx.agents.agent_search.shared_graph_utils.utils import get_today_prompt
from onyx.agents.agent_search.shared_graph_utils.utils import summarize_history
from onyx.configs.agent_configs import AGENT_MAX_STATIC_HISTORY_CHAR_LENGTH
from onyx.context.search.models import InferenceSection
from onyx.llm.interfaces import LLMConfig
from onyx.llm.utils import get_max_input_tokens
from onyx.natural_language_processing.utils import get_tokenizer
from onyx.natural_language_processing.utils import tokenizer_trim_content
def build_sub_question_answer_prompt(
question: str,
original_question: str,
docs: list[InferenceSection],
persona_specification: str,
config: LLMConfig,
) -> list[SystemMessage | HumanMessage | AIMessage | ToolMessage]:
system_message = SystemMessage(
content=persona_specification,
)
date_str = get_today_prompt()
docs_format_list = [
f"""Document Number: [D{doc_nr + 1}]\n
Content: {doc.combined_content}\n\n"""
for doc_nr, doc in enumerate(docs)
]
docs_str = "\n\n".join(docs_format_list)
docs_str = trim_prompt_piece(
config, docs_str, BASE_RAG_PROMPT_v2 + question + original_question + date_str
)
human_message = HumanMessage(
content=BASE_RAG_PROMPT_v2.format(
question=question,
original_question=original_question,
context=docs_str,
date_prompt=date_str,
)
)
return [system_message, human_message]
def trim_prompt_piece(config: LLMConfig, prompt_piece: str, reserved_str: str) -> str:
# TODO: this truncating might add latency. We could do a rougher + faster check
# first to determine whether truncation is needed
# TODO: maybe save the tokenizer and max input tokens if this is getting called multiple times?
llm_tokenizer = get_tokenizer(
provider_type=config.model_provider,
model_name=config.model_name,
)
max_tokens = get_max_input_tokens(
model_provider=config.model_provider,
model_name=config.model_name,
)
# slightly conservative trimming
return tokenizer_trim_content(
content=prompt_piece,
desired_length=max_tokens - len(llm_tokenizer.encode(reserved_str)),
tokenizer=llm_tokenizer,
)
def build_history_prompt(config: AgentSearchConfig, question: str) -> str:
prompt_builder = config.prompt_builder
model = config.fast_llm
persona_base = get_persona_agent_prompt_expressions(
config.search_request.persona
).base_prompt
if prompt_builder is None:
return ""
if prompt_builder.single_message_history is not None:
history = prompt_builder.single_message_history
else:
history_components = []
previous_message_type = None
for message in prompt_builder.raw_message_history:
if "user" in message.message_type:
history_components.append(f"User: {message.message}\n")
previous_message_type = "user"
elif "assistant" in message.message_type:
# only use the last agent answer for the history
if previous_message_type != "assistant":
history_components.append(f"You/Agent: {message.message}\n")
else:
history_components = history_components[:-1]
history_components.append(f"You/Agent: {message.message}\n")
previous_message_type = "assistant"
else:
continue
history = "\n".join(history_components)
if len(history) > AGENT_MAX_STATIC_HISTORY_CHAR_LENGTH:
history = summarize_history(history, question, persona_base, model)
return HISTORY_PROMPT.format(history=history) if history else ""
def get_prompt_enrichment_components(
config: AgentSearchConfig,
) -> AgentPromptEnrichmentComponents:
persona_prompts = get_persona_agent_prompt_expressions(
config.search_request.persona
)
history = build_history_prompt(config, config.search_request.query)
date_str = get_today_prompt()
return AgentPromptEnrichmentComponents(
persona_prompts=persona_prompts,
history=history,
date_str=date_str,
)

View File

@@ -0,0 +1,98 @@
import numpy as np
from onyx.agents.agent_search.shared_graph_utils.models import RetrievalFitScoreMetrics
from onyx.agents.agent_search.shared_graph_utils.models import RetrievalFitStats
from onyx.chat.models import SectionRelevancePiece
from onyx.context.search.models import InferenceSection
from onyx.utils.logger import setup_logger
logger = setup_logger()
def unique_chunk_id(doc: InferenceSection) -> str:
return f"{doc.center_chunk.document_id}_{doc.center_chunk.chunk_id}"
def calculate_rank_shift(list1: list, list2: list, top_n: int = 20) -> float:
shift = 0
for rank_first, doc_id in enumerate(list1[:top_n], 1):
try:
rank_second = list2.index(doc_id) + 1
except ValueError:
rank_second = len(list2) # Document not found in second list
shift += np.abs(rank_first - rank_second) / np.log(1 + rank_first * rank_second)
return shift / top_n
def get_fit_scores(
pre_reranked_results: list[InferenceSection],
post_reranked_results: list[InferenceSection] | list[SectionRelevancePiece],
) -> RetrievalFitStats | None:
"""
Calculate retrieval metrics for search purposes
"""
if len(pre_reranked_results) == 0 or len(post_reranked_results) == 0:
return None
ranked_sections = {
"initial": pre_reranked_results,
"reranked": post_reranked_results,
}
fit_eval: RetrievalFitStats = RetrievalFitStats(
fit_score_lift=0,
rerank_effect=0,
fit_scores={
"initial": RetrievalFitScoreMetrics(scores={}, chunk_ids=[]),
"reranked": RetrievalFitScoreMetrics(scores={}, chunk_ids=[]),
},
)
for rank_type, docs in ranked_sections.items():
logger.debug(f"rank_type: {rank_type}")
for i in [1, 5, 10]:
fit_eval.fit_scores[rank_type].scores[str(i)] = (
sum(
[
float(doc.center_chunk.score)
for doc in docs[:i]
if type(doc) == InferenceSection
and doc.center_chunk.score is not None
]
)
/ i
)
fit_eval.fit_scores[rank_type].scores["fit_score"] = (
1
/ 3
* (
fit_eval.fit_scores[rank_type].scores["1"]
+ fit_eval.fit_scores[rank_type].scores["5"]
+ fit_eval.fit_scores[rank_type].scores["10"]
)
)
fit_eval.fit_scores[rank_type].scores["fit_score"] = fit_eval.fit_scores[
rank_type
].scores["1"]
fit_eval.fit_scores[rank_type].chunk_ids = [
unique_chunk_id(doc) for doc in docs if type(doc) == InferenceSection
]
fit_eval.fit_score_lift = (
fit_eval.fit_scores["reranked"].scores["fit_score"]
/ fit_eval.fit_scores["initial"].scores["fit_score"]
)
fit_eval.rerank_effect = calculate_rank_shift(
fit_eval.fit_scores["initial"].chunk_ids,
fit_eval.fit_scores["reranked"].chunk_ids,
)
return fit_eval

View File

@@ -0,0 +1,123 @@
from typing import Literal
from pydantic import BaseModel
from onyx.agents.agent_search.deep_search_a.main__graph.models import (
AgentAdditionalMetrics,
)
from onyx.agents.agent_search.deep_search_a.main__graph.models import AgentBaseMetrics
from onyx.agents.agent_search.deep_search_a.main__graph.models import (
AgentRefinedMetrics,
)
from onyx.agents.agent_search.deep_search_a.main__graph.models import AgentTimings
from onyx.context.search.models import InferenceSection
from onyx.tools.models import SearchQueryInfo
# Pydantic models for structured outputs
class RewrittenQueries(BaseModel):
rewritten_queries: list[str]
class BinaryDecision(BaseModel):
decision: Literal["yes", "no"]
class BinaryDecisionWithReasoning(BaseModel):
reasoning: str
decision: Literal["yes", "no"]
class RetrievalFitScoreMetrics(BaseModel):
scores: dict[str, float]
chunk_ids: list[str]
class RetrievalFitStats(BaseModel):
fit_score_lift: float
rerank_effect: float
fit_scores: dict[str, RetrievalFitScoreMetrics]
class AgentChunkScores(BaseModel):
scores: dict[str, dict[str, list[int | float]]]
class AgentChunkStats(BaseModel):
verified_count: int | None = None
verified_avg_scores: float | None = None
rejected_count: int | None = None
rejected_avg_scores: float | None = None
verified_doc_chunk_ids: list[str] = []
dismissed_doc_chunk_ids: list[str] = []
class InitialAgentResultStats(BaseModel):
sub_questions: dict[str, float | int | None]
original_question: dict[str, float | int | None]
agent_effectiveness: dict[str, float | int | None]
class RefinedAgentStats(BaseModel):
revision_doc_efficiency: float | None
revision_question_efficiency: float | None
class Term(BaseModel):
term_name: str = ""
term_type: str = ""
term_similar_to: list[str] = []
### Models ###
class Entity(BaseModel):
entity_name: str = ""
entity_type: str = ""
class Relationship(BaseModel):
relationship_name: str = ""
relationship_type: str = ""
relationship_entities: list[str] = []
class EntityRelationshipTermExtraction(BaseModel):
entities: list[Entity] = []
relationships: list[Relationship] = []
terms: list[Term] = []
### Models ###
class QueryResult(BaseModel):
query: str
search_results: list[InferenceSection]
stats: RetrievalFitStats | None
query_info: SearchQueryInfo | None
class QuestionAnswerResults(BaseModel):
question: str
question_id: str
answer: str
quality: str
expanded_retrieval_results: list[QueryResult]
documents: list[InferenceSection]
context_documents: list[InferenceSection]
cited_docs: list[InferenceSection]
sub_question_retrieval_stats: AgentChunkStats
class CombinedAgentMetrics(BaseModel):
timings: AgentTimings
base_metrics: AgentBaseMetrics | None
refined_metrics: AgentRefinedMetrics
additional_metrics: AgentAdditionalMetrics
class PersonaExpressions(BaseModel):
contextualized_prompt: str
base_prompt: str

View File

@@ -0,0 +1,31 @@
from onyx.agents.agent_search.shared_graph_utils.models import (
QuestionAnswerResults,
)
from onyx.chat.prune_and_merge import _merge_sections
from onyx.context.search.models import InferenceSection
def dedup_inference_sections(
list1: list[InferenceSection], list2: list[InferenceSection]
) -> list[InferenceSection]:
deduped = _merge_sections(list1 + list2)
return deduped
def dedup_question_answer_results(
question_answer_results_1: list[QuestionAnswerResults],
question_answer_results_2: list[QuestionAnswerResults],
) -> list[QuestionAnswerResults]:
deduped_question_answer_results: list[
QuestionAnswerResults
] = question_answer_results_1
utilized_question_ids: set[str] = set(
[x.question_id for x in question_answer_results_1]
)
for question_answer_result in question_answer_results_2:
if question_answer_result.question_id not in utilized_question_ids:
deduped_question_answer_results.append(question_answer_result)
utilized_question_ids.add(question_answer_result.question_id)
return deduped_question_answer_results

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,367 @@
import ast
import json
import re
from collections.abc import Callable
from collections.abc import Iterator
from collections.abc import Sequence
from datetime import datetime
from datetime import timedelta
from typing import Any
from typing import cast
from uuid import UUID
from langchain_core.callbacks.manager import dispatch_custom_event
from langchain_core.messages import BaseMessage
from langchain_core.messages import HumanMessage
from sqlalchemy.orm import Session
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.agents.agent_search.shared_graph_utils.models import (
EntityRelationshipTermExtraction,
)
from onyx.agents.agent_search.shared_graph_utils.models import PersonaExpressions
from onyx.agents.agent_search.shared_graph_utils.prompts import (
ASSISTANT_SYSTEM_PROMPT_DEFAULT,
)
from onyx.agents.agent_search.shared_graph_utils.prompts import (
ASSISTANT_SYSTEM_PROMPT_PERSONA,
)
from onyx.agents.agent_search.shared_graph_utils.prompts import DATE_PROMPT
from onyx.agents.agent_search.shared_graph_utils.prompts import (
HISTORY_CONTEXT_SUMMARY_PROMPT,
)
from onyx.chat.models import AnswerStyleConfig
from onyx.chat.models import CitationConfig
from onyx.chat.models import DocumentPruningConfig
from onyx.chat.models import PromptConfig
from onyx.chat.models import StreamStopInfo
from onyx.chat.models import StreamStopReason
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
from onyx.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
from onyx.configs.constants import DEFAULT_PERSONA_ID
from onyx.context.search.enums import LLMEvaluationType
from onyx.context.search.models import InferenceSection
from onyx.context.search.models import RetrievalDetails
from onyx.context.search.models import SearchRequest
from onyx.db.engine import get_session_context_manager
from onyx.db.persona import get_persona_by_id
from onyx.db.persona import Persona
from onyx.llm.interfaces import LLM
from onyx.tools.force import ForceUseTool
from onyx.tools.tool_constructor import SearchToolConfig
from onyx.tools.tool_implementations.search.search_tool import (
SEARCH_RESPONSE_SUMMARY_ID,
)
from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary
from onyx.tools.tool_implementations.search.search_tool import SearchTool
def normalize_whitespace(text: str) -> str:
"""Normalize whitespace in text to single spaces and strip leading/trailing whitespace."""
import re
return re.sub(r"\s+", " ", text.strip())
# Post-processing
def format_docs(docs: Sequence[InferenceSection]) -> str:
formatted_doc_list = []
for doc_nr, doc in enumerate(docs):
formatted_doc_list.append(f"Document D{doc_nr + 1}:\n{doc.combined_content}")
return "\n\n".join(formatted_doc_list)
def format_docs_content_flat(docs: Sequence[InferenceSection]) -> str:
formatted_doc_list = []
for _, doc in enumerate(docs):
formatted_doc_list.append(f"\n...{doc.combined_content}\n")
return "\n\n".join(formatted_doc_list)
def clean_and_parse_list_string(json_string: str) -> list[dict]:
# Remove any prefixes/labels before the actual JSON content
json_string = re.sub(r"^.*?(?=\[)", "", json_string, flags=re.DOTALL)
# Remove markdown code block markers and any newline prefixes
cleaned_string = re.sub(r"```json\n|\n```", "", json_string)
cleaned_string = cleaned_string.replace("\\n", " ").replace("\n", " ")
cleaned_string = " ".join(cleaned_string.split())
# Try parsing with json.loads first, fall back to ast.literal_eval
try:
return json.loads(cleaned_string)
except json.JSONDecodeError:
try:
return ast.literal_eval(cleaned_string)
except (ValueError, SyntaxError) as e:
raise ValueError(f"Failed to parse JSON string: {cleaned_string}") from e
def clean_and_parse_json_string(json_string: str) -> dict[str, Any]:
# Remove markdown code block markers and any newline prefixes
cleaned_string = re.sub(r"```json\n|\n```", "", json_string)
cleaned_string = cleaned_string.replace("\\n", " ").replace("\n", " ")
cleaned_string = " ".join(cleaned_string.split())
# Parse the cleaned string into a Python dictionary
return json.loads(cleaned_string)
def format_entity_term_extraction(
entity_term_extraction_dict: EntityRelationshipTermExtraction,
) -> str:
entities = entity_term_extraction_dict.entities
terms = entity_term_extraction_dict.terms
relationships = entity_term_extraction_dict.relationships
entity_strs = ["\nEntities:\n"]
for entity in entities:
entity_str = f"{entity.entity_name} ({entity.entity_type})"
entity_strs.append(entity_str)
entity_str = "\n - ".join(entity_strs)
relationship_strs = ["\n\nRelationships:\n"]
for relationship in relationships:
relationship_name = relationship.relationship_name
relationship_type = relationship.relationship_type
relationship_entities = relationship.relationship_entities
relationship_str = (
f"""{relationship_name} ({relationship_type}): {relationship_entities}"""
)
relationship_strs.append(relationship_str)
relationship_str = "\n - ".join(relationship_strs)
term_strs = ["\n\nTerms:\n"]
for term in terms:
term_str = f"{term.term_name} ({term.term_type}): similar to {', '.join(term.term_similar_to)}"
term_strs.append(term_str)
term_str = "\n - ".join(term_strs)
return "\n".join(entity_strs + relationship_strs + term_strs)
def _format_time_delta(time: timedelta) -> str:
seconds_from_start = f"{((time).seconds):03d}"
microseconds_from_start = f"{((time).microseconds):06d}"
return f"{seconds_from_start}.{microseconds_from_start}"
def generate_log_message(
message: str,
node_start_time: datetime,
graph_start_time: datetime | None = None,
) -> str:
current_time = datetime.now()
if graph_start_time is not None:
graph_time_str = _format_time_delta(current_time - graph_start_time)
else:
graph_time_str = "N/A"
node_time_str = _format_time_delta(current_time - node_start_time)
return f"{graph_time_str} ({node_time_str} s): {message}"
def get_test_config(
db_session: Session,
primary_llm: LLM,
fast_llm: LLM,
search_request: SearchRequest,
use_agentic_search: bool = True,
) -> tuple[AgentSearchConfig, SearchTool]:
persona = get_persona_by_id(DEFAULT_PERSONA_ID, None, db_session)
document_pruning_config = DocumentPruningConfig(
max_chunks=int(
persona.num_chunks
if persona.num_chunks is not None
else MAX_CHUNKS_FED_TO_CHAT
),
max_window_percentage=CHAT_TARGET_CHUNK_PERCENTAGE,
)
answer_style_config = AnswerStyleConfig(
citation_config=CitationConfig(
# The docs retrieved by this flow are already relevance-filtered
all_docs_useful=True
),
document_pruning_config=document_pruning_config,
structured_response_format=None,
)
search_tool_config = SearchToolConfig(
answer_style_config=answer_style_config,
document_pruning_config=document_pruning_config,
retrieval_options=RetrievalDetails(), # may want to set dedupe_docs=True
rerank_settings=None, # Can use this to change reranking model
selected_sections=None,
latest_query_files=None,
bypass_acl=False,
)
prompt_config = PromptConfig.from_model(persona.prompts[0])
search_tool = SearchTool(
db_session=db_session,
user=None,
persona=persona,
retrieval_options=search_tool_config.retrieval_options,
prompt_config=prompt_config,
llm=primary_llm,
fast_llm=fast_llm,
pruning_config=search_tool_config.document_pruning_config,
answer_style_config=search_tool_config.answer_style_config,
selected_sections=search_tool_config.selected_sections,
chunks_above=search_tool_config.chunks_above,
chunks_below=search_tool_config.chunks_below,
full_doc=search_tool_config.full_doc,
evaluation_type=(
LLMEvaluationType.BASIC
if persona.llm_relevance_filter
else LLMEvaluationType.SKIP
),
rerank_settings=search_tool_config.rerank_settings,
bypass_acl=search_tool_config.bypass_acl,
)
config = AgentSearchConfig(
search_request=search_request,
primary_llm=primary_llm,
fast_llm=fast_llm,
search_tool=search_tool,
force_use_tool=ForceUseTool(force_use=False, tool_name=""),
prompt_builder=AnswerPromptBuilder(
user_message=HumanMessage(content=search_request.query),
message_history=[],
llm_config=primary_llm.config,
raw_user_query=search_request.query,
raw_user_uploaded_files=[],
),
# chat_session_id=UUID("123e4567-e89b-12d3-a456-426614174000"),
chat_session_id=UUID("edda10d5-6cef-45d8-acfb-39317552a1f4"), # Joachim
# chat_session_id=UUID("d1acd613-2692-4bc3-9d65-c6d3da62e58e"), # Evan
message_id=1,
use_agentic_persistence=True,
db_session=db_session,
tools=[search_tool],
use_agentic_search=use_agentic_search,
)
return config, search_tool
def get_persona_agent_prompt_expressions(persona: Persona | None) -> PersonaExpressions:
if persona is None:
persona_prompt = ASSISTANT_SYSTEM_PROMPT_DEFAULT
persona_base = ""
else:
persona_base = "\n".join([x.system_prompt for x in persona.prompts])
persona_prompt = ASSISTANT_SYSTEM_PROMPT_PERSONA.format(
persona_prompt=persona_base
)
return PersonaExpressions(
contextualized_prompt=persona_prompt, base_prompt=persona_base
)
def make_question_id(level: int, question_nr: int) -> str:
return f"{level}_{question_nr}"
def parse_question_id(question_id: str) -> tuple[int, int]:
level, question_nr = question_id.split("_")
return int(level), int(question_nr)
def _dispatch_nonempty(
content: str, dispatch_event: Callable[[str, int], None], num: int
) -> None:
if content != "":
dispatch_event(content, num)
def dispatch_separated(
token_itr: Iterator[BaseMessage],
dispatch_event: Callable[[str, int], None],
sep: str = "\n",
) -> list[str | list[str | dict[str, Any]]]:
num = 1
streamed_tokens: list[str | list[str | dict[str, Any]]] = [""]
for message in token_itr:
content = cast(str, message.content)
if sep in content:
sub_question_parts = content.split(sep)
_dispatch_nonempty(sub_question_parts[0], dispatch_event, num)
num += 1
_dispatch_nonempty(
"".join(sub_question_parts[1:]).strip(), dispatch_event, num
)
else:
_dispatch_nonempty(content, dispatch_event, num)
streamed_tokens.append(content)
return streamed_tokens
def dispatch_main_answer_stop_info(level: int) -> None:
stop_event = StreamStopInfo(
stop_reason=StreamStopReason.FINISHED,
stream_type="main_answer",
level=level,
)
dispatch_custom_event("stream_finished", stop_event)
def get_today_prompt() -> str:
return DATE_PROMPT.format(date=datetime.now().strftime("%A, %B %d, %Y"))
def retrieve_search_docs(
search_tool: SearchTool, question: str
) -> list[InferenceSection]:
retrieved_docs: list[InferenceSection] = []
# new db session to avoid concurrency issues
with get_session_context_manager() as db_session:
for tool_response in search_tool.run(
query=question,
force_no_rerank=True,
alternate_db_session=db_session,
):
# get retrieved docs to send to the rest of the graph
if tool_response.id == SEARCH_RESPONSE_SUMMARY_ID:
response = cast(SearchResponseSummary, tool_response.response)
retrieved_docs = response.top_sections
break
return retrieved_docs
def get_answer_citation_ids(answer_str: str) -> list[int]:
citation_ids = re.findall(r"\[\[D(\d+)\]\]", answer_str)
return list(set([(int(id) - 1) for id in citation_ids]))
def summarize_history(
history: str, question: str, persona_specification: str, model: LLM
) -> str:
history_context_prompt = HISTORY_CONTEXT_SUMMARY_PROMPT.format(
persona_specification=persona_specification, question=question, history=history
)
history_response = model.invoke(history_context_prompt)
if isinstance(history_response.content, str):
history_context_response_str = history_response.content
else:
history_context_response_str = ""
return history_context_response_str

View File

@@ -1,110 +1,81 @@
from collections import defaultdict
from collections.abc import Callable
from collections.abc import Iterator
from uuid import uuid4
from uuid import UUID
from langchain.schema.messages import BaseMessage
from langchain_core.messages import AIMessageChunk
from langchain_core.messages import ToolCall
from sqlalchemy.orm import Session
from onyx.chat.llm_response_handler import LLMResponseHandlerManager
from onyx.chat.models import AnswerQuestionPossibleReturn
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.agents.agent_search.run_graph import run_basic_graph
from onyx.agents.agent_search.run_graph import run_main_graph
from onyx.chat.models import AgentAnswerPiece
from onyx.chat.models import AnswerPacket
from onyx.chat.models import AnswerStream
from onyx.chat.models import AnswerStyleConfig
from onyx.chat.models import CitationInfo
from onyx.chat.models import OnyxAnswerPiece
from onyx.chat.models import PromptConfig
from onyx.chat.models import StreamStopInfo
from onyx.chat.models import StreamStopReason
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_system_message
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_user_message
from onyx.chat.prompt_builder.answer_prompt_builder import LLMCall
from onyx.chat.stream_processing.answer_response_handler import (
CitationResponseHandler,
)
from onyx.chat.stream_processing.answer_response_handler import (
DummyAnswerResponseHandler,
)
from onyx.chat.stream_processing.utils import (
map_document_id_order,
)
from onyx.chat.tool_handling.tool_response_handler import ToolResponseHandler
from onyx.chat.tool_handling.tool_response_handler import get_tool_by_name
from onyx.configs.constants import BASIC_KEY
from onyx.context.search.models import SearchRequest
from onyx.file_store.utils import InMemoryChatFile
from onyx.llm.interfaces import LLM
from onyx.llm.models import PreviousMessage
from onyx.natural_language_processing.utils import get_tokenizer
from onyx.tools.force import ForceUseTool
from onyx.tools.models import ToolResponse
from onyx.tools.tool import Tool
from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.tools.tool_runner import ToolCallKickoff
from onyx.tools.utils import explicit_tool_calling_supported
from onyx.utils.logger import setup_logger
logger = setup_logger()
AnswerStream = Iterator[AnswerQuestionPossibleReturn | ToolCallKickoff | ToolResponse]
class Answer:
def __init__(
self,
question: str,
prompt_builder: AnswerPromptBuilder,
answer_style_config: AnswerStyleConfig,
llm: LLM,
prompt_config: PromptConfig,
fast_llm: LLM,
force_use_tool: ForceUseTool,
# must be the same length as `docs`. If None, all docs are considered "relevant"
message_history: list[PreviousMessage] | None = None,
single_message_history: str | None = None,
search_request: SearchRequest,
chat_session_id: UUID,
current_agent_message_id: int,
# newly passed in files to include as part of this question
# TODO THIS NEEDS TO BE HANDLED
latest_query_files: list[InMemoryChatFile] | None = None,
files: list[InMemoryChatFile] | None = None,
tools: list[Tool] | None = None,
# NOTE: for native tool-calling, this is only supported by OpenAI atm,
# but we only support them anyways
# if set to True, then never use the LLMs provided tool-calling functonality
skip_explicit_tool_calling: bool = False,
# Returns the full document sections text from the search tool
return_contexts: bool = False,
skip_gen_ai_answer_generation: bool = False,
is_connected: Callable[[], bool] | None = None,
db_session: Session | None = None,
use_agentic_search: bool = False,
use_agentic_persistence: bool = True,
) -> None:
if single_message_history and message_history:
raise ValueError(
"Cannot provide both `message_history` and `single_message_history`"
)
self.question = question
self.is_connected: Callable[[], bool] | None = is_connected
self.latest_query_files = latest_query_files or []
self.file_id_to_file = {file.file_id: file for file in (files or [])}
self.tools = tools or []
self.force_use_tool = force_use_tool
self.message_history = message_history or []
# used for QA flow where we only want to send a single message
self.single_message_history = single_message_history
self.answer_style_config = answer_style_config
self.prompt_config = prompt_config
self.llm = llm
self.fast_llm = fast_llm
self.llm_tokenizer = get_tokenizer(
provider_type=llm.config.model_provider,
model_name=llm.config.model_name,
)
self._final_prompt: list[BaseMessage] | None = None
self._streamed_output: list[str] | None = None
self._processed_stream: (
list[AnswerQuestionPossibleReturn | ToolResponse | ToolCallKickoff] | None
) = None
self._processed_stream: (list[AnswerPacket] | None) = None
self._return_contexts = return_contexts
self.skip_gen_ai_answer_generation = skip_gen_ai_answer_generation
self._is_cancelled = False
@@ -115,167 +86,76 @@ class Answer:
and not skip_explicit_tool_calling
)
search_tools = [tool for tool in (tools or []) if isinstance(tool, SearchTool)]
search_tool: SearchTool | None = None
if len(search_tools) > 1:
# TODO: handle multiple search tools
raise ValueError("Multiple search tools found")
elif len(search_tools) == 1:
search_tool = search_tools[0]
using_tool_calling_llm = explicit_tool_calling_supported(
llm.config.model_provider, llm.config.model_name
)
self.agent_search_config = AgentSearchConfig(
search_request=search_request,
primary_llm=llm,
fast_llm=fast_llm,
search_tool=search_tool,
force_use_tool=force_use_tool,
use_agentic_search=use_agentic_search,
chat_session_id=chat_session_id,
message_id=current_agent_message_id,
use_agentic_persistence=use_agentic_persistence,
allow_refinement=True,
db_session=db_session,
prompt_builder=prompt_builder,
tools=tools,
using_tool_calling_llm=using_tool_calling_llm,
files=latest_query_files,
structured_response_format=answer_style_config.structured_response_format,
skip_gen_ai_answer_generation=skip_gen_ai_answer_generation,
)
self.db_session = db_session
def _get_tools_list(self) -> list[Tool]:
if not self.force_use_tool.force_use:
return self.tools
tool = next(
(t for t in self.tools if t.name == self.force_use_tool.tool_name), None
)
if tool is None:
raise RuntimeError(f"Tool '{self.force_use_tool.tool_name}' not found")
tool = get_tool_by_name(self.tools, self.force_use_tool.tool_name)
logger.info(
f"Forcefully using tool='{tool.name}'"
+ (
f" with args='{self.force_use_tool.args}'"
if self.force_use_tool.args is not None
else ""
)
args_str = (
f" with args='{self.force_use_tool.args}'"
if self.force_use_tool.args
else ""
)
logger.info(f"Forcefully using tool='{tool.name}'{args_str}")
return [tool]
def _handle_specified_tool_call(
self, llm_calls: list[LLMCall], tool: Tool, tool_args: dict
) -> AnswerStream:
current_llm_call = llm_calls[-1]
# make a dummy tool handler
tool_handler = ToolResponseHandler([tool])
dummy_tool_call_chunk = AIMessageChunk(content="")
dummy_tool_call_chunk.tool_calls = [
ToolCall(name=tool.name, args=tool_args, id=str(uuid4()))
]
response_handler_manager = LLMResponseHandlerManager(
tool_handler, DummyAnswerResponseHandler(), self.is_cancelled
)
yield from response_handler_manager.handle_llm_response(
iter([dummy_tool_call_chunk])
)
new_llm_call = response_handler_manager.next_llm_call(current_llm_call)
if new_llm_call:
yield from self._get_response(llm_calls + [new_llm_call])
else:
raise RuntimeError("Tool call handler did not return a new LLM call")
def _get_response(self, llm_calls: list[LLMCall]) -> AnswerStream:
current_llm_call = llm_calls[-1]
# handle the case where no decision has to be made; we simply run the tool
if (
current_llm_call.force_use_tool.force_use
and current_llm_call.force_use_tool.args is not None
):
tool_name, tool_args = (
current_llm_call.force_use_tool.tool_name,
current_llm_call.force_use_tool.args,
)
tool = next(
(t for t in current_llm_call.tools if t.name == tool_name), None
)
if not tool:
raise RuntimeError(f"Tool '{tool_name}' not found")
yield from self._handle_specified_tool_call(llm_calls, tool, tool_args)
return
# special pre-logic for non-tool calling LLM case
if not self.using_tool_calling_llm and current_llm_call.tools:
chosen_tool_and_args = (
ToolResponseHandler.get_tool_call_for_non_tool_calling_llm(
current_llm_call, self.llm
)
)
if chosen_tool_and_args:
tool, tool_args = chosen_tool_and_args
yield from self._handle_specified_tool_call(llm_calls, tool, tool_args)
return
# if we're skipping gen ai answer generation, we should break
# out unless we're forcing a tool call. If we don't, we might generate an
# answer, which is a no-no!
if (
self.skip_gen_ai_answer_generation
and not current_llm_call.force_use_tool.force_use
):
return
# set up "handlers" to listen to the LLM response stream and
# feed back the processed results + handle tool call requests
# + figure out what the next LLM call should be
tool_call_handler = ToolResponseHandler(current_llm_call.tools)
final_search_results, displayed_search_results = SearchTool.get_search_result(
current_llm_call
) or ([], [])
answer_handler = CitationResponseHandler(
context_docs=final_search_results,
final_doc_id_to_rank_map=map_document_id_order(final_search_results),
display_doc_id_to_rank_map=map_document_id_order(displayed_search_results),
)
response_handler_manager = LLMResponseHandlerManager(
tool_call_handler, answer_handler, self.is_cancelled
)
# DEBUG: good breakpoint
stream = self.llm.stream(
# For tool calling LLMs, we want to insert the task prompt as part of this flow, this is because the LLM
# may choose to not call any tools and just generate the answer, in which case the task prompt is needed.
prompt=current_llm_call.prompt_builder.build(),
tools=[tool.tool_definition() for tool in current_llm_call.tools] or None,
tool_choice=(
"required"
if current_llm_call.tools and current_llm_call.force_use_tool.force_use
else None
),
structured_response_format=self.answer_style_config.structured_response_format,
)
yield from response_handler_manager.handle_llm_response(stream)
new_llm_call = response_handler_manager.next_llm_call(current_llm_call)
if new_llm_call:
yield from self._get_response(llm_calls + [new_llm_call])
@property
def processed_streamed_output(self) -> AnswerStream:
if self._processed_stream is not None:
yield from self._processed_stream
return
prompt_builder = AnswerPromptBuilder(
user_message=default_build_user_message(
user_query=self.question,
prompt_config=self.prompt_config,
files=self.latest_query_files,
single_message_history=self.single_message_history,
),
message_history=self.message_history,
llm_config=self.llm.config,
raw_user_query=self.question,
raw_user_uploaded_files=self.latest_query_files or [],
single_message_history=self.single_message_history,
run_langgraph = (
run_main_graph
if self.agent_search_config.use_agentic_search
else run_basic_graph
)
prompt_builder.update_system_prompt(
default_build_system_message(self.prompt_config)
)
llm_call = LLMCall(
prompt_builder=prompt_builder,
tools=self._get_tools_list(),
force_use_tool=self.force_use_tool,
files=self.latest_query_files,
tool_call_info=[],
using_tool_calling_llm=self.using_tool_calling_llm,
stream = run_langgraph(
self.agent_search_config,
)
processed_stream = []
for processed_packet in self._get_response([llm_call]):
processed_stream.append(processed_packet)
yield processed_packet
for packet in stream:
if self.is_cancelled():
packet = StreamStopInfo(stop_reason=StreamStopReason.CANCELLED)
yield packet
break
processed_stream.append(packet)
yield packet
self._processed_stream = processed_stream
@@ -283,20 +163,57 @@ class Answer:
def llm_answer(self) -> str:
answer = ""
for packet in self.processed_streamed_output:
if isinstance(packet, OnyxAnswerPiece) and packet.answer_piece:
# handle basic answer flow, plus level 0 agent answer flow
# since level 0 is the first answer the user sees and therefore the
# child message of the user message in the db (so it is handled
# like a basic flow answer)
if (isinstance(packet, OnyxAnswerPiece) and packet.answer_piece) or (
isinstance(packet, AgentAnswerPiece)
and packet.answer_piece
and packet.answer_type == "agent_level_answer"
and packet.level == 0
):
answer += packet.answer_piece
return answer
def llm_answer_by_level(self) -> dict[int, str]:
answer_by_level: dict[int, str] = defaultdict(str)
for packet in self.processed_streamed_output:
if (
isinstance(packet, AgentAnswerPiece)
and packet.answer_piece
and packet.answer_type == "agent_level_answer"
):
answer_by_level[packet.level] += packet.answer_piece
elif isinstance(packet, OnyxAnswerPiece) and packet.answer_piece:
answer_by_level[BASIC_KEY[0]] += packet.answer_piece
return answer_by_level
@property
def citations(self) -> list[CitationInfo]:
citations: list[CitationInfo] = []
for packet in self.processed_streamed_output:
if isinstance(packet, CitationInfo):
if isinstance(packet, CitationInfo) and packet.level is None:
citations.append(packet)
return citations
# TODO: replace tuple of ints with SubQuestionId EVERYWHERE
def citations_by_subquestion(self) -> dict[tuple[int, int], list[CitationInfo]]:
citations_by_subquestion: dict[
tuple[int, int], list[CitationInfo]
] = defaultdict(list)
for packet in self.processed_streamed_output:
if isinstance(packet, CitationInfo):
if packet.level_question_nr is not None and packet.level is not None:
citations_by_subquestion[
(packet.level, packet.level_question_nr)
].append(packet)
elif packet.level is None:
citations_by_subquestion[BASIC_KEY].append(packet)
return citations_by_subquestion
def is_cancelled(self) -> bool:
if self._is_cancelled:
return True

View File

@@ -48,6 +48,8 @@ def prepare_chat_message_request(
retrieval_details: RetrievalDetails | None,
rerank_settings: RerankingDetails | None,
db_session: Session,
use_agentic_search: bool = False,
skip_gen_ai_answer_generation: bool = False,
) -> CreateChatMessageRequest:
# Typically used for one shot flows like SlackBot or non-chat API endpoint use cases
new_chat_session = create_chat_session(
@@ -72,6 +74,8 @@ def prepare_chat_message_request(
search_doc_ids=None,
retrieval_options=retrieval_details,
rerank_settings=rerank_settings,
use_agentic_search=use_agentic_search,
skip_gen_ai_answer_generation=skip_gen_ai_answer_generation,
)

View File

@@ -9,25 +9,37 @@ from onyx.chat.models import StreamStopInfo
from onyx.chat.models import StreamStopReason
from onyx.chat.prompt_builder.answer_prompt_builder import LLMCall
from onyx.chat.stream_processing.answer_response_handler import AnswerResponseHandler
from onyx.chat.stream_processing.answer_response_handler import (
DummyAnswerResponseHandler,
)
from onyx.chat.tool_handling.tool_response_handler import ToolResponseHandler
class LLMResponseHandlerManager:
"""
This class is responsible for postprocessing the LLM response stream.
In particular, we:
1. handle the tool call requests
2. handle citations
3. pass through answers generated by the LLM
4. Stop yielding if the client disconnects
"""
def __init__(
self,
tool_handler: ToolResponseHandler,
answer_handler: AnswerResponseHandler,
tool_handler: ToolResponseHandler | None,
answer_handler: AnswerResponseHandler | None,
is_cancelled: Callable[[], bool],
):
self.tool_handler = tool_handler
self.answer_handler = answer_handler
self.tool_handler = tool_handler or ToolResponseHandler([])
self.answer_handler = answer_handler or DummyAnswerResponseHandler()
self.is_cancelled = is_cancelled
def handle_llm_response(
self,
stream: Iterator[BaseMessage],
) -> Generator[ResponsePart, None, None]:
all_messages: list[BaseMessage] = []
all_messages: list[BaseMessage | str] = []
for message in stream:
if self.is_cancelled():
yield StreamStopInfo(stop_reason=StreamStopReason.CANCELLED)

View File

@@ -3,6 +3,7 @@ from collections.abc import Iterator
from datetime import datetime
from enum import Enum
from typing import Any
from typing import Literal
from typing import TYPE_CHECKING
from pydantic import BaseModel
@@ -48,6 +49,8 @@ class QADocsResponse(RetrievalDocs):
applied_source_filters: list[DocumentSource] | None
applied_time_cutoff: datetime | None
recency_bias_multiplier: float
level: int | None = None
level_question_nr: int | None = None
def model_dump(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
initial_dict = super().model_dump(mode="json", *args, **kwargs) # type: ignore
@@ -61,11 +64,17 @@ class QADocsResponse(RetrievalDocs):
class StreamStopReason(Enum):
CONTEXT_LENGTH = "context_length"
CANCELLED = "cancelled"
FINISHED = "finished"
class StreamStopInfo(BaseModel):
stop_reason: StreamStopReason
stream_type: Literal["", "sub_questions", "sub_answer", "main_answer"] = ""
# used to identify the stream that was stopped for agent search
level: int | None = None
level_question_nr: int | None = None
def model_dump(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
data = super().model_dump(mode="json", *args, **kwargs) # type: ignore
data["stop_reason"] = self.stop_reason.name
@@ -108,6 +117,8 @@ class OnyxAnswerPiece(BaseModel):
class CitationInfo(BaseModel):
citation_num: int
document_id: str
level: int | None = None
level_question_nr: int | None = None
class AllCitations(BaseModel):
@@ -273,7 +284,7 @@ class AnswerStyleConfig(BaseModel):
class PromptConfig(BaseModel):
"""Final representation of the Prompt configuration passed
into the `Answer` object."""
into the `PromptBuilder` object."""
system_prompt: str
task_prompt: str
@@ -299,6 +310,48 @@ class PromptConfig(BaseModel):
model_config = ConfigDict(frozen=True)
class SubQueryPiece(BaseModel):
sub_query: str
level: int
level_question_nr: int
query_id: int
class AgentAnswerPiece(BaseModel):
answer_piece: str
level: int
level_question_nr: int
answer_type: Literal["agent_sub_answer", "agent_level_answer"]
class SubQuestionPiece(BaseModel):
sub_question: str
level: int
level_question_nr: int
class ExtendedToolResponse(ToolResponse):
level: int
level_question_nr: int
class RefinedAnswerImprovement(BaseModel):
refined_answer_improvement: bool
AgentSearchPacket = (
SubQuestionPiece
| AgentAnswerPiece
| SubQueryPiece
| ExtendedToolResponse
| RefinedAnswerImprovement
)
AnswerPacket = (
AnswerQuestionPossibleReturn | AgentSearchPacket | ToolCallKickoff | ToolResponse
)
ResponsePart = (
OnyxAnswerPiece
| CitationInfo
@@ -306,4 +359,7 @@ ResponsePart = (
| ToolResponse
| ToolCallFinalResult
| StreamStopInfo
| AgentSearchPacket
)
AnswerStream = Iterator[AnswerPacket]

View File

@@ -1,6 +1,8 @@
import traceback
from collections import defaultdict
from collections.abc import Callable
from collections.abc import Iterator
from dataclasses import dataclass
from functools import partial
from typing import cast
@@ -9,6 +11,7 @@ from sqlalchemy.orm import Session
from onyx.chat.answer import Answer
from onyx.chat.chat_utils import create_chat_chain
from onyx.chat.chat_utils import create_temporary_persona
from onyx.chat.models import AgentSearchPacket
from onyx.chat.models import AllCitations
from onyx.chat.models import AnswerStyleConfig
from onyx.chat.models import ChatOnyxBotResponse
@@ -16,6 +19,7 @@ from onyx.chat.models import CitationConfig
from onyx.chat.models import CitationInfo
from onyx.chat.models import CustomToolResponse
from onyx.chat.models import DocumentPruningConfig
from onyx.chat.models import ExtendedToolResponse
from onyx.chat.models import FileChatDisplay
from onyx.chat.models import FinalUsedContextDocsResponse
from onyx.chat.models import LLMRelevanceFilterResponse
@@ -25,19 +29,28 @@ from onyx.chat.models import OnyxAnswerPiece
from onyx.chat.models import OnyxContexts
from onyx.chat.models import PromptConfig
from onyx.chat.models import QADocsResponse
from onyx.chat.models import RefinedAnswerImprovement
from onyx.chat.models import StreamingError
from onyx.chat.models import StreamStopInfo
from onyx.chat.models import StreamStopReason
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_system_message
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_user_message
from onyx.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
from onyx.configs.chat_configs import DISABLE_LLM_CHOOSE_SEARCH
from onyx.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
from onyx.configs.constants import AGENT_SEARCH_INITIAL_KEY
from onyx.configs.constants import BASIC_KEY
from onyx.configs.constants import MessageType
from onyx.configs.constants import MilestoneRecordType
from onyx.configs.constants import NO_AUTH_USER_ID
from onyx.context.search.enums import LLMEvaluationType
from onyx.context.search.enums import OptionalSearchSetting
from onyx.context.search.enums import QueryFlow
from onyx.context.search.enums import SearchType
from onyx.context.search.models import InferenceSection
from onyx.context.search.models import RetrievalDetails
from onyx.context.search.models import SearchRequest
from onyx.context.search.retrieval.search_runner import inference_sections_from_ids
from onyx.context.search.utils import chunks_or_sections_to_search_docs
from onyx.context.search.utils import dedupe_documents
@@ -127,7 +140,6 @@ from onyx.utils.timing import log_function_time
from onyx.utils.timing import log_generator_function_time
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
logger = setup_logger()
@@ -159,12 +171,15 @@ def _handle_search_tool_response_summary(
) -> tuple[QADocsResponse, list[DbSearchDoc], list[int] | None]:
response_sumary = cast(SearchResponseSummary, packet.response)
is_extended = isinstance(packet, ExtendedToolResponse)
dropped_inds = None
if not selected_search_docs:
top_docs = chunks_or_sections_to_search_docs(response_sumary.top_sections)
deduped_docs = top_docs
if dedupe_docs:
if (
dedupe_docs and not is_extended
): # Extended tool responses are already deduped
deduped_docs, dropped_inds = dedupe_documents(top_docs)
reference_db_search_docs = [
@@ -178,6 +193,10 @@ def _handle_search_tool_response_summary(
translate_db_search_doc_to_server_search_doc(db_search_doc)
for db_search_doc in reference_db_search_docs
]
level, question_nr = None, None
if isinstance(packet, ExtendedToolResponse):
level, question_nr = packet.level, packet.level_question_nr
return (
QADocsResponse(
rephrased_query=response_sumary.rephrased_query,
@@ -187,6 +206,8 @@ def _handle_search_tool_response_summary(
applied_source_filters=response_sumary.final_filters.source_type,
applied_time_cutoff=response_sumary.final_filters.time_cutoff,
recency_bias_multiplier=response_sumary.recency_bias_multiplier,
level=level,
level_question_nr=question_nr,
),
reference_db_search_docs,
dropped_inds,
@@ -282,10 +303,22 @@ ChatPacket = (
| MessageSpecificCitations
| MessageResponseIDInfo
| StreamStopInfo
| AgentSearchPacket
)
ChatPacketStream = Iterator[ChatPacket]
# can't store a DbSearchDoc in a Pydantic BaseModel
@dataclass
class AnswerPostInfo:
ai_message_files: list[FileDescriptor]
qa_docs_response: QADocsResponse | None = None
reference_db_search_docs: list[DbSearchDoc] | None = None
dropped_indices: list[int] | None = None
tool_result: ToolCallFinalResult | None = None
message_specific_citations: MessageSpecificCitations | None = None
def stream_chat_message_objects(
new_msg_req: CreateChatMessageRequest,
user: User | None,
@@ -324,6 +357,7 @@ def stream_chat_message_objects(
new_msg_req.chunks_above = 0
new_msg_req.chunks_below = 0
llm = None
try:
user_id = user.id if user is not None else None
@@ -502,11 +536,8 @@ def stream_chat_message_objects(
files = load_all_chat_files(
history_msgs, new_msg_req.file_descriptors, db_session
)
latest_query_files = [
file
for file in files
if file.file_id in [f["id"] for f in new_msg_req.file_descriptors]
]
req_file_ids = [f["id"] for f in new_msg_req.file_descriptors]
latest_query_files = [file for file in files if file.file_id in req_file_ids]
if user_message:
attach_files_to_chat_message(
@@ -679,13 +710,58 @@ def stream_chat_message_objects(
for tool_list in tool_dict.values():
tools.extend(tool_list)
# TODO: unify message history with single message history
message_history = [
PreviousMessage.from_chat_message(msg, files) for msg in history_msgs
]
search_request = SearchRequest(
query=final_msg.message,
evaluation_type=(
LLMEvaluationType.BASIC
if persona.llm_relevance_filter
else LLMEvaluationType.SKIP
),
human_selected_filters=(
retrieval_options.filters if retrieval_options else None
),
persona=persona,
offset=(retrieval_options.offset if retrieval_options else None),
limit=retrieval_options.limit if retrieval_options else None,
rerank_settings=new_msg_req.rerank_settings,
chunks_above=new_msg_req.chunks_above,
chunks_below=new_msg_req.chunks_below,
full_doc=new_msg_req.full_doc,
enable_auto_detect_filters=(
retrieval_options.enable_auto_detect_filters
if retrieval_options
else None
),
)
force_use_tool = _get_force_search_settings(new_msg_req, tools)
prompt_builder = AnswerPromptBuilder(
user_message=default_build_user_message(
user_query=final_msg.message,
prompt_config=prompt_config,
files=latest_query_files,
single_message_history=single_message_history,
),
system_message=default_build_system_message(prompt_config),
message_history=message_history,
llm_config=llm.config,
raw_user_query=final_msg.message,
raw_user_uploaded_files=latest_query_files or [],
single_message_history=single_message_history,
)
prompt_builder.update_system_prompt(default_build_system_message(prompt_config))
# LLM prompt building, response capturing, etc.
answer = Answer(
prompt_builder=prompt_builder,
is_connected=is_connected,
question=final_msg.message,
latest_query_files=latest_query_files,
answer_style_config=answer_style_config,
prompt_config=prompt_config,
llm=(
llm
or get_main_llm_from_tuple(
@@ -698,28 +774,42 @@ def stream_chat_message_objects(
)
)
),
message_history=[
PreviousMessage.from_chat_message(msg, files) for msg in history_msgs
],
fast_llm=fast_llm,
force_use_tool=force_use_tool,
search_request=search_request,
chat_session_id=chat_session_id,
current_agent_message_id=reserved_message_id,
tools=tools,
force_use_tool=_get_force_search_settings(new_msg_req, tools),
single_message_history=single_message_history,
db_session=db_session,
use_agentic_search=new_msg_req.use_agentic_search,
)
reference_db_search_docs = None
qa_docs_response = None
# any files to associate with the AI message e.g. dall-e generated images
ai_message_files = []
dropped_indices = None
tool_result = None
# reference_db_search_docs = None
# qa_docs_response = None
# # any files to associate with the AI message e.g. dall-e generated images
# ai_message_files = []
# dropped_indices = None
# tool_result = None
# TODO: different channels for stored info when it's coming from the agent flow
info_by_subq: dict[tuple[int, int], AnswerPostInfo] = defaultdict(
lambda: AnswerPostInfo(ai_message_files=[])
)
refined_answer_improvement = True
for packet in answer.processed_streamed_output:
if isinstance(packet, ToolResponse):
level, level_question_nr = (
(packet.level, packet.level_question_nr)
if isinstance(packet, ExtendedToolResponse)
else BASIC_KEY
)
info = info_by_subq[(level, level_question_nr)]
# TODO: don't need to dedupe here when we do it in agent flow
if packet.id == SEARCH_RESPONSE_SUMMARY_ID:
(
qa_docs_response,
reference_db_search_docs,
dropped_indices,
info.qa_docs_response,
info.reference_db_search_docs,
info.dropped_indices,
) = _handle_search_tool_response_summary(
packet=packet,
db_session=db_session,
@@ -731,29 +821,34 @@ def stream_chat_message_objects(
else False
),
)
yield qa_docs_response
yield info.qa_docs_response
elif packet.id == SECTION_RELEVANCE_LIST_ID:
relevance_sections = packet.response
if reference_db_search_docs is not None:
llm_indices = relevant_sections_to_indices(
relevance_sections=relevance_sections,
items=[
translate_db_search_doc_to_server_search_doc(doc)
for doc in reference_db_search_docs
],
if info.reference_db_search_docs is None:
logger.warning(
"No reference docs found for relevance filtering"
)
continue
llm_indices = relevant_sections_to_indices(
relevance_sections=relevance_sections,
items=[
translate_db_search_doc_to_server_search_doc(doc)
for doc in info.reference_db_search_docs
],
)
if info.dropped_indices:
llm_indices = drop_llm_indices(
llm_indices=llm_indices,
search_docs=info.reference_db_search_docs,
dropped_indices=info.dropped_indices,
)
if dropped_indices:
llm_indices = drop_llm_indices(
llm_indices=llm_indices,
search_docs=reference_db_search_docs,
dropped_indices=dropped_indices,
)
yield LLMRelevanceFilterResponse(
llm_selected_doc_indices=llm_indices
)
yield LLMRelevanceFilterResponse(
llm_selected_doc_indices=llm_indices
)
elif packet.id == FINAL_CONTEXT_DOCUMENTS_ID:
yield FinalUsedContextDocsResponse(
final_context_docs=packet.response
@@ -773,22 +868,24 @@ def stream_chat_message_objects(
],
tenant_id=tenant_id,
)
ai_message_files = [
FileDescriptor(id=str(file_id), type=ChatFileType.IMAGE)
for file_id in file_ids
]
info.ai_message_files.extend(
[
FileDescriptor(id=str(file_id), type=ChatFileType.IMAGE)
for file_id in file_ids
]
)
yield FileChatDisplay(
file_ids=[str(file_id) for file_id in file_ids]
)
elif packet.id == INTERNET_SEARCH_RESPONSE_ID:
(
qa_docs_response,
reference_db_search_docs,
info.qa_docs_response,
info.reference_db_search_docs,
) = _handle_internet_search_tool_response_summary(
packet=packet,
db_session=db_session,
)
yield qa_docs_response
yield info.qa_docs_response
elif packet.id == CUSTOM_TOOL_RESPONSE_ID:
custom_tool_response = cast(CustomToolCallSummary, packet.response)
@@ -797,7 +894,7 @@ def stream_chat_message_objects(
or custom_tool_response.response_type == "csv"
):
file_ids = custom_tool_response.tool_result.file_ids
ai_message_files.extend(
info.ai_message_files.extend(
[
FileDescriptor(
id=str(file_id),
@@ -822,10 +919,21 @@ def stream_chat_message_objects(
yield cast(OnyxContexts, packet.response)
elif isinstance(packet, StreamStopInfo):
pass
if packet.stop_reason == StreamStopReason.FINISHED:
yield packet
elif isinstance(packet, RefinedAnswerImprovement):
refined_answer_improvement = packet.refined_answer_improvement
yield packet
else:
if isinstance(packet, ToolCallFinalResult):
tool_result = packet
level, level_question_nr = (
(packet.level, packet.level_question_nr)
if packet.level is not None
and packet.level_question_nr is not None
else BASIC_KEY
)
info = info_by_subq[(level, level_question_nr)]
info.tool_result = packet
yield cast(ChatPacket, packet)
logger.debug("Reached end of stream")
except ValueError as e:
@@ -841,59 +949,99 @@ def stream_chat_message_objects(
error_msg = str(e)
stack_trace = traceback.format_exc()
client_error_msg = litellm_exception_to_error_msg(e, llm)
if llm.config.api_key and len(llm.config.api_key) > 2:
error_msg = error_msg.replace(llm.config.api_key, "[REDACTED_API_KEY]")
stack_trace = stack_trace.replace(llm.config.api_key, "[REDACTED_API_KEY]")
if llm:
client_error_msg = litellm_exception_to_error_msg(e, llm)
if llm.config.api_key and len(llm.config.api_key) > 2:
error_msg = error_msg.replace(llm.config.api_key, "[REDACTED_API_KEY]")
stack_trace = stack_trace.replace(
llm.config.api_key, "[REDACTED_API_KEY]"
)
yield StreamingError(error=client_error_msg, stack_trace=stack_trace)
yield StreamingError(error=client_error_msg, stack_trace=stack_trace)
db_session.rollback()
return
# Post-LLM answer processing
try:
logger.debug("Post-LLM answer processing")
message_specific_citations: MessageSpecificCitations | None = None
if reference_db_search_docs:
message_specific_citations = _translate_citations(
citations_list=answer.citations,
db_docs=reference_db_search_docs,
)
if not answer.is_cancelled():
yield AllCitations(citations=answer.citations)
# Saving Gen AI answer and responding with message info
tool_name_to_tool_id: dict[str, int] = {}
for tool_id, tool_list in tool_dict.items():
for tool in tool_list:
tool_name_to_tool_id[tool.name] = tool_id
subq_citations = answer.citations_by_subquestion()
for pair in subq_citations:
level, level_question_nr = pair
info = info_by_subq[(level, level_question_nr)]
logger.debug("Post-LLM answer processing")
if info.reference_db_search_docs:
info.message_specific_citations = _translate_citations(
citations_list=subq_citations[pair],
db_docs=info.reference_db_search_docs,
)
# TODO: AllCitations should contain subq info?
if not answer.is_cancelled():
yield AllCitations(citations=subq_citations[pair])
# Saving Gen AI answer and responding with message info
info = (
info_by_subq[BASIC_KEY]
if BASIC_KEY in info_by_subq
else info_by_subq[AGENT_SEARCH_INITIAL_KEY]
)
gen_ai_response_message = partial_response(
message=answer.llm_answer,
rephrased_query=(
qa_docs_response.rephrased_query if qa_docs_response else None
info.qa_docs_response.rephrased_query if info.qa_docs_response else None
),
reference_docs=reference_db_search_docs,
files=ai_message_files,
reference_docs=info.reference_db_search_docs,
files=info.ai_message_files,
token_count=len(llm_tokenizer_encode_func(answer.llm_answer)),
citations=(
message_specific_citations.citation_map
if message_specific_citations
info.message_specific_citations.citation_map
if info.message_specific_citations
else None
),
error=None,
tool_call=(
ToolCall(
tool_id=tool_name_to_tool_id[tool_result.tool_name],
tool_name=tool_result.tool_name,
tool_arguments=tool_result.tool_args,
tool_result=tool_result.tool_result,
tool_id=tool_name_to_tool_id[info.tool_result.tool_name],
tool_name=info.tool_result.tool_name,
tool_arguments=info.tool_result.tool_args,
tool_result=info.tool_result.tool_result,
)
if tool_result
if info.tool_result
else None
),
)
# TODO: add answers for levels >= 1, where each level has the previous as its parent. Use
# the answer_by_level method in answer.py to get the answers for each level
next_level = 1
prev_message = gen_ai_response_message
agent_answers = answer.llm_answer_by_level()
while next_level in agent_answers:
next_answer = agent_answers[next_level]
info = info_by_subq[(next_level, AGENT_SEARCH_INITIAL_KEY[1])]
next_answer_message = create_new_chat_message(
chat_session_id=chat_session_id,
parent_message=prev_message,
message=next_answer,
prompt_id=None,
token_count=len(llm_tokenizer_encode_func(next_answer)),
message_type=MessageType.ASSISTANT,
db_session=db_session,
files=info.ai_message_files,
reference_docs=info.reference_db_search_docs,
citations=info.message_specific_citations.citation_map
if info.message_specific_citations
else None,
refined_answer_improvement=refined_answer_improvement,
)
next_level += 1
prev_message = next_answer_message
logger.debug("Committing messages")
db_session.commit() # actually save user / assistant message

View File

@@ -4,6 +4,7 @@ from typing import cast
from langchain_core.messages import BaseMessage
from langchain_core.messages import HumanMessage
from langchain_core.messages import SystemMessage
from pydantic import BaseModel
from pydantic.v1 import BaseModel as BaseModel__v1
from onyx.chat.models import PromptConfig
@@ -84,6 +85,7 @@ class AnswerPromptBuilder:
raw_user_query: str,
raw_user_uploaded_files: list[InMemoryChatFile],
single_message_history: str | None = None,
system_message: SystemMessage | None = None,
) -> None:
self.max_tokens = compute_max_llm_input_tokens(llm_config)
@@ -108,7 +110,14 @@ class AnswerPromptBuilder:
),
)
self.system_message_and_token_cnt: tuple[SystemMessage, int] | None = None
self.system_message_and_token_cnt: tuple[SystemMessage, int] | None = (
(
system_message,
check_message_tokens(system_message, self.llm_tokenizer_encode_func),
)
if system_message
else None
)
self.user_message_and_token_cnt = (
user_message,
check_message_tokens(
@@ -174,6 +183,14 @@ class AnswerPromptBuilder:
)
# Stores some parts of a prompt builder as needed for tool calls
class PromptSnapshot(BaseModel):
raw_message_history: list[PreviousMessage]
raw_user_query: str
built_prompt: list[BaseMessage]
# TODO: rename this? AnswerConfig maybe?
class LLMCall(BaseModel__v1):
prompt_builder: AnswerPromptBuilder
tools: list[Tool]

View File

@@ -3,9 +3,10 @@ from collections.abc import Generator
from langchain_core.messages import BaseMessage
from onyx.chat.llm_response_handler import ResponsePart
from onyx.chat.models import CitationInfo
from onyx.chat.models import LlmDoc
from onyx.chat.models import OnyxAnswerPiece
from onyx.chat.models import ResponsePart
from onyx.chat.stream_processing.citation_processing import CitationProcessor
from onyx.chat.stream_processing.utils import DocumentIdOrderMapping
from onyx.utils.logger import setup_logger
@@ -13,21 +14,32 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
# TODO: remove update() once it is no longer needed
class AnswerResponseHandler(abc.ABC):
@abc.abstractmethod
def handle_response_part(
self,
response_item: BaseMessage | None,
previous_response_items: list[BaseMessage],
response_item: BaseMessage | str | None,
previous_response_items: list[BaseMessage | str],
) -> Generator[ResponsePart, None, None]:
raise NotImplementedError
class PassThroughAnswerResponseHandler(AnswerResponseHandler):
def handle_response_part(
self,
response_item: BaseMessage | str | None,
previous_response_items: list[BaseMessage | str],
) -> Generator[ResponsePart, None, None]:
content = _message_to_str(response_item)
yield OnyxAnswerPiece(answer_piece=content)
class DummyAnswerResponseHandler(AnswerResponseHandler):
def handle_response_part(
self,
response_item: BaseMessage | None,
previous_response_items: list[BaseMessage],
response_item: BaseMessage | str | None,
previous_response_items: list[BaseMessage | str],
) -> Generator[ResponsePart, None, None]:
# This is a dummy handler that returns nothing
yield from []
@@ -56,43 +68,25 @@ class CitationResponseHandler(AnswerResponseHandler):
def handle_response_part(
self,
response_item: BaseMessage | None,
previous_response_items: list[BaseMessage],
response_item: BaseMessage | str | None,
previous_response_items: list[BaseMessage | str],
) -> Generator[ResponsePart, None, None]:
if response_item is None:
return
content = (
response_item.content if isinstance(response_item.content, str) else ""
)
content = _message_to_str(response_item)
# Process the new content through the citation processor
yield from self.citation_processor.process_token(content)
# No longer in use, remove later
# class QuotesResponseHandler(AnswerResponseHandler):
# def __init__(
# self,
# context_docs: list[LlmDoc],
# is_json_prompt: bool = True,
# ):
# self.quotes_processor = QuotesProcessor(
# context_docs=context_docs,
# is_json_prompt=is_json_prompt,
# )
# def handle_response_part(
# self,
# response_item: BaseMessage | None,
# previous_response_items: list[BaseMessage],
# ) -> Generator[ResponsePart, None, None]:
# if response_item is None:
# yield from self.quotes_processor.process_token(None)
# return
# content = (
# response_item.content if isinstance(response_item.content, str) else ""
# )
# yield from self.quotes_processor.process_token(content)
def _message_to_str(message: BaseMessage | str | None) -> str:
if message is None:
return ""
if isinstance(message, str):
return message
content = message.content if isinstance(message, BaseMessage) else message
if not isinstance(content, str):
logger.warning(f"Received non-string content: {type(content)}")
content = str(content) if content is not None else ""
return content

View File

@@ -5,7 +5,9 @@ from langchain_core.messages import BaseMessage
from langchain_core.messages import ToolCall
from onyx.chat.models import ResponsePart
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.chat.prompt_builder.answer_prompt_builder import LLMCall
from onyx.chat.prompt_builder.answer_prompt_builder import PromptSnapshot
from onyx.llm.interfaces import LLM
from onyx.tools.force import ForceUseTool
from onyx.tools.message import build_tool_message
@@ -25,6 +27,13 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
def get_tool_by_name(tools: list[Tool], tool_name: str) -> Tool:
for tool in tools:
if tool.name == tool_name:
return tool
raise RuntimeError(f"Tool '{tool_name}' not found")
class ToolResponseHandler:
def __init__(self, tools: list[Tool]):
self.tools = tools
@@ -43,67 +52,12 @@ class ToolResponseHandler:
def get_tool_call_for_non_tool_calling_llm(
cls, llm_call: LLMCall, llm: LLM
) -> tuple[Tool, dict] | None:
if llm_call.force_use_tool.force_use:
# if we are forcing a tool, we don't need to check which tools to run
tool = next(
(
t
for t in llm_call.tools
if t.name == llm_call.force_use_tool.tool_name
),
None,
)
if not tool:
raise RuntimeError(
f"Tool '{llm_call.force_use_tool.tool_name}' not found"
)
tool_args = (
llm_call.force_use_tool.args
if llm_call.force_use_tool.args is not None
else tool.get_args_for_non_tool_calling_llm(
query=llm_call.prompt_builder.raw_user_query,
history=llm_call.prompt_builder.raw_message_history,
llm=llm,
force_run=True,
)
)
if tool_args is None:
raise RuntimeError(f"Tool '{tool.name}' did not return args")
return (tool, tool_args)
else:
tool_options = check_which_tools_should_run_for_non_tool_calling_llm(
tools=llm_call.tools,
query=llm_call.prompt_builder.raw_user_query,
history=llm_call.prompt_builder.raw_message_history,
llm=llm,
)
available_tools_and_args = [
(llm_call.tools[ind], args)
for ind, args in enumerate(tool_options)
if args is not None
]
logger.info(
f"Selecting single tool from tools: {[(tool.name, args) for tool, args in available_tools_and_args]}"
)
chosen_tool_and_args = (
select_single_tool_for_non_tool_calling_llm(
tools_and_args=available_tools_and_args,
history=llm_call.prompt_builder.raw_message_history,
query=llm_call.prompt_builder.raw_user_query,
llm=llm,
)
if available_tools_and_args
else None
)
logger.notice(f"Chosen tool: {chosen_tool_and_args}")
return chosen_tool_and_args
return get_tool_call_for_non_tool_calling_llm_impl(
force_use_tool=llm_call.force_use_tool,
tools=llm_call.tools,
prompt_builder=llm_call.prompt_builder,
llm=llm,
)
def _handle_tool_call(self) -> Generator[ResponsePart, None, None]:
if not self.tool_call_chunk or not self.tool_call_chunk.tool_calls:
@@ -118,20 +72,17 @@ class ToolResponseHandler:
tool for tool in self.tools if tool.name == tool_call_request["name"]
]
if not known_tools_by_name:
logger.error(
"Tool call requested with unknown name field. \n"
f"self.tools: {self.tools}"
f"tool_call_request: {tool_call_request}"
)
continue
else:
if known_tools_by_name:
selected_tool = known_tools_by_name[0]
selected_tool_call_request = tool_call_request
if selected_tool and selected_tool_call_request:
break
logger.error(
"Tool call requested with unknown name field. \n"
f"self.tools: {self.tools}"
f"tool_call_request: {tool_call_request}"
)
if not selected_tool or not selected_tool_call_request:
return
@@ -157,8 +108,8 @@ class ToolResponseHandler:
def handle_response_part(
self,
response_item: BaseMessage | None,
previous_response_items: list[BaseMessage],
response_item: BaseMessage | str | None,
previous_response_items: list[BaseMessage | str],
) -> Generator[ResponsePart, None, None]:
if response_item is None:
yield from self._handle_tool_call()
@@ -171,8 +122,6 @@ class ToolResponseHandler:
else:
self.tool_call_chunk += response_item # type: ignore
return
def next_llm_call(self, current_llm_call: LLMCall) -> LLMCall | None:
if (
self.tool_runner is None
@@ -205,3 +154,61 @@ class ToolResponseHandler:
self.tool_final_result,
],
)
def get_tool_call_for_non_tool_calling_llm_impl(
force_use_tool: ForceUseTool,
tools: list[Tool],
prompt_builder: AnswerPromptBuilder | PromptSnapshot,
llm: LLM,
) -> tuple[Tool, dict] | None:
if force_use_tool.force_use:
# if we are forcing a tool, we don't need to check which tools to run
tool = get_tool_by_name(tools, force_use_tool.tool_name)
tool_args = (
force_use_tool.args
if force_use_tool.args is not None
else tool.get_args_for_non_tool_calling_llm(
query=prompt_builder.raw_user_query,
history=prompt_builder.raw_message_history,
llm=llm,
force_run=True,
)
)
if tool_args is None:
raise RuntimeError(f"Tool '{tool.name}' did not return args")
return (tool, tool_args)
else:
tool_options = check_which_tools_should_run_for_non_tool_calling_llm(
tools=tools,
query=prompt_builder.raw_user_query,
history=prompt_builder.raw_message_history,
llm=llm,
)
available_tools_and_args = [
(tools[ind], args)
for ind, args in enumerate(tool_options)
if args is not None
]
logger.info(
f"Selecting single tool from tools: {[(tool.name, args) for tool, args in available_tools_and_args]}"
)
chosen_tool_and_args = (
select_single_tool_for_non_tool_calling_llm(
tools_and_args=available_tools_and_args,
history=prompt_builder.raw_message_history,
query=prompt_builder.raw_user_query,
llm=llm,
)
if available_tools_and_args
else None
)
logger.notice(f"Chosen tool: {chosen_tool_and_args}")
return chosen_tool_and_args

View File

@@ -0,0 +1,77 @@
import os
AGENT_DEFAULT_RETRIEVAL_HITS = 15
AGENT_DEFAULT_RERANKING_HITS = 10
AGENT_DEFAULT_SUB_QUESTION_MAX_CONTEXT_HITS = 8
AGENT_DEFAULT_NUM_DOCS_FOR_INITIAL_DECOMPOSITION = 3
AGENT_DEFAULT_NUM_DOCS_FOR_REFINED_DECOMPOSITION = 5
AGENT_DEFAULT_EXPLORATORY_SEARCH_RESULTS = 3
AGENT_DEFAULT_MIN_ORIG_QUESTION_DOCS = 3
AGENT_DEFAULT_MAX_ANSWER_CONTEXT_DOCS = 10
AGENT_DEFAULT_MAX_STATIC_HISTORY_CHAR_LENGTH = 10000
#####
# Agent Configs
#####
AGENT_RETRIEVAL_STATS = (
not os.environ.get("AGENT_RETRIEVAL_STATS") == "False"
) or True # default True
AGENT_MAX_QUERY_RETRIEVAL_RESULTS = int(
os.environ.get("AGENT_MAX_QUERY_RETRIEVAL_RESULTS") or AGENT_DEFAULT_RETRIEVAL_HITS
) # 15
AGENT_MAX_QUERY_RETRIEVAL_RESULTS = int(
os.environ.get("AGENT_MAX_QUERY_RETRIEVAL_RESULTS") or AGENT_DEFAULT_RETRIEVAL_HITS
) # 15
# Reranking agent configs
# Reranking stats - no influence on flow outside of stats collection
AGENT_RERANKING_STATS = (
not os.environ.get("AGENT_RERANKING_STATS") == "True"
) or False # default False
AGENT_MAX_QUERY_RETRIEVAL_RESULTS = int(
os.environ.get("AGENT_MAX_QUERY_RETRIEVAL_RESULTS") or AGENT_DEFAULT_RETRIEVAL_HITS
) # 15
AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS = int(
os.environ.get("AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS")
or AGENT_DEFAULT_RERANKING_HITS
) # 10
AGENT_NUM_DOCS_FOR_DECOMPOSITION = int(
os.environ.get("AGENT_NUM_DOCS_FOR_DECOMPOSITION")
or AGENT_DEFAULT_NUM_DOCS_FOR_INITIAL_DECOMPOSITION
) # 3
AGENT_NUM_DOCS_FOR_REFINED_DECOMPOSITION = int(
os.environ.get("AGENT_NUM_DOCS_FOR_REFINED_DECOMPOSITION")
or AGENT_DEFAULT_NUM_DOCS_FOR_REFINED_DECOMPOSITION
) # 5
AGENT_EXPLORATORY_SEARCH_RESULTS = int(
os.environ.get("AGENT_EXPLORATORY_SEARCH_RESULTS")
or AGENT_DEFAULT_EXPLORATORY_SEARCH_RESULTS
) # 3
AGENT_MIN_ORIG_QUESTION_DOCS = int(
os.environ.get("AGENT_MIN_ORIG_QUESTION_DOCS")
or AGENT_DEFAULT_MIN_ORIG_QUESTION_DOCS
) # 3
AGENT_MAX_ANSWER_CONTEXT_DOCS = int(
os.environ.get("AGENT_MAX_ANSWER_CONTEXT_DOCS")
or AGENT_DEFAULT_SUB_QUESTION_MAX_CONTEXT_HITS
) # 8
AGENT_MAX_STATIC_HISTORY_CHAR_LENGTH = int(
os.environ.get("AGENT_MAX_STATIC_HISTORY_CHAR_LENGTH")
or AGENT_DEFAULT_MAX_STATIC_HISTORY_CHAR_LENGTH
) # 10000
GRAPH_VERSION_NAME: str = "a"

Some files were not shown because too many files have changed in this diff Show More