Compare commits

...

91 Commits

Author SHA1 Message Date
Evan Lohn
75aea29ccf Merge remote-tracking branch 'origin/main' into evan_answer_rework 2025-01-10 13:24:20 -08:00
Evan Lohn
89c60db0cc answer rework basic idea 2025-01-09 23:35:09 -08:00
Evan Lohn
635ba351d0 initial movement of answer code (WIP) 2025-01-09 17:19:46 -08:00
joachim-danswer
4ab99fb4a7 created pro_search_a directory
- moved all non-common files into that new directory to enable multiple graphs
2025-01-09 15:50:54 -08:00
joachim-danswer
aee625d525 agent logging level change 2025-01-09 14:02:05 -08:00
joachim-danswer
d379453d64 Creation of AgentAnswerPiece class
- unifying sub-answers and main agent answers
2025-01-09 13:33:24 -08:00
joachim-danswer
685f12c531 sub-question citation processing 2025-01-09 12:38:26 -08:00
Evan Lohn
686eddfa52 code cleanup 2025-01-09 10:05:49 -08:00
Evan Lohn
5f3e877833 Merge branch 'search_2_0_evan_refactor_api' of https://github.com/onyx-dot-app/onyx into search_2_0_evan_refactor_api 2025-01-09 09:53:24 -08:00
Evan Lohn
df75a3115b yield docs after deduping and reranking, yield initial question docs 2025-01-09 09:53:18 -08:00
joachim-danswer
6ff439c342 Added link for document citations
- assumes that Sub-question answers will follow the same citation format
 - assumes that we will be able to fit initial and revised answers into the same format.
2025-01-08 20:22:29 -08:00
Evan Lohn
44ccb5ef0f Merge branch 'search_2_0_evan_refactor_api' of https://github.com/onyx-dot-app/onyx into search_2_0_evan_refactor_api
merge
2025-01-08 17:15:16 -08:00
Evan Lohn
8045d52090 minor bug fix 2025-01-08 17:15:12 -08:00
joachim-danswer
0d059cf835 small fix 2025-01-08 17:13:31 -08:00
Evan Lohn
969d02767f merge 2025-01-08 16:17:08 -08:00
Evan Lohn
0bc3bb5558 more minor cleanup 2025-01-08 16:16:33 -08:00
joachim-danswer
4c19b19488 initial agent citation_processing 2025-01-08 15:54:57 -08:00
Evan Lohn
79e7b73db1 minor cleanup in preparation for Answer rework 2025-01-08 13:29:49 -08:00
Evan Lohn
7ce0436d71 main merge 2025-01-07 16:39:26 -08:00
Evan Lohn
1849069c5f merge with main 2025-01-07 16:16:50 -08:00
Evan Lohn
5788902d86 tiny testing change 2025-01-07 15:44:02 -08:00
Evan Lohn
80542859b6 Merge branch 'search_2_0_evan_refactor_api' of https://github.com/onyx-dot-app/onyx into search_2_0_evan_refactor_api
merge
2025-01-07 15:17:20 -08:00
joachim-danswer
1996e22f9b remove duplication of Main State keys + SubQuestion citation 2025-01-07 15:16:04 -08:00
Evan Lohn
682b145a6a add message history to pro search config 2025-01-07 15:13:34 -08:00
joachim-danswer
ef67f9cd1e more renamings
follow_up -> refined
2025-01-07 14:00:56 -08:00
joachim-danswer
295417f85d more renamings.. 2025-01-07 13:50:36 -08:00
joachim-danswer
8650b8ff51 main nodes/edges function renaming 2025-01-07 13:18:24 -08:00
joachim-danswer
462db23683 updated run_graph to stream out answer tokens 2025-01-07 10:14:21 -08:00
joachim-danswer
f2507755c3 merge 2025-01-07 09:41:01 -08:00
joachim-danswer
fd1191637b extra info logging and allow_refinement flag 2025-01-07 08:58:10 -08:00
Evan Lohn
7898f38a5d refactor and separate id fields 2025-01-06 19:12:49 -08:00
Evan Lohn
3a38407bca added type to ChatPacket 2025-01-06 18:43:37 -08:00
Evan Lohn
239f2f2718 cleanup plus get-chat-session API v1 2025-01-06 18:37:04 -08:00
joachim-danswer
0ff44c7661 agent metric column rename
duration_s -> duration__s
2025-01-06 11:25:34 -08:00
joachim-danswer
d697ad0fc8 renamed tables + additional logging 2025-01-06 11:13:42 -08:00
joachim-danswer
6c0d051f80 merged 2025-01-06 09:40:57 -08:00
joachim-danswer
2f7f4917e3 persistence sans Citations
TODO: change commit to flush
2025-01-06 08:00:30 -08:00
Evan Lohn
695d07f0f9 non-duplicated work changes, plus finished v1 of backend streaming 2025-01-05 17:37:49 -08:00
joachim-danswer
dd2c9425bd new columns 2025-01-05 13:23:38 -08:00
Evan Lohn
e468cac28c merged logging and WIP streaming 2025-01-05 11:49:02 -08:00
joachim-danswer
d35aa1eab9 question numbers & expanded_retrieval results 2025-01-05 11:05:22 -08:00
Evan Lohn
ae65c739de WIP, deleted persistence code, partway through streaming code 2025-01-05 10:55:32 -08:00
joachim-danswer
328f4758ae persistence of metrics 2025-01-04 19:55:09 -08:00
joachim-danswer
6422ad90a5 print statements removed (and files saved) 2025-01-03 08:26:54 -08:00
joachim-danswer
46850cc2ac removal of print statements 2025-01-03 08:23:48 -08:00
joachim-danswer
72e56aa4ca assistant persona - untested 2025-01-03 07:49:41 -08:00
joachim-danswer
13a5a86dec removed refined_search subgraph 2025-01-02 16:54:52 -08:00
joachim-danswer
a206723191 moved object dependencies out of refined subgraph 2025-01-02 16:22:45 -08:00
joachim-danswer
60207589d2 remove need for refined subgraph 2025-01-02 15:52:34 -08:00
Evan Lohn
d773163502 unorganized streaming of all relevant info 2025-01-02 14:29:08 -08:00
joachim-danswer
121827e34c initial end-to-end functioning 2025-01-02 13:49:36 -08:00
Evan Lohn
dd64c3a175 ugly but workable calling langgraph from the UI 2024-12-30 20:28:50 -08:00
Evan Lohn
38a616c87c basic support for accessing langgraph through the UI 2024-12-30 18:52:50 -08:00
Evan Lohn
d7812ee807 my happy about graph building 2024-12-30 14:01:03 -08:00
Evan Lohn
11db9647f3 no more dummy plus nodes files 2024-12-30 13:36:05 -08:00
Evan Lohn
d6a385b837 merging in testing changes 2024-12-30 13:16:10 -08:00
Evan Lohn
d68cf98e77 nodes in one file plus some mypy fixes 2024-12-30 12:58:35 -08:00
joachim-danswer
821b226d25 removed test file 2024-12-30 12:38:09 -08:00
joachim-danswer
6dc81bbb7c fixed stats 2024-12-30 12:23:11 -08:00
joachim-danswer
6989441851 pre-metrics clean-up 2024-12-30 09:11:45 -08:00
joachim-danswer
683978ddb0 base functioning 2024-12-29 12:23:39 -08:00
joachim-danswer
568bc16536 all 3 options 2024-12-29 10:51:31 -08:00
joachim-danswer
0333ff648a left or right 2024-12-29 10:09:12 -08:00
joachim-danswer
cc76486d21 added streaming and small fixes 2024-12-28 09:08:30 -08:00
joachim-danswer
901d8c22c4 further clean-up 2024-12-26 14:08:34 -08:00
joachim-danswer
21928133e0 tmp_state_model_sep 2024-12-26 10:59:10 -08:00
joachim-danswer
c4af11c19b core mypy resolutions 2024-12-25 16:15:55 -08:00
joachim-danswer
ca3f3beabe initial mypy changes 2024-12-24 20:21:34 -08:00
joachim-danswer
fa481019e8 improvements 2024-12-24 08:16:56 -08:00
joachim-danswer
f4c826c4e5 regression-test graph vs regulat graph 2024-12-20 14:52:09 -08:00
joachim-danswer
2a3328fc3d initial_test_flow 2024-12-20 14:41:20 -08:00
hagen-danswer
34aa054c5d added chunk_ids and stats to QueryResult 2024-12-19 08:48:05 -08:00
hagen-danswer
cebe237705 renamed PrimaryState to CoreState 2024-12-19 08:47:39 -08:00
hagen-danswer
c759fb5709 Merge pull request #3505 from onyx-dot-app/fix-the-states
Fix the states
2024-12-18 15:48:20 -08:00
hagen-danswer
ffc81f6e45 seperate edge for initial retrieval 2024-12-18 13:10:46 -08:00
hagen-danswer
2d6f746259 made query expansion explicit 2024-12-18 13:03:28 -08:00
hagen-danswer
bca02ebec6 figured it out 2024-12-18 12:44:28 -08:00
hagen-danswer
0c75ca0579 renames 2024-12-18 11:08:43 -08:00
hagen-danswer
9d3220fcfc explicitly ingest state from retrieval 2024-12-18 10:17:07 -08:00
hagen-danswer
50a216f554 naming and comments 2024-12-18 09:56:34 -08:00
hagen-danswer
8399d2ee0a mypy fixed 2024-12-18 09:27:47 -08:00
hagen-danswer
fd694bea8f query->question 2024-12-18 08:47:43 -08:00
hagen-danswer
e76cbec53c main graph works 2024-12-18 08:43:54 -08:00
hagen-danswer
d66180fe13 Cleanup 2024-12-18 07:33:40 -08:00
hagen-danswer
442c94727e got answer subgraph working 2024-12-17 15:16:36 -08:00
hagen-danswer
2f2b9a862a fixed expanded retrieval subgraph 2024-12-17 15:11:54 -08:00
hagen-danswer
1f88b60abd Now using result objects 2024-12-17 14:05:51 -08:00
hagen-danswer
ff03d717f3 brough over joachim changes 2024-12-17 12:36:28 -08:00
hagen-danswer
82914ad365 fixed key issue 2024-12-16 13:26:09 -08:00
hagen-danswer
11ce2a62ab fix: update staged changes 2024-12-16 12:24:17 -08:00
joachim-danswer
6311b70cc6 initial onyx changes 2024-12-16 11:23:01 -08:00
71 changed files with 6086 additions and 201 deletions

4
.gitignore vendored
View File

@@ -7,4 +7,6 @@
.vscode/
*.sw?
/backend/tests/regression/answer_quality/search_test_config.yaml
/web/test-results/
/web/test-results/
backend/onyx/agent_search/main/test_data.json
backend/tests/regression/answer_quality/test_data.json

View File

@@ -51,3 +51,9 @@ BING_API_KEY=<REPLACE THIS>
# Enable the full set of Danswer Enterprise Edition features
# NOTE: DO NOT ENABLE THIS UNLESS YOU HAVE A PAID ENTERPRISE LICENSE (or if you are using this for local testing/development)
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=False
# Agent Search configs # TODO: Remove give proper namings
AGENT_RETRIEVAL_STATS=False # Note: This setting will incur substantial re-ranking effort
AGENT_RERANKING_STATS=True
AGENT_MAX_QUERY_RETRIEVAL_RESULTS=20
AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS=20

View File

@@ -0,0 +1,29 @@
"""agent_doc_result_col
Revision ID: 1adf5ea20d2b
Revises: e9cf2bd7baed
Create Date: 2025-01-05 13:14:58.344316
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "1adf5ea20d2b"
down_revision = "e9cf2bd7baed"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Add the new column with JSONB type
op.add_column(
"sub_question",
sa.Column("sub_question_doc_results", postgresql.JSONB(), nullable=True),
)
def downgrade() -> None:
# Drop the column
op.drop_column("sub_question", "sub_question_doc_results")

View File

@@ -0,0 +1,35 @@
"""agent_metric_col_rename__s
Revision ID: 925b58bd75b6
Revises: 9787be927e58
Create Date: 2025-01-06 11:20:26.752441
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "925b58bd75b6"
down_revision = "9787be927e58"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Rename columns using PostgreSQL syntax
op.alter_column(
"agent__search_metrics", "base_duration_s", new_column_name="base_duration__s"
)
op.alter_column(
"agent__search_metrics", "full_duration_s", new_column_name="full_duration__s"
)
def downgrade() -> None:
# Revert the column renames
op.alter_column(
"agent__search_metrics", "base_duration__s", new_column_name="base_duration_s"
)
op.alter_column(
"agent__search_metrics", "full_duration__s", new_column_name="full_duration_s"
)

View File

@@ -0,0 +1,25 @@
"""agent_metric_table_renames__agent__
Revision ID: 9787be927e58
Revises: bceb76d618ec
Create Date: 2025-01-06 11:01:44.210160
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "9787be927e58"
down_revision = "bceb76d618ec"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Rename table from agent_search_metrics to agent__search_metrics
op.rename_table("agent_search_metrics", "agent__search_metrics")
def downgrade() -> None:
# Rename table back from agent__search_metrics to agent_search_metrics
op.rename_table("agent__search_metrics", "agent_search_metrics")

View File

@@ -0,0 +1,42 @@
"""agent_tracking
Revision ID: 98a5008d8711
Revises: 2955778aa44c
Create Date: 2025-01-04 14:41:52.732238
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "98a5008d8711"
down_revision = "2955778aa44c"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.create_table(
"agent_search_metrics",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=True),
sa.Column("persona_id", sa.Integer(), nullable=True),
sa.Column("agent_type", sa.String(), nullable=False),
sa.Column("start_time", sa.DateTime(timezone=True), nullable=False),
sa.Column("base_duration_s", sa.Float(), nullable=False),
sa.Column("full_duration_s", sa.Float(), nullable=False),
sa.Column("base_metrics", postgresql.JSONB(), nullable=True),
sa.Column("refined_metrics", postgresql.JSONB(), nullable=True),
sa.Column("all_metrics", postgresql.JSONB(), nullable=True),
sa.ForeignKeyConstraint(
["persona_id"],
["persona.id"],
),
sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"),
)
def downgrade() -> None:
op.drop_table("agent_search_metrics")

View File

@@ -0,0 +1,84 @@
"""agent_table_renames__agent__
Revision ID: bceb76d618ec
Revises: c0132518a25b
Create Date: 2025-01-06 10:50:48.109285
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "bceb76d618ec"
down_revision = "c0132518a25b"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.drop_constraint(
"sub_query__search_doc_sub_query_id_fkey",
"sub_query__search_doc",
type_="foreignkey",
)
op.drop_constraint(
"sub_query__search_doc_search_doc_id_fkey",
"sub_query__search_doc",
type_="foreignkey",
)
# Rename tables
op.rename_table("sub_query", "agent__sub_query")
op.rename_table("sub_question", "agent__sub_question")
op.rename_table("sub_query__search_doc", "agent__sub_query__search_doc")
# Update both foreign key constraints for agent__sub_query__search_doc
# Create new foreign keys with updated names
op.create_foreign_key(
"agent__sub_query__search_doc_sub_query_id_fkey",
"agent__sub_query__search_doc",
"agent__sub_query",
["sub_query_id"],
["id"],
)
op.create_foreign_key(
"agent__sub_query__search_doc_search_doc_id_fkey",
"agent__sub_query__search_doc",
"search_doc", # This table name doesn't change
["search_doc_id"],
["id"],
)
def downgrade() -> None:
# Update foreign key constraints for sub_query__search_doc
op.drop_constraint(
"agent__sub_query__search_doc_sub_query_id_fkey",
"agent__sub_query__search_doc",
type_="foreignkey",
)
op.drop_constraint(
"agent__sub_query__search_doc_search_doc_id_fkey",
"agent__sub_query__search_doc",
type_="foreignkey",
)
# Rename tables back
op.rename_table("agent__sub_query__search_doc", "sub_query__search_doc")
op.rename_table("agent__sub_question", "sub_question")
op.rename_table("agent__sub_query", "sub_query")
op.create_foreign_key(
"sub_query__search_doc_sub_query_id_fkey",
"sub_query__search_doc",
"sub_query",
["sub_query_id"],
["id"],
)
op.create_foreign_key(
"sub_query__search_doc_search_doc_id_fkey",
"sub_query__search_doc",
"search_doc", # This table name doesn't change
["search_doc_id"],
["id"],
)

View File

@@ -0,0 +1,40 @@
"""agent_table_changes_rename_level
Revision ID: c0132518a25b
Revises: 1adf5ea20d2b
Create Date: 2025-01-05 16:38:37.660152
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "c0132518a25b"
down_revision = "1adf5ea20d2b"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Add level and level_question_nr columns with NOT NULL constraint
op.add_column(
"sub_question",
sa.Column("level", sa.Integer(), nullable=False, server_default="0"),
)
op.add_column(
"sub_question",
sa.Column(
"level_question_nr", sa.Integer(), nullable=False, server_default="0"
),
)
# Remove the server_default after the columns are created
op.alter_column("sub_question", "level", server_default=None)
op.alter_column("sub_question", "level_question_nr", server_default=None)
def downgrade() -> None:
# Remove the columns
op.drop_column("sub_question", "level_question_nr")
op.drop_column("sub_question", "level")

View File

@@ -0,0 +1,68 @@
"""create pro search persistence tables
Revision ID: e9cf2bd7baed
Revises: 98a5008d8711
Create Date: 2025-01-02 17:55:56.544246
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects.postgresql import UUID
# revision identifiers, used by Alembic.
revision = "e9cf2bd7baed"
down_revision = "98a5008d8711"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Create sub_question table
op.create_table(
"sub_question",
sa.Column("id", sa.Integer, primary_key=True),
sa.Column("primary_question_id", sa.Integer, sa.ForeignKey("chat_message.id")),
sa.Column(
"chat_session_id", UUID(as_uuid=True), sa.ForeignKey("chat_session.id")
),
sa.Column("sub_question", sa.Text),
sa.Column(
"time_created", sa.DateTime(timezone=True), server_default=sa.func.now()
),
sa.Column("sub_answer", sa.Text),
)
# Create sub_query table
op.create_table(
"sub_query",
sa.Column("id", sa.Integer, primary_key=True),
sa.Column("parent_question_id", sa.Integer, sa.ForeignKey("sub_question.id")),
sa.Column(
"chat_session_id", UUID(as_uuid=True), sa.ForeignKey("chat_session.id")
),
sa.Column("sub_query", sa.Text),
sa.Column(
"time_created", sa.DateTime(timezone=True), server_default=sa.func.now()
),
)
# Create sub_query__search_doc association table
op.create_table(
"sub_query__search_doc",
sa.Column(
"sub_query_id", sa.Integer, sa.ForeignKey("sub_query.id"), primary_key=True
),
sa.Column(
"search_doc_id",
sa.Integer,
sa.ForeignKey("search_doc.id"),
primary_key=True,
),
)
def downgrade() -> None:
op.drop_table("sub_query__search_doc")
op.drop_table("sub_query")
op.drop_table("sub_question")

View File

@@ -179,6 +179,7 @@ def handle_simplified_chat_message(
chunks_below=0,
full_doc=chat_message_req.full_doc,
structured_response_format=chat_message_req.structured_response_format,
use_pro_search=chat_message_req.use_pro_search,
)
packets = stream_chat_message_objects(
@@ -301,6 +302,7 @@ def handle_send_message_simple_with_history(
chunks_below=0,
full_doc=req.full_doc,
structured_response_format=req.structured_response_format,
use_pro_search=req.use_pro_search,
)
packets = stream_chat_message_objects(

View File

@@ -57,6 +57,9 @@ class BasicCreateChatMessageRequest(ChunkContext):
# https://platform.openai.com/docs/guides/structured-outputs/introduction
structured_response_format: dict | None = None
# If True, uses pro search instead of basic search
use_pro_search: bool = False
class BasicCreateChatMessageWithHistoryRequest(ChunkContext):
# Last element is the new query. All previous elements are historical context
@@ -71,6 +74,8 @@ class BasicCreateChatMessageWithHistoryRequest(ChunkContext):
# only works if using an OpenAI model. See the following for more details:
# https://platform.openai.com/docs/guides/structured-outputs/introduction
structured_response_format: dict | None = None
# If True, uses pro search instead of basic search
use_pro_search: bool = False
class SimpleDoc(BaseModel):
@@ -123,6 +128,9 @@ class OneShotQARequest(ChunkContext):
# If True, skips generative an AI response to the search query
skip_gen_ai_answer_generation: bool = False
# If True, uses pro search instead of basic search
use_pro_search: bool = False
@model_validator(mode="after")
def check_persona_fields(self) -> "OneShotQARequest":
if self.persona_override_config is None and self.persona_id is None:

View File

@@ -196,6 +196,7 @@ def get_answer_stream(
retrieval_details=query_request.retrieval_options,
rerank_settings=query_request.rerank_settings,
db_session=db_session,
use_pro_search=query_request.use_pro_search,
)
packets = stream_chat_message_objects(

View File

@@ -0,0 +1,78 @@
from langchain_core.callbacks.manager import dispatch_custom_event
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
from onyx.agent_search.basic.states import BasicInput
from onyx.agent_search.basic.states import BasicOutput
from onyx.agent_search.basic.states import BasicState
from onyx.agent_search.basic.states import BasicStateUpdate
def basic_graph_builder() -> StateGraph:
graph = StateGraph(
state_schema=BasicState,
input=BasicInput,
output=BasicOutput,
)
### Add nodes ###
graph.add_node(
node="get_response",
action=get_response,
)
### Add edges ###
graph.add_edge(start_key=START, end_key="get_response")
graph.add_conditional_edges("get_response", should_continue, ["get_response", END])
graph.add_edge(
start_key="get_response",
end_key=END,
)
return graph
def should_continue(state: BasicState) -> str:
return (
END if state["last_llm_call"] is None or state["calls"] > 0 else "get_response"
)
def get_response(state: BasicState) -> BasicStateUpdate:
llm = state["llm"]
current_llm_call = state["last_llm_call"]
if current_llm_call is None:
raise ValueError("last_llm_call is None")
answer_style_config = state["answer_style_config"]
response_handler_manager = state["response_handler_manager"]
# DEBUG: good breakpoint
stream = llm.stream(
# For tool calling LLMs, we want to insert the task prompt as part of this flow, this is because the LLM
# may choose to not call any tools and just generate the answer, in which case the task prompt is needed.
prompt=current_llm_call.prompt_builder.build(),
tools=[tool.tool_definition() for tool in current_llm_call.tools] or None,
tool_choice=(
"required"
if current_llm_call.tools and current_llm_call.force_use_tool.force_use
else None
),
structured_response_format=answer_style_config.structured_response_format,
)
for response in response_handler_manager.handle_llm_response(stream):
dispatch_custom_event(
"basic_response",
response,
)
return BasicStateUpdate(
last_llm_call=response_handler_manager.next_llm_call(current_llm_call),
calls=state["calls"] + 1,
)
if __name__ == "__main__":
pass

View File

@@ -0,0 +1,41 @@
from typing import TypedDict
from onyx.chat.llm_response_handler import LLMResponseHandlerManager
from onyx.chat.models import AnswerStyleConfig
from onyx.chat.prompt_builder.build import LLMCall
from onyx.llm.chat_llm import LLM
## Update States
## Graph Input State
class BasicInput(TypedDict):
last_llm_call: LLMCall | None
llm: LLM
answer_style_config: AnswerStyleConfig
response_handler_manager: LLMResponseHandlerManager
calls: int
## Graph Output State
class BasicOutput(TypedDict):
pass
class BasicStateUpdate(TypedDict):
last_llm_call: LLMCall | None
calls: int
## Graph State
class BasicState(
BasicInput,
BasicOutput,
):
pass

View File

@@ -0,0 +1,66 @@
from operator import add
from typing import Annotated
from typing import TypedDict
from typing import TypeVar
from sqlalchemy.orm import Session
from onyx.chat.models import ProSearchConfig
from onyx.db.models import User
from onyx.llm.interfaces import LLM
from onyx.tools.tool_implementations.search.search_tool import SearchTool
class CoreState(TypedDict, total=False):
"""
This is the core state that is shared across all subgraphs.
"""
config: ProSearchConfig
primary_llm: LLM
fast_llm: LLM
# a single session for the entire agent search
# is fine if we are only reading
db_session: Session
user: User | None
log_messages: Annotated[list[str], add]
search_tool: SearchTool
class SubgraphCoreState(TypedDict, total=False):
"""
This is the core state that is shared across all subgraphs.
"""
subgraph_config: ProSearchConfig
subgraph_primary_llm: LLM
subgraph_fast_llm: LLM
# a single session for the entire agent search
# is fine if we are only reading
subgraph_db_session: Session
subgraph_search_tool: SearchTool
# This ensures that the state passed in extends the CoreState
T = TypeVar("T", bound=CoreState)
T_SUBGRAPH = TypeVar("T_SUBGRAPH", bound=SubgraphCoreState)
def extract_core_fields(state: T) -> CoreState:
filtered_dict = {k: v for k, v in state.items() if k in CoreState.__annotations__}
return CoreState(**dict(filtered_dict)) # type: ignore
def extract_core_fields_for_subgraph(state: T) -> SubgraphCoreState:
filtered_dict = {
"subgraph_" + k: v for k, v in state.items() if k in CoreState.__annotations__
}
return SubgraphCoreState(**dict(filtered_dict)) # type: ignore
def in_subgraph_extract_core_fields(state: T_SUBGRAPH) -> SubgraphCoreState:
filtered_dict = {
k: v for k, v in state.items() if k in SubgraphCoreState.__annotations__
}
return SubgraphCoreState(**dict(filtered_dict)) # type: ignore

View File

@@ -0,0 +1,66 @@
from uuid import UUID
from sqlalchemy.orm import Session
from onyx.db.models import AgentSubQuery
from onyx.db.models import AgentSubQuestion
def create_sub_question(
db_session: Session,
chat_session_id: UUID,
primary_message_id: int,
sub_question: str,
sub_answer: str,
) -> AgentSubQuestion:
"""Create a new sub-question record in the database."""
sub_q = AgentSubQuestion(
chat_session_id=chat_session_id,
primary_question_id=primary_message_id,
sub_question=sub_question,
sub_answer=sub_answer,
)
db_session.add(sub_q)
db_session.flush()
return sub_q
def create_sub_query(
db_session: Session,
chat_session_id: UUID,
parent_question_id: int,
sub_query: str,
) -> AgentSubQuery:
"""Create a new sub-query record in the database."""
sub_q = AgentSubQuery(
chat_session_id=chat_session_id,
parent_question_id=parent_question_id,
sub_query=sub_query,
)
db_session.add(sub_q)
db_session.flush()
return sub_q
def get_sub_questions_for_message(
db_session: Session,
primary_message_id: int,
) -> list[AgentSubQuestion]:
"""Get all sub-questions for a given primary message."""
return (
db_session.query(AgentSubQuestion)
.filter(AgentSubQuestion.primary_question_id == primary_message_id)
.all()
)
def get_sub_queries_for_question(
db_session: Session,
sub_question_id: int,
) -> list[AgentSubQuery]:
"""Get all sub-queries for a given sub-question."""
return (
db_session.query(AgentSubQuery)
.filter(AgentSubQuery.parent_question_id == sub_question_id)
.all()
)

View File

@@ -0,0 +1,7 @@
from pydantic import BaseModel
class AgentDocumentCitations(BaseModel):
document_id: str
document_title: str
link: str

View File

@@ -0,0 +1,28 @@
from collections.abc import Hashable
from langgraph.types import Send
from onyx.agent_search.core_state import in_subgraph_extract_core_fields
from onyx.agent_search.pro_search_a.answer_initial_sub_question.states import (
AnswerQuestionInput,
)
from onyx.agent_search.pro_search_a.expanded_retrieval.states import (
ExpandedRetrievalInput,
)
from onyx.utils.logger import setup_logger
logger = setup_logger()
def send_to_expanded_retrieval(state: AnswerQuestionInput) -> Send | Hashable:
logger.debug("sending to expanded retrieval via edge")
return Send(
"initial_sub_question_expanded_retrieval",
ExpandedRetrievalInput(
**in_subgraph_extract_core_fields(state),
question=state["question"],
base_search=False,
sub_question_id=state["question_id"],
),
)

View File

@@ -0,0 +1,129 @@
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
from onyx.agent_search.pro_search_a.answer_initial_sub_question.edges import (
send_to_expanded_retrieval,
)
from onyx.agent_search.pro_search_a.answer_initial_sub_question.nodes.answer_check import (
answer_check,
)
from onyx.agent_search.pro_search_a.answer_initial_sub_question.nodes.answer_generation import (
answer_generation,
)
from onyx.agent_search.pro_search_a.answer_initial_sub_question.nodes.format_answer import (
format_answer,
)
from onyx.agent_search.pro_search_a.answer_initial_sub_question.nodes.ingest_retrieval import (
ingest_retrieval,
)
from onyx.agent_search.pro_search_a.answer_initial_sub_question.states import (
AnswerQuestionInput,
)
from onyx.agent_search.pro_search_a.answer_initial_sub_question.states import (
AnswerQuestionOutput,
)
from onyx.agent_search.pro_search_a.answer_initial_sub_question.states import (
AnswerQuestionState,
)
from onyx.agent_search.pro_search_a.expanded_retrieval.graph_builder import (
expanded_retrieval_graph_builder,
)
from onyx.agent_search.shared_graph_utils.utils import get_test_config
from onyx.utils.logger import setup_logger
logger = setup_logger()
def answer_query_graph_builder() -> StateGraph:
graph = StateGraph(
state_schema=AnswerQuestionState,
input=AnswerQuestionInput,
output=AnswerQuestionOutput,
)
### Add nodes ###
expanded_retrieval = expanded_retrieval_graph_builder().compile()
graph.add_node(
node="initial_sub_question_expanded_retrieval",
action=expanded_retrieval,
)
graph.add_node(
node="answer_check",
action=answer_check,
)
graph.add_node(
node="answer_generation",
action=answer_generation,
)
graph.add_node(
node="format_answer",
action=format_answer,
)
graph.add_node(
node="ingest_retrieval",
action=ingest_retrieval,
)
### Add edges ###
graph.add_conditional_edges(
source=START,
path=send_to_expanded_retrieval,
path_map=["initial_sub_question_expanded_retrieval"],
)
graph.add_edge(
start_key="initial_sub_question_expanded_retrieval",
end_key="ingest_retrieval",
)
graph.add_edge(
start_key="ingest_retrieval",
end_key="answer_generation",
)
graph.add_edge(
start_key="answer_generation",
end_key="answer_check",
)
graph.add_edge(
start_key="answer_check",
end_key="format_answer",
)
graph.add_edge(
start_key="format_answer",
end_key=END,
)
return graph
if __name__ == "__main__":
from onyx.db.engine import get_session_context_manager
from onyx.llm.factory import get_default_llms
from onyx.context.search.models import SearchRequest
graph = answer_query_graph_builder()
compiled_graph = graph.compile()
primary_llm, fast_llm = get_default_llms()
search_request = SearchRequest(
query="what can you do with onyx or danswer?",
)
with get_session_context_manager() as db_session:
pro_search_config, search_tool = get_test_config(
db_session, primary_llm, fast_llm, search_request
)
inputs = AnswerQuestionInput(
question="what can you do with onyx?",
subgraph_fast_llm=fast_llm,
subgraph_primary_llm=primary_llm,
subgraph_config=pro_search_config,
subgraph_search_tool=search_tool,
subgraph_db_session=db_session,
question_id="0_0",
)
for thing in compiled_graph.stream(
input=inputs,
# debug=True,
# subgraphs=True,
):
logger.debug(thing)

View File

@@ -0,0 +1,21 @@
from pydantic import BaseModel
from onyx.agent_search.pro_search_a.expanded_retrieval.models import QueryResult
from onyx.agent_search.shared_graph_utils.models import AgentChunkStats
from onyx.context.search.models import InferenceSection
### Models ###
class AnswerRetrievalStats(BaseModel):
answer_retrieval_stats: dict[str, float | int]
class QuestionAnswerResults(BaseModel):
question: str
question_id: str
answer: str
quality: str
expanded_retrieval_results: list[QueryResult]
documents: list[InferenceSection]
sub_question_retrieval_stats: AgentChunkStats

View File

@@ -0,0 +1,34 @@
from langchain_core.messages import HumanMessage
from langchain_core.messages import merge_message_runs
from onyx.agent_search.pro_search_a.answer_initial_sub_question.states import (
AnswerQuestionState,
)
from onyx.agent_search.pro_search_a.answer_initial_sub_question.states import (
QACheckUpdate,
)
from onyx.agent_search.shared_graph_utils.prompts import SUB_CHECK_PROMPT
def answer_check(state: AnswerQuestionState) -> QACheckUpdate:
msg = [
HumanMessage(
content=SUB_CHECK_PROMPT.format(
question=state["question"],
base_answer=state["answer"],
)
)
]
fast_llm = state["subgraph_fast_llm"]
response = list(
fast_llm.stream(
prompt=msg,
)
)
quality_str = merge_message_runs(response, chunk_separator="")[0].content
return QACheckUpdate(
answer_quality=quality_str,
)

View File

@@ -0,0 +1,77 @@
import datetime
from typing import Any
from langchain_core.callbacks.manager import dispatch_custom_event
from langchain_core.messages import merge_message_runs
from onyx.agent_search.pro_search_a.answer_initial_sub_question.states import (
AnswerQuestionState,
)
from onyx.agent_search.pro_search_a.answer_initial_sub_question.states import (
QAGenerationUpdate,
)
from onyx.agent_search.shared_graph_utils.agent_prompt_ops import (
build_sub_question_answer_prompt,
)
from onyx.agent_search.shared_graph_utils.prompts import ASSISTANT_SYSTEM_PROMPT_DEFAULT
from onyx.agent_search.shared_graph_utils.prompts import ASSISTANT_SYSTEM_PROMPT_PERSONA
from onyx.agent_search.shared_graph_utils.utils import get_persona_prompt
from onyx.agent_search.shared_graph_utils.utils import parse_question_id
from onyx.chat.models import AgentAnswerPiece
from onyx.utils.logger import setup_logger
logger = setup_logger()
def answer_generation(state: AnswerQuestionState) -> QAGenerationUpdate:
now_start = datetime.datetime.now()
logger.debug(f"--------{now_start}--------START ANSWER GENERATION---")
question = state["question"]
docs = state["documents"]
level, question_nr = parse_question_id(state["question_id"])
persona_prompt = get_persona_prompt(state["subgraph_config"].search_request.persona)
if len(persona_prompt) > 0:
persona_specification = ASSISTANT_SYSTEM_PROMPT_DEFAULT
else:
persona_specification = ASSISTANT_SYSTEM_PROMPT_PERSONA.format(
persona_prompt=persona_prompt
)
logger.debug(f"Number of verified retrieval docs: {len(docs)}")
msg = build_sub_question_answer_prompt(
question=question,
original_question=state["subgraph_config"].search_request.query,
docs=docs,
persona_specification=persona_specification,
)
fast_llm = state["subgraph_fast_llm"]
response: list[str | list[str | dict[str, Any]]] = []
for message in fast_llm.stream(
prompt=msg,
):
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
content = message.content
if not isinstance(content, str):
raise ValueError(
f"Expected content to be a string, but got {type(content)}"
)
dispatch_custom_event(
"sub_answers",
AgentAnswerPiece(
answer_piece=content,
level=level,
level_question_nr=question_nr,
answer_type="agent_sub_answer",
),
)
response.append(content)
answer_str = merge_message_runs(response, chunk_separator="")[0].content
return QAGenerationUpdate(
answer=answer_str,
)

View File

@@ -0,0 +1,25 @@
from onyx.agent_search.pro_search_a.answer_initial_sub_question.states import (
AnswerQuestionOutput,
)
from onyx.agent_search.pro_search_a.answer_initial_sub_question.states import (
AnswerQuestionState,
)
from onyx.agent_search.pro_search_a.answer_initial_sub_question.states import (
QuestionAnswerResults,
)
def format_answer(state: AnswerQuestionState) -> AnswerQuestionOutput:
return AnswerQuestionOutput(
answer_results=[
QuestionAnswerResults(
question=state["question"],
question_id=state["question_id"],
quality=state.get("answer_quality", "No"),
answer=state["answer"],
expanded_retrieval_results=state["expanded_retrieval_results"],
documents=state["documents"],
sub_question_retrieval_stats=state["sub_question_retrieval_stats"],
)
],
)

View File

@@ -0,0 +1,23 @@
from onyx.agent_search.pro_search_a.answer_initial_sub_question.states import (
RetrievalIngestionUpdate,
)
from onyx.agent_search.pro_search_a.expanded_retrieval.states import (
ExpandedRetrievalOutput,
)
from onyx.agent_search.shared_graph_utils.models import AgentChunkStats
def ingest_retrieval(state: ExpandedRetrievalOutput) -> RetrievalIngestionUpdate:
sub_question_retrieval_stats = state[
"expanded_retrieval_result"
].sub_question_retrieval_stats
if sub_question_retrieval_stats is None:
sub_question_retrieval_stats = [AgentChunkStats()]
return RetrievalIngestionUpdate(
expanded_retrieval_results=state[
"expanded_retrieval_result"
].expanded_queries_results,
documents=state["expanded_retrieval_result"].all_documents,
sub_question_retrieval_stats=sub_question_retrieval_stats,
)

View File

@@ -0,0 +1,63 @@
from operator import add
from typing import Annotated
from typing import TypedDict
from onyx.agent_search.core_state import SubgraphCoreState
from onyx.agent_search.pro_search_a.answer_initial_sub_question.models import (
QuestionAnswerResults,
)
from onyx.agent_search.pro_search_a.expanded_retrieval.models import QueryResult
from onyx.agent_search.shared_graph_utils.models import AgentChunkStats
from onyx.agent_search.shared_graph_utils.operators import dedup_inference_sections
from onyx.context.search.models import InferenceSection
## Update States
class QACheckUpdate(TypedDict):
answer_quality: str
class QAGenerationUpdate(TypedDict):
answer: str
# answer_stat: AnswerStats
class RetrievalIngestionUpdate(TypedDict):
expanded_retrieval_results: list[QueryResult]
documents: Annotated[list[InferenceSection], dedup_inference_sections]
sub_question_retrieval_stats: AgentChunkStats
## Graph Input State
class AnswerQuestionInput(SubgraphCoreState):
question: str
question_id: str # 0_0 is original question, everything else is <level>_<question_num>.
# level 0 is original question and first decomposition, level 1 is follow up, etc
# question_num is a unique number per original question per level.
## Graph State
class AnswerQuestionState(
AnswerQuestionInput,
QAGenerationUpdate,
QACheckUpdate,
RetrievalIngestionUpdate,
):
pass
## Graph Output State
class AnswerQuestionOutput(TypedDict):
"""
This is a list of results even though each call of this subgraph only returns one result.
This is because if we parallelize the answer query subgraph, there will be multiple
results in a list so the add operator is used to add them together.
"""
answer_results: Annotated[list[QuestionAnswerResults], add]

View File

@@ -0,0 +1,28 @@
from collections.abc import Hashable
from langgraph.types import Send
from onyx.agent_search.core_state import in_subgraph_extract_core_fields
from onyx.agent_search.pro_search_a.answer_initial_sub_question.states import (
AnswerQuestionInput,
)
from onyx.agent_search.pro_search_a.expanded_retrieval.states import (
ExpandedRetrievalInput,
)
from onyx.utils.logger import setup_logger
logger = setup_logger()
def send_to_expanded_refined_retrieval(state: AnswerQuestionInput) -> Send | Hashable:
logger.debug("sending to expanded retrieval for follow up question via edge")
return Send(
"refined_sub_question_expanded_retrieval",
ExpandedRetrievalInput(
**in_subgraph_extract_core_fields(state),
question=state["question"],
sub_question_id=state["question_id"],
base_search=False
),
)

View File

@@ -0,0 +1,122 @@
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
from onyx.agent_search.pro_search_a.answer_initial_sub_question.nodes.answer_check import (
answer_check,
)
from onyx.agent_search.pro_search_a.answer_initial_sub_question.nodes.answer_generation import (
answer_generation,
)
from onyx.agent_search.pro_search_a.answer_initial_sub_question.nodes.format_answer import (
format_answer,
)
from onyx.agent_search.pro_search_a.answer_initial_sub_question.nodes.ingest_retrieval import (
ingest_retrieval,
)
from onyx.agent_search.pro_search_a.answer_initial_sub_question.states import (
AnswerQuestionInput,
)
from onyx.agent_search.pro_search_a.answer_initial_sub_question.states import (
AnswerQuestionOutput,
)
from onyx.agent_search.pro_search_a.answer_initial_sub_question.states import (
AnswerQuestionState,
)
from onyx.agent_search.pro_search_a.answer_refinement_sub_question.edges import (
send_to_expanded_refined_retrieval,
)
from onyx.agent_search.pro_search_a.expanded_retrieval.graph_builder import (
expanded_retrieval_graph_builder,
)
from onyx.utils.logger import setup_logger
logger = setup_logger()
def answer_refined_query_graph_builder() -> StateGraph:
graph = StateGraph(
state_schema=AnswerQuestionState,
input=AnswerQuestionInput,
output=AnswerQuestionOutput,
)
### Add nodes ###
expanded_retrieval = expanded_retrieval_graph_builder().compile()
graph.add_node(
node="refined_sub_question_expanded_retrieval",
action=expanded_retrieval,
)
graph.add_node(
node="refined_sub_answer_check",
action=answer_check,
)
graph.add_node(
node="refined_sub_answer_generation",
action=answer_generation,
)
graph.add_node(
node="format_refined_sub_answer",
action=format_answer,
)
graph.add_node(
node="ingest_refined_retrieval",
action=ingest_retrieval,
)
### Add edges ###
graph.add_conditional_edges(
source=START,
path=send_to_expanded_refined_retrieval,
path_map=["refined_sub_question_expanded_retrieval"],
)
graph.add_edge(
start_key="refined_sub_question_expanded_retrieval",
end_key="ingest_refined_retrieval",
)
graph.add_edge(
start_key="ingest_refined_retrieval",
end_key="refined_sub_answer_generation",
)
graph.add_edge(
start_key="refined_sub_answer_generation",
end_key="refined_sub_answer_check",
)
graph.add_edge(
start_key="refined_sub_answer_check",
end_key="format_refined_sub_answer",
)
graph.add_edge(
start_key="format_refined_sub_answer",
end_key=END,
)
return graph
if __name__ == "__main__":
from onyx.db.engine import get_session_context_manager
from onyx.llm.factory import get_default_llms
from onyx.context.search.models import SearchRequest
graph = answer_refined_query_graph_builder()
compiled_graph = graph.compile()
primary_llm, fast_llm = get_default_llms()
search_request = SearchRequest(
query="what can you do with onyx or danswer?",
)
with get_session_context_manager() as db_session:
inputs = AnswerQuestionInput(
question="what can you do with onyx?",
question_id="0_0",
)
for thing in compiled_graph.stream(
input=inputs,
# debug=True,
# subgraphs=True,
):
logger.debug(thing)
# output = compiled_graph.invoke(inputs)
# logger.debug(output)

View File

@@ -0,0 +1,19 @@
from pydantic import BaseModel
from onyx.agent_search.shared_graph_utils.models import AgentChunkStats
from onyx.context.search.models import InferenceSection
### Models ###
class AnswerRetrievalStats(BaseModel):
answer_retrieval_stats: dict[str, float | int]
class QuestionAnswerResults(BaseModel):
question: str
answer: str
quality: str
# expanded_retrieval_results: list[QueryResult]
documents: list[InferenceSection]
sub_question_retrieval_stats: AgentChunkStats

View File

@@ -0,0 +1,70 @@
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
from onyx.agent_search.pro_search_a.base_raw_search.nodes.format_raw_search_results import (
format_raw_search_results,
)
from onyx.agent_search.pro_search_a.base_raw_search.nodes.generate_raw_search_data import (
generate_raw_search_data,
)
from onyx.agent_search.pro_search_a.base_raw_search.states import BaseRawSearchInput
from onyx.agent_search.pro_search_a.base_raw_search.states import BaseRawSearchOutput
from onyx.agent_search.pro_search_a.base_raw_search.states import BaseRawSearchState
from onyx.agent_search.pro_search_a.expanded_retrieval.graph_builder import (
expanded_retrieval_graph_builder,
)
def base_raw_search_graph_builder() -> StateGraph:
graph = StateGraph(
state_schema=BaseRawSearchState,
input=BaseRawSearchInput,
output=BaseRawSearchOutput,
)
### Add nodes ###
expanded_retrieval = expanded_retrieval_graph_builder().compile()
graph.add_node(
node="generate_raw_search_data",
action=generate_raw_search_data,
)
graph.add_node(
node="expanded_retrieval_base_search",
action=expanded_retrieval,
)
graph.add_node(
node="format_raw_search_results",
action=format_raw_search_results,
)
### Add edges ###
graph.add_edge(start_key=START, end_key="generate_raw_search_data")
graph.add_edge(
start_key="generate_raw_search_data",
end_key="expanded_retrieval_base_search",
)
graph.add_edge(
start_key="expanded_retrieval_base_search",
end_key="format_raw_search_results",
)
# graph.add_edge(
# start_key="expanded_retrieval_base_search",
# end_key=END,
# )
graph.add_edge(
start_key="format_raw_search_results",
end_key=END,
)
return graph
if __name__ == "__main__":
pass

View File

@@ -0,0 +1,20 @@
from pydantic import BaseModel
from onyx.agent_search.pro_search_a.expanded_retrieval.models import QueryResult
from onyx.agent_search.shared_graph_utils.models import AgentChunkStats
from onyx.context.search.models import InferenceSection
### Models ###
class AnswerRetrievalStats(BaseModel):
answer_retrieval_stats: dict[str, float | int]
class QuestionAnswerResults(BaseModel):
question: str
answer: str
quality: str
expanded_retrieval_results: list[QueryResult]
documents: list[InferenceSection]
sub_question_retrieval_stats: list[AgentChunkStats]

View File

@@ -0,0 +1,16 @@
from onyx.agent_search.pro_search_a.base_raw_search.states import BaseRawSearchOutput
from onyx.agent_search.pro_search_a.expanded_retrieval.states import (
ExpandedRetrievalOutput,
)
from onyx.utils.logger import setup_logger
logger = setup_logger()
def format_raw_search_results(state: ExpandedRetrievalOutput) -> BaseRawSearchOutput:
logger.debug("format_raw_search_results")
return BaseRawSearchOutput(
base_expanded_retrieval_result=state["expanded_retrieval_result"],
# base_retrieval_results=[state["expanded_retrieval_result"]],
# base_search_documents=[],
)

View File

@@ -0,0 +1,21 @@
from onyx.agent_search.core_state import CoreState
from onyx.agent_search.pro_search_a.expanded_retrieval.states import (
ExpandedRetrievalInput,
)
from onyx.utils.logger import setup_logger
logger = setup_logger()
def generate_raw_search_data(state: CoreState) -> ExpandedRetrievalInput:
logger.debug("generate_raw_search_data")
return ExpandedRetrievalInput(
subgraph_config=state["config"],
subgraph_primary_llm=state["primary_llm"],
subgraph_fast_llm=state["fast_llm"],
subgraph_db_session=state["db_session"],
question=state["config"].search_request.query,
base_search=True,
subgraph_search_tool=state["search_tool"],
sub_question_id=None, # This graph is always and only used for the original question
)

View File

@@ -0,0 +1,42 @@
from typing import TypedDict
from onyx.agent_search.core_state import CoreState
from onyx.agent_search.core_state import SubgraphCoreState
from onyx.agent_search.pro_search_a.expanded_retrieval.models import (
ExpandedRetrievalResult,
)
## Update States
## Graph Input State
class BaseRawSearchInput(CoreState, SubgraphCoreState):
pass
## Graph Output State
class BaseRawSearchOutput(TypedDict):
"""
This is a list of results even though each call of this subgraph only returns one result.
This is because if we parallelize the answer query subgraph, there will be multiple
results in a list so the add operator is used to add them together.
"""
# base_search_documents: Annotated[list[InferenceSection], dedup_inference_sections]
# base_retrieval_results: Annotated[list[ExpandedRetrievalResult], add]
base_expanded_retrieval_result: ExpandedRetrievalResult
## Graph State
class BaseRawSearchState(
BaseRawSearchInput,
BaseRawSearchOutput,
):
pass

View File

@@ -0,0 +1,28 @@
from collections.abc import Hashable
from langgraph.types import Send
from onyx.agent_search.core_state import in_subgraph_extract_core_fields
from onyx.agent_search.pro_search_a.expanded_retrieval.nodes import RetrievalInput
from onyx.agent_search.pro_search_a.expanded_retrieval.states import (
ExpandedRetrievalState,
)
def parallel_retrieval_edge(state: ExpandedRetrievalState) -> list[Send | Hashable]:
question = state.get("question", state["subgraph_config"].search_request.query)
query_expansions = state.get("expanded_queries", []) + [question]
return [
Send(
"doc_retrieval",
RetrievalInput(
query_to_retrieve=query,
question=question,
**in_subgraph_extract_core_fields(state),
base_search=False,
sub_question_id=state.get("sub_question_id"),
),
)
for query in query_expansions
]

View File

@@ -0,0 +1,126 @@
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
from onyx.agent_search.pro_search_a.expanded_retrieval.edges import (
parallel_retrieval_edge,
)
from onyx.agent_search.pro_search_a.expanded_retrieval.nodes import doc_reranking
from onyx.agent_search.pro_search_a.expanded_retrieval.nodes import doc_retrieval
from onyx.agent_search.pro_search_a.expanded_retrieval.nodes import doc_verification
from onyx.agent_search.pro_search_a.expanded_retrieval.nodes import expand_queries
from onyx.agent_search.pro_search_a.expanded_retrieval.nodes import format_results
from onyx.agent_search.pro_search_a.expanded_retrieval.nodes import verification_kickoff
from onyx.agent_search.pro_search_a.expanded_retrieval.states import (
ExpandedRetrievalInput,
)
from onyx.agent_search.pro_search_a.expanded_retrieval.states import (
ExpandedRetrievalOutput,
)
from onyx.agent_search.pro_search_a.expanded_retrieval.states import (
ExpandedRetrievalState,
)
from onyx.agent_search.shared_graph_utils.utils import get_test_config
from onyx.utils.logger import setup_logger
logger = setup_logger()
def expanded_retrieval_graph_builder() -> StateGraph:
graph = StateGraph(
state_schema=ExpandedRetrievalState,
input=ExpandedRetrievalInput,
output=ExpandedRetrievalOutput,
)
### Add nodes ###
graph.add_node(
node="expand_queries",
action=expand_queries,
)
graph.add_node(
node="doc_retrieval",
action=doc_retrieval,
)
graph.add_node(
node="verification_kickoff",
action=verification_kickoff,
)
graph.add_node(
node="doc_verification",
action=doc_verification,
)
graph.add_node(
node="doc_reranking",
action=doc_reranking,
)
graph.add_node(
node="format_results",
action=format_results,
)
### Add edges ###
graph.add_edge(
start_key=START,
end_key="expand_queries",
)
graph.add_conditional_edges(
source="expand_queries",
path=parallel_retrieval_edge,
path_map=["doc_retrieval"],
)
graph.add_edge(
start_key="doc_retrieval",
end_key="verification_kickoff",
)
graph.add_edge(
start_key="doc_verification",
end_key="doc_reranking",
)
graph.add_edge(
start_key="doc_reranking",
end_key="format_results",
)
graph.add_edge(
start_key="format_results",
end_key=END,
)
return graph
if __name__ == "__main__":
from onyx.db.engine import get_session_context_manager
from onyx.llm.factory import get_default_llms
from onyx.context.search.models import SearchRequest
graph = expanded_retrieval_graph_builder()
compiled_graph = graph.compile()
primary_llm, fast_llm = get_default_llms()
search_request = SearchRequest(
query="what can you do with onyx or danswer?",
)
with get_session_context_manager() as db_session:
pro_search_config, search_tool = get_test_config(
db_session, primary_llm, fast_llm, search_request
)
inputs = ExpandedRetrievalInput(
question="what can you do with onyx?",
base_search=False,
subgraph_fast_llm=fast_llm,
subgraph_primary_llm=primary_llm,
subgraph_db_session=db_session,
subgraph_config=pro_search_config,
subgraph_search_tool=search_tool,
sub_question_id=None,
)
for thing in compiled_graph.stream(
input=inputs,
# debug=True,
subgraphs=True,
):
logger.debug(thing)

View File

@@ -0,0 +1,21 @@
from pydantic import BaseModel
from onyx.agent_search.shared_graph_utils.models import AgentChunkStats
from onyx.agent_search.shared_graph_utils.models import RetrievalFitStats
from onyx.context.search.models import InferenceSection
from onyx.tools.models import SearchQueryInfo
### Models ###
class QueryResult(BaseModel):
query: str
search_results: list[InferenceSection]
stats: RetrievalFitStats | None
query_info: SearchQueryInfo | None
class ExpandedRetrievalResult(BaseModel):
expanded_queries_results: list[QueryResult]
all_documents: list[InferenceSection]
sub_question_retrieval_stats: AgentChunkStats

View File

@@ -0,0 +1,410 @@
from collections import defaultdict
from collections.abc import Callable
from typing import cast
from typing import Literal
import numpy as np
from langchain_core.callbacks.manager import dispatch_custom_event
from langchain_core.messages import HumanMessage
from langchain_core.messages import merge_message_runs
from langgraph.types import Command
from langgraph.types import Send
from onyx.agent_search.core_state import in_subgraph_extract_core_fields
from onyx.agent_search.pro_search_a.expanded_retrieval.models import (
ExpandedRetrievalResult,
)
from onyx.agent_search.pro_search_a.expanded_retrieval.models import QueryResult
from onyx.agent_search.pro_search_a.expanded_retrieval.states import DocRerankingUpdate
from onyx.agent_search.pro_search_a.expanded_retrieval.states import DocRetrievalUpdate
from onyx.agent_search.pro_search_a.expanded_retrieval.states import (
DocVerificationInput,
)
from onyx.agent_search.pro_search_a.expanded_retrieval.states import (
DocVerificationUpdate,
)
from onyx.agent_search.pro_search_a.expanded_retrieval.states import (
ExpandedRetrievalInput,
)
from onyx.agent_search.pro_search_a.expanded_retrieval.states import (
ExpandedRetrievalState,
)
from onyx.agent_search.pro_search_a.expanded_retrieval.states import (
ExpandedRetrievalUpdate,
)
from onyx.agent_search.pro_search_a.expanded_retrieval.states import InferenceSection
from onyx.agent_search.pro_search_a.expanded_retrieval.states import (
QueryExpansionUpdate,
)
from onyx.agent_search.pro_search_a.expanded_retrieval.states import RetrievalInput
from onyx.agent_search.shared_graph_utils.calculations import get_fit_scores
from onyx.agent_search.shared_graph_utils.models import AgentChunkStats
from onyx.agent_search.shared_graph_utils.models import RetrievalFitStats
from onyx.agent_search.shared_graph_utils.prompts import REWRITE_PROMPT_MULTI_ORIGINAL
from onyx.agent_search.shared_graph_utils.prompts import VERIFIER_PROMPT
from onyx.agent_search.shared_graph_utils.utils import dispatch_separated
from onyx.agent_search.shared_graph_utils.utils import parse_question_id
from onyx.chat.models import ExtendedToolResponse
from onyx.chat.models import SubQueryPiece
from onyx.configs.dev_configs import AGENT_MAX_QUERY_RETRIEVAL_RESULTS
from onyx.configs.dev_configs import AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS
from onyx.configs.dev_configs import AGENT_RERANKING_STATS
from onyx.configs.dev_configs import AGENT_RETRIEVAL_STATS
from onyx.context.search.models import SearchRequest
from onyx.context.search.pipeline import retrieval_preprocessing
from onyx.context.search.postprocessing.postprocessing import rerank_sections
from onyx.db.engine import get_session_context_manager
from onyx.llm.interfaces import LLM
from onyx.tools.models import SearchQueryInfo
from onyx.tools.tool_implementations.search.search_tool import (
SEARCH_RESPONSE_SUMMARY_ID,
)
from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary
from onyx.tools.tool_implementations.search.search_tool import yield_search_responses
from onyx.utils.logger import setup_logger
logger = setup_logger()
def dispatch_subquery(level: int, question_nr: int) -> Callable[[str, int], None]:
def helper(token: str, num: int) -> None:
dispatch_custom_event(
"subqueries",
SubQueryPiece(
sub_query=token,
level=level,
level_question_nr=question_nr,
query_id=num,
),
)
return helper
def expand_queries(state: ExpandedRetrievalInput) -> QueryExpansionUpdate:
# Sometimes we want to expand the original question, sometimes we want to expand a sub-question.
# When we are running this node on the original question, no question is explictly passed in.
# Instead, we use the original question from the search request.
question = state.get("question", state["subgraph_config"].search_request.query)
llm: LLM = state["subgraph_fast_llm"]
state["subgraph_db_session"]
chat_session_id = state["subgraph_config"].chat_session_id
sub_question_id = state.get("sub_question_id")
if sub_question_id is None:
level, question_nr = 0, 0
else:
level, question_nr = parse_question_id(sub_question_id)
if chat_session_id is None:
raise ValueError("chat_session_id must be provided for agent search")
msg = [
HumanMessage(
content=REWRITE_PROMPT_MULTI_ORIGINAL.format(question=question),
)
]
llm_response_list = dispatch_separated(
llm.stream(prompt=msg), dispatch_subquery(level, question_nr)
)
llm_response = merge_message_runs(llm_response_list, chunk_separator="")[0].content
rewritten_queries = llm_response.split("\n")
return QueryExpansionUpdate(
expanded_queries=rewritten_queries,
)
def doc_retrieval(state: RetrievalInput) -> DocRetrievalUpdate:
"""
Retrieve documents
Args:
state (RetrievalInput): Primary state + the query to retrieve
Updates:
expanded_retrieval_results: list[ExpandedRetrievalResult]
retrieved_documents: list[InferenceSection]
"""
query_to_retrieve = state["query_to_retrieve"]
search_tool = state["subgraph_search_tool"]
retrieved_docs: list[InferenceSection] = []
if not query_to_retrieve.strip():
logger.warning("Empty query, skipping retrieval")
return DocRetrievalUpdate(
expanded_retrieval_results=[],
retrieved_documents=[],
)
query_info = None
# new db session to avoid concurrency issues
with get_session_context_manager() as db_session:
for tool_response in search_tool.run(
query=query_to_retrieve,
force_no_rerank=True,
alternate_db_session=db_session,
):
# get retrieved docs to send to the rest of the graph
if tool_response.id == SEARCH_RESPONSE_SUMMARY_ID:
response = cast(SearchResponseSummary, tool_response.response)
retrieved_docs = response.top_sections
query_info = SearchQueryInfo(
predicted_search=response.predicted_search,
final_filters=response.final_filters,
recency_bias_multiplier=response.recency_bias_multiplier,
)
retrieved_docs = retrieved_docs[:AGENT_MAX_QUERY_RETRIEVAL_RESULTS]
pre_rerank_docs = retrieved_docs
if search_tool.search_pipeline is not None:
pre_rerank_docs = (
search_tool.search_pipeline._retrieved_sections or retrieved_docs
)
if AGENT_RETRIEVAL_STATS:
fit_scores = get_fit_scores(
pre_rerank_docs,
retrieved_docs,
)
else:
fit_scores = None
expanded_retrieval_result = QueryResult(
query=query_to_retrieve,
search_results=retrieved_docs,
stats=fit_scores,
query_info=query_info,
)
return DocRetrievalUpdate(
expanded_retrieval_results=[expanded_retrieval_result],
retrieved_documents=retrieved_docs,
)
def verification_kickoff(
state: ExpandedRetrievalState,
) -> Command[Literal["doc_verification"]]:
documents = state["retrieved_documents"]
verification_question = state.get(
"question", state["subgraph_config"].search_request.query
)
sub_question_id = state.get("sub_question_id")
return Command(
update={},
goto=[
Send(
node="doc_verification",
arg=DocVerificationInput(
doc_to_verify=doc,
question=verification_question,
base_search=False,
sub_question_id=sub_question_id,
**in_subgraph_extract_core_fields(state),
),
)
for doc in documents
],
)
def doc_verification(state: DocVerificationInput) -> DocVerificationUpdate:
"""
Check whether the document is relevant for the original user question
Args:
state (DocVerificationInput): The current state
Updates:
verified_documents: list[InferenceSection]
"""
question = state["question"]
doc_to_verify = state["doc_to_verify"]
document_content = doc_to_verify.combined_content
msg = [
HumanMessage(
content=VERIFIER_PROMPT.format(
question=question, document_content=document_content
)
)
]
fast_llm = state["subgraph_fast_llm"]
response = fast_llm.invoke(msg)
verified_documents = []
if isinstance(response.content, str) and "yes" in response.content.lower():
verified_documents.append(doc_to_verify)
return DocVerificationUpdate(
verified_documents=verified_documents,
)
def doc_reranking(state: ExpandedRetrievalState) -> DocRerankingUpdate:
verified_documents = state["verified_documents"]
# Rerank post retrieval and verification. First, create a search query
# then create the list of reranked sections
question = state.get("question", state["subgraph_config"].search_request.query)
with get_session_context_manager() as db_session:
_search_query = retrieval_preprocessing(
search_request=SearchRequest(query=question),
user=state["subgraph_search_tool"].user, # bit of a hack
llm=state["subgraph_fast_llm"],
db_session=db_session,
)
# skip section filtering
if (
_search_query.rerank_settings
and _search_query.rerank_settings.rerank_model_name
and _search_query.rerank_settings.num_rerank > 0
):
reranked_documents = rerank_sections(
_search_query,
verified_documents,
)
else:
logger.warning("No reranking settings found, using unranked documents")
reranked_documents = verified_documents
if AGENT_RERANKING_STATS:
fit_scores = get_fit_scores(verified_documents, reranked_documents)
else:
fit_scores = RetrievalFitStats(fit_score_lift=0, rerank_effect=0, fit_scores={})
# TODO: stream deduped docs here, or decide to use search tool ranking/verification
return DocRerankingUpdate(
reranked_documents=[
doc for doc in reranked_documents if type(doc) == InferenceSection
][:AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS],
sub_question_retrieval_stats=fit_scores,
)
def _calculate_sub_question_retrieval_stats(
verified_documents: list[InferenceSection],
expanded_retrieval_results: list[QueryResult],
) -> AgentChunkStats:
chunk_scores: dict[str, dict[str, list[int | float]]] = defaultdict(
lambda: defaultdict(list)
)
for expanded_retrieval_result in expanded_retrieval_results:
for doc in expanded_retrieval_result.search_results:
doc_chunk_id = f"{doc.center_chunk.document_id}_{doc.center_chunk.chunk_id}"
if doc.center_chunk.score is not None:
chunk_scores[doc_chunk_id]["score"].append(doc.center_chunk.score)
verified_doc_chunk_ids = [
f"{verified_document.center_chunk.document_id}_{verified_document.center_chunk.chunk_id}"
for verified_document in verified_documents
]
dismissed_doc_chunk_ids = []
raw_chunk_stats_counts: dict[str, int] = defaultdict(int)
raw_chunk_stats_scores: dict[str, float] = defaultdict(float)
for doc_chunk_id, chunk_data in chunk_scores.items():
if doc_chunk_id in verified_doc_chunk_ids:
raw_chunk_stats_counts["verified_count"] += 1
valid_chunk_scores = [
score for score in chunk_data["score"] if score is not None
]
raw_chunk_stats_scores["verified_scores"] += float(
np.mean(valid_chunk_scores)
)
else:
raw_chunk_stats_counts["rejected_count"] += 1
valid_chunk_scores = [
score for score in chunk_data["score"] if score is not None
]
raw_chunk_stats_scores["rejected_scores"] += float(
np.mean(valid_chunk_scores)
)
dismissed_doc_chunk_ids.append(doc_chunk_id)
if raw_chunk_stats_counts["verified_count"] == 0:
verified_avg_scores = 0.0
else:
verified_avg_scores = raw_chunk_stats_scores["verified_scores"] / float(
raw_chunk_stats_counts["verified_count"]
)
rejected_scores = raw_chunk_stats_scores.get("rejected_scores", None)
if rejected_scores is not None:
rejected_avg_scores = rejected_scores / float(
raw_chunk_stats_counts["rejected_count"]
)
else:
rejected_avg_scores = None
chunk_stats = AgentChunkStats(
verified_count=raw_chunk_stats_counts["verified_count"],
verified_avg_scores=verified_avg_scores,
rejected_count=raw_chunk_stats_counts["rejected_count"],
rejected_avg_scores=rejected_avg_scores,
verified_doc_chunk_ids=verified_doc_chunk_ids,
dismissed_doc_chunk_ids=dismissed_doc_chunk_ids,
)
return chunk_stats
def format_results(state: ExpandedRetrievalState) -> ExpandedRetrievalUpdate:
level, question_nr = parse_question_id(state.get("sub_question_id") or "0_0")
query_infos = [
result.query_info
for result in state["expanded_retrieval_results"]
if result.query_info is not None
]
if len(query_infos) == 0:
raise ValueError("No query info found")
# main question docs will be sent later after aggregation and deduping with sub-question docs
if not (level == 0 and question_nr == 0):
for tool_response in yield_search_responses(
query=state["question"],
reranked_sections=state[
"retrieved_documents"
], # TODO: rename params. this one is supposed to be the sections pre-merging
final_context_sections=state["reranked_documents"],
search_query_info=query_infos[0], # TODO: handle differing query infos?
get_section_relevance=lambda: None, # TODO: add relevance
search_tool=state["subgraph_search_tool"],
):
dispatch_custom_event(
"tool_response",
ExtendedToolResponse(
id=tool_response.id,
response=tool_response.response,
level=level,
level_question_nr=question_nr,
),
)
sub_question_retrieval_stats = _calculate_sub_question_retrieval_stats(
verified_documents=state["verified_documents"],
expanded_retrieval_results=state["expanded_retrieval_results"],
)
if sub_question_retrieval_stats is None:
sub_question_retrieval_stats = AgentChunkStats()
# else:
# sub_question_retrieval_stats = [sub_question_retrieval_stats]
return ExpandedRetrievalUpdate(
expanded_retrieval_result=ExpandedRetrievalResult(
expanded_queries_results=state["expanded_retrieval_results"],
all_documents=state["reranked_documents"],
sub_question_retrieval_stats=sub_question_retrieval_stats,
),
)

View File

@@ -0,0 +1,82 @@
from operator import add
from typing import Annotated
from typing import TypedDict
from onyx.agent_search.core_state import SubgraphCoreState
from onyx.agent_search.pro_search_a.expanded_retrieval.models import (
ExpandedRetrievalResult,
)
from onyx.agent_search.pro_search_a.expanded_retrieval.models import QueryResult
from onyx.agent_search.shared_graph_utils.models import RetrievalFitStats
from onyx.agent_search.shared_graph_utils.operators import dedup_inference_sections
from onyx.context.search.models import InferenceSection
### States ###
## Graph Input State
class ExpandedRetrievalInput(SubgraphCoreState):
question: str
base_search: bool
sub_question_id: str | None
## Update/Return States
class QueryExpansionUpdate(TypedDict):
expanded_queries: list[str]
class DocVerificationUpdate(TypedDict):
verified_documents: Annotated[list[InferenceSection], dedup_inference_sections]
class DocRetrievalUpdate(TypedDict):
expanded_retrieval_results: Annotated[list[QueryResult], add]
retrieved_documents: Annotated[list[InferenceSection], dedup_inference_sections]
class DocRerankingUpdate(TypedDict):
reranked_documents: Annotated[list[InferenceSection], dedup_inference_sections]
sub_question_retrieval_stats: RetrievalFitStats | None
class ExpandedRetrievalUpdate(TypedDict):
expanded_retrieval_result: ExpandedRetrievalResult
## Graph Output State
class ExpandedRetrievalOutput(TypedDict):
expanded_retrieval_result: ExpandedRetrievalResult
base_expanded_retrieval_result: ExpandedRetrievalResult
## Graph State
class ExpandedRetrievalState(
# This includes the core state
ExpandedRetrievalInput,
QueryExpansionUpdate,
DocRetrievalUpdate,
DocVerificationUpdate,
DocRerankingUpdate,
ExpandedRetrievalOutput,
):
pass
## Conditional Input States
class DocVerificationInput(ExpandedRetrievalInput):
doc_to_verify: InferenceSection
class RetrievalInput(ExpandedRetrievalInput):
query_to_retrieve: str

View File

@@ -0,0 +1,92 @@
from collections.abc import Hashable
from typing import Literal
from langgraph.types import Send
from onyx.agent_search.core_state import extract_core_fields_for_subgraph
from onyx.agent_search.pro_search_a.answer_initial_sub_question.states import (
AnswerQuestionInput,
)
from onyx.agent_search.pro_search_a.answer_initial_sub_question.states import (
AnswerQuestionOutput,
)
from onyx.agent_search.pro_search_a.main.states import MainState
from onyx.agent_search.pro_search_a.main.states import RequireRefinedAnswerUpdate
from onyx.agent_search.shared_graph_utils.utils import make_question_id
from onyx.utils.logger import setup_logger
logger = setup_logger()
def parallelize_initial_sub_question_answering(
state: MainState,
) -> list[Send | Hashable]:
if len(state["initial_decomp_questions"]) > 0:
# sub_question_record_ids = [subq_record.id for subq_record in state["sub_question_records"]]
# if len(state["sub_question_records"]) == 0:
# if state["config"].use_persistence:
# raise ValueError("No sub-questions found for initial decompozed questions")
# else:
# # in this case, we are doing retrieval on the original question.
# # to make all the logic consistent, we create a new sub-question
# # with the same content as the original question
# sub_question_record_ids = [1] * len(state["initial_decomp_questions"])
return [
Send(
"answer_query_subgraph",
AnswerQuestionInput(
**extract_core_fields_for_subgraph(state),
question=question,
question_id=make_question_id(0, question_nr),
),
)
for question_nr, question in enumerate(state["initial_decomp_questions"])
]
else:
return [
Send(
"ingest_answers",
AnswerQuestionOutput(
answer_results=[],
),
)
]
# Define the function that determines whether to continue or not
def continue_to_refined_answer_or_end(
state: RequireRefinedAnswerUpdate,
) -> Literal["refined_decompose", "logging_node"]:
if state["require_refined_answer"]:
return "refined_decompose"
else:
return "logging_node"
def parallelize_refined_sub_question_answering(
state: MainState,
) -> list[Send | Hashable]:
if len(state["refined_sub_questions"]) > 0:
return [
Send(
"answer_refinement_sub_question",
AnswerQuestionInput(
**extract_core_fields_for_subgraph(state),
question=question_data.sub_question,
question_id=make_question_id(1, question_nr),
),
)
for question_nr, question_data in state["refined_sub_questions"].items()
]
else:
return [
Send(
"ingest_refined_sub_answers",
AnswerQuestionOutput(
answer_results=[],
),
)
]

View File

@@ -0,0 +1,264 @@
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
from onyx.agent_search.pro_search_a.answer_initial_sub_question.graph_builder import (
answer_query_graph_builder,
)
from onyx.agent_search.pro_search_a.answer_refinement_sub_question.graph_builder import (
answer_refined_query_graph_builder,
)
from onyx.agent_search.pro_search_a.base_raw_search.graph_builder import (
base_raw_search_graph_builder,
)
from onyx.agent_search.pro_search_a.main.edges import continue_to_refined_answer_or_end
from onyx.agent_search.pro_search_a.main.edges import (
parallelize_initial_sub_question_answering,
)
from onyx.agent_search.pro_search_a.main.edges import (
parallelize_refined_sub_question_answering,
)
from onyx.agent_search.pro_search_a.main.nodes import agent_logging
from onyx.agent_search.pro_search_a.main.nodes import entity_term_extraction_llm
from onyx.agent_search.pro_search_a.main.nodes import generate_initial_answer
from onyx.agent_search.pro_search_a.main.nodes import generate_refined_answer
from onyx.agent_search.pro_search_a.main.nodes import ingest_initial_base_retrieval
from onyx.agent_search.pro_search_a.main.nodes import (
ingest_initial_sub_question_answers,
)
from onyx.agent_search.pro_search_a.main.nodes import ingest_refined_answers
from onyx.agent_search.pro_search_a.main.nodes import initial_answer_quality_check
from onyx.agent_search.pro_search_a.main.nodes import initial_sub_question_creation
from onyx.agent_search.pro_search_a.main.nodes import refined_answer_decision
from onyx.agent_search.pro_search_a.main.nodes import refined_sub_question_creation
from onyx.agent_search.pro_search_a.main.states import MainInput
from onyx.agent_search.pro_search_a.main.states import MainState
from onyx.agent_search.shared_graph_utils.utils import get_test_config
from onyx.utils.logger import setup_logger
logger = setup_logger()
test_mode = False
def main_graph_builder(test_mode: bool = False) -> StateGraph:
graph = StateGraph(
state_schema=MainState,
input=MainInput,
)
graph.add_node(
node="initial_sub_question_creation",
action=initial_sub_question_creation,
)
answer_query_subgraph = answer_query_graph_builder().compile()
graph.add_node(
node="answer_query_subgraph",
action=answer_query_subgraph,
)
base_raw_search_subgraph = base_raw_search_graph_builder().compile()
graph.add_node(
node="base_raw_search_subgraph",
action=base_raw_search_subgraph,
)
# refined_answer_subgraph = refined_answers_graph_builder().compile()
# graph.add_node(
# node="refined_answer_subgraph",
# action=refined_answer_subgraph,
# )
graph.add_node(
node="refined_sub_question_creation",
action=refined_sub_question_creation,
)
answer_refined_question = answer_refined_query_graph_builder().compile()
graph.add_node(
node="answer_refined_question",
action=answer_refined_question,
)
graph.add_node(
node="ingest_refined_answers",
action=ingest_refined_answers,
)
graph.add_node(
node="generate_refined_answer",
action=generate_refined_answer,
)
# graph.add_node(
# node="check_refined_answer",
# action=check_refined_answer,
# )
graph.add_node(
node="ingest_initial_retrieval",
action=ingest_initial_base_retrieval,
)
graph.add_node(
node="ingest_initial_sub_question_answers",
action=ingest_initial_sub_question_answers,
)
graph.add_node(
node="generate_initial_answer",
action=generate_initial_answer,
)
graph.add_node(
node="initial_answer_quality_check",
action=initial_answer_quality_check,
)
graph.add_node(
node="entity_term_extraction_llm",
action=entity_term_extraction_llm,
)
graph.add_node(
node="refined_answer_decision",
action=refined_answer_decision,
)
graph.add_node(
node="logging_node",
action=agent_logging,
)
# if test_mode:
# graph.add_node(
# node="generate_initial_base_answer",
# action=generate_initial_base_answer,
# )
### Add edges ###
graph.add_edge(start_key=START, end_key="base_raw_search_subgraph")
graph.add_edge(
start_key="base_raw_search_subgraph",
end_key="ingest_initial_retrieval",
)
graph.add_edge(
start_key=START,
end_key="initial_sub_question_creation",
)
graph.add_conditional_edges(
source="initial_sub_question_creation",
path=parallelize_initial_sub_question_answering,
path_map=["answer_query_subgraph"],
)
graph.add_edge(
start_key="answer_query_subgraph",
end_key="ingest_initial_sub_question_answers",
)
graph.add_edge(
start_key=["ingest_initial_sub_question_answers", "ingest_initial_retrieval"],
end_key="generate_initial_answer",
)
graph.add_edge(
start_key=["ingest_initial_sub_question_answers", "ingest_initial_retrieval"],
end_key="entity_term_extraction_llm",
)
graph.add_edge(
start_key="generate_initial_answer",
end_key="initial_answer_quality_check",
)
graph.add_edge(
start_key=["initial_answer_quality_check", "entity_term_extraction_llm"],
end_key="refined_answer_decision",
)
graph.add_conditional_edges(
source="refined_answer_decision",
path=continue_to_refined_answer_or_end,
path_map=["refined_sub_question_creation", "logging_node"],
)
graph.add_conditional_edges(
source="refined_sub_question_creation",
path=parallelize_refined_sub_question_answering,
path_map=["answer_refined_question"],
)
graph.add_edge(
start_key="answer_refined_question",
end_key="ingest_refined_answers",
)
graph.add_edge(
start_key="ingest_refined_answers",
end_key="generate_refined_answer",
)
# graph.add_conditional_edges(
# source="refined_answer_decision",
# path=continue_to_refined_answer_or_end,
# path_map=["refined_answer_subgraph", END],
# )
# graph.add_edge(
# start_key="refined_answer_subgraph",
# end_key="generate_refined_answer",
# )
graph.add_edge(
start_key="generate_refined_answer",
end_key="logging_node",
)
graph.add_edge(
start_key="logging_node",
end_key=END,
)
# graph.add_edge(
# start_key="generate_refined_answer",
# end_key="check_refined_answer",
# )
# graph.add_edge(
# start_key="check_refined_answer",
# end_key=END,
# )
return graph
if __name__ == "__main__":
pass
from onyx.db.engine import get_session_context_manager
from onyx.llm.factory import get_default_llms
from onyx.context.search.models import SearchRequest
graph = main_graph_builder()
compiled_graph = graph.compile()
primary_llm, fast_llm = get_default_llms()
with get_session_context_manager() as db_session:
search_request = SearchRequest(query="Who created Excel?")
pro_search_config, search_tool = get_test_config(
db_session, primary_llm, fast_llm, search_request
)
inputs = MainInput(
primary_llm=primary_llm,
fast_llm=fast_llm,
db_session=db_session,
config=pro_search_config,
search_tool=search_tool,
)
for thing in compiled_graph.stream(
input=inputs,
# stream_mode="debug",
# debug=True,
subgraphs=True,
):
logger.debug(thing)

View File

@@ -0,0 +1,69 @@
from pydantic import BaseModel
### Models ###
class Entity(BaseModel):
entity_name: str
entity_type: str
class Relationship(BaseModel):
relationship_name: str
relationship_type: str
relationship_entities: list[str]
class Term(BaseModel):
term_name: str
term_type: str
term_similar_to: list[str]
class EntityRelationshipTermExtraction(BaseModel):
entities: list[Entity]
relationships: list[Relationship]
terms: list[Term]
class FollowUpSubQuestion(BaseModel):
sub_question: str
sub_question_id: str
verified: bool
answered: bool
answer: str
class AgentTimings(BaseModel):
base_duration__s: float | None
refined_duration__s: float | None
full_duration__s: float | None
class AgentBaseMetrics(BaseModel):
num_verified_documents_total: int | None
num_verified_documents_core: int | None
verified_avg_score_core: float | None
num_verified_documents_base: int | float | None
verified_avg_score_base: float | None
base_doc_boost_factor: float | None
support_boost_factor: float | None
duration__s: float | None
class AgentRefinedMetrics(BaseModel):
refined_doc_boost_factor: float | None
refined_question_boost_factor: float | None
duration__s: float | None
class AgentAdditionalMetrics(BaseModel):
pass
class CombinedAgentMetrics(BaseModel):
timings: AgentTimings
base_metrics: AgentBaseMetrics
refined_metrics: AgentRefinedMetrics
additional_metrics: AgentAdditionalMetrics

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,151 @@
from datetime import datetime
from operator import add
from typing import Annotated
from typing import TypedDict
from onyx.agent_search.core_state import CoreState
from onyx.agent_search.pro_search_a.answer_initial_sub_question.states import (
QuestionAnswerResults,
)
from onyx.agent_search.pro_search_a.expanded_retrieval.models import (
ExpandedRetrievalResult,
)
from onyx.agent_search.pro_search_a.expanded_retrieval.models import QueryResult
from onyx.agent_search.pro_search_a.main.models import AgentBaseMetrics
from onyx.agent_search.pro_search_a.main.models import AgentRefinedMetrics
from onyx.agent_search.pro_search_a.main.models import EntityRelationshipTermExtraction
from onyx.agent_search.pro_search_a.main.models import FollowUpSubQuestion
from onyx.agent_search.shared_graph_utils.models import AgentChunkStats
from onyx.agent_search.shared_graph_utils.models import InitialAgentResultStats
from onyx.agent_search.shared_graph_utils.models import RefinedAgentStats
from onyx.agent_search.shared_graph_utils.operators import dedup_inference_sections
from onyx.agent_search.shared_graph_utils.operators import dedup_question_answer_results
from onyx.context.search.models import InferenceSection
### States ###
## Update States
class RefinedAgentStartStats(TypedDict):
agent_refined_start_time: datetime | None
class RefinedAgentEndStats(TypedDict):
agent_refined_end_time: datetime | None
agent_refined_metrics: AgentRefinedMetrics
class BaseDecompUpdateBase(TypedDict):
agent_start_time: datetime
initial_decomp_questions: list[str]
class BaseDecompUpdate(
RefinedAgentStartStats, RefinedAgentEndStats, BaseDecompUpdateBase
):
pass
class InitialAnswerBASEUpdate(TypedDict):
initial_base_answer: str
class InitialAnswerUpdate(TypedDict):
initial_answer: str
initial_agent_stats: InitialAgentResultStats | None
generated_sub_questions: list[str]
agent_base_end_time: datetime
agent_base_metrics: AgentBaseMetrics
class RefinedAnswerUpdateBase(TypedDict):
refined_answer: str
refined_agent_stats: RefinedAgentStats | None
refined_answer_quality: bool
class RefinedAnswerUpdate(RefinedAgentEndStats, RefinedAnswerUpdateBase):
pass
class InitialAnswerQualityUpdate(TypedDict):
initial_answer_quality: bool
class RequireRefinedAnswerUpdate(TypedDict):
require_refined_answer: bool
class DecompAnswersUpdate(TypedDict):
documents: Annotated[list[InferenceSection], dedup_inference_sections]
decomp_answer_results: Annotated[
list[QuestionAnswerResults], dedup_question_answer_results
]
class FollowUpDecompAnswersUpdate(TypedDict):
refined_documents: Annotated[list[InferenceSection], dedup_inference_sections]
refined_decomp_answer_results: Annotated[list[QuestionAnswerResults], add]
class ExpandedRetrievalUpdate(TypedDict):
all_original_question_documents: Annotated[
list[InferenceSection], dedup_inference_sections
]
original_question_retrieval_results: list[QueryResult]
original_question_retrieval_stats: AgentChunkStats
class EntityTermExtractionUpdate(TypedDict):
entity_retlation_term_extractions: EntityRelationshipTermExtraction
class FollowUpSubQuestionsUpdateBase(TypedDict):
refined_sub_questions: dict[int, FollowUpSubQuestion]
class FollowUpSubQuestionsUpdate(
RefinedAgentStartStats, FollowUpSubQuestionsUpdateBase
):
pass
## Graph Input State
## Graph Input State
class MainInput(CoreState):
pass
## Graph State
class MainState(
# This includes the core state
MainInput,
BaseDecompUpdateBase,
InitialAnswerUpdate,
InitialAnswerBASEUpdate,
DecompAnswersUpdate,
ExpandedRetrievalUpdate,
EntityTermExtractionUpdate,
InitialAnswerQualityUpdate,
RequireRefinedAnswerUpdate,
FollowUpSubQuestionsUpdateBase,
FollowUpDecompAnswersUpdate,
RefinedAnswerUpdateBase,
RefinedAgentStartStats,
RefinedAgentEndStats,
):
# expanded_retrieval_result: Annotated[list[ExpandedRetrievalResult], add]
base_raw_search_result: Annotated[list[ExpandedRetrievalResult], add]
## Graph Output State - presently not used
class MainOutput(TypedDict):
pass

View File

@@ -0,0 +1,404 @@
import asyncio
from collections import defaultdict
from collections.abc import AsyncIterable
from collections.abc import Iterable
from datetime import datetime
from typing import cast
from langchain_core.runnables.schema import StreamEvent
from langgraph.graph.state import CompiledStateGraph
from sqlalchemy.orm import Session
from onyx.agent_search.basic.graph_builder import basic_graph_builder
from onyx.agent_search.basic.states import BasicInput
from onyx.agent_search.models import AgentDocumentCitations
from onyx.agent_search.pro_search_a.main.graph_builder import main_graph_builder
from onyx.agent_search.pro_search_a.main.states import MainInput
from onyx.agent_search.shared_graph_utils.utils import get_test_config
from onyx.chat.llm_response_handler import LLMResponseHandlerManager
from onyx.chat.models import AgentAnswerPiece
from onyx.chat.models import AnswerPacket
from onyx.chat.models import AnswerStream
from onyx.chat.models import AnswerStyleConfig
from onyx.chat.models import ExtendedToolResponse
from onyx.chat.models import OnyxAnswerPiece
from onyx.chat.models import ProSearchConfig
from onyx.chat.models import SubQueryPiece
from onyx.chat.models import SubQuestionPiece
from onyx.chat.models import ToolResponse
from onyx.chat.prompt_builder.build import LLMCall
from onyx.context.search.models import SearchRequest
from onyx.db.engine import get_session_context_manager
from onyx.llm.interfaces import LLM
from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.tools.tool_runner import ToolCallKickoff
from onyx.utils.logger import setup_logger
logger = setup_logger()
_COMPILED_GRAPH: CompiledStateGraph | None = None
def _set_combined_token_value(
combined_token: str, parsed_object: AgentAnswerPiece
) -> AgentAnswerPiece:
parsed_object.answer_piece = combined_token
return parsed_object
def _parse_agent_event(
event: StreamEvent,
) -> AnswerPacket | None:
"""
Parse the event into a typed object.
Return None if we are not interested in the event.
"""
event_type = event["event"]
if event_type == "on_custom_event":
# TODO: different AnswerStream types for different events
if event["name"] == "decomp_qs":
return cast(SubQuestionPiece, event["data"])
elif event["name"] == "subqueries":
return cast(SubQueryPiece, event["data"])
elif event["name"] == "sub_answers":
return cast(AgentAnswerPiece, event["data"])
elif event["name"] == "initial_agent_answer":
return cast(AgentAnswerPiece, event["data"])
elif event["name"] == "tool_response":
return cast(ToolResponse, event["data"])
elif event["name"] == "basic_response":
return cast(AnswerPacket, event["data"])
return None
def _manage_async_event_streaming(
compiled_graph: CompiledStateGraph,
graph_input: MainInput | BasicInput,
) -> Iterable[StreamEvent]:
async def _run_async_event_stream() -> AsyncIterable[StreamEvent]:
async for event in compiled_graph.astream_events(
input=graph_input,
# debug=True,
# indicating v2 here deserves further scrutiny
version="v2",
):
yield event
# This might be able to be simplified
def _yield_async_to_sync() -> Iterable[StreamEvent]:
loop = asyncio.new_event_loop()
try:
# Get the async generator
async_gen = _run_async_event_stream()
# Convert to AsyncIterator
async_iter = async_gen.__aiter__()
while True:
try:
# Create a coroutine by calling anext with the async iterator
next_coro = anext(async_iter)
# Run the coroutine to get the next event
event = loop.run_until_complete(next_coro)
yield event
except StopAsyncIteration:
break
finally:
loop.close()
return _yield_async_to_sync()
def run_graph(
compiled_graph: CompiledStateGraph,
input: MainInput | BasicInput,
) -> AnswerStream:
agent_document_citations: dict[int, dict[int, list[AgentDocumentCitations]]] = {}
agent_question_citations_used_docs: defaultdict[
int, defaultdict[int, list[str]]
] = defaultdict(lambda: defaultdict(list))
citation_potential: defaultdict[int, defaultdict[int, bool]] = defaultdict(
lambda: defaultdict(lambda: False)
)
current_yield_components: defaultdict[
int, defaultdict[int, list[str]]
] = defaultdict(lambda: defaultdict(list))
current_yield_str: defaultdict[int, defaultdict[int, str]] = defaultdict(
lambda: defaultdict(lambda: "")
)
# def _process_citation(current_yield_str: str) -> tuple[str, str]:
# """Process a citation string and return the formatted citation and remaining text."""
# section_split = current_yield_str.split(']', 1)
# citation_part = section_split[0] + ']'
# remaining_text = section_split[1] if len(section_split) > 1 else ''
# if 'D' in citation_part:
# citation_type = "Document"
# formatted_citation = citation_part.replace('[D', '[[').replace(']', ']]')
# else: # Q case
# citation_type = "Question"
# formatted_citation = citation_part.replace('[Q', '{{').replace(']', '}}')
# return f" --- CITATION: {citation_type} - {formatted_citation}", remaining_text
for event in _manage_async_event_streaming(
compiled_graph=compiled_graph, graph_input=input
):
parsed_object = _parse_agent_event(event)
if not parsed_object:
continue
level = getattr(parsed_object, "level", None)
level_question_nr = getattr(parsed_object, "level_question_nr", None)
if isinstance(parsed_object, (OnyxAnswerPiece, AgentAnswerPiece)):
# logger.debug(f"FA {parsed_object.answer_piece}")
if isinstance(parsed_object, AgentAnswerPiece):
token = parsed_object.answer_piece
level = parsed_object.level
level_question_nr = parsed_object.level_question_nr
else:
yield parsed_object
continue
# raise ValueError(
# f"Invalid parsed object type: {type(parsed_object)}"
# )
if not citation_potential[level][level_question_nr] and token:
if token.startswith(" ["):
citation_potential[level][level_question_nr] = True
current_yield_components[level][level_question_nr] = [token]
else:
yield parsed_object
elif token and citation_potential[level][level_question_nr]:
current_yield_components[level][level_question_nr].append(token)
current_yield_str[level][level_question_nr] = "".join(
current_yield_components[level][level_question_nr]
)
if current_yield_str[level][level_question_nr].strip().startswith(
"[D"
) or current_yield_str[level][level_question_nr].strip().startswith(
"[Q"
):
citation_potential[level][level_question_nr] = True
else:
citation_potential[level][level_question_nr] = False
parsed_object = _set_combined_token_value(
current_yield_str[level][level_question_nr], parsed_object
)
yield parsed_object
if (
len(current_yield_components[level][level_question_nr]) > 15
): # ??? 15?
citation_potential[level][level_question_nr] = False
parsed_object = _set_combined_token_value(
current_yield_str[level][level_question_nr], parsed_object
)
yield parsed_object
elif "]" in current_yield_str[level][level_question_nr]:
section_split = current_yield_str[level][level_question_nr].split(
"]"
)
section_split[0] + "]" # dead code?
start_of_next_section = "]".join(section_split[1:])
citation_string = current_yield_str[level][level_question_nr][
: -len(start_of_next_section)
]
if "[D" in citation_string:
cite_open_bracket_marker, cite_close_bracket_marker = (
"[",
"]",
)
cite_identifyer = "D"
try:
cited_document = int(
citation_string[level][level_question_nr][2:-1]
)
if level and level_question_nr:
link = agent_document_citations[int(level)][
int(level_question_nr)
][cited_document].link
else:
link = ""
except (ValueError, IndexError):
link = ""
elif "[Q" in citation_string:
cite_open_bracket_marker, cite_close_bracket_marker = (
"{",
"}",
)
cite_identifyer = "Q"
else:
pass
citation_string = citation_string.replace(
"[" + cite_identifyer,
cite_open_bracket_marker * 2,
).replace("]", cite_close_bracket_marker * 2)
if cite_identifyer == "D":
citation_string += f"({link})"
parsed_object = _set_combined_token_value(
citation_string, parsed_object
)
yield parsed_object
current_yield_components[level][level_question_nr] = [
start_of_next_section
]
if not start_of_next_section.strip().startswith("["):
citation_potential[level][level_question_nr] = False
elif isinstance(parsed_object, ExtendedToolResponse):
if parsed_object.id == "search_response_summary":
level = parsed_object.level
level_question_nr = parsed_object.level_question_nr
for inference_section in parsed_object.response.top_sections:
doc_link = inference_section.center_chunk.source_links[0]
doc_title = inference_section.center_chunk.title
doc_id = inference_section.center_chunk.document_id
if (
doc_id
not in agent_question_citations_used_docs[level][
level_question_nr
]
):
if level not in agent_document_citations:
agent_document_citations[level] = {}
if level_question_nr not in agent_document_citations[level]:
agent_document_citations[level][level_question_nr] = []
agent_document_citations[level][level_question_nr].append(
AgentDocumentCitations(
document_id=doc_id,
document_title=doc_title,
link=doc_link,
)
)
agent_question_citations_used_docs[level][
level_question_nr
].append(doc_id)
yield parsed_object
else:
yield parsed_object
# TODO: call this once on startup, TBD where and if it should be gated based
# on dev mode or not
def load_compiled_graph() -> CompiledStateGraph:
global _COMPILED_GRAPH
if _COMPILED_GRAPH is None:
graph = main_graph_builder()
_COMPILED_GRAPH = graph.compile()
return _COMPILED_GRAPH
def run_main_graph(
config: ProSearchConfig,
search_tool: SearchTool,
primary_llm: LLM,
fast_llm: LLM,
db_session: Session,
) -> AnswerStream:
compiled_graph = load_compiled_graph()
input = MainInput(
config=config,
primary_llm=primary_llm,
fast_llm=fast_llm,
db_session=db_session,
search_tool=search_tool,
)
return run_graph(compiled_graph, input)
def run_basic_graph(
last_llm_call: LLMCall | None,
primary_llm: LLM,
answer_style_config: AnswerStyleConfig,
response_handler_manager: LLMResponseHandlerManager,
) -> AnswerStream:
graph = basic_graph_builder()
compiled_graph = graph.compile()
input = BasicInput(
last_llm_call=last_llm_call,
llm=primary_llm,
answer_style_config=answer_style_config,
response_handler_manager=response_handler_manager,
calls=0,
)
return run_graph(compiled_graph, input)
if __name__ == "__main__":
from onyx.llm.factory import get_default_llms
now_start = datetime.now()
logger.debug(f"Start at {now_start}")
graph = main_graph_builder()
compiled_graph = graph.compile()
now_end = datetime.now()
logger.debug(f"Graph compiled in {now_end - now_start} seconds")
primary_llm, fast_llm = get_default_llms()
search_request = SearchRequest(
# query="what can you do with gitlab?",
query="What are the guiding principles behind the development of cockroachDB?",
)
# Joachim custom persona
with get_session_context_manager() as db_session:
config, search_tool = get_test_config(
db_session, primary_llm, fast_llm, search_request
)
# search_request.persona = get_persona_by_id(1, None, db_session)
config.use_persistence = True
# with open("output.txt", "w") as f:
tool_responses: list = []
input = MainInput(
config=config,
primary_llm=primary_llm,
fast_llm=fast_llm,
db_session=db_session,
search_tool=search_tool,
)
for output in run_graph(compiled_graph, input):
# pass
if isinstance(output, ToolCallKickoff):
pass
elif isinstance(output, ToolResponse):
tool_responses.append(output.response)
elif isinstance(output, SubQuestionPiece):
logger.debug(
f"SQ {output.level} - {output.level_question_nr} - {output.sub_question} | "
)
elif (
isinstance(output, AgentAnswerPiece)
and output.answer_type == "agent_sub_answer"
):
logger.debug(
f" ---- SA {output.level} - {output.level_question_nr} {output.answer_piece} | "
)
elif (
isinstance(output, AgentAnswerPiece)
and output.answer_type == "agent_level_answer"
):
logger.debug(f" ---------- FA {output.answer_piece} | ")
# for tool_response in tool_responses:
# logger.debug(tool_response)

View File

@@ -0,0 +1,34 @@
from langchain.schema import AIMessage
from langchain.schema import HumanMessage
from langchain.schema import SystemMessage
from langchain_core.messages.tool import ToolMessage
from onyx.agent_search.shared_graph_utils.prompts import BASE_RAG_PROMPT_v2
from onyx.context.search.models import InferenceSection
def build_sub_question_answer_prompt(
question: str,
original_question: str,
docs: list[InferenceSection],
persona_specification: str,
) -> list[SystemMessage | HumanMessage | AIMessage | ToolMessage]:
system_message = SystemMessage(
content=persona_specification,
)
docs_format_list = [
f"""Document Number: [D{doc_nr + 1}]\n
Content: {doc.combined_content}\n\n"""
for doc_nr, doc in enumerate(docs)
]
docs_str = "\n\n".join(docs_format_list)
human_message = HumanMessage(
content=BASE_RAG_PROMPT_v2.format(
question=question, original_question=original_question, context=docs_str
)
)
return [system_message, human_message]

View File

@@ -0,0 +1,98 @@
import numpy as np
from onyx.agent_search.shared_graph_utils.models import RetrievalFitScoreMetrics
from onyx.agent_search.shared_graph_utils.models import RetrievalFitStats
from onyx.chat.models import SectionRelevancePiece
from onyx.context.search.models import InferenceSection
from onyx.utils.logger import setup_logger
logger = setup_logger()
def unique_chunk_id(doc: InferenceSection) -> str:
return f"{doc.center_chunk.document_id}_{doc.center_chunk.chunk_id}"
def calculate_rank_shift(list1: list, list2: list, top_n: int = 20) -> float:
shift = 0
for rank_first, doc_id in enumerate(list1[:top_n], 1):
try:
rank_second = list2.index(doc_id) + 1
except ValueError:
rank_second = len(list2) # Document not found in second list
shift += np.abs(rank_first - rank_second) / np.log(1 + rank_first * rank_second)
return shift / top_n
def get_fit_scores(
pre_reranked_results: list[InferenceSection],
post_reranked_results: list[InferenceSection] | list[SectionRelevancePiece],
) -> RetrievalFitStats | None:
"""
Calculate retrieval metrics for search purposes
"""
if len(pre_reranked_results) == 0 or len(post_reranked_results) == 0:
return None
ranked_sections = {
"initial": pre_reranked_results,
"reranked": post_reranked_results,
}
fit_eval: RetrievalFitStats = RetrievalFitStats(
fit_score_lift=0,
rerank_effect=0,
fit_scores={
"initial": RetrievalFitScoreMetrics(scores={}, chunk_ids=[]),
"reranked": RetrievalFitScoreMetrics(scores={}, chunk_ids=[]),
},
)
for rank_type, docs in ranked_sections.items():
logger.debug(f"rank_type: {rank_type}")
for i in [1, 5, 10]:
fit_eval.fit_scores[rank_type].scores[str(i)] = (
sum(
[
float(doc.center_chunk.score)
for doc in docs[:i]
if type(doc) == InferenceSection
and doc.center_chunk.score is not None
]
)
/ i
)
fit_eval.fit_scores[rank_type].scores["fit_score"] = (
1
/ 3
* (
fit_eval.fit_scores[rank_type].scores["1"]
+ fit_eval.fit_scores[rank_type].scores["5"]
+ fit_eval.fit_scores[rank_type].scores["10"]
)
)
fit_eval.fit_scores[rank_type].scores["fit_score"] = fit_eval.fit_scores[
rank_type
].scores["1"]
fit_eval.fit_scores[rank_type].chunk_ids = [
unique_chunk_id(doc) for doc in docs if type(doc) == InferenceSection
]
fit_eval.fit_score_lift = (
fit_eval.fit_scores["reranked"].scores["fit_score"]
/ fit_eval.fit_scores["initial"].scores["fit_score"]
)
fit_eval.rerank_effect = calculate_rank_shift(
fit_eval.fit_scores["initial"].chunk_ids,
fit_eval.fit_scores["reranked"].chunk_ids,
)
return fit_eval

View File

@@ -0,0 +1,52 @@
from typing import Literal
from pydantic import BaseModel
# Pydantic models for structured outputs
class RewrittenQueries(BaseModel):
rewritten_queries: list[str]
class BinaryDecision(BaseModel):
decision: Literal["yes", "no"]
class BinaryDecisionWithReasoning(BaseModel):
reasoning: str
decision: Literal["yes", "no"]
class RetrievalFitScoreMetrics(BaseModel):
scores: dict[str, float]
chunk_ids: list[str]
class RetrievalFitStats(BaseModel):
fit_score_lift: float
rerank_effect: float
fit_scores: dict[str, RetrievalFitScoreMetrics]
class AgentChunkScores(BaseModel):
scores: dict[str, dict[str, list[int | float]]]
class AgentChunkStats(BaseModel):
verified_count: int | None
verified_avg_scores: float | None
rejected_count: int | None
rejected_avg_scores: float | None
verified_doc_chunk_ids: list[str]
dismissed_doc_chunk_ids: list[str]
class InitialAgentResultStats(BaseModel):
sub_questions: dict[str, float | int | None]
original_question: dict[str, float | int | None]
agent_effectiveness: dict[str, float | int | None]
class RefinedAgentStats(BaseModel):
revision_doc_efficiency: float | None
revision_question_efficiency: float | None

View File

@@ -0,0 +1,31 @@
from onyx.agent_search.pro_search_a.answer_initial_sub_question.models import (
QuestionAnswerResults,
)
from onyx.chat.prune_and_merge import _merge_sections
from onyx.context.search.models import InferenceSection
def dedup_inference_sections(
list1: list[InferenceSection], list2: list[InferenceSection]
) -> list[InferenceSection]:
deduped = _merge_sections(list1 + list2)
return deduped
def dedup_question_answer_results(
question_answer_results_1: list[QuestionAnswerResults],
question_answer_results_2: list[QuestionAnswerResults],
) -> list[QuestionAnswerResults]:
deduped_question_answer_results: list[
QuestionAnswerResults
] = question_answer_results_1
utilized_question_ids: set[str] = set(
[x.question_id for x in question_answer_results_1]
)
for question_answer_result in question_answer_results_2:
if question_answer_result.question_id not in utilized_question_ids:
deduped_question_answer_results.append(question_answer_result)
utilized_question_ids.add(question_answer_result.question_id)
return deduped_question_answer_results

View File

@@ -0,0 +1,697 @@
REWRITE_PROMPT_MULTI_ORIGINAL = """ \n
Please convert an initial user question into a 2-3 more appropriate short and pointed search queries for retrievel from a
document store. Particularly, try to think about resolving ambiguities and make the search queries more specific,
enabling the system to search more broadly.
Also, try to make the search queries not redundant, i.e. not too similar! \n\n
Here is the initial question:
\n ------- \n
{question}
\n ------- \n
Formulate the queries separated by newlines (Do not say 'Query 1: ...', just write the querytext) as follows:
<query 1>
<query 2>
...
queries: """
REWRITE_PROMPT_MULTI = """ \n
Please create a list of 2-3 sample documents that could answer an original question. Each document
should be about as long as the original question. \n
Here is the initial question:
\n ------- \n
{question}
\n ------- \n
Formulate the sample documents separated by '--' (Do not say 'Document 1: ...', just write the text): """
# The prompt is only used if there is no persona prompt, so the placeholder is ''
BASE_RAG_PROMPT = """ \n
{persona_specification}
Use the context provided below - and only the
provided context - to answer the given question. (Note that the answer is in service of anserwing a broader
question, given below as 'motivation'.)
Again, only use the provided context and do not use your internal knowledge! If you cannot answer the
question based on the context, say "I don't know". It is a matter of life and death that you do NOT
use your internal knowledge, just the provided information!
Make sure that you keep all relevant information, specifically as it concerns to the ultimate goal.
(But keep other details as well.)
\nContext:\n {context} \n
Motivation:\n {original_question} \n\n
\n\n
And here is the question I want you to answer based on the context above (with the motivation in mind):
\n--\n {question} \n--\n
"""
BASE_RAG_PROMPT_v2 = """ \n
Use the context provided below - and only the
provided context - to answer the given question. (Note that the answer is in service of answering a broader
question, given below as 'motivation'.)
Again, only use the provided context and do not use your internal knowledge! If you cannot answer the
question based on the context, say "I don't know". It is a matter of life and death that you do NOT
use your internal knowledge, just the provided information!
Make sure that you keep all relevant information, specifically as it concerns to the ultimate goal.
(But keep other details as well.)
Please remember to provide inline citations in the format [D1], [D2], [D3], etc.\n\n\n
For your general information, here is the ultimate motivation:
\n--\n {original_question} \n--\n
\n\n
And here is the actual question I want you to answer based on the context above (with the motivation in mind):
\n--\n {question} \n--\n
Here is the context:
\n\n\n--\n {context} \n--\n
"""
SUB_CHECK_PROMPT = """
Your task is to see whether a given answer addresses a given question.
Please do not use any internal knowledge you may have - just focus on whether the answer
as given seems to largely address the question as given, or at least addresses part of the question.
Here is the question:
\n ------- \n
{question}
\n ------- \n
Here is the suggested answer:
\n ------- \n
{base_answer}
\n ------- \n
Does the suggested answer address the question? Please answer with yes or no:"""
BASE_CHECK_PROMPT = """ \n
Please check whether 1) the suggested answer seems to fully address the original question AND 2)the
original question requests a simple, factual answer, and there are no ambiguities, judgements,
aggregations, or any other complications that may require extra context. (I.e., if the question is
somewhat addressed, but the answer would benefit from more context, then answer with 'no'.)
Please only answer with 'yes' or 'no' \n
Here is the initial question:
\n ------- \n
{question}
\n ------- \n
Here is the proposed answer:
\n ------- \n
{initial_answer}
\n ------- \n
Please answer with yes or no:"""
VERIFIER_PROMPT = """
You are supposed to judge whether a document text contains data or information that is potentially relevant for a question.
Here is a document text that you can take as a fact:
--
DOCUMENT INFORMATION:
{document_content}
--
Do you think that this information is useful and relevant to answer the following question?
(Other documents may supply additional information, so do not worry if the provided information
is not enough to answer the question, but it needs to be relevant to the question.)
--
QUESTION:
{question}
--
Please answer with 'yes' or 'no':
Answer:
"""
INITIAL_DECOMPOSITION_PROMPT_BASIC = """ \n
If you think it is helpful, please decompose an initial user question into not more
than 4 appropriate sub-questions that help to answer the original question.
The purpose for this decomposition is to isolate individulal entities
(i.e., 'compare sales of company A and company B' -> 'what are sales for company A' + 'what are sales
for company B'), split ambiguous terms (i.e., 'what is our success with company A' -> 'what are our
sales with company A' + 'what is our market share with company A' + 'is company A a reference customer
for us'), etc. Each sub-question should be realistically be answerable by a good RAG system.
Importantly, if you think it is not needed or helpful, please just return an empty list. That is ok too.
Here is the initial question:
\n ------- \n
{question}
\n ------- \n
Please formulate your answer as a list of subquestions:
Answer:
"""
REWRITE_PROMPT_SINGLE = """ \n
Please convert an initial user question into a more appropriate search query for retrievel from a
document store. \n
Here is the initial question:
\n ------- \n
{question}
\n ------- \n
Formulate the query: """
MODIFIED_RAG_PROMPT = """You are an assistant for question-answering tasks. Use the context provided below
- and only this context - to answer the question. If you don't know the answer, just say "I don't know".
Use three sentences maximum and keep the answer concise.
Pay also particular attention to the sub-questions and their answers, at least it may enrich the answer.
Again, only use the provided context and do not use your internal knowledge! If you cannot answer the
question based on the context, say "I don't know". It is a matter of life and death that you do NOT
use your internal knowledge, just the provided information!
\nQuestion: {question}
\nContext: {combined_context} \n
Answer:"""
ORIG_DEEP_DECOMPOSE_PROMPT = """ \n
An initial user question needs to be answered. An initial answer has been provided but it wasn't quite
good enough. Also, some sub-questions had been answered and this information has been used to provide
the initial answer. Some other subquestions may have been suggested based on little knowledge, but they
were not directly answerable. Also, some entities, relationships and terms are givenm to you so that
you have an idea of how the avaiolable data looks like.
Your role is to generate 3-5 new sub-questions that would help to answer the initial question,
considering:
1) The initial question
2) The initial answer that was found to be unsatisfactory
3) The sub-questions that were answered
4) The sub-questions that were suggested but not answered
5) The entities, relationships and terms that were extracted from the context
The individual questions should be answerable by a good RAG system.
So a good idea would be to use the sub-questions to resolve ambiguities and/or to separate the
question for different entities that may be involved in the original question, but in a way that does
not duplicate questions that were already tried.
Additional Guidelines:
- The sub-questions should be specific to the question and provide richer context for the question,
resolve ambiguities, or address shortcoming of the initial answer
- Each sub-question - when answered - should be relevant for the answer to the original question
- The sub-questions should be free from comparisions, ambiguities,judgements, aggregations, or any
other complications that may require extra context.
- The sub-questions MUST have the full context of the original question so that it can be executed by
a RAG system independently without the original question available
(Example:
- initial question: "What is the capital of France?"
- bad sub-question: "What is the name of the river there?"
- good sub-question: "What is the name of the river that flows through Paris?"
- For each sub-question, please provide a short explanation for why it is a good sub-question. So
generate a list of dictionaries with the following format:
[{{"sub_question": <sub-question>, "explanation": <explanation>, "search_term": <rewrite the
sub-question using as a search phrase for the document store>}}, ...]
\n\n
Here is the initial question:
\n ------- \n
{question}
\n ------- \n
Here is the initial sub-optimal answer:
\n ------- \n
{base_answer}
\n ------- \n
Here are the sub-questions that were answered:
\n ------- \n
{answered_sub_questions}
\n ------- \n
Here are the sub-questions that were suggested but not answered:
\n ------- \n
{failed_sub_questions}
\n ------- \n
And here are the entities, relationships and terms extracted from the context:
\n ------- \n
{entity_term_extraction_str}
\n ------- \n
Please generate the list of good, fully contextualized sub-questions that would help to address the
main question. Again, please find questions that are NOT overlapping too much with the already answered
sub-questions or those that already were suggested and failed.
In other words - what can we try in addition to what has been tried so far?
Please think through it step by step and then generate the list of json dictionaries with the following
format:
{{"sub_questions": [{{"sub_question": <sub-question>,
"explanation": <explanation>,
"search_term": <rewrite the sub-question using as a search phrase for the document store>}},
...]}} """
DEEP_DECOMPOSE_PROMPT = """ \n
An initial user question needs to be answered. An initial answer has been provided but it wasn't quite
good enough. Also, some sub-questions had been answered and this information has been used to provide
the initial answer. Some other subquestions may have been suggested based on little knowledge, but they
were not directly answerable. Also, some entities, relationships and terms are givenm to you so that
you have an idea of how the avaiolable data looks like.
Your role is to generate 3-5 new sub-questions that would help to answer the initial question,
considering:
1) The initial question
2) The initial answer that was found to be unsatisfactory
3) The sub-questions that were answered
4) The sub-questions that were suggested but not answered
5) The entities, relationships and terms that were extracted from the context
The individual questions should be answerable by a good RAG system.
So a good idea would be to use the sub-questions to resolve ambiguities and/or to separate the
question for different entities that may be involved in the original question, but in a way that does
not duplicate questions that were already tried.
Additional Guidelines:
- The sub-questions should be specific to the question and provide richer context for the question,
resolve ambiguities, or address shortcoming of the initial answer
- Each sub-question - when answered - should be relevant for the answer to the original question
- The sub-questions should be free from comparisions, ambiguities,judgements, aggregations, or any
other complications that may require extra context.
- The sub-questions MUST have the full context of the original question so that it can be executed by
a RAG system independently without the original question available
(Example:
- initial question: "What is the capital of France?"
- bad sub-question: "What is the name of the river there?"
- good sub-question: "What is the name of the river that flows through Paris?"
- For each sub-question, please also provide a search term that can be used to retrieve relevant
documents from a document store.
- Consider specifically the sub-questions that were suggested but not answered. This is a sign that they are not
answerable with the available context, and you should not ask similar questions.
\n\n
Here is the initial question:
\n ------- \n
{question}
\n ------- \n
Here is the initial sub-optimal answer:
\n ------- \n
{base_answer}
\n ------- \n
Here are the sub-questions that were answered:
\n ------- \n
{answered_sub_questions}
\n ------- \n
Here are the sub-questions that were suggested but not answered:
\n ------- \n
{failed_sub_questions}
\n ------- \n
And here are the entities, relationships and terms extracted from the context:
\n ------- \n
{entity_term_extraction_str}
\n ------- \n
Please generate the list of good, fully contextualized sub-questions that would help to address the
main question.
Specifically pay attention also to the entities, relationships and terms extracted, as these indicate what type of
objects/relationships/terms you can ask about! Do not ask about entities, terms or relationships that are not
mentioned in the 'entities, relationships and terms' section.
Again, please find questions that are NOT overlapping too much with the already answered
sub-questions or those that already were suggested and failed.
In other words - what can we try in addition to what has been tried so far?
Generate the list of questions separated by new lines like this:
<sub-question 1>
<sub-question 2>
<sub-question 3>
...
"""
DECOMPOSE_PROMPT = """ \n
For an initial user question, please generate at 5-10 individual sub-questions whose answers would help
\n to answer the initial question. The individual questions should be answerable by a good RAG system.
So a good idea would be to \n use the sub-questions to resolve ambiguities and/or to separate the
question for different entities that may be involved in the original question.
In order to arrive at meaningful sub-questions, please also consider the context retrieved from the
document store, expressed as entities, relationships and terms. You can also think about the types
mentioned in brackets
Guidelines:
- The sub-questions should be specific to the question and provide richer context for the question,
and or resolve ambiguities
- Each sub-question - when answered - should be relevant for the answer to the original question
- The sub-questions should be free from comparisions, ambiguities,judgements, aggregations, or any
other complications that may require extra context.
- The sub-questions MUST have the full context of the original question so that it can be executed by
a RAG system independently without the original question available
(Example:
- initial question: "What is the capital of France?"
- bad sub-question: "What is the name of the river there?"
- good sub-question: "What is the name of the river that flows through Paris?"
- For each sub-question, please provide a short explanation for why it is a good sub-question. So
generate a list of dictionaries with the following format:
[{{"sub_question": <sub-question>, "explanation": <explanation>}}, ...]
\n\n
Here is the initial question:
\n ------- \n
{question}
\n ------- \n
And here are the entities, relationships and terms extracted from the context:
\n ------- \n
{entity_term_extraction_str}
\n ------- \n
Please generate the list of good, fully contextualized sub-questions that would help to address the
main question. Don't be too specific unless the original question is specific.
Please think through it step by step and then generate the list of json dictionaries with the following
format:
{{"sub_questions": [{{"sub_question": <sub-question>,
"explanation": <explanation>,
"search_term": <rewrite the sub-question using as a search phrase for the document store>}},
...]}} """
#### Consolidations
COMBINED_CONTEXT = """-------
Below you will find useful information to answer the original question. First, you see a number of
sub-questions with their answers. This information should be considered to be more focussed and
somewhat more specific to the original question as it tries to contextualized facts.
After that will see the documents that were considered to be relevant to answer the original question.
Here are the sub-questions and their answers:
\n\n {deep_answer_context} \n\n
\n\n Here are the documents that were considered to be relevant to answer the original question:
\n\n {formated_docs} \n\n
----------------
"""
SUB_QUESTION_EXPLANATION_RANKER_PROMPT = """-------
Below you will find a question that we ultimately want to answer (the original question) and a list of
motivations in arbitrary order for generated sub-questions that are supposed to help us answering the
original question. The motivations are formatted as <motivation number>: <motivation explanation>.
(Again, the numbering is arbitrary and does not necessarily mean that 1 is the most relevant
motivation and 2 is less relevant.)
Please rank the motivations in order of relevance for answering the original question. Also, try to
ensure that the top questions do not duplicate too much, i.e. that they are not too similar.
Ultimately, create a list with the motivation numbers where the number of the most relevant
motivations comes first.
Here is the original question:
\n\n {original_question} \n\n
\n\n Here is the list of sub-question motivations:
\n\n {sub_question_explanations} \n\n
----------------
Please think step by step and then generate the ranked list of motivations.
Please format your answer as a json object in the following format:
{{"reasonning": <explain your reasoning for the ranking>,
"ranked_motivations": <ranked list of motivation numbers>}}
"""
INITIAL_DECOMPOSITION_PROMPT_QUESTIONS = """
If you think it is helpful, please decompose an initial user question into 2 or 4 appropriate sub-questions that help to
answer the original question. The purpose for this decomposition is to
1) isolate individual entities (i.e., 'compare sales of company A and company B' -> ['what are sales for company A',
'what are sales for company B')]
2) clarify or disambiguate ambiguous terms (i.e., 'what is our success with company A' -> ['what are our sales with company A',
'what is our market share with company A', 'is company A a reference customer for us', etc.])
3) if a term or a metric is essentially clear, but it could relate to various components of an entity and you are generally
familiar with the entity, then you can decompose the question into sub-questions that are more specific to components
(i.e., 'what do we do to improve scalability of product X', 'what do we to to improve scalability of product X',
'what do we do to improve stability of product X', ...])
If you think that a decomposition is not needed or helpful, please just return an empty string. That is ok too.
Here is the initial question:
-------
{question}
-------
Please formulate your answer as a newline-separated list of questions like so:
<sub-question>
<sub-question>
<sub-question>
Answer:"""
INITIAL_DECOMPOSITION_PROMPT = """ \n
Please decompose an initial user question into 2 or 3 appropriate sub-questions that help to
answer the original question. The purpose for this decomposition is to isolate individulal entities
(i.e., 'compare sales of company A and company B' -> 'what are sales for company A' + 'what are sales
for company B'), split ambiguous terms (i.e., 'what is our success with company A' -> 'what are our
sales with company A' + 'what is our market share with company A' + 'is company A a reference customer
for us'), etc. Each sub-question should be realistically be answerable by a good RAG system. \n
For each sub-question, please also create one search term that can be used to retrieve relevant
documents from a document store.
Here is the initial question:
\n ------- \n
{question}
\n ------- \n
Please formulate your answer as a list of json objects with the following format:
[{{"sub_question": <sub-question>, "search_term": <search term>}}, ...]
Answer:
"""
INITIAL_RAG_BASE_PROMPT = """ \n
You are an assistant for question-answering tasks. Use the information provided below - and only the
provided information - to answer the provided question.
The information provided below consists ofa number of documents that were deemed relevant for the question.
IMPORTANT RULES:
- If you cannot reliably answer the question solely using the provided information, say that you cannot reliably answer.
You may give some additional facts you learned, but do not try to invent an answer.
- If the information is empty or irrelevant, just say "I don't know".
- If the information is relevant but not fully conclusive, specify that the information is not conclusive and say why.
Try to keep your answer concise.
Here is the contextual information from the document store:
\n ------- \n
{context} \n\n\n
\n ------- \n
And here is the question I want you to answer based on the context above (with the motivation in mind):
\n--\n {question} \n--\n
Answer:"""
### ANSWER GENERATION PROMPTS
# Persona specification
ASSISTANT_SYSTEM_PROMPT_DEFAULT = """
You are an assistant for question-answering tasks."""
ASSISTANT_SYSTEM_PROMPT_PERSONA = """
You are an assistant for question-answering tasks. Here is more information about you:
\n ------- \n
{persona_prompt}
\n ------- \n
"""
SUB_QUESTION_ANSWER_TEMPLATE = """
Sub-Question: Q{sub_question_nr}\n Sub-Question:\n - \n{sub_question}\n --\nAnswer:\n -\n {sub_answer}\n\n
"""
INITIAL_RAG_PROMPT = """ \n
{persona_specification}
Use the information provided below - and only the
provided information - to answer the provided question.
The information provided below consists of:
1) a number of answered sub-questions - these are very important(!) and definitely should be
considered to answer the question.
2) a number of documents that were also deemed relevant for the question.
IMPORTANT RULES:
- If you cannot reliably answer the question solely using the provided information, say that you cannot reliably answer.
You may give some additional facts you learned, but do not try to invent an answer.
- If the information is empty or irrelevant, just say "I don't know".
- If the information is relevant but not fully conclusive, specify that the information is not conclusive and say why.
Remember to provide inline citations of documentsin the format [D1], [D2], [D3], etc., and [Q1], [Q2],... if
you want to cite the answer to a sub-question. If you have multiple citations, please cite for example
as [D1][Q3], or [D2][D4], etc. Feel free to cite documents in addition to the sub-questions!
Proper citations are important for the final answer to be verifiable! \n\n\n
Again, you should be sure that the answer is supported by the information provided!
Try to keep your answer concise. But also highlight uncertainties you may have should there be substantial ones,
or assumptions you made.
Here is the contextual information:
\n-------\n
*Answered Sub-questions (these should really matter!):
{answered_sub_questions}
And here are relevant document information that support the sub-question answers, or that are relevant for the actual question:\n
{relevant_docs}
\n-------\n
\n
And here is the question I want you to answer based on the information above:
\n--\n
{question}
\n--\n\n
Answer:"""
# sub_question_answer_str is empty
INITIAL_RAG_PROMPT_NO_SUB_QUESTIONS = """{answered_sub_questions}
{persona_specification}
Use the information provided below
- and only the provided information - to answer the provided question.
The information provided below consists of a number of documents that were deemed relevant for the question.
IMPORTANT RULES:
- If you cannot reliably answer the question solely using the provided information, say that you cannot reliably answer.
You may give some additional facts you learned, but do not try to invent an answer.
- If the information is irrelevant, just say "I don't know".
- If the information is relevant but not fully conclusive, specify that the information is not conclusive and say why.
Again, you should be sure that the answer is supported by the information provided!
Try to keep your answer concise.
Here are is the relevant context information:
\n-------\n
{relevant_docs}
\n-------\n
And here is the question I want you to answer based on the context above
\n--\n
{question}
\n--\n
Answer:"""
REVISED_RAG_PROMPT = """\n
{persona_specification}
Use the information provided below - and only the
provided information - to answer the provided question.
The information provided below consists of:
1) an initial answer that was given but found to be lacking in some way.
2) a number of answered sub-questions - these are very important(!) and definitely should be
considered to answer the question.
3) a number of documents that were also deemed relevant for the question.
IMPORTANT RULES:
- If you cannot reliably answer the question solely using the provided information, say that you cannot reliably answer.
You may give some additional facts you learned, but do not try to invent an answer.
- If the information is empty or irrelevant, just say "I don't know".
- If the information is relevant but not fully conclusive, provide and answer to the extent you can but also
specify that the information is not conclusive and why.
Again, you should be sure that the answer is supported by the information provided!
Try to keep your answer concise. But also highlight uncertainties you may have should there be substantial ones,
or assumptions you made.
Here is the contextual information:
\n-------\n
*Initial Answer that was found to be lacking:
{initial_answer}
*Answered Sub-questions (these should really matter! They also contain questions/answers that were not available when the original
answer was constructed):
{answered_sub_questions}
And here are relevant document information that support the sub-question answers, or that are relevant for the actual question:\n
{relevant_docs}
\n-------\n
\n
Lastly, here is the question I want you to answer based on the information above:
\n--\n
{question}
\n--\n\n
Answer:"""
# sub_question_answer_str is empty
REVISED_RAG_PROMPT_NO_SUB_QUESTIONS = """{sub_question_answer_str}\n
{persona_specification}
Use the information provided below - and only the
provided information - to answer the provided question.
The information provided below consists of:
1) an initial answer that was given but found to be lacking in some way.
2) a number of documents that were also deemed relevant for the question.
IMPORTANT RULES:
- If you cannot reliably answer the question solely using the provided information, say that you cannot reliably answer.
You may give some additional facts you learned, but do not try to invent an answer.
- If the information is empty or irrelevant, just say "I don't know".
- If the information is relevant but not fully conclusive, provide and answer to the extent you can but also
specify that the information is not conclusive and why.
Again, you should be sure that the answer is supported by the information provided!
Try to keep your answer concise. But also highlight uncertainties you may have should there be substantial ones,
or assumptions you made.
Here is the contextual information:
\n-------\n
*Initial Answer that was found to be lacking:
{initial_answer}
And here are relevant document information that support the sub-question answers, or that are relevant for the actual question:\n
{relevant_docs}
\n-------\n
\n
Lastly, here is the question I want you to answer based on the information above:
\n--\n
{question}
\n--\n\n
Answer:"""
ENTITY_TERM_PROMPT = """ \n
Based on the original question and the context retieved from a dataset, please generate a list of
entities (e.g. companies, organizations, industries, products, locations, etc.), terms and concepts
(e.g. sales, revenue, etc.) that are relevant for the question, plus their relations to each other.
\n\n
Here is the original question:
\n ------- \n
{question}
\n ------- \n
And here is the context retrieved:
\n ------- \n
{context}
\n ------- \n
Please format your answer as a json object in the following format:
{{"retrieved_entities_relationships": {{
"entities": [{{
"entity_name": <assign a name for the entity>,
"entity_type": <specify a short type name for the entity, such as 'company', 'location',...>
}}],
"relationships": [{{
"relationship_name": <assign a name for the relationship>,
"relationship_type": <specify a short type name for the relationship, such as 'sales_to', 'is_location_of',...>,
"relationship_entities": [<related entity name 1>, <related entity name 2>, ...]
}}],
"terms": [{{
"term_name": <assign a name for the term>,
"term_type": <specify a short type name for the term, such as 'revenue', 'market_share',...>,
"term_similar_to": <list terms that are similar to this term>
}}]
}}
}}
"""

View File

@@ -0,0 +1,243 @@
import ast
import json
import re
from collections.abc import Callable
from collections.abc import Iterator
from collections.abc import Sequence
from datetime import datetime
from datetime import timedelta
from typing import Any
from typing import cast
from uuid import UUID
from langchain_core.messages import BaseMessage
from sqlalchemy.orm import Session
from onyx.agent_search.pro_search_a.main.models import EntityRelationshipTermExtraction
from onyx.chat.models import AnswerStyleConfig
from onyx.chat.models import CitationConfig
from onyx.chat.models import DocumentPruningConfig
from onyx.chat.models import PromptConfig
from onyx.chat.models import ProSearchConfig
from onyx.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
from onyx.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
from onyx.configs.constants import DEFAULT_PERSONA_ID
from onyx.context.search.enums import LLMEvaluationType
from onyx.context.search.models import InferenceSection
from onyx.context.search.models import RetrievalDetails
from onyx.context.search.models import SearchRequest
from onyx.db.persona import get_persona_by_id
from onyx.db.persona import Persona
from onyx.llm.interfaces import LLM
from onyx.tools.tool_constructor import SearchToolConfig
from onyx.tools.tool_implementations.search.search_tool import SearchTool
def normalize_whitespace(text: str) -> str:
"""Normalize whitespace in text to single spaces and strip leading/trailing whitespace."""
import re
return re.sub(r"\s+", " ", text.strip())
# Post-processing
def format_docs(docs: Sequence[InferenceSection]) -> str:
formatted_doc_list = []
for doc_nr, doc in enumerate(docs):
formatted_doc_list.append(f"Document D{doc_nr + 1}:\n{doc.combined_content}")
return "\n\n".join(formatted_doc_list)
def clean_and_parse_list_string(json_string: str) -> list[dict]:
# Remove any prefixes/labels before the actual JSON content
json_string = re.sub(r"^.*?(?=\[)", "", json_string, flags=re.DOTALL)
# Remove markdown code block markers and any newline prefixes
cleaned_string = re.sub(r"```json\n|\n```", "", json_string)
cleaned_string = cleaned_string.replace("\\n", " ").replace("\n", " ")
cleaned_string = " ".join(cleaned_string.split())
# Try parsing with json.loads first, fall back to ast.literal_eval
try:
return json.loads(cleaned_string)
except json.JSONDecodeError:
try:
return ast.literal_eval(cleaned_string)
except (ValueError, SyntaxError) as e:
raise ValueError(f"Failed to parse JSON string: {cleaned_string}") from e
def clean_and_parse_json_string(json_string: str) -> dict[str, Any]:
# Remove markdown code block markers and any newline prefixes
cleaned_string = re.sub(r"```json\n|\n```", "", json_string)
cleaned_string = cleaned_string.replace("\\n", " ").replace("\n", " ")
cleaned_string = " ".join(cleaned_string.split())
# Parse the cleaned string into a Python dictionary
return json.loads(cleaned_string)
def format_entity_term_extraction(
entity_term_extraction_dict: EntityRelationshipTermExtraction,
) -> str:
entities = entity_term_extraction_dict.entities
terms = entity_term_extraction_dict.terms
relationships = entity_term_extraction_dict.relationships
entity_strs = ["\nEntities:\n"]
for entity in entities:
entity_str = f"{entity.entity_name} ({entity.entity_type})"
entity_strs.append(entity_str)
entity_str = "\n - ".join(entity_strs)
relationship_strs = ["\n\nRelationships:\n"]
for relationship in relationships:
relationship_name = relationship.relationship_name
relationship_type = relationship.relationship_type
relationship_entities = relationship.relationship_entities
relationship_str = (
f"""{relationship_name} ({relationship_type}): {relationship_entities}"""
)
relationship_strs.append(relationship_str)
relationship_str = "\n - ".join(relationship_strs)
term_strs = ["\n\nTerms:\n"]
for term in terms:
term_str = f"{term.term_name} ({term.term_type}): similar to {', '.join(term.term_similar_to)}"
term_strs.append(term_str)
term_str = "\n - ".join(term_strs)
return "\n".join(entity_strs + relationship_strs + term_strs)
def _format_time_delta(time: timedelta) -> str:
seconds_from_start = f"{((time).seconds):03d}"
microseconds_from_start = f"{((time).microseconds):06d}"
return f"{seconds_from_start}.{microseconds_from_start}"
def generate_log_message(
message: str,
node_start_time: datetime,
graph_start_time: datetime | None = None,
) -> str:
current_time = datetime.now()
if graph_start_time is not None:
graph_time_str = _format_time_delta(current_time - graph_start_time)
else:
graph_time_str = "N/A"
node_time_str = _format_time_delta(current_time - node_start_time)
return f"{graph_time_str} ({node_time_str} s): {message}"
def get_test_config(
db_session: Session, primary_llm: LLM, fast_llm: LLM, search_request: SearchRequest
) -> tuple[ProSearchConfig, SearchTool]:
persona = get_persona_by_id(DEFAULT_PERSONA_ID, None, db_session)
document_pruning_config = DocumentPruningConfig(
max_chunks=int(
persona.num_chunks
if persona.num_chunks is not None
else MAX_CHUNKS_FED_TO_CHAT
),
max_window_percentage=CHAT_TARGET_CHUNK_PERCENTAGE,
)
answer_style_config = AnswerStyleConfig(
citation_config=CitationConfig(
# The docs retrieved by this flow are already relevance-filtered
all_docs_useful=True
),
document_pruning_config=document_pruning_config,
structured_response_format=None,
)
search_tool_config = SearchToolConfig(
answer_style_config=answer_style_config,
document_pruning_config=document_pruning_config,
retrieval_options=RetrievalDetails(), # may want to set dedupe_docs=True
rerank_settings=None, # Can use this to change reranking model
selected_sections=None,
latest_query_files=None,
bypass_acl=False,
)
prompt_config = PromptConfig.from_model(persona.prompts[0])
search_tool = SearchTool(
db_session=db_session,
user=None,
persona=persona,
retrieval_options=search_tool_config.retrieval_options,
prompt_config=prompt_config,
llm=primary_llm,
fast_llm=fast_llm,
pruning_config=search_tool_config.document_pruning_config,
answer_style_config=search_tool_config.answer_style_config,
selected_sections=search_tool_config.selected_sections,
chunks_above=search_tool_config.chunks_above,
chunks_below=search_tool_config.chunks_below,
full_doc=search_tool_config.full_doc,
evaluation_type=(
LLMEvaluationType.BASIC
if persona.llm_relevance_filter
else LLMEvaluationType.SKIP
),
rerank_settings=search_tool_config.rerank_settings,
bypass_acl=search_tool_config.bypass_acl,
)
config = ProSearchConfig(
search_request=search_request,
# chat_session_id=UUID("123e4567-e89b-12d3-a456-426614174000"),
chat_session_id=UUID("edda10d5-6cef-45d8-acfb-39317552a1f4"), # Joachim
# chat_session_id=UUID("d1acd613-2692-4bc3-9d65-c6d3da62e58e"), # Evan
message_id=1,
use_persistence=True,
)
return config, search_tool
def get_persona_prompt(persona: Persona | None) -> str:
if persona is None:
return ""
else:
return "\n".join([x.system_prompt for x in persona.prompts])
def make_question_id(level: int, question_nr: int) -> str:
return f"{level}_{question_nr}"
def parse_question_id(question_id: str) -> tuple[int, int]:
level, question_nr = question_id.split("_")
return int(level), int(question_nr)
def dispatch_separated(
token_itr: Iterator[BaseMessage],
dispatch_event: Callable[[str, int], None],
sep: str = "\n",
) -> list[str | list[str | dict[str, Any]]]:
num = 0
streamed_tokens: list[str | list[str | dict[str, Any]]] = [""]
for message in token_itr:
content = cast(str, message.content)
if sep in content:
for sub_question_part in content.split(sep):
dispatch_event(sub_question_part, num)
num += 1
num -= 1 # fencepost; extra increment at end of loop
else:
dispatch_event(content, num)
streamed_tokens.append(content)
return streamed_tokens

View File

@@ -1,17 +1,21 @@
from collections.abc import Callable
from collections.abc import Iterator
from uuid import uuid4
from langchain.schema.messages import BaseMessage
from langchain_core.messages import AIMessageChunk
from langchain_core.messages import ToolCall
from sqlalchemy.orm import Session
from onyx.agent_search.run_graph import run_basic_graph
from onyx.agent_search.run_graph import run_main_graph
from onyx.chat.llm_response_handler import LLMResponseHandlerManager
from onyx.chat.models import AnswerQuestionPossibleReturn
from onyx.chat.models import AnswerPacket
from onyx.chat.models import AnswerStream
from onyx.chat.models import AnswerStyleConfig
from onyx.chat.models import CitationInfo
from onyx.chat.models import OnyxAnswerPiece
from onyx.chat.models import PromptConfig
from onyx.chat.models import ProSearchConfig
from onyx.chat.prompt_builder.build import AnswerPromptBuilder
from onyx.chat.prompt_builder.build import default_build_system_message
from onyx.chat.prompt_builder.build import default_build_user_message
@@ -19,32 +23,24 @@ from onyx.chat.prompt_builder.build import LLMCall
from onyx.chat.stream_processing.answer_response_handler import (
CitationResponseHandler,
)
from onyx.chat.stream_processing.answer_response_handler import (
DummyAnswerResponseHandler,
)
from onyx.chat.stream_processing.utils import (
map_document_id_order,
)
from onyx.chat.tool_handling.tool_response_handler import get_tool_by_name
from onyx.chat.tool_handling.tool_response_handler import ToolResponseHandler
from onyx.file_store.utils import InMemoryChatFile
from onyx.llm.interfaces import LLM
from onyx.llm.models import PreviousMessage
from onyx.natural_language_processing.utils import get_tokenizer
from onyx.tools.force import ForceUseTool
from onyx.tools.models import ToolResponse
from onyx.tools.tool import Tool
from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.tools.tool_runner import ToolCallKickoff
from onyx.tools.utils import explicit_tool_calling_supported
from onyx.utils.logger import setup_logger
logger = setup_logger()
AnswerStream = Iterator[AnswerQuestionPossibleReturn | ToolCallKickoff | ToolResponse]
class Answer:
def __init__(
self,
@@ -59,7 +55,6 @@ class Answer:
# newly passed in files to include as part of this question
# TODO THIS NEEDS TO BE HANDLED
latest_query_files: list[InMemoryChatFile] | None = None,
files: list[InMemoryChatFile] | None = None,
tools: list[Tool] | None = None,
# NOTE: for native tool-calling, this is only supported by OpenAI atm,
# but we only support them anyways
@@ -69,6 +64,9 @@ class Answer:
return_contexts: bool = False,
skip_gen_ai_answer_generation: bool = False,
is_connected: Callable[[], bool] | None = None,
pro_search_config: ProSearchConfig | None = None,
fast_llm: LLM | None = None,
db_session: Session | None = None,
) -> None:
if single_message_history and message_history:
raise ValueError(
@@ -79,7 +77,6 @@ class Answer:
self.is_connected: Callable[[], bool] | None = is_connected
self.latest_query_files = latest_query_files or []
self.file_id_to_file = {file.file_id: file for file in (files or [])}
self.tools = tools or []
self.force_use_tool = force_use_tool
@@ -92,6 +89,7 @@ class Answer:
self.prompt_config = prompt_config
self.llm = llm
self.fast_llm = fast_llm
self.llm_tokenizer = get_tokenizer(
provider_type=llm.config.model_provider,
model_name=llm.config.model_name,
@@ -100,9 +98,7 @@ class Answer:
self._final_prompt: list[BaseMessage] | None = None
self._streamed_output: list[str] | None = None
self._processed_stream: (
list[AnswerQuestionPossibleReturn | ToolResponse | ToolCallKickoff] | None
) = None
self._processed_stream: (list[AnswerPacket] | None) = None
self._return_contexts = return_contexts
self.skip_gen_ai_answer_generation = skip_gen_ai_answer_generation
@@ -115,55 +111,28 @@ class Answer:
and not skip_explicit_tool_calling
)
self.pro_search_config = pro_search_config
self.db_session = db_session
def _get_tools_list(self) -> list[Tool]:
if not self.force_use_tool.force_use:
return self.tools
tool = next(
(t for t in self.tools if t.name == self.force_use_tool.tool_name), None
)
if tool is None:
raise RuntimeError(f"Tool '{self.force_use_tool.tool_name}' not found")
tool = get_tool_by_name(self.tools, self.force_use_tool.tool_name)
logger.info(
f"Forcefully using tool='{tool.name}'"
+ (
f" with args='{self.force_use_tool.args}'"
if self.force_use_tool.args is not None
else ""
)
args_str = (
f" with args='{self.force_use_tool.args}'"
if self.force_use_tool.args
else ""
)
logger.info(f"Forcefully using tool='{tool.name}'{args_str}")
return [tool]
def _handle_specified_tool_call(
self, llm_calls: list[LLMCall], tool: Tool, tool_args: dict
) -> AnswerStream:
current_llm_call = llm_calls[-1]
# make a dummy tool handler
tool_handler = ToolResponseHandler([tool])
dummy_tool_call_chunk = AIMessageChunk(content="")
dummy_tool_call_chunk.tool_calls = [
ToolCall(name=tool.name, args=tool_args, id=str(uuid4()))
]
response_handler_manager = LLMResponseHandlerManager(
tool_handler, DummyAnswerResponseHandler(), self.is_cancelled
)
yield from response_handler_manager.handle_llm_response(
iter([dummy_tool_call_chunk])
)
new_llm_call = response_handler_manager.next_llm_call(current_llm_call)
if new_llm_call:
yield from self._get_response(llm_calls + [new_llm_call])
else:
raise RuntimeError("Tool call handler did not return a new LLM call")
# TODO: delete the function and move the full body to processed_streamed_output
def _get_response(self, llm_calls: list[LLMCall]) -> AnswerStream:
current_llm_call = llm_calls[-1]
tool, tool_args = None, None
# handle the case where no decision has to be made; we simply run the tool
if (
current_llm_call.force_use_tool.force_use
@@ -173,17 +142,10 @@ class Answer:
current_llm_call.force_use_tool.tool_name,
current_llm_call.force_use_tool.args,
)
tool = next(
(t for t in current_llm_call.tools if t.name == tool_name), None
)
if not tool:
raise RuntimeError(f"Tool '{tool_name}' not found")
yield from self._handle_specified_tool_call(llm_calls, tool, tool_args)
return
tool = get_tool_by_name(current_llm_call.tools, tool_name)
# special pre-logic for non-tool calling LLM case
if not self.using_tool_calling_llm and current_llm_call.tools:
elif not self.using_tool_calling_llm and current_llm_call.tools:
chosen_tool_and_args = (
ToolResponseHandler.get_tool_call_for_non_tool_calling_llm(
current_llm_call, self.llm
@@ -191,8 +153,24 @@ class Answer:
)
if chosen_tool_and_args:
tool, tool_args = chosen_tool_and_args
yield from self._handle_specified_tool_call(llm_calls, tool, tool_args)
return
if tool and tool_args:
dummy_tool_call_chunk = AIMessageChunk(content="")
dummy_tool_call_chunk.tool_calls = [
ToolCall(name=tool.name, args=tool_args, id=str(uuid4()))
]
response_handler_manager = LLMResponseHandlerManager(
ToolResponseHandler([tool]), None, self.is_cancelled
)
yield from response_handler_manager.handle_llm_response(
iter([dummy_tool_call_chunk])
)
tmp_call = response_handler_manager.next_llm_call(current_llm_call)
if tmp_call is None:
return # no more LLM calls to process
current_llm_call = tmp_call
# if we're skipping gen ai answer generation, we should break
# out unless we're forcing a tool call. If we don't, we might generate an
@@ -212,29 +190,63 @@ class Answer:
current_llm_call
) or ([], [])
# Quotes are no longer supported
# answer_handler: AnswerResponseHandler
# if self.answer_style_config.citation_config:
# answer_handler = CitationResponseHandler(
# context_docs=search_result,
# doc_id_to_rank_map=map_document_id_order(search_result),
# )
# elif self.answer_style_config.quotes_config:
# answer_handler = QuotesResponseHandler(
# context_docs=search_result,
# )
# else:
# raise ValueError("No answer style config provided")
# NEXT: we still want to handle the LLM response stream, but it is now:
# 1. handle the tool call requests
# 2. feed back the processed results
# 3. handle the citations
answer_handler = CitationResponseHandler(
context_docs=final_search_results,
final_doc_id_to_rank_map=map_document_id_order(final_search_results),
display_doc_id_to_rank_map=map_document_id_order(displayed_search_results),
)
# At the moment, this wrapper class passes streamed stuff through citation and tool handlers.
# In the future, we'll want to handle citations and tool calls in the langgraph graph.
response_handler_manager = LLMResponseHandlerManager(
tool_call_handler, answer_handler, self.is_cancelled
)
# In langgraph, whether we do the basic thing (call llm stream) or pro search
# is based on a flag in the pro search config
if self.pro_search_config:
if self.pro_search_config.search_request is None:
raise ValueError("Search request must be provided for pro search")
search_tools = [tool for tool in self.tools if isinstance(tool, SearchTool)]
if len(search_tools) == 0:
raise ValueError("No search tool found")
elif len(search_tools) > 1:
# TODO: handle multiple search tools
raise ValueError("Multiple search tools found")
search_tool = search_tools[0]
if self.db_session is None:
raise ValueError("db_session must be provided for pro search")
if self.fast_llm is None:
raise ValueError("fast_llm must be provided for pro search")
stream = run_main_graph(
config=self.pro_search_config,
primary_llm=self.llm,
fast_llm=self.fast_llm,
search_tool=search_tool,
db_session=self.db_session,
)
else:
stream = run_basic_graph(
last_llm_call=current_llm_call,
primary_llm=self.llm,
answer_style_config=self.answer_style_config,
response_handler_manager=response_handler_manager,
)
processed_stream = []
for packet in stream:
processed_stream.append(packet)
yield packet
self._processed_stream = processed_stream
return
# DEBUG: good breakpoint
stream = self.llm.stream(
# For tool calling LLMs, we want to insert the task prompt as part of this flow, this is because the LLM

View File

@@ -48,6 +48,7 @@ def prepare_chat_message_request(
retrieval_details: RetrievalDetails | None,
rerank_settings: RerankingDetails | None,
db_session: Session,
use_pro_search: bool = False,
) -> CreateChatMessageRequest:
# Typically used for one shot flows like SlackBot or non-chat API endpoint use cases
new_chat_session = create_chat_session(
@@ -72,6 +73,7 @@ def prepare_chat_message_request(
search_doc_ids=None,
retrieval_options=retrieval_details,
rerank_settings=rerank_settings,
use_pro_search=use_pro_search,
)

View File

@@ -9,18 +9,30 @@ from onyx.chat.models import StreamStopInfo
from onyx.chat.models import StreamStopReason
from onyx.chat.prompt_builder.build import LLMCall
from onyx.chat.stream_processing.answer_response_handler import AnswerResponseHandler
from onyx.chat.stream_processing.answer_response_handler import (
DummyAnswerResponseHandler,
)
from onyx.chat.tool_handling.tool_response_handler import ToolResponseHandler
class LLMResponseHandlerManager:
"""
This class is responsible for postprocessing the LLM response stream.
In particular, we:
1. handle the tool call requests
2. handle citations
3. pass through answers generated by the LLM
4. Stop yielding if the client disconnects
"""
def __init__(
self,
tool_handler: ToolResponseHandler,
answer_handler: AnswerResponseHandler,
tool_handler: ToolResponseHandler | None,
answer_handler: AnswerResponseHandler | None,
is_cancelled: Callable[[], bool],
):
self.tool_handler = tool_handler
self.answer_handler = answer_handler
self.tool_handler = tool_handler or ToolResponseHandler([])
self.answer_handler = answer_handler or DummyAnswerResponseHandler()
self.is_cancelled = is_cancelled
def handle_llm_response(

View File

@@ -3,7 +3,9 @@ from collections.abc import Iterator
from datetime import datetime
from enum import Enum
from typing import Any
from typing import Literal
from typing import TYPE_CHECKING
from uuid import UUID
from pydantic import BaseModel
from pydantic import ConfigDict
@@ -16,6 +18,8 @@ from onyx.context.search.enums import QueryFlow
from onyx.context.search.enums import RecencyBiasSetting
from onyx.context.search.enums import SearchType
from onyx.context.search.models import RetrievalDocs
from onyx.context.search.models import SearchRequest
from onyx.llm.models import PreviousMessage
from onyx.llm.override_models import PromptOverride
from onyx.tools.models import ToolCallFinalResult
from onyx.tools.models import ToolCallKickoff
@@ -204,6 +208,30 @@ class PersonaOverrideConfig(BaseModel):
custom_tools_openapi: list[dict[str, Any]] = Field(default_factory=list)
class ProSearchConfig(BaseModel):
"""
Configuration for the Pro Search feature.
"""
use_agentic_search: bool = False
# For persisting agent search data
chat_session_id: UUID | None = None
# The message ID of the user message that triggered the Pro Search
message_id: int | None = None
# The search request that was used to generate the Pro Search
search_request: SearchRequest
# Whether to persistence data for the Pro Search (turned off for testing)
use_persistence: bool = True
# Whether to allow creation of refinement questions (and entity extraction, etc.)
allow_refinement: bool = False
# Message history for the current chat session
message_history: list[PreviousMessage] | None = None
AnswerQuestionPossibleReturn = (
OnyxAnswerPiece
| CitationInfo
@@ -327,3 +355,39 @@ ResponsePart = (
| ToolCallFinalResult
| StreamStopInfo
)
class SubQueryPiece(BaseModel):
sub_query: str
level: int
level_question_nr: int
query_id: int
class AgentAnswerPiece(BaseModel):
answer_piece: str
level: int
level_question_nr: int
answer_type: Literal["agent_sub_answer", "agent_level_answer"]
class SubQuestionPiece(BaseModel):
sub_question: str
level: int
level_question_nr: int
class ExtendedToolResponse(ToolResponse):
level: int
level_question_nr: int
ProSearchPacket = (
SubQuestionPiece | AgentAnswerPiece | SubQueryPiece | ExtendedToolResponse
)
AnswerPacket = (
AnswerQuestionPossibleReturn | ProSearchPacket | ToolCallKickoff | ToolResponse
)
AnswerStream = Iterator[AnswerPacket]

View File

@@ -24,6 +24,8 @@ from onyx.chat.models import MessageSpecificCitations
from onyx.chat.models import OnyxAnswerPiece
from onyx.chat.models import OnyxContexts
from onyx.chat.models import PromptConfig
from onyx.chat.models import ProSearchConfig
from onyx.chat.models import ProSearchPacket
from onyx.chat.models import QADocsResponse
from onyx.chat.models import StreamingError
from onyx.chat.models import StreamStopInfo
@@ -33,11 +35,13 @@ from onyx.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
from onyx.configs.constants import MessageType
from onyx.configs.constants import MilestoneRecordType
from onyx.configs.constants import NO_AUTH_USER_ID
from onyx.context.search.enums import LLMEvaluationType
from onyx.context.search.enums import OptionalSearchSetting
from onyx.context.search.enums import QueryFlow
from onyx.context.search.enums import SearchType
from onyx.context.search.models import InferenceSection
from onyx.context.search.models import RetrievalDetails
from onyx.context.search.models import SearchRequest
from onyx.context.search.retrieval.search_runner import inference_sections_from_ids
from onyx.context.search.utils import chunks_or_sections_to_search_docs
from onyx.context.search.utils import dedupe_documents
@@ -281,6 +285,7 @@ ChatPacket = (
| MessageSpecificCitations
| MessageResponseIDInfo
| StreamStopInfo
| ProSearchPacket
)
ChatPacketStream = Iterator[ChatPacket]
@@ -683,6 +688,51 @@ def stream_chat_message_objects(
for tool_list in tool_dict.values():
tools.extend(tool_list)
message_history = [
PreviousMessage.from_chat_message(msg, files) for msg in history_msgs
]
search_request = None
pro_search_config = None
if new_msg_req.use_pro_search:
search_request = SearchRequest(
query=final_msg.message,
evaluation_type=(
LLMEvaluationType.BASIC
if persona.llm_relevance_filter
else LLMEvaluationType.SKIP
),
human_selected_filters=(
retrieval_options.filters if retrieval_options else None
),
persona=persona,
offset=(retrieval_options.offset if retrieval_options else None),
limit=retrieval_options.limit if retrieval_options else None,
rerank_settings=new_msg_req.rerank_settings,
chunks_above=new_msg_req.chunks_above,
chunks_below=new_msg_req.chunks_below,
full_doc=new_msg_req.full_doc,
enable_auto_detect_filters=(
retrieval_options.enable_auto_detect_filters
if retrieval_options
else None
),
)
# TODO: Since we're deleting the current main path in Answer,
# we should construct this unconditionally inside Answer instead
# Leaving it here for the time being to avoid breaking changes
pro_search_config = (
ProSearchConfig(
search_request=search_request,
chat_session_id=chat_session_id,
message_id=user_message.id if user_message else None,
message_history=message_history,
)
if new_msg_req.use_pro_search
else None
)
# TODO: add previous messages, answer style config, tools, etc.
# LLM prompt building, response capturing, etc.
answer = Answer(
is_connected=is_connected,
@@ -702,11 +752,12 @@ def stream_chat_message_objects(
)
)
),
message_history=[
PreviousMessage.from_chat_message(msg, files) for msg in history_msgs
],
fast_llm=fast_llm,
message_history=message_history,
tools=tools,
force_use_tool=_get_force_search_settings(new_msg_req, tools),
pro_search_config=pro_search_config,
db_session=db_session,
)
reference_db_search_docs = None
@@ -718,6 +769,7 @@ def stream_chat_message_objects(
for packet in answer.processed_streamed_output:
if isinstance(packet, ToolResponse):
# TODO: don't need to dedupe here when we do it in agent flow
if packet.id == SEARCH_RESPONSE_SUMMARY_ID:
(
qa_docs_response,
@@ -738,25 +790,30 @@ def stream_chat_message_objects(
elif packet.id == SECTION_RELEVANCE_LIST_ID:
relevance_sections = packet.response
if reference_db_search_docs is not None:
llm_indices = relevant_sections_to_indices(
relevance_sections=relevance_sections,
items=[
translate_db_search_doc_to_server_search_doc(doc)
for doc in reference_db_search_docs
],
if reference_db_search_docs is None:
logger.warning(
"No reference docs found for relevance filtering"
)
continue
llm_indices = relevant_sections_to_indices(
relevance_sections=relevance_sections,
items=[
translate_db_search_doc_to_server_search_doc(doc)
for doc in reference_db_search_docs
],
)
if dropped_indices:
llm_indices = drop_llm_indices(
llm_indices=llm_indices,
search_docs=reference_db_search_docs,
dropped_indices=dropped_indices,
)
if dropped_indices:
llm_indices = drop_llm_indices(
llm_indices=llm_indices,
search_docs=reference_db_search_docs,
dropped_indices=dropped_indices,
)
yield LLMRelevanceFilterResponse(
llm_selected_doc_indices=llm_indices
)
yield LLMRelevanceFilterResponse(
llm_selected_doc_indices=llm_indices
)
elif packet.id == FINAL_CONTEXT_DOCUMENTS_ID:
yield FinalUsedContextDocsResponse(
final_context_docs=packet.response

View File

@@ -147,6 +147,7 @@ class AnswerPromptBuilder:
)
# TODO: rename this? AnswerConfig maybe?
class LLMCall(BaseModel__v1):
prompt_builder: AnswerPromptBuilder
tools: list[Tool]

View File

@@ -25,6 +25,13 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
def get_tool_by_name(tools: list[Tool], tool_name: str) -> Tool:
for tool in tools:
if tool.name == tool_name:
return tool
raise RuntimeError(f"Tool '{tool_name}' not found")
class ToolResponseHandler:
def __init__(self, tools: list[Tool]):
self.tools = tools
@@ -45,18 +52,7 @@ class ToolResponseHandler:
) -> tuple[Tool, dict] | None:
if llm_call.force_use_tool.force_use:
# if we are forcing a tool, we don't need to check which tools to run
tool = next(
(
t
for t in llm_call.tools
if t.name == llm_call.force_use_tool.tool_name
),
None,
)
if not tool:
raise RuntimeError(
f"Tool '{llm_call.force_use_tool.tool_name}' not found"
)
tool = get_tool_by_name(llm_call.tools, llm_call.force_use_tool.tool_name)
tool_args = (
llm_call.force_use_tool.args
@@ -118,20 +114,17 @@ class ToolResponseHandler:
tool for tool in self.tools if tool.name == tool_call_request["name"]
]
if not known_tools_by_name:
logger.error(
"Tool call requested with unknown name field. \n"
f"self.tools: {self.tools}"
f"tool_call_request: {tool_call_request}"
)
continue
else:
if known_tools_by_name:
selected_tool = known_tools_by_name[0]
selected_tool_call_request = tool_call_request
if selected_tool and selected_tool_call_request:
break
logger.error(
"Tool call requested with unknown name field. \n"
f"self.tools: {self.tools}"
f"tool_call_request: {tool_call_request}"
)
if not selected_tool or not selected_tool_call_request:
return
@@ -171,8 +164,6 @@ class ToolResponseHandler:
else:
self.tool_call_chunk += response_item # type: ignore
return
def next_llm_call(self, current_llm_call: LLMCall) -> LLMCall | None:
if (
self.tool_runner is None

View File

@@ -0,0 +1,57 @@
import os
from .chat_configs import NUM_RETURNED_HITS
#####
# Agent Configs
#####
agent_retrieval_stats_os: bool | str | None = os.environ.get(
"AGENT_RETRIEVAL_STATS", False
)
AGENT_RETRIEVAL_STATS: bool = False
if isinstance(agent_retrieval_stats_os, str) and agent_retrieval_stats_os == "True":
AGENT_RETRIEVAL_STATS = True
elif isinstance(agent_retrieval_stats_os, bool) and agent_retrieval_stats_os:
AGENT_RETRIEVAL_STATS = True
agent_max_query_retrieval_results_os: int | str = os.environ.get(
"AGENT_MAX_QUERY_RETRIEVAL_RESULTS", NUM_RETURNED_HITS
)
AGENT_MAX_QUERY_RETRIEVAL_RESULTS: int = NUM_RETURNED_HITS
try:
atmqrr = int(agent_max_query_retrieval_results_os)
AGENT_MAX_QUERY_RETRIEVAL_RESULTS = atmqrr
except ValueError:
raise ValueError(
f"MAX_AGENT_QUERY_RETRIEVAL_RESULTS must be an integer, got {AGENT_MAX_QUERY_RETRIEVAL_RESULTS}"
)
# Reranking agent configs
agent_reranking_stats_os: bool | str | None = os.environ.get(
"AGENT_RERANKING_TEST", False
)
AGENT_RERANKING_STATS: bool = False
if isinstance(agent_reranking_stats_os, str) and agent_reranking_stats_os == "True":
AGENT_RERANKING_STATS = True
elif isinstance(agent_reranking_stats_os, bool) and agent_reranking_stats_os:
AGENT_RERANKING_STATS = True
agent_reranking_max_query_retrieval_results_os: int | str = os.environ.get(
"AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS", NUM_RETURNED_HITS
)
AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS: int = NUM_RETURNED_HITS
try:
atmqrr = int(agent_reranking_max_query_retrieval_results_os)
AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS = atmqrr
except ValueError:
raise ValueError(
f"AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS must be an integer, got {AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS}"
)

View File

@@ -406,8 +406,18 @@ class SearchPipeline:
@property
def section_relevance_list(self) -> list[bool]:
llm_indices = relevant_sections_to_indices(
relevance_sections=self.section_relevance,
items=self.final_context_sections,
return section_relevance_list_impl(
section_relevance=self.section_relevance,
final_context_sections=self.final_context_sections,
)
return [ind in llm_indices for ind in range(len(self.final_context_sections))]
def section_relevance_list_impl(
section_relevance: list[SectionRelevancePiece] | None,
final_context_sections: list[InferenceSection],
) -> list[bool]:
llm_indices = relevant_sections_to_indices(
relevance_sections=section_relevance,
items=final_context_sections,
)
return [ind in llm_indices for ind in range(len(final_context_sections))]

View File

@@ -80,7 +80,7 @@ def drop_llm_indices(
search_docs: Sequence[DBSearchDoc | SavedSearchDoc],
dropped_indices: list[int],
) -> list[int]:
llm_bools = [True if i in llm_indices else False for i in range(len(search_docs))]
llm_bools = [i in llm_indices for i in range(len(search_docs))]
if dropped_indices:
llm_bools = [
val for ind, val in enumerate(llm_bools) if ind not in dropped_indices

View File

@@ -1,6 +1,7 @@
from collections.abc import Sequence
from datetime import datetime
from datetime import timedelta
from typing import Any
from uuid import UUID
from fastapi import HTTPException
@@ -15,13 +16,21 @@ from sqlalchemy.exc import MultipleResultsFound
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import Session
from onyx.agent_search.pro_search_a.answer_initial_sub_question.models import (
QuestionAnswerResults,
)
from onyx.agent_search.pro_search_a.main.models import CombinedAgentMetrics
from onyx.auth.schemas import UserRole
from onyx.chat.models import DocumentRelevance
from onyx.configs.chat_configs import HARD_DELETE_CHATS
from onyx.configs.constants import MessageType
from onyx.context.search.models import InferenceSection
from onyx.context.search.models import RetrievalDocs
from onyx.context.search.models import SavedSearchDoc
from onyx.context.search.models import SearchDoc as ServerSearchDoc
from onyx.db.models import AgentSearchMetrics
from onyx.db.models import AgentSubQuery
from onyx.db.models import AgentSubQuestion
from onyx.db.models import ChatMessage
from onyx.db.models import ChatMessage__SearchDoc
from onyx.db.models import ChatSession
@@ -37,6 +46,8 @@ from onyx.file_store.models import FileDescriptor
from onyx.llm.override_models import LLMOverride
from onyx.llm.override_models import PromptOverride
from onyx.server.query_and_chat.models import ChatMessageDetail
from onyx.server.query_and_chat.models import SubQueryDetail
from onyx.server.query_and_chat.models import SubQuestionDetail
from onyx.tools.tool_runner import ToolCallFinalResult
from onyx.utils.logger import setup_logger
@@ -496,6 +507,7 @@ def get_chat_messages_by_session(
prefetch_tool_calls: bool = False,
) -> list[ChatMessage]:
if not skip_permission_check:
# bug if we ever call this expecting the permission check to not be skipped
get_chat_session_by_id(
chat_session_id=chat_session_id, user_id=user_id, db_session=db_session
)
@@ -507,7 +519,12 @@ def get_chat_messages_by_session(
)
if prefetch_tool_calls:
stmt = stmt.options(joinedload(ChatMessage.tool_call))
stmt = stmt.options(
joinedload(ChatMessage.tool_call),
joinedload(ChatMessage.sub_questions).joinedload(
AgentSubQuestion.sub_queries
),
)
result = db_session.scalars(stmt).unique().all()
else:
result = db_session.scalars(stmt).all()
@@ -837,14 +854,45 @@ def translate_db_search_doc_to_server_search_doc(
)
def get_retrieval_docs_from_chat_message(
chat_message: ChatMessage, remove_doc_content: bool = False
def translate_db_sub_questions_to_server_objects(
db_sub_questions: list[AgentSubQuestion],
) -> list[SubQuestionDetail]:
sub_questions = []
for sub_question in db_sub_questions:
sub_queries = []
docs: list[SearchDoc] = []
for sub_query in sub_question.sub_queries:
doc_ids = [doc.id for doc in sub_query.search_docs]
sub_queries.append(
SubQueryDetail(
query=sub_query.sub_query,
query_id=sub_query.id,
doc_ids=doc_ids,
)
)
docs += sub_query.search_docs
sub_questions.append(
SubQuestionDetail(
level=sub_question.level,
level_question_nr=sub_question.level_question_nr,
question=sub_question.sub_question,
answer=sub_question.sub_answer,
sub_queries=sub_queries,
context_docs=get_retrieval_docs_from_search_docs(docs),
)
)
return sub_questions
def get_retrieval_docs_from_search_docs(
search_docs: list[SearchDoc], remove_doc_content: bool = False
) -> RetrievalDocs:
top_documents = [
translate_db_search_doc_to_server_search_doc(
db_doc, remove_doc_content=remove_doc_content
)
for db_doc in chat_message.search_docs
for db_doc in search_docs
]
top_documents = sorted(top_documents, key=lambda doc: doc.score, reverse=True) # type: ignore
return RetrievalDocs(top_documents=top_documents)
@@ -861,8 +909,8 @@ def translate_db_message_to_chat_message_detail(
latest_child_message=chat_message.latest_child_message,
message=chat_message.message,
rephrased_query=chat_message.rephrased_query,
context_docs=get_retrieval_docs_from_chat_message(
chat_message, remove_doc_content=remove_doc_content
context_docs=get_retrieval_docs_from_search_docs(
chat_message.search_docs, remove_doc_content=remove_doc_content
),
message_type=chat_message.message_type,
time_sent=chat_message.time_sent,
@@ -877,6 +925,114 @@ def translate_db_message_to_chat_message_detail(
else None,
alternate_assistant_id=chat_message.alternate_assistant_id,
overridden_model=chat_message.overridden_model,
sub_questions=translate_db_sub_questions_to_server_objects(
chat_message.sub_questions
),
)
return chat_msg_detail
def log_agent_metrics(
db_session: Session,
user_id: UUID | None,
persona_id: int | None, # Can be none if temporary persona is used
agent_type: str,
start_time: datetime,
agent_metrics: CombinedAgentMetrics,
) -> AgentSearchMetrics:
agent_timings = agent_metrics.timings
agent_base_metrics = agent_metrics.base_metrics
agent_refined_metrics = agent_metrics.refined_metrics
agent_additional_metrics = agent_metrics.additional_metrics
agent_metric_tracking = AgentSearchMetrics(
user_id=user_id,
persona_id=persona_id,
agent_type=agent_type,
start_time=start_time,
base_duration__s=agent_timings.base_duration__s,
full_duration__s=agent_timings.full_duration__s,
base_metrics=vars(agent_base_metrics),
refined_metrics=vars(agent_refined_metrics),
all_metrics=vars(agent_additional_metrics),
)
db_session.add(agent_metric_tracking)
db_session.flush()
return agent_metric_tracking
def log_agent_sub_question_results(
db_session: Session,
chat_session_id: UUID | None,
primary_message_id: int | None,
sub_question_answer_results: list[QuestionAnswerResults],
) -> None:
def _create_citation_format_list(
document_citations: list[InferenceSection],
) -> list[dict[str, Any]]:
citation_list: list[dict[str, Any]] = []
for document_citation in document_citations:
document_citation_dict = {
"link": "",
"blurb": document_citation.center_chunk.blurb,
"content": document_citation.center_chunk.content,
"metadata": document_citation.center_chunk.metadata,
"updated_at": str(document_citation.center_chunk.updated_at),
"document_id": document_citation.center_chunk.document_id,
"source_type": "file",
"source_links": document_citation.center_chunk.source_links,
"match_highlights": document_citation.center_chunk.match_highlights,
"semantic_identifier": document_citation.center_chunk.semantic_identifier,
}
citation_list.append(document_citation_dict)
return citation_list
now = datetime.now()
for sub_question_answer_result in sub_question_answer_results:
level, level_question_nr = [
int(x) for x in sub_question_answer_result.question_id.split("_")
]
sub_question = sub_question_answer_result.question
sub_answer = sub_question_answer_result.answer
sub_document_results = _create_citation_format_list(
sub_question_answer_result.documents
)
sub_queries = [
x.query for x in sub_question_answer_result.expanded_retrieval_results
]
sub_question_object = AgentSubQuestion(
chat_session_id=chat_session_id,
primary_question_id=primary_message_id,
level=level,
level_question_nr=level_question_nr,
sub_question=sub_question,
sub_answer=sub_answer,
sub_question_doc_results=sub_document_results,
)
db_session.add(sub_question_object)
db_session.commit()
# db_session.flush()
sub_question_id = sub_question_object.id
for sub_query in sub_queries:
sub_query_object = AgentSubQuery(
parent_question_id=sub_question_id,
chat_session_id=chat_session_id,
sub_query=sub_query,
time_created=now,
)
db_session.add(sub_query_object)
db_session.commit()
# db_session.flush()
return None

View File

@@ -320,6 +320,17 @@ class ChatMessage__SearchDoc(Base):
)
class AgentSubQuery__SearchDoc(Base):
__tablename__ = "agent__sub_query__search_doc"
sub_query_id: Mapped[int] = mapped_column(
ForeignKey("agent__sub_query.id"), primary_key=True
)
search_doc_id: Mapped[int] = mapped_column(
ForeignKey("search_doc.id"), primary_key=True
)
class Document__Tag(Base):
__tablename__ = "document__tag"
@@ -960,6 +971,11 @@ class SearchDoc(Base):
secondary=ChatMessage__SearchDoc.__table__,
back_populates="search_docs",
)
sub_queries = relationship(
"AgentSubQuery",
secondary=AgentSubQuery__SearchDoc.__table__,
back_populates="search_docs",
)
class ToolCall(Base):
@@ -1122,6 +1138,11 @@ class ChatMessage(Base):
uselist=False,
)
sub_questions: Mapped[list["AgentSubQuestion"]] = relationship(
"AgentSubQuestion",
back_populates="primary_message",
)
standard_answers: Mapped[list["StandardAnswer"]] = relationship(
"StandardAnswer",
secondary=ChatMessage__StandardAnswer.__table__,
@@ -1156,6 +1177,71 @@ class ChatFolder(Base):
return self.display_priority < other.display_priority
class AgentSubQuestion(Base):
"""
A sub-question is a question that is asked of the LLM to gather supporting
information to answer a primary question.
"""
__tablename__ = "agent__sub_question"
id: Mapped[int] = mapped_column(primary_key=True)
primary_question_id: Mapped[int] = mapped_column(ForeignKey("chat_message.id"))
chat_session_id: Mapped[UUID] = mapped_column(
PGUUID(as_uuid=True), ForeignKey("chat_session.id")
)
sub_question: Mapped[str] = mapped_column(Text)
level: Mapped[int] = mapped_column(Integer)
level_question_nr: Mapped[int] = mapped_column(Integer)
time_created: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now()
)
sub_answer: Mapped[str] = mapped_column(Text)
sub_question_doc_results: Mapped[JSON_ro] = mapped_column(postgresql.JSONB())
# Relationships
primary_message: Mapped["ChatMessage"] = relationship(
"ChatMessage",
foreign_keys=[primary_question_id],
back_populates="sub_questions",
)
chat_session: Mapped["ChatSession"] = relationship("ChatSession")
sub_queries: Mapped[list["AgentSubQuery"]] = relationship(
"AgentSubQuery", back_populates="parent_question"
)
class AgentSubQuery(Base):
"""
A sub-query is a vector DB query that gathers supporting information to answer a sub-question.
"""
__tablename__ = "agent__sub_query"
id: Mapped[int] = mapped_column(primary_key=True)
parent_question_id: Mapped[int] = mapped_column(
ForeignKey("agent__sub_question.id")
)
chat_session_id: Mapped[UUID] = mapped_column(
PGUUID(as_uuid=True), ForeignKey("chat_session.id")
)
sub_query: Mapped[str] = mapped_column(Text)
time_created: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now()
)
# Relationships
parent_question: Mapped["AgentSubQuestion"] = relationship(
"AgentSubQuestion", back_populates="sub_queries"
)
chat_session: Mapped["ChatSession"] = relationship("ChatSession")
search_docs: Mapped[list["SearchDoc"]] = relationship(
"SearchDoc",
secondary=AgentSubQuery__SearchDoc.__table__,
back_populates="sub_queries",
)
"""
Feedback, Logging, Metrics Tables
"""
@@ -1641,6 +1727,25 @@ class PGFileStore(Base):
lobj_oid: Mapped[int] = mapped_column(Integer, nullable=False)
class AgentSearchMetrics(Base):
__tablename__ = "agent__search_metrics"
id: Mapped[int] = mapped_column(primary_key=True)
user_id: Mapped[UUID | None] = mapped_column(
ForeignKey("user.id", ondelete="CASCADE"), nullable=True
)
persona_id: Mapped[int | None] = mapped_column(
ForeignKey("persona.id"), nullable=True
)
agent_type: Mapped[str] = mapped_column(String)
start_time: Mapped[datetime.datetime] = mapped_column(DateTime(timezone=True))
base_duration__s: Mapped[float] = mapped_column(Float)
full_duration__s: Mapped[float] = mapped_column(Float)
base_metrics: Mapped[JSON_ro] = mapped_column(postgresql.JSONB(), nullable=True)
refined_metrics: Mapped[JSON_ro] = mapped_column(postgresql.JSONB(), nullable=True)
all_metrics: Mapped[JSON_ro] = mapped_column(postgresql.JSONB(), nullable=True)
"""
************************************************************************
Enterprise Edition Models

View File

@@ -125,6 +125,10 @@ class CreateChatMessageRequest(ChunkContext):
# https://platform.openai.com/docs/guides/structured-outputs/introduction
structured_response_format: dict | None = None
# If true, ignores most of the search options and uses pro search instead.
# TODO: decide how many of the above options we want to pass through to pro search
use_pro_search: bool = False
@model_validator(mode="after")
def check_search_doc_ids_or_retrieval_options(self) -> "CreateChatMessageRequest":
if self.search_doc_ids is None and self.retrieval_options is None:
@@ -190,6 +194,22 @@ class SearchFeedbackRequest(BaseModel):
return self
class SubQueryDetail(BaseModel):
query: str
query_id: int
# TODO: store these to enable per-query doc selection
doc_ids: list[int] | None = None
class SubQuestionDetail(BaseModel):
level: int
level_question_nr: int
question: str
answer: str
sub_queries: list[SubQueryDetail] | None = None
context_docs: RetrievalDocs | None = None
class ChatMessageDetail(BaseModel):
message_id: int
parent_message: int | None = None
@@ -201,9 +221,10 @@ class ChatMessageDetail(BaseModel):
time_sent: datetime
overridden_model: str | None
alternate_assistant_id: int | None = None
# Dict mapping citation number to db_doc_id
chat_session_id: UUID | None = None
# Dict mapping citation number to db_doc_id
citations: dict[int, int] | None = None
sub_questions: list[SubQuestionDetail] | None = None
files: list[FileDescriptor]
tool_call: ToolCallFinalResult | None

View File

@@ -25,6 +25,11 @@ class ToolCallSummary(BaseModel__v1):
tool_call_request: AIMessage
tool_call_result: ToolMessage
# This is a workaround to allow arbitrary types in the model
# TODO: Remove this once we have a better solution
class Config:
arbitrary_types_allowed = True
def tool_call_tokens(
tool_call_summary: ToolCallSummary, llm_tokenizer: BaseTokenizer

View File

@@ -4,6 +4,9 @@ from uuid import UUID
from pydantic import BaseModel
from pydantic import model_validator
from onyx.context.search.enums import SearchType
from onyx.context.search.models import IndexFilters
class ToolResponse(BaseModel):
id: str | None = None
@@ -45,5 +48,11 @@ class DynamicSchemaInfo(BaseModel):
message_id: int | None
class SearchQueryInfo(BaseModel):
predicted_search: SearchType | None
final_filters: IndexFilters
recency_bias_multiplier: float
CHAT_SESSION_ID_PLACEHOLDER = "CHAT_SESSION_ID"
MESSAGE_ID_PLACEHOLDER = "MESSAGE_ID"

View File

@@ -1,9 +1,9 @@
import json
from collections.abc import Callable
from collections.abc import Generator
from typing import Any
from typing import cast
from pydantic import BaseModel
from sqlalchemy.orm import Session
from onyx.chat.chat_utils import llm_doc_from_inference_section
@@ -25,13 +25,13 @@ from onyx.configs.chat_configs import CONTEXT_CHUNKS_BELOW
from onyx.configs.model_configs import GEN_AI_MODEL_FALLBACK_MAX_TOKENS
from onyx.context.search.enums import LLMEvaluationType
from onyx.context.search.enums import QueryFlow
from onyx.context.search.enums import SearchType
from onyx.context.search.models import IndexFilters
from onyx.context.search.models import InferenceSection
from onyx.context.search.models import RerankingDetails
from onyx.context.search.models import RetrievalDetails
from onyx.context.search.models import SearchRequest
from onyx.context.search.pipeline import SearchPipeline
from onyx.context.search.pipeline import section_relevance_list_impl
from onyx.db.models import Persona
from onyx.db.models import User
from onyx.llm.interfaces import LLM
@@ -39,6 +39,7 @@ from onyx.llm.models import PreviousMessage
from onyx.secondary_llm_flows.choose_search import check_if_need_search
from onyx.secondary_llm_flows.query_expansion import history_based_query_rephrase
from onyx.tools.message import ToolCallSummary
from onyx.tools.models import SearchQueryInfo
from onyx.tools.models import ToolResponse
from onyx.tools.tool import Tool
from onyx.tools.tool_implementations.search.search_utils import llm_doc_to_dict
@@ -62,13 +63,10 @@ SECTION_RELEVANCE_LIST_ID = "section_relevance_list"
SEARCH_EVALUATION_ID = "llm_doc_eval"
class SearchResponseSummary(BaseModel):
class SearchResponseSummary(SearchQueryInfo):
top_sections: list[InferenceSection]
rephrased_query: str | None = None
predicted_flow: QueryFlow | None
predicted_search: SearchType | None
final_filters: IndexFilters
recency_bias_multiplier: float
SEARCH_TOOL_DESCRIPTION = """
@@ -117,6 +115,8 @@ class SearchTool(Tool):
self.fast_llm = fast_llm
self.evaluation_type = evaluation_type
self.search_pipeline: SearchPipeline | None = None
self.selected_sections = selected_sections
self.full_doc = full_doc
@@ -281,8 +281,10 @@ class SearchTool(Tool):
yield ToolResponse(id=FINAL_CONTEXT_DOCUMENTS_ID, response=llm_docs)
def run(self, **kwargs: str) -> Generator[ToolResponse, None, None]:
def run(self, **kwargs: Any) -> Generator[ToolResponse, None, None]:
query = cast(str, kwargs["query"])
force_no_rerank = cast(bool, kwargs.get("force_no_rerank", False))
alternate_db_session = cast(Session, kwargs.get("alternate_db_session", None))
if self.selected_sections:
yield from self._build_response_for_specified_sections(query)
@@ -291,7 +293,9 @@ class SearchTool(Tool):
search_pipeline = SearchPipeline(
search_request=SearchRequest(
query=query,
evaluation_type=self.evaluation_type,
evaluation_type=LLMEvaluationType.SKIP
if force_no_rerank
else self.evaluation_type,
human_selected_filters=(
self.retrieval_options.filters if self.retrieval_options else None
),
@@ -300,7 +304,16 @@ class SearchTool(Tool):
self.retrieval_options.offset if self.retrieval_options else None
),
limit=self.retrieval_options.limit if self.retrieval_options else None,
rerank_settings=self.rerank_settings,
rerank_settings=RerankingDetails(
rerank_model_name=None,
rerank_api_url=None,
rerank_provider_type=None,
rerank_api_key=None,
num_rerank=0,
disable_rerank_for_streaming=True,
)
if force_no_rerank
else self.rerank_settings,
chunks_above=self.chunks_above,
chunks_below=self.chunks_below,
full_doc=self.full_doc,
@@ -314,57 +327,25 @@ class SearchTool(Tool):
llm=self.llm,
fast_llm=self.fast_llm,
bypass_acl=self.bypass_acl,
db_session=self.db_session,
db_session=alternate_db_session or self.db_session,
prompt_config=self.prompt_config,
)
self.search_pipeline = search_pipeline # used for agent_search metrics
yield ToolResponse(
id=SEARCH_RESPONSE_SUMMARY_ID,
response=SearchResponseSummary(
rephrased_query=query,
top_sections=search_pipeline.final_context_sections,
predicted_flow=search_pipeline.predicted_flow,
predicted_search=search_pipeline.predicted_search_type,
final_filters=search_pipeline.search_query.filters,
recency_bias_multiplier=search_pipeline.search_query.recency_bias_multiplier,
),
search_query_info = SearchQueryInfo(
predicted_search=search_pipeline.search_query.search_type,
final_filters=search_pipeline.search_query.filters,
recency_bias_multiplier=search_pipeline.search_query.recency_bias_multiplier,
)
yield ToolResponse(
id=SEARCH_DOC_CONTENT_ID,
response=OnyxContexts(
contexts=[
OnyxContext(
content=section.combined_content,
document_id=section.center_chunk.document_id,
semantic_identifier=section.center_chunk.semantic_identifier,
blurb=section.center_chunk.blurb,
)
for section in search_pipeline.reranked_sections
]
),
yield from yield_search_responses(
query,
search_pipeline.reranked_sections,
search_pipeline.final_context_sections,
search_query_info,
lambda: search_pipeline.section_relevance,
self,
)
yield ToolResponse(
id=SECTION_RELEVANCE_LIST_ID,
response=search_pipeline.section_relevance,
)
pruned_sections = prune_sections(
sections=search_pipeline.final_context_sections,
section_relevance_list=search_pipeline.section_relevance_list,
prompt_config=self.prompt_config,
llm_config=self.llm.config,
question=query,
contextual_pruning_config=self.contextual_pruning_config,
)
llm_docs = [
llm_doc_from_inference_section(section) for section in pruned_sections
]
yield ToolResponse(id=FINAL_CONTEXT_DOCUMENTS_ID, response=llm_docs)
def final_result(self, *args: ToolResponse) -> JSON_ro:
final_docs = cast(
list[LlmDoc],
@@ -425,3 +406,64 @@ class SearchTool(Tool):
initial_search_results = cast(list[LlmDoc], initial_search_results)
return final_search_results, initial_search_results
# Allows yielding the same responses as a SearchTool without being a SearchTool.
# SearchTool passed in to allow for access to SearchTool properties.
# We can't just call SearchTool methods in the graph because we're operating on
# the retrieved docs (reranking, deduping, etc.) after the SearchTool has run.
def yield_search_responses(
query: str,
reranked_sections: list[InferenceSection],
final_context_sections: list[InferenceSection],
search_query_info: SearchQueryInfo,
get_section_relevance: Callable[[], list[SectionRelevancePiece] | None],
search_tool: SearchTool,
) -> Generator[ToolResponse, None, None]:
yield ToolResponse(
id=SEARCH_RESPONSE_SUMMARY_ID,
response=SearchResponseSummary(
rephrased_query=query,
top_sections=final_context_sections,
predicted_flow=QueryFlow.QUESTION_ANSWER,
predicted_search=search_query_info.predicted_search,
final_filters=search_query_info.final_filters,
recency_bias_multiplier=search_query_info.recency_bias_multiplier,
),
)
yield ToolResponse(
id=SEARCH_DOC_CONTENT_ID,
response=OnyxContexts(
contexts=[
OnyxContext(
content=section.combined_content,
document_id=section.center_chunk.document_id,
semantic_identifier=section.center_chunk.semantic_identifier,
blurb=section.center_chunk.blurb,
)
for section in reranked_sections
]
),
)
section_relevance = get_section_relevance()
yield ToolResponse(
id=SECTION_RELEVANCE_LIST_ID,
response=section_relevance,
)
pruned_sections = prune_sections(
sections=final_context_sections,
section_relevance_list=section_relevance_list_impl(
section_relevance, final_context_sections
),
prompt_config=search_tool.prompt_config,
llm_config=search_tool.llm.config,
question=query,
contextual_pruning_config=search_tool.contextual_pruning_config,
)
llm_docs = [llm_doc_from_inference_section(section) for section in pruned_sections]
yield ToolResponse(id=FINAL_CONTEXT_DOCUMENTS_ID, response=llm_docs)

View File

@@ -49,11 +49,7 @@ def build_next_prompt_for_search_like_tool(
message=prompt_builder.user_message_and_token_cnt[0],
prompt_config=prompt_config,
context_docs=final_context_documents,
all_doc_useful=(
answer_style_config.citation_config.all_docs_useful
if answer_style_config.citation_config
else False
),
all_doc_useful=(answer_style_config.citation_config.all_docs_useful),
history_message=prompt_builder.single_message_history or "",
)
)

View File

@@ -29,9 +29,14 @@ inflection==0.5.1
jira==3.5.1
jsonref==1.1.0
trafilatura==1.12.2
langchain==0.1.17
langchain-core==0.1.50
langchain-text-splitters==0.0.1
langchain==0.3.7
langchain-core==0.3.24
langchain-openai==0.2.9
langchain-text-splitters==0.3.2
langchainhub==0.1.21
langgraph==0.2.59
langgraph-checkpoint==2.0.5
langgraph-sdk==0.1.44
litellm==1.55.4
lxml==5.3.0
lxml_html_clean==0.2.2

View File

@@ -0,0 +1,128 @@
import csv
import datetime
import json
import os
import yaml
from onyx.agent_search.pro_search_a.main.graph_builder import main_graph_builder
from onyx.agent_search.pro_search_a.main.states import MainInput
from onyx.chat.models import ProSearchConfig
from onyx.context.search.models import SearchRequest
from onyx.db.engine import get_session_context_manager
from onyx.llm.factory import get_default_llms
cwd = os.getcwd()
CONFIG = yaml.safe_load(
open(f"{cwd}/backend/tests/regression/answer_quality/search_test_config.yaml")
)
INPUT_DIR = CONFIG["agent_test_input_folder"]
OUTPUT_DIR = CONFIG["agent_test_output_folder"]
graph = main_graph_builder(test_mode=True)
compiled_graph = graph.compile()
primary_llm, fast_llm = get_default_llms()
# create a local json test data file and use it here
input_file_object = open(
f"{INPUT_DIR}/agent_test_data.json",
)
output_file = f"{OUTPUT_DIR}/agent_test_output.csv"
test_data = json.load(input_file_object)
example_data = test_data["examples"]
example_ids = test_data["example_ids"]
with get_session_context_manager() as db_session:
output_data = []
for example in example_data:
example_id = example["id"]
if len(example_ids) > 0 and example_id not in example_ids:
continue
example_question = example["question"]
target_sub_questions = example.get("target_sub_questions", [])
num_target_sub_questions = len(target_sub_questions)
search_request = SearchRequest(query=example_question)
config = ProSearchConfig(
search_request=search_request,
message_id=None,
chat_session_id=None,
use_persistence=False,
)
inputs = MainInput(
config=config,
primary_llm=primary_llm,
fast_llm=fast_llm,
db_session=db_session,
)
start_time = datetime.datetime.now()
question_result = compiled_graph.invoke(input=inputs)
end_time = datetime.datetime.now()
duration = end_time - start_time
if num_target_sub_questions > 0:
chunk_expansion_ratio = (
question_result["initial_agent_stats"]
.get("agent_effectiveness", {})
.get("utilized_chunk_ratio", None)
)
support_effectiveness_ratio = (
question_result["initial_agent_stats"]
.get("agent_effectiveness", {})
.get("support_ratio", None)
)
else:
chunk_expansion_ratio = None
support_effectiveness_ratio = None
generated_sub_questions = question_result.get("generated_sub_questions", [])
num_generated_sub_questions = len(generated_sub_questions)
base_answer = question_result["initial_base_answer"].split("==")[-1]
agent_answer = question_result["initial_answer"].split("==")[-1]
output_point = {
"example_id": example_id,
"question": example_question,
"duration": duration,
"target_sub_questions": target_sub_questions,
"generated_sub_questions": generated_sub_questions,
"num_target_sub_questions": num_target_sub_questions,
"num_generated_sub_questions": num_generated_sub_questions,
"chunk_expansion_ratio": chunk_expansion_ratio,
"support_effectiveness_ratio": support_effectiveness_ratio,
"base_answer": base_answer,
"agent_answer": agent_answer,
}
output_data.append(output_point)
with open(output_file, "w", newline="") as csvfile:
fieldnames = [
"example_id",
"question",
"duration",
"target_sub_questions",
"generated_sub_questions",
"num_target_sub_questions",
"num_generated_sub_questions",
"chunk_expansion_ratio",
"support_effectiveness_ratio",
"base_answer",
"agent_answer",
]
writer = csv.DictWriter(csvfile, fieldnames=fieldnames, delimiter="\t")
writer.writeheader()
writer.writerows(output_data)
print("DONE")

View File

@@ -180,6 +180,7 @@ export function ChatPage({
const [documentSidebarToggled, setDocumentSidebarToggled] = useState(false);
const [filtersToggled, setFiltersToggled] = useState(false);
const [langgraphEnabled, setLanggraphEnabled] = useState(false);
const [userSettingsToggled, setUserSettingsToggled] = useState(false);
@@ -1264,6 +1265,7 @@ export function ChatPage({
systemPromptOverride:
searchParams.get(SEARCH_PARAM_NAMES.SYSTEM_PROMPT) || undefined,
useExistingUserMessage: isSeededChat,
useLanggraph: langgraphEnabled,
});
const delay = (ms: number) => {
@@ -2245,6 +2247,17 @@ export function ChatPage({
hideUserDropdown={user?.is_anonymous_user}
/>
)}
<div className="flex items-center justify-end px-4 py-2">
<label className="flex items-center cursor-pointer">
<span className="mr-2 text-sm">Langgraph</span>
<input
type="checkbox"
checked={langgraphEnabled}
onChange={(e) => setLanggraphEnabled(e.target.checked)}
className="form-checkbox h-4 w-4"
/>
</label>
</div>
{documentSidebarInitialWidth !== undefined && isReady ? (
<Dropzone onDrop={handleImageUpload} noClick>

View File

@@ -128,6 +128,7 @@ export async function* sendMessage({
useExistingUserMessage,
alternateAssistantId,
signal,
useLanggraph,
}: {
regenerate: boolean;
message: string;
@@ -146,6 +147,7 @@ export async function* sendMessage({
useExistingUserMessage?: boolean;
alternateAssistantId?: number;
signal?: AbortSignal;
useLanggraph?: boolean;
}): AsyncGenerator<PacketType, void, unknown> {
const documentsAreSelected =
selectedDocumentIds && selectedDocumentIds.length > 0;
@@ -186,6 +188,7 @@ export async function* sendMessage({
}
: null,
use_existing_user_message: useExistingUserMessage,
use_pro_search: useLanggraph,
});
const response = await fetch(`/api/chat/send-message`, {