mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-03-25 01:22:45 +00:00
Compare commits
91 Commits
v0.20.0-cl
...
evan_answe
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
75aea29ccf | ||
|
|
89c60db0cc | ||
|
|
635ba351d0 | ||
|
|
4ab99fb4a7 | ||
|
|
aee625d525 | ||
|
|
d379453d64 | ||
|
|
685f12c531 | ||
|
|
686eddfa52 | ||
|
|
5f3e877833 | ||
|
|
df75a3115b | ||
|
|
6ff439c342 | ||
|
|
44ccb5ef0f | ||
|
|
8045d52090 | ||
|
|
0d059cf835 | ||
|
|
969d02767f | ||
|
|
0bc3bb5558 | ||
|
|
4c19b19488 | ||
|
|
79e7b73db1 | ||
|
|
7ce0436d71 | ||
|
|
1849069c5f | ||
|
|
5788902d86 | ||
|
|
80542859b6 | ||
|
|
1996e22f9b | ||
|
|
682b145a6a | ||
|
|
ef67f9cd1e | ||
|
|
295417f85d | ||
|
|
8650b8ff51 | ||
|
|
462db23683 | ||
|
|
f2507755c3 | ||
|
|
fd1191637b | ||
|
|
7898f38a5d | ||
|
|
3a38407bca | ||
|
|
239f2f2718 | ||
|
|
0ff44c7661 | ||
|
|
d697ad0fc8 | ||
|
|
6c0d051f80 | ||
|
|
2f7f4917e3 | ||
|
|
695d07f0f9 | ||
|
|
dd2c9425bd | ||
|
|
e468cac28c | ||
|
|
d35aa1eab9 | ||
|
|
ae65c739de | ||
|
|
328f4758ae | ||
|
|
6422ad90a5 | ||
|
|
46850cc2ac | ||
|
|
72e56aa4ca | ||
|
|
13a5a86dec | ||
|
|
a206723191 | ||
|
|
60207589d2 | ||
|
|
d773163502 | ||
|
|
121827e34c | ||
|
|
dd64c3a175 | ||
|
|
38a616c87c | ||
|
|
d7812ee807 | ||
|
|
11db9647f3 | ||
|
|
d6a385b837 | ||
|
|
d68cf98e77 | ||
|
|
821b226d25 | ||
|
|
6dc81bbb7c | ||
|
|
6989441851 | ||
|
|
683978ddb0 | ||
|
|
568bc16536 | ||
|
|
0333ff648a | ||
|
|
cc76486d21 | ||
|
|
901d8c22c4 | ||
|
|
21928133e0 | ||
|
|
c4af11c19b | ||
|
|
ca3f3beabe | ||
|
|
fa481019e8 | ||
|
|
f4c826c4e5 | ||
|
|
2a3328fc3d | ||
|
|
34aa054c5d | ||
|
|
cebe237705 | ||
|
|
c759fb5709 | ||
|
|
ffc81f6e45 | ||
|
|
2d6f746259 | ||
|
|
bca02ebec6 | ||
|
|
0c75ca0579 | ||
|
|
9d3220fcfc | ||
|
|
50a216f554 | ||
|
|
8399d2ee0a | ||
|
|
fd694bea8f | ||
|
|
e76cbec53c | ||
|
|
d66180fe13 | ||
|
|
442c94727e | ||
|
|
2f2b9a862a | ||
|
|
1f88b60abd | ||
|
|
ff03d717f3 | ||
|
|
82914ad365 | ||
|
|
11ce2a62ab | ||
|
|
6311b70cc6 |
4
.gitignore
vendored
4
.gitignore
vendored
@@ -7,4 +7,6 @@
|
||||
.vscode/
|
||||
*.sw?
|
||||
/backend/tests/regression/answer_quality/search_test_config.yaml
|
||||
/web/test-results/
|
||||
/web/test-results/
|
||||
backend/onyx/agent_search/main/test_data.json
|
||||
backend/tests/regression/answer_quality/test_data.json
|
||||
|
||||
6
.vscode/env_template.txt
vendored
6
.vscode/env_template.txt
vendored
@@ -51,3 +51,9 @@ BING_API_KEY=<REPLACE THIS>
|
||||
# Enable the full set of Danswer Enterprise Edition features
|
||||
# NOTE: DO NOT ENABLE THIS UNLESS YOU HAVE A PAID ENTERPRISE LICENSE (or if you are using this for local testing/development)
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=False
|
||||
|
||||
# Agent Search configs # TODO: Remove give proper namings
|
||||
AGENT_RETRIEVAL_STATS=False # Note: This setting will incur substantial re-ranking effort
|
||||
AGENT_RERANKING_STATS=True
|
||||
AGENT_MAX_QUERY_RETRIEVAL_RESULTS=20
|
||||
AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS=20
|
||||
|
||||
@@ -0,0 +1,29 @@
|
||||
"""agent_doc_result_col
|
||||
|
||||
Revision ID: 1adf5ea20d2b
|
||||
Revises: e9cf2bd7baed
|
||||
Create Date: 2025-01-05 13:14:58.344316
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "1adf5ea20d2b"
|
||||
down_revision = "e9cf2bd7baed"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add the new column with JSONB type
|
||||
op.add_column(
|
||||
"sub_question",
|
||||
sa.Column("sub_question_doc_results", postgresql.JSONB(), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop the column
|
||||
op.drop_column("sub_question", "sub_question_doc_results")
|
||||
@@ -0,0 +1,35 @@
|
||||
"""agent_metric_col_rename__s
|
||||
|
||||
Revision ID: 925b58bd75b6
|
||||
Revises: 9787be927e58
|
||||
Create Date: 2025-01-06 11:20:26.752441
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "925b58bd75b6"
|
||||
down_revision = "9787be927e58"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Rename columns using PostgreSQL syntax
|
||||
op.alter_column(
|
||||
"agent__search_metrics", "base_duration_s", new_column_name="base_duration__s"
|
||||
)
|
||||
op.alter_column(
|
||||
"agent__search_metrics", "full_duration_s", new_column_name="full_duration__s"
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Revert the column renames
|
||||
op.alter_column(
|
||||
"agent__search_metrics", "base_duration__s", new_column_name="base_duration_s"
|
||||
)
|
||||
op.alter_column(
|
||||
"agent__search_metrics", "full_duration__s", new_column_name="full_duration_s"
|
||||
)
|
||||
@@ -0,0 +1,25 @@
|
||||
"""agent_metric_table_renames__agent__
|
||||
|
||||
Revision ID: 9787be927e58
|
||||
Revises: bceb76d618ec
|
||||
Create Date: 2025-01-06 11:01:44.210160
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "9787be927e58"
|
||||
down_revision = "bceb76d618ec"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Rename table from agent_search_metrics to agent__search_metrics
|
||||
op.rename_table("agent_search_metrics", "agent__search_metrics")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Rename table back from agent__search_metrics to agent_search_metrics
|
||||
op.rename_table("agent__search_metrics", "agent_search_metrics")
|
||||
42
backend/alembic/versions/98a5008d8711_agent_tracking.py
Normal file
42
backend/alembic/versions/98a5008d8711_agent_tracking.py
Normal file
@@ -0,0 +1,42 @@
|
||||
"""agent_tracking
|
||||
|
||||
Revision ID: 98a5008d8711
|
||||
Revises: 2955778aa44c
|
||||
Create Date: 2025-01-04 14:41:52.732238
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "98a5008d8711"
|
||||
down_revision = "2955778aa44c"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"agent_search_metrics",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=True),
|
||||
sa.Column("persona_id", sa.Integer(), nullable=True),
|
||||
sa.Column("agent_type", sa.String(), nullable=False),
|
||||
sa.Column("start_time", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("base_duration_s", sa.Float(), nullable=False),
|
||||
sa.Column("full_duration_s", sa.Float(), nullable=False),
|
||||
sa.Column("base_metrics", postgresql.JSONB(), nullable=True),
|
||||
sa.Column("refined_metrics", postgresql.JSONB(), nullable=True),
|
||||
sa.Column("all_metrics", postgresql.JSONB(), nullable=True),
|
||||
sa.ForeignKeyConstraint(
|
||||
["persona_id"],
|
||||
["persona.id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("agent_search_metrics")
|
||||
@@ -0,0 +1,84 @@
|
||||
"""agent_table_renames__agent__
|
||||
|
||||
Revision ID: bceb76d618ec
|
||||
Revises: c0132518a25b
|
||||
Create Date: 2025-01-06 10:50:48.109285
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "bceb76d618ec"
|
||||
down_revision = "c0132518a25b"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.drop_constraint(
|
||||
"sub_query__search_doc_sub_query_id_fkey",
|
||||
"sub_query__search_doc",
|
||||
type_="foreignkey",
|
||||
)
|
||||
op.drop_constraint(
|
||||
"sub_query__search_doc_search_doc_id_fkey",
|
||||
"sub_query__search_doc",
|
||||
type_="foreignkey",
|
||||
)
|
||||
# Rename tables
|
||||
op.rename_table("sub_query", "agent__sub_query")
|
||||
op.rename_table("sub_question", "agent__sub_question")
|
||||
op.rename_table("sub_query__search_doc", "agent__sub_query__search_doc")
|
||||
|
||||
# Update both foreign key constraints for agent__sub_query__search_doc
|
||||
|
||||
# Create new foreign keys with updated names
|
||||
op.create_foreign_key(
|
||||
"agent__sub_query__search_doc_sub_query_id_fkey",
|
||||
"agent__sub_query__search_doc",
|
||||
"agent__sub_query",
|
||||
["sub_query_id"],
|
||||
["id"],
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"agent__sub_query__search_doc_search_doc_id_fkey",
|
||||
"agent__sub_query__search_doc",
|
||||
"search_doc", # This table name doesn't change
|
||||
["search_doc_id"],
|
||||
["id"],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Update foreign key constraints for sub_query__search_doc
|
||||
op.drop_constraint(
|
||||
"agent__sub_query__search_doc_sub_query_id_fkey",
|
||||
"agent__sub_query__search_doc",
|
||||
type_="foreignkey",
|
||||
)
|
||||
op.drop_constraint(
|
||||
"agent__sub_query__search_doc_search_doc_id_fkey",
|
||||
"agent__sub_query__search_doc",
|
||||
type_="foreignkey",
|
||||
)
|
||||
|
||||
# Rename tables back
|
||||
op.rename_table("agent__sub_query__search_doc", "sub_query__search_doc")
|
||||
op.rename_table("agent__sub_question", "sub_question")
|
||||
op.rename_table("agent__sub_query", "sub_query")
|
||||
|
||||
op.create_foreign_key(
|
||||
"sub_query__search_doc_sub_query_id_fkey",
|
||||
"sub_query__search_doc",
|
||||
"sub_query",
|
||||
["sub_query_id"],
|
||||
["id"],
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"sub_query__search_doc_search_doc_id_fkey",
|
||||
"sub_query__search_doc",
|
||||
"search_doc", # This table name doesn't change
|
||||
["search_doc_id"],
|
||||
["id"],
|
||||
)
|
||||
@@ -0,0 +1,40 @@
|
||||
"""agent_table_changes_rename_level
|
||||
|
||||
Revision ID: c0132518a25b
|
||||
Revises: 1adf5ea20d2b
|
||||
Create Date: 2025-01-05 16:38:37.660152
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "c0132518a25b"
|
||||
down_revision = "1adf5ea20d2b"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add level and level_question_nr columns with NOT NULL constraint
|
||||
op.add_column(
|
||||
"sub_question",
|
||||
sa.Column("level", sa.Integer(), nullable=False, server_default="0"),
|
||||
)
|
||||
op.add_column(
|
||||
"sub_question",
|
||||
sa.Column(
|
||||
"level_question_nr", sa.Integer(), nullable=False, server_default="0"
|
||||
),
|
||||
)
|
||||
|
||||
# Remove the server_default after the columns are created
|
||||
op.alter_column("sub_question", "level", server_default=None)
|
||||
op.alter_column("sub_question", "level_question_nr", server_default=None)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Remove the columns
|
||||
op.drop_column("sub_question", "level_question_nr")
|
||||
op.drop_column("sub_question", "level")
|
||||
@@ -0,0 +1,68 @@
|
||||
"""create pro search persistence tables
|
||||
|
||||
Revision ID: e9cf2bd7baed
|
||||
Revises: 98a5008d8711
|
||||
Create Date: 2025-01-02 17:55:56.544246
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "e9cf2bd7baed"
|
||||
down_revision = "98a5008d8711"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Create sub_question table
|
||||
op.create_table(
|
||||
"sub_question",
|
||||
sa.Column("id", sa.Integer, primary_key=True),
|
||||
sa.Column("primary_question_id", sa.Integer, sa.ForeignKey("chat_message.id")),
|
||||
sa.Column(
|
||||
"chat_session_id", UUID(as_uuid=True), sa.ForeignKey("chat_session.id")
|
||||
),
|
||||
sa.Column("sub_question", sa.Text),
|
||||
sa.Column(
|
||||
"time_created", sa.DateTime(timezone=True), server_default=sa.func.now()
|
||||
),
|
||||
sa.Column("sub_answer", sa.Text),
|
||||
)
|
||||
|
||||
# Create sub_query table
|
||||
op.create_table(
|
||||
"sub_query",
|
||||
sa.Column("id", sa.Integer, primary_key=True),
|
||||
sa.Column("parent_question_id", sa.Integer, sa.ForeignKey("sub_question.id")),
|
||||
sa.Column(
|
||||
"chat_session_id", UUID(as_uuid=True), sa.ForeignKey("chat_session.id")
|
||||
),
|
||||
sa.Column("sub_query", sa.Text),
|
||||
sa.Column(
|
||||
"time_created", sa.DateTime(timezone=True), server_default=sa.func.now()
|
||||
),
|
||||
)
|
||||
|
||||
# Create sub_query__search_doc association table
|
||||
op.create_table(
|
||||
"sub_query__search_doc",
|
||||
sa.Column(
|
||||
"sub_query_id", sa.Integer, sa.ForeignKey("sub_query.id"), primary_key=True
|
||||
),
|
||||
sa.Column(
|
||||
"search_doc_id",
|
||||
sa.Integer,
|
||||
sa.ForeignKey("search_doc.id"),
|
||||
primary_key=True,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("sub_query__search_doc")
|
||||
op.drop_table("sub_query")
|
||||
op.drop_table("sub_question")
|
||||
@@ -179,6 +179,7 @@ def handle_simplified_chat_message(
|
||||
chunks_below=0,
|
||||
full_doc=chat_message_req.full_doc,
|
||||
structured_response_format=chat_message_req.structured_response_format,
|
||||
use_pro_search=chat_message_req.use_pro_search,
|
||||
)
|
||||
|
||||
packets = stream_chat_message_objects(
|
||||
@@ -301,6 +302,7 @@ def handle_send_message_simple_with_history(
|
||||
chunks_below=0,
|
||||
full_doc=req.full_doc,
|
||||
structured_response_format=req.structured_response_format,
|
||||
use_pro_search=req.use_pro_search,
|
||||
)
|
||||
|
||||
packets = stream_chat_message_objects(
|
||||
|
||||
@@ -57,6 +57,9 @@ class BasicCreateChatMessageRequest(ChunkContext):
|
||||
# https://platform.openai.com/docs/guides/structured-outputs/introduction
|
||||
structured_response_format: dict | None = None
|
||||
|
||||
# If True, uses pro search instead of basic search
|
||||
use_pro_search: bool = False
|
||||
|
||||
|
||||
class BasicCreateChatMessageWithHistoryRequest(ChunkContext):
|
||||
# Last element is the new query. All previous elements are historical context
|
||||
@@ -71,6 +74,8 @@ class BasicCreateChatMessageWithHistoryRequest(ChunkContext):
|
||||
# only works if using an OpenAI model. See the following for more details:
|
||||
# https://platform.openai.com/docs/guides/structured-outputs/introduction
|
||||
structured_response_format: dict | None = None
|
||||
# If True, uses pro search instead of basic search
|
||||
use_pro_search: bool = False
|
||||
|
||||
|
||||
class SimpleDoc(BaseModel):
|
||||
@@ -123,6 +128,9 @@ class OneShotQARequest(ChunkContext):
|
||||
# If True, skips generative an AI response to the search query
|
||||
skip_gen_ai_answer_generation: bool = False
|
||||
|
||||
# If True, uses pro search instead of basic search
|
||||
use_pro_search: bool = False
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_persona_fields(self) -> "OneShotQARequest":
|
||||
if self.persona_override_config is None and self.persona_id is None:
|
||||
|
||||
@@ -196,6 +196,7 @@ def get_answer_stream(
|
||||
retrieval_details=query_request.retrieval_options,
|
||||
rerank_settings=query_request.rerank_settings,
|
||||
db_session=db_session,
|
||||
use_pro_search=query_request.use_pro_search,
|
||||
)
|
||||
|
||||
packets = stream_chat_message_objects(
|
||||
|
||||
78
backend/onyx/agent_search/basic/graph_builder.py
Normal file
78
backend/onyx/agent_search/basic/graph_builder.py
Normal file
@@ -0,0 +1,78 @@
|
||||
from langchain_core.callbacks.manager import dispatch_custom_event
|
||||
from langgraph.graph import END
|
||||
from langgraph.graph import START
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
from onyx.agent_search.basic.states import BasicInput
|
||||
from onyx.agent_search.basic.states import BasicOutput
|
||||
from onyx.agent_search.basic.states import BasicState
|
||||
from onyx.agent_search.basic.states import BasicStateUpdate
|
||||
|
||||
|
||||
def basic_graph_builder() -> StateGraph:
|
||||
graph = StateGraph(
|
||||
state_schema=BasicState,
|
||||
input=BasicInput,
|
||||
output=BasicOutput,
|
||||
)
|
||||
|
||||
### Add nodes ###
|
||||
|
||||
graph.add_node(
|
||||
node="get_response",
|
||||
action=get_response,
|
||||
)
|
||||
|
||||
### Add edges ###
|
||||
|
||||
graph.add_edge(start_key=START, end_key="get_response")
|
||||
|
||||
graph.add_conditional_edges("get_response", should_continue, ["get_response", END])
|
||||
graph.add_edge(
|
||||
start_key="get_response",
|
||||
end_key=END,
|
||||
)
|
||||
|
||||
return graph
|
||||
|
||||
|
||||
def should_continue(state: BasicState) -> str:
|
||||
return (
|
||||
END if state["last_llm_call"] is None or state["calls"] > 0 else "get_response"
|
||||
)
|
||||
|
||||
|
||||
def get_response(state: BasicState) -> BasicStateUpdate:
|
||||
llm = state["llm"]
|
||||
current_llm_call = state["last_llm_call"]
|
||||
if current_llm_call is None:
|
||||
raise ValueError("last_llm_call is None")
|
||||
answer_style_config = state["answer_style_config"]
|
||||
response_handler_manager = state["response_handler_manager"]
|
||||
# DEBUG: good breakpoint
|
||||
stream = llm.stream(
|
||||
# For tool calling LLMs, we want to insert the task prompt as part of this flow, this is because the LLM
|
||||
# may choose to not call any tools and just generate the answer, in which case the task prompt is needed.
|
||||
prompt=current_llm_call.prompt_builder.build(),
|
||||
tools=[tool.tool_definition() for tool in current_llm_call.tools] or None,
|
||||
tool_choice=(
|
||||
"required"
|
||||
if current_llm_call.tools and current_llm_call.force_use_tool.force_use
|
||||
else None
|
||||
),
|
||||
structured_response_format=answer_style_config.structured_response_format,
|
||||
)
|
||||
|
||||
for response in response_handler_manager.handle_llm_response(stream):
|
||||
dispatch_custom_event(
|
||||
"basic_response",
|
||||
response,
|
||||
)
|
||||
return BasicStateUpdate(
|
||||
last_llm_call=response_handler_manager.next_llm_call(current_llm_call),
|
||||
calls=state["calls"] + 1,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pass
|
||||
41
backend/onyx/agent_search/basic/states.py
Normal file
41
backend/onyx/agent_search/basic/states.py
Normal file
@@ -0,0 +1,41 @@
|
||||
from typing import TypedDict
|
||||
|
||||
from onyx.chat.llm_response_handler import LLMResponseHandlerManager
|
||||
from onyx.chat.models import AnswerStyleConfig
|
||||
from onyx.chat.prompt_builder.build import LLMCall
|
||||
from onyx.llm.chat_llm import LLM
|
||||
|
||||
## Update States
|
||||
|
||||
|
||||
## Graph Input State
|
||||
|
||||
|
||||
class BasicInput(TypedDict):
|
||||
last_llm_call: LLMCall | None
|
||||
llm: LLM
|
||||
answer_style_config: AnswerStyleConfig
|
||||
response_handler_manager: LLMResponseHandlerManager
|
||||
calls: int
|
||||
|
||||
|
||||
## Graph Output State
|
||||
|
||||
|
||||
class BasicOutput(TypedDict):
|
||||
pass
|
||||
|
||||
|
||||
class BasicStateUpdate(TypedDict):
|
||||
last_llm_call: LLMCall | None
|
||||
calls: int
|
||||
|
||||
|
||||
## Graph State
|
||||
|
||||
|
||||
class BasicState(
|
||||
BasicInput,
|
||||
BasicOutput,
|
||||
):
|
||||
pass
|
||||
66
backend/onyx/agent_search/core_state.py
Normal file
66
backend/onyx/agent_search/core_state.py
Normal file
@@ -0,0 +1,66 @@
|
||||
from operator import add
|
||||
from typing import Annotated
|
||||
from typing import TypedDict
|
||||
from typing import TypeVar
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.chat.models import ProSearchConfig
|
||||
from onyx.db.models import User
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
|
||||
|
||||
class CoreState(TypedDict, total=False):
|
||||
"""
|
||||
This is the core state that is shared across all subgraphs.
|
||||
"""
|
||||
|
||||
config: ProSearchConfig
|
||||
primary_llm: LLM
|
||||
fast_llm: LLM
|
||||
# a single session for the entire agent search
|
||||
# is fine if we are only reading
|
||||
db_session: Session
|
||||
user: User | None
|
||||
log_messages: Annotated[list[str], add]
|
||||
search_tool: SearchTool
|
||||
|
||||
|
||||
class SubgraphCoreState(TypedDict, total=False):
|
||||
"""
|
||||
This is the core state that is shared across all subgraphs.
|
||||
"""
|
||||
|
||||
subgraph_config: ProSearchConfig
|
||||
subgraph_primary_llm: LLM
|
||||
subgraph_fast_llm: LLM
|
||||
# a single session for the entire agent search
|
||||
# is fine if we are only reading
|
||||
subgraph_db_session: Session
|
||||
|
||||
subgraph_search_tool: SearchTool
|
||||
|
||||
|
||||
# This ensures that the state passed in extends the CoreState
|
||||
T = TypeVar("T", bound=CoreState)
|
||||
T_SUBGRAPH = TypeVar("T_SUBGRAPH", bound=SubgraphCoreState)
|
||||
|
||||
|
||||
def extract_core_fields(state: T) -> CoreState:
|
||||
filtered_dict = {k: v for k, v in state.items() if k in CoreState.__annotations__}
|
||||
return CoreState(**dict(filtered_dict)) # type: ignore
|
||||
|
||||
|
||||
def extract_core_fields_for_subgraph(state: T) -> SubgraphCoreState:
|
||||
filtered_dict = {
|
||||
"subgraph_" + k: v for k, v in state.items() if k in CoreState.__annotations__
|
||||
}
|
||||
return SubgraphCoreState(**dict(filtered_dict)) # type: ignore
|
||||
|
||||
|
||||
def in_subgraph_extract_core_fields(state: T_SUBGRAPH) -> SubgraphCoreState:
|
||||
filtered_dict = {
|
||||
k: v for k, v in state.items() if k in SubgraphCoreState.__annotations__
|
||||
}
|
||||
return SubgraphCoreState(**dict(filtered_dict)) # type: ignore
|
||||
66
backend/onyx/agent_search/db_operations.py
Normal file
66
backend/onyx/agent_search/db_operations.py
Normal file
@@ -0,0 +1,66 @@
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.models import AgentSubQuery
|
||||
from onyx.db.models import AgentSubQuestion
|
||||
|
||||
|
||||
def create_sub_question(
|
||||
db_session: Session,
|
||||
chat_session_id: UUID,
|
||||
primary_message_id: int,
|
||||
sub_question: str,
|
||||
sub_answer: str,
|
||||
) -> AgentSubQuestion:
|
||||
"""Create a new sub-question record in the database."""
|
||||
sub_q = AgentSubQuestion(
|
||||
chat_session_id=chat_session_id,
|
||||
primary_question_id=primary_message_id,
|
||||
sub_question=sub_question,
|
||||
sub_answer=sub_answer,
|
||||
)
|
||||
db_session.add(sub_q)
|
||||
db_session.flush()
|
||||
return sub_q
|
||||
|
||||
|
||||
def create_sub_query(
|
||||
db_session: Session,
|
||||
chat_session_id: UUID,
|
||||
parent_question_id: int,
|
||||
sub_query: str,
|
||||
) -> AgentSubQuery:
|
||||
"""Create a new sub-query record in the database."""
|
||||
sub_q = AgentSubQuery(
|
||||
chat_session_id=chat_session_id,
|
||||
parent_question_id=parent_question_id,
|
||||
sub_query=sub_query,
|
||||
)
|
||||
db_session.add(sub_q)
|
||||
db_session.flush()
|
||||
return sub_q
|
||||
|
||||
|
||||
def get_sub_questions_for_message(
|
||||
db_session: Session,
|
||||
primary_message_id: int,
|
||||
) -> list[AgentSubQuestion]:
|
||||
"""Get all sub-questions for a given primary message."""
|
||||
return (
|
||||
db_session.query(AgentSubQuestion)
|
||||
.filter(AgentSubQuestion.primary_question_id == primary_message_id)
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
def get_sub_queries_for_question(
|
||||
db_session: Session,
|
||||
sub_question_id: int,
|
||||
) -> list[AgentSubQuery]:
|
||||
"""Get all sub-queries for a given sub-question."""
|
||||
return (
|
||||
db_session.query(AgentSubQuery)
|
||||
.filter(AgentSubQuery.parent_question_id == sub_question_id)
|
||||
.all()
|
||||
)
|
||||
7
backend/onyx/agent_search/models.py
Normal file
7
backend/onyx/agent_search/models.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class AgentDocumentCitations(BaseModel):
|
||||
document_id: str
|
||||
document_title: str
|
||||
link: str
|
||||
@@ -0,0 +1,28 @@
|
||||
from collections.abc import Hashable
|
||||
|
||||
from langgraph.types import Send
|
||||
|
||||
from onyx.agent_search.core_state import in_subgraph_extract_core_fields
|
||||
from onyx.agent_search.pro_search_a.answer_initial_sub_question.states import (
|
||||
AnswerQuestionInput,
|
||||
)
|
||||
from onyx.agent_search.pro_search_a.expanded_retrieval.states import (
|
||||
ExpandedRetrievalInput,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def send_to_expanded_retrieval(state: AnswerQuestionInput) -> Send | Hashable:
|
||||
logger.debug("sending to expanded retrieval via edge")
|
||||
|
||||
return Send(
|
||||
"initial_sub_question_expanded_retrieval",
|
||||
ExpandedRetrievalInput(
|
||||
**in_subgraph_extract_core_fields(state),
|
||||
question=state["question"],
|
||||
base_search=False,
|
||||
sub_question_id=state["question_id"],
|
||||
),
|
||||
)
|
||||
@@ -0,0 +1,129 @@
|
||||
from langgraph.graph import END
|
||||
from langgraph.graph import START
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
from onyx.agent_search.pro_search_a.answer_initial_sub_question.edges import (
|
||||
send_to_expanded_retrieval,
|
||||
)
|
||||
from onyx.agent_search.pro_search_a.answer_initial_sub_question.nodes.answer_check import (
|
||||
answer_check,
|
||||
)
|
||||
from onyx.agent_search.pro_search_a.answer_initial_sub_question.nodes.answer_generation import (
|
||||
answer_generation,
|
||||
)
|
||||
from onyx.agent_search.pro_search_a.answer_initial_sub_question.nodes.format_answer import (
|
||||
format_answer,
|
||||
)
|
||||
from onyx.agent_search.pro_search_a.answer_initial_sub_question.nodes.ingest_retrieval import (
|
||||
ingest_retrieval,
|
||||
)
|
||||
from onyx.agent_search.pro_search_a.answer_initial_sub_question.states import (
|
||||
AnswerQuestionInput,
|
||||
)
|
||||
from onyx.agent_search.pro_search_a.answer_initial_sub_question.states import (
|
||||
AnswerQuestionOutput,
|
||||
)
|
||||
from onyx.agent_search.pro_search_a.answer_initial_sub_question.states import (
|
||||
AnswerQuestionState,
|
||||
)
|
||||
from onyx.agent_search.pro_search_a.expanded_retrieval.graph_builder import (
|
||||
expanded_retrieval_graph_builder,
|
||||
)
|
||||
from onyx.agent_search.shared_graph_utils.utils import get_test_config
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def answer_query_graph_builder() -> StateGraph:
|
||||
graph = StateGraph(
|
||||
state_schema=AnswerQuestionState,
|
||||
input=AnswerQuestionInput,
|
||||
output=AnswerQuestionOutput,
|
||||
)
|
||||
|
||||
### Add nodes ###
|
||||
|
||||
expanded_retrieval = expanded_retrieval_graph_builder().compile()
|
||||
graph.add_node(
|
||||
node="initial_sub_question_expanded_retrieval",
|
||||
action=expanded_retrieval,
|
||||
)
|
||||
graph.add_node(
|
||||
node="answer_check",
|
||||
action=answer_check,
|
||||
)
|
||||
graph.add_node(
|
||||
node="answer_generation",
|
||||
action=answer_generation,
|
||||
)
|
||||
graph.add_node(
|
||||
node="format_answer",
|
||||
action=format_answer,
|
||||
)
|
||||
graph.add_node(
|
||||
node="ingest_retrieval",
|
||||
action=ingest_retrieval,
|
||||
)
|
||||
|
||||
### Add edges ###
|
||||
|
||||
graph.add_conditional_edges(
|
||||
source=START,
|
||||
path=send_to_expanded_retrieval,
|
||||
path_map=["initial_sub_question_expanded_retrieval"],
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="initial_sub_question_expanded_retrieval",
|
||||
end_key="ingest_retrieval",
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="ingest_retrieval",
|
||||
end_key="answer_generation",
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="answer_generation",
|
||||
end_key="answer_check",
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="answer_check",
|
||||
end_key="format_answer",
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="format_answer",
|
||||
end_key=END,
|
||||
)
|
||||
|
||||
return graph
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.llm.factory import get_default_llms
|
||||
from onyx.context.search.models import SearchRequest
|
||||
|
||||
graph = answer_query_graph_builder()
|
||||
compiled_graph = graph.compile()
|
||||
primary_llm, fast_llm = get_default_llms()
|
||||
search_request = SearchRequest(
|
||||
query="what can you do with onyx or danswer?",
|
||||
)
|
||||
with get_session_context_manager() as db_session:
|
||||
pro_search_config, search_tool = get_test_config(
|
||||
db_session, primary_llm, fast_llm, search_request
|
||||
)
|
||||
inputs = AnswerQuestionInput(
|
||||
question="what can you do with onyx?",
|
||||
subgraph_fast_llm=fast_llm,
|
||||
subgraph_primary_llm=primary_llm,
|
||||
subgraph_config=pro_search_config,
|
||||
subgraph_search_tool=search_tool,
|
||||
subgraph_db_session=db_session,
|
||||
question_id="0_0",
|
||||
)
|
||||
for thing in compiled_graph.stream(
|
||||
input=inputs,
|
||||
# debug=True,
|
||||
# subgraphs=True,
|
||||
):
|
||||
logger.debug(thing)
|
||||
@@ -0,0 +1,21 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.agent_search.pro_search_a.expanded_retrieval.models import QueryResult
|
||||
from onyx.agent_search.shared_graph_utils.models import AgentChunkStats
|
||||
from onyx.context.search.models import InferenceSection
|
||||
|
||||
### Models ###
|
||||
|
||||
|
||||
class AnswerRetrievalStats(BaseModel):
|
||||
answer_retrieval_stats: dict[str, float | int]
|
||||
|
||||
|
||||
class QuestionAnswerResults(BaseModel):
|
||||
question: str
|
||||
question_id: str
|
||||
answer: str
|
||||
quality: str
|
||||
expanded_retrieval_results: list[QueryResult]
|
||||
documents: list[InferenceSection]
|
||||
sub_question_retrieval_stats: AgentChunkStats
|
||||
@@ -0,0 +1,34 @@
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import merge_message_runs
|
||||
|
||||
from onyx.agent_search.pro_search_a.answer_initial_sub_question.states import (
|
||||
AnswerQuestionState,
|
||||
)
|
||||
from onyx.agent_search.pro_search_a.answer_initial_sub_question.states import (
|
||||
QACheckUpdate,
|
||||
)
|
||||
from onyx.agent_search.shared_graph_utils.prompts import SUB_CHECK_PROMPT
|
||||
|
||||
|
||||
def answer_check(state: AnswerQuestionState) -> QACheckUpdate:
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=SUB_CHECK_PROMPT.format(
|
||||
question=state["question"],
|
||||
base_answer=state["answer"],
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
fast_llm = state["subgraph_fast_llm"]
|
||||
response = list(
|
||||
fast_llm.stream(
|
||||
prompt=msg,
|
||||
)
|
||||
)
|
||||
|
||||
quality_str = merge_message_runs(response, chunk_separator="")[0].content
|
||||
|
||||
return QACheckUpdate(
|
||||
answer_quality=quality_str,
|
||||
)
|
||||
@@ -0,0 +1,77 @@
|
||||
import datetime
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.callbacks.manager import dispatch_custom_event
|
||||
from langchain_core.messages import merge_message_runs
|
||||
|
||||
from onyx.agent_search.pro_search_a.answer_initial_sub_question.states import (
|
||||
AnswerQuestionState,
|
||||
)
|
||||
from onyx.agent_search.pro_search_a.answer_initial_sub_question.states import (
|
||||
QAGenerationUpdate,
|
||||
)
|
||||
from onyx.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
build_sub_question_answer_prompt,
|
||||
)
|
||||
from onyx.agent_search.shared_graph_utils.prompts import ASSISTANT_SYSTEM_PROMPT_DEFAULT
|
||||
from onyx.agent_search.shared_graph_utils.prompts import ASSISTANT_SYSTEM_PROMPT_PERSONA
|
||||
from onyx.agent_search.shared_graph_utils.utils import get_persona_prompt
|
||||
from onyx.agent_search.shared_graph_utils.utils import parse_question_id
|
||||
from onyx.chat.models import AgentAnswerPiece
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def answer_generation(state: AnswerQuestionState) -> QAGenerationUpdate:
|
||||
now_start = datetime.datetime.now()
|
||||
logger.debug(f"--------{now_start}--------START ANSWER GENERATION---")
|
||||
|
||||
question = state["question"]
|
||||
docs = state["documents"]
|
||||
level, question_nr = parse_question_id(state["question_id"])
|
||||
persona_prompt = get_persona_prompt(state["subgraph_config"].search_request.persona)
|
||||
|
||||
if len(persona_prompt) > 0:
|
||||
persona_specification = ASSISTANT_SYSTEM_PROMPT_DEFAULT
|
||||
else:
|
||||
persona_specification = ASSISTANT_SYSTEM_PROMPT_PERSONA.format(
|
||||
persona_prompt=persona_prompt
|
||||
)
|
||||
|
||||
logger.debug(f"Number of verified retrieval docs: {len(docs)}")
|
||||
|
||||
msg = build_sub_question_answer_prompt(
|
||||
question=question,
|
||||
original_question=state["subgraph_config"].search_request.query,
|
||||
docs=docs,
|
||||
persona_specification=persona_specification,
|
||||
)
|
||||
|
||||
fast_llm = state["subgraph_fast_llm"]
|
||||
response: list[str | list[str | dict[str, Any]]] = []
|
||||
for message in fast_llm.stream(
|
||||
prompt=msg,
|
||||
):
|
||||
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
|
||||
content = message.content
|
||||
if not isinstance(content, str):
|
||||
raise ValueError(
|
||||
f"Expected content to be a string, but got {type(content)}"
|
||||
)
|
||||
dispatch_custom_event(
|
||||
"sub_answers",
|
||||
AgentAnswerPiece(
|
||||
answer_piece=content,
|
||||
level=level,
|
||||
level_question_nr=question_nr,
|
||||
answer_type="agent_sub_answer",
|
||||
),
|
||||
)
|
||||
response.append(content)
|
||||
|
||||
answer_str = merge_message_runs(response, chunk_separator="")[0].content
|
||||
|
||||
return QAGenerationUpdate(
|
||||
answer=answer_str,
|
||||
)
|
||||
@@ -0,0 +1,25 @@
|
||||
from onyx.agent_search.pro_search_a.answer_initial_sub_question.states import (
|
||||
AnswerQuestionOutput,
|
||||
)
|
||||
from onyx.agent_search.pro_search_a.answer_initial_sub_question.states import (
|
||||
AnswerQuestionState,
|
||||
)
|
||||
from onyx.agent_search.pro_search_a.answer_initial_sub_question.states import (
|
||||
QuestionAnswerResults,
|
||||
)
|
||||
|
||||
|
||||
def format_answer(state: AnswerQuestionState) -> AnswerQuestionOutput:
|
||||
return AnswerQuestionOutput(
|
||||
answer_results=[
|
||||
QuestionAnswerResults(
|
||||
question=state["question"],
|
||||
question_id=state["question_id"],
|
||||
quality=state.get("answer_quality", "No"),
|
||||
answer=state["answer"],
|
||||
expanded_retrieval_results=state["expanded_retrieval_results"],
|
||||
documents=state["documents"],
|
||||
sub_question_retrieval_stats=state["sub_question_retrieval_stats"],
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,23 @@
|
||||
from onyx.agent_search.pro_search_a.answer_initial_sub_question.states import (
|
||||
RetrievalIngestionUpdate,
|
||||
)
|
||||
from onyx.agent_search.pro_search_a.expanded_retrieval.states import (
|
||||
ExpandedRetrievalOutput,
|
||||
)
|
||||
from onyx.agent_search.shared_graph_utils.models import AgentChunkStats
|
||||
|
||||
|
||||
def ingest_retrieval(state: ExpandedRetrievalOutput) -> RetrievalIngestionUpdate:
|
||||
sub_question_retrieval_stats = state[
|
||||
"expanded_retrieval_result"
|
||||
].sub_question_retrieval_stats
|
||||
if sub_question_retrieval_stats is None:
|
||||
sub_question_retrieval_stats = [AgentChunkStats()]
|
||||
|
||||
return RetrievalIngestionUpdate(
|
||||
expanded_retrieval_results=state[
|
||||
"expanded_retrieval_result"
|
||||
].expanded_queries_results,
|
||||
documents=state["expanded_retrieval_result"].all_documents,
|
||||
sub_question_retrieval_stats=sub_question_retrieval_stats,
|
||||
)
|
||||
@@ -0,0 +1,63 @@
|
||||
from operator import add
|
||||
from typing import Annotated
|
||||
from typing import TypedDict
|
||||
|
||||
from onyx.agent_search.core_state import SubgraphCoreState
|
||||
from onyx.agent_search.pro_search_a.answer_initial_sub_question.models import (
|
||||
QuestionAnswerResults,
|
||||
)
|
||||
from onyx.agent_search.pro_search_a.expanded_retrieval.models import QueryResult
|
||||
from onyx.agent_search.shared_graph_utils.models import AgentChunkStats
|
||||
from onyx.agent_search.shared_graph_utils.operators import dedup_inference_sections
|
||||
from onyx.context.search.models import InferenceSection
|
||||
|
||||
|
||||
## Update States
|
||||
class QACheckUpdate(TypedDict):
|
||||
answer_quality: str
|
||||
|
||||
|
||||
class QAGenerationUpdate(TypedDict):
|
||||
answer: str
|
||||
# answer_stat: AnswerStats
|
||||
|
||||
|
||||
class RetrievalIngestionUpdate(TypedDict):
|
||||
expanded_retrieval_results: list[QueryResult]
|
||||
documents: Annotated[list[InferenceSection], dedup_inference_sections]
|
||||
sub_question_retrieval_stats: AgentChunkStats
|
||||
|
||||
|
||||
## Graph Input State
|
||||
|
||||
|
||||
class AnswerQuestionInput(SubgraphCoreState):
|
||||
question: str
|
||||
question_id: str # 0_0 is original question, everything else is <level>_<question_num>.
|
||||
# level 0 is original question and first decomposition, level 1 is follow up, etc
|
||||
# question_num is a unique number per original question per level.
|
||||
|
||||
|
||||
## Graph State
|
||||
|
||||
|
||||
class AnswerQuestionState(
|
||||
AnswerQuestionInput,
|
||||
QAGenerationUpdate,
|
||||
QACheckUpdate,
|
||||
RetrievalIngestionUpdate,
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
## Graph Output State
|
||||
|
||||
|
||||
class AnswerQuestionOutput(TypedDict):
|
||||
"""
|
||||
This is a list of results even though each call of this subgraph only returns one result.
|
||||
This is because if we parallelize the answer query subgraph, there will be multiple
|
||||
results in a list so the add operator is used to add them together.
|
||||
"""
|
||||
|
||||
answer_results: Annotated[list[QuestionAnswerResults], add]
|
||||
@@ -0,0 +1,28 @@
|
||||
from collections.abc import Hashable
|
||||
|
||||
from langgraph.types import Send
|
||||
|
||||
from onyx.agent_search.core_state import in_subgraph_extract_core_fields
|
||||
from onyx.agent_search.pro_search_a.answer_initial_sub_question.states import (
|
||||
AnswerQuestionInput,
|
||||
)
|
||||
from onyx.agent_search.pro_search_a.expanded_retrieval.states import (
|
||||
ExpandedRetrievalInput,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def send_to_expanded_refined_retrieval(state: AnswerQuestionInput) -> Send | Hashable:
|
||||
logger.debug("sending to expanded retrieval for follow up question via edge")
|
||||
|
||||
return Send(
|
||||
"refined_sub_question_expanded_retrieval",
|
||||
ExpandedRetrievalInput(
|
||||
**in_subgraph_extract_core_fields(state),
|
||||
question=state["question"],
|
||||
sub_question_id=state["question_id"],
|
||||
base_search=False
|
||||
),
|
||||
)
|
||||
@@ -0,0 +1,122 @@
|
||||
from langgraph.graph import END
|
||||
from langgraph.graph import START
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
from onyx.agent_search.pro_search_a.answer_initial_sub_question.nodes.answer_check import (
|
||||
answer_check,
|
||||
)
|
||||
from onyx.agent_search.pro_search_a.answer_initial_sub_question.nodes.answer_generation import (
|
||||
answer_generation,
|
||||
)
|
||||
from onyx.agent_search.pro_search_a.answer_initial_sub_question.nodes.format_answer import (
|
||||
format_answer,
|
||||
)
|
||||
from onyx.agent_search.pro_search_a.answer_initial_sub_question.nodes.ingest_retrieval import (
|
||||
ingest_retrieval,
|
||||
)
|
||||
from onyx.agent_search.pro_search_a.answer_initial_sub_question.states import (
|
||||
AnswerQuestionInput,
|
||||
)
|
||||
from onyx.agent_search.pro_search_a.answer_initial_sub_question.states import (
|
||||
AnswerQuestionOutput,
|
||||
)
|
||||
from onyx.agent_search.pro_search_a.answer_initial_sub_question.states import (
|
||||
AnswerQuestionState,
|
||||
)
|
||||
from onyx.agent_search.pro_search_a.answer_refinement_sub_question.edges import (
|
||||
send_to_expanded_refined_retrieval,
|
||||
)
|
||||
from onyx.agent_search.pro_search_a.expanded_retrieval.graph_builder import (
|
||||
expanded_retrieval_graph_builder,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def answer_refined_query_graph_builder() -> StateGraph:
|
||||
graph = StateGraph(
|
||||
state_schema=AnswerQuestionState,
|
||||
input=AnswerQuestionInput,
|
||||
output=AnswerQuestionOutput,
|
||||
)
|
||||
|
||||
### Add nodes ###
|
||||
|
||||
expanded_retrieval = expanded_retrieval_graph_builder().compile()
|
||||
graph.add_node(
|
||||
node="refined_sub_question_expanded_retrieval",
|
||||
action=expanded_retrieval,
|
||||
)
|
||||
graph.add_node(
|
||||
node="refined_sub_answer_check",
|
||||
action=answer_check,
|
||||
)
|
||||
graph.add_node(
|
||||
node="refined_sub_answer_generation",
|
||||
action=answer_generation,
|
||||
)
|
||||
graph.add_node(
|
||||
node="format_refined_sub_answer",
|
||||
action=format_answer,
|
||||
)
|
||||
graph.add_node(
|
||||
node="ingest_refined_retrieval",
|
||||
action=ingest_retrieval,
|
||||
)
|
||||
|
||||
### Add edges ###
|
||||
|
||||
graph.add_conditional_edges(
|
||||
source=START,
|
||||
path=send_to_expanded_refined_retrieval,
|
||||
path_map=["refined_sub_question_expanded_retrieval"],
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="refined_sub_question_expanded_retrieval",
|
||||
end_key="ingest_refined_retrieval",
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="ingest_refined_retrieval",
|
||||
end_key="refined_sub_answer_generation",
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="refined_sub_answer_generation",
|
||||
end_key="refined_sub_answer_check",
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="refined_sub_answer_check",
|
||||
end_key="format_refined_sub_answer",
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="format_refined_sub_answer",
|
||||
end_key=END,
|
||||
)
|
||||
|
||||
return graph
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.llm.factory import get_default_llms
|
||||
from onyx.context.search.models import SearchRequest
|
||||
|
||||
graph = answer_refined_query_graph_builder()
|
||||
compiled_graph = graph.compile()
|
||||
primary_llm, fast_llm = get_default_llms()
|
||||
search_request = SearchRequest(
|
||||
query="what can you do with onyx or danswer?",
|
||||
)
|
||||
with get_session_context_manager() as db_session:
|
||||
inputs = AnswerQuestionInput(
|
||||
question="what can you do with onyx?",
|
||||
question_id="0_0",
|
||||
)
|
||||
for thing in compiled_graph.stream(
|
||||
input=inputs,
|
||||
# debug=True,
|
||||
# subgraphs=True,
|
||||
):
|
||||
logger.debug(thing)
|
||||
# output = compiled_graph.invoke(inputs)
|
||||
# logger.debug(output)
|
||||
@@ -0,0 +1,19 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.agent_search.shared_graph_utils.models import AgentChunkStats
|
||||
from onyx.context.search.models import InferenceSection
|
||||
|
||||
### Models ###
|
||||
|
||||
|
||||
class AnswerRetrievalStats(BaseModel):
|
||||
answer_retrieval_stats: dict[str, float | int]
|
||||
|
||||
|
||||
class QuestionAnswerResults(BaseModel):
|
||||
question: str
|
||||
answer: str
|
||||
quality: str
|
||||
# expanded_retrieval_results: list[QueryResult]
|
||||
documents: list[InferenceSection]
|
||||
sub_question_retrieval_stats: AgentChunkStats
|
||||
@@ -0,0 +1,70 @@
|
||||
from langgraph.graph import END
|
||||
from langgraph.graph import START
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
from onyx.agent_search.pro_search_a.base_raw_search.nodes.format_raw_search_results import (
|
||||
format_raw_search_results,
|
||||
)
|
||||
from onyx.agent_search.pro_search_a.base_raw_search.nodes.generate_raw_search_data import (
|
||||
generate_raw_search_data,
|
||||
)
|
||||
from onyx.agent_search.pro_search_a.base_raw_search.states import BaseRawSearchInput
|
||||
from onyx.agent_search.pro_search_a.base_raw_search.states import BaseRawSearchOutput
|
||||
from onyx.agent_search.pro_search_a.base_raw_search.states import BaseRawSearchState
|
||||
from onyx.agent_search.pro_search_a.expanded_retrieval.graph_builder import (
|
||||
expanded_retrieval_graph_builder,
|
||||
)
|
||||
|
||||
|
||||
def base_raw_search_graph_builder() -> StateGraph:
|
||||
graph = StateGraph(
|
||||
state_schema=BaseRawSearchState,
|
||||
input=BaseRawSearchInput,
|
||||
output=BaseRawSearchOutput,
|
||||
)
|
||||
|
||||
### Add nodes ###
|
||||
|
||||
expanded_retrieval = expanded_retrieval_graph_builder().compile()
|
||||
graph.add_node(
|
||||
node="generate_raw_search_data",
|
||||
action=generate_raw_search_data,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
node="expanded_retrieval_base_search",
|
||||
action=expanded_retrieval,
|
||||
)
|
||||
graph.add_node(
|
||||
node="format_raw_search_results",
|
||||
action=format_raw_search_results,
|
||||
)
|
||||
|
||||
### Add edges ###
|
||||
|
||||
graph.add_edge(start_key=START, end_key="generate_raw_search_data")
|
||||
|
||||
graph.add_edge(
|
||||
start_key="generate_raw_search_data",
|
||||
end_key="expanded_retrieval_base_search",
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="expanded_retrieval_base_search",
|
||||
end_key="format_raw_search_results",
|
||||
)
|
||||
|
||||
# graph.add_edge(
|
||||
# start_key="expanded_retrieval_base_search",
|
||||
# end_key=END,
|
||||
# )
|
||||
|
||||
graph.add_edge(
|
||||
start_key="format_raw_search_results",
|
||||
end_key=END,
|
||||
)
|
||||
|
||||
return graph
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pass
|
||||
@@ -0,0 +1,20 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.agent_search.pro_search_a.expanded_retrieval.models import QueryResult
|
||||
from onyx.agent_search.shared_graph_utils.models import AgentChunkStats
|
||||
from onyx.context.search.models import InferenceSection
|
||||
|
||||
### Models ###
|
||||
|
||||
|
||||
class AnswerRetrievalStats(BaseModel):
|
||||
answer_retrieval_stats: dict[str, float | int]
|
||||
|
||||
|
||||
class QuestionAnswerResults(BaseModel):
|
||||
question: str
|
||||
answer: str
|
||||
quality: str
|
||||
expanded_retrieval_results: list[QueryResult]
|
||||
documents: list[InferenceSection]
|
||||
sub_question_retrieval_stats: list[AgentChunkStats]
|
||||
@@ -0,0 +1,16 @@
|
||||
from onyx.agent_search.pro_search_a.base_raw_search.states import BaseRawSearchOutput
|
||||
from onyx.agent_search.pro_search_a.expanded_retrieval.states import (
|
||||
ExpandedRetrievalOutput,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def format_raw_search_results(state: ExpandedRetrievalOutput) -> BaseRawSearchOutput:
|
||||
logger.debug("format_raw_search_results")
|
||||
return BaseRawSearchOutput(
|
||||
base_expanded_retrieval_result=state["expanded_retrieval_result"],
|
||||
# base_retrieval_results=[state["expanded_retrieval_result"]],
|
||||
# base_search_documents=[],
|
||||
)
|
||||
@@ -0,0 +1,21 @@
|
||||
from onyx.agent_search.core_state import CoreState
|
||||
from onyx.agent_search.pro_search_a.expanded_retrieval.states import (
|
||||
ExpandedRetrievalInput,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def generate_raw_search_data(state: CoreState) -> ExpandedRetrievalInput:
|
||||
logger.debug("generate_raw_search_data")
|
||||
return ExpandedRetrievalInput(
|
||||
subgraph_config=state["config"],
|
||||
subgraph_primary_llm=state["primary_llm"],
|
||||
subgraph_fast_llm=state["fast_llm"],
|
||||
subgraph_db_session=state["db_session"],
|
||||
question=state["config"].search_request.query,
|
||||
base_search=True,
|
||||
subgraph_search_tool=state["search_tool"],
|
||||
sub_question_id=None, # This graph is always and only used for the original question
|
||||
)
|
||||
@@ -0,0 +1,42 @@
|
||||
from typing import TypedDict
|
||||
|
||||
from onyx.agent_search.core_state import CoreState
|
||||
from onyx.agent_search.core_state import SubgraphCoreState
|
||||
from onyx.agent_search.pro_search_a.expanded_retrieval.models import (
|
||||
ExpandedRetrievalResult,
|
||||
)
|
||||
|
||||
|
||||
## Update States
|
||||
|
||||
|
||||
## Graph Input State
|
||||
|
||||
|
||||
class BaseRawSearchInput(CoreState, SubgraphCoreState):
|
||||
pass
|
||||
|
||||
|
||||
## Graph Output State
|
||||
|
||||
|
||||
class BaseRawSearchOutput(TypedDict):
|
||||
"""
|
||||
This is a list of results even though each call of this subgraph only returns one result.
|
||||
This is because if we parallelize the answer query subgraph, there will be multiple
|
||||
results in a list so the add operator is used to add them together.
|
||||
"""
|
||||
|
||||
# base_search_documents: Annotated[list[InferenceSection], dedup_inference_sections]
|
||||
# base_retrieval_results: Annotated[list[ExpandedRetrievalResult], add]
|
||||
base_expanded_retrieval_result: ExpandedRetrievalResult
|
||||
|
||||
|
||||
## Graph State
|
||||
|
||||
|
||||
class BaseRawSearchState(
|
||||
BaseRawSearchInput,
|
||||
BaseRawSearchOutput,
|
||||
):
|
||||
pass
|
||||
@@ -0,0 +1,28 @@
|
||||
from collections.abc import Hashable
|
||||
|
||||
from langgraph.types import Send
|
||||
|
||||
from onyx.agent_search.core_state import in_subgraph_extract_core_fields
|
||||
from onyx.agent_search.pro_search_a.expanded_retrieval.nodes import RetrievalInput
|
||||
from onyx.agent_search.pro_search_a.expanded_retrieval.states import (
|
||||
ExpandedRetrievalState,
|
||||
)
|
||||
|
||||
|
||||
def parallel_retrieval_edge(state: ExpandedRetrievalState) -> list[Send | Hashable]:
|
||||
question = state.get("question", state["subgraph_config"].search_request.query)
|
||||
|
||||
query_expansions = state.get("expanded_queries", []) + [question]
|
||||
return [
|
||||
Send(
|
||||
"doc_retrieval",
|
||||
RetrievalInput(
|
||||
query_to_retrieve=query,
|
||||
question=question,
|
||||
**in_subgraph_extract_core_fields(state),
|
||||
base_search=False,
|
||||
sub_question_id=state.get("sub_question_id"),
|
||||
),
|
||||
)
|
||||
for query in query_expansions
|
||||
]
|
||||
@@ -0,0 +1,126 @@
|
||||
from langgraph.graph import END
|
||||
from langgraph.graph import START
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
from onyx.agent_search.pro_search_a.expanded_retrieval.edges import (
|
||||
parallel_retrieval_edge,
|
||||
)
|
||||
from onyx.agent_search.pro_search_a.expanded_retrieval.nodes import doc_reranking
|
||||
from onyx.agent_search.pro_search_a.expanded_retrieval.nodes import doc_retrieval
|
||||
from onyx.agent_search.pro_search_a.expanded_retrieval.nodes import doc_verification
|
||||
from onyx.agent_search.pro_search_a.expanded_retrieval.nodes import expand_queries
|
||||
from onyx.agent_search.pro_search_a.expanded_retrieval.nodes import format_results
|
||||
from onyx.agent_search.pro_search_a.expanded_retrieval.nodes import verification_kickoff
|
||||
from onyx.agent_search.pro_search_a.expanded_retrieval.states import (
|
||||
ExpandedRetrievalInput,
|
||||
)
|
||||
from onyx.agent_search.pro_search_a.expanded_retrieval.states import (
|
||||
ExpandedRetrievalOutput,
|
||||
)
|
||||
from onyx.agent_search.pro_search_a.expanded_retrieval.states import (
|
||||
ExpandedRetrievalState,
|
||||
)
|
||||
from onyx.agent_search.shared_graph_utils.utils import get_test_config
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def expanded_retrieval_graph_builder() -> StateGraph:
|
||||
graph = StateGraph(
|
||||
state_schema=ExpandedRetrievalState,
|
||||
input=ExpandedRetrievalInput,
|
||||
output=ExpandedRetrievalOutput,
|
||||
)
|
||||
|
||||
### Add nodes ###
|
||||
|
||||
graph.add_node(
|
||||
node="expand_queries",
|
||||
action=expand_queries,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
node="doc_retrieval",
|
||||
action=doc_retrieval,
|
||||
)
|
||||
graph.add_node(
|
||||
node="verification_kickoff",
|
||||
action=verification_kickoff,
|
||||
)
|
||||
graph.add_node(
|
||||
node="doc_verification",
|
||||
action=doc_verification,
|
||||
)
|
||||
graph.add_node(
|
||||
node="doc_reranking",
|
||||
action=doc_reranking,
|
||||
)
|
||||
graph.add_node(
|
||||
node="format_results",
|
||||
action=format_results,
|
||||
)
|
||||
|
||||
### Add edges ###
|
||||
graph.add_edge(
|
||||
start_key=START,
|
||||
end_key="expand_queries",
|
||||
)
|
||||
|
||||
graph.add_conditional_edges(
|
||||
source="expand_queries",
|
||||
path=parallel_retrieval_edge,
|
||||
path_map=["doc_retrieval"],
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="doc_retrieval",
|
||||
end_key="verification_kickoff",
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="doc_verification",
|
||||
end_key="doc_reranking",
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="doc_reranking",
|
||||
end_key="format_results",
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="format_results",
|
||||
end_key=END,
|
||||
)
|
||||
|
||||
return graph
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.llm.factory import get_default_llms
|
||||
from onyx.context.search.models import SearchRequest
|
||||
|
||||
graph = expanded_retrieval_graph_builder()
|
||||
compiled_graph = graph.compile()
|
||||
primary_llm, fast_llm = get_default_llms()
|
||||
search_request = SearchRequest(
|
||||
query="what can you do with onyx or danswer?",
|
||||
)
|
||||
|
||||
with get_session_context_manager() as db_session:
|
||||
pro_search_config, search_tool = get_test_config(
|
||||
db_session, primary_llm, fast_llm, search_request
|
||||
)
|
||||
inputs = ExpandedRetrievalInput(
|
||||
question="what can you do with onyx?",
|
||||
base_search=False,
|
||||
subgraph_fast_llm=fast_llm,
|
||||
subgraph_primary_llm=primary_llm,
|
||||
subgraph_db_session=db_session,
|
||||
subgraph_config=pro_search_config,
|
||||
subgraph_search_tool=search_tool,
|
||||
sub_question_id=None,
|
||||
)
|
||||
for thing in compiled_graph.stream(
|
||||
input=inputs,
|
||||
# debug=True,
|
||||
subgraphs=True,
|
||||
):
|
||||
logger.debug(thing)
|
||||
@@ -0,0 +1,21 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.agent_search.shared_graph_utils.models import AgentChunkStats
|
||||
from onyx.agent_search.shared_graph_utils.models import RetrievalFitStats
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.tools.models import SearchQueryInfo
|
||||
|
||||
### Models ###
|
||||
|
||||
|
||||
class QueryResult(BaseModel):
|
||||
query: str
|
||||
search_results: list[InferenceSection]
|
||||
stats: RetrievalFitStats | None
|
||||
query_info: SearchQueryInfo | None
|
||||
|
||||
|
||||
class ExpandedRetrievalResult(BaseModel):
|
||||
expanded_queries_results: list[QueryResult]
|
||||
all_documents: list[InferenceSection]
|
||||
sub_question_retrieval_stats: AgentChunkStats
|
||||
@@ -0,0 +1,410 @@
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
from typing import cast
|
||||
from typing import Literal
|
||||
|
||||
import numpy as np
|
||||
from langchain_core.callbacks.manager import dispatch_custom_event
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import merge_message_runs
|
||||
from langgraph.types import Command
|
||||
from langgraph.types import Send
|
||||
|
||||
from onyx.agent_search.core_state import in_subgraph_extract_core_fields
|
||||
from onyx.agent_search.pro_search_a.expanded_retrieval.models import (
|
||||
ExpandedRetrievalResult,
|
||||
)
|
||||
from onyx.agent_search.pro_search_a.expanded_retrieval.models import QueryResult
|
||||
from onyx.agent_search.pro_search_a.expanded_retrieval.states import DocRerankingUpdate
|
||||
from onyx.agent_search.pro_search_a.expanded_retrieval.states import DocRetrievalUpdate
|
||||
from onyx.agent_search.pro_search_a.expanded_retrieval.states import (
|
||||
DocVerificationInput,
|
||||
)
|
||||
from onyx.agent_search.pro_search_a.expanded_retrieval.states import (
|
||||
DocVerificationUpdate,
|
||||
)
|
||||
from onyx.agent_search.pro_search_a.expanded_retrieval.states import (
|
||||
ExpandedRetrievalInput,
|
||||
)
|
||||
from onyx.agent_search.pro_search_a.expanded_retrieval.states import (
|
||||
ExpandedRetrievalState,
|
||||
)
|
||||
from onyx.agent_search.pro_search_a.expanded_retrieval.states import (
|
||||
ExpandedRetrievalUpdate,
|
||||
)
|
||||
from onyx.agent_search.pro_search_a.expanded_retrieval.states import InferenceSection
|
||||
from onyx.agent_search.pro_search_a.expanded_retrieval.states import (
|
||||
QueryExpansionUpdate,
|
||||
)
|
||||
from onyx.agent_search.pro_search_a.expanded_retrieval.states import RetrievalInput
|
||||
from onyx.agent_search.shared_graph_utils.calculations import get_fit_scores
|
||||
from onyx.agent_search.shared_graph_utils.models import AgentChunkStats
|
||||
from onyx.agent_search.shared_graph_utils.models import RetrievalFitStats
|
||||
from onyx.agent_search.shared_graph_utils.prompts import REWRITE_PROMPT_MULTI_ORIGINAL
|
||||
from onyx.agent_search.shared_graph_utils.prompts import VERIFIER_PROMPT
|
||||
from onyx.agent_search.shared_graph_utils.utils import dispatch_separated
|
||||
from onyx.agent_search.shared_graph_utils.utils import parse_question_id
|
||||
from onyx.chat.models import ExtendedToolResponse
|
||||
from onyx.chat.models import SubQueryPiece
|
||||
from onyx.configs.dev_configs import AGENT_MAX_QUERY_RETRIEVAL_RESULTS
|
||||
from onyx.configs.dev_configs import AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS
|
||||
from onyx.configs.dev_configs import AGENT_RERANKING_STATS
|
||||
from onyx.configs.dev_configs import AGENT_RETRIEVAL_STATS
|
||||
from onyx.context.search.models import SearchRequest
|
||||
from onyx.context.search.pipeline import retrieval_preprocessing
|
||||
from onyx.context.search.postprocessing.postprocessing import rerank_sections
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.tools.models import SearchQueryInfo
|
||||
from onyx.tools.tool_implementations.search.search_tool import (
|
||||
SEARCH_RESPONSE_SUMMARY_ID,
|
||||
)
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary
|
||||
from onyx.tools.tool_implementations.search.search_tool import yield_search_responses
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def dispatch_subquery(level: int, question_nr: int) -> Callable[[str, int], None]:
|
||||
def helper(token: str, num: int) -> None:
|
||||
dispatch_custom_event(
|
||||
"subqueries",
|
||||
SubQueryPiece(
|
||||
sub_query=token,
|
||||
level=level,
|
||||
level_question_nr=question_nr,
|
||||
query_id=num,
|
||||
),
|
||||
)
|
||||
|
||||
return helper
|
||||
|
||||
|
||||
def expand_queries(state: ExpandedRetrievalInput) -> QueryExpansionUpdate:
|
||||
# Sometimes we want to expand the original question, sometimes we want to expand a sub-question.
|
||||
# When we are running this node on the original question, no question is explictly passed in.
|
||||
# Instead, we use the original question from the search request.
|
||||
question = state.get("question", state["subgraph_config"].search_request.query)
|
||||
llm: LLM = state["subgraph_fast_llm"]
|
||||
state["subgraph_db_session"]
|
||||
chat_session_id = state["subgraph_config"].chat_session_id
|
||||
sub_question_id = state.get("sub_question_id")
|
||||
if sub_question_id is None:
|
||||
level, question_nr = 0, 0
|
||||
else:
|
||||
level, question_nr = parse_question_id(sub_question_id)
|
||||
|
||||
if chat_session_id is None:
|
||||
raise ValueError("chat_session_id must be provided for agent search")
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=REWRITE_PROMPT_MULTI_ORIGINAL.format(question=question),
|
||||
)
|
||||
]
|
||||
|
||||
llm_response_list = dispatch_separated(
|
||||
llm.stream(prompt=msg), dispatch_subquery(level, question_nr)
|
||||
)
|
||||
|
||||
llm_response = merge_message_runs(llm_response_list, chunk_separator="")[0].content
|
||||
|
||||
rewritten_queries = llm_response.split("\n")
|
||||
|
||||
return QueryExpansionUpdate(
|
||||
expanded_queries=rewritten_queries,
|
||||
)
|
||||
|
||||
|
||||
def doc_retrieval(state: RetrievalInput) -> DocRetrievalUpdate:
|
||||
"""
|
||||
Retrieve documents
|
||||
|
||||
Args:
|
||||
state (RetrievalInput): Primary state + the query to retrieve
|
||||
|
||||
Updates:
|
||||
expanded_retrieval_results: list[ExpandedRetrievalResult]
|
||||
retrieved_documents: list[InferenceSection]
|
||||
"""
|
||||
query_to_retrieve = state["query_to_retrieve"]
|
||||
search_tool = state["subgraph_search_tool"]
|
||||
|
||||
retrieved_docs: list[InferenceSection] = []
|
||||
if not query_to_retrieve.strip():
|
||||
logger.warning("Empty query, skipping retrieval")
|
||||
return DocRetrievalUpdate(
|
||||
expanded_retrieval_results=[],
|
||||
retrieved_documents=[],
|
||||
)
|
||||
|
||||
query_info = None
|
||||
# new db session to avoid concurrency issues
|
||||
with get_session_context_manager() as db_session:
|
||||
for tool_response in search_tool.run(
|
||||
query=query_to_retrieve,
|
||||
force_no_rerank=True,
|
||||
alternate_db_session=db_session,
|
||||
):
|
||||
# get retrieved docs to send to the rest of the graph
|
||||
if tool_response.id == SEARCH_RESPONSE_SUMMARY_ID:
|
||||
response = cast(SearchResponseSummary, tool_response.response)
|
||||
retrieved_docs = response.top_sections
|
||||
query_info = SearchQueryInfo(
|
||||
predicted_search=response.predicted_search,
|
||||
final_filters=response.final_filters,
|
||||
recency_bias_multiplier=response.recency_bias_multiplier,
|
||||
)
|
||||
|
||||
retrieved_docs = retrieved_docs[:AGENT_MAX_QUERY_RETRIEVAL_RESULTS]
|
||||
pre_rerank_docs = retrieved_docs
|
||||
if search_tool.search_pipeline is not None:
|
||||
pre_rerank_docs = (
|
||||
search_tool.search_pipeline._retrieved_sections or retrieved_docs
|
||||
)
|
||||
|
||||
if AGENT_RETRIEVAL_STATS:
|
||||
fit_scores = get_fit_scores(
|
||||
pre_rerank_docs,
|
||||
retrieved_docs,
|
||||
)
|
||||
else:
|
||||
fit_scores = None
|
||||
|
||||
expanded_retrieval_result = QueryResult(
|
||||
query=query_to_retrieve,
|
||||
search_results=retrieved_docs,
|
||||
stats=fit_scores,
|
||||
query_info=query_info,
|
||||
)
|
||||
return DocRetrievalUpdate(
|
||||
expanded_retrieval_results=[expanded_retrieval_result],
|
||||
retrieved_documents=retrieved_docs,
|
||||
)
|
||||
|
||||
|
||||
def verification_kickoff(
|
||||
state: ExpandedRetrievalState,
|
||||
) -> Command[Literal["doc_verification"]]:
|
||||
documents = state["retrieved_documents"]
|
||||
verification_question = state.get(
|
||||
"question", state["subgraph_config"].search_request.query
|
||||
)
|
||||
sub_question_id = state.get("sub_question_id")
|
||||
return Command(
|
||||
update={},
|
||||
goto=[
|
||||
Send(
|
||||
node="doc_verification",
|
||||
arg=DocVerificationInput(
|
||||
doc_to_verify=doc,
|
||||
question=verification_question,
|
||||
base_search=False,
|
||||
sub_question_id=sub_question_id,
|
||||
**in_subgraph_extract_core_fields(state),
|
||||
),
|
||||
)
|
||||
for doc in documents
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def doc_verification(state: DocVerificationInput) -> DocVerificationUpdate:
|
||||
"""
|
||||
Check whether the document is relevant for the original user question
|
||||
|
||||
Args:
|
||||
state (DocVerificationInput): The current state
|
||||
|
||||
Updates:
|
||||
verified_documents: list[InferenceSection]
|
||||
"""
|
||||
|
||||
question = state["question"]
|
||||
doc_to_verify = state["doc_to_verify"]
|
||||
document_content = doc_to_verify.combined_content
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=VERIFIER_PROMPT.format(
|
||||
question=question, document_content=document_content
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
fast_llm = state["subgraph_fast_llm"]
|
||||
|
||||
response = fast_llm.invoke(msg)
|
||||
|
||||
verified_documents = []
|
||||
if isinstance(response.content, str) and "yes" in response.content.lower():
|
||||
verified_documents.append(doc_to_verify)
|
||||
|
||||
return DocVerificationUpdate(
|
||||
verified_documents=verified_documents,
|
||||
)
|
||||
|
||||
|
||||
def doc_reranking(state: ExpandedRetrievalState) -> DocRerankingUpdate:
|
||||
verified_documents = state["verified_documents"]
|
||||
|
||||
# Rerank post retrieval and verification. First, create a search query
|
||||
# then create the list of reranked sections
|
||||
|
||||
question = state.get("question", state["subgraph_config"].search_request.query)
|
||||
with get_session_context_manager() as db_session:
|
||||
_search_query = retrieval_preprocessing(
|
||||
search_request=SearchRequest(query=question),
|
||||
user=state["subgraph_search_tool"].user, # bit of a hack
|
||||
llm=state["subgraph_fast_llm"],
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# skip section filtering
|
||||
|
||||
if (
|
||||
_search_query.rerank_settings
|
||||
and _search_query.rerank_settings.rerank_model_name
|
||||
and _search_query.rerank_settings.num_rerank > 0
|
||||
):
|
||||
reranked_documents = rerank_sections(
|
||||
_search_query,
|
||||
verified_documents,
|
||||
)
|
||||
else:
|
||||
logger.warning("No reranking settings found, using unranked documents")
|
||||
reranked_documents = verified_documents
|
||||
|
||||
if AGENT_RERANKING_STATS:
|
||||
fit_scores = get_fit_scores(verified_documents, reranked_documents)
|
||||
else:
|
||||
fit_scores = RetrievalFitStats(fit_score_lift=0, rerank_effect=0, fit_scores={})
|
||||
|
||||
# TODO: stream deduped docs here, or decide to use search tool ranking/verification
|
||||
|
||||
return DocRerankingUpdate(
|
||||
reranked_documents=[
|
||||
doc for doc in reranked_documents if type(doc) == InferenceSection
|
||||
][:AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS],
|
||||
sub_question_retrieval_stats=fit_scores,
|
||||
)
|
||||
|
||||
|
||||
def _calculate_sub_question_retrieval_stats(
|
||||
verified_documents: list[InferenceSection],
|
||||
expanded_retrieval_results: list[QueryResult],
|
||||
) -> AgentChunkStats:
|
||||
chunk_scores: dict[str, dict[str, list[int | float]]] = defaultdict(
|
||||
lambda: defaultdict(list)
|
||||
)
|
||||
|
||||
for expanded_retrieval_result in expanded_retrieval_results:
|
||||
for doc in expanded_retrieval_result.search_results:
|
||||
doc_chunk_id = f"{doc.center_chunk.document_id}_{doc.center_chunk.chunk_id}"
|
||||
if doc.center_chunk.score is not None:
|
||||
chunk_scores[doc_chunk_id]["score"].append(doc.center_chunk.score)
|
||||
|
||||
verified_doc_chunk_ids = [
|
||||
f"{verified_document.center_chunk.document_id}_{verified_document.center_chunk.chunk_id}"
|
||||
for verified_document in verified_documents
|
||||
]
|
||||
dismissed_doc_chunk_ids = []
|
||||
|
||||
raw_chunk_stats_counts: dict[str, int] = defaultdict(int)
|
||||
raw_chunk_stats_scores: dict[str, float] = defaultdict(float)
|
||||
for doc_chunk_id, chunk_data in chunk_scores.items():
|
||||
if doc_chunk_id in verified_doc_chunk_ids:
|
||||
raw_chunk_stats_counts["verified_count"] += 1
|
||||
|
||||
valid_chunk_scores = [
|
||||
score for score in chunk_data["score"] if score is not None
|
||||
]
|
||||
raw_chunk_stats_scores["verified_scores"] += float(
|
||||
np.mean(valid_chunk_scores)
|
||||
)
|
||||
else:
|
||||
raw_chunk_stats_counts["rejected_count"] += 1
|
||||
valid_chunk_scores = [
|
||||
score for score in chunk_data["score"] if score is not None
|
||||
]
|
||||
raw_chunk_stats_scores["rejected_scores"] += float(
|
||||
np.mean(valid_chunk_scores)
|
||||
)
|
||||
dismissed_doc_chunk_ids.append(doc_chunk_id)
|
||||
|
||||
if raw_chunk_stats_counts["verified_count"] == 0:
|
||||
verified_avg_scores = 0.0
|
||||
else:
|
||||
verified_avg_scores = raw_chunk_stats_scores["verified_scores"] / float(
|
||||
raw_chunk_stats_counts["verified_count"]
|
||||
)
|
||||
|
||||
rejected_scores = raw_chunk_stats_scores.get("rejected_scores", None)
|
||||
if rejected_scores is not None:
|
||||
rejected_avg_scores = rejected_scores / float(
|
||||
raw_chunk_stats_counts["rejected_count"]
|
||||
)
|
||||
else:
|
||||
rejected_avg_scores = None
|
||||
|
||||
chunk_stats = AgentChunkStats(
|
||||
verified_count=raw_chunk_stats_counts["verified_count"],
|
||||
verified_avg_scores=verified_avg_scores,
|
||||
rejected_count=raw_chunk_stats_counts["rejected_count"],
|
||||
rejected_avg_scores=rejected_avg_scores,
|
||||
verified_doc_chunk_ids=verified_doc_chunk_ids,
|
||||
dismissed_doc_chunk_ids=dismissed_doc_chunk_ids,
|
||||
)
|
||||
|
||||
return chunk_stats
|
||||
|
||||
|
||||
def format_results(state: ExpandedRetrievalState) -> ExpandedRetrievalUpdate:
|
||||
level, question_nr = parse_question_id(state.get("sub_question_id") or "0_0")
|
||||
query_infos = [
|
||||
result.query_info
|
||||
for result in state["expanded_retrieval_results"]
|
||||
if result.query_info is not None
|
||||
]
|
||||
if len(query_infos) == 0:
|
||||
raise ValueError("No query info found")
|
||||
|
||||
# main question docs will be sent later after aggregation and deduping with sub-question docs
|
||||
if not (level == 0 and question_nr == 0):
|
||||
for tool_response in yield_search_responses(
|
||||
query=state["question"],
|
||||
reranked_sections=state[
|
||||
"retrieved_documents"
|
||||
], # TODO: rename params. this one is supposed to be the sections pre-merging
|
||||
final_context_sections=state["reranked_documents"],
|
||||
search_query_info=query_infos[0], # TODO: handle differing query infos?
|
||||
get_section_relevance=lambda: None, # TODO: add relevance
|
||||
search_tool=state["subgraph_search_tool"],
|
||||
):
|
||||
dispatch_custom_event(
|
||||
"tool_response",
|
||||
ExtendedToolResponse(
|
||||
id=tool_response.id,
|
||||
response=tool_response.response,
|
||||
level=level,
|
||||
level_question_nr=question_nr,
|
||||
),
|
||||
)
|
||||
sub_question_retrieval_stats = _calculate_sub_question_retrieval_stats(
|
||||
verified_documents=state["verified_documents"],
|
||||
expanded_retrieval_results=state["expanded_retrieval_results"],
|
||||
)
|
||||
|
||||
if sub_question_retrieval_stats is None:
|
||||
sub_question_retrieval_stats = AgentChunkStats()
|
||||
# else:
|
||||
# sub_question_retrieval_stats = [sub_question_retrieval_stats]
|
||||
|
||||
return ExpandedRetrievalUpdate(
|
||||
expanded_retrieval_result=ExpandedRetrievalResult(
|
||||
expanded_queries_results=state["expanded_retrieval_results"],
|
||||
all_documents=state["reranked_documents"],
|
||||
sub_question_retrieval_stats=sub_question_retrieval_stats,
|
||||
),
|
||||
)
|
||||
@@ -0,0 +1,82 @@
|
||||
from operator import add
|
||||
from typing import Annotated
|
||||
from typing import TypedDict
|
||||
|
||||
from onyx.agent_search.core_state import SubgraphCoreState
|
||||
from onyx.agent_search.pro_search_a.expanded_retrieval.models import (
|
||||
ExpandedRetrievalResult,
|
||||
)
|
||||
from onyx.agent_search.pro_search_a.expanded_retrieval.models import QueryResult
|
||||
from onyx.agent_search.shared_graph_utils.models import RetrievalFitStats
|
||||
from onyx.agent_search.shared_graph_utils.operators import dedup_inference_sections
|
||||
from onyx.context.search.models import InferenceSection
|
||||
|
||||
|
||||
### States ###
|
||||
|
||||
## Graph Input State
|
||||
|
||||
|
||||
class ExpandedRetrievalInput(SubgraphCoreState):
|
||||
question: str
|
||||
base_search: bool
|
||||
sub_question_id: str | None
|
||||
|
||||
|
||||
## Update/Return States
|
||||
|
||||
|
||||
class QueryExpansionUpdate(TypedDict):
|
||||
expanded_queries: list[str]
|
||||
|
||||
|
||||
class DocVerificationUpdate(TypedDict):
|
||||
verified_documents: Annotated[list[InferenceSection], dedup_inference_sections]
|
||||
|
||||
|
||||
class DocRetrievalUpdate(TypedDict):
|
||||
expanded_retrieval_results: Annotated[list[QueryResult], add]
|
||||
retrieved_documents: Annotated[list[InferenceSection], dedup_inference_sections]
|
||||
|
||||
|
||||
class DocRerankingUpdate(TypedDict):
|
||||
reranked_documents: Annotated[list[InferenceSection], dedup_inference_sections]
|
||||
sub_question_retrieval_stats: RetrievalFitStats | None
|
||||
|
||||
|
||||
class ExpandedRetrievalUpdate(TypedDict):
|
||||
expanded_retrieval_result: ExpandedRetrievalResult
|
||||
|
||||
|
||||
## Graph Output State
|
||||
|
||||
|
||||
class ExpandedRetrievalOutput(TypedDict):
|
||||
expanded_retrieval_result: ExpandedRetrievalResult
|
||||
base_expanded_retrieval_result: ExpandedRetrievalResult
|
||||
|
||||
|
||||
## Graph State
|
||||
|
||||
|
||||
class ExpandedRetrievalState(
|
||||
# This includes the core state
|
||||
ExpandedRetrievalInput,
|
||||
QueryExpansionUpdate,
|
||||
DocRetrievalUpdate,
|
||||
DocVerificationUpdate,
|
||||
DocRerankingUpdate,
|
||||
ExpandedRetrievalOutput,
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
## Conditional Input States
|
||||
|
||||
|
||||
class DocVerificationInput(ExpandedRetrievalInput):
|
||||
doc_to_verify: InferenceSection
|
||||
|
||||
|
||||
class RetrievalInput(ExpandedRetrievalInput):
|
||||
query_to_retrieve: str
|
||||
92
backend/onyx/agent_search/pro_search_a/main/edges.py
Normal file
92
backend/onyx/agent_search/pro_search_a/main/edges.py
Normal file
@@ -0,0 +1,92 @@
|
||||
from collections.abc import Hashable
|
||||
from typing import Literal
|
||||
|
||||
from langgraph.types import Send
|
||||
|
||||
from onyx.agent_search.core_state import extract_core_fields_for_subgraph
|
||||
from onyx.agent_search.pro_search_a.answer_initial_sub_question.states import (
|
||||
AnswerQuestionInput,
|
||||
)
|
||||
from onyx.agent_search.pro_search_a.answer_initial_sub_question.states import (
|
||||
AnswerQuestionOutput,
|
||||
)
|
||||
from onyx.agent_search.pro_search_a.main.states import MainState
|
||||
from onyx.agent_search.pro_search_a.main.states import RequireRefinedAnswerUpdate
|
||||
from onyx.agent_search.shared_graph_utils.utils import make_question_id
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def parallelize_initial_sub_question_answering(
|
||||
state: MainState,
|
||||
) -> list[Send | Hashable]:
|
||||
if len(state["initial_decomp_questions"]) > 0:
|
||||
# sub_question_record_ids = [subq_record.id for subq_record in state["sub_question_records"]]
|
||||
# if len(state["sub_question_records"]) == 0:
|
||||
# if state["config"].use_persistence:
|
||||
# raise ValueError("No sub-questions found for initial decompozed questions")
|
||||
# else:
|
||||
# # in this case, we are doing retrieval on the original question.
|
||||
# # to make all the logic consistent, we create a new sub-question
|
||||
# # with the same content as the original question
|
||||
# sub_question_record_ids = [1] * len(state["initial_decomp_questions"])
|
||||
|
||||
return [
|
||||
Send(
|
||||
"answer_query_subgraph",
|
||||
AnswerQuestionInput(
|
||||
**extract_core_fields_for_subgraph(state),
|
||||
question=question,
|
||||
question_id=make_question_id(0, question_nr),
|
||||
),
|
||||
)
|
||||
for question_nr, question in enumerate(state["initial_decomp_questions"])
|
||||
]
|
||||
|
||||
else:
|
||||
return [
|
||||
Send(
|
||||
"ingest_answers",
|
||||
AnswerQuestionOutput(
|
||||
answer_results=[],
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
# Define the function that determines whether to continue or not
|
||||
def continue_to_refined_answer_or_end(
|
||||
state: RequireRefinedAnswerUpdate,
|
||||
) -> Literal["refined_decompose", "logging_node"]:
|
||||
if state["require_refined_answer"]:
|
||||
return "refined_decompose"
|
||||
else:
|
||||
return "logging_node"
|
||||
|
||||
|
||||
def parallelize_refined_sub_question_answering(
|
||||
state: MainState,
|
||||
) -> list[Send | Hashable]:
|
||||
if len(state["refined_sub_questions"]) > 0:
|
||||
return [
|
||||
Send(
|
||||
"answer_refinement_sub_question",
|
||||
AnswerQuestionInput(
|
||||
**extract_core_fields_for_subgraph(state),
|
||||
question=question_data.sub_question,
|
||||
question_id=make_question_id(1, question_nr),
|
||||
),
|
||||
)
|
||||
for question_nr, question_data in state["refined_sub_questions"].items()
|
||||
]
|
||||
|
||||
else:
|
||||
return [
|
||||
Send(
|
||||
"ingest_refined_sub_answers",
|
||||
AnswerQuestionOutput(
|
||||
answer_results=[],
|
||||
),
|
||||
)
|
||||
]
|
||||
264
backend/onyx/agent_search/pro_search_a/main/graph_builder.py
Normal file
264
backend/onyx/agent_search/pro_search_a/main/graph_builder.py
Normal file
@@ -0,0 +1,264 @@
|
||||
from langgraph.graph import END
|
||||
from langgraph.graph import START
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
from onyx.agent_search.pro_search_a.answer_initial_sub_question.graph_builder import (
|
||||
answer_query_graph_builder,
|
||||
)
|
||||
from onyx.agent_search.pro_search_a.answer_refinement_sub_question.graph_builder import (
|
||||
answer_refined_query_graph_builder,
|
||||
)
|
||||
from onyx.agent_search.pro_search_a.base_raw_search.graph_builder import (
|
||||
base_raw_search_graph_builder,
|
||||
)
|
||||
from onyx.agent_search.pro_search_a.main.edges import continue_to_refined_answer_or_end
|
||||
from onyx.agent_search.pro_search_a.main.edges import (
|
||||
parallelize_initial_sub_question_answering,
|
||||
)
|
||||
from onyx.agent_search.pro_search_a.main.edges import (
|
||||
parallelize_refined_sub_question_answering,
|
||||
)
|
||||
from onyx.agent_search.pro_search_a.main.nodes import agent_logging
|
||||
from onyx.agent_search.pro_search_a.main.nodes import entity_term_extraction_llm
|
||||
from onyx.agent_search.pro_search_a.main.nodes import generate_initial_answer
|
||||
from onyx.agent_search.pro_search_a.main.nodes import generate_refined_answer
|
||||
from onyx.agent_search.pro_search_a.main.nodes import ingest_initial_base_retrieval
|
||||
from onyx.agent_search.pro_search_a.main.nodes import (
|
||||
ingest_initial_sub_question_answers,
|
||||
)
|
||||
from onyx.agent_search.pro_search_a.main.nodes import ingest_refined_answers
|
||||
from onyx.agent_search.pro_search_a.main.nodes import initial_answer_quality_check
|
||||
from onyx.agent_search.pro_search_a.main.nodes import initial_sub_question_creation
|
||||
from onyx.agent_search.pro_search_a.main.nodes import refined_answer_decision
|
||||
from onyx.agent_search.pro_search_a.main.nodes import refined_sub_question_creation
|
||||
from onyx.agent_search.pro_search_a.main.states import MainInput
|
||||
from onyx.agent_search.pro_search_a.main.states import MainState
|
||||
from onyx.agent_search.shared_graph_utils.utils import get_test_config
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
test_mode = False
|
||||
|
||||
|
||||
def main_graph_builder(test_mode: bool = False) -> StateGraph:
|
||||
graph = StateGraph(
|
||||
state_schema=MainState,
|
||||
input=MainInput,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
node="initial_sub_question_creation",
|
||||
action=initial_sub_question_creation,
|
||||
)
|
||||
answer_query_subgraph = answer_query_graph_builder().compile()
|
||||
graph.add_node(
|
||||
node="answer_query_subgraph",
|
||||
action=answer_query_subgraph,
|
||||
)
|
||||
|
||||
base_raw_search_subgraph = base_raw_search_graph_builder().compile()
|
||||
graph.add_node(
|
||||
node="base_raw_search_subgraph",
|
||||
action=base_raw_search_subgraph,
|
||||
)
|
||||
|
||||
# refined_answer_subgraph = refined_answers_graph_builder().compile()
|
||||
# graph.add_node(
|
||||
# node="refined_answer_subgraph",
|
||||
# action=refined_answer_subgraph,
|
||||
# )
|
||||
|
||||
graph.add_node(
|
||||
node="refined_sub_question_creation",
|
||||
action=refined_sub_question_creation,
|
||||
)
|
||||
|
||||
answer_refined_question = answer_refined_query_graph_builder().compile()
|
||||
graph.add_node(
|
||||
node="answer_refined_question",
|
||||
action=answer_refined_question,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
node="ingest_refined_answers",
|
||||
action=ingest_refined_answers,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
node="generate_refined_answer",
|
||||
action=generate_refined_answer,
|
||||
)
|
||||
|
||||
# graph.add_node(
|
||||
# node="check_refined_answer",
|
||||
# action=check_refined_answer,
|
||||
# )
|
||||
|
||||
graph.add_node(
|
||||
node="ingest_initial_retrieval",
|
||||
action=ingest_initial_base_retrieval,
|
||||
)
|
||||
graph.add_node(
|
||||
node="ingest_initial_sub_question_answers",
|
||||
action=ingest_initial_sub_question_answers,
|
||||
)
|
||||
graph.add_node(
|
||||
node="generate_initial_answer",
|
||||
action=generate_initial_answer,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
node="initial_answer_quality_check",
|
||||
action=initial_answer_quality_check,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
node="entity_term_extraction_llm",
|
||||
action=entity_term_extraction_llm,
|
||||
)
|
||||
graph.add_node(
|
||||
node="refined_answer_decision",
|
||||
action=refined_answer_decision,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
node="logging_node",
|
||||
action=agent_logging,
|
||||
)
|
||||
# if test_mode:
|
||||
# graph.add_node(
|
||||
# node="generate_initial_base_answer",
|
||||
# action=generate_initial_base_answer,
|
||||
# )
|
||||
|
||||
### Add edges ###
|
||||
|
||||
graph.add_edge(start_key=START, end_key="base_raw_search_subgraph")
|
||||
|
||||
graph.add_edge(
|
||||
start_key="base_raw_search_subgraph",
|
||||
end_key="ingest_initial_retrieval",
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key=START,
|
||||
end_key="initial_sub_question_creation",
|
||||
)
|
||||
graph.add_conditional_edges(
|
||||
source="initial_sub_question_creation",
|
||||
path=parallelize_initial_sub_question_answering,
|
||||
path_map=["answer_query_subgraph"],
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="answer_query_subgraph",
|
||||
end_key="ingest_initial_sub_question_answers",
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key=["ingest_initial_sub_question_answers", "ingest_initial_retrieval"],
|
||||
end_key="generate_initial_answer",
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key=["ingest_initial_sub_question_answers", "ingest_initial_retrieval"],
|
||||
end_key="entity_term_extraction_llm",
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key="generate_initial_answer",
|
||||
end_key="initial_answer_quality_check",
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key=["initial_answer_quality_check", "entity_term_extraction_llm"],
|
||||
end_key="refined_answer_decision",
|
||||
)
|
||||
|
||||
graph.add_conditional_edges(
|
||||
source="refined_answer_decision",
|
||||
path=continue_to_refined_answer_or_end,
|
||||
path_map=["refined_sub_question_creation", "logging_node"],
|
||||
)
|
||||
|
||||
graph.add_conditional_edges(
|
||||
source="refined_sub_question_creation",
|
||||
path=parallelize_refined_sub_question_answering,
|
||||
path_map=["answer_refined_question"],
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="answer_refined_question",
|
||||
end_key="ingest_refined_answers",
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key="ingest_refined_answers",
|
||||
end_key="generate_refined_answer",
|
||||
)
|
||||
|
||||
# graph.add_conditional_edges(
|
||||
# source="refined_answer_decision",
|
||||
# path=continue_to_refined_answer_or_end,
|
||||
# path_map=["refined_answer_subgraph", END],
|
||||
# )
|
||||
|
||||
# graph.add_edge(
|
||||
# start_key="refined_answer_subgraph",
|
||||
# end_key="generate_refined_answer",
|
||||
# )
|
||||
|
||||
graph.add_edge(
|
||||
start_key="generate_refined_answer",
|
||||
end_key="logging_node",
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key="logging_node",
|
||||
end_key=END,
|
||||
)
|
||||
|
||||
# graph.add_edge(
|
||||
# start_key="generate_refined_answer",
|
||||
# end_key="check_refined_answer",
|
||||
# )
|
||||
|
||||
# graph.add_edge(
|
||||
# start_key="check_refined_answer",
|
||||
# end_key=END,
|
||||
# )
|
||||
|
||||
return graph
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pass
|
||||
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.llm.factory import get_default_llms
|
||||
from onyx.context.search.models import SearchRequest
|
||||
|
||||
graph = main_graph_builder()
|
||||
compiled_graph = graph.compile()
|
||||
primary_llm, fast_llm = get_default_llms()
|
||||
|
||||
with get_session_context_manager() as db_session:
|
||||
search_request = SearchRequest(query="Who created Excel?")
|
||||
pro_search_config, search_tool = get_test_config(
|
||||
db_session, primary_llm, fast_llm, search_request
|
||||
)
|
||||
|
||||
inputs = MainInput(
|
||||
primary_llm=primary_llm,
|
||||
fast_llm=fast_llm,
|
||||
db_session=db_session,
|
||||
config=pro_search_config,
|
||||
search_tool=search_tool,
|
||||
)
|
||||
|
||||
for thing in compiled_graph.stream(
|
||||
input=inputs,
|
||||
# stream_mode="debug",
|
||||
# debug=True,
|
||||
subgraphs=True,
|
||||
):
|
||||
logger.debug(thing)
|
||||
69
backend/onyx/agent_search/pro_search_a/main/models.py
Normal file
69
backend/onyx/agent_search/pro_search_a/main/models.py
Normal file
@@ -0,0 +1,69 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
### Models ###
|
||||
|
||||
|
||||
class Entity(BaseModel):
|
||||
entity_name: str
|
||||
entity_type: str
|
||||
|
||||
|
||||
class Relationship(BaseModel):
|
||||
relationship_name: str
|
||||
relationship_type: str
|
||||
relationship_entities: list[str]
|
||||
|
||||
|
||||
class Term(BaseModel):
|
||||
term_name: str
|
||||
term_type: str
|
||||
term_similar_to: list[str]
|
||||
|
||||
|
||||
class EntityRelationshipTermExtraction(BaseModel):
|
||||
entities: list[Entity]
|
||||
relationships: list[Relationship]
|
||||
terms: list[Term]
|
||||
|
||||
|
||||
class FollowUpSubQuestion(BaseModel):
|
||||
sub_question: str
|
||||
sub_question_id: str
|
||||
verified: bool
|
||||
answered: bool
|
||||
answer: str
|
||||
|
||||
|
||||
class AgentTimings(BaseModel):
|
||||
base_duration__s: float | None
|
||||
refined_duration__s: float | None
|
||||
full_duration__s: float | None
|
||||
|
||||
|
||||
class AgentBaseMetrics(BaseModel):
|
||||
num_verified_documents_total: int | None
|
||||
num_verified_documents_core: int | None
|
||||
verified_avg_score_core: float | None
|
||||
num_verified_documents_base: int | float | None
|
||||
verified_avg_score_base: float | None
|
||||
base_doc_boost_factor: float | None
|
||||
support_boost_factor: float | None
|
||||
duration__s: float | None
|
||||
|
||||
|
||||
class AgentRefinedMetrics(BaseModel):
|
||||
refined_doc_boost_factor: float | None
|
||||
refined_question_boost_factor: float | None
|
||||
duration__s: float | None
|
||||
|
||||
|
||||
class AgentAdditionalMetrics(BaseModel):
|
||||
pass
|
||||
|
||||
|
||||
class CombinedAgentMetrics(BaseModel):
|
||||
timings: AgentTimings
|
||||
base_metrics: AgentBaseMetrics
|
||||
refined_metrics: AgentRefinedMetrics
|
||||
additional_metrics: AgentAdditionalMetrics
|
||||
1056
backend/onyx/agent_search/pro_search_a/main/nodes.py
Normal file
1056
backend/onyx/agent_search/pro_search_a/main/nodes.py
Normal file
File diff suppressed because it is too large
Load Diff
151
backend/onyx/agent_search/pro_search_a/main/states.py
Normal file
151
backend/onyx/agent_search/pro_search_a/main/states.py
Normal file
@@ -0,0 +1,151 @@
|
||||
from datetime import datetime
|
||||
from operator import add
|
||||
from typing import Annotated
|
||||
from typing import TypedDict
|
||||
|
||||
from onyx.agent_search.core_state import CoreState
|
||||
from onyx.agent_search.pro_search_a.answer_initial_sub_question.states import (
|
||||
QuestionAnswerResults,
|
||||
)
|
||||
from onyx.agent_search.pro_search_a.expanded_retrieval.models import (
|
||||
ExpandedRetrievalResult,
|
||||
)
|
||||
from onyx.agent_search.pro_search_a.expanded_retrieval.models import QueryResult
|
||||
from onyx.agent_search.pro_search_a.main.models import AgentBaseMetrics
|
||||
from onyx.agent_search.pro_search_a.main.models import AgentRefinedMetrics
|
||||
from onyx.agent_search.pro_search_a.main.models import EntityRelationshipTermExtraction
|
||||
from onyx.agent_search.pro_search_a.main.models import FollowUpSubQuestion
|
||||
from onyx.agent_search.shared_graph_utils.models import AgentChunkStats
|
||||
from onyx.agent_search.shared_graph_utils.models import InitialAgentResultStats
|
||||
from onyx.agent_search.shared_graph_utils.models import RefinedAgentStats
|
||||
from onyx.agent_search.shared_graph_utils.operators import dedup_inference_sections
|
||||
from onyx.agent_search.shared_graph_utils.operators import dedup_question_answer_results
|
||||
from onyx.context.search.models import InferenceSection
|
||||
|
||||
|
||||
### States ###
|
||||
|
||||
## Update States
|
||||
|
||||
|
||||
class RefinedAgentStartStats(TypedDict):
|
||||
agent_refined_start_time: datetime | None
|
||||
|
||||
|
||||
class RefinedAgentEndStats(TypedDict):
|
||||
agent_refined_end_time: datetime | None
|
||||
agent_refined_metrics: AgentRefinedMetrics
|
||||
|
||||
|
||||
class BaseDecompUpdateBase(TypedDict):
|
||||
agent_start_time: datetime
|
||||
initial_decomp_questions: list[str]
|
||||
|
||||
|
||||
class BaseDecompUpdate(
|
||||
RefinedAgentStartStats, RefinedAgentEndStats, BaseDecompUpdateBase
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
class InitialAnswerBASEUpdate(TypedDict):
|
||||
initial_base_answer: str
|
||||
|
||||
|
||||
class InitialAnswerUpdate(TypedDict):
|
||||
initial_answer: str
|
||||
initial_agent_stats: InitialAgentResultStats | None
|
||||
generated_sub_questions: list[str]
|
||||
agent_base_end_time: datetime
|
||||
agent_base_metrics: AgentBaseMetrics
|
||||
|
||||
|
||||
class RefinedAnswerUpdateBase(TypedDict):
|
||||
refined_answer: str
|
||||
refined_agent_stats: RefinedAgentStats | None
|
||||
refined_answer_quality: bool
|
||||
|
||||
|
||||
class RefinedAnswerUpdate(RefinedAgentEndStats, RefinedAnswerUpdateBase):
|
||||
pass
|
||||
|
||||
|
||||
class InitialAnswerQualityUpdate(TypedDict):
|
||||
initial_answer_quality: bool
|
||||
|
||||
|
||||
class RequireRefinedAnswerUpdate(TypedDict):
|
||||
require_refined_answer: bool
|
||||
|
||||
|
||||
class DecompAnswersUpdate(TypedDict):
|
||||
documents: Annotated[list[InferenceSection], dedup_inference_sections]
|
||||
decomp_answer_results: Annotated[
|
||||
list[QuestionAnswerResults], dedup_question_answer_results
|
||||
]
|
||||
|
||||
|
||||
class FollowUpDecompAnswersUpdate(TypedDict):
|
||||
refined_documents: Annotated[list[InferenceSection], dedup_inference_sections]
|
||||
refined_decomp_answer_results: Annotated[list[QuestionAnswerResults], add]
|
||||
|
||||
|
||||
class ExpandedRetrievalUpdate(TypedDict):
|
||||
all_original_question_documents: Annotated[
|
||||
list[InferenceSection], dedup_inference_sections
|
||||
]
|
||||
original_question_retrieval_results: list[QueryResult]
|
||||
original_question_retrieval_stats: AgentChunkStats
|
||||
|
||||
|
||||
class EntityTermExtractionUpdate(TypedDict):
|
||||
entity_retlation_term_extractions: EntityRelationshipTermExtraction
|
||||
|
||||
|
||||
class FollowUpSubQuestionsUpdateBase(TypedDict):
|
||||
refined_sub_questions: dict[int, FollowUpSubQuestion]
|
||||
|
||||
|
||||
class FollowUpSubQuestionsUpdate(
|
||||
RefinedAgentStartStats, FollowUpSubQuestionsUpdateBase
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
## Graph Input State
|
||||
## Graph Input State
|
||||
|
||||
|
||||
class MainInput(CoreState):
|
||||
pass
|
||||
|
||||
|
||||
## Graph State
|
||||
|
||||
|
||||
class MainState(
|
||||
# This includes the core state
|
||||
MainInput,
|
||||
BaseDecompUpdateBase,
|
||||
InitialAnswerUpdate,
|
||||
InitialAnswerBASEUpdate,
|
||||
DecompAnswersUpdate,
|
||||
ExpandedRetrievalUpdate,
|
||||
EntityTermExtractionUpdate,
|
||||
InitialAnswerQualityUpdate,
|
||||
RequireRefinedAnswerUpdate,
|
||||
FollowUpSubQuestionsUpdateBase,
|
||||
FollowUpDecompAnswersUpdate,
|
||||
RefinedAnswerUpdateBase,
|
||||
RefinedAgentStartStats,
|
||||
RefinedAgentEndStats,
|
||||
):
|
||||
# expanded_retrieval_result: Annotated[list[ExpandedRetrievalResult], add]
|
||||
base_raw_search_result: Annotated[list[ExpandedRetrievalResult], add]
|
||||
|
||||
|
||||
## Graph Output State - presently not used
|
||||
|
||||
|
||||
class MainOutput(TypedDict):
|
||||
pass
|
||||
404
backend/onyx/agent_search/run_graph.py
Normal file
404
backend/onyx/agent_search/run_graph.py
Normal file
@@ -0,0 +1,404 @@
|
||||
import asyncio
|
||||
from collections import defaultdict
|
||||
from collections.abc import AsyncIterable
|
||||
from collections.abc import Iterable
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.runnables.schema import StreamEvent
|
||||
from langgraph.graph.state import CompiledStateGraph
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.agent_search.basic.graph_builder import basic_graph_builder
|
||||
from onyx.agent_search.basic.states import BasicInput
|
||||
from onyx.agent_search.models import AgentDocumentCitations
|
||||
from onyx.agent_search.pro_search_a.main.graph_builder import main_graph_builder
|
||||
from onyx.agent_search.pro_search_a.main.states import MainInput
|
||||
from onyx.agent_search.shared_graph_utils.utils import get_test_config
|
||||
from onyx.chat.llm_response_handler import LLMResponseHandlerManager
|
||||
from onyx.chat.models import AgentAnswerPiece
|
||||
from onyx.chat.models import AnswerPacket
|
||||
from onyx.chat.models import AnswerStream
|
||||
from onyx.chat.models import AnswerStyleConfig
|
||||
from onyx.chat.models import ExtendedToolResponse
|
||||
from onyx.chat.models import OnyxAnswerPiece
|
||||
from onyx.chat.models import ProSearchConfig
|
||||
from onyx.chat.models import SubQueryPiece
|
||||
from onyx.chat.models import SubQuestionPiece
|
||||
from onyx.chat.models import ToolResponse
|
||||
from onyx.chat.prompt_builder.build import LLMCall
|
||||
from onyx.context.search.models import SearchRequest
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from onyx.tools.tool_runner import ToolCallKickoff
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_COMPILED_GRAPH: CompiledStateGraph | None = None
|
||||
|
||||
|
||||
def _set_combined_token_value(
|
||||
combined_token: str, parsed_object: AgentAnswerPiece
|
||||
) -> AgentAnswerPiece:
|
||||
parsed_object.answer_piece = combined_token
|
||||
|
||||
return parsed_object
|
||||
|
||||
|
||||
def _parse_agent_event(
|
||||
event: StreamEvent,
|
||||
) -> AnswerPacket | None:
|
||||
"""
|
||||
Parse the event into a typed object.
|
||||
Return None if we are not interested in the event.
|
||||
"""
|
||||
|
||||
event_type = event["event"]
|
||||
|
||||
if event_type == "on_custom_event":
|
||||
# TODO: different AnswerStream types for different events
|
||||
if event["name"] == "decomp_qs":
|
||||
return cast(SubQuestionPiece, event["data"])
|
||||
elif event["name"] == "subqueries":
|
||||
return cast(SubQueryPiece, event["data"])
|
||||
elif event["name"] == "sub_answers":
|
||||
return cast(AgentAnswerPiece, event["data"])
|
||||
elif event["name"] == "initial_agent_answer":
|
||||
return cast(AgentAnswerPiece, event["data"])
|
||||
elif event["name"] == "tool_response":
|
||||
return cast(ToolResponse, event["data"])
|
||||
elif event["name"] == "basic_response":
|
||||
return cast(AnswerPacket, event["data"])
|
||||
return None
|
||||
|
||||
|
||||
def _manage_async_event_streaming(
|
||||
compiled_graph: CompiledStateGraph,
|
||||
graph_input: MainInput | BasicInput,
|
||||
) -> Iterable[StreamEvent]:
|
||||
async def _run_async_event_stream() -> AsyncIterable[StreamEvent]:
|
||||
async for event in compiled_graph.astream_events(
|
||||
input=graph_input,
|
||||
# debug=True,
|
||||
# indicating v2 here deserves further scrutiny
|
||||
version="v2",
|
||||
):
|
||||
yield event
|
||||
|
||||
# This might be able to be simplified
|
||||
def _yield_async_to_sync() -> Iterable[StreamEvent]:
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
# Get the async generator
|
||||
async_gen = _run_async_event_stream()
|
||||
# Convert to AsyncIterator
|
||||
async_iter = async_gen.__aiter__()
|
||||
while True:
|
||||
try:
|
||||
# Create a coroutine by calling anext with the async iterator
|
||||
next_coro = anext(async_iter)
|
||||
# Run the coroutine to get the next event
|
||||
event = loop.run_until_complete(next_coro)
|
||||
yield event
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
return _yield_async_to_sync()
|
||||
|
||||
|
||||
def run_graph(
|
||||
compiled_graph: CompiledStateGraph,
|
||||
input: MainInput | BasicInput,
|
||||
) -> AnswerStream:
|
||||
agent_document_citations: dict[int, dict[int, list[AgentDocumentCitations]]] = {}
|
||||
agent_question_citations_used_docs: defaultdict[
|
||||
int, defaultdict[int, list[str]]
|
||||
] = defaultdict(lambda: defaultdict(list))
|
||||
|
||||
citation_potential: defaultdict[int, defaultdict[int, bool]] = defaultdict(
|
||||
lambda: defaultdict(lambda: False)
|
||||
)
|
||||
|
||||
current_yield_components: defaultdict[
|
||||
int, defaultdict[int, list[str]]
|
||||
] = defaultdict(lambda: defaultdict(list))
|
||||
current_yield_str: defaultdict[int, defaultdict[int, str]] = defaultdict(
|
||||
lambda: defaultdict(lambda: "")
|
||||
)
|
||||
|
||||
# def _process_citation(current_yield_str: str) -> tuple[str, str]:
|
||||
# """Process a citation string and return the formatted citation and remaining text."""
|
||||
# section_split = current_yield_str.split(']', 1)
|
||||
# citation_part = section_split[0] + ']'
|
||||
# remaining_text = section_split[1] if len(section_split) > 1 else ''
|
||||
|
||||
# if 'D' in citation_part:
|
||||
# citation_type = "Document"
|
||||
# formatted_citation = citation_part.replace('[D', '[[').replace(']', ']]')
|
||||
# else: # Q case
|
||||
# citation_type = "Question"
|
||||
# formatted_citation = citation_part.replace('[Q', '{{').replace(']', '}}')
|
||||
|
||||
# return f" --- CITATION: {citation_type} - {formatted_citation}", remaining_text
|
||||
|
||||
for event in _manage_async_event_streaming(
|
||||
compiled_graph=compiled_graph, graph_input=input
|
||||
):
|
||||
parsed_object = _parse_agent_event(event)
|
||||
if not parsed_object:
|
||||
continue
|
||||
|
||||
level = getattr(parsed_object, "level", None)
|
||||
level_question_nr = getattr(parsed_object, "level_question_nr", None)
|
||||
|
||||
if isinstance(parsed_object, (OnyxAnswerPiece, AgentAnswerPiece)):
|
||||
# logger.debug(f"FA {parsed_object.answer_piece}")
|
||||
|
||||
if isinstance(parsed_object, AgentAnswerPiece):
|
||||
token = parsed_object.answer_piece
|
||||
level = parsed_object.level
|
||||
level_question_nr = parsed_object.level_question_nr
|
||||
else:
|
||||
yield parsed_object
|
||||
continue
|
||||
# raise ValueError(
|
||||
# f"Invalid parsed object type: {type(parsed_object)}"
|
||||
# )
|
||||
|
||||
if not citation_potential[level][level_question_nr] and token:
|
||||
if token.startswith(" ["):
|
||||
citation_potential[level][level_question_nr] = True
|
||||
current_yield_components[level][level_question_nr] = [token]
|
||||
else:
|
||||
yield parsed_object
|
||||
elif token and citation_potential[level][level_question_nr]:
|
||||
current_yield_components[level][level_question_nr].append(token)
|
||||
current_yield_str[level][level_question_nr] = "".join(
|
||||
current_yield_components[level][level_question_nr]
|
||||
)
|
||||
|
||||
if current_yield_str[level][level_question_nr].strip().startswith(
|
||||
"[D"
|
||||
) or current_yield_str[level][level_question_nr].strip().startswith(
|
||||
"[Q"
|
||||
):
|
||||
citation_potential[level][level_question_nr] = True
|
||||
|
||||
else:
|
||||
citation_potential[level][level_question_nr] = False
|
||||
parsed_object = _set_combined_token_value(
|
||||
current_yield_str[level][level_question_nr], parsed_object
|
||||
)
|
||||
yield parsed_object
|
||||
|
||||
if (
|
||||
len(current_yield_components[level][level_question_nr]) > 15
|
||||
): # ??? 15?
|
||||
citation_potential[level][level_question_nr] = False
|
||||
parsed_object = _set_combined_token_value(
|
||||
current_yield_str[level][level_question_nr], parsed_object
|
||||
)
|
||||
yield parsed_object
|
||||
elif "]" in current_yield_str[level][level_question_nr]:
|
||||
section_split = current_yield_str[level][level_question_nr].split(
|
||||
"]"
|
||||
)
|
||||
section_split[0] + "]" # dead code?
|
||||
start_of_next_section = "]".join(section_split[1:])
|
||||
citation_string = current_yield_str[level][level_question_nr][
|
||||
: -len(start_of_next_section)
|
||||
]
|
||||
if "[D" in citation_string:
|
||||
cite_open_bracket_marker, cite_close_bracket_marker = (
|
||||
"[",
|
||||
"]",
|
||||
)
|
||||
cite_identifyer = "D"
|
||||
|
||||
try:
|
||||
cited_document = int(
|
||||
citation_string[level][level_question_nr][2:-1]
|
||||
)
|
||||
if level and level_question_nr:
|
||||
link = agent_document_citations[int(level)][
|
||||
int(level_question_nr)
|
||||
][cited_document].link
|
||||
else:
|
||||
link = ""
|
||||
except (ValueError, IndexError):
|
||||
link = ""
|
||||
elif "[Q" in citation_string:
|
||||
cite_open_bracket_marker, cite_close_bracket_marker = (
|
||||
"{",
|
||||
"}",
|
||||
)
|
||||
cite_identifyer = "Q"
|
||||
else:
|
||||
pass
|
||||
|
||||
citation_string = citation_string.replace(
|
||||
"[" + cite_identifyer,
|
||||
cite_open_bracket_marker * 2,
|
||||
).replace("]", cite_close_bracket_marker * 2)
|
||||
|
||||
if cite_identifyer == "D":
|
||||
citation_string += f"({link})"
|
||||
|
||||
parsed_object = _set_combined_token_value(
|
||||
citation_string, parsed_object
|
||||
)
|
||||
|
||||
yield parsed_object
|
||||
|
||||
current_yield_components[level][level_question_nr] = [
|
||||
start_of_next_section
|
||||
]
|
||||
if not start_of_next_section.strip().startswith("["):
|
||||
citation_potential[level][level_question_nr] = False
|
||||
|
||||
elif isinstance(parsed_object, ExtendedToolResponse):
|
||||
if parsed_object.id == "search_response_summary":
|
||||
level = parsed_object.level
|
||||
level_question_nr = parsed_object.level_question_nr
|
||||
for inference_section in parsed_object.response.top_sections:
|
||||
doc_link = inference_section.center_chunk.source_links[0]
|
||||
doc_title = inference_section.center_chunk.title
|
||||
doc_id = inference_section.center_chunk.document_id
|
||||
|
||||
if (
|
||||
doc_id
|
||||
not in agent_question_citations_used_docs[level][
|
||||
level_question_nr
|
||||
]
|
||||
):
|
||||
if level not in agent_document_citations:
|
||||
agent_document_citations[level] = {}
|
||||
if level_question_nr not in agent_document_citations[level]:
|
||||
agent_document_citations[level][level_question_nr] = []
|
||||
|
||||
agent_document_citations[level][level_question_nr].append(
|
||||
AgentDocumentCitations(
|
||||
document_id=doc_id,
|
||||
document_title=doc_title,
|
||||
link=doc_link,
|
||||
)
|
||||
)
|
||||
agent_question_citations_used_docs[level][
|
||||
level_question_nr
|
||||
].append(doc_id)
|
||||
|
||||
yield parsed_object
|
||||
|
||||
else:
|
||||
yield parsed_object
|
||||
|
||||
|
||||
# TODO: call this once on startup, TBD where and if it should be gated based
|
||||
# on dev mode or not
|
||||
def load_compiled_graph() -> CompiledStateGraph:
|
||||
global _COMPILED_GRAPH
|
||||
if _COMPILED_GRAPH is None:
|
||||
graph = main_graph_builder()
|
||||
_COMPILED_GRAPH = graph.compile()
|
||||
return _COMPILED_GRAPH
|
||||
|
||||
|
||||
def run_main_graph(
|
||||
config: ProSearchConfig,
|
||||
search_tool: SearchTool,
|
||||
primary_llm: LLM,
|
||||
fast_llm: LLM,
|
||||
db_session: Session,
|
||||
) -> AnswerStream:
|
||||
compiled_graph = load_compiled_graph()
|
||||
input = MainInput(
|
||||
config=config,
|
||||
primary_llm=primary_llm,
|
||||
fast_llm=fast_llm,
|
||||
db_session=db_session,
|
||||
search_tool=search_tool,
|
||||
)
|
||||
return run_graph(compiled_graph, input)
|
||||
|
||||
|
||||
def run_basic_graph(
|
||||
last_llm_call: LLMCall | None,
|
||||
primary_llm: LLM,
|
||||
answer_style_config: AnswerStyleConfig,
|
||||
response_handler_manager: LLMResponseHandlerManager,
|
||||
) -> AnswerStream:
|
||||
graph = basic_graph_builder()
|
||||
compiled_graph = graph.compile()
|
||||
input = BasicInput(
|
||||
last_llm_call=last_llm_call,
|
||||
llm=primary_llm,
|
||||
answer_style_config=answer_style_config,
|
||||
response_handler_manager=response_handler_manager,
|
||||
calls=0,
|
||||
)
|
||||
return run_graph(compiled_graph, input)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from onyx.llm.factory import get_default_llms
|
||||
|
||||
now_start = datetime.now()
|
||||
logger.debug(f"Start at {now_start}")
|
||||
|
||||
graph = main_graph_builder()
|
||||
compiled_graph = graph.compile()
|
||||
now_end = datetime.now()
|
||||
logger.debug(f"Graph compiled in {now_end - now_start} seconds")
|
||||
primary_llm, fast_llm = get_default_llms()
|
||||
search_request = SearchRequest(
|
||||
# query="what can you do with gitlab?",
|
||||
query="What are the guiding principles behind the development of cockroachDB?",
|
||||
)
|
||||
# Joachim custom persona
|
||||
|
||||
with get_session_context_manager() as db_session:
|
||||
config, search_tool = get_test_config(
|
||||
db_session, primary_llm, fast_llm, search_request
|
||||
)
|
||||
# search_request.persona = get_persona_by_id(1, None, db_session)
|
||||
config.use_persistence = True
|
||||
|
||||
# with open("output.txt", "w") as f:
|
||||
tool_responses: list = []
|
||||
input = MainInput(
|
||||
config=config,
|
||||
primary_llm=primary_llm,
|
||||
fast_llm=fast_llm,
|
||||
db_session=db_session,
|
||||
search_tool=search_tool,
|
||||
)
|
||||
for output in run_graph(compiled_graph, input):
|
||||
# pass
|
||||
|
||||
if isinstance(output, ToolCallKickoff):
|
||||
pass
|
||||
elif isinstance(output, ToolResponse):
|
||||
tool_responses.append(output.response)
|
||||
elif isinstance(output, SubQuestionPiece):
|
||||
logger.debug(
|
||||
f"SQ {output.level} - {output.level_question_nr} - {output.sub_question} | "
|
||||
)
|
||||
elif (
|
||||
isinstance(output, AgentAnswerPiece)
|
||||
and output.answer_type == "agent_sub_answer"
|
||||
):
|
||||
logger.debug(
|
||||
f" ---- SA {output.level} - {output.level_question_nr} {output.answer_piece} | "
|
||||
)
|
||||
elif (
|
||||
isinstance(output, AgentAnswerPiece)
|
||||
and output.answer_type == "agent_level_answer"
|
||||
):
|
||||
logger.debug(f" ---------- FA {output.answer_piece} | ")
|
||||
|
||||
# for tool_response in tool_responses:
|
||||
# logger.debug(tool_response)
|
||||
@@ -0,0 +1,34 @@
|
||||
from langchain.schema import AIMessage
|
||||
from langchain.schema import HumanMessage
|
||||
from langchain.schema import SystemMessage
|
||||
from langchain_core.messages.tool import ToolMessage
|
||||
|
||||
from onyx.agent_search.shared_graph_utils.prompts import BASE_RAG_PROMPT_v2
|
||||
from onyx.context.search.models import InferenceSection
|
||||
|
||||
|
||||
def build_sub_question_answer_prompt(
|
||||
question: str,
|
||||
original_question: str,
|
||||
docs: list[InferenceSection],
|
||||
persona_specification: str,
|
||||
) -> list[SystemMessage | HumanMessage | AIMessage | ToolMessage]:
|
||||
system_message = SystemMessage(
|
||||
content=persona_specification,
|
||||
)
|
||||
|
||||
docs_format_list = [
|
||||
f"""Document Number: [D{doc_nr + 1}]\n
|
||||
Content: {doc.combined_content}\n\n"""
|
||||
for doc_nr, doc in enumerate(docs)
|
||||
]
|
||||
|
||||
docs_str = "\n\n".join(docs_format_list)
|
||||
|
||||
human_message = HumanMessage(
|
||||
content=BASE_RAG_PROMPT_v2.format(
|
||||
question=question, original_question=original_question, context=docs_str
|
||||
)
|
||||
)
|
||||
|
||||
return [system_message, human_message]
|
||||
98
backend/onyx/agent_search/shared_graph_utils/calculations.py
Normal file
98
backend/onyx/agent_search/shared_graph_utils/calculations.py
Normal file
@@ -0,0 +1,98 @@
|
||||
import numpy as np
|
||||
|
||||
from onyx.agent_search.shared_graph_utils.models import RetrievalFitScoreMetrics
|
||||
from onyx.agent_search.shared_graph_utils.models import RetrievalFitStats
|
||||
from onyx.chat.models import SectionRelevancePiece
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def unique_chunk_id(doc: InferenceSection) -> str:
|
||||
return f"{doc.center_chunk.document_id}_{doc.center_chunk.chunk_id}"
|
||||
|
||||
|
||||
def calculate_rank_shift(list1: list, list2: list, top_n: int = 20) -> float:
|
||||
shift = 0
|
||||
for rank_first, doc_id in enumerate(list1[:top_n], 1):
|
||||
try:
|
||||
rank_second = list2.index(doc_id) + 1
|
||||
except ValueError:
|
||||
rank_second = len(list2) # Document not found in second list
|
||||
|
||||
shift += np.abs(rank_first - rank_second) / np.log(1 + rank_first * rank_second)
|
||||
|
||||
return shift / top_n
|
||||
|
||||
|
||||
def get_fit_scores(
|
||||
pre_reranked_results: list[InferenceSection],
|
||||
post_reranked_results: list[InferenceSection] | list[SectionRelevancePiece],
|
||||
) -> RetrievalFitStats | None:
|
||||
"""
|
||||
Calculate retrieval metrics for search purposes
|
||||
"""
|
||||
|
||||
if len(pre_reranked_results) == 0 or len(post_reranked_results) == 0:
|
||||
return None
|
||||
|
||||
ranked_sections = {
|
||||
"initial": pre_reranked_results,
|
||||
"reranked": post_reranked_results,
|
||||
}
|
||||
|
||||
fit_eval: RetrievalFitStats = RetrievalFitStats(
|
||||
fit_score_lift=0,
|
||||
rerank_effect=0,
|
||||
fit_scores={
|
||||
"initial": RetrievalFitScoreMetrics(scores={}, chunk_ids=[]),
|
||||
"reranked": RetrievalFitScoreMetrics(scores={}, chunk_ids=[]),
|
||||
},
|
||||
)
|
||||
|
||||
for rank_type, docs in ranked_sections.items():
|
||||
logger.debug(f"rank_type: {rank_type}")
|
||||
|
||||
for i in [1, 5, 10]:
|
||||
fit_eval.fit_scores[rank_type].scores[str(i)] = (
|
||||
sum(
|
||||
[
|
||||
float(doc.center_chunk.score)
|
||||
for doc in docs[:i]
|
||||
if type(doc) == InferenceSection
|
||||
and doc.center_chunk.score is not None
|
||||
]
|
||||
)
|
||||
/ i
|
||||
)
|
||||
|
||||
fit_eval.fit_scores[rank_type].scores["fit_score"] = (
|
||||
1
|
||||
/ 3
|
||||
* (
|
||||
fit_eval.fit_scores[rank_type].scores["1"]
|
||||
+ fit_eval.fit_scores[rank_type].scores["5"]
|
||||
+ fit_eval.fit_scores[rank_type].scores["10"]
|
||||
)
|
||||
)
|
||||
|
||||
fit_eval.fit_scores[rank_type].scores["fit_score"] = fit_eval.fit_scores[
|
||||
rank_type
|
||||
].scores["1"]
|
||||
|
||||
fit_eval.fit_scores[rank_type].chunk_ids = [
|
||||
unique_chunk_id(doc) for doc in docs if type(doc) == InferenceSection
|
||||
]
|
||||
|
||||
fit_eval.fit_score_lift = (
|
||||
fit_eval.fit_scores["reranked"].scores["fit_score"]
|
||||
/ fit_eval.fit_scores["initial"].scores["fit_score"]
|
||||
)
|
||||
|
||||
fit_eval.rerank_effect = calculate_rank_shift(
|
||||
fit_eval.fit_scores["initial"].chunk_ids,
|
||||
fit_eval.fit_scores["reranked"].chunk_ids,
|
||||
)
|
||||
|
||||
return fit_eval
|
||||
52
backend/onyx/agent_search/shared_graph_utils/models.py
Normal file
52
backend/onyx/agent_search/shared_graph_utils/models.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
# Pydantic models for structured outputs
|
||||
class RewrittenQueries(BaseModel):
|
||||
rewritten_queries: list[str]
|
||||
|
||||
|
||||
class BinaryDecision(BaseModel):
|
||||
decision: Literal["yes", "no"]
|
||||
|
||||
|
||||
class BinaryDecisionWithReasoning(BaseModel):
|
||||
reasoning: str
|
||||
decision: Literal["yes", "no"]
|
||||
|
||||
|
||||
class RetrievalFitScoreMetrics(BaseModel):
|
||||
scores: dict[str, float]
|
||||
chunk_ids: list[str]
|
||||
|
||||
|
||||
class RetrievalFitStats(BaseModel):
|
||||
fit_score_lift: float
|
||||
rerank_effect: float
|
||||
fit_scores: dict[str, RetrievalFitScoreMetrics]
|
||||
|
||||
|
||||
class AgentChunkScores(BaseModel):
|
||||
scores: dict[str, dict[str, list[int | float]]]
|
||||
|
||||
|
||||
class AgentChunkStats(BaseModel):
|
||||
verified_count: int | None
|
||||
verified_avg_scores: float | None
|
||||
rejected_count: int | None
|
||||
rejected_avg_scores: float | None
|
||||
verified_doc_chunk_ids: list[str]
|
||||
dismissed_doc_chunk_ids: list[str]
|
||||
|
||||
|
||||
class InitialAgentResultStats(BaseModel):
|
||||
sub_questions: dict[str, float | int | None]
|
||||
original_question: dict[str, float | int | None]
|
||||
agent_effectiveness: dict[str, float | int | None]
|
||||
|
||||
|
||||
class RefinedAgentStats(BaseModel):
|
||||
revision_doc_efficiency: float | None
|
||||
revision_question_efficiency: float | None
|
||||
31
backend/onyx/agent_search/shared_graph_utils/operators.py
Normal file
31
backend/onyx/agent_search/shared_graph_utils/operators.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from onyx.agent_search.pro_search_a.answer_initial_sub_question.models import (
|
||||
QuestionAnswerResults,
|
||||
)
|
||||
from onyx.chat.prune_and_merge import _merge_sections
|
||||
from onyx.context.search.models import InferenceSection
|
||||
|
||||
|
||||
def dedup_inference_sections(
|
||||
list1: list[InferenceSection], list2: list[InferenceSection]
|
||||
) -> list[InferenceSection]:
|
||||
deduped = _merge_sections(list1 + list2)
|
||||
return deduped
|
||||
|
||||
|
||||
def dedup_question_answer_results(
|
||||
question_answer_results_1: list[QuestionAnswerResults],
|
||||
question_answer_results_2: list[QuestionAnswerResults],
|
||||
) -> list[QuestionAnswerResults]:
|
||||
deduped_question_answer_results: list[
|
||||
QuestionAnswerResults
|
||||
] = question_answer_results_1
|
||||
utilized_question_ids: set[str] = set(
|
||||
[x.question_id for x in question_answer_results_1]
|
||||
)
|
||||
|
||||
for question_answer_result in question_answer_results_2:
|
||||
if question_answer_result.question_id not in utilized_question_ids:
|
||||
deduped_question_answer_results.append(question_answer_result)
|
||||
utilized_question_ids.add(question_answer_result.question_id)
|
||||
|
||||
return deduped_question_answer_results
|
||||
697
backend/onyx/agent_search/shared_graph_utils/prompts.py
Normal file
697
backend/onyx/agent_search/shared_graph_utils/prompts.py
Normal file
@@ -0,0 +1,697 @@
|
||||
REWRITE_PROMPT_MULTI_ORIGINAL = """ \n
|
||||
Please convert an initial user question into a 2-3 more appropriate short and pointed search queries for retrievel from a
|
||||
document store. Particularly, try to think about resolving ambiguities and make the search queries more specific,
|
||||
enabling the system to search more broadly.
|
||||
Also, try to make the search queries not redundant, i.e. not too similar! \n\n
|
||||
Here is the initial question:
|
||||
\n ------- \n
|
||||
{question}
|
||||
\n ------- \n
|
||||
Formulate the queries separated by newlines (Do not say 'Query 1: ...', just write the querytext) as follows:
|
||||
<query 1>
|
||||
<query 2>
|
||||
...
|
||||
queries: """
|
||||
|
||||
REWRITE_PROMPT_MULTI = """ \n
|
||||
Please create a list of 2-3 sample documents that could answer an original question. Each document
|
||||
should be about as long as the original question. \n
|
||||
Here is the initial question:
|
||||
\n ------- \n
|
||||
{question}
|
||||
\n ------- \n
|
||||
Formulate the sample documents separated by '--' (Do not say 'Document 1: ...', just write the text): """
|
||||
|
||||
# The prompt is only used if there is no persona prompt, so the placeholder is ''
|
||||
BASE_RAG_PROMPT = """ \n
|
||||
{persona_specification}
|
||||
Use the context provided below - and only the
|
||||
provided context - to answer the given question. (Note that the answer is in service of anserwing a broader
|
||||
question, given below as 'motivation'.)
|
||||
|
||||
Again, only use the provided context and do not use your internal knowledge! If you cannot answer the
|
||||
question based on the context, say "I don't know". It is a matter of life and death that you do NOT
|
||||
use your internal knowledge, just the provided information!
|
||||
|
||||
Make sure that you keep all relevant information, specifically as it concerns to the ultimate goal.
|
||||
(But keep other details as well.)
|
||||
|
||||
\nContext:\n {context} \n
|
||||
|
||||
Motivation:\n {original_question} \n\n
|
||||
\n\n
|
||||
And here is the question I want you to answer based on the context above (with the motivation in mind):
|
||||
\n--\n {question} \n--\n
|
||||
"""
|
||||
|
||||
BASE_RAG_PROMPT_v2 = """ \n
|
||||
Use the context provided below - and only the
|
||||
provided context - to answer the given question. (Note that the answer is in service of answering a broader
|
||||
question, given below as 'motivation'.)
|
||||
|
||||
Again, only use the provided context and do not use your internal knowledge! If you cannot answer the
|
||||
question based on the context, say "I don't know". It is a matter of life and death that you do NOT
|
||||
use your internal knowledge, just the provided information!
|
||||
|
||||
Make sure that you keep all relevant information, specifically as it concerns to the ultimate goal.
|
||||
(But keep other details as well.)
|
||||
|
||||
Please remember to provide inline citations in the format [D1], [D2], [D3], etc.\n\n\n
|
||||
|
||||
For your general information, here is the ultimate motivation:
|
||||
\n--\n {original_question} \n--\n
|
||||
\n\n
|
||||
And here is the actual question I want you to answer based on the context above (with the motivation in mind):
|
||||
\n--\n {question} \n--\n
|
||||
|
||||
Here is the context:
|
||||
\n\n\n--\n {context} \n--\n
|
||||
"""
|
||||
|
||||
|
||||
SUB_CHECK_PROMPT = """
|
||||
Your task is to see whether a given answer addresses a given question.
|
||||
Please do not use any internal knowledge you may have - just focus on whether the answer
|
||||
as given seems to largely address the question as given, or at least addresses part of the question.
|
||||
Here is the question:
|
||||
\n ------- \n
|
||||
{question}
|
||||
\n ------- \n
|
||||
Here is the suggested answer:
|
||||
\n ------- \n
|
||||
{base_answer}
|
||||
\n ------- \n
|
||||
Does the suggested answer address the question? Please answer with yes or no:"""
|
||||
|
||||
|
||||
BASE_CHECK_PROMPT = """ \n
|
||||
Please check whether 1) the suggested answer seems to fully address the original question AND 2)the
|
||||
original question requests a simple, factual answer, and there are no ambiguities, judgements,
|
||||
aggregations, or any other complications that may require extra context. (I.e., if the question is
|
||||
somewhat addressed, but the answer would benefit from more context, then answer with 'no'.)
|
||||
|
||||
Please only answer with 'yes' or 'no' \n
|
||||
Here is the initial question:
|
||||
\n ------- \n
|
||||
{question}
|
||||
\n ------- \n
|
||||
Here is the proposed answer:
|
||||
\n ------- \n
|
||||
{initial_answer}
|
||||
\n ------- \n
|
||||
Please answer with yes or no:"""
|
||||
|
||||
VERIFIER_PROMPT = """
|
||||
You are supposed to judge whether a document text contains data or information that is potentially relevant for a question.
|
||||
|
||||
Here is a document text that you can take as a fact:
|
||||
--
|
||||
DOCUMENT INFORMATION:
|
||||
{document_content}
|
||||
--
|
||||
|
||||
Do you think that this information is useful and relevant to answer the following question?
|
||||
(Other documents may supply additional information, so do not worry if the provided information
|
||||
is not enough to answer the question, but it needs to be relevant to the question.)
|
||||
--
|
||||
QUESTION:
|
||||
{question}
|
||||
--
|
||||
|
||||
Please answer with 'yes' or 'no':
|
||||
|
||||
Answer:
|
||||
|
||||
"""
|
||||
|
||||
INITIAL_DECOMPOSITION_PROMPT_BASIC = """ \n
|
||||
If you think it is helpful, please decompose an initial user question into not more
|
||||
than 4 appropriate sub-questions that help to answer the original question.
|
||||
The purpose for this decomposition is to isolate individulal entities
|
||||
(i.e., 'compare sales of company A and company B' -> 'what are sales for company A' + 'what are sales
|
||||
for company B'), split ambiguous terms (i.e., 'what is our success with company A' -> 'what are our
|
||||
sales with company A' + 'what is our market share with company A' + 'is company A a reference customer
|
||||
for us'), etc. Each sub-question should be realistically be answerable by a good RAG system.
|
||||
|
||||
Importantly, if you think it is not needed or helpful, please just return an empty list. That is ok too.
|
||||
|
||||
Here is the initial question:
|
||||
\n ------- \n
|
||||
{question}
|
||||
\n ------- \n
|
||||
|
||||
Please formulate your answer as a list of subquestions:
|
||||
|
||||
Answer:
|
||||
"""
|
||||
|
||||
REWRITE_PROMPT_SINGLE = """ \n
|
||||
Please convert an initial user question into a more appropriate search query for retrievel from a
|
||||
document store. \n
|
||||
Here is the initial question:
|
||||
\n ------- \n
|
||||
{question}
|
||||
\n ------- \n
|
||||
|
||||
Formulate the query: """
|
||||
|
||||
MODIFIED_RAG_PROMPT = """You are an assistant for question-answering tasks. Use the context provided below
|
||||
- and only this context - to answer the question. If you don't know the answer, just say "I don't know".
|
||||
Use three sentences maximum and keep the answer concise.
|
||||
Pay also particular attention to the sub-questions and their answers, at least it may enrich the answer.
|
||||
Again, only use the provided context and do not use your internal knowledge! If you cannot answer the
|
||||
question based on the context, say "I don't know". It is a matter of life and death that you do NOT
|
||||
use your internal knowledge, just the provided information!
|
||||
|
||||
\nQuestion: {question}
|
||||
\nContext: {combined_context} \n
|
||||
|
||||
Answer:"""
|
||||
|
||||
ORIG_DEEP_DECOMPOSE_PROMPT = """ \n
|
||||
An initial user question needs to be answered. An initial answer has been provided but it wasn't quite
|
||||
good enough. Also, some sub-questions had been answered and this information has been used to provide
|
||||
the initial answer. Some other subquestions may have been suggested based on little knowledge, but they
|
||||
were not directly answerable. Also, some entities, relationships and terms are givenm to you so that
|
||||
you have an idea of how the avaiolable data looks like.
|
||||
|
||||
Your role is to generate 3-5 new sub-questions that would help to answer the initial question,
|
||||
considering:
|
||||
|
||||
1) The initial question
|
||||
2) The initial answer that was found to be unsatisfactory
|
||||
3) The sub-questions that were answered
|
||||
4) The sub-questions that were suggested but not answered
|
||||
5) The entities, relationships and terms that were extracted from the context
|
||||
|
||||
The individual questions should be answerable by a good RAG system.
|
||||
So a good idea would be to use the sub-questions to resolve ambiguities and/or to separate the
|
||||
question for different entities that may be involved in the original question, but in a way that does
|
||||
not duplicate questions that were already tried.
|
||||
|
||||
Additional Guidelines:
|
||||
- The sub-questions should be specific to the question and provide richer context for the question,
|
||||
resolve ambiguities, or address shortcoming of the initial answer
|
||||
- Each sub-question - when answered - should be relevant for the answer to the original question
|
||||
- The sub-questions should be free from comparisions, ambiguities,judgements, aggregations, or any
|
||||
other complications that may require extra context.
|
||||
- The sub-questions MUST have the full context of the original question so that it can be executed by
|
||||
a RAG system independently without the original question available
|
||||
(Example:
|
||||
- initial question: "What is the capital of France?"
|
||||
- bad sub-question: "What is the name of the river there?"
|
||||
- good sub-question: "What is the name of the river that flows through Paris?"
|
||||
- For each sub-question, please provide a short explanation for why it is a good sub-question. So
|
||||
generate a list of dictionaries with the following format:
|
||||
[{{"sub_question": <sub-question>, "explanation": <explanation>, "search_term": <rewrite the
|
||||
sub-question using as a search phrase for the document store>}}, ...]
|
||||
|
||||
\n\n
|
||||
Here is the initial question:
|
||||
\n ------- \n
|
||||
{question}
|
||||
\n ------- \n
|
||||
|
||||
Here is the initial sub-optimal answer:
|
||||
\n ------- \n
|
||||
{base_answer}
|
||||
\n ------- \n
|
||||
|
||||
Here are the sub-questions that were answered:
|
||||
\n ------- \n
|
||||
{answered_sub_questions}
|
||||
\n ------- \n
|
||||
|
||||
Here are the sub-questions that were suggested but not answered:
|
||||
\n ------- \n
|
||||
{failed_sub_questions}
|
||||
\n ------- \n
|
||||
|
||||
And here are the entities, relationships and terms extracted from the context:
|
||||
\n ------- \n
|
||||
{entity_term_extraction_str}
|
||||
\n ------- \n
|
||||
|
||||
Please generate the list of good, fully contextualized sub-questions that would help to address the
|
||||
main question. Again, please find questions that are NOT overlapping too much with the already answered
|
||||
sub-questions or those that already were suggested and failed.
|
||||
In other words - what can we try in addition to what has been tried so far?
|
||||
|
||||
Please think through it step by step and then generate the list of json dictionaries with the following
|
||||
format:
|
||||
|
||||
{{"sub_questions": [{{"sub_question": <sub-question>,
|
||||
"explanation": <explanation>,
|
||||
"search_term": <rewrite the sub-question using as a search phrase for the document store>}},
|
||||
...]}} """
|
||||
|
||||
DEEP_DECOMPOSE_PROMPT = """ \n
|
||||
An initial user question needs to be answered. An initial answer has been provided but it wasn't quite
|
||||
good enough. Also, some sub-questions had been answered and this information has been used to provide
|
||||
the initial answer. Some other subquestions may have been suggested based on little knowledge, but they
|
||||
were not directly answerable. Also, some entities, relationships and terms are givenm to you so that
|
||||
you have an idea of how the avaiolable data looks like.
|
||||
|
||||
Your role is to generate 3-5 new sub-questions that would help to answer the initial question,
|
||||
considering:
|
||||
|
||||
1) The initial question
|
||||
2) The initial answer that was found to be unsatisfactory
|
||||
3) The sub-questions that were answered
|
||||
4) The sub-questions that were suggested but not answered
|
||||
5) The entities, relationships and terms that were extracted from the context
|
||||
|
||||
The individual questions should be answerable by a good RAG system.
|
||||
So a good idea would be to use the sub-questions to resolve ambiguities and/or to separate the
|
||||
question for different entities that may be involved in the original question, but in a way that does
|
||||
not duplicate questions that were already tried.
|
||||
|
||||
Additional Guidelines:
|
||||
- The sub-questions should be specific to the question and provide richer context for the question,
|
||||
resolve ambiguities, or address shortcoming of the initial answer
|
||||
- Each sub-question - when answered - should be relevant for the answer to the original question
|
||||
- The sub-questions should be free from comparisions, ambiguities,judgements, aggregations, or any
|
||||
other complications that may require extra context.
|
||||
- The sub-questions MUST have the full context of the original question so that it can be executed by
|
||||
a RAG system independently without the original question available
|
||||
(Example:
|
||||
- initial question: "What is the capital of France?"
|
||||
- bad sub-question: "What is the name of the river there?"
|
||||
- good sub-question: "What is the name of the river that flows through Paris?"
|
||||
- For each sub-question, please also provide a search term that can be used to retrieve relevant
|
||||
documents from a document store.
|
||||
- Consider specifically the sub-questions that were suggested but not answered. This is a sign that they are not
|
||||
answerable with the available context, and you should not ask similar questions.
|
||||
\n\n
|
||||
Here is the initial question:
|
||||
\n ------- \n
|
||||
{question}
|
||||
\n ------- \n
|
||||
|
||||
Here is the initial sub-optimal answer:
|
||||
\n ------- \n
|
||||
{base_answer}
|
||||
\n ------- \n
|
||||
|
||||
Here are the sub-questions that were answered:
|
||||
\n ------- \n
|
||||
{answered_sub_questions}
|
||||
\n ------- \n
|
||||
|
||||
Here are the sub-questions that were suggested but not answered:
|
||||
\n ------- \n
|
||||
{failed_sub_questions}
|
||||
\n ------- \n
|
||||
|
||||
And here are the entities, relationships and terms extracted from the context:
|
||||
\n ------- \n
|
||||
{entity_term_extraction_str}
|
||||
\n ------- \n
|
||||
|
||||
Please generate the list of good, fully contextualized sub-questions that would help to address the
|
||||
main question.
|
||||
|
||||
Specifically pay attention also to the entities, relationships and terms extracted, as these indicate what type of
|
||||
objects/relationships/terms you can ask about! Do not ask about entities, terms or relationships that are not
|
||||
mentioned in the 'entities, relationships and terms' section.
|
||||
|
||||
Again, please find questions that are NOT overlapping too much with the already answered
|
||||
sub-questions or those that already were suggested and failed.
|
||||
In other words - what can we try in addition to what has been tried so far?
|
||||
|
||||
Generate the list of questions separated by new lines like this:
|
||||
|
||||
<sub-question 1>
|
||||
<sub-question 2>
|
||||
<sub-question 3>
|
||||
...
|
||||
"""
|
||||
|
||||
DECOMPOSE_PROMPT = """ \n
|
||||
For an initial user question, please generate at 5-10 individual sub-questions whose answers would help
|
||||
\n to answer the initial question. The individual questions should be answerable by a good RAG system.
|
||||
So a good idea would be to \n use the sub-questions to resolve ambiguities and/or to separate the
|
||||
question for different entities that may be involved in the original question.
|
||||
|
||||
In order to arrive at meaningful sub-questions, please also consider the context retrieved from the
|
||||
document store, expressed as entities, relationships and terms. You can also think about the types
|
||||
mentioned in brackets
|
||||
|
||||
Guidelines:
|
||||
- The sub-questions should be specific to the question and provide richer context for the question,
|
||||
and or resolve ambiguities
|
||||
- Each sub-question - when answered - should be relevant for the answer to the original question
|
||||
- The sub-questions should be free from comparisions, ambiguities,judgements, aggregations, or any
|
||||
other complications that may require extra context.
|
||||
- The sub-questions MUST have the full context of the original question so that it can be executed by
|
||||
a RAG system independently without the original question available
|
||||
(Example:
|
||||
- initial question: "What is the capital of France?"
|
||||
- bad sub-question: "What is the name of the river there?"
|
||||
- good sub-question: "What is the name of the river that flows through Paris?"
|
||||
- For each sub-question, please provide a short explanation for why it is a good sub-question. So
|
||||
generate a list of dictionaries with the following format:
|
||||
[{{"sub_question": <sub-question>, "explanation": <explanation>}}, ...]
|
||||
|
||||
\n\n
|
||||
Here is the initial question:
|
||||
\n ------- \n
|
||||
{question}
|
||||
\n ------- \n
|
||||
|
||||
And here are the entities, relationships and terms extracted from the context:
|
||||
\n ------- \n
|
||||
{entity_term_extraction_str}
|
||||
\n ------- \n
|
||||
|
||||
Please generate the list of good, fully contextualized sub-questions that would help to address the
|
||||
main question. Don't be too specific unless the original question is specific.
|
||||
Please think through it step by step and then generate the list of json dictionaries with the following
|
||||
format:
|
||||
{{"sub_questions": [{{"sub_question": <sub-question>,
|
||||
"explanation": <explanation>,
|
||||
"search_term": <rewrite the sub-question using as a search phrase for the document store>}},
|
||||
...]}} """
|
||||
|
||||
#### Consolidations
|
||||
COMBINED_CONTEXT = """-------
|
||||
Below you will find useful information to answer the original question. First, you see a number of
|
||||
sub-questions with their answers. This information should be considered to be more focussed and
|
||||
somewhat more specific to the original question as it tries to contextualized facts.
|
||||
After that will see the documents that were considered to be relevant to answer the original question.
|
||||
|
||||
Here are the sub-questions and their answers:
|
||||
\n\n {deep_answer_context} \n\n
|
||||
\n\n Here are the documents that were considered to be relevant to answer the original question:
|
||||
\n\n {formated_docs} \n\n
|
||||
----------------
|
||||
"""
|
||||
|
||||
SUB_QUESTION_EXPLANATION_RANKER_PROMPT = """-------
|
||||
Below you will find a question that we ultimately want to answer (the original question) and a list of
|
||||
motivations in arbitrary order for generated sub-questions that are supposed to help us answering the
|
||||
original question. The motivations are formatted as <motivation number>: <motivation explanation>.
|
||||
(Again, the numbering is arbitrary and does not necessarily mean that 1 is the most relevant
|
||||
motivation and 2 is less relevant.)
|
||||
|
||||
Please rank the motivations in order of relevance for answering the original question. Also, try to
|
||||
ensure that the top questions do not duplicate too much, i.e. that they are not too similar.
|
||||
Ultimately, create a list with the motivation numbers where the number of the most relevant
|
||||
motivations comes first.
|
||||
|
||||
Here is the original question:
|
||||
\n\n {original_question} \n\n
|
||||
\n\n Here is the list of sub-question motivations:
|
||||
\n\n {sub_question_explanations} \n\n
|
||||
----------------
|
||||
|
||||
Please think step by step and then generate the ranked list of motivations.
|
||||
|
||||
Please format your answer as a json object in the following format:
|
||||
{{"reasonning": <explain your reasoning for the ranking>,
|
||||
"ranked_motivations": <ranked list of motivation numbers>}}
|
||||
"""
|
||||
|
||||
|
||||
INITIAL_DECOMPOSITION_PROMPT_QUESTIONS = """
|
||||
If you think it is helpful, please decompose an initial user question into 2 or 4 appropriate sub-questions that help to
|
||||
answer the original question. The purpose for this decomposition is to
|
||||
1) isolate individual entities (i.e., 'compare sales of company A and company B' -> ['what are sales for company A',
|
||||
'what are sales for company B')]
|
||||
2) clarify or disambiguate ambiguous terms (i.e., 'what is our success with company A' -> ['what are our sales with company A',
|
||||
'what is our market share with company A', 'is company A a reference customer for us', etc.])
|
||||
3) if a term or a metric is essentially clear, but it could relate to various components of an entity and you are generally
|
||||
familiar with the entity, then you can decompose the question into sub-questions that are more specific to components
|
||||
(i.e., 'what do we do to improve scalability of product X', 'what do we to to improve scalability of product X',
|
||||
'what do we do to improve stability of product X', ...])
|
||||
|
||||
If you think that a decomposition is not needed or helpful, please just return an empty string. That is ok too.
|
||||
|
||||
Here is the initial question:
|
||||
-------
|
||||
{question}
|
||||
-------
|
||||
Please formulate your answer as a newline-separated list of questions like so:
|
||||
<sub-question>
|
||||
<sub-question>
|
||||
<sub-question>
|
||||
|
||||
Answer:"""
|
||||
|
||||
INITIAL_DECOMPOSITION_PROMPT = """ \n
|
||||
Please decompose an initial user question into 2 or 3 appropriate sub-questions that help to
|
||||
answer the original question. The purpose for this decomposition is to isolate individulal entities
|
||||
(i.e., 'compare sales of company A and company B' -> 'what are sales for company A' + 'what are sales
|
||||
for company B'), split ambiguous terms (i.e., 'what is our success with company A' -> 'what are our
|
||||
sales with company A' + 'what is our market share with company A' + 'is company A a reference customer
|
||||
for us'), etc. Each sub-question should be realistically be answerable by a good RAG system. \n
|
||||
|
||||
For each sub-question, please also create one search term that can be used to retrieve relevant
|
||||
documents from a document store.
|
||||
|
||||
Here is the initial question:
|
||||
\n ------- \n
|
||||
{question}
|
||||
\n ------- \n
|
||||
|
||||
Please formulate your answer as a list of json objects with the following format:
|
||||
|
||||
[{{"sub_question": <sub-question>, "search_term": <search term>}}, ...]
|
||||
|
||||
Answer:
|
||||
"""
|
||||
|
||||
INITIAL_RAG_BASE_PROMPT = """ \n
|
||||
You are an assistant for question-answering tasks. Use the information provided below - and only the
|
||||
provided information - to answer the provided question.
|
||||
|
||||
The information provided below consists ofa number of documents that were deemed relevant for the question.
|
||||
|
||||
IMPORTANT RULES:
|
||||
- If you cannot reliably answer the question solely using the provided information, say that you cannot reliably answer.
|
||||
You may give some additional facts you learned, but do not try to invent an answer.
|
||||
- If the information is empty or irrelevant, just say "I don't know".
|
||||
- If the information is relevant but not fully conclusive, specify that the information is not conclusive and say why.
|
||||
|
||||
Try to keep your answer concise.
|
||||
|
||||
Here is the contextual information from the document store:
|
||||
\n ------- \n
|
||||
{context} \n\n\n
|
||||
\n ------- \n
|
||||
And here is the question I want you to answer based on the context above (with the motivation in mind):
|
||||
\n--\n {question} \n--\n
|
||||
Answer:"""
|
||||
|
||||
|
||||
### ANSWER GENERATION PROMPTS
|
||||
|
||||
# Persona specification
|
||||
ASSISTANT_SYSTEM_PROMPT_DEFAULT = """
|
||||
You are an assistant for question-answering tasks."""
|
||||
|
||||
ASSISTANT_SYSTEM_PROMPT_PERSONA = """
|
||||
You are an assistant for question-answering tasks. Here is more information about you:
|
||||
\n ------- \n
|
||||
{persona_prompt}
|
||||
\n ------- \n
|
||||
"""
|
||||
|
||||
SUB_QUESTION_ANSWER_TEMPLATE = """
|
||||
Sub-Question: Q{sub_question_nr}\n Sub-Question:\n - \n{sub_question}\n --\nAnswer:\n -\n {sub_answer}\n\n
|
||||
"""
|
||||
|
||||
INITIAL_RAG_PROMPT = """ \n
|
||||
{persona_specification}
|
||||
|
||||
Use the information provided below - and only the
|
||||
provided information - to answer the provided question.
|
||||
|
||||
The information provided below consists of:
|
||||
1) a number of answered sub-questions - these are very important(!) and definitely should be
|
||||
considered to answer the question.
|
||||
2) a number of documents that were also deemed relevant for the question.
|
||||
|
||||
IMPORTANT RULES:
|
||||
- If you cannot reliably answer the question solely using the provided information, say that you cannot reliably answer.
|
||||
You may give some additional facts you learned, but do not try to invent an answer.
|
||||
- If the information is empty or irrelevant, just say "I don't know".
|
||||
- If the information is relevant but not fully conclusive, specify that the information is not conclusive and say why.
|
||||
|
||||
Remember to provide inline citations of documentsin the format [D1], [D2], [D3], etc., and [Q1], [Q2],... if
|
||||
you want to cite the answer to a sub-question. If you have multiple citations, please cite for example
|
||||
as [D1][Q3], or [D2][D4], etc. Feel free to cite documents in addition to the sub-questions!
|
||||
Proper citations are important for the final answer to be verifiable! \n\n\n
|
||||
|
||||
Again, you should be sure that the answer is supported by the information provided!
|
||||
|
||||
Try to keep your answer concise. But also highlight uncertainties you may have should there be substantial ones,
|
||||
or assumptions you made.
|
||||
|
||||
Here is the contextual information:
|
||||
\n-------\n
|
||||
*Answered Sub-questions (these should really matter!):
|
||||
{answered_sub_questions}
|
||||
|
||||
And here are relevant document information that support the sub-question answers, or that are relevant for the actual question:\n
|
||||
|
||||
{relevant_docs}
|
||||
|
||||
\n-------\n
|
||||
\n
|
||||
And here is the question I want you to answer based on the information above:
|
||||
\n--\n
|
||||
{question}
|
||||
\n--\n\n
|
||||
Answer:"""
|
||||
|
||||
# sub_question_answer_str is empty
|
||||
INITIAL_RAG_PROMPT_NO_SUB_QUESTIONS = """{answered_sub_questions}
|
||||
{persona_specification}
|
||||
Use the information provided below
|
||||
- and only the provided information - to answer the provided question.
|
||||
The information provided below consists of a number of documents that were deemed relevant for the question.
|
||||
|
||||
IMPORTANT RULES:
|
||||
- If you cannot reliably answer the question solely using the provided information, say that you cannot reliably answer.
|
||||
You may give some additional facts you learned, but do not try to invent an answer.
|
||||
- If the information is irrelevant, just say "I don't know".
|
||||
- If the information is relevant but not fully conclusive, specify that the information is not conclusive and say why.
|
||||
|
||||
Again, you should be sure that the answer is supported by the information provided!
|
||||
|
||||
Try to keep your answer concise.
|
||||
|
||||
Here are is the relevant context information:
|
||||
\n-------\n
|
||||
{relevant_docs}
|
||||
\n-------\n
|
||||
|
||||
And here is the question I want you to answer based on the context above
|
||||
\n--\n
|
||||
{question}
|
||||
\n--\n
|
||||
|
||||
Answer:"""
|
||||
|
||||
REVISED_RAG_PROMPT = """\n
|
||||
{persona_specification}
|
||||
Use the information provided below - and only the
|
||||
provided information - to answer the provided question.
|
||||
|
||||
The information provided below consists of:
|
||||
1) an initial answer that was given but found to be lacking in some way.
|
||||
2) a number of answered sub-questions - these are very important(!) and definitely should be
|
||||
considered to answer the question.
|
||||
3) a number of documents that were also deemed relevant for the question.
|
||||
|
||||
IMPORTANT RULES:
|
||||
- If you cannot reliably answer the question solely using the provided information, say that you cannot reliably answer.
|
||||
You may give some additional facts you learned, but do not try to invent an answer.
|
||||
- If the information is empty or irrelevant, just say "I don't know".
|
||||
- If the information is relevant but not fully conclusive, provide and answer to the extent you can but also
|
||||
specify that the information is not conclusive and why.
|
||||
|
||||
Again, you should be sure that the answer is supported by the information provided!
|
||||
|
||||
Try to keep your answer concise. But also highlight uncertainties you may have should there be substantial ones,
|
||||
or assumptions you made.
|
||||
|
||||
Here is the contextual information:
|
||||
\n-------\n
|
||||
|
||||
*Initial Answer that was found to be lacking:
|
||||
{initial_answer}
|
||||
|
||||
*Answered Sub-questions (these should really matter! They also contain questions/answers that were not available when the original
|
||||
answer was constructed):
|
||||
{answered_sub_questions}
|
||||
|
||||
And here are relevant document information that support the sub-question answers, or that are relevant for the actual question:\n
|
||||
|
||||
{relevant_docs}
|
||||
|
||||
\n-------\n
|
||||
\n
|
||||
Lastly, here is the question I want you to answer based on the information above:
|
||||
\n--\n
|
||||
{question}
|
||||
\n--\n\n
|
||||
Answer:"""
|
||||
|
||||
# sub_question_answer_str is empty
|
||||
REVISED_RAG_PROMPT_NO_SUB_QUESTIONS = """{sub_question_answer_str}\n
|
||||
{persona_specification}
|
||||
Use the information provided below - and only the
|
||||
provided information - to answer the provided question.
|
||||
|
||||
The information provided below consists of:
|
||||
1) an initial answer that was given but found to be lacking in some way.
|
||||
2) a number of documents that were also deemed relevant for the question.
|
||||
|
||||
IMPORTANT RULES:
|
||||
- If you cannot reliably answer the question solely using the provided information, say that you cannot reliably answer.
|
||||
You may give some additional facts you learned, but do not try to invent an answer.
|
||||
- If the information is empty or irrelevant, just say "I don't know".
|
||||
- If the information is relevant but not fully conclusive, provide and answer to the extent you can but also
|
||||
specify that the information is not conclusive and why.
|
||||
|
||||
Again, you should be sure that the answer is supported by the information provided!
|
||||
|
||||
Try to keep your answer concise. But also highlight uncertainties you may have should there be substantial ones,
|
||||
or assumptions you made.
|
||||
|
||||
Here is the contextual information:
|
||||
\n-------\n
|
||||
|
||||
*Initial Answer that was found to be lacking:
|
||||
{initial_answer}
|
||||
|
||||
And here are relevant document information that support the sub-question answers, or that are relevant for the actual question:\n
|
||||
|
||||
{relevant_docs}
|
||||
|
||||
\n-------\n
|
||||
\n
|
||||
Lastly, here is the question I want you to answer based on the information above:
|
||||
\n--\n
|
||||
{question}
|
||||
\n--\n\n
|
||||
Answer:"""
|
||||
|
||||
|
||||
ENTITY_TERM_PROMPT = """ \n
|
||||
Based on the original question and the context retieved from a dataset, please generate a list of
|
||||
entities (e.g. companies, organizations, industries, products, locations, etc.), terms and concepts
|
||||
(e.g. sales, revenue, etc.) that are relevant for the question, plus their relations to each other.
|
||||
|
||||
\n\n
|
||||
Here is the original question:
|
||||
\n ------- \n
|
||||
{question}
|
||||
\n ------- \n
|
||||
And here is the context retrieved:
|
||||
\n ------- \n
|
||||
{context}
|
||||
\n ------- \n
|
||||
|
||||
Please format your answer as a json object in the following format:
|
||||
|
||||
{{"retrieved_entities_relationships": {{
|
||||
"entities": [{{
|
||||
"entity_name": <assign a name for the entity>,
|
||||
"entity_type": <specify a short type name for the entity, such as 'company', 'location',...>
|
||||
}}],
|
||||
"relationships": [{{
|
||||
"relationship_name": <assign a name for the relationship>,
|
||||
"relationship_type": <specify a short type name for the relationship, such as 'sales_to', 'is_location_of',...>,
|
||||
"relationship_entities": [<related entity name 1>, <related entity name 2>, ...]
|
||||
}}],
|
||||
"terms": [{{
|
||||
"term_name": <assign a name for the term>,
|
||||
"term_type": <specify a short type name for the term, such as 'revenue', 'market_share',...>,
|
||||
"term_similar_to": <list terms that are similar to this term>
|
||||
}}]
|
||||
}}
|
||||
}}
|
||||
"""
|
||||
243
backend/onyx/agent_search/shared_graph_utils/utils.py
Normal file
243
backend/onyx/agent_search/shared_graph_utils/utils.py
Normal file
@@ -0,0 +1,243 @@
|
||||
import ast
|
||||
import json
|
||||
import re
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from uuid import UUID
|
||||
|
||||
from langchain_core.messages import BaseMessage
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.agent_search.pro_search_a.main.models import EntityRelationshipTermExtraction
|
||||
from onyx.chat.models import AnswerStyleConfig
|
||||
from onyx.chat.models import CitationConfig
|
||||
from onyx.chat.models import DocumentPruningConfig
|
||||
from onyx.chat.models import PromptConfig
|
||||
from onyx.chat.models import ProSearchConfig
|
||||
from onyx.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
|
||||
from onyx.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
|
||||
from onyx.configs.constants import DEFAULT_PERSONA_ID
|
||||
from onyx.context.search.enums import LLMEvaluationType
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.context.search.models import RetrievalDetails
|
||||
from onyx.context.search.models import SearchRequest
|
||||
from onyx.db.persona import get_persona_by_id
|
||||
from onyx.db.persona import Persona
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.tools.tool_constructor import SearchToolConfig
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
|
||||
|
||||
def normalize_whitespace(text: str) -> str:
|
||||
"""Normalize whitespace in text to single spaces and strip leading/trailing whitespace."""
|
||||
import re
|
||||
|
||||
return re.sub(r"\s+", " ", text.strip())
|
||||
|
||||
|
||||
# Post-processing
|
||||
def format_docs(docs: Sequence[InferenceSection]) -> str:
|
||||
formatted_doc_list = []
|
||||
|
||||
for doc_nr, doc in enumerate(docs):
|
||||
formatted_doc_list.append(f"Document D{doc_nr + 1}:\n{doc.combined_content}")
|
||||
|
||||
return "\n\n".join(formatted_doc_list)
|
||||
|
||||
|
||||
def clean_and_parse_list_string(json_string: str) -> list[dict]:
|
||||
# Remove any prefixes/labels before the actual JSON content
|
||||
json_string = re.sub(r"^.*?(?=\[)", "", json_string, flags=re.DOTALL)
|
||||
|
||||
# Remove markdown code block markers and any newline prefixes
|
||||
cleaned_string = re.sub(r"```json\n|\n```", "", json_string)
|
||||
cleaned_string = cleaned_string.replace("\\n", " ").replace("\n", " ")
|
||||
cleaned_string = " ".join(cleaned_string.split())
|
||||
|
||||
# Try parsing with json.loads first, fall back to ast.literal_eval
|
||||
try:
|
||||
return json.loads(cleaned_string)
|
||||
except json.JSONDecodeError:
|
||||
try:
|
||||
return ast.literal_eval(cleaned_string)
|
||||
except (ValueError, SyntaxError) as e:
|
||||
raise ValueError(f"Failed to parse JSON string: {cleaned_string}") from e
|
||||
|
||||
|
||||
def clean_and_parse_json_string(json_string: str) -> dict[str, Any]:
|
||||
# Remove markdown code block markers and any newline prefixes
|
||||
cleaned_string = re.sub(r"```json\n|\n```", "", json_string)
|
||||
cleaned_string = cleaned_string.replace("\\n", " ").replace("\n", " ")
|
||||
cleaned_string = " ".join(cleaned_string.split())
|
||||
# Parse the cleaned string into a Python dictionary
|
||||
return json.loads(cleaned_string)
|
||||
|
||||
|
||||
def format_entity_term_extraction(
|
||||
entity_term_extraction_dict: EntityRelationshipTermExtraction,
|
||||
) -> str:
|
||||
entities = entity_term_extraction_dict.entities
|
||||
terms = entity_term_extraction_dict.terms
|
||||
relationships = entity_term_extraction_dict.relationships
|
||||
|
||||
entity_strs = ["\nEntities:\n"]
|
||||
for entity in entities:
|
||||
entity_str = f"{entity.entity_name} ({entity.entity_type})"
|
||||
entity_strs.append(entity_str)
|
||||
|
||||
entity_str = "\n - ".join(entity_strs)
|
||||
|
||||
relationship_strs = ["\n\nRelationships:\n"]
|
||||
for relationship in relationships:
|
||||
relationship_name = relationship.relationship_name
|
||||
relationship_type = relationship.relationship_type
|
||||
relationship_entities = relationship.relationship_entities
|
||||
relationship_str = (
|
||||
f"""{relationship_name} ({relationship_type}): {relationship_entities}"""
|
||||
)
|
||||
relationship_strs.append(relationship_str)
|
||||
|
||||
relationship_str = "\n - ".join(relationship_strs)
|
||||
|
||||
term_strs = ["\n\nTerms:\n"]
|
||||
for term in terms:
|
||||
term_str = f"{term.term_name} ({term.term_type}): similar to {', '.join(term.term_similar_to)}"
|
||||
term_strs.append(term_str)
|
||||
|
||||
term_str = "\n - ".join(term_strs)
|
||||
|
||||
return "\n".join(entity_strs + relationship_strs + term_strs)
|
||||
|
||||
|
||||
def _format_time_delta(time: timedelta) -> str:
|
||||
seconds_from_start = f"{((time).seconds):03d}"
|
||||
microseconds_from_start = f"{((time).microseconds):06d}"
|
||||
return f"{seconds_from_start}.{microseconds_from_start}"
|
||||
|
||||
|
||||
def generate_log_message(
|
||||
message: str,
|
||||
node_start_time: datetime,
|
||||
graph_start_time: datetime | None = None,
|
||||
) -> str:
|
||||
current_time = datetime.now()
|
||||
|
||||
if graph_start_time is not None:
|
||||
graph_time_str = _format_time_delta(current_time - graph_start_time)
|
||||
else:
|
||||
graph_time_str = "N/A"
|
||||
|
||||
node_time_str = _format_time_delta(current_time - node_start_time)
|
||||
|
||||
return f"{graph_time_str} ({node_time_str} s): {message}"
|
||||
|
||||
|
||||
def get_test_config(
|
||||
db_session: Session, primary_llm: LLM, fast_llm: LLM, search_request: SearchRequest
|
||||
) -> tuple[ProSearchConfig, SearchTool]:
|
||||
persona = get_persona_by_id(DEFAULT_PERSONA_ID, None, db_session)
|
||||
document_pruning_config = DocumentPruningConfig(
|
||||
max_chunks=int(
|
||||
persona.num_chunks
|
||||
if persona.num_chunks is not None
|
||||
else MAX_CHUNKS_FED_TO_CHAT
|
||||
),
|
||||
max_window_percentage=CHAT_TARGET_CHUNK_PERCENTAGE,
|
||||
)
|
||||
|
||||
answer_style_config = AnswerStyleConfig(
|
||||
citation_config=CitationConfig(
|
||||
# The docs retrieved by this flow are already relevance-filtered
|
||||
all_docs_useful=True
|
||||
),
|
||||
document_pruning_config=document_pruning_config,
|
||||
structured_response_format=None,
|
||||
)
|
||||
|
||||
search_tool_config = SearchToolConfig(
|
||||
answer_style_config=answer_style_config,
|
||||
document_pruning_config=document_pruning_config,
|
||||
retrieval_options=RetrievalDetails(), # may want to set dedupe_docs=True
|
||||
rerank_settings=None, # Can use this to change reranking model
|
||||
selected_sections=None,
|
||||
latest_query_files=None,
|
||||
bypass_acl=False,
|
||||
)
|
||||
|
||||
prompt_config = PromptConfig.from_model(persona.prompts[0])
|
||||
|
||||
search_tool = SearchTool(
|
||||
db_session=db_session,
|
||||
user=None,
|
||||
persona=persona,
|
||||
retrieval_options=search_tool_config.retrieval_options,
|
||||
prompt_config=prompt_config,
|
||||
llm=primary_llm,
|
||||
fast_llm=fast_llm,
|
||||
pruning_config=search_tool_config.document_pruning_config,
|
||||
answer_style_config=search_tool_config.answer_style_config,
|
||||
selected_sections=search_tool_config.selected_sections,
|
||||
chunks_above=search_tool_config.chunks_above,
|
||||
chunks_below=search_tool_config.chunks_below,
|
||||
full_doc=search_tool_config.full_doc,
|
||||
evaluation_type=(
|
||||
LLMEvaluationType.BASIC
|
||||
if persona.llm_relevance_filter
|
||||
else LLMEvaluationType.SKIP
|
||||
),
|
||||
rerank_settings=search_tool_config.rerank_settings,
|
||||
bypass_acl=search_tool_config.bypass_acl,
|
||||
)
|
||||
|
||||
config = ProSearchConfig(
|
||||
search_request=search_request,
|
||||
# chat_session_id=UUID("123e4567-e89b-12d3-a456-426614174000"),
|
||||
chat_session_id=UUID("edda10d5-6cef-45d8-acfb-39317552a1f4"), # Joachim
|
||||
# chat_session_id=UUID("d1acd613-2692-4bc3-9d65-c6d3da62e58e"), # Evan
|
||||
message_id=1,
|
||||
use_persistence=True,
|
||||
)
|
||||
|
||||
return config, search_tool
|
||||
|
||||
|
||||
def get_persona_prompt(persona: Persona | None) -> str:
|
||||
if persona is None:
|
||||
return ""
|
||||
else:
|
||||
return "\n".join([x.system_prompt for x in persona.prompts])
|
||||
|
||||
|
||||
def make_question_id(level: int, question_nr: int) -> str:
|
||||
return f"{level}_{question_nr}"
|
||||
|
||||
|
||||
def parse_question_id(question_id: str) -> tuple[int, int]:
|
||||
level, question_nr = question_id.split("_")
|
||||
return int(level), int(question_nr)
|
||||
|
||||
|
||||
def dispatch_separated(
|
||||
token_itr: Iterator[BaseMessage],
|
||||
dispatch_event: Callable[[str, int], None],
|
||||
sep: str = "\n",
|
||||
) -> list[str | list[str | dict[str, Any]]]:
|
||||
num = 0
|
||||
streamed_tokens: list[str | list[str | dict[str, Any]]] = [""]
|
||||
for message in token_itr:
|
||||
content = cast(str, message.content)
|
||||
if sep in content:
|
||||
for sub_question_part in content.split(sep):
|
||||
dispatch_event(sub_question_part, num)
|
||||
num += 1
|
||||
num -= 1 # fencepost; extra increment at end of loop
|
||||
else:
|
||||
dispatch_event(content, num)
|
||||
streamed_tokens.append(content)
|
||||
|
||||
return streamed_tokens
|
||||
@@ -1,17 +1,21 @@
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from uuid import uuid4
|
||||
|
||||
from langchain.schema.messages import BaseMessage
|
||||
from langchain_core.messages import AIMessageChunk
|
||||
from langchain_core.messages import ToolCall
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.agent_search.run_graph import run_basic_graph
|
||||
from onyx.agent_search.run_graph import run_main_graph
|
||||
from onyx.chat.llm_response_handler import LLMResponseHandlerManager
|
||||
from onyx.chat.models import AnswerQuestionPossibleReturn
|
||||
from onyx.chat.models import AnswerPacket
|
||||
from onyx.chat.models import AnswerStream
|
||||
from onyx.chat.models import AnswerStyleConfig
|
||||
from onyx.chat.models import CitationInfo
|
||||
from onyx.chat.models import OnyxAnswerPiece
|
||||
from onyx.chat.models import PromptConfig
|
||||
from onyx.chat.models import ProSearchConfig
|
||||
from onyx.chat.prompt_builder.build import AnswerPromptBuilder
|
||||
from onyx.chat.prompt_builder.build import default_build_system_message
|
||||
from onyx.chat.prompt_builder.build import default_build_user_message
|
||||
@@ -19,32 +23,24 @@ from onyx.chat.prompt_builder.build import LLMCall
|
||||
from onyx.chat.stream_processing.answer_response_handler import (
|
||||
CitationResponseHandler,
|
||||
)
|
||||
from onyx.chat.stream_processing.answer_response_handler import (
|
||||
DummyAnswerResponseHandler,
|
||||
)
|
||||
from onyx.chat.stream_processing.utils import (
|
||||
map_document_id_order,
|
||||
)
|
||||
from onyx.chat.tool_handling.tool_response_handler import get_tool_by_name
|
||||
from onyx.chat.tool_handling.tool_response_handler import ToolResponseHandler
|
||||
from onyx.file_store.utils import InMemoryChatFile
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.models import PreviousMessage
|
||||
from onyx.natural_language_processing.utils import get_tokenizer
|
||||
from onyx.tools.force import ForceUseTool
|
||||
from onyx.tools.models import ToolResponse
|
||||
from onyx.tools.tool import Tool
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from onyx.tools.tool_runner import ToolCallKickoff
|
||||
from onyx.tools.utils import explicit_tool_calling_supported
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
AnswerStream = Iterator[AnswerQuestionPossibleReturn | ToolCallKickoff | ToolResponse]
|
||||
|
||||
|
||||
class Answer:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -59,7 +55,6 @@ class Answer:
|
||||
# newly passed in files to include as part of this question
|
||||
# TODO THIS NEEDS TO BE HANDLED
|
||||
latest_query_files: list[InMemoryChatFile] | None = None,
|
||||
files: list[InMemoryChatFile] | None = None,
|
||||
tools: list[Tool] | None = None,
|
||||
# NOTE: for native tool-calling, this is only supported by OpenAI atm,
|
||||
# but we only support them anyways
|
||||
@@ -69,6 +64,9 @@ class Answer:
|
||||
return_contexts: bool = False,
|
||||
skip_gen_ai_answer_generation: bool = False,
|
||||
is_connected: Callable[[], bool] | None = None,
|
||||
pro_search_config: ProSearchConfig | None = None,
|
||||
fast_llm: LLM | None = None,
|
||||
db_session: Session | None = None,
|
||||
) -> None:
|
||||
if single_message_history and message_history:
|
||||
raise ValueError(
|
||||
@@ -79,7 +77,6 @@ class Answer:
|
||||
self.is_connected: Callable[[], bool] | None = is_connected
|
||||
|
||||
self.latest_query_files = latest_query_files or []
|
||||
self.file_id_to_file = {file.file_id: file for file in (files or [])}
|
||||
|
||||
self.tools = tools or []
|
||||
self.force_use_tool = force_use_tool
|
||||
@@ -92,6 +89,7 @@ class Answer:
|
||||
self.prompt_config = prompt_config
|
||||
|
||||
self.llm = llm
|
||||
self.fast_llm = fast_llm
|
||||
self.llm_tokenizer = get_tokenizer(
|
||||
provider_type=llm.config.model_provider,
|
||||
model_name=llm.config.model_name,
|
||||
@@ -100,9 +98,7 @@ class Answer:
|
||||
self._final_prompt: list[BaseMessage] | None = None
|
||||
|
||||
self._streamed_output: list[str] | None = None
|
||||
self._processed_stream: (
|
||||
list[AnswerQuestionPossibleReturn | ToolResponse | ToolCallKickoff] | None
|
||||
) = None
|
||||
self._processed_stream: (list[AnswerPacket] | None) = None
|
||||
|
||||
self._return_contexts = return_contexts
|
||||
self.skip_gen_ai_answer_generation = skip_gen_ai_answer_generation
|
||||
@@ -115,55 +111,28 @@ class Answer:
|
||||
and not skip_explicit_tool_calling
|
||||
)
|
||||
|
||||
self.pro_search_config = pro_search_config
|
||||
self.db_session = db_session
|
||||
|
||||
def _get_tools_list(self) -> list[Tool]:
|
||||
if not self.force_use_tool.force_use:
|
||||
return self.tools
|
||||
|
||||
tool = next(
|
||||
(t for t in self.tools if t.name == self.force_use_tool.tool_name), None
|
||||
)
|
||||
if tool is None:
|
||||
raise RuntimeError(f"Tool '{self.force_use_tool.tool_name}' not found")
|
||||
tool = get_tool_by_name(self.tools, self.force_use_tool.tool_name)
|
||||
|
||||
logger.info(
|
||||
f"Forcefully using tool='{tool.name}'"
|
||||
+ (
|
||||
f" with args='{self.force_use_tool.args}'"
|
||||
if self.force_use_tool.args is not None
|
||||
else ""
|
||||
)
|
||||
args_str = (
|
||||
f" with args='{self.force_use_tool.args}'"
|
||||
if self.force_use_tool.args
|
||||
else ""
|
||||
)
|
||||
logger.info(f"Forcefully using tool='{tool.name}'{args_str}")
|
||||
return [tool]
|
||||
|
||||
def _handle_specified_tool_call(
|
||||
self, llm_calls: list[LLMCall], tool: Tool, tool_args: dict
|
||||
) -> AnswerStream:
|
||||
current_llm_call = llm_calls[-1]
|
||||
|
||||
# make a dummy tool handler
|
||||
tool_handler = ToolResponseHandler([tool])
|
||||
|
||||
dummy_tool_call_chunk = AIMessageChunk(content="")
|
||||
dummy_tool_call_chunk.tool_calls = [
|
||||
ToolCall(name=tool.name, args=tool_args, id=str(uuid4()))
|
||||
]
|
||||
|
||||
response_handler_manager = LLMResponseHandlerManager(
|
||||
tool_handler, DummyAnswerResponseHandler(), self.is_cancelled
|
||||
)
|
||||
yield from response_handler_manager.handle_llm_response(
|
||||
iter([dummy_tool_call_chunk])
|
||||
)
|
||||
|
||||
new_llm_call = response_handler_manager.next_llm_call(current_llm_call)
|
||||
if new_llm_call:
|
||||
yield from self._get_response(llm_calls + [new_llm_call])
|
||||
else:
|
||||
raise RuntimeError("Tool call handler did not return a new LLM call")
|
||||
|
||||
# TODO: delete the function and move the full body to processed_streamed_output
|
||||
def _get_response(self, llm_calls: list[LLMCall]) -> AnswerStream:
|
||||
current_llm_call = llm_calls[-1]
|
||||
|
||||
tool, tool_args = None, None
|
||||
# handle the case where no decision has to be made; we simply run the tool
|
||||
if (
|
||||
current_llm_call.force_use_tool.force_use
|
||||
@@ -173,17 +142,10 @@ class Answer:
|
||||
current_llm_call.force_use_tool.tool_name,
|
||||
current_llm_call.force_use_tool.args,
|
||||
)
|
||||
tool = next(
|
||||
(t for t in current_llm_call.tools if t.name == tool_name), None
|
||||
)
|
||||
if not tool:
|
||||
raise RuntimeError(f"Tool '{tool_name}' not found")
|
||||
|
||||
yield from self._handle_specified_tool_call(llm_calls, tool, tool_args)
|
||||
return
|
||||
tool = get_tool_by_name(current_llm_call.tools, tool_name)
|
||||
|
||||
# special pre-logic for non-tool calling LLM case
|
||||
if not self.using_tool_calling_llm and current_llm_call.tools:
|
||||
elif not self.using_tool_calling_llm and current_llm_call.tools:
|
||||
chosen_tool_and_args = (
|
||||
ToolResponseHandler.get_tool_call_for_non_tool_calling_llm(
|
||||
current_llm_call, self.llm
|
||||
@@ -191,8 +153,24 @@ class Answer:
|
||||
)
|
||||
if chosen_tool_and_args:
|
||||
tool, tool_args = chosen_tool_and_args
|
||||
yield from self._handle_specified_tool_call(llm_calls, tool, tool_args)
|
||||
return
|
||||
|
||||
if tool and tool_args:
|
||||
dummy_tool_call_chunk = AIMessageChunk(content="")
|
||||
dummy_tool_call_chunk.tool_calls = [
|
||||
ToolCall(name=tool.name, args=tool_args, id=str(uuid4()))
|
||||
]
|
||||
|
||||
response_handler_manager = LLMResponseHandlerManager(
|
||||
ToolResponseHandler([tool]), None, self.is_cancelled
|
||||
)
|
||||
yield from response_handler_manager.handle_llm_response(
|
||||
iter([dummy_tool_call_chunk])
|
||||
)
|
||||
|
||||
tmp_call = response_handler_manager.next_llm_call(current_llm_call)
|
||||
if tmp_call is None:
|
||||
return # no more LLM calls to process
|
||||
current_llm_call = tmp_call
|
||||
|
||||
# if we're skipping gen ai answer generation, we should break
|
||||
# out unless we're forcing a tool call. If we don't, we might generate an
|
||||
@@ -212,29 +190,63 @@ class Answer:
|
||||
current_llm_call
|
||||
) or ([], [])
|
||||
|
||||
# Quotes are no longer supported
|
||||
# answer_handler: AnswerResponseHandler
|
||||
# if self.answer_style_config.citation_config:
|
||||
# answer_handler = CitationResponseHandler(
|
||||
# context_docs=search_result,
|
||||
# doc_id_to_rank_map=map_document_id_order(search_result),
|
||||
# )
|
||||
# elif self.answer_style_config.quotes_config:
|
||||
# answer_handler = QuotesResponseHandler(
|
||||
# context_docs=search_result,
|
||||
# )
|
||||
# else:
|
||||
# raise ValueError("No answer style config provided")
|
||||
# NEXT: we still want to handle the LLM response stream, but it is now:
|
||||
# 1. handle the tool call requests
|
||||
# 2. feed back the processed results
|
||||
# 3. handle the citations
|
||||
|
||||
answer_handler = CitationResponseHandler(
|
||||
context_docs=final_search_results,
|
||||
final_doc_id_to_rank_map=map_document_id_order(final_search_results),
|
||||
display_doc_id_to_rank_map=map_document_id_order(displayed_search_results),
|
||||
)
|
||||
|
||||
# At the moment, this wrapper class passes streamed stuff through citation and tool handlers.
|
||||
# In the future, we'll want to handle citations and tool calls in the langgraph graph.
|
||||
response_handler_manager = LLMResponseHandlerManager(
|
||||
tool_call_handler, answer_handler, self.is_cancelled
|
||||
)
|
||||
|
||||
# In langgraph, whether we do the basic thing (call llm stream) or pro search
|
||||
# is based on a flag in the pro search config
|
||||
|
||||
if self.pro_search_config:
|
||||
if self.pro_search_config.search_request is None:
|
||||
raise ValueError("Search request must be provided for pro search")
|
||||
search_tools = [tool for tool in self.tools if isinstance(tool, SearchTool)]
|
||||
if len(search_tools) == 0:
|
||||
raise ValueError("No search tool found")
|
||||
elif len(search_tools) > 1:
|
||||
# TODO: handle multiple search tools
|
||||
raise ValueError("Multiple search tools found")
|
||||
|
||||
search_tool = search_tools[0]
|
||||
if self.db_session is None:
|
||||
raise ValueError("db_session must be provided for pro search")
|
||||
if self.fast_llm is None:
|
||||
raise ValueError("fast_llm must be provided for pro search")
|
||||
|
||||
stream = run_main_graph(
|
||||
config=self.pro_search_config,
|
||||
primary_llm=self.llm,
|
||||
fast_llm=self.fast_llm,
|
||||
search_tool=search_tool,
|
||||
db_session=self.db_session,
|
||||
)
|
||||
else:
|
||||
stream = run_basic_graph(
|
||||
last_llm_call=current_llm_call,
|
||||
primary_llm=self.llm,
|
||||
answer_style_config=self.answer_style_config,
|
||||
response_handler_manager=response_handler_manager,
|
||||
)
|
||||
|
||||
processed_stream = []
|
||||
for packet in stream:
|
||||
processed_stream.append(packet)
|
||||
yield packet
|
||||
self._processed_stream = processed_stream
|
||||
return
|
||||
# DEBUG: good breakpoint
|
||||
stream = self.llm.stream(
|
||||
# For tool calling LLMs, we want to insert the task prompt as part of this flow, this is because the LLM
|
||||
|
||||
@@ -48,6 +48,7 @@ def prepare_chat_message_request(
|
||||
retrieval_details: RetrievalDetails | None,
|
||||
rerank_settings: RerankingDetails | None,
|
||||
db_session: Session,
|
||||
use_pro_search: bool = False,
|
||||
) -> CreateChatMessageRequest:
|
||||
# Typically used for one shot flows like SlackBot or non-chat API endpoint use cases
|
||||
new_chat_session = create_chat_session(
|
||||
@@ -72,6 +73,7 @@ def prepare_chat_message_request(
|
||||
search_doc_ids=None,
|
||||
retrieval_options=retrieval_details,
|
||||
rerank_settings=rerank_settings,
|
||||
use_pro_search=use_pro_search,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -9,18 +9,30 @@ from onyx.chat.models import StreamStopInfo
|
||||
from onyx.chat.models import StreamStopReason
|
||||
from onyx.chat.prompt_builder.build import LLMCall
|
||||
from onyx.chat.stream_processing.answer_response_handler import AnswerResponseHandler
|
||||
from onyx.chat.stream_processing.answer_response_handler import (
|
||||
DummyAnswerResponseHandler,
|
||||
)
|
||||
from onyx.chat.tool_handling.tool_response_handler import ToolResponseHandler
|
||||
|
||||
|
||||
class LLMResponseHandlerManager:
|
||||
"""
|
||||
This class is responsible for postprocessing the LLM response stream.
|
||||
In particular, we:
|
||||
1. handle the tool call requests
|
||||
2. handle citations
|
||||
3. pass through answers generated by the LLM
|
||||
4. Stop yielding if the client disconnects
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tool_handler: ToolResponseHandler,
|
||||
answer_handler: AnswerResponseHandler,
|
||||
tool_handler: ToolResponseHandler | None,
|
||||
answer_handler: AnswerResponseHandler | None,
|
||||
is_cancelled: Callable[[], bool],
|
||||
):
|
||||
self.tool_handler = tool_handler
|
||||
self.answer_handler = answer_handler
|
||||
self.tool_handler = tool_handler or ToolResponseHandler([])
|
||||
self.answer_handler = answer_handler or DummyAnswerResponseHandler()
|
||||
self.is_cancelled = is_cancelled
|
||||
|
||||
def handle_llm_response(
|
||||
|
||||
@@ -3,7 +3,9 @@ from collections.abc import Iterator
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
from typing import Literal
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
@@ -16,6 +18,8 @@ from onyx.context.search.enums import QueryFlow
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.context.search.enums import SearchType
|
||||
from onyx.context.search.models import RetrievalDocs
|
||||
from onyx.context.search.models import SearchRequest
|
||||
from onyx.llm.models import PreviousMessage
|
||||
from onyx.llm.override_models import PromptOverride
|
||||
from onyx.tools.models import ToolCallFinalResult
|
||||
from onyx.tools.models import ToolCallKickoff
|
||||
@@ -204,6 +208,30 @@ class PersonaOverrideConfig(BaseModel):
|
||||
custom_tools_openapi: list[dict[str, Any]] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ProSearchConfig(BaseModel):
|
||||
"""
|
||||
Configuration for the Pro Search feature.
|
||||
"""
|
||||
|
||||
use_agentic_search: bool = False
|
||||
|
||||
# For persisting agent search data
|
||||
chat_session_id: UUID | None = None
|
||||
# The message ID of the user message that triggered the Pro Search
|
||||
message_id: int | None = None
|
||||
# The search request that was used to generate the Pro Search
|
||||
search_request: SearchRequest
|
||||
|
||||
# Whether to persistence data for the Pro Search (turned off for testing)
|
||||
use_persistence: bool = True
|
||||
|
||||
# Whether to allow creation of refinement questions (and entity extraction, etc.)
|
||||
allow_refinement: bool = False
|
||||
|
||||
# Message history for the current chat session
|
||||
message_history: list[PreviousMessage] | None = None
|
||||
|
||||
|
||||
AnswerQuestionPossibleReturn = (
|
||||
OnyxAnswerPiece
|
||||
| CitationInfo
|
||||
@@ -327,3 +355,39 @@ ResponsePart = (
|
||||
| ToolCallFinalResult
|
||||
| StreamStopInfo
|
||||
)
|
||||
|
||||
|
||||
class SubQueryPiece(BaseModel):
|
||||
sub_query: str
|
||||
level: int
|
||||
level_question_nr: int
|
||||
query_id: int
|
||||
|
||||
|
||||
class AgentAnswerPiece(BaseModel):
|
||||
answer_piece: str
|
||||
level: int
|
||||
level_question_nr: int
|
||||
answer_type: Literal["agent_sub_answer", "agent_level_answer"]
|
||||
|
||||
|
||||
class SubQuestionPiece(BaseModel):
|
||||
sub_question: str
|
||||
level: int
|
||||
level_question_nr: int
|
||||
|
||||
|
||||
class ExtendedToolResponse(ToolResponse):
|
||||
level: int
|
||||
level_question_nr: int
|
||||
|
||||
|
||||
ProSearchPacket = (
|
||||
SubQuestionPiece | AgentAnswerPiece | SubQueryPiece | ExtendedToolResponse
|
||||
)
|
||||
|
||||
AnswerPacket = (
|
||||
AnswerQuestionPossibleReturn | ProSearchPacket | ToolCallKickoff | ToolResponse
|
||||
)
|
||||
|
||||
AnswerStream = Iterator[AnswerPacket]
|
||||
|
||||
@@ -24,6 +24,8 @@ from onyx.chat.models import MessageSpecificCitations
|
||||
from onyx.chat.models import OnyxAnswerPiece
|
||||
from onyx.chat.models import OnyxContexts
|
||||
from onyx.chat.models import PromptConfig
|
||||
from onyx.chat.models import ProSearchConfig
|
||||
from onyx.chat.models import ProSearchPacket
|
||||
from onyx.chat.models import QADocsResponse
|
||||
from onyx.chat.models import StreamingError
|
||||
from onyx.chat.models import StreamStopInfo
|
||||
@@ -33,11 +35,13 @@ from onyx.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.configs.constants import MilestoneRecordType
|
||||
from onyx.configs.constants import NO_AUTH_USER_ID
|
||||
from onyx.context.search.enums import LLMEvaluationType
|
||||
from onyx.context.search.enums import OptionalSearchSetting
|
||||
from onyx.context.search.enums import QueryFlow
|
||||
from onyx.context.search.enums import SearchType
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.context.search.models import RetrievalDetails
|
||||
from onyx.context.search.models import SearchRequest
|
||||
from onyx.context.search.retrieval.search_runner import inference_sections_from_ids
|
||||
from onyx.context.search.utils import chunks_or_sections_to_search_docs
|
||||
from onyx.context.search.utils import dedupe_documents
|
||||
@@ -281,6 +285,7 @@ ChatPacket = (
|
||||
| MessageSpecificCitations
|
||||
| MessageResponseIDInfo
|
||||
| StreamStopInfo
|
||||
| ProSearchPacket
|
||||
)
|
||||
ChatPacketStream = Iterator[ChatPacket]
|
||||
|
||||
@@ -683,6 +688,51 @@ def stream_chat_message_objects(
|
||||
for tool_list in tool_dict.values():
|
||||
tools.extend(tool_list)
|
||||
|
||||
message_history = [
|
||||
PreviousMessage.from_chat_message(msg, files) for msg in history_msgs
|
||||
]
|
||||
|
||||
search_request = None
|
||||
pro_search_config = None
|
||||
if new_msg_req.use_pro_search:
|
||||
search_request = SearchRequest(
|
||||
query=final_msg.message,
|
||||
evaluation_type=(
|
||||
LLMEvaluationType.BASIC
|
||||
if persona.llm_relevance_filter
|
||||
else LLMEvaluationType.SKIP
|
||||
),
|
||||
human_selected_filters=(
|
||||
retrieval_options.filters if retrieval_options else None
|
||||
),
|
||||
persona=persona,
|
||||
offset=(retrieval_options.offset if retrieval_options else None),
|
||||
limit=retrieval_options.limit if retrieval_options else None,
|
||||
rerank_settings=new_msg_req.rerank_settings,
|
||||
chunks_above=new_msg_req.chunks_above,
|
||||
chunks_below=new_msg_req.chunks_below,
|
||||
full_doc=new_msg_req.full_doc,
|
||||
enable_auto_detect_filters=(
|
||||
retrieval_options.enable_auto_detect_filters
|
||||
if retrieval_options
|
||||
else None
|
||||
),
|
||||
)
|
||||
# TODO: Since we're deleting the current main path in Answer,
|
||||
# we should construct this unconditionally inside Answer instead
|
||||
# Leaving it here for the time being to avoid breaking changes
|
||||
pro_search_config = (
|
||||
ProSearchConfig(
|
||||
search_request=search_request,
|
||||
chat_session_id=chat_session_id,
|
||||
message_id=user_message.id if user_message else None,
|
||||
message_history=message_history,
|
||||
)
|
||||
if new_msg_req.use_pro_search
|
||||
else None
|
||||
)
|
||||
# TODO: add previous messages, answer style config, tools, etc.
|
||||
|
||||
# LLM prompt building, response capturing, etc.
|
||||
answer = Answer(
|
||||
is_connected=is_connected,
|
||||
@@ -702,11 +752,12 @@ def stream_chat_message_objects(
|
||||
)
|
||||
)
|
||||
),
|
||||
message_history=[
|
||||
PreviousMessage.from_chat_message(msg, files) for msg in history_msgs
|
||||
],
|
||||
fast_llm=fast_llm,
|
||||
message_history=message_history,
|
||||
tools=tools,
|
||||
force_use_tool=_get_force_search_settings(new_msg_req, tools),
|
||||
pro_search_config=pro_search_config,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
reference_db_search_docs = None
|
||||
@@ -718,6 +769,7 @@ def stream_chat_message_objects(
|
||||
|
||||
for packet in answer.processed_streamed_output:
|
||||
if isinstance(packet, ToolResponse):
|
||||
# TODO: don't need to dedupe here when we do it in agent flow
|
||||
if packet.id == SEARCH_RESPONSE_SUMMARY_ID:
|
||||
(
|
||||
qa_docs_response,
|
||||
@@ -738,25 +790,30 @@ def stream_chat_message_objects(
|
||||
elif packet.id == SECTION_RELEVANCE_LIST_ID:
|
||||
relevance_sections = packet.response
|
||||
|
||||
if reference_db_search_docs is not None:
|
||||
llm_indices = relevant_sections_to_indices(
|
||||
relevance_sections=relevance_sections,
|
||||
items=[
|
||||
translate_db_search_doc_to_server_search_doc(doc)
|
||||
for doc in reference_db_search_docs
|
||||
],
|
||||
if reference_db_search_docs is None:
|
||||
logger.warning(
|
||||
"No reference docs found for relevance filtering"
|
||||
)
|
||||
continue
|
||||
|
||||
llm_indices = relevant_sections_to_indices(
|
||||
relevance_sections=relevance_sections,
|
||||
items=[
|
||||
translate_db_search_doc_to_server_search_doc(doc)
|
||||
for doc in reference_db_search_docs
|
||||
],
|
||||
)
|
||||
|
||||
if dropped_indices:
|
||||
llm_indices = drop_llm_indices(
|
||||
llm_indices=llm_indices,
|
||||
search_docs=reference_db_search_docs,
|
||||
dropped_indices=dropped_indices,
|
||||
)
|
||||
|
||||
if dropped_indices:
|
||||
llm_indices = drop_llm_indices(
|
||||
llm_indices=llm_indices,
|
||||
search_docs=reference_db_search_docs,
|
||||
dropped_indices=dropped_indices,
|
||||
)
|
||||
|
||||
yield LLMRelevanceFilterResponse(
|
||||
llm_selected_doc_indices=llm_indices
|
||||
)
|
||||
yield LLMRelevanceFilterResponse(
|
||||
llm_selected_doc_indices=llm_indices
|
||||
)
|
||||
elif packet.id == FINAL_CONTEXT_DOCUMENTS_ID:
|
||||
yield FinalUsedContextDocsResponse(
|
||||
final_context_docs=packet.response
|
||||
|
||||
@@ -147,6 +147,7 @@ class AnswerPromptBuilder:
|
||||
)
|
||||
|
||||
|
||||
# TODO: rename this? AnswerConfig maybe?
|
||||
class LLMCall(BaseModel__v1):
|
||||
prompt_builder: AnswerPromptBuilder
|
||||
tools: list[Tool]
|
||||
|
||||
@@ -25,6 +25,13 @@ from onyx.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def get_tool_by_name(tools: list[Tool], tool_name: str) -> Tool:
|
||||
for tool in tools:
|
||||
if tool.name == tool_name:
|
||||
return tool
|
||||
raise RuntimeError(f"Tool '{tool_name}' not found")
|
||||
|
||||
|
||||
class ToolResponseHandler:
|
||||
def __init__(self, tools: list[Tool]):
|
||||
self.tools = tools
|
||||
@@ -45,18 +52,7 @@ class ToolResponseHandler:
|
||||
) -> tuple[Tool, dict] | None:
|
||||
if llm_call.force_use_tool.force_use:
|
||||
# if we are forcing a tool, we don't need to check which tools to run
|
||||
tool = next(
|
||||
(
|
||||
t
|
||||
for t in llm_call.tools
|
||||
if t.name == llm_call.force_use_tool.tool_name
|
||||
),
|
||||
None,
|
||||
)
|
||||
if not tool:
|
||||
raise RuntimeError(
|
||||
f"Tool '{llm_call.force_use_tool.tool_name}' not found"
|
||||
)
|
||||
tool = get_tool_by_name(llm_call.tools, llm_call.force_use_tool.tool_name)
|
||||
|
||||
tool_args = (
|
||||
llm_call.force_use_tool.args
|
||||
@@ -118,20 +114,17 @@ class ToolResponseHandler:
|
||||
tool for tool in self.tools if tool.name == tool_call_request["name"]
|
||||
]
|
||||
|
||||
if not known_tools_by_name:
|
||||
logger.error(
|
||||
"Tool call requested with unknown name field. \n"
|
||||
f"self.tools: {self.tools}"
|
||||
f"tool_call_request: {tool_call_request}"
|
||||
)
|
||||
continue
|
||||
else:
|
||||
if known_tools_by_name:
|
||||
selected_tool = known_tools_by_name[0]
|
||||
selected_tool_call_request = tool_call_request
|
||||
|
||||
if selected_tool and selected_tool_call_request:
|
||||
break
|
||||
|
||||
logger.error(
|
||||
"Tool call requested with unknown name field. \n"
|
||||
f"self.tools: {self.tools}"
|
||||
f"tool_call_request: {tool_call_request}"
|
||||
)
|
||||
|
||||
if not selected_tool or not selected_tool_call_request:
|
||||
return
|
||||
|
||||
@@ -171,8 +164,6 @@ class ToolResponseHandler:
|
||||
else:
|
||||
self.tool_call_chunk += response_item # type: ignore
|
||||
|
||||
return
|
||||
|
||||
def next_llm_call(self, current_llm_call: LLMCall) -> LLMCall | None:
|
||||
if (
|
||||
self.tool_runner is None
|
||||
|
||||
57
backend/onyx/configs/dev_configs.py
Normal file
57
backend/onyx/configs/dev_configs.py
Normal file
@@ -0,0 +1,57 @@
|
||||
import os
|
||||
|
||||
from .chat_configs import NUM_RETURNED_HITS
|
||||
|
||||
|
||||
#####
|
||||
# Agent Configs
|
||||
#####
|
||||
|
||||
agent_retrieval_stats_os: bool | str | None = os.environ.get(
|
||||
"AGENT_RETRIEVAL_STATS", False
|
||||
)
|
||||
|
||||
AGENT_RETRIEVAL_STATS: bool = False
|
||||
if isinstance(agent_retrieval_stats_os, str) and agent_retrieval_stats_os == "True":
|
||||
AGENT_RETRIEVAL_STATS = True
|
||||
elif isinstance(agent_retrieval_stats_os, bool) and agent_retrieval_stats_os:
|
||||
AGENT_RETRIEVAL_STATS = True
|
||||
|
||||
agent_max_query_retrieval_results_os: int | str = os.environ.get(
|
||||
"AGENT_MAX_QUERY_RETRIEVAL_RESULTS", NUM_RETURNED_HITS
|
||||
)
|
||||
|
||||
AGENT_MAX_QUERY_RETRIEVAL_RESULTS: int = NUM_RETURNED_HITS
|
||||
try:
|
||||
atmqrr = int(agent_max_query_retrieval_results_os)
|
||||
AGENT_MAX_QUERY_RETRIEVAL_RESULTS = atmqrr
|
||||
except ValueError:
|
||||
raise ValueError(
|
||||
f"MAX_AGENT_QUERY_RETRIEVAL_RESULTS must be an integer, got {AGENT_MAX_QUERY_RETRIEVAL_RESULTS}"
|
||||
)
|
||||
|
||||
|
||||
# Reranking agent configs
|
||||
agent_reranking_stats_os: bool | str | None = os.environ.get(
|
||||
"AGENT_RERANKING_TEST", False
|
||||
)
|
||||
AGENT_RERANKING_STATS: bool = False
|
||||
if isinstance(agent_reranking_stats_os, str) and agent_reranking_stats_os == "True":
|
||||
AGENT_RERANKING_STATS = True
|
||||
elif isinstance(agent_reranking_stats_os, bool) and agent_reranking_stats_os:
|
||||
AGENT_RERANKING_STATS = True
|
||||
|
||||
|
||||
agent_reranking_max_query_retrieval_results_os: int | str = os.environ.get(
|
||||
"AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS", NUM_RETURNED_HITS
|
||||
)
|
||||
|
||||
AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS: int = NUM_RETURNED_HITS
|
||||
|
||||
try:
|
||||
atmqrr = int(agent_reranking_max_query_retrieval_results_os)
|
||||
AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS = atmqrr
|
||||
except ValueError:
|
||||
raise ValueError(
|
||||
f"AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS must be an integer, got {AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS}"
|
||||
)
|
||||
@@ -406,8 +406,18 @@ class SearchPipeline:
|
||||
|
||||
@property
|
||||
def section_relevance_list(self) -> list[bool]:
|
||||
llm_indices = relevant_sections_to_indices(
|
||||
relevance_sections=self.section_relevance,
|
||||
items=self.final_context_sections,
|
||||
return section_relevance_list_impl(
|
||||
section_relevance=self.section_relevance,
|
||||
final_context_sections=self.final_context_sections,
|
||||
)
|
||||
return [ind in llm_indices for ind in range(len(self.final_context_sections))]
|
||||
|
||||
|
||||
def section_relevance_list_impl(
|
||||
section_relevance: list[SectionRelevancePiece] | None,
|
||||
final_context_sections: list[InferenceSection],
|
||||
) -> list[bool]:
|
||||
llm_indices = relevant_sections_to_indices(
|
||||
relevance_sections=section_relevance,
|
||||
items=final_context_sections,
|
||||
)
|
||||
return [ind in llm_indices for ind in range(len(final_context_sections))]
|
||||
|
||||
@@ -80,7 +80,7 @@ def drop_llm_indices(
|
||||
search_docs: Sequence[DBSearchDoc | SavedSearchDoc],
|
||||
dropped_indices: list[int],
|
||||
) -> list[int]:
|
||||
llm_bools = [True if i in llm_indices else False for i in range(len(search_docs))]
|
||||
llm_bools = [i in llm_indices for i in range(len(search_docs))]
|
||||
if dropped_indices:
|
||||
llm_bools = [
|
||||
val for ind, val in enumerate(llm_bools) if ind not in dropped_indices
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import HTTPException
|
||||
@@ -15,13 +16,21 @@ from sqlalchemy.exc import MultipleResultsFound
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.agent_search.pro_search_a.answer_initial_sub_question.models import (
|
||||
QuestionAnswerResults,
|
||||
)
|
||||
from onyx.agent_search.pro_search_a.main.models import CombinedAgentMetrics
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.chat.models import DocumentRelevance
|
||||
from onyx.configs.chat_configs import HARD_DELETE_CHATS
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.context.search.models import RetrievalDocs
|
||||
from onyx.context.search.models import SavedSearchDoc
|
||||
from onyx.context.search.models import SearchDoc as ServerSearchDoc
|
||||
from onyx.db.models import AgentSearchMetrics
|
||||
from onyx.db.models import AgentSubQuery
|
||||
from onyx.db.models import AgentSubQuestion
|
||||
from onyx.db.models import ChatMessage
|
||||
from onyx.db.models import ChatMessage__SearchDoc
|
||||
from onyx.db.models import ChatSession
|
||||
@@ -37,6 +46,8 @@ from onyx.file_store.models import FileDescriptor
|
||||
from onyx.llm.override_models import LLMOverride
|
||||
from onyx.llm.override_models import PromptOverride
|
||||
from onyx.server.query_and_chat.models import ChatMessageDetail
|
||||
from onyx.server.query_and_chat.models import SubQueryDetail
|
||||
from onyx.server.query_and_chat.models import SubQuestionDetail
|
||||
from onyx.tools.tool_runner import ToolCallFinalResult
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -496,6 +507,7 @@ def get_chat_messages_by_session(
|
||||
prefetch_tool_calls: bool = False,
|
||||
) -> list[ChatMessage]:
|
||||
if not skip_permission_check:
|
||||
# bug if we ever call this expecting the permission check to not be skipped
|
||||
get_chat_session_by_id(
|
||||
chat_session_id=chat_session_id, user_id=user_id, db_session=db_session
|
||||
)
|
||||
@@ -507,7 +519,12 @@ def get_chat_messages_by_session(
|
||||
)
|
||||
|
||||
if prefetch_tool_calls:
|
||||
stmt = stmt.options(joinedload(ChatMessage.tool_call))
|
||||
stmt = stmt.options(
|
||||
joinedload(ChatMessage.tool_call),
|
||||
joinedload(ChatMessage.sub_questions).joinedload(
|
||||
AgentSubQuestion.sub_queries
|
||||
),
|
||||
)
|
||||
result = db_session.scalars(stmt).unique().all()
|
||||
else:
|
||||
result = db_session.scalars(stmt).all()
|
||||
@@ -837,14 +854,45 @@ def translate_db_search_doc_to_server_search_doc(
|
||||
)
|
||||
|
||||
|
||||
def get_retrieval_docs_from_chat_message(
|
||||
chat_message: ChatMessage, remove_doc_content: bool = False
|
||||
def translate_db_sub_questions_to_server_objects(
|
||||
db_sub_questions: list[AgentSubQuestion],
|
||||
) -> list[SubQuestionDetail]:
|
||||
sub_questions = []
|
||||
for sub_question in db_sub_questions:
|
||||
sub_queries = []
|
||||
docs: list[SearchDoc] = []
|
||||
for sub_query in sub_question.sub_queries:
|
||||
doc_ids = [doc.id for doc in sub_query.search_docs]
|
||||
sub_queries.append(
|
||||
SubQueryDetail(
|
||||
query=sub_query.sub_query,
|
||||
query_id=sub_query.id,
|
||||
doc_ids=doc_ids,
|
||||
)
|
||||
)
|
||||
docs += sub_query.search_docs
|
||||
|
||||
sub_questions.append(
|
||||
SubQuestionDetail(
|
||||
level=sub_question.level,
|
||||
level_question_nr=sub_question.level_question_nr,
|
||||
question=sub_question.sub_question,
|
||||
answer=sub_question.sub_answer,
|
||||
sub_queries=sub_queries,
|
||||
context_docs=get_retrieval_docs_from_search_docs(docs),
|
||||
)
|
||||
)
|
||||
return sub_questions
|
||||
|
||||
|
||||
def get_retrieval_docs_from_search_docs(
|
||||
search_docs: list[SearchDoc], remove_doc_content: bool = False
|
||||
) -> RetrievalDocs:
|
||||
top_documents = [
|
||||
translate_db_search_doc_to_server_search_doc(
|
||||
db_doc, remove_doc_content=remove_doc_content
|
||||
)
|
||||
for db_doc in chat_message.search_docs
|
||||
for db_doc in search_docs
|
||||
]
|
||||
top_documents = sorted(top_documents, key=lambda doc: doc.score, reverse=True) # type: ignore
|
||||
return RetrievalDocs(top_documents=top_documents)
|
||||
@@ -861,8 +909,8 @@ def translate_db_message_to_chat_message_detail(
|
||||
latest_child_message=chat_message.latest_child_message,
|
||||
message=chat_message.message,
|
||||
rephrased_query=chat_message.rephrased_query,
|
||||
context_docs=get_retrieval_docs_from_chat_message(
|
||||
chat_message, remove_doc_content=remove_doc_content
|
||||
context_docs=get_retrieval_docs_from_search_docs(
|
||||
chat_message.search_docs, remove_doc_content=remove_doc_content
|
||||
),
|
||||
message_type=chat_message.message_type,
|
||||
time_sent=chat_message.time_sent,
|
||||
@@ -877,6 +925,114 @@ def translate_db_message_to_chat_message_detail(
|
||||
else None,
|
||||
alternate_assistant_id=chat_message.alternate_assistant_id,
|
||||
overridden_model=chat_message.overridden_model,
|
||||
sub_questions=translate_db_sub_questions_to_server_objects(
|
||||
chat_message.sub_questions
|
||||
),
|
||||
)
|
||||
|
||||
return chat_msg_detail
|
||||
|
||||
|
||||
def log_agent_metrics(
|
||||
db_session: Session,
|
||||
user_id: UUID | None,
|
||||
persona_id: int | None, # Can be none if temporary persona is used
|
||||
agent_type: str,
|
||||
start_time: datetime,
|
||||
agent_metrics: CombinedAgentMetrics,
|
||||
) -> AgentSearchMetrics:
|
||||
agent_timings = agent_metrics.timings
|
||||
agent_base_metrics = agent_metrics.base_metrics
|
||||
agent_refined_metrics = agent_metrics.refined_metrics
|
||||
agent_additional_metrics = agent_metrics.additional_metrics
|
||||
|
||||
agent_metric_tracking = AgentSearchMetrics(
|
||||
user_id=user_id,
|
||||
persona_id=persona_id,
|
||||
agent_type=agent_type,
|
||||
start_time=start_time,
|
||||
base_duration__s=agent_timings.base_duration__s,
|
||||
full_duration__s=agent_timings.full_duration__s,
|
||||
base_metrics=vars(agent_base_metrics),
|
||||
refined_metrics=vars(agent_refined_metrics),
|
||||
all_metrics=vars(agent_additional_metrics),
|
||||
)
|
||||
|
||||
db_session.add(agent_metric_tracking)
|
||||
db_session.flush()
|
||||
|
||||
return agent_metric_tracking
|
||||
|
||||
|
||||
def log_agent_sub_question_results(
|
||||
db_session: Session,
|
||||
chat_session_id: UUID | None,
|
||||
primary_message_id: int | None,
|
||||
sub_question_answer_results: list[QuestionAnswerResults],
|
||||
) -> None:
|
||||
def _create_citation_format_list(
|
||||
document_citations: list[InferenceSection],
|
||||
) -> list[dict[str, Any]]:
|
||||
citation_list: list[dict[str, Any]] = []
|
||||
for document_citation in document_citations:
|
||||
document_citation_dict = {
|
||||
"link": "",
|
||||
"blurb": document_citation.center_chunk.blurb,
|
||||
"content": document_citation.center_chunk.content,
|
||||
"metadata": document_citation.center_chunk.metadata,
|
||||
"updated_at": str(document_citation.center_chunk.updated_at),
|
||||
"document_id": document_citation.center_chunk.document_id,
|
||||
"source_type": "file",
|
||||
"source_links": document_citation.center_chunk.source_links,
|
||||
"match_highlights": document_citation.center_chunk.match_highlights,
|
||||
"semantic_identifier": document_citation.center_chunk.semantic_identifier,
|
||||
}
|
||||
|
||||
citation_list.append(document_citation_dict)
|
||||
|
||||
return citation_list
|
||||
|
||||
now = datetime.now()
|
||||
|
||||
for sub_question_answer_result in sub_question_answer_results:
|
||||
level, level_question_nr = [
|
||||
int(x) for x in sub_question_answer_result.question_id.split("_")
|
||||
]
|
||||
sub_question = sub_question_answer_result.question
|
||||
sub_answer = sub_question_answer_result.answer
|
||||
sub_document_results = _create_citation_format_list(
|
||||
sub_question_answer_result.documents
|
||||
)
|
||||
sub_queries = [
|
||||
x.query for x in sub_question_answer_result.expanded_retrieval_results
|
||||
]
|
||||
|
||||
sub_question_object = AgentSubQuestion(
|
||||
chat_session_id=chat_session_id,
|
||||
primary_question_id=primary_message_id,
|
||||
level=level,
|
||||
level_question_nr=level_question_nr,
|
||||
sub_question=sub_question,
|
||||
sub_answer=sub_answer,
|
||||
sub_question_doc_results=sub_document_results,
|
||||
)
|
||||
|
||||
db_session.add(sub_question_object)
|
||||
db_session.commit()
|
||||
# db_session.flush()
|
||||
|
||||
sub_question_id = sub_question_object.id
|
||||
|
||||
for sub_query in sub_queries:
|
||||
sub_query_object = AgentSubQuery(
|
||||
parent_question_id=sub_question_id,
|
||||
chat_session_id=chat_session_id,
|
||||
sub_query=sub_query,
|
||||
time_created=now,
|
||||
)
|
||||
|
||||
db_session.add(sub_query_object)
|
||||
db_session.commit()
|
||||
# db_session.flush()
|
||||
|
||||
return None
|
||||
|
||||
@@ -320,6 +320,17 @@ class ChatMessage__SearchDoc(Base):
|
||||
)
|
||||
|
||||
|
||||
class AgentSubQuery__SearchDoc(Base):
|
||||
__tablename__ = "agent__sub_query__search_doc"
|
||||
|
||||
sub_query_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("agent__sub_query.id"), primary_key=True
|
||||
)
|
||||
search_doc_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("search_doc.id"), primary_key=True
|
||||
)
|
||||
|
||||
|
||||
class Document__Tag(Base):
|
||||
__tablename__ = "document__tag"
|
||||
|
||||
@@ -960,6 +971,11 @@ class SearchDoc(Base):
|
||||
secondary=ChatMessage__SearchDoc.__table__,
|
||||
back_populates="search_docs",
|
||||
)
|
||||
sub_queries = relationship(
|
||||
"AgentSubQuery",
|
||||
secondary=AgentSubQuery__SearchDoc.__table__,
|
||||
back_populates="search_docs",
|
||||
)
|
||||
|
||||
|
||||
class ToolCall(Base):
|
||||
@@ -1122,6 +1138,11 @@ class ChatMessage(Base):
|
||||
uselist=False,
|
||||
)
|
||||
|
||||
sub_questions: Mapped[list["AgentSubQuestion"]] = relationship(
|
||||
"AgentSubQuestion",
|
||||
back_populates="primary_message",
|
||||
)
|
||||
|
||||
standard_answers: Mapped[list["StandardAnswer"]] = relationship(
|
||||
"StandardAnswer",
|
||||
secondary=ChatMessage__StandardAnswer.__table__,
|
||||
@@ -1156,6 +1177,71 @@ class ChatFolder(Base):
|
||||
return self.display_priority < other.display_priority
|
||||
|
||||
|
||||
class AgentSubQuestion(Base):
|
||||
"""
|
||||
A sub-question is a question that is asked of the LLM to gather supporting
|
||||
information to answer a primary question.
|
||||
"""
|
||||
|
||||
__tablename__ = "agent__sub_question"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
primary_question_id: Mapped[int] = mapped_column(ForeignKey("chat_message.id"))
|
||||
chat_session_id: Mapped[UUID] = mapped_column(
|
||||
PGUUID(as_uuid=True), ForeignKey("chat_session.id")
|
||||
)
|
||||
sub_question: Mapped[str] = mapped_column(Text)
|
||||
level: Mapped[int] = mapped_column(Integer)
|
||||
level_question_nr: Mapped[int] = mapped_column(Integer)
|
||||
time_created: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
)
|
||||
sub_answer: Mapped[str] = mapped_column(Text)
|
||||
sub_question_doc_results: Mapped[JSON_ro] = mapped_column(postgresql.JSONB())
|
||||
|
||||
# Relationships
|
||||
primary_message: Mapped["ChatMessage"] = relationship(
|
||||
"ChatMessage",
|
||||
foreign_keys=[primary_question_id],
|
||||
back_populates="sub_questions",
|
||||
)
|
||||
chat_session: Mapped["ChatSession"] = relationship("ChatSession")
|
||||
sub_queries: Mapped[list["AgentSubQuery"]] = relationship(
|
||||
"AgentSubQuery", back_populates="parent_question"
|
||||
)
|
||||
|
||||
|
||||
class AgentSubQuery(Base):
|
||||
"""
|
||||
A sub-query is a vector DB query that gathers supporting information to answer a sub-question.
|
||||
"""
|
||||
|
||||
__tablename__ = "agent__sub_query"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
parent_question_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("agent__sub_question.id")
|
||||
)
|
||||
chat_session_id: Mapped[UUID] = mapped_column(
|
||||
PGUUID(as_uuid=True), ForeignKey("chat_session.id")
|
||||
)
|
||||
sub_query: Mapped[str] = mapped_column(Text)
|
||||
time_created: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
)
|
||||
|
||||
# Relationships
|
||||
parent_question: Mapped["AgentSubQuestion"] = relationship(
|
||||
"AgentSubQuestion", back_populates="sub_queries"
|
||||
)
|
||||
chat_session: Mapped["ChatSession"] = relationship("ChatSession")
|
||||
search_docs: Mapped[list["SearchDoc"]] = relationship(
|
||||
"SearchDoc",
|
||||
secondary=AgentSubQuery__SearchDoc.__table__,
|
||||
back_populates="sub_queries",
|
||||
)
|
||||
|
||||
|
||||
"""
|
||||
Feedback, Logging, Metrics Tables
|
||||
"""
|
||||
@@ -1641,6 +1727,25 @@ class PGFileStore(Base):
|
||||
lobj_oid: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
|
||||
|
||||
class AgentSearchMetrics(Base):
|
||||
__tablename__ = "agent__search_metrics"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
user_id: Mapped[UUID | None] = mapped_column(
|
||||
ForeignKey("user.id", ondelete="CASCADE"), nullable=True
|
||||
)
|
||||
persona_id: Mapped[int | None] = mapped_column(
|
||||
ForeignKey("persona.id"), nullable=True
|
||||
)
|
||||
agent_type: Mapped[str] = mapped_column(String)
|
||||
start_time: Mapped[datetime.datetime] = mapped_column(DateTime(timezone=True))
|
||||
base_duration__s: Mapped[float] = mapped_column(Float)
|
||||
full_duration__s: Mapped[float] = mapped_column(Float)
|
||||
base_metrics: Mapped[JSON_ro] = mapped_column(postgresql.JSONB(), nullable=True)
|
||||
refined_metrics: Mapped[JSON_ro] = mapped_column(postgresql.JSONB(), nullable=True)
|
||||
all_metrics: Mapped[JSON_ro] = mapped_column(postgresql.JSONB(), nullable=True)
|
||||
|
||||
|
||||
"""
|
||||
************************************************************************
|
||||
Enterprise Edition Models
|
||||
|
||||
@@ -125,6 +125,10 @@ class CreateChatMessageRequest(ChunkContext):
|
||||
# https://platform.openai.com/docs/guides/structured-outputs/introduction
|
||||
structured_response_format: dict | None = None
|
||||
|
||||
# If true, ignores most of the search options and uses pro search instead.
|
||||
# TODO: decide how many of the above options we want to pass through to pro search
|
||||
use_pro_search: bool = False
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_search_doc_ids_or_retrieval_options(self) -> "CreateChatMessageRequest":
|
||||
if self.search_doc_ids is None and self.retrieval_options is None:
|
||||
@@ -190,6 +194,22 @@ class SearchFeedbackRequest(BaseModel):
|
||||
return self
|
||||
|
||||
|
||||
class SubQueryDetail(BaseModel):
|
||||
query: str
|
||||
query_id: int
|
||||
# TODO: store these to enable per-query doc selection
|
||||
doc_ids: list[int] | None = None
|
||||
|
||||
|
||||
class SubQuestionDetail(BaseModel):
|
||||
level: int
|
||||
level_question_nr: int
|
||||
question: str
|
||||
answer: str
|
||||
sub_queries: list[SubQueryDetail] | None = None
|
||||
context_docs: RetrievalDocs | None = None
|
||||
|
||||
|
||||
class ChatMessageDetail(BaseModel):
|
||||
message_id: int
|
||||
parent_message: int | None = None
|
||||
@@ -201,9 +221,10 @@ class ChatMessageDetail(BaseModel):
|
||||
time_sent: datetime
|
||||
overridden_model: str | None
|
||||
alternate_assistant_id: int | None = None
|
||||
# Dict mapping citation number to db_doc_id
|
||||
chat_session_id: UUID | None = None
|
||||
# Dict mapping citation number to db_doc_id
|
||||
citations: dict[int, int] | None = None
|
||||
sub_questions: list[SubQuestionDetail] | None = None
|
||||
files: list[FileDescriptor]
|
||||
tool_call: ToolCallFinalResult | None
|
||||
|
||||
|
||||
@@ -25,6 +25,11 @@ class ToolCallSummary(BaseModel__v1):
|
||||
tool_call_request: AIMessage
|
||||
tool_call_result: ToolMessage
|
||||
|
||||
# This is a workaround to allow arbitrary types in the model
|
||||
# TODO: Remove this once we have a better solution
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
def tool_call_tokens(
|
||||
tool_call_summary: ToolCallSummary, llm_tokenizer: BaseTokenizer
|
||||
|
||||
@@ -4,6 +4,9 @@ from uuid import UUID
|
||||
from pydantic import BaseModel
|
||||
from pydantic import model_validator
|
||||
|
||||
from onyx.context.search.enums import SearchType
|
||||
from onyx.context.search.models import IndexFilters
|
||||
|
||||
|
||||
class ToolResponse(BaseModel):
|
||||
id: str | None = None
|
||||
@@ -45,5 +48,11 @@ class DynamicSchemaInfo(BaseModel):
|
||||
message_id: int | None
|
||||
|
||||
|
||||
class SearchQueryInfo(BaseModel):
|
||||
predicted_search: SearchType | None
|
||||
final_filters: IndexFilters
|
||||
recency_bias_multiplier: float
|
||||
|
||||
|
||||
CHAT_SESSION_ID_PLACEHOLDER = "CHAT_SESSION_ID"
|
||||
MESSAGE_ID_PLACEHOLDER = "MESSAGE_ID"
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import json
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.chat.chat_utils import llm_doc_from_inference_section
|
||||
@@ -25,13 +25,13 @@ from onyx.configs.chat_configs import CONTEXT_CHUNKS_BELOW
|
||||
from onyx.configs.model_configs import GEN_AI_MODEL_FALLBACK_MAX_TOKENS
|
||||
from onyx.context.search.enums import LLMEvaluationType
|
||||
from onyx.context.search.enums import QueryFlow
|
||||
from onyx.context.search.enums import SearchType
|
||||
from onyx.context.search.models import IndexFilters
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.context.search.models import RerankingDetails
|
||||
from onyx.context.search.models import RetrievalDetails
|
||||
from onyx.context.search.models import SearchRequest
|
||||
from onyx.context.search.pipeline import SearchPipeline
|
||||
from onyx.context.search.pipeline import section_relevance_list_impl
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import User
|
||||
from onyx.llm.interfaces import LLM
|
||||
@@ -39,6 +39,7 @@ from onyx.llm.models import PreviousMessage
|
||||
from onyx.secondary_llm_flows.choose_search import check_if_need_search
|
||||
from onyx.secondary_llm_flows.query_expansion import history_based_query_rephrase
|
||||
from onyx.tools.message import ToolCallSummary
|
||||
from onyx.tools.models import SearchQueryInfo
|
||||
from onyx.tools.models import ToolResponse
|
||||
from onyx.tools.tool import Tool
|
||||
from onyx.tools.tool_implementations.search.search_utils import llm_doc_to_dict
|
||||
@@ -62,13 +63,10 @@ SECTION_RELEVANCE_LIST_ID = "section_relevance_list"
|
||||
SEARCH_EVALUATION_ID = "llm_doc_eval"
|
||||
|
||||
|
||||
class SearchResponseSummary(BaseModel):
|
||||
class SearchResponseSummary(SearchQueryInfo):
|
||||
top_sections: list[InferenceSection]
|
||||
rephrased_query: str | None = None
|
||||
predicted_flow: QueryFlow | None
|
||||
predicted_search: SearchType | None
|
||||
final_filters: IndexFilters
|
||||
recency_bias_multiplier: float
|
||||
|
||||
|
||||
SEARCH_TOOL_DESCRIPTION = """
|
||||
@@ -117,6 +115,8 @@ class SearchTool(Tool):
|
||||
self.fast_llm = fast_llm
|
||||
self.evaluation_type = evaluation_type
|
||||
|
||||
self.search_pipeline: SearchPipeline | None = None
|
||||
|
||||
self.selected_sections = selected_sections
|
||||
|
||||
self.full_doc = full_doc
|
||||
@@ -281,8 +281,10 @@ class SearchTool(Tool):
|
||||
|
||||
yield ToolResponse(id=FINAL_CONTEXT_DOCUMENTS_ID, response=llm_docs)
|
||||
|
||||
def run(self, **kwargs: str) -> Generator[ToolResponse, None, None]:
|
||||
def run(self, **kwargs: Any) -> Generator[ToolResponse, None, None]:
|
||||
query = cast(str, kwargs["query"])
|
||||
force_no_rerank = cast(bool, kwargs.get("force_no_rerank", False))
|
||||
alternate_db_session = cast(Session, kwargs.get("alternate_db_session", None))
|
||||
|
||||
if self.selected_sections:
|
||||
yield from self._build_response_for_specified_sections(query)
|
||||
@@ -291,7 +293,9 @@ class SearchTool(Tool):
|
||||
search_pipeline = SearchPipeline(
|
||||
search_request=SearchRequest(
|
||||
query=query,
|
||||
evaluation_type=self.evaluation_type,
|
||||
evaluation_type=LLMEvaluationType.SKIP
|
||||
if force_no_rerank
|
||||
else self.evaluation_type,
|
||||
human_selected_filters=(
|
||||
self.retrieval_options.filters if self.retrieval_options else None
|
||||
),
|
||||
@@ -300,7 +304,16 @@ class SearchTool(Tool):
|
||||
self.retrieval_options.offset if self.retrieval_options else None
|
||||
),
|
||||
limit=self.retrieval_options.limit if self.retrieval_options else None,
|
||||
rerank_settings=self.rerank_settings,
|
||||
rerank_settings=RerankingDetails(
|
||||
rerank_model_name=None,
|
||||
rerank_api_url=None,
|
||||
rerank_provider_type=None,
|
||||
rerank_api_key=None,
|
||||
num_rerank=0,
|
||||
disable_rerank_for_streaming=True,
|
||||
)
|
||||
if force_no_rerank
|
||||
else self.rerank_settings,
|
||||
chunks_above=self.chunks_above,
|
||||
chunks_below=self.chunks_below,
|
||||
full_doc=self.full_doc,
|
||||
@@ -314,57 +327,25 @@ class SearchTool(Tool):
|
||||
llm=self.llm,
|
||||
fast_llm=self.fast_llm,
|
||||
bypass_acl=self.bypass_acl,
|
||||
db_session=self.db_session,
|
||||
db_session=alternate_db_session or self.db_session,
|
||||
prompt_config=self.prompt_config,
|
||||
)
|
||||
self.search_pipeline = search_pipeline # used for agent_search metrics
|
||||
|
||||
yield ToolResponse(
|
||||
id=SEARCH_RESPONSE_SUMMARY_ID,
|
||||
response=SearchResponseSummary(
|
||||
rephrased_query=query,
|
||||
top_sections=search_pipeline.final_context_sections,
|
||||
predicted_flow=search_pipeline.predicted_flow,
|
||||
predicted_search=search_pipeline.predicted_search_type,
|
||||
final_filters=search_pipeline.search_query.filters,
|
||||
recency_bias_multiplier=search_pipeline.search_query.recency_bias_multiplier,
|
||||
),
|
||||
search_query_info = SearchQueryInfo(
|
||||
predicted_search=search_pipeline.search_query.search_type,
|
||||
final_filters=search_pipeline.search_query.filters,
|
||||
recency_bias_multiplier=search_pipeline.search_query.recency_bias_multiplier,
|
||||
)
|
||||
|
||||
yield ToolResponse(
|
||||
id=SEARCH_DOC_CONTENT_ID,
|
||||
response=OnyxContexts(
|
||||
contexts=[
|
||||
OnyxContext(
|
||||
content=section.combined_content,
|
||||
document_id=section.center_chunk.document_id,
|
||||
semantic_identifier=section.center_chunk.semantic_identifier,
|
||||
blurb=section.center_chunk.blurb,
|
||||
)
|
||||
for section in search_pipeline.reranked_sections
|
||||
]
|
||||
),
|
||||
yield from yield_search_responses(
|
||||
query,
|
||||
search_pipeline.reranked_sections,
|
||||
search_pipeline.final_context_sections,
|
||||
search_query_info,
|
||||
lambda: search_pipeline.section_relevance,
|
||||
self,
|
||||
)
|
||||
|
||||
yield ToolResponse(
|
||||
id=SECTION_RELEVANCE_LIST_ID,
|
||||
response=search_pipeline.section_relevance,
|
||||
)
|
||||
|
||||
pruned_sections = prune_sections(
|
||||
sections=search_pipeline.final_context_sections,
|
||||
section_relevance_list=search_pipeline.section_relevance_list,
|
||||
prompt_config=self.prompt_config,
|
||||
llm_config=self.llm.config,
|
||||
question=query,
|
||||
contextual_pruning_config=self.contextual_pruning_config,
|
||||
)
|
||||
|
||||
llm_docs = [
|
||||
llm_doc_from_inference_section(section) for section in pruned_sections
|
||||
]
|
||||
|
||||
yield ToolResponse(id=FINAL_CONTEXT_DOCUMENTS_ID, response=llm_docs)
|
||||
|
||||
def final_result(self, *args: ToolResponse) -> JSON_ro:
|
||||
final_docs = cast(
|
||||
list[LlmDoc],
|
||||
@@ -425,3 +406,64 @@ class SearchTool(Tool):
|
||||
initial_search_results = cast(list[LlmDoc], initial_search_results)
|
||||
|
||||
return final_search_results, initial_search_results
|
||||
|
||||
|
||||
# Allows yielding the same responses as a SearchTool without being a SearchTool.
|
||||
# SearchTool passed in to allow for access to SearchTool properties.
|
||||
# We can't just call SearchTool methods in the graph because we're operating on
|
||||
# the retrieved docs (reranking, deduping, etc.) after the SearchTool has run.
|
||||
def yield_search_responses(
|
||||
query: str,
|
||||
reranked_sections: list[InferenceSection],
|
||||
final_context_sections: list[InferenceSection],
|
||||
search_query_info: SearchQueryInfo,
|
||||
get_section_relevance: Callable[[], list[SectionRelevancePiece] | None],
|
||||
search_tool: SearchTool,
|
||||
) -> Generator[ToolResponse, None, None]:
|
||||
yield ToolResponse(
|
||||
id=SEARCH_RESPONSE_SUMMARY_ID,
|
||||
response=SearchResponseSummary(
|
||||
rephrased_query=query,
|
||||
top_sections=final_context_sections,
|
||||
predicted_flow=QueryFlow.QUESTION_ANSWER,
|
||||
predicted_search=search_query_info.predicted_search,
|
||||
final_filters=search_query_info.final_filters,
|
||||
recency_bias_multiplier=search_query_info.recency_bias_multiplier,
|
||||
),
|
||||
)
|
||||
|
||||
yield ToolResponse(
|
||||
id=SEARCH_DOC_CONTENT_ID,
|
||||
response=OnyxContexts(
|
||||
contexts=[
|
||||
OnyxContext(
|
||||
content=section.combined_content,
|
||||
document_id=section.center_chunk.document_id,
|
||||
semantic_identifier=section.center_chunk.semantic_identifier,
|
||||
blurb=section.center_chunk.blurb,
|
||||
)
|
||||
for section in reranked_sections
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
section_relevance = get_section_relevance()
|
||||
yield ToolResponse(
|
||||
id=SECTION_RELEVANCE_LIST_ID,
|
||||
response=section_relevance,
|
||||
)
|
||||
|
||||
pruned_sections = prune_sections(
|
||||
sections=final_context_sections,
|
||||
section_relevance_list=section_relevance_list_impl(
|
||||
section_relevance, final_context_sections
|
||||
),
|
||||
prompt_config=search_tool.prompt_config,
|
||||
llm_config=search_tool.llm.config,
|
||||
question=query,
|
||||
contextual_pruning_config=search_tool.contextual_pruning_config,
|
||||
)
|
||||
|
||||
llm_docs = [llm_doc_from_inference_section(section) for section in pruned_sections]
|
||||
|
||||
yield ToolResponse(id=FINAL_CONTEXT_DOCUMENTS_ID, response=llm_docs)
|
||||
|
||||
@@ -49,11 +49,7 @@ def build_next_prompt_for_search_like_tool(
|
||||
message=prompt_builder.user_message_and_token_cnt[0],
|
||||
prompt_config=prompt_config,
|
||||
context_docs=final_context_documents,
|
||||
all_doc_useful=(
|
||||
answer_style_config.citation_config.all_docs_useful
|
||||
if answer_style_config.citation_config
|
||||
else False
|
||||
),
|
||||
all_doc_useful=(answer_style_config.citation_config.all_docs_useful),
|
||||
history_message=prompt_builder.single_message_history or "",
|
||||
)
|
||||
)
|
||||
|
||||
@@ -29,9 +29,14 @@ inflection==0.5.1
|
||||
jira==3.5.1
|
||||
jsonref==1.1.0
|
||||
trafilatura==1.12.2
|
||||
langchain==0.1.17
|
||||
langchain-core==0.1.50
|
||||
langchain-text-splitters==0.0.1
|
||||
langchain==0.3.7
|
||||
langchain-core==0.3.24
|
||||
langchain-openai==0.2.9
|
||||
langchain-text-splitters==0.3.2
|
||||
langchainhub==0.1.21
|
||||
langgraph==0.2.59
|
||||
langgraph-checkpoint==2.0.5
|
||||
langgraph-sdk==0.1.44
|
||||
litellm==1.55.4
|
||||
lxml==5.3.0
|
||||
lxml_html_clean==0.2.2
|
||||
|
||||
128
backend/tests/regression/answer_quality/agent_test.py
Normal file
128
backend/tests/regression/answer_quality/agent_test.py
Normal file
@@ -0,0 +1,128 @@
|
||||
import csv
|
||||
import datetime
|
||||
import json
|
||||
import os
|
||||
|
||||
import yaml
|
||||
|
||||
from onyx.agent_search.pro_search_a.main.graph_builder import main_graph_builder
|
||||
from onyx.agent_search.pro_search_a.main.states import MainInput
|
||||
from onyx.chat.models import ProSearchConfig
|
||||
from onyx.context.search.models import SearchRequest
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.llm.factory import get_default_llms
|
||||
|
||||
cwd = os.getcwd()
|
||||
CONFIG = yaml.safe_load(
|
||||
open(f"{cwd}/backend/tests/regression/answer_quality/search_test_config.yaml")
|
||||
)
|
||||
INPUT_DIR = CONFIG["agent_test_input_folder"]
|
||||
OUTPUT_DIR = CONFIG["agent_test_output_folder"]
|
||||
|
||||
|
||||
graph = main_graph_builder(test_mode=True)
|
||||
compiled_graph = graph.compile()
|
||||
primary_llm, fast_llm = get_default_llms()
|
||||
|
||||
# create a local json test data file and use it here
|
||||
|
||||
|
||||
input_file_object = open(
|
||||
f"{INPUT_DIR}/agent_test_data.json",
|
||||
)
|
||||
output_file = f"{OUTPUT_DIR}/agent_test_output.csv"
|
||||
|
||||
|
||||
test_data = json.load(input_file_object)
|
||||
example_data = test_data["examples"]
|
||||
example_ids = test_data["example_ids"]
|
||||
|
||||
with get_session_context_manager() as db_session:
|
||||
output_data = []
|
||||
|
||||
for example in example_data:
|
||||
example_id = example["id"]
|
||||
if len(example_ids) > 0 and example_id not in example_ids:
|
||||
continue
|
||||
|
||||
example_question = example["question"]
|
||||
target_sub_questions = example.get("target_sub_questions", [])
|
||||
num_target_sub_questions = len(target_sub_questions)
|
||||
search_request = SearchRequest(query=example_question)
|
||||
|
||||
config = ProSearchConfig(
|
||||
search_request=search_request,
|
||||
message_id=None,
|
||||
chat_session_id=None,
|
||||
use_persistence=False,
|
||||
)
|
||||
inputs = MainInput(
|
||||
config=config,
|
||||
primary_llm=primary_llm,
|
||||
fast_llm=fast_llm,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
start_time = datetime.datetime.now()
|
||||
|
||||
question_result = compiled_graph.invoke(input=inputs)
|
||||
end_time = datetime.datetime.now()
|
||||
|
||||
duration = end_time - start_time
|
||||
if num_target_sub_questions > 0:
|
||||
chunk_expansion_ratio = (
|
||||
question_result["initial_agent_stats"]
|
||||
.get("agent_effectiveness", {})
|
||||
.get("utilized_chunk_ratio", None)
|
||||
)
|
||||
support_effectiveness_ratio = (
|
||||
question_result["initial_agent_stats"]
|
||||
.get("agent_effectiveness", {})
|
||||
.get("support_ratio", None)
|
||||
)
|
||||
else:
|
||||
chunk_expansion_ratio = None
|
||||
support_effectiveness_ratio = None
|
||||
|
||||
generated_sub_questions = question_result.get("generated_sub_questions", [])
|
||||
num_generated_sub_questions = len(generated_sub_questions)
|
||||
base_answer = question_result["initial_base_answer"].split("==")[-1]
|
||||
agent_answer = question_result["initial_answer"].split("==")[-1]
|
||||
|
||||
output_point = {
|
||||
"example_id": example_id,
|
||||
"question": example_question,
|
||||
"duration": duration,
|
||||
"target_sub_questions": target_sub_questions,
|
||||
"generated_sub_questions": generated_sub_questions,
|
||||
"num_target_sub_questions": num_target_sub_questions,
|
||||
"num_generated_sub_questions": num_generated_sub_questions,
|
||||
"chunk_expansion_ratio": chunk_expansion_ratio,
|
||||
"support_effectiveness_ratio": support_effectiveness_ratio,
|
||||
"base_answer": base_answer,
|
||||
"agent_answer": agent_answer,
|
||||
}
|
||||
|
||||
output_data.append(output_point)
|
||||
|
||||
|
||||
with open(output_file, "w", newline="") as csvfile:
|
||||
fieldnames = [
|
||||
"example_id",
|
||||
"question",
|
||||
"duration",
|
||||
"target_sub_questions",
|
||||
"generated_sub_questions",
|
||||
"num_target_sub_questions",
|
||||
"num_generated_sub_questions",
|
||||
"chunk_expansion_ratio",
|
||||
"support_effectiveness_ratio",
|
||||
"base_answer",
|
||||
"agent_answer",
|
||||
]
|
||||
|
||||
writer = csv.DictWriter(csvfile, fieldnames=fieldnames, delimiter="\t")
|
||||
writer.writeheader()
|
||||
writer.writerows(output_data)
|
||||
|
||||
print("DONE")
|
||||
@@ -180,6 +180,7 @@ export function ChatPage({
|
||||
|
||||
const [documentSidebarToggled, setDocumentSidebarToggled] = useState(false);
|
||||
const [filtersToggled, setFiltersToggled] = useState(false);
|
||||
const [langgraphEnabled, setLanggraphEnabled] = useState(false);
|
||||
|
||||
const [userSettingsToggled, setUserSettingsToggled] = useState(false);
|
||||
|
||||
@@ -1264,6 +1265,7 @@ export function ChatPage({
|
||||
systemPromptOverride:
|
||||
searchParams.get(SEARCH_PARAM_NAMES.SYSTEM_PROMPT) || undefined,
|
||||
useExistingUserMessage: isSeededChat,
|
||||
useLanggraph: langgraphEnabled,
|
||||
});
|
||||
|
||||
const delay = (ms: number) => {
|
||||
@@ -2245,6 +2247,17 @@ export function ChatPage({
|
||||
hideUserDropdown={user?.is_anonymous_user}
|
||||
/>
|
||||
)}
|
||||
<div className="flex items-center justify-end px-4 py-2">
|
||||
<label className="flex items-center cursor-pointer">
|
||||
<span className="mr-2 text-sm">Langgraph</span>
|
||||
<input
|
||||
type="checkbox"
|
||||
checked={langgraphEnabled}
|
||||
onChange={(e) => setLanggraphEnabled(e.target.checked)}
|
||||
className="form-checkbox h-4 w-4"
|
||||
/>
|
||||
</label>
|
||||
</div>
|
||||
|
||||
{documentSidebarInitialWidth !== undefined && isReady ? (
|
||||
<Dropzone onDrop={handleImageUpload} noClick>
|
||||
|
||||
@@ -128,6 +128,7 @@ export async function* sendMessage({
|
||||
useExistingUserMessage,
|
||||
alternateAssistantId,
|
||||
signal,
|
||||
useLanggraph,
|
||||
}: {
|
||||
regenerate: boolean;
|
||||
message: string;
|
||||
@@ -146,6 +147,7 @@ export async function* sendMessage({
|
||||
useExistingUserMessage?: boolean;
|
||||
alternateAssistantId?: number;
|
||||
signal?: AbortSignal;
|
||||
useLanggraph?: boolean;
|
||||
}): AsyncGenerator<PacketType, void, unknown> {
|
||||
const documentsAreSelected =
|
||||
selectedDocumentIds && selectedDocumentIds.length > 0;
|
||||
@@ -186,6 +188,7 @@ export async function* sendMessage({
|
||||
}
|
||||
: null,
|
||||
use_existing_user_message: useExistingUserMessage,
|
||||
use_pro_search: useLanggraph,
|
||||
});
|
||||
|
||||
const response = await fetch(`/api/chat/send-message`, {
|
||||
|
||||
Reference in New Issue
Block a user