Compare commits

..

50 Commits

Author SHA1 Message Date
pablodanswer
e55dd89444 k 2025-01-23 18:51:14 -08:00
pablodanswer
811d564a0f updated + functional 2025-01-23 18:51:14 -08:00
pablodanswer
206b247ca5 update- reorg 2025-01-23 18:51:14 -08:00
pablodanswer
cd445f0e3f k 2025-01-23 18:51:14 -08:00
pablodanswer
967b0e5b0f build fix 2025-01-23 18:05:34 -08:00
joachim-danswer
fac9525833 Merge pull request #3757 from onyx-dot-app/agent-search-feature-jr-1
Changes addressing YS questions from 01/22/25
2025-01-23 15:20:16 -08:00
joachim-danswer
79fc4ae47d EL comments addressed 2025-01-23 15:18:43 -08:00
joachim-danswer
1c09c75e5f loser verification prompt 2025-01-23 14:11:34 -08:00
joachim-danswer
23ec33e411 turning off initial search pre route decision 2025-01-23 13:29:13 -08:00
joachim-danswer
fea429e11b change of sub-question answer if no docs recovered 2025-01-23 13:23:10 -08:00
joachim-danswer
f9d7d21d8e various fixes from Yuhong's list 2025-01-23 13:06:26 -08:00
Yuhong Sun
d2a8938545 Copy changes 2025-01-23 11:25:06 -08:00
evan-danswer
ac909f8437 Merge pull request #3752 from onyx-dot-app/asf-evan-async-task-cleanup
async task cleanup + basic citations
2025-01-23 10:47:29 -08:00
Evan Lohn
2f32111169 removed print statements, fixed pass through handling 2025-01-23 10:38:09 -08:00
Evan Lohn
23acb163f5 fixed basic flow citations and second test 2025-01-23 10:18:22 -08:00
Evan Lohn
ebe15b42d2 fix for early cancellation test; solves issue with tasks being destroyed while pending 2025-01-22 21:15:56 -08:00
pablodanswer
6cbb237945 add agent search frontend 2025-01-22 18:31:30 -08:00
Evan Lohn
6803548066 fix alembic history 2025-01-22 17:35:26 -08:00
joachim-danswer
1111ce6ce4 streaming + saving of search docs of no verified ones available
- sub-questions only
2025-01-22 17:30:36 -08:00
Evan Lohn
3f68e8ea8e reworked history messages in agent config 2025-01-22 17:30:36 -08:00
Evan Lohn
06a8373ff4 missed files from prev commit 2025-01-22 17:30:36 -08:00
Evan Lohn
86e770d968 basic search restructure: WIP on fixing tests 2025-01-22 17:30:36 -08:00
joachim-danswer
f11216132e prompts that even further motivates to cite docs over sub-q's 2025-01-22 17:30:36 -08:00
joachim-danswer
1f7d05cd75 pydantic for LangGraph + changed ERT extraction flow 2025-01-22 17:30:36 -08:00
joachim-danswer
c8bf051fb6 history added to agent flow 2025-01-22 17:30:36 -08:00
pablodanswer
14b54db033 minor fixes to branch 2025-01-22 17:30:36 -08:00
Evan Lohn
0e9f9301ba second clean commit 2025-01-22 17:30:36 -08:00
rkuo-danswer
69c60feda4 cloud check for migrations (#3734)
* cloud check for migrations

* fix table declaration

* change back interval

* Fix usage of POSTGRES_DEFAULT_SCHEMA

---------

Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2025-01-22 22:41:28 +00:00
pablonyx
a215ea9143 Performance monitoring (#3725)
* nit

* minimal

* config

* not too big a change

* k

* update

* update web push

* node options

* k

* update config

* attempt fix
2025-01-22 19:54:07 +00:00
pablonyx
f81a42b4e8 fix image edge case width screen size (#3738) 2025-01-22 18:54:00 +00:00
rkuo-danswer
b095e17827 Bugfix/watchdog signal (#3699)
* signal from the watchdog so that the monitor task doesn't try to clean up before it can exit

* ttl constants

* improve comment

---------

Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2025-01-22 17:51:06 +00:00
pablonyx
2a758ae33f Slack doc set fix (#3737) 2025-01-22 09:57:21 -08:00
hagen-danswer
3e58cf2667 Added ability to use a tag to insert the current datetime in prompts (#3697)
* Added ability to use a tag to insert the current datetime in prompts

* made tagging logic more robust

* rename

* k

---------

Co-authored-by: Yuhong Sun <yuhongsun96@gmail.com>
2025-01-22 16:17:20 +00:00
hagen-danswer
b9c29f2a36 Fix pagination for index attempts table DAN-1284 (#3722)
* Fix pagination for index attempts table

* fixed index attempts pagination

* fixed query history table

* query clearnup

* fixed test

* fixed weird tests???
2025-01-22 01:51:16 +00:00
Yuhong Sun
647adb9ba0 Change Persona to Assistant for Analytics Page (#3741) 2025-01-21 17:08:03 -08:00
pablonyx
7d6d73529b fix gmail connector (#3733) 2025-01-21 20:43:25 +00:00
Chris Weaver
420476ad92 Add basic passthrough auth (#3731)
* Add basic passthrough auth

* Add server-side validation

* Disallow for non-oauth

* Fix npm build
2025-01-20 23:39:23 -08:00
pablonyx
4ca7325d1a Finalize ux rework (#3720)
* colors

* nit

* finalize chat ux

* fix seeding waiting

* update chat input bar icons

* k

* Revert "fix seeding waiting"

This reverts commit e1aa93ff0c.
2025-01-21 01:09:16 +00:00
pablonyx
8ddd95d0d4 Fix exceptional seeding delay (#3723)
* fix seeding waiting

* k

* updated
2025-01-21 01:02:13 +00:00
Weves
1378364686 Pass in tenant_id to kv_store in monitoring job 2025-01-20 15:23:16 -08:00
pablonyx
cc4953b560 Slackbot optimization (#3696)
* initial pass

* update

* nit

* nit

* bot -> app

* nit

* quick update

* various improvements

* k

* k

* nit
2025-01-20 19:46:52 +00:00
pablonyx
fe3eae3680 Update JWT expiry time config (#3717)
* update redis configs

* update comment
2025-01-20 11:12:48 -08:00
hagen-danswer
2a7a22d953 fixed broken zendesk connector tests 2025-01-20 11:09:04 -08:00
pablonyx
f163b798ea Input Formik + hidden screen (#3715) 2025-01-20 10:16:10 -08:00
pablonyx
d4563b8693 Add linear check to PRs (#3708)
* add linear check

* Update pull_request_template.md
2025-01-20 03:48:22 +00:00
Weves
a54ed77140 Enhance airtable connector 2025-01-19 18:57:48 -08:00
Devin AI
f27979ef7f docs: fix typo in README.md ('Any many' -> 'And many')
Co-Authored-By: Chris Weaver <chris@onyx.app>
2025-01-19 14:26:39 -08:00
pablonyx
122a9af9b3 Polish (#3692) 2025-01-19 14:22:08 -08:00
pablodanswer
32a97e5479 fix bug 2025-01-19 13:42:23 -08:00
Chris Weaver
bf30dab9c4 Enable location support for Vertex AI (#3707) 2025-01-19 17:41:35 +00:00
267 changed files with 16743 additions and 4744 deletions

View File

@@ -11,5 +11,4 @@
Note: You have to check that the action passes, otherwise resolve the conflicts manually and tag the patches.
- [ ] This PR should be backported (make sure to check that the backport attempt succeeds)
- [ ] I have included a link to a Linear ticket in my description.
- [ ] [Optional] Override Linear Check

View File

@@ -67,6 +67,7 @@ jobs:
NEXT_PUBLIC_SENTRY_DSN=${{ secrets.SENTRY_DSN }}
NEXT_PUBLIC_GTM_ENABLED=true
NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED=true
NODE_OPTIONS=--max-old-space-size=8192
# needed due to weird interactions with the builds for different platforms
no-cache: true
labels: ${{ steps.meta.outputs.labels }}

View File

@@ -60,6 +60,8 @@ jobs:
push: true
build-args: |
ONYX_VERSION=${{ github.ref_name }}
NODE_OPTIONS=--max-old-space-size=8192
# needed due to weird interactions with the builds for different platforms
no-cache: true
labels: ${{ steps.meta.outputs.labels }}

4
.gitignore vendored
View File

@@ -7,4 +7,6 @@
.vscode/
*.sw?
/backend/tests/regression/answer_quality/search_test_config.yaml
/web/test-results/
/web/test-results/
backend/onyx/agent_search/main/test_data.json
backend/tests/regression/answer_quality/test_data.json

View File

@@ -52,3 +52,9 @@ BING_API_KEY=<REPLACE THIS>
# Enable the full set of Danswer Enterprise Edition features
# NOTE: DO NOT ENABLE THIS UNLESS YOU HAVE A PAID ENTERPRISE LICENSE (or if you are using this for local testing/development)
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=False
# Agent Search configs # TODO: Remove give proper namings
AGENT_RETRIEVAL_STATS=False # Note: This setting will incur substantial re-ranking effort
AGENT_RERANKING_STATS=True
AGENT_MAX_QUERY_RETRIEVAL_RESULTS=20
AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS=20

View File

@@ -119,7 +119,7 @@ There are two editions of Onyx:
- Whitelabeling
- API key authentication
- Encryption of secrets
- Any many more! Checkout [our website](https://www.onyx.app/) for the latest.
- And many more! Checkout [our website](https://www.onyx.app/) for the latest.
To try the Onyx Enterprise Edition:

1
Untitled-12 Normal file
View File

@@ -0,0 +1 @@

View File

@@ -0,0 +1,29 @@
"""agent_doc_result_col
Revision ID: 1adf5ea20d2b
Revises: e9cf2bd7baed
Create Date: 2025-01-05 13:14:58.344316
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "1adf5ea20d2b"
down_revision = "e9cf2bd7baed"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Add the new column with JSONB type
op.add_column(
"sub_question",
sa.Column("sub_question_doc_results", postgresql.JSONB(), nullable=True),
)
def downgrade() -> None:
# Drop the column
op.drop_column("sub_question", "sub_question_doc_results")

View File

@@ -0,0 +1,35 @@
"""agent_metric_col_rename__s
Revision ID: 925b58bd75b6
Revises: 9787be927e58
Create Date: 2025-01-06 11:20:26.752441
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "925b58bd75b6"
down_revision = "9787be927e58"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Rename columns using PostgreSQL syntax
op.alter_column(
"agent__search_metrics", "base_duration_s", new_column_name="base_duration__s"
)
op.alter_column(
"agent__search_metrics", "full_duration_s", new_column_name="full_duration__s"
)
def downgrade() -> None:
# Revert the column renames
op.alter_column(
"agent__search_metrics", "base_duration__s", new_column_name="base_duration_s"
)
op.alter_column(
"agent__search_metrics", "full_duration__s", new_column_name="full_duration_s"
)

View File

@@ -0,0 +1,25 @@
"""agent_metric_table_renames__agent__
Revision ID: 9787be927e58
Revises: bceb76d618ec
Create Date: 2025-01-06 11:01:44.210160
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "9787be927e58"
down_revision = "bceb76d618ec"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Rename table from agent_search_metrics to agent__search_metrics
op.rename_table("agent_search_metrics", "agent__search_metrics")
def downgrade() -> None:
# Rename table back from agent__search_metrics to agent_search_metrics
op.rename_table("agent__search_metrics", "agent_search_metrics")

View File

@@ -0,0 +1,42 @@
"""agent_tracking
Revision ID: 98a5008d8711
Revises: f1ca58b2f2ec
Create Date: 2025-01-04 14:41:52.732238
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "98a5008d8711"
down_revision = "f1ca58b2f2ec"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.create_table(
"agent_search_metrics",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=True),
sa.Column("persona_id", sa.Integer(), nullable=True),
sa.Column("agent_type", sa.String(), nullable=False),
sa.Column("start_time", sa.DateTime(timezone=True), nullable=False),
sa.Column("base_duration_s", sa.Float(), nullable=False),
sa.Column("full_duration_s", sa.Float(), nullable=False),
sa.Column("base_metrics", postgresql.JSONB(), nullable=True),
sa.Column("refined_metrics", postgresql.JSONB(), nullable=True),
sa.Column("all_metrics", postgresql.JSONB(), nullable=True),
sa.ForeignKeyConstraint(
["persona_id"],
["persona.id"],
),
sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"),
)
def downgrade() -> None:
op.drop_table("agent_search_metrics")

View File

@@ -0,0 +1,84 @@
"""agent_table_renames__agent__
Revision ID: bceb76d618ec
Revises: c0132518a25b
Create Date: 2025-01-06 10:50:48.109285
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "bceb76d618ec"
down_revision = "c0132518a25b"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.drop_constraint(
"sub_query__search_doc_sub_query_id_fkey",
"sub_query__search_doc",
type_="foreignkey",
)
op.drop_constraint(
"sub_query__search_doc_search_doc_id_fkey",
"sub_query__search_doc",
type_="foreignkey",
)
# Rename tables
op.rename_table("sub_query", "agent__sub_query")
op.rename_table("sub_question", "agent__sub_question")
op.rename_table("sub_query__search_doc", "agent__sub_query__search_doc")
# Update both foreign key constraints for agent__sub_query__search_doc
# Create new foreign keys with updated names
op.create_foreign_key(
"agent__sub_query__search_doc_sub_query_id_fkey",
"agent__sub_query__search_doc",
"agent__sub_query",
["sub_query_id"],
["id"],
)
op.create_foreign_key(
"agent__sub_query__search_doc_search_doc_id_fkey",
"agent__sub_query__search_doc",
"search_doc", # This table name doesn't change
["search_doc_id"],
["id"],
)
def downgrade() -> None:
# Update foreign key constraints for sub_query__search_doc
op.drop_constraint(
"agent__sub_query__search_doc_sub_query_id_fkey",
"agent__sub_query__search_doc",
type_="foreignkey",
)
op.drop_constraint(
"agent__sub_query__search_doc_search_doc_id_fkey",
"agent__sub_query__search_doc",
type_="foreignkey",
)
# Rename tables back
op.rename_table("agent__sub_query__search_doc", "sub_query__search_doc")
op.rename_table("agent__sub_question", "sub_question")
op.rename_table("agent__sub_query", "sub_query")
op.create_foreign_key(
"sub_query__search_doc_sub_query_id_fkey",
"sub_query__search_doc",
"sub_query",
["sub_query_id"],
["id"],
)
op.create_foreign_key(
"sub_query__search_doc_search_doc_id_fkey",
"sub_query__search_doc",
"search_doc", # This table name doesn't change
["search_doc_id"],
["id"],
)

View File

@@ -0,0 +1,40 @@
"""agent_table_changes_rename_level
Revision ID: c0132518a25b
Revises: 1adf5ea20d2b
Create Date: 2025-01-05 16:38:37.660152
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "c0132518a25b"
down_revision = "1adf5ea20d2b"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Add level and level_question_nr columns with NOT NULL constraint
op.add_column(
"sub_question",
sa.Column("level", sa.Integer(), nullable=False, server_default="0"),
)
op.add_column(
"sub_question",
sa.Column(
"level_question_nr", sa.Integer(), nullable=False, server_default="0"
),
)
# Remove the server_default after the columns are created
op.alter_column("sub_question", "level", server_default=None)
op.alter_column("sub_question", "level_question_nr", server_default=None)
def downgrade() -> None:
# Remove the columns
op.drop_column("sub_question", "level_question_nr")
op.drop_column("sub_question", "level")

View File

@@ -0,0 +1,68 @@
"""create pro search persistence tables
Revision ID: e9cf2bd7baed
Revises: 98a5008d8711
Create Date: 2025-01-02 17:55:56.544246
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects.postgresql import UUID
# revision identifiers, used by Alembic.
revision = "e9cf2bd7baed"
down_revision = "98a5008d8711"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Create sub_question table
op.create_table(
"sub_question",
sa.Column("id", sa.Integer, primary_key=True),
sa.Column("primary_question_id", sa.Integer, sa.ForeignKey("chat_message.id")),
sa.Column(
"chat_session_id", UUID(as_uuid=True), sa.ForeignKey("chat_session.id")
),
sa.Column("sub_question", sa.Text),
sa.Column(
"time_created", sa.DateTime(timezone=True), server_default=sa.func.now()
),
sa.Column("sub_answer", sa.Text),
)
# Create sub_query table
op.create_table(
"sub_query",
sa.Column("id", sa.Integer, primary_key=True),
sa.Column("parent_question_id", sa.Integer, sa.ForeignKey("sub_question.id")),
sa.Column(
"chat_session_id", UUID(as_uuid=True), sa.ForeignKey("chat_session.id")
),
sa.Column("sub_query", sa.Text),
sa.Column(
"time_created", sa.DateTime(timezone=True), server_default=sa.func.now()
),
)
# Create sub_query__search_doc association table
op.create_table(
"sub_query__search_doc",
sa.Column(
"sub_query_id", sa.Integer, sa.ForeignKey("sub_query.id"), primary_key=True
),
sa.Column(
"search_doc_id",
sa.Integer,
sa.ForeignKey("search_doc.id"),
primary_key=True,
),
)
def downgrade() -> None:
op.drop_table("sub_query__search_doc")
op.drop_table("sub_query")
op.drop_table("sub_question")

View File

@@ -0,0 +1,33 @@
"""add passthrough auth to tool
Revision ID: f1ca58b2f2ec
Revises: c7bf5721733e
Create Date: 2024-03-19
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = "f1ca58b2f2ec"
down_revision: Union[str, None] = "c7bf5721733e"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# Add passthrough_auth column to tool table with default value of False
op.add_column(
"tool",
sa.Column(
"passthrough_auth", sa.Boolean(), nullable=False, server_default=sa.false()
),
)
def downgrade() -> None:
# Remove passthrough_auth column from tool table
op.drop_column("tool", "passthrough_auth")

370
backend/chat_packets.log Normal file

File diff suppressed because one or more lines are too long

View File

@@ -98,10 +98,9 @@ def get_page_of_chat_sessions(
conditions = _build_filter_conditions(start_time, end_time, feedback_filter)
subquery = (
select(ChatSession.id, ChatSession.time_created)
select(ChatSession.id)
.filter(*conditions)
.order_by(ChatSession.id, desc(ChatSession.time_created))
.distinct(ChatSession.id)
.order_by(desc(ChatSession.time_created), ChatSession.id)
.limit(page_size)
.offset(page_num * page_size)
.subquery()
@@ -118,7 +117,11 @@ def get_page_of_chat_sessions(
ChatMessage.chat_message_feedbacks
),
)
.order_by(desc(ChatSession.time_created), asc(ChatMessage.id))
.order_by(
desc(ChatSession.time_created),
ChatSession.id,
asc(ChatMessage.id), # Ensure chronological message order
)
)
return db_session.scalars(stmt).unique().all()

View File

@@ -179,6 +179,7 @@ def handle_simplified_chat_message(
chunks_below=0,
full_doc=chat_message_req.full_doc,
structured_response_format=chat_message_req.structured_response_format,
use_agentic_search=chat_message_req.use_agentic_search,
)
packets = stream_chat_message_objects(
@@ -301,6 +302,7 @@ def handle_send_message_simple_with_history(
chunks_below=0,
full_doc=req.full_doc,
structured_response_format=req.structured_response_format,
use_agentic_search=req.use_agentic_search,
)
packets = stream_chat_message_objects(

View File

@@ -57,6 +57,9 @@ class BasicCreateChatMessageRequest(ChunkContext):
# https://platform.openai.com/docs/guides/structured-outputs/introduction
structured_response_format: dict | None = None
# If True, uses agentic search instead of basic search
use_agentic_search: bool = False
class BasicCreateChatMessageWithHistoryRequest(ChunkContext):
# Last element is the new query. All previous elements are historical context
@@ -71,6 +74,8 @@ class BasicCreateChatMessageWithHistoryRequest(ChunkContext):
# only works if using an OpenAI model. See the following for more details:
# https://platform.openai.com/docs/guides/structured-outputs/introduction
structured_response_format: dict | None = None
# If True, uses agentic search instead of basic search
use_agentic_search: bool = False
class SimpleDoc(BaseModel):
@@ -123,6 +128,9 @@ class OneShotQARequest(ChunkContext):
# If True, skips generative an AI response to the search query
skip_gen_ai_answer_generation: bool = False
# If True, uses pro search instead of basic search
use_agentic_search: bool = False
@model_validator(mode="after")
def check_persona_fields(self) -> "OneShotQARequest":
if self.persona_override_config is None and self.persona_id is None:

View File

@@ -196,6 +196,8 @@ def get_answer_stream(
retrieval_details=query_request.retrieval_options,
rerank_settings=query_request.rerank_settings,
db_session=db_session,
use_agentic_search=query_request.use_agentic_search,
skip_gen_ai_answer_generation=query_request.skip_gen_ai_answer_generation,
)
packets = stream_chat_message_objects(

View File

@@ -0,0 +1,71 @@
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
from onyx.agents.agent_search.basic.nodes.basic_use_tool_response import (
basic_use_tool_response,
)
from onyx.agents.agent_search.basic.nodes.llm_tool_choice import llm_tool_choice
from onyx.agents.agent_search.basic.nodes.tool_call import tool_call
from onyx.agents.agent_search.basic.states import BasicInput
from onyx.agents.agent_search.basic.states import BasicOutput
from onyx.agents.agent_search.basic.states import BasicState
from onyx.utils.logger import setup_logger
logger = setup_logger()
def basic_graph_builder() -> StateGraph:
graph = StateGraph(
state_schema=BasicState,
input=BasicInput,
output=BasicOutput,
)
### Add nodes ###
graph.add_node(
node="llm_tool_choice",
action=llm_tool_choice,
)
graph.add_node(
node="tool_call",
action=tool_call,
)
graph.add_node(
node="basic_use_tool_response",
action=basic_use_tool_response,
)
### Add edges ###
graph.add_edge(start_key=START, end_key="llm_tool_choice")
graph.add_conditional_edges("llm_tool_choice", should_continue, ["tool_call", END])
graph.add_edge(
start_key="tool_call",
end_key="basic_use_tool_response",
)
graph.add_edge(
start_key="basic_use_tool_response",
end_key=END,
)
return graph
def should_continue(state: BasicState) -> str:
return (
# If there are no tool calls, basic graph already streamed the answer
END
if state["tool_choice"] is None
else "tool_call"
)
if __name__ == "__main__":
pass

View File

@@ -0,0 +1,63 @@
from typing import cast
from langchain_core.runnables.config import RunnableConfig
from onyx.agents.agent_search.basic.states import BasicOutput
from onyx.agents.agent_search.basic.states import BasicState
from onyx.agents.agent_search.basic.utils import process_llm_stream
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.chat.models import LlmDoc
from onyx.tools.tool_implementations.search.search_tool import (
SEARCH_DOC_CONTENT_ID,
)
from onyx.tools.tool_implementations.search_like_tool_utils import (
FINAL_CONTEXT_DOCUMENTS_ID,
)
def basic_use_tool_response(state: BasicState, config: RunnableConfig) -> BasicOutput:
agent_config = cast(AgentSearchConfig, config["metadata"]["config"])
structured_response_format = agent_config.structured_response_format
llm = agent_config.primary_llm
tool_choice = state["tool_choice"]
if tool_choice is None:
raise ValueError("Tool choice is None")
tool = tool_choice["tool"]
prompt_builder = agent_config.prompt_builder
tool_call_summary = state["tool_call_summary"]
tool_call_responses = state["tool_call_responses"]
state["tool_call_final_result"]
new_prompt_builder = tool.build_next_prompt(
prompt_builder=prompt_builder,
tool_call_summary=tool_call_summary,
tool_responses=tool_call_responses,
using_tool_calling_llm=agent_config.using_tool_calling_llm,
)
final_search_results = []
initial_search_results = []
for yield_item in tool_call_responses:
if yield_item.id == FINAL_CONTEXT_DOCUMENTS_ID:
final_search_results = cast(list[LlmDoc], yield_item.response)
elif yield_item.id == SEARCH_DOC_CONTENT_ID:
search_contexts = yield_item.response.contexts
for doc in search_contexts:
if doc.document_id not in initial_search_results:
initial_search_results.append(doc)
initial_search_results = cast(list[LlmDoc], initial_search_results)
stream = llm.stream(
prompt=new_prompt_builder.build(),
structured_response_format=structured_response_format,
)
# For now, we don't do multiple tool calls, so we ignore the tool_message
process_llm_stream(
stream,
True,
final_search_results=final_search_results,
displayed_search_results=initial_search_results,
)
return BasicOutput()

View File

@@ -0,0 +1,134 @@
from typing import cast
from uuid import uuid4
from langchain_core.messages import ToolCall
from langchain_core.runnables.config import RunnableConfig
from onyx.agents.agent_search.basic.states import BasicState
from onyx.agents.agent_search.basic.states import ToolChoice
from onyx.agents.agent_search.basic.states import ToolChoiceUpdate
from onyx.agents.agent_search.basic.utils import process_llm_stream
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.chat.tool_handling.tool_response_handler import get_tool_by_name
from onyx.chat.tool_handling.tool_response_handler import (
get_tool_call_for_non_tool_calling_llm_impl,
)
from onyx.tools.tool import Tool
from onyx.utils.logger import setup_logger
logger = setup_logger()
# TODO: break this out into an implementation function
# and a function that handles extracting the necessary fields
# from the state and config
# TODO: fan-out to multiple tool call nodes? Make this configurable?
def llm_tool_choice(state: BasicState, config: RunnableConfig) -> ToolChoiceUpdate:
"""
This node is responsible for calling the LLM to choose a tool. If no tool is chosen,
The node MAY emit an answer, depending on whether state["should_stream_answer"] is set.
"""
should_stream_answer = state["should_stream_answer"]
agent_config = cast(AgentSearchConfig, config["metadata"]["config"])
using_tool_calling_llm = agent_config.using_tool_calling_llm
prompt_builder = agent_config.prompt_builder
llm = agent_config.primary_llm
skip_gen_ai_answer_generation = agent_config.skip_gen_ai_answer_generation
structured_response_format = agent_config.structured_response_format
tools = agent_config.tools or []
force_use_tool = agent_config.force_use_tool
tool, tool_args = None, None
if force_use_tool.force_use and force_use_tool.args is not None:
tool_name, tool_args = (
force_use_tool.tool_name,
force_use_tool.args,
)
tool = get_tool_by_name(tools, tool_name)
# special pre-logic for non-tool calling LLM case
elif not using_tool_calling_llm and tools:
chosen_tool_and_args = get_tool_call_for_non_tool_calling_llm_impl(
force_use_tool=force_use_tool,
tools=tools,
prompt_builder=prompt_builder,
llm=llm,
)
if chosen_tool_and_args:
tool, tool_args = chosen_tool_and_args
# If we have a tool and tool args, we are redy to request a tool call.
# This only happens if the tool call was forced or we are using a non-tool calling LLM.
if tool and tool_args:
return ToolChoiceUpdate(
tool_choice=ToolChoice(
tool=tool,
tool_args=tool_args,
id=str(uuid4()),
),
)
# if we're skipping gen ai answer generation, we should only
# continue if we're forcing a tool call (which will be emitted by
# the tool calling llm in the stream() below)
if skip_gen_ai_answer_generation and not force_use_tool.force_use:
return ToolChoiceUpdate(
tool_choice=None,
)
# At this point, we are either using a tool calling LLM or we are skipping the tool call.
# DEBUG: good breakpoint
stream = llm.stream(
# For tool calling LLMs, we want to insert the task prompt as part of this flow, this is because the LLM
# may choose to not call any tools and just generate the answer, in which case the task prompt is needed.
prompt=prompt_builder.build(),
tools=[tool.tool_definition() for tool in tools] or None,
tool_choice=("required" if tools and force_use_tool.force_use else None),
structured_response_format=structured_response_format,
)
tool_message = process_llm_stream(stream, should_stream_answer)
# If no tool calls are emitted by the LLM, we should not choose a tool
if len(tool_message.tool_calls) == 0:
return ToolChoiceUpdate(
tool_choice=None,
)
# TODO: here we could handle parallel tool calls. Right now
# we just pick the first one that matches.
selected_tool: Tool | None = None
selected_tool_call_request: ToolCall | None = None
for tool_call_request in tool_message.tool_calls:
known_tools_by_name = [
tool for tool in tools if tool.name == tool_call_request["name"]
]
if known_tools_by_name:
selected_tool = known_tools_by_name[0]
selected_tool_call_request = tool_call_request
break
logger.error(
"Tool call requested with unknown name field. \n"
f"tools: {tools}"
f"tool_call_request: {tool_call_request}"
)
if not selected_tool or not selected_tool_call_request:
raise ValueError(
f"Tool call attempted with tool {selected_tool}, request {selected_tool_call_request}"
)
logger.info(f"Selected tool: {selected_tool.name}")
logger.debug(f"Selected tool call request: {selected_tool_call_request}")
return ToolChoiceUpdate(
tool_choice=ToolChoice(
tool=selected_tool,
tool_args=selected_tool_call_request["args"],
id=selected_tool_call_request["id"],
),
)

View File

@@ -0,0 +1,69 @@
from typing import cast
from langchain_core.callbacks.manager import dispatch_custom_event
from langchain_core.messages import AIMessageChunk
from langchain_core.messages.tool import ToolCall
from langchain_core.runnables.config import RunnableConfig
from onyx.agents.agent_search.basic.states import BasicState
from onyx.agents.agent_search.basic.states import ToolCallUpdate
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.chat.models import AnswerPacket
from onyx.tools.message import build_tool_message
from onyx.tools.message import ToolCallSummary
from onyx.tools.tool_runner import ToolRunner
from onyx.utils.logger import setup_logger
logger = setup_logger()
def emit_packet(packet: AnswerPacket) -> None:
dispatch_custom_event("basic_response", packet)
# TODO: handle is_cancelled
def tool_call(state: BasicState, config: RunnableConfig) -> ToolCallUpdate:
"""Calls the tool specified in the state and updates the state with the result"""
# TODO: implement
cast(AgentSearchConfig, config["metadata"]["config"])
# Unnecessary now, node should only be called if there is a tool call
# if not self.tool_call_chunk or not self.tool_call_chunk.tool_calls:
# return
tool_choice = state["tool_choice"]
if tool_choice is None:
raise ValueError("Cannot invoke tool call node without a tool choice")
tool = tool_choice["tool"]
tool_args = tool_choice["tool_args"]
tool_id = tool_choice["id"]
tool_runner = ToolRunner(tool, tool_args)
tool_kickoff = tool_runner.kickoff()
# TODO: custom events for yields
emit_packet(tool_kickoff)
tool_responses = []
for response in tool_runner.tool_responses():
tool_responses.append(response)
emit_packet(response)
tool_final_result = tool_runner.tool_final_result()
emit_packet(tool_final_result)
tool_call = ToolCall(name=tool.name, args=tool_args, id=tool_id)
tool_call_summary = ToolCallSummary(
tool_call_request=AIMessageChunk(content="", tool_calls=[tool_call]),
tool_call_result=build_tool_message(
tool_call, tool_runner.tool_message_content()
),
)
return ToolCallUpdate(
tool_call_summary=tool_call_summary,
tool_call_kickoff=tool_kickoff,
tool_call_responses=tool_responses,
tool_call_final_result=tool_final_result,
)

View File

@@ -0,0 +1,55 @@
from typing import TypedDict
from onyx.tools.message import ToolCallSummary
from onyx.tools.models import ToolCallFinalResult
from onyx.tools.models import ToolCallKickoff
from onyx.tools.models import ToolResponse
from onyx.tools.tool import Tool
# States contain values that change over the course of graph execution,
# Config is for values that are set at the start and never change.
# If you are using a value from the config and realize it needs to change,
# you should add it to the state and use/update the version in the state.
## Graph Input State
class BasicInput(TypedDict):
should_stream_answer: bool
## Graph Output State
class BasicOutput(TypedDict):
pass
## Update States
class ToolCallUpdate(TypedDict):
tool_call_summary: ToolCallSummary
tool_call_kickoff: ToolCallKickoff
tool_call_responses: list[ToolResponse]
tool_call_final_result: ToolCallFinalResult
class ToolChoice(TypedDict):
tool: Tool
tool_args: dict
id: str | None
class ToolChoiceUpdate(TypedDict):
tool_choice: ToolChoice | None
## Graph State
class BasicState(
BasicInput,
ToolCallUpdate,
ToolChoiceUpdate,
BasicOutput,
):
pass

View File

@@ -0,0 +1,67 @@
from collections.abc import Iterator
from typing import cast
from langchain_core.callbacks.manager import dispatch_custom_event
from langchain_core.messages import AIMessageChunk
from langchain_core.messages import BaseMessage
from onyx.chat.models import LlmDoc
from onyx.chat.stream_processing.answer_response_handler import AnswerResponseHandler
from onyx.chat.stream_processing.answer_response_handler import CitationResponseHandler
from onyx.chat.stream_processing.answer_response_handler import (
PassThroughAnswerResponseHandler,
)
from onyx.chat.stream_processing.utils import map_document_id_order
from onyx.utils.logger import setup_logger
logger = setup_logger()
# TODO: handle citations here; below is what was previously passed in
# see basic_use_tool_response.py for where these variables come from
# answer_handler = CitationResponseHandler(
# context_docs=final_search_results,
# final_doc_id_to_rank_map=map_document_id_order(final_search_results),
# display_doc_id_to_rank_map=map_document_id_order(displayed_search_results),
# )
def process_llm_stream(
stream: Iterator[BaseMessage],
should_stream_answer: bool,
final_search_results: list[LlmDoc] | None = None,
displayed_search_results: list[LlmDoc] | None = None,
) -> AIMessageChunk:
tool_call_chunk = AIMessageChunk(content="")
# for response in response_handler_manager.handle_llm_response(stream):
if final_search_results and displayed_search_results:
answer_handler: AnswerResponseHandler = CitationResponseHandler(
context_docs=final_search_results,
final_doc_id_to_rank_map=map_document_id_order(final_search_results),
display_doc_id_to_rank_map=map_document_id_order(displayed_search_results),
)
else:
answer_handler = PassThroughAnswerResponseHandler()
# This stream will be the llm answer if no tool is chosen. When a tool is chosen,
# the stream will contain AIMessageChunks with tool call information.
for response in stream:
answer_piece = response.content
if not isinstance(answer_piece, str):
# TODO: handle non-string content
logger.warning(f"Received non-string content: {type(answer_piece)}")
answer_piece = str(answer_piece)
if isinstance(response, AIMessageChunk) and (
response.tool_call_chunks or response.tool_calls
):
tool_call_chunk += response # type: ignore
elif should_stream_answer:
# TODO: handle emitting of CitationInfo
for response_part in answer_handler.handle_response_part(response, []):
dispatch_custom_event(
"basic_response",
response_part,
)
return cast(AIMessageChunk, tool_call_chunk)

View File

@@ -0,0 +1,21 @@
from operator import add
from typing import Annotated
from pydantic import BaseModel
class CoreState(BaseModel):
"""
This is the core state that is shared across all subgraphs.
"""
base_question: str = ""
log_messages: Annotated[list[str], add] = []
class SubgraphCoreState(BaseModel):
"""
This is the core state that is shared across all subgraphs.
"""
log_messages: Annotated[list[str], add]

View File

@@ -0,0 +1,66 @@
from uuid import UUID
from sqlalchemy.orm import Session
from onyx.db.models import AgentSubQuery
from onyx.db.models import AgentSubQuestion
def create_sub_question(
db_session: Session,
chat_session_id: UUID,
primary_message_id: int,
sub_question: str,
sub_answer: str,
) -> AgentSubQuestion:
"""Create a new sub-question record in the database."""
sub_q = AgentSubQuestion(
chat_session_id=chat_session_id,
primary_question_id=primary_message_id,
sub_question=sub_question,
sub_answer=sub_answer,
)
db_session.add(sub_q)
db_session.flush()
return sub_q
def create_sub_query(
db_session: Session,
chat_session_id: UUID,
parent_question_id: int,
sub_query: str,
) -> AgentSubQuery:
"""Create a new sub-query record in the database."""
sub_q = AgentSubQuery(
chat_session_id=chat_session_id,
parent_question_id=parent_question_id,
sub_query=sub_query,
)
db_session.add(sub_q)
db_session.flush()
return sub_q
def get_sub_questions_for_message(
db_session: Session,
primary_message_id: int,
) -> list[AgentSubQuestion]:
"""Get all sub-questions for a given primary message."""
return (
db_session.query(AgentSubQuestion)
.filter(AgentSubQuestion.primary_question_id == primary_message_id)
.all()
)
def get_sub_queries_for_question(
db_session: Session,
sub_question_id: int,
) -> list[AgentSubQuery]:
"""Get all sub-queries for a given sub-question."""
return (
db_session.query(AgentSubQuery)
.filter(AgentSubQuery.parent_question_id == sub_question_id)
.all()
)

View File

@@ -0,0 +1,29 @@
from collections.abc import Hashable
from datetime import datetime
from langgraph.types import Send
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.states import (
AnswerQuestionInput,
)
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
ExpandedRetrievalInput,
)
from onyx.utils.logger import setup_logger
logger = setup_logger()
def send_to_expanded_retrieval(state: AnswerQuestionInput) -> Send | Hashable:
logger.debug("sending to expanded retrieval via edge")
now_start = datetime.now()
return Send(
"initial_sub_question_expanded_retrieval",
ExpandedRetrievalInput(
question=state.question,
base_search=False,
sub_question_id=state.question_id,
log_messages=[f"{now_start} -- Sending to expanded retrieval"],
),
)

View File

@@ -0,0 +1,126 @@
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.edges import (
send_to_expanded_retrieval,
)
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.nodes.answer_check import (
answer_check,
)
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.nodes.answer_generation import (
answer_generation,
)
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.nodes.format_answer import (
format_answer,
)
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.nodes.ingest_retrieval import (
ingest_retrieval,
)
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.states import (
AnswerQuestionInput,
)
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.states import (
AnswerQuestionOutput,
)
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.states import (
AnswerQuestionState,
)
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.graph_builder import (
expanded_retrieval_graph_builder,
)
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
from onyx.utils.logger import setup_logger
logger = setup_logger()
def answer_query_graph_builder() -> StateGraph:
graph = StateGraph(
state_schema=AnswerQuestionState,
input=AnswerQuestionInput,
output=AnswerQuestionOutput,
)
### Add nodes ###
expanded_retrieval = expanded_retrieval_graph_builder().compile()
graph.add_node(
node="initial_sub_question_expanded_retrieval",
action=expanded_retrieval,
)
graph.add_node(
node="answer_check",
action=answer_check,
)
graph.add_node(
node="answer_generation",
action=answer_generation,
)
graph.add_node(
node="format_answer",
action=format_answer,
)
graph.add_node(
node="ingest_retrieval",
action=ingest_retrieval,
)
### Add edges ###
graph.add_conditional_edges(
source=START,
path=send_to_expanded_retrieval,
path_map=["initial_sub_question_expanded_retrieval"],
)
graph.add_edge(
start_key="initial_sub_question_expanded_retrieval",
end_key="ingest_retrieval",
)
graph.add_edge(
start_key="ingest_retrieval",
end_key="answer_generation",
)
graph.add_edge(
start_key="answer_generation",
end_key="answer_check",
)
graph.add_edge(
start_key="answer_check",
end_key="format_answer",
)
graph.add_edge(
start_key="format_answer",
end_key=END,
)
return graph
if __name__ == "__main__":
from onyx.db.engine import get_session_context_manager
from onyx.llm.factory import get_default_llms
from onyx.context.search.models import SearchRequest
graph = answer_query_graph_builder()
compiled_graph = graph.compile()
primary_llm, fast_llm = get_default_llms()
search_request = SearchRequest(
query="what can you do with onyx or danswer?",
)
with get_session_context_manager() as db_session:
agent_search_config, search_tool = get_test_config(
db_session, primary_llm, fast_llm, search_request
)
inputs = AnswerQuestionInput(
question="what can you do with onyx?",
question_id="0_0",
log_messages=[],
)
for thing in compiled_graph.stream(
input=inputs,
config={"configurable": {"config": agent_search_config}},
# debug=True,
# subgraphs=True,
):
logger.debug(thing)

View File

@@ -0,0 +1,8 @@
from pydantic import BaseModel
### Models ###
class AnswerRetrievalStats(BaseModel):
answer_retrieval_stats: dict[str, float | int]

View File

@@ -0,0 +1,59 @@
from datetime import datetime
from typing import cast
from langchain_core.messages import HumanMessage
from langchain_core.messages import merge_message_runs
from langchain_core.runnables.config import RunnableConfig
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.states import (
AnswerQuestionState,
)
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.states import (
QACheckUpdate,
)
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.agents.agent_search.shared_graph_utils.prompts import SUB_CHECK_NO
from onyx.agents.agent_search.shared_graph_utils.prompts import SUB_CHECK_PROMPT
from onyx.agents.agent_search.shared_graph_utils.prompts import UNKNOWN_ANSWER
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
def answer_check(state: AnswerQuestionState, config: RunnableConfig) -> QACheckUpdate:
now_start = datetime.now()
level, question_num = parse_question_id(state.question_id)
if state.answer == UNKNOWN_ANSWER:
now_end = datetime.now()
return QACheckUpdate(
answer_quality=SUB_CHECK_NO,
log_messages=[
f"{now_end} -- Answer check SQ-{level}-{question_num} - unknown answer, Time taken: {now_end - now_start}"
],
)
msg = [
HumanMessage(
content=SUB_CHECK_PROMPT.format(
question=state.question,
base_answer=state.answer,
)
)
]
agent_searchch_config = cast(AgentSearchConfig, config["metadata"]["config"])
fast_llm = agent_searchch_config.fast_llm
response = list(
fast_llm.stream(
prompt=msg,
)
)
quality_str = merge_message_runs(response, chunk_separator="")[0].content
now_end = datetime.now()
return QACheckUpdate(
answer_quality=quality_str,
log_messages=[
f"""{now_end} -- Answer check SQ-{level}-{question_num} - Answer quality: {quality_str},
Time taken: {now_end - now_start}"""
],
)

View File

@@ -0,0 +1,116 @@
from datetime import datetime
from typing import Any
from typing import cast
from langchain_core.callbacks.manager import dispatch_custom_event
from langchain_core.messages import merge_message_runs
from langchain_core.runnables.config import RunnableConfig
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.states import (
AnswerQuestionState,
)
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.states import (
QAGenerationUpdate,
)
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
build_sub_question_answer_prompt,
)
from onyx.agents.agent_search.shared_graph_utils.prompts import (
ASSISTANT_SYSTEM_PROMPT_DEFAULT,
)
from onyx.agents.agent_search.shared_graph_utils.prompts import (
ASSISTANT_SYSTEM_PROMPT_PERSONA,
)
from onyx.agents.agent_search.shared_graph_utils.prompts import NO_RECOVERED_DOCS
from onyx.agents.agent_search.shared_graph_utils.utils import get_persona_prompt
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
from onyx.chat.models import AgentAnswerPiece
from onyx.chat.models import StreamStopInfo
from onyx.chat.models import StreamStopReason
from onyx.utils.logger import setup_logger
logger = setup_logger()
def answer_generation(
state: AnswerQuestionState, config: RunnableConfig
) -> QAGenerationUpdate:
now_start = datetime.now()
logger.debug(f"--------{now_start}--------START ANSWER GENERATION---")
agent_search_config = cast(AgentSearchConfig, config["metadata"]["config"])
question = state.question
docs = state.documents
level, question_nr = parse_question_id(state.question_id)
context_docs = state.context_documents
persona_prompt = get_persona_prompt(agent_search_config.search_request.persona)
if len(context_docs) == 0:
answer_str = NO_RECOVERED_DOCS
dispatch_custom_event(
"sub_answers",
AgentAnswerPiece(
answer_piece=answer_str,
level=level,
level_question_nr=question_nr,
answer_type="agent_sub_answer",
),
)
else:
if len(persona_prompt) > 0:
persona_specification = ASSISTANT_SYSTEM_PROMPT_DEFAULT
else:
persona_specification = ASSISTANT_SYSTEM_PROMPT_PERSONA.format(
persona_prompt=persona_prompt
)
logger.debug(f"Number of verified retrieval docs: {len(docs)}")
fast_llm = agent_search_config.fast_llm
msg = build_sub_question_answer_prompt(
question=question,
original_question=agent_search_config.search_request.query,
docs=docs,
persona_specification=persona_specification,
config=fast_llm.config,
)
response: list[str | list[str | dict[str, Any]]] = []
for message in fast_llm.stream(
prompt=msg,
):
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
content = message.content
if not isinstance(content, str):
raise ValueError(
f"Expected content to be a string, but got {type(content)}"
)
dispatch_custom_event(
"sub_answers",
AgentAnswerPiece(
answer_piece=content,
level=level,
level_question_nr=question_nr,
answer_type="agent_sub_answer",
),
)
response.append(content)
answer_str = merge_message_runs(response, chunk_separator="")[0].content
stop_event = StreamStopInfo(
stop_reason=StreamStopReason.FINISHED,
stream_type="sub_answer",
level=level,
level_question_nr=question_nr,
)
dispatch_custom_event("stream_finished", stop_event)
now_end = datetime.now()
return QAGenerationUpdate(
answer=answer_str,
log_messages=[
f"{now_end} -- Answer generation SQ-{level} - Q{question_nr} - Time taken: {now_end - now_start}"
],
)

View File

@@ -0,0 +1,28 @@
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.states import (
AnswerQuestionOutput,
)
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.states import (
AnswerQuestionState,
)
from onyx.agents.agent_search.shared_graph_utils.models import (
QuestionAnswerResults,
)
def format_answer(state: AnswerQuestionState) -> AnswerQuestionOutput:
return AnswerQuestionOutput(
answer_results=[
QuestionAnswerResults(
question=state.question,
question_id=state.question_id,
quality=state.answer_quality
if hasattr(state, "answer_quality")
else "No",
answer=state.answer,
expanded_retrieval_results=state.expanded_retrieval_results,
documents=state.documents,
context_documents=state.context_documents,
sub_question_retrieval_stats=state.sub_question_retrieval_stats,
)
],
)

View File

@@ -0,0 +1,22 @@
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.states import (
RetrievalIngestionUpdate,
)
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
ExpandedRetrievalOutput,
)
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkStats
def ingest_retrieval(state: ExpandedRetrievalOutput) -> RetrievalIngestionUpdate:
sub_question_retrieval_stats = (
state.expanded_retrieval_result.sub_question_retrieval_stats
)
if sub_question_retrieval_stats is None:
sub_question_retrieval_stats = [AgentChunkStats()]
return RetrievalIngestionUpdate(
expanded_retrieval_results=state.expanded_retrieval_result.expanded_queries_results,
documents=state.expanded_retrieval_result.all_documents,
context_documents=state.expanded_retrieval_result.context_documents,
sub_question_retrieval_stats=sub_question_retrieval_stats,
)

View File

@@ -0,0 +1,71 @@
from operator import add
from typing import Annotated
from pydantic import BaseModel
from onyx.agents.agent_search.core_state import SubgraphCoreState
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkStats
from onyx.agents.agent_search.shared_graph_utils.models import QueryResult
from onyx.agents.agent_search.shared_graph_utils.models import (
QuestionAnswerResults,
)
from onyx.agents.agent_search.shared_graph_utils.operators import (
dedup_inference_sections,
)
from onyx.context.search.models import InferenceSection
## Update States
class QACheckUpdate(BaseModel):
answer_quality: str = ""
log_messages: list[str] = []
class QAGenerationUpdate(BaseModel):
answer: str = ""
log_messages: list[str] = []
# answer_stat: AnswerStats
class RetrievalIngestionUpdate(BaseModel):
expanded_retrieval_results: list[QueryResult] = []
documents: Annotated[list[InferenceSection], dedup_inference_sections] = []
context_documents: Annotated[list[InferenceSection], dedup_inference_sections] = []
sub_question_retrieval_stats: AgentChunkStats = AgentChunkStats()
## Graph Input State
class AnswerQuestionInput(SubgraphCoreState):
question: str = ""
question_id: str = (
"" # 0_0 is original question, everything else is <level>_<question_num>.
)
# level 0 is original question and first decomposition, level 1 is follow up, etc
# question_num is a unique number per original question per level.
## Graph State
class AnswerQuestionState(
AnswerQuestionInput,
QAGenerationUpdate,
QACheckUpdate,
RetrievalIngestionUpdate,
):
pass
## Graph Output State
class AnswerQuestionOutput(BaseModel):
"""
This is a list of results even though each call of this subgraph only returns one result.
This is because if we parallelize the answer query subgraph, there will be multiple
results in a list so the add operator is used to add them together.
"""
answer_results: Annotated[list[QuestionAnswerResults], add] = []

View File

@@ -0,0 +1,28 @@
from collections.abc import Hashable
from datetime import datetime
from langgraph.types import Send
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.states import (
AnswerQuestionInput,
)
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
ExpandedRetrievalInput,
)
from onyx.utils.logger import setup_logger
logger = setup_logger()
def send_to_expanded_refined_retrieval(state: AnswerQuestionInput) -> Send | Hashable:
logger.debug("sending to expanded retrieval for follow up question via edge")
datetime.now()
return Send(
"refined_sub_question_expanded_retrieval",
ExpandedRetrievalInput(
question=state.question,
sub_question_id=state.question_id,
base_search=False,
log_messages=[f"{datetime.now()} -- Sending to expanded retrieval"],
),
)

View File

@@ -0,0 +1,123 @@
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.nodes.answer_check import (
answer_check,
)
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.nodes.answer_generation import (
answer_generation,
)
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.nodes.format_answer import (
format_answer,
)
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.nodes.ingest_retrieval import (
ingest_retrieval,
)
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.states import (
AnswerQuestionInput,
)
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.states import (
AnswerQuestionOutput,
)
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.states import (
AnswerQuestionState,
)
from onyx.agents.agent_search.deep_search_a.answer_refinement_sub_question.edges import (
send_to_expanded_refined_retrieval,
)
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.graph_builder import (
expanded_retrieval_graph_builder,
)
from onyx.utils.logger import setup_logger
logger = setup_logger()
def answer_refined_query_graph_builder() -> StateGraph:
graph = StateGraph(
state_schema=AnswerQuestionState,
input=AnswerQuestionInput,
output=AnswerQuestionOutput,
)
### Add nodes ###
expanded_retrieval = expanded_retrieval_graph_builder().compile()
graph.add_node(
node="refined_sub_question_expanded_retrieval",
action=expanded_retrieval,
)
graph.add_node(
node="refined_sub_answer_check",
action=answer_check,
)
graph.add_node(
node="refined_sub_answer_generation",
action=answer_generation,
)
graph.add_node(
node="format_refined_sub_answer",
action=format_answer,
)
graph.add_node(
node="ingest_refined_retrieval",
action=ingest_retrieval,
)
### Add edges ###
graph.add_conditional_edges(
source=START,
path=send_to_expanded_refined_retrieval,
path_map=["refined_sub_question_expanded_retrieval"],
)
graph.add_edge(
start_key="refined_sub_question_expanded_retrieval",
end_key="ingest_refined_retrieval",
)
graph.add_edge(
start_key="ingest_refined_retrieval",
end_key="refined_sub_answer_generation",
)
graph.add_edge(
start_key="refined_sub_answer_generation",
end_key="refined_sub_answer_check",
)
graph.add_edge(
start_key="refined_sub_answer_check",
end_key="format_refined_sub_answer",
)
graph.add_edge(
start_key="format_refined_sub_answer",
end_key=END,
)
return graph
if __name__ == "__main__":
from onyx.db.engine import get_session_context_manager
from onyx.llm.factory import get_default_llms
from onyx.context.search.models import SearchRequest
graph = answer_refined_query_graph_builder()
compiled_graph = graph.compile()
primary_llm, fast_llm = get_default_llms()
search_request = SearchRequest(
query="what can you do with onyx or danswer?",
)
with get_session_context_manager() as db_session:
inputs = AnswerQuestionInput(
question="what can you do with onyx?",
question_id="0_0",
log_messages=[],
)
for thing in compiled_graph.stream(
input=inputs,
# debug=True,
# subgraphs=True,
):
logger.debug(thing)
# output = compiled_graph.invoke(inputs)
# logger.debug(output)

View File

@@ -0,0 +1,19 @@
from pydantic import BaseModel
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkStats
from onyx.context.search.models import InferenceSection
### Models ###
class AnswerRetrievalStats(BaseModel):
answer_retrieval_stats: dict[str, float | int]
class QuestionAnswerResults(BaseModel):
question: str
answer: str
quality: str
# expanded_retrieval_results: list[QueryResult]
documents: list[InferenceSection]
sub_question_retrieval_stats: AgentChunkStats

View File

@@ -0,0 +1,76 @@
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
from onyx.agents.agent_search.deep_search_a.base_raw_search.nodes.format_raw_search_results import (
format_raw_search_results,
)
from onyx.agents.agent_search.deep_search_a.base_raw_search.nodes.generate_raw_search_data import (
generate_raw_search_data,
)
from onyx.agents.agent_search.deep_search_a.base_raw_search.states import (
BaseRawSearchInput,
)
from onyx.agents.agent_search.deep_search_a.base_raw_search.states import (
BaseRawSearchOutput,
)
from onyx.agents.agent_search.deep_search_a.base_raw_search.states import (
BaseRawSearchState,
)
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.graph_builder import (
expanded_retrieval_graph_builder,
)
def base_raw_search_graph_builder() -> StateGraph:
graph = StateGraph(
state_schema=BaseRawSearchState,
input=BaseRawSearchInput,
output=BaseRawSearchOutput,
)
### Add nodes ###
graph.add_node(
node="generate_raw_search_data",
action=generate_raw_search_data,
)
expanded_retrieval = expanded_retrieval_graph_builder().compile()
graph.add_node(
node="expanded_retrieval_base_search",
action=expanded_retrieval,
)
graph.add_node(
node="format_raw_search_results",
action=format_raw_search_results,
)
### Add edges ###
graph.add_edge(start_key=START, end_key="generate_raw_search_data")
graph.add_edge(
start_key="generate_raw_search_data",
end_key="expanded_retrieval_base_search",
)
graph.add_edge(
start_key="expanded_retrieval_base_search",
end_key="format_raw_search_results",
)
# graph.add_edge(
# start_key="expanded_retrieval_base_search",
# end_key=END,
# )
graph.add_edge(
start_key="format_raw_search_results",
end_key=END,
)
return graph
if __name__ == "__main__":
pass

View File

@@ -0,0 +1,20 @@
from pydantic import BaseModel
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkStats
from onyx.agents.agent_search.shared_graph_utils.models import QueryResult
from onyx.context.search.models import InferenceSection
### Models ###
class AnswerRetrievalStats(BaseModel):
answer_retrieval_stats: dict[str, float | int]
class QuestionAnswerResults(BaseModel):
question: str
answer: str
quality: str
expanded_retrieval_results: list[QueryResult]
documents: list[InferenceSection]
sub_question_retrieval_stats: list[AgentChunkStats]

View File

@@ -0,0 +1,18 @@
from onyx.agents.agent_search.deep_search_a.base_raw_search.states import (
BaseRawSearchOutput,
)
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
ExpandedRetrievalOutput,
)
from onyx.utils.logger import setup_logger
logger = setup_logger()
def format_raw_search_results(state: ExpandedRetrievalOutput) -> BaseRawSearchOutput:
logger.debug("format_raw_search_results")
return BaseRawSearchOutput(
base_expanded_retrieval_result=state.expanded_retrieval_result,
# base_retrieval_results=[state.expanded_retrieval_result],
# base_search_documents=[],
)

View File

@@ -0,0 +1,25 @@
from typing import cast
from langchain_core.runnables.config import RunnableConfig
from onyx.agents.agent_search.core_state import CoreState
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
ExpandedRetrievalInput,
)
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.utils.logger import setup_logger
logger = setup_logger()
def generate_raw_search_data(
state: CoreState, config: RunnableConfig
) -> ExpandedRetrievalInput:
logger.debug("generate_raw_search_data")
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
return ExpandedRetrievalInput(
question=agent_a_config.search_request.query,
base_search=True,
sub_question_id=None, # This graph is always and only used for the original question
log_messages=[],
)

View File

@@ -0,0 +1,43 @@
from pydantic import BaseModel
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.models import (
ExpandedRetrievalResult,
)
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
ExpandedRetrievalInput,
)
## Update States
## Graph Input State
class BaseRawSearchInput(ExpandedRetrievalInput):
pass
## Graph Output State
class BaseRawSearchOutput(BaseModel):
"""
This is a list of results even though each call of this subgraph only returns one result.
This is because if we parallelize the answer query subgraph, there will be multiple
results in a list so the add operator is used to add them together.
"""
# base_search_documents: Annotated[list[InferenceSection], dedup_inference_sections]
# base_retrieval_results: Annotated[list[ExpandedRetrievalResult], add]
base_expanded_retrieval_result: ExpandedRetrievalResult = ExpandedRetrievalResult()
## Graph State
class BaseRawSearchState(
BaseRawSearchInput,
BaseRawSearchOutput,
):
pass

View File

@@ -0,0 +1,37 @@
from collections.abc import Hashable
from typing import cast
from langchain_core.runnables.config import RunnableConfig
from langgraph.types import Send
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
ExpandedRetrievalState,
)
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
RetrievalInput,
)
from onyx.agents.agent_search.models import AgentSearchConfig
def parallel_retrieval_edge(
state: ExpandedRetrievalState, config: RunnableConfig
) -> list[Send | Hashable]:
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
question = state.question if state.question else agent_a_config.search_request.query
query_expansions = (
state.expanded_queries if state.expanded_queries else [] + [question]
)
return [
Send(
"doc_retrieval",
RetrievalInput(
query_to_retrieve=query,
question=question,
base_search=False,
sub_question_id=state.sub_question_id,
log_messages=[],
),
)
for query in query_expansions
]

View File

@@ -0,0 +1,147 @@
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.edges import (
parallel_retrieval_edge,
)
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.nodes.doc_reranking import (
doc_reranking,
)
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.nodes.doc_retrieval import (
doc_retrieval,
)
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.nodes.doc_verification import (
doc_verification,
)
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.nodes.dummy import (
dummy,
)
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.nodes.expand_queries import (
expand_queries,
)
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.nodes.format_results import (
format_results,
)
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.nodes.verification_kickoff import (
verification_kickoff,
)
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
ExpandedRetrievalInput,
)
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
ExpandedRetrievalOutput,
)
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
ExpandedRetrievalState,
)
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
from onyx.utils.logger import setup_logger
logger = setup_logger()
def expanded_retrieval_graph_builder() -> StateGraph:
graph = StateGraph(
state_schema=ExpandedRetrievalState,
input=ExpandedRetrievalInput,
output=ExpandedRetrievalOutput,
)
### Add nodes ###
graph.add_node(
node="expand_queries",
action=expand_queries,
)
graph.add_node(
node="dummy",
action=dummy,
)
graph.add_node(
node="doc_retrieval",
action=doc_retrieval,
)
graph.add_node(
node="verification_kickoff",
action=verification_kickoff,
)
graph.add_node(
node="doc_verification",
action=doc_verification,
)
graph.add_node(
node="doc_reranking",
action=doc_reranking,
)
graph.add_node(
node="format_results",
action=format_results,
)
### Add edges ###
graph.add_edge(
start_key=START,
end_key="expand_queries",
)
graph.add_edge(
start_key="expand_queries",
end_key="dummy",
)
graph.add_conditional_edges(
source="dummy",
path=parallel_retrieval_edge,
path_map=["doc_retrieval"],
)
graph.add_edge(
start_key="doc_retrieval",
end_key="verification_kickoff",
)
graph.add_edge(
start_key="doc_verification",
end_key="doc_reranking",
)
graph.add_edge(
start_key="doc_reranking",
end_key="format_results",
)
graph.add_edge(
start_key="format_results",
end_key=END,
)
return graph
if __name__ == "__main__":
from onyx.db.engine import get_session_context_manager
from onyx.llm.factory import get_default_llms
from onyx.context.search.models import SearchRequest
graph = expanded_retrieval_graph_builder()
compiled_graph = graph.compile()
primary_llm, fast_llm = get_default_llms()
search_request = SearchRequest(
query="what can you do with onyx or danswer?",
)
with get_session_context_manager() as db_session:
agent_a_config, search_tool = get_test_config(
db_session, primary_llm, fast_llm, search_request
)
inputs = ExpandedRetrievalInput(
question="what can you do with onyx?",
base_search=False,
sub_question_id=None,
log_messages=[],
)
for thing in compiled_graph.stream(
input=inputs,
config={"configurable": {"config": agent_a_config}},
# debug=True,
subgraphs=True,
):
logger.debug(thing)

View File

@@ -0,0 +1,12 @@
from pydantic import BaseModel
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkStats
from onyx.agents.agent_search.shared_graph_utils.models import QueryResult
from onyx.context.search.models import InferenceSection
class ExpandedRetrievalResult(BaseModel):
expanded_queries_results: list[QueryResult] = []
all_documents: list[InferenceSection] = []
context_documents: list[InferenceSection] = []
sub_question_retrieval_stats: AgentChunkStats = AgentChunkStats()

View File

@@ -0,0 +1,74 @@
from datetime import datetime
from typing import cast
from langchain_core.runnables.config import RunnableConfig
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.operations import logger
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
DocRerankingUpdate,
)
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
ExpandedRetrievalState,
)
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.agents.agent_search.shared_graph_utils.calculations import get_fit_scores
from onyx.agents.agent_search.shared_graph_utils.models import RetrievalFitStats
from onyx.configs.dev_configs import AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS
from onyx.configs.dev_configs import AGENT_RERANKING_STATS
from onyx.context.search.models import InferenceSection
from onyx.context.search.models import SearchRequest
from onyx.context.search.pipeline import retrieval_preprocessing
from onyx.context.search.postprocessing.postprocessing import rerank_sections
from onyx.db.engine import get_session_context_manager
def doc_reranking(
state: ExpandedRetrievalState, config: RunnableConfig
) -> DocRerankingUpdate:
now_start = datetime.now()
verified_documents = state.verified_documents
# Rerank post retrieval and verification. First, create a search query
# then create the list of reranked sections
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
question = state.question if state.question else agent_a_config.search_request.query
with get_session_context_manager() as db_session:
_search_query = retrieval_preprocessing(
search_request=SearchRequest(query=question),
user=agent_a_config.search_tool.user, # bit of a hack
llm=agent_a_config.fast_llm,
db_session=db_session,
)
# skip section filtering
if (
_search_query.rerank_settings
and _search_query.rerank_settings.rerank_model_name
and _search_query.rerank_settings.num_rerank > 0
):
reranked_documents = rerank_sections(
_search_query,
verified_documents,
)
else:
logger.warning("No reranking settings found, using unranked documents")
reranked_documents = verified_documents
if AGENT_RERANKING_STATS:
fit_scores = get_fit_scores(verified_documents, reranked_documents)
else:
fit_scores = RetrievalFitStats(fit_score_lift=0, rerank_effect=0, fit_scores={})
# TODO: stream deduped docs here, or decide to use search tool ranking/verification
now_end = datetime.now()
return DocRerankingUpdate(
reranked_documents=[
doc for doc in reranked_documents if type(doc) == InferenceSection
][:AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS],
sub_question_retrieval_stats=fit_scores,
log_messages=[
f"{now_end} -- Expanded Retrieval - Reranking - Time taken: {now_end - now_start}"
],
)

View File

@@ -0,0 +1,103 @@
from datetime import datetime
from typing import cast
from langchain_core.runnables.config import RunnableConfig
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.operations import logger
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
DocRetrievalUpdate,
)
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
RetrievalInput,
)
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.agents.agent_search.shared_graph_utils.calculations import get_fit_scores
from onyx.agents.agent_search.shared_graph_utils.models import QueryResult
from onyx.configs.dev_configs import AGENT_MAX_QUERY_RETRIEVAL_RESULTS
from onyx.configs.dev_configs import AGENT_RETRIEVAL_STATS
from onyx.context.search.models import InferenceSection
from onyx.db.engine import get_session_context_manager
from onyx.tools.models import SearchQueryInfo
from onyx.tools.tool_implementations.search.search_tool import (
SEARCH_RESPONSE_SUMMARY_ID,
)
from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary
def doc_retrieval(state: RetrievalInput, config: RunnableConfig) -> DocRetrievalUpdate:
"""
Retrieve documents
Args:
state (RetrievalInput): Primary state + the query to retrieve
config (RunnableConfig): Configuration containing ProSearchConfig
Updates:
expanded_retrieval_results: list[ExpandedRetrievalResult]
retrieved_documents: list[InferenceSection]
"""
now_start = datetime.now()
query_to_retrieve = state.query_to_retrieve
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
search_tool = agent_a_config.search_tool
retrieved_docs: list[InferenceSection] = []
if not query_to_retrieve.strip():
logger.warning("Empty query, skipping retrieval")
now_end = datetime.now()
return DocRetrievalUpdate(
expanded_retrieval_results=[],
retrieved_documents=[],
log_messages=[
f"{now_end} -- Expanded Retrieval - Retrieval - Empty Query - Time taken: {now_end - now_start}"
],
)
query_info = None
# new db session to avoid concurrency issues
with get_session_context_manager() as db_session:
for tool_response in search_tool.run(
query=query_to_retrieve,
force_no_rerank=True,
alternate_db_session=db_session,
):
# get retrieved docs to send to the rest of the graph
if tool_response.id == SEARCH_RESPONSE_SUMMARY_ID:
response = cast(SearchResponseSummary, tool_response.response)
retrieved_docs = response.top_sections
query_info = SearchQueryInfo(
predicted_search=response.predicted_search,
final_filters=response.final_filters,
recency_bias_multiplier=response.recency_bias_multiplier,
)
break
retrieved_docs = retrieved_docs[:AGENT_MAX_QUERY_RETRIEVAL_RESULTS]
pre_rerank_docs = retrieved_docs
if search_tool.search_pipeline is not None:
pre_rerank_docs = (
search_tool.search_pipeline._retrieved_sections or retrieved_docs
)
if AGENT_RETRIEVAL_STATS:
fit_scores = get_fit_scores(
pre_rerank_docs,
retrieved_docs,
)
else:
fit_scores = None
expanded_retrieval_result = QueryResult(
query=query_to_retrieve,
search_results=retrieved_docs,
stats=fit_scores,
query_info=query_info,
)
now_end = datetime.now()
return DocRetrievalUpdate(
expanded_retrieval_results=[expanded_retrieval_result],
retrieved_documents=retrieved_docs,
log_messages=[
f"{now_end} -- Expanded Retrieval - Retrieval - Time taken: {now_end - now_start}"
],
)

View File

@@ -0,0 +1,60 @@
from typing import cast
from langchain_core.messages import HumanMessage
from langchain_core.runnables.config import RunnableConfig
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
DocVerificationInput,
)
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
DocVerificationUpdate,
)
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
trim_prompt_piece,
)
from onyx.agents.agent_search.shared_graph_utils.prompts import VERIFIER_PROMPT
def doc_verification(
state: DocVerificationInput, config: RunnableConfig
) -> DocVerificationUpdate:
"""
Check whether the document is relevant for the original user question
Args:
state (DocVerificationInput): The current state
config (RunnableConfig): Configuration containing ProSearchConfig
Updates:
verified_documents: list[InferenceSection]
"""
question = state.question
doc_to_verify = state.doc_to_verify
document_content = doc_to_verify.combined_content
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
fast_llm = agent_a_config.fast_llm
document_content = trim_prompt_piece(
fast_llm.config, document_content, VERIFIER_PROMPT + question
)
msg = [
HumanMessage(
content=VERIFIER_PROMPT.format(
question=question, document_content=document_content
)
)
]
response = fast_llm.invoke(msg)
verified_documents = []
if isinstance(response.content, str) and "yes" in response.content.lower():
verified_documents.append(doc_to_verify)
return DocVerificationUpdate(
verified_documents=verified_documents,
)

View File

@@ -0,0 +1,16 @@
from langchain_core.runnables.config import RunnableConfig
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
ExpandedRetrievalState,
)
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
QueryExpansionUpdate,
)
def dummy(
state: ExpandedRetrievalState, config: RunnableConfig
) -> QueryExpansionUpdate:
return QueryExpansionUpdate(
expanded_queries=state.expanded_queries,
)

View File

@@ -0,0 +1,68 @@
from datetime import datetime
from typing import cast
from langchain_core.messages import HumanMessage
from langchain_core.messages import merge_message_runs
from langchain_core.runnables.config import RunnableConfig
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.operations import (
dispatch_subquery,
)
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
ExpandedRetrievalInput,
)
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
QueryExpansionUpdate,
)
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.agents.agent_search.shared_graph_utils.prompts import (
REWRITE_PROMPT_MULTI_ORIGINAL,
)
from onyx.agents.agent_search.shared_graph_utils.utils import dispatch_separated
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
def expand_queries(
state: ExpandedRetrievalInput, config: RunnableConfig
) -> QueryExpansionUpdate:
# Sometimes we want to expand the original question, sometimes we want to expand a sub-question.
# When we are running this node on the original question, no question is explictly passed in.
# Instead, we use the original question from the search request.
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
now_start = datetime.now()
question = (
state.question
if hasattr(state, "question")
else agent_a_config.search_request.query
)
llm = agent_a_config.fast_llm
chat_session_id = agent_a_config.chat_session_id
sub_question_id = state.sub_question_id
if sub_question_id is None:
level, question_nr = 0, 0
else:
level, question_nr = parse_question_id(sub_question_id)
if chat_session_id is None:
raise ValueError("chat_session_id must be provided for agent search")
msg = [
HumanMessage(
content=REWRITE_PROMPT_MULTI_ORIGINAL.format(question=question),
)
]
llm_response_list = dispatch_separated(
llm.stream(prompt=msg), dispatch_subquery(level, question_nr)
)
llm_response = merge_message_runs(llm_response_list, chunk_separator="")[0].content
rewritten_queries = llm_response.split("\n")
now_end = datetime.now()
return QueryExpansionUpdate(
expanded_queries=rewritten_queries,
log_messages=[
f"{now_end} -- Expanded Retrieval - Query Expansion - Time taken: {now_end - now_start}"
],
)

View File

@@ -0,0 +1,82 @@
from typing import cast
from langchain_core.callbacks.manager import dispatch_custom_event
from langchain_core.runnables.config import RunnableConfig
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.models import (
ExpandedRetrievalResult,
)
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.operations import (
calculate_sub_question_retrieval_stats,
)
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
ExpandedRetrievalState,
)
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
ExpandedRetrievalUpdate,
)
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkStats
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
from onyx.chat.models import ExtendedToolResponse
from onyx.tools.tool_implementations.search.search_tool import yield_search_responses
def format_results(
state: ExpandedRetrievalState, config: RunnableConfig
) -> ExpandedRetrievalUpdate:
level, question_nr = parse_question_id(state.sub_question_id or "0_0")
query_infos = [
result.query_info
for result in state.expanded_retrieval_results
if result.query_info is not None
]
if len(query_infos) == 0:
raise ValueError("No query info found")
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
# main question docs will be sent later after aggregation and deduping with sub-question docs
stream_documents = state.reranked_documents
if not (level == 0 and question_nr == 0):
if len(stream_documents) == 0:
# The sub-question is used as the last query. If no verified documents are found, stream
# the top 3 for that one. We may want to revisit this.
stream_documents = state.expanded_retrieval_results[-1].search_results[:3]
for tool_response in yield_search_responses(
query=state.question,
reranked_sections=state.retrieved_documents, # TODO: rename params. (sections pre-merging here.)
final_context_sections=stream_documents,
search_query_info=query_infos[0], # TODO: handle differing query infos?
get_section_relevance=lambda: None, # TODO: add relevance
search_tool=agent_a_config.search_tool,
):
dispatch_custom_event(
"tool_response",
ExtendedToolResponse(
id=tool_response.id,
response=tool_response.response,
level=level,
level_question_nr=question_nr,
),
)
sub_question_retrieval_stats = calculate_sub_question_retrieval_stats(
verified_documents=state.verified_documents,
expanded_retrieval_results=state.expanded_retrieval_results,
)
if sub_question_retrieval_stats is None:
sub_question_retrieval_stats = AgentChunkStats()
# else:
# sub_question_retrieval_stats = [sub_question_retrieval_stats]
return ExpandedRetrievalUpdate(
expanded_retrieval_result=ExpandedRetrievalResult(
expanded_queries_results=state.expanded_retrieval_results,
all_documents=stream_documents,
context_documents=state.reranked_documents,
sub_question_retrieval_stats=sub_question_retrieval_stats,
),
)

View File

@@ -0,0 +1,44 @@
from typing import cast
from typing import Literal
from langchain_core.runnables.config import RunnableConfig
from langgraph.types import Command
from langgraph.types import Send
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
DocVerificationInput,
)
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
ExpandedRetrievalState,
)
from onyx.agents.agent_search.models import AgentSearchConfig
def verification_kickoff(
state: ExpandedRetrievalState,
config: RunnableConfig,
) -> Command[Literal["doc_verification"]]:
documents = state.retrieved_documents
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
verification_question = (
state.question
if hasattr(state, "question")
else agent_a_config.search_request.query
)
sub_question_id = state.sub_question_id
return Command(
update={},
goto=[
Send(
node="doc_verification",
arg=DocVerificationInput(
doc_to_verify=doc,
question=verification_question,
base_search=False,
sub_question_id=sub_question_id,
log_messages=[],
),
)
for doc in documents
],
)

View File

@@ -0,0 +1,97 @@
from collections import defaultdict
from collections.abc import Callable
import numpy as np
from langchain_core.callbacks.manager import dispatch_custom_event
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkStats
from onyx.agents.agent_search.shared_graph_utils.models import QueryResult
from onyx.chat.models import SubQueryPiece
from onyx.context.search.models import InferenceSection
from onyx.utils.logger import setup_logger
logger = setup_logger()
def dispatch_subquery(level: int, question_nr: int) -> Callable[[str, int], None]:
def helper(token: str, num: int) -> None:
dispatch_custom_event(
"subqueries",
SubQueryPiece(
sub_query=token,
level=level,
level_question_nr=question_nr,
query_id=num,
),
)
return helper
def calculate_sub_question_retrieval_stats(
verified_documents: list[InferenceSection],
expanded_retrieval_results: list[QueryResult],
) -> AgentChunkStats:
chunk_scores: dict[str, dict[str, list[int | float]]] = defaultdict(
lambda: defaultdict(list)
)
for expanded_retrieval_result in expanded_retrieval_results:
for doc in expanded_retrieval_result.search_results:
doc_chunk_id = f"{doc.center_chunk.document_id}_{doc.center_chunk.chunk_id}"
if doc.center_chunk.score is not None:
chunk_scores[doc_chunk_id]["score"].append(doc.center_chunk.score)
verified_doc_chunk_ids = [
f"{verified_document.center_chunk.document_id}_{verified_document.center_chunk.chunk_id}"
for verified_document in verified_documents
]
dismissed_doc_chunk_ids = []
raw_chunk_stats_counts: dict[str, int] = defaultdict(int)
raw_chunk_stats_scores: dict[str, float] = defaultdict(float)
for doc_chunk_id, chunk_data in chunk_scores.items():
if doc_chunk_id in verified_doc_chunk_ids:
raw_chunk_stats_counts["verified_count"] += 1
valid_chunk_scores = [
score for score in chunk_data["score"] if score is not None
]
raw_chunk_stats_scores["verified_scores"] += float(
np.mean(valid_chunk_scores)
)
else:
raw_chunk_stats_counts["rejected_count"] += 1
valid_chunk_scores = [
score for score in chunk_data["score"] if score is not None
]
raw_chunk_stats_scores["rejected_scores"] += float(
np.mean(valid_chunk_scores)
)
dismissed_doc_chunk_ids.append(doc_chunk_id)
if raw_chunk_stats_counts["verified_count"] == 0:
verified_avg_scores = 0.0
else:
verified_avg_scores = raw_chunk_stats_scores["verified_scores"] / float(
raw_chunk_stats_counts["verified_count"]
)
rejected_scores = raw_chunk_stats_scores.get("rejected_scores", None)
if rejected_scores is not None:
rejected_avg_scores = rejected_scores / float(
raw_chunk_stats_counts["rejected_count"]
)
else:
rejected_avg_scores = None
chunk_stats = AgentChunkStats(
verified_count=raw_chunk_stats_counts["verified_count"],
verified_avg_scores=verified_avg_scores,
rejected_count=raw_chunk_stats_counts["rejected_count"],
rejected_avg_scores=rejected_avg_scores,
verified_doc_chunk_ids=verified_doc_chunk_ids,
dismissed_doc_chunk_ids=dismissed_doc_chunk_ids,
)
return chunk_stats

View File

@@ -0,0 +1,91 @@
from operator import add
from typing import Annotated
from pydantic import BaseModel
from onyx.agents.agent_search.core_state import SubgraphCoreState
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.models import (
ExpandedRetrievalResult,
)
from onyx.agents.agent_search.shared_graph_utils.models import QueryResult
from onyx.agents.agent_search.shared_graph_utils.models import RetrievalFitStats
from onyx.agents.agent_search.shared_graph_utils.operators import (
dedup_inference_sections,
)
from onyx.context.search.models import InferenceSection
### States ###
## Graph Input State
class ExpandedRetrievalInput(SubgraphCoreState):
question: str = ""
base_search: bool = False
sub_question_id: str | None = None
## Update/Return States
class QueryExpansionUpdate(BaseModel):
expanded_queries: list[str] = ["aaa", "bbb"]
log_messages: list[str] = []
class DocVerificationUpdate(BaseModel):
verified_documents: Annotated[list[InferenceSection], dedup_inference_sections] = []
class DocRetrievalUpdate(BaseModel):
expanded_retrieval_results: Annotated[list[QueryResult], add] = []
retrieved_documents: Annotated[
list[InferenceSection], dedup_inference_sections
] = []
log_messages: list[str] = []
class DocRerankingUpdate(BaseModel):
reranked_documents: Annotated[list[InferenceSection], dedup_inference_sections] = []
sub_question_retrieval_stats: RetrievalFitStats | None = None
log_messages: list[str] = []
class ExpandedRetrievalUpdate(BaseModel):
expanded_retrieval_result: ExpandedRetrievalResult
## Graph Output State
class ExpandedRetrievalOutput(BaseModel):
expanded_retrieval_result: ExpandedRetrievalResult = ExpandedRetrievalResult()
base_expanded_retrieval_result: ExpandedRetrievalResult = ExpandedRetrievalResult()
log_messages: list[str] = []
## Graph State
class ExpandedRetrievalState(
# This includes the core state
ExpandedRetrievalInput,
QueryExpansionUpdate,
DocRetrievalUpdate,
DocVerificationUpdate,
DocRerankingUpdate,
ExpandedRetrievalOutput,
):
pass
## Conditional Input States
class DocVerificationInput(ExpandedRetrievalInput):
doc_to_verify: InferenceSection
class RetrievalInput(ExpandedRetrievalInput):
query_to_retrieve: str = ""

View File

@@ -0,0 +1,100 @@
from collections.abc import Hashable
from datetime import datetime
from typing import Literal
from langgraph.types import Send
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.states import (
AnswerQuestionInput,
)
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.states import (
AnswerQuestionOutput,
)
from onyx.agents.agent_search.deep_search_a.main.states import MainState
from onyx.agents.agent_search.deep_search_a.main.states import (
RequireRefinedAnswerUpdate,
)
from onyx.agents.agent_search.shared_graph_utils.utils import make_question_id
from onyx.utils.logger import setup_logger
logger = setup_logger()
def parallelize_initial_sub_question_answering(
state: MainState,
) -> list[Send | Hashable]:
now_start = datetime.now()
if len(state.initial_decomp_questions) > 0:
# sub_question_record_ids = [subq_record.id for subq_record in state["sub_question_records"]]
# if len(state["sub_question_records"]) == 0:
# if state["config"].use_persistence:
# raise ValueError("No sub-questions found for initial decompozed questions")
# else:
# # in this case, we are doing retrieval on the original question.
# # to make all the logic consistent, we create a new sub-question
# # with the same content as the original question
# sub_question_record_ids = [1] * len(state["initial_decomp_questions"])
return [
Send(
"answer_query_subgraph",
AnswerQuestionInput(
question=question,
question_id=make_question_id(0, question_nr + 1),
log_messages=[
f"{now_start} -- Main Edge - Parallelize Initial Sub-question Answering"
],
),
)
for question_nr, question in enumerate(state.initial_decomp_questions)
]
else:
return [
Send(
"ingest_answers",
AnswerQuestionOutput(
answer_results=[],
),
)
]
# Define the function that determines whether to continue or not
def continue_to_refined_answer_or_end(
state: RequireRefinedAnswerUpdate,
) -> Literal["refined_sub_question_creation", "logging_node"]:
if state.require_refined_answer:
return "refined_sub_question_creation"
else:
return "logging_node"
def parallelize_refined_sub_question_answering(
state: MainState,
) -> list[Send | Hashable]:
now_start = datetime.now()
if len(state.refined_sub_questions) > 0:
return [
Send(
"answer_refined_question",
AnswerQuestionInput(
question=question_data.sub_question,
question_id=make_question_id(1, question_nr),
log_messages=[
f"{now_start} -- Main Edge - Parallelize Refined Sub-question Answering"
],
),
)
for question_nr, question_data in state.refined_sub_questions.items()
]
else:
return [
Send(
"ingest_refined_sub_answers",
AnswerQuestionOutput(
answer_results=[],
),
)
]

View File

@@ -0,0 +1,375 @@
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.graph_builder import (
answer_query_graph_builder,
)
from onyx.agents.agent_search.deep_search_a.answer_refinement_sub_question.graph_builder import (
answer_refined_query_graph_builder,
)
from onyx.agents.agent_search.deep_search_a.base_raw_search.graph_builder import (
base_raw_search_graph_builder,
)
from onyx.agents.agent_search.deep_search_a.main.edges import (
continue_to_refined_answer_or_end,
)
from onyx.agents.agent_search.deep_search_a.main.edges import (
parallelize_initial_sub_question_answering,
)
from onyx.agents.agent_search.deep_search_a.main.edges import (
parallelize_refined_sub_question_answering,
)
from onyx.agents.agent_search.deep_search_a.main.nodes.agent_logging import (
agent_logging,
)
from onyx.agents.agent_search.deep_search_a.main.nodes.agent_path_decision import (
agent_path_decision,
)
from onyx.agents.agent_search.deep_search_a.main.nodes.agent_path_routing import (
agent_path_routing,
)
from onyx.agents.agent_search.deep_search_a.main.nodes.agent_search_start import (
agent_search_start,
)
from onyx.agents.agent_search.deep_search_a.main.nodes.answer_comparison import (
answer_comparison,
)
from onyx.agents.agent_search.deep_search_a.main.nodes.direct_llm_handling import (
direct_llm_handling,
)
from onyx.agents.agent_search.deep_search_a.main.nodes.entity_term_extraction_llm import (
entity_term_extraction_llm,
)
from onyx.agents.agent_search.deep_search_a.main.nodes.generate_initial_answer import (
generate_initial_answer,
)
from onyx.agents.agent_search.deep_search_a.main.nodes.generate_refined_answer import (
generate_refined_answer,
)
from onyx.agents.agent_search.deep_search_a.main.nodes.ingest_initial_base_retrieval import (
ingest_initial_base_retrieval,
)
from onyx.agents.agent_search.deep_search_a.main.nodes.ingest_initial_sub_question_answers import (
ingest_initial_sub_question_answers,
)
from onyx.agents.agent_search.deep_search_a.main.nodes.ingest_refined_answers import (
ingest_refined_answers,
)
from onyx.agents.agent_search.deep_search_a.main.nodes.initial_answer_quality_check import (
initial_answer_quality_check,
)
from onyx.agents.agent_search.deep_search_a.main.nodes.initial_sub_question_creation import (
initial_sub_question_creation,
)
from onyx.agents.agent_search.deep_search_a.main.nodes.refined_answer_decision import (
refined_answer_decision,
)
from onyx.agents.agent_search.deep_search_a.main.nodes.refined_sub_question_creation import (
refined_sub_question_creation,
)
from onyx.agents.agent_search.deep_search_a.main.nodes.retrieval_consolidation import (
retrieval_consolidation,
)
from onyx.agents.agent_search.deep_search_a.main.states import MainInput
from onyx.agents.agent_search.deep_search_a.main.states import MainState
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
from onyx.utils.logger import setup_logger
logger = setup_logger()
test_mode = False
def main_graph_builder(test_mode: bool = False) -> StateGraph:
graph = StateGraph(
state_schema=MainState,
input=MainInput,
)
graph.add_node(
node="agent_path_decision",
action=agent_path_decision,
)
graph.add_node(
node="agent_path_routing",
action=agent_path_routing,
)
graph.add_node(
node="LLM",
action=direct_llm_handling,
)
graph.add_node(
node="agent_search_start",
action=agent_search_start,
)
graph.add_node(
node="initial_sub_question_creation",
action=initial_sub_question_creation,
)
answer_query_subgraph = answer_query_graph_builder().compile()
graph.add_node(
node="answer_query_subgraph",
action=answer_query_subgraph,
)
base_raw_search_subgraph = base_raw_search_graph_builder().compile()
graph.add_node(
node="base_raw_search_subgraph",
action=base_raw_search_subgraph,
)
# refined_answer_subgraph = refined_answers_graph_builder().compile()
# graph.add_node(
# node="refined_answer_subgraph",
# action=refined_answer_subgraph,
# )
graph.add_node(
node="refined_sub_question_creation",
action=refined_sub_question_creation,
)
answer_refined_question = answer_refined_query_graph_builder().compile()
graph.add_node(
node="answer_refined_question",
action=answer_refined_question,
)
graph.add_node(
node="ingest_refined_answers",
action=ingest_refined_answers,
)
graph.add_node(
node="generate_refined_answer",
action=generate_refined_answer,
)
# graph.add_node(
# node="check_refined_answer",
# action=check_refined_answer,
# )
graph.add_node(
node="ingest_initial_retrieval",
action=ingest_initial_base_retrieval,
)
graph.add_node(
node="retrieval_consolidation",
action=retrieval_consolidation,
)
graph.add_node(
node="ingest_initial_sub_question_answers",
action=ingest_initial_sub_question_answers,
)
graph.add_node(
node="generate_initial_answer",
action=generate_initial_answer,
)
graph.add_node(
node="initial_answer_quality_check",
action=initial_answer_quality_check,
)
graph.add_node(
node="entity_term_extraction_llm",
action=entity_term_extraction_llm,
)
graph.add_node(
node="refined_answer_decision",
action=refined_answer_decision,
)
graph.add_node(
node="answer_comparison",
action=answer_comparison,
)
graph.add_node(
node="logging_node",
action=agent_logging,
)
# if test_mode:
# graph.add_node(
# node="generate_initial_base_answer",
# action=generate_initial_base_answer,
# )
### Add edges ###
# raph.add_edge(start_key=START, end_key="base_raw_search_subgraph")
graph.add_edge(
start_key=START,
end_key="agent_path_decision",
)
graph.add_edge(
start_key="agent_path_decision",
end_key="agent_path_routing",
)
graph.add_edge(
start_key="agent_search_start",
end_key="base_raw_search_subgraph",
)
graph.add_edge(
start_key="agent_search_start",
end_key="initial_sub_question_creation",
)
graph.add_edge(
start_key="base_raw_search_subgraph",
end_key="ingest_initial_retrieval",
)
graph.add_edge(
start_key=["ingest_initial_retrieval", "ingest_initial_sub_question_answers"],
end_key="retrieval_consolidation",
)
graph.add_edge(
start_key="retrieval_consolidation",
end_key="entity_term_extraction_llm",
)
graph.add_edge(
start_key="retrieval_consolidation",
end_key="generate_initial_answer",
)
graph.add_edge(
start_key="LLM",
end_key=END,
)
# graph.add_edge(
# start_key=START,
# end_key="initial_sub_question_creation",
# )
graph.add_conditional_edges(
source="initial_sub_question_creation",
path=parallelize_initial_sub_question_answering,
path_map=["answer_query_subgraph"],
)
graph.add_edge(
start_key="answer_query_subgraph",
end_key="ingest_initial_sub_question_answers",
)
graph.add_edge(
start_key="retrieval_consolidation",
end_key="generate_initial_answer",
)
# graph.add_edge(
# start_key="generate_initial_answer",
# end_key="entity_term_extraction_llm",
# )
graph.add_edge(
start_key="generate_initial_answer",
end_key="initial_answer_quality_check",
)
graph.add_edge(
start_key=["initial_answer_quality_check", "entity_term_extraction_llm"],
end_key="refined_answer_decision",
)
graph.add_conditional_edges(
source="refined_answer_decision",
path=continue_to_refined_answer_or_end,
path_map=["refined_sub_question_creation", "logging_node"],
)
graph.add_conditional_edges(
source="refined_sub_question_creation", # DONE
path=parallelize_refined_sub_question_answering,
path_map=["answer_refined_question"],
)
graph.add_edge(
start_key="answer_refined_question", # HERE
end_key="ingest_refined_answers",
)
graph.add_edge(
start_key="ingest_refined_answers",
end_key="generate_refined_answer",
)
# graph.add_conditional_edges(
# source="refined_answer_decision",
# path=continue_to_refined_answer_or_end,
# path_map=["refined_answer_subgraph", END],
# )
# graph.add_edge(
# start_key="refined_answer_subgraph",
# end_key="generate_refined_answer",
# )
graph.add_edge(
start_key="generate_refined_answer",
end_key="answer_comparison",
)
graph.add_edge(
start_key="answer_comparison",
end_key="logging_node",
)
graph.add_edge(
start_key="logging_node",
end_key=END,
)
# graph.add_edge(
# start_key="generate_refined_answer",
# end_key="check_refined_answer",
# )
# graph.add_edge(
# start_key="check_refined_answer",
# end_key=END,
# )
return graph
if __name__ == "__main__":
pass
from onyx.db.engine import get_session_context_manager
from onyx.llm.factory import get_default_llms
from onyx.context.search.models import SearchRequest
graph = main_graph_builder()
compiled_graph = graph.compile()
primary_llm, fast_llm = get_default_llms()
with get_session_context_manager() as db_session:
search_request = SearchRequest(query="Who created Excel?")
agent_a_config, search_tool = get_test_config(
db_session, primary_llm, fast_llm, search_request
)
inputs = MainInput(
base_question=agent_a_config.search_request.query, log_messages=[]
)
for thing in compiled_graph.stream(
input=inputs,
config={"configurable": {"config": agent_a_config}},
# stream_mode="debug",
# debug=True,
subgraphs=True,
):
logger.debug(thing)

View File

@@ -0,0 +1,36 @@
from pydantic import BaseModel
class FollowUpSubQuestion(BaseModel):
sub_question: str
sub_question_id: str
verified: bool
answered: bool
answer: str
class AgentTimings(BaseModel):
base_duration__s: float | None
refined_duration__s: float | None
full_duration__s: float | None
class AgentBaseMetrics(BaseModel):
num_verified_documents_total: int | None
num_verified_documents_core: int | None
verified_avg_score_core: float | None
num_verified_documents_base: int | float | None
verified_avg_score_base: float | None = None
base_doc_boost_factor: float | None = None
support_boost_factor: float | None = None
duration__s: float | None = None
class AgentRefinedMetrics(BaseModel):
refined_doc_boost_factor: float | None = None
refined_question_boost_factor: float | None = None
duration__s: float | None = None
class AgentAdditionalMetrics(BaseModel):
pass

View File

@@ -0,0 +1,113 @@
from datetime import datetime
from typing import cast
from langchain_core.runnables import RunnableConfig
from onyx.agents.agent_search.deep_search_a.main.models import AgentAdditionalMetrics
from onyx.agents.agent_search.deep_search_a.main.models import AgentTimings
from onyx.agents.agent_search.deep_search_a.main.operations import logger
from onyx.agents.agent_search.deep_search_a.main.states import MainOutput
from onyx.agents.agent_search.deep_search_a.main.states import MainState
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.agents.agent_search.shared_graph_utils.models import CombinedAgentMetrics
from onyx.db.chat import log_agent_metrics
from onyx.db.chat import log_agent_sub_question_results
def agent_logging(state: MainState, config: RunnableConfig) -> MainOutput:
now_start = datetime.now()
logger.debug(f"--------{now_start}--------LOGGING NODE---")
agent_start_time = state.agent_start_time
agent_base_end_time = state.agent_base_end_time
agent_refined_start_time = state.agent_refined_start_time or None
agent_refined_end_time = state.agent_refined_end_time or None
agent_end_time = agent_refined_end_time or agent_base_end_time
agent_base_duration = None
if agent_base_end_time:
agent_base_duration = (agent_base_end_time - agent_start_time).total_seconds()
agent_refined_duration = None
if agent_refined_start_time and agent_refined_end_time:
agent_refined_duration = (
agent_refined_end_time - agent_refined_start_time
).total_seconds()
agent_full_duration = None
if agent_end_time:
agent_full_duration = (agent_end_time - agent_start_time).total_seconds()
agent_type = "refined" if agent_refined_duration else "base"
agent_base_metrics = state.agent_base_metrics
agent_refined_metrics = state.agent_refined_metrics
combined_agent_metrics = CombinedAgentMetrics(
timings=AgentTimings(
base_duration__s=agent_base_duration,
refined_duration__s=agent_refined_duration,
full_duration__s=agent_full_duration,
),
base_metrics=agent_base_metrics,
refined_metrics=agent_refined_metrics,
additional_metrics=AgentAdditionalMetrics(),
)
persona_id = None
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
if agent_a_config.search_request.persona:
persona_id = agent_a_config.search_request.persona.id
user_id = None
user = agent_a_config.search_tool.user
if user:
user_id = user.id
# log the agent metrics
if agent_a_config.db_session is not None:
log_agent_metrics(
db_session=agent_a_config.db_session,
user_id=user_id,
persona_id=persona_id,
agent_type=agent_type,
start_time=agent_start_time,
agent_metrics=combined_agent_metrics,
)
if agent_a_config.use_persistence:
# Persist the sub-answer in the database
db_session = agent_a_config.db_session
chat_session_id = agent_a_config.chat_session_id
primary_message_id = agent_a_config.message_id
sub_question_answer_results = state.decomp_answer_results
log_agent_sub_question_results(
db_session=db_session,
chat_session_id=chat_session_id,
primary_message_id=primary_message_id,
sub_question_answer_results=sub_question_answer_results,
)
# if chat_session_id is not None and primary_message_id is not None and sub_question_id is not None:
# create_sub_answer(
# db_session=db_session,
# chat_session_id=chat_session_id,
# primary_message_id=primary_message_id,
# sub_question_id=sub_question_id,
# answer=answer_str,
# # )
# pass
now_end = datetime.now()
main_output = MainOutput(
log_messages=[
f"{now_end} -- Main - Logging, Time taken: {now_end - now_start}"
],
)
logger.debug(f"--------{now_end}--{now_end - now_start}--------LOGGING NODE END---")
logger.debug(f"--------{now_end}--{now_end - now_start}--------LOGGING NODE END---")
return main_output

View File

@@ -0,0 +1,99 @@
from datetime import datetime
from typing import cast
from langchain_core.messages import HumanMessage
from langchain_core.runnables import RunnableConfig
from onyx.agents.agent_search.deep_search_a.main.operations import logger
from onyx.agents.agent_search.deep_search_a.main.states import MainState
from onyx.agents.agent_search.deep_search_a.main.states import RoutingDecision
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
build_history_prompt,
)
from onyx.agents.agent_search.shared_graph_utils.prompts import AGENT_DECISION_PROMPT
from onyx.agents.agent_search.shared_graph_utils.prompts import (
AGENT_DECISION_PROMPT_AFTER_SEARCH,
)
from onyx.context.search.models import InferenceSection
from onyx.db.engine import get_session_context_manager
from onyx.llm.utils import check_number_of_tokens
from onyx.tools.tool_implementations.search.search_tool import (
SEARCH_RESPONSE_SUMMARY_ID,
)
from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary
def agent_path_decision(state: MainState, config: RunnableConfig) -> RoutingDecision:
now_start = datetime.now()
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
question = agent_a_config.search_request.query
perform_initial_search_path_decision = (
agent_a_config.perform_initial_search_path_decision
)
history = build_history_prompt(agent_a_config.prompt_builder)
logger.debug(f"--------{now_start}--------DECIDING TO SEARCH OR GO TO LLM---")
if perform_initial_search_path_decision:
search_tool = agent_a_config.search_tool
retrieved_docs: list[InferenceSection] = []
# new db session to avoid concurrency issues
with get_session_context_manager() as db_session:
for tool_response in search_tool.run(
query=question,
force_no_rerank=True,
alternate_db_session=db_session,
):
# get retrieved docs to send to the rest of the graph
if tool_response.id == SEARCH_RESPONSE_SUMMARY_ID:
response = cast(SearchResponseSummary, tool_response.response)
retrieved_docs = response.top_sections
break
sample_doc_str = "\n\n".join(
[doc.combined_content for _, doc in enumerate(retrieved_docs[:3])]
)
agent_decision_prompt = AGENT_DECISION_PROMPT_AFTER_SEARCH.format(
question=question, sample_doc_str=sample_doc_str, history=history
)
else:
sample_doc_str = ""
agent_decision_prompt = AGENT_DECISION_PROMPT.format(
question=question, history=history
)
msg = [HumanMessage(content=agent_decision_prompt)]
# Get the rewritten queries in a defined format
model = agent_a_config.fast_llm
# no need to stream this
resp = model.invoke(msg)
if isinstance(resp.content, str) and "research" in resp.content.lower():
routing = "agent_search"
else:
routing = "LLM"
now_end = datetime.now()
logger.debug(
f"--------{now_end}--{now_end - now_start}--------DECIDING TO SEARCH OR GO TO LLM END---"
)
check_number_of_tokens(agent_decision_prompt)
return RoutingDecision(
# Decide which route to take
routing=routing,
sample_doc_str=sample_doc_str,
log_messages=[
f"{now_end} -- Path decision: {routing}, Time taken: {now_end - now_start}"
],
)

View File

@@ -0,0 +1,31 @@
from datetime import datetime
from typing import Literal
from langgraph.types import Command
from onyx.agents.agent_search.deep_search_a.main.states import MainState
def agent_path_routing(
state: MainState,
) -> Command[Literal["agent_search_start", "LLM"]]:
now_start = datetime.now()
routing = state.routing if hasattr(state, "routing") else "agent_search"
if routing == "agent_search":
agent_path = "agent_search_start"
else:
agent_path = "LLM"
now_end = datetime.now()
return Command(
# state update
update={
"log_messages": [
f"{now_end} -- Main - Path routing: {agent_path}, Time taken: {now_end - now_start}"
]
},
# control flow
goto=agent_path,
)

View File

@@ -0,0 +1,8 @@
from datetime import datetime
from onyx.agents.agent_search.core_state import CoreState
def agent_search_start(state: CoreState) -> CoreState:
now_end = datetime.now()
return CoreState(log_messages=[f"{now_end} -- Main - Agent search start"])

View File

@@ -0,0 +1,60 @@
from datetime import datetime
from typing import cast
from langchain_core.callbacks.manager import dispatch_custom_event
from langchain_core.messages import HumanMessage
from langchain_core.runnables import RunnableConfig
from onyx.agents.agent_search.deep_search_a.main.operations import logger
from onyx.agents.agent_search.deep_search_a.main.states import AnswerComparison
from onyx.agents.agent_search.deep_search_a.main.states import MainState
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.agents.agent_search.shared_graph_utils.prompts import ANSWER_COMPARISON_PROMPT
from onyx.chat.models import RefinedAnswerImprovement
def answer_comparison(state: MainState, config: RunnableConfig) -> AnswerComparison:
now_start = datetime.now()
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
question = agent_a_config.search_request.query
initial_answer = state.initial_answer
refined_answer = state.refined_answer
logger.debug(f"--------{now_start}--------ANSWER COMPARISON STARTED--")
answer_comparison_prompt = ANSWER_COMPARISON_PROMPT.format(
question=question, initial_answer=initial_answer, refined_answer=refined_answer
)
msg = [HumanMessage(content=answer_comparison_prompt)]
# Get the rewritten queries in a defined format
model = agent_a_config.fast_llm
# no need to stream this
resp = model.invoke(msg)
refined_answer_improvement = (
isinstance(resp.content, str) and "yes" in resp.content.lower()
)
dispatch_custom_event(
"refined_answer_improvement",
RefinedAnswerImprovement(
refined_answer_improvement=refined_answer_improvement,
),
)
now_end = datetime.now()
logger.debug(
f"--------{now_end}--{now_end - now_start}--------ANSWER COMPARISON COMPLETED---"
)
return AnswerComparison(
refined_answer_improvement=refined_answer_improvement,
log_messages=[
f"{now_start} -- Answer comparison: {refined_answer_improvement}, Time taken: {now_end - now_start}"
],
)

View File

@@ -0,0 +1,89 @@
from datetime import datetime
from typing import Any
from typing import cast
from langchain_core.callbacks.manager import dispatch_custom_event
from langchain_core.messages import HumanMessage
from langchain_core.messages import merge_content
from langchain_core.runnables import RunnableConfig
from onyx.agents.agent_search.deep_search_a.main.operations import logger
from onyx.agents.agent_search.deep_search_a.main.states import InitialAnswerUpdate
from onyx.agents.agent_search.deep_search_a.main.states import MainState
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.agents.agent_search.shared_graph_utils.prompts import (
ASSISTANT_SYSTEM_PROMPT_DEFAULT,
)
from onyx.agents.agent_search.shared_graph_utils.prompts import (
ASSISTANT_SYSTEM_PROMPT_PERSONA,
)
from onyx.agents.agent_search.shared_graph_utils.prompts import DIRECT_LLM_PROMPT
from onyx.agents.agent_search.shared_graph_utils.utils import get_persona_prompt
from onyx.chat.models import AgentAnswerPiece
def direct_llm_handling(
state: MainState, config: RunnableConfig
) -> InitialAnswerUpdate:
now_start = datetime.now()
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
question = agent_a_config.search_request.query
persona_prompt = get_persona_prompt(agent_a_config.search_request.persona)
if len(persona_prompt) == 0:
persona_specification = ASSISTANT_SYSTEM_PROMPT_DEFAULT
else:
persona_specification = ASSISTANT_SYSTEM_PROMPT_PERSONA.format(
persona_prompt=persona_prompt
)
logger.debug(f"--------{now_start}--------LLM HANDLING START---")
model = agent_a_config.fast_llm
msg = [
HumanMessage(
content=DIRECT_LLM_PROMPT.format(
persona_specification=persona_specification, question=question
)
)
]
streamed_tokens: list[str | list[str | dict[str, Any]]] = [""]
for message in model.stream(msg):
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
content = message.content
if not isinstance(content, str):
raise ValueError(
f"Expected content to be a string, but got {type(content)}"
)
dispatch_custom_event(
"initial_agent_answer",
AgentAnswerPiece(
answer_piece=content,
level=0,
level_question_nr=0,
answer_type="agent_level_answer",
),
)
streamed_tokens.append(content)
response = merge_content(*streamed_tokens)
answer = cast(str, response)
now_end = datetime.now()
logger.debug(f"--------{now_end}--{now_end - now_start}--------LLM HANDLING END---")
return InitialAnswerUpdate(
initial_answer=answer,
initial_agent_stats=None,
generated_sub_questions=[],
agent_base_end_time=now_end,
agent_base_metrics=None,
log_messages=[
f"{now_end} -- Main - LLM handling: {answer}, Time taken: {now_end - now_start}"
],
)

View File

@@ -0,0 +1,137 @@
import json
import re
from datetime import datetime
from typing import cast
from langchain_core.messages import HumanMessage
from langchain_core.messages import merge_message_runs
from langchain_core.runnables import RunnableConfig
from onyx.agents.agent_search.deep_search_a.main.operations import logger
from onyx.agents.agent_search.deep_search_a.main.states import (
EntityTermExtractionUpdate,
)
from onyx.agents.agent_search.deep_search_a.main.states import MainState
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
trim_prompt_piece,
)
from onyx.agents.agent_search.shared_graph_utils.models import Entity
from onyx.agents.agent_search.shared_graph_utils.models import (
EntityRelationshipTermExtraction,
)
from onyx.agents.agent_search.shared_graph_utils.models import Relationship
from onyx.agents.agent_search.shared_graph_utils.models import Term
from onyx.agents.agent_search.shared_graph_utils.operators import (
dedup_inference_sections,
)
from onyx.agents.agent_search.shared_graph_utils.prompts import ENTITY_TERM_PROMPT
from onyx.agents.agent_search.shared_graph_utils.utils import format_docs
def entity_term_extraction_llm(
state: MainState, config: RunnableConfig
) -> EntityTermExtractionUpdate:
now_start = datetime.now()
logger.debug(f"--------{now_start}--------GENERATE ENTITIES & TERMS---")
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
if not agent_a_config.allow_refinement:
now_end = datetime.now()
return EntityTermExtractionUpdate(
entity_retlation_term_extractions=EntityRelationshipTermExtraction(
entities=[],
relationships=[],
terms=[],
),
log_messages=[
f"{now_end} -- Main - ETR Extraction, Time taken: {now_end - now_start}"
],
)
# first four lines duplicates from generate_initial_answer
question = agent_a_config.search_request.query
sub_question_docs = state.documents
all_original_question_documents = state.all_original_question_documents
relevant_docs = dedup_inference_sections(
sub_question_docs, all_original_question_documents
)
# start with the entity/term/extraction
doc_context = format_docs(relevant_docs)
doc_context = trim_prompt_piece(
agent_a_config.fast_llm.config, doc_context, ENTITY_TERM_PROMPT + question
)
msg = [
HumanMessage(
content=ENTITY_TERM_PROMPT.format(question=question, context=doc_context),
)
]
fast_llm = agent_a_config.fast_llm
# Grader
llm_response_list = list(
fast_llm.stream(
prompt=msg,
)
)
llm_response = merge_message_runs(llm_response_list, chunk_separator="")[0].content
cleaned_response = re.sub(r"```json\n|\n```", "", llm_response)
parsed_response = json.loads(cleaned_response)
entities = []
relationships = []
terms = []
for entity in parsed_response.get("retrieved_entities_relationships", {}).get(
"entities", {}
):
entity_name = entity.get("entity_name", "")
entity_type = entity.get("entity_type", "")
entities.append(Entity(entity_name=entity_name, entity_type=entity_type))
for relationship in parsed_response.get("retrieved_entities_relationships", {}).get(
"relationships", {}
):
relationship_name = relationship.get("relationship_name", "")
relationship_type = relationship.get("relationship_type", "")
relationship_entities = relationship.get("relationship_entities", [])
relationships.append(
Relationship(
relationship_name=relationship_name,
relationship_type=relationship_type,
relationship_entities=relationship_entities,
)
)
for term in parsed_response.get("retrieved_entities_relationships", {}).get(
"terms", {}
):
term_name = term.get("term_name", "")
term_type = term.get("term_type", "")
term_similar_to = term.get("term_similar_to", [])
terms.append(
Term(
term_name=term_name,
term_type=term_type,
term_similar_to=term_similar_to,
)
)
now_end = datetime.now()
logger.debug(
f"--------{now_end}--{now_end - now_start}--------ENTITY TERM EXTRACTION END---"
)
return EntityTermExtractionUpdate(
entity_retlation_term_extractions=EntityRelationshipTermExtraction(
entities=entities,
relationships=relationships,
terms=terms,
),
log_messages=[
f"{now_end} -- Main - ETR Extraction, Time taken: {now_end - now_start}"
],
)

View File

@@ -0,0 +1,264 @@
from datetime import datetime
from typing import Any
from typing import cast
from langchain_core.callbacks.manager import dispatch_custom_event
from langchain_core.messages import HumanMessage
from langchain_core.messages import merge_content
from langchain_core.runnables import RunnableConfig
from onyx.agents.agent_search.deep_search_a.main.models import AgentBaseMetrics
from onyx.agents.agent_search.deep_search_a.main.operations import (
calculate_initial_agent_stats,
)
from onyx.agents.agent_search.deep_search_a.main.operations import get_query_info
from onyx.agents.agent_search.deep_search_a.main.operations import logger
from onyx.agents.agent_search.deep_search_a.main.operations import (
remove_document_citations,
)
from onyx.agents.agent_search.deep_search_a.main.states import InitialAnswerUpdate
from onyx.agents.agent_search.deep_search_a.main.states import MainState
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
build_history_prompt,
)
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
trim_prompt_piece,
)
from onyx.agents.agent_search.shared_graph_utils.models import InitialAgentResultStats
from onyx.agents.agent_search.shared_graph_utils.operators import (
dedup_inference_sections,
)
from onyx.agents.agent_search.shared_graph_utils.prompts import (
ASSISTANT_SYSTEM_PROMPT_DEFAULT,
)
from onyx.agents.agent_search.shared_graph_utils.prompts import (
ASSISTANT_SYSTEM_PROMPT_PERSONA,
)
from onyx.agents.agent_search.shared_graph_utils.prompts import INITIAL_RAG_PROMPT
from onyx.agents.agent_search.shared_graph_utils.prompts import (
INITIAL_RAG_PROMPT_NO_SUB_QUESTIONS,
)
from onyx.agents.agent_search.shared_graph_utils.prompts import (
SUB_QUESTION_ANSWER_TEMPLATE,
)
from onyx.agents.agent_search.shared_graph_utils.prompts import UNKNOWN_ANSWER
from onyx.agents.agent_search.shared_graph_utils.utils import format_docs
from onyx.agents.agent_search.shared_graph_utils.utils import get_persona_prompt
from onyx.agents.agent_search.shared_graph_utils.utils import get_today_prompt
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
from onyx.chat.models import AgentAnswerPiece
from onyx.chat.models import ExtendedToolResponse
from onyx.tools.tool_implementations.search.search_tool import yield_search_responses
def generate_initial_answer(
state: MainState, config: RunnableConfig
) -> InitialAnswerUpdate:
now_start = datetime.now()
logger.debug(f"--------{now_start}--------GENERATE INITIAL---")
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
question = agent_a_config.search_request.query
persona_prompt = get_persona_prompt(agent_a_config.search_request.persona)
history = build_history_prompt(agent_a_config.prompt_builder)
date_str = get_today_prompt()
sub_question_docs = state.context_documents
all_original_question_documents = state.all_original_question_documents
relevant_docs = dedup_inference_sections(
sub_question_docs, all_original_question_documents
)
decomp_questions = []
if len(relevant_docs) == 0:
dispatch_custom_event(
"initial_agent_answer",
AgentAnswerPiece(
answer_piece=UNKNOWN_ANSWER,
level=0,
level_question_nr=0,
answer_type="agent_level_answer",
),
)
answer = UNKNOWN_ANSWER
initial_agent_stats = InitialAgentResultStats(
sub_questions={},
original_question={},
agent_effectiveness={},
)
else:
# Use the query info from the base document retrieval
query_info = get_query_info(state.original_question_retrieval_results)
for tool_response in yield_search_responses(
query=question,
reranked_sections=relevant_docs,
final_context_sections=relevant_docs,
search_query_info=query_info,
get_section_relevance=lambda: None, # TODO: add relevance
search_tool=agent_a_config.search_tool,
):
dispatch_custom_event(
"tool_response",
ExtendedToolResponse(
id=tool_response.id,
response=tool_response.response,
level=0,
level_question_nr=0, # 0, 0 is the base question
),
)
net_new_original_question_docs = []
for all_original_question_doc in all_original_question_documents:
if all_original_question_doc not in sub_question_docs:
net_new_original_question_docs.append(all_original_question_doc)
decomp_answer_results = state.decomp_answer_results
good_qa_list: list[str] = []
sub_question_nr = 1
for decomp_answer_result in decomp_answer_results:
decomp_questions.append(decomp_answer_result.question)
_, question_nr = parse_question_id(decomp_answer_result.question_id)
if (
decomp_answer_result.quality.lower().startswith("yes")
and len(decomp_answer_result.answer) > 0
and decomp_answer_result.answer != UNKNOWN_ANSWER
):
good_qa_list.append(
SUB_QUESTION_ANSWER_TEMPLATE.format(
sub_question=decomp_answer_result.question,
sub_answer=decomp_answer_result.answer,
sub_question_nr=sub_question_nr,
)
)
sub_question_nr += 1
if len(good_qa_list) > 0:
sub_question_answer_str = "\n\n------\n\n".join(good_qa_list)
else:
sub_question_answer_str = ""
# Determine which persona-specification prompt to use
if len(persona_prompt) == 0:
persona_specification = ASSISTANT_SYSTEM_PROMPT_DEFAULT
else:
persona_specification = ASSISTANT_SYSTEM_PROMPT_PERSONA.format(
persona_prompt=persona_prompt
)
# Determine which base prompt to use given the sub-question information
if len(good_qa_list) > 0:
base_prompt = INITIAL_RAG_PROMPT
else:
base_prompt = INITIAL_RAG_PROMPT_NO_SUB_QUESTIONS
model = agent_a_config.fast_llm
doc_context = format_docs(relevant_docs)
doc_context = trim_prompt_piece(
model.config,
doc_context,
base_prompt
+ sub_question_answer_str
+ persona_specification
+ history
+ date_str,
)
msg = [
HumanMessage(
content=base_prompt.format(
question=question,
answered_sub_questions=remove_document_citations(
sub_question_answer_str
),
relevant_docs=format_docs(relevant_docs),
persona_specification=persona_specification,
history=history,
date_prompt=date_str,
)
)
]
streamed_tokens: list[str | list[str | dict[str, Any]]] = [""]
for message in model.stream(msg):
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
content = message.content
if not isinstance(content, str):
raise ValueError(
f"Expected content to be a string, but got {type(content)}"
)
dispatch_custom_event(
"initial_agent_answer",
AgentAnswerPiece(
answer_piece=content,
level=0,
level_question_nr=0,
answer_type="agent_level_answer",
),
)
streamed_tokens.append(content)
response = merge_content(*streamed_tokens)
answer = cast(str, response)
initial_agent_stats = calculate_initial_agent_stats(
state.decomp_answer_results, state.original_question_retrieval_stats
)
logger.debug(
f"\n\nYYYYY--Sub-Questions:\n\n{sub_question_answer_str}\n\nStats:\n\n"
)
if initial_agent_stats:
logger.debug(initial_agent_stats.original_question)
logger.debug(initial_agent_stats.sub_questions)
logger.debug(initial_agent_stats.agent_effectiveness)
now_end = datetime.now()
logger.debug(
f"--------{now_end}--{now_end - now_start}--------INITIAL AGENT ANSWER END---\n\n"
)
agent_base_end_time = datetime.now()
agent_base_metrics = AgentBaseMetrics(
num_verified_documents_total=len(relevant_docs),
num_verified_documents_core=state.original_question_retrieval_stats.verified_count,
verified_avg_score_core=state.original_question_retrieval_stats.verified_avg_scores,
num_verified_documents_base=initial_agent_stats.sub_questions.get(
"num_verified_documents", None
),
verified_avg_score_base=initial_agent_stats.sub_questions.get(
"verified_avg_score", None
),
base_doc_boost_factor=initial_agent_stats.agent_effectiveness.get(
"utilized_chunk_ratio", None
),
support_boost_factor=initial_agent_stats.agent_effectiveness.get(
"support_ratio", None
),
duration__s=(agent_base_end_time - state.agent_start_time).total_seconds(),
)
return InitialAnswerUpdate(
initial_answer=answer,
initial_agent_stats=initial_agent_stats,
generated_sub_questions=decomp_questions,
agent_base_end_time=agent_base_end_time,
agent_base_metrics=agent_base_metrics,
log_messages=[
f"{now_end} -- Main - Initial Answer generation, Time taken: {now_end - now_start}"
],
)

View File

@@ -0,0 +1,56 @@
from datetime import datetime
from typing import cast
from langchain_core.messages import HumanMessage
from langchain_core.runnables import RunnableConfig
from onyx.agents.agent_search.deep_search_a.main.operations import logger
from onyx.agents.agent_search.deep_search_a.main.states import InitialAnswerBASEUpdate
from onyx.agents.agent_search.deep_search_a.main.states import MainState
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
trim_prompt_piece,
)
from onyx.agents.agent_search.shared_graph_utils.prompts import INITIAL_RAG_BASE_PROMPT
from onyx.agents.agent_search.shared_graph_utils.utils import format_docs
def generate_initial_base_search_only_answer(
state: MainState,
config: RunnableConfig,
) -> InitialAnswerBASEUpdate:
now_start = datetime.now()
logger.debug(f"--------{now_start}--------GENERATE INITIAL BASE ANSWER---")
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
question = agent_a_config.search_request.query
original_question_docs = state.all_original_question_documents
model = agent_a_config.fast_llm
doc_context = format_docs(original_question_docs)
doc_context = trim_prompt_piece(
model.config, doc_context, INITIAL_RAG_BASE_PROMPT + question
)
msg = [
HumanMessage(
content=INITIAL_RAG_BASE_PROMPT.format(
question=question,
context=doc_context,
)
)
]
# Grader
response = model.invoke(msg)
answer = response.pretty_repr()
now_end = datetime.now()
logger.debug(
f"--------{now_end}--{now_end - now_start}--------INITIAL BASE ANSWER END---\n\n"
)
return InitialAnswerBASEUpdate(initial_base_answer=answer)

View File

@@ -0,0 +1,326 @@
from datetime import datetime
from typing import Any
from typing import cast
from langchain_core.callbacks.manager import dispatch_custom_event
from langchain_core.messages import HumanMessage
from langchain_core.messages import merge_content
from langchain_core.runnables import RunnableConfig
from onyx.agents.agent_search.deep_search_a.main.models import AgentRefinedMetrics
from onyx.agents.agent_search.deep_search_a.main.operations import get_query_info
from onyx.agents.agent_search.deep_search_a.main.operations import logger
from onyx.agents.agent_search.deep_search_a.main.operations import (
remove_document_citations,
)
from onyx.agents.agent_search.deep_search_a.main.states import MainState
from onyx.agents.agent_search.deep_search_a.main.states import RefinedAnswerUpdate
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
build_history_prompt,
)
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
trim_prompt_piece,
)
from onyx.agents.agent_search.shared_graph_utils.models import RefinedAgentStats
from onyx.agents.agent_search.shared_graph_utils.operators import (
dedup_inference_sections,
)
from onyx.agents.agent_search.shared_graph_utils.prompts import (
ASSISTANT_SYSTEM_PROMPT_DEFAULT,
)
from onyx.agents.agent_search.shared_graph_utils.prompts import (
ASSISTANT_SYSTEM_PROMPT_PERSONA,
)
from onyx.agents.agent_search.shared_graph_utils.prompts import REVISED_RAG_PROMPT
from onyx.agents.agent_search.shared_graph_utils.prompts import (
REVISED_RAG_PROMPT_NO_SUB_QUESTIONS,
)
from onyx.agents.agent_search.shared_graph_utils.prompts import (
SUB_QUESTION_ANSWER_TEMPLATE,
)
from onyx.agents.agent_search.shared_graph_utils.prompts import UNKNOWN_ANSWER
from onyx.agents.agent_search.shared_graph_utils.utils import format_docs
from onyx.agents.agent_search.shared_graph_utils.utils import get_persona_prompt
from onyx.agents.agent_search.shared_graph_utils.utils import get_today_prompt
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
from onyx.chat.models import AgentAnswerPiece
from onyx.chat.models import ExtendedToolResponse
from onyx.tools.tool_implementations.search.search_tool import yield_search_responses
def generate_refined_answer(
state: MainState, config: RunnableConfig
) -> RefinedAnswerUpdate:
now_start = datetime.now()
logger.debug(f"--------{now_start}--------GENERATE REFINED ANSWER---")
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
question = agent_a_config.search_request.query
persona_prompt = get_persona_prompt(agent_a_config.search_request.persona)
history = build_history_prompt(agent_a_config.prompt_builder)
date_str = get_today_prompt()
initial_documents = state.documents
revised_documents = state.refined_documents
combined_documents = dedup_inference_sections(initial_documents, revised_documents)
query_info = get_query_info(state.original_question_retrieval_results)
# stream refined answer docs
for tool_response in yield_search_responses(
query=question,
reranked_sections=combined_documents,
final_context_sections=combined_documents,
search_query_info=query_info,
get_section_relevance=lambda: None, # TODO: add relevance
search_tool=agent_a_config.search_tool,
):
dispatch_custom_event(
"tool_response",
ExtendedToolResponse(
id=tool_response.id,
response=tool_response.response,
level=1,
level_question_nr=0, # 0, 0 is the base question
),
)
if len(initial_documents) > 0:
revision_doc_effectiveness = len(combined_documents) / len(initial_documents)
elif len(revised_documents) == 0:
revision_doc_effectiveness = 0.0
else:
revision_doc_effectiveness = 10.0
decomp_answer_results = state.decomp_answer_results
# revised_answer_results = state.refined_decomp_answer_results
good_qa_list: list[str] = []
decomp_questions = []
initial_good_sub_questions: list[str] = []
new_revised_good_sub_questions: list[str] = []
sub_question_nr = 1
for decomp_answer_result in decomp_answer_results:
question_level, question_nr = parse_question_id(
decomp_answer_result.question_id
)
decomp_questions.append(decomp_answer_result.question)
if (
decomp_answer_result.quality.lower().startswith("yes")
and len(decomp_answer_result.answer) > 0
and decomp_answer_result.answer != UNKNOWN_ANSWER
):
good_qa_list.append(
SUB_QUESTION_ANSWER_TEMPLATE.format(
sub_question=decomp_answer_result.question,
sub_answer=decomp_answer_result.answer,
sub_question_nr=sub_question_nr,
)
)
if question_level == 0:
initial_good_sub_questions.append(decomp_answer_result.question)
else:
new_revised_good_sub_questions.append(decomp_answer_result.question)
sub_question_nr += 1
initial_good_sub_questions = list(set(initial_good_sub_questions))
new_revised_good_sub_questions = list(set(new_revised_good_sub_questions))
total_good_sub_questions = list(
set(initial_good_sub_questions + new_revised_good_sub_questions)
)
if len(initial_good_sub_questions) > 0:
revision_question_efficiency: float = len(total_good_sub_questions) / len(
initial_good_sub_questions
)
elif len(new_revised_good_sub_questions) > 0:
revision_question_efficiency = 10.0
else:
revision_question_efficiency = 1.0
sub_question_answer_str = "\n\n------\n\n".join(list(set(good_qa_list)))
# original answer
initial_answer = state.initial_answer
# Determine which persona-specification prompt to use
if len(persona_prompt) == 0:
persona_specification = ASSISTANT_SYSTEM_PROMPT_DEFAULT
else:
persona_specification = ASSISTANT_SYSTEM_PROMPT_PERSONA.format(
persona_prompt=persona_prompt
)
# Determine which base prompt to use given the sub-question information
if len(good_qa_list) > 0:
base_prompt = REVISED_RAG_PROMPT
else:
base_prompt = REVISED_RAG_PROMPT_NO_SUB_QUESTIONS
model = agent_a_config.fast_llm
relevant_docs = format_docs(combined_documents)
relevant_docs = trim_prompt_piece(
model.config,
relevant_docs,
base_prompt
+ question
+ sub_question_answer_str
+ relevant_docs
+ initial_answer
+ persona_specification
+ history,
)
msg = [
HumanMessage(
content=base_prompt.format(
question=question,
history=history,
answered_sub_questions=remove_document_citations(
sub_question_answer_str
),
relevant_docs=relevant_docs,
initial_answer=remove_document_citations(initial_answer),
persona_specification=persona_specification,
date_prompt=date_str,
)
)
]
# Grader
streamed_tokens: list[str | list[str | dict[str, Any]]] = [""]
for message in model.stream(msg):
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
content = message.content
if not isinstance(content, str):
raise ValueError(
f"Expected content to be a string, but got {type(content)}"
)
dispatch_custom_event(
"refined_agent_answer",
AgentAnswerPiece(
answer_piece=content,
level=1,
level_question_nr=0,
answer_type="agent_level_answer",
),
)
streamed_tokens.append(content)
response = merge_content(*streamed_tokens)
answer = cast(str, response)
# refined_agent_stats = _calculate_refined_agent_stats(
# state.decomp_answer_results, state.original_question_retrieval_stats
# )
initial_good_sub_questions_str = "\n".join(list(set(initial_good_sub_questions)))
new_revised_good_sub_questions_str = "\n".join(
list(set(new_revised_good_sub_questions))
)
refined_agent_stats = RefinedAgentStats(
revision_doc_efficiency=revision_doc_effectiveness,
revision_question_efficiency=revision_question_efficiency,
)
logger.debug(
f"\n\n---INITIAL ANSWER START---\n\n Answer:\n Agent: {initial_answer}"
)
logger.debug("-" * 10)
logger.debug(f"\n\n---REVISED AGENT ANSWER START---\n\n Answer:\n Agent: {answer}")
logger.debug("-" * 100)
logger.debug(f"\n\nINITAL Sub-Questions\n\n{initial_good_sub_questions_str}\n\n")
logger.debug("-" * 10)
logger.debug(
f"\n\nNEW REVISED Sub-Questions\n\n{new_revised_good_sub_questions_str}\n\n"
)
logger.debug("-" * 100)
logger.debug(
f"\n\nINITAL & REVISED Sub-Questions & Answers:\n\n{sub_question_answer_str}\n\nStas:\n\n"
)
logger.debug("-" * 100)
if state.initial_agent_stats:
initial_doc_boost_factor = state.initial_agent_stats.agent_effectiveness.get(
"utilized_chunk_ratio", "--"
)
initial_support_boost_factor = (
state.initial_agent_stats.agent_effectiveness.get("support_ratio", "--")
)
num_initial_verified_docs = state.initial_agent_stats.original_question.get(
"num_verified_documents", "--"
)
initial_verified_docs_avg_score = (
state.initial_agent_stats.original_question.get("verified_avg_score", "--")
)
initial_sub_questions_verified_docs = (
state.initial_agent_stats.sub_questions.get("num_verified_documents", "--")
)
logger.debug("INITIAL AGENT STATS")
logger.debug(f"Document Boost Factor: {initial_doc_boost_factor}")
logger.debug(f"Support Boost Factor: {initial_support_boost_factor}")
logger.debug(f"Originally Verified Docs: {num_initial_verified_docs}")
logger.debug(
f"Originally Verified Docs Avg Score: {initial_verified_docs_avg_score}"
)
logger.debug(
f"Sub-Questions Verified Docs: {initial_sub_questions_verified_docs}"
)
if refined_agent_stats:
logger.debug("-" * 10)
logger.debug("REFINED AGENT STATS")
logger.debug(
f"Revision Doc Factor: {refined_agent_stats.revision_doc_efficiency}"
)
logger.debug(
f"Revision Question Factor: {refined_agent_stats.revision_question_efficiency}"
)
now_end = datetime.now()
logger.debug(
f"--------{now_end}--{now_end - now_start}--------INITIAL AGENT ANSWER END---\n\n"
)
agent_refined_end_time = datetime.now()
if state.agent_refined_start_time:
agent_refined_duration = (
agent_refined_end_time - state.agent_refined_start_time
).total_seconds()
else:
agent_refined_duration = None
agent_refined_metrics = AgentRefinedMetrics(
refined_doc_boost_factor=refined_agent_stats.revision_doc_efficiency,
refined_question_boost_factor=refined_agent_stats.revision_question_efficiency,
duration__s=agent_refined_duration,
)
now_end = datetime.now()
logger.debug(
f"--------{now_end}--{now_end - now_start}--------REFINED ANSWER UPDATE END---"
)
return RefinedAnswerUpdate(
refined_answer=answer,
refined_answer_quality=True, # TODO: replace this with the actual check value
refined_agent_stats=refined_agent_stats,
agent_refined_end_time=agent_refined_end_time,
agent_refined_metrics=agent_refined_metrics,
)

View File

@@ -0,0 +1,39 @@
from datetime import datetime
from onyx.agents.agent_search.deep_search_a.base_raw_search.states import (
BaseRawSearchOutput,
)
from onyx.agents.agent_search.deep_search_a.main.operations import logger
from onyx.agents.agent_search.deep_search_a.main.states import ExpandedRetrievalUpdate
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkStats
def ingest_initial_base_retrieval(
state: BaseRawSearchOutput,
) -> ExpandedRetrievalUpdate:
now_start = datetime.now()
logger.debug(f"--------{now_start}--------INGEST INITIAL RETRIEVAL---")
sub_question_retrieval_stats = (
state.base_expanded_retrieval_result.sub_question_retrieval_stats
)
if sub_question_retrieval_stats is None:
sub_question_retrieval_stats = AgentChunkStats()
else:
sub_question_retrieval_stats = sub_question_retrieval_stats
now_end = datetime.now()
logger.debug(
f"--------{now_end}--{now_end - now_start}--------INGEST INITIAL RETRIEVAL END---"
)
return ExpandedRetrievalUpdate(
original_question_retrieval_results=state.base_expanded_retrieval_result.expanded_queries_results,
all_original_question_documents=state.base_expanded_retrieval_result.context_documents,
original_question_retrieval_stats=sub_question_retrieval_stats,
log_messages=[
f"{now_end} -- Main - Ingestion base retrieval, Time taken: {now_end - now_start}"
],
)

View File

@@ -0,0 +1,41 @@
from datetime import datetime
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.states import (
AnswerQuestionOutput,
)
from onyx.agents.agent_search.deep_search_a.main.operations import logger
from onyx.agents.agent_search.deep_search_a.main.states import DecompAnswersUpdate
from onyx.agents.agent_search.shared_graph_utils.operators import (
dedup_inference_sections,
)
def ingest_initial_sub_question_answers(
state: AnswerQuestionOutput,
) -> DecompAnswersUpdate:
now_start = datetime.now()
logger.debug(f"--------{now_start}--------INGEST ANSWERS---")
documents = []
context_documents = []
answer_results = state.answer_results if hasattr(state, "answer_results") else []
for answer_result in answer_results:
documents.extend(answer_result.documents)
context_documents.extend(answer_result.context_documents)
now_end = datetime.now()
logger.debug(
f"--------{now_end}--{now_end - now_start}--------INGEST ANSWERS END---"
)
return DecompAnswersUpdate(
# Deduping is done by the documents operator for the main graph
# so we might not need to dedup here
documents=dedup_inference_sections(documents, []),
context_documents=dedup_inference_sections(context_documents, []),
decomp_answer_results=answer_results,
log_messages=[
f"{now_end} -- Main - Ingest initial processed sub questions, Time taken: {now_end - now_start}"
],
)

View File

@@ -0,0 +1,39 @@
from datetime import datetime
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.states import (
AnswerQuestionOutput,
)
from onyx.agents.agent_search.deep_search_a.main.operations import logger
from onyx.agents.agent_search.deep_search_a.main.states import DecompAnswersUpdate
from onyx.agents.agent_search.shared_graph_utils.operators import (
dedup_inference_sections,
)
def ingest_refined_answers(
state: AnswerQuestionOutput,
) -> DecompAnswersUpdate:
now_start = datetime.now()
logger.debug(f"--------{now_start}--------INGEST FOLLOW UP ANSWERS---")
documents = []
answer_results = state.answer_results if hasattr(state, "answer_results") else []
for answer_result in answer_results:
documents.extend(answer_result.documents)
now_end = datetime.now()
logger.debug(
f"--------{now_end}--{now_end - now_start}--------INGEST FOLLOW UP ANSWERS END---"
)
return DecompAnswersUpdate(
# Deduping is done by the documents operator for the main graph
# so we might not need to dedup here
documents=dedup_inference_sections(documents, []),
decomp_answer_results=answer_results,
log_messages=[
f"{now_end} -- Main - Ingest refined answers, Time taken: {now_end - now_start}"
],
)

View File

@@ -0,0 +1,40 @@
from datetime import datetime
from onyx.agents.agent_search.deep_search_a.main.operations import logger
from onyx.agents.agent_search.deep_search_a.main.states import (
InitialAnswerQualityUpdate,
)
from onyx.agents.agent_search.deep_search_a.main.states import MainState
def initial_answer_quality_check(state: MainState) -> InitialAnswerQualityUpdate:
"""
Check whether the final output satisfies the original user question
Args:
state (messages): The current state
Returns:
InitialAnswerQualityUpdate
"""
now_start = datetime.now()
logger.debug(
f"--------{now_start}--------Checking for base answer validity - for not set True/False manually"
)
verdict = True
now_end = datetime.now()
logger.debug(
f"--------{now_end}--{now_end - now_start}--------INITIAL ANSWER QUALITY CHECK END---"
)
return InitialAnswerQualityUpdate(
initial_answer_quality=verdict,
log_messages=[
f"{now_end} -- Main - Initial answer quality check, Time taken: {now_end - now_start}"
],
)

View File

@@ -0,0 +1,150 @@
from datetime import datetime
from typing import cast
from langchain_core.callbacks.manager import dispatch_custom_event
from langchain_core.messages import HumanMessage
from langchain_core.messages import merge_content
from langchain_core.runnables import RunnableConfig
from onyx.agents.agent_search.deep_search_a.main.models import AgentRefinedMetrics
from onyx.agents.agent_search.deep_search_a.main.operations import dispatch_subquestion
from onyx.agents.agent_search.deep_search_a.main.operations import logger
from onyx.agents.agent_search.deep_search_a.main.states import BaseDecompUpdate
from onyx.agents.agent_search.deep_search_a.main.states import MainState
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
build_history_prompt,
)
from onyx.agents.agent_search.shared_graph_utils.prompts import (
INITIAL_DECOMPOSITION_PROMPT_QUESTIONS,
)
from onyx.agents.agent_search.shared_graph_utils.prompts import (
INITIAL_DECOMPOSITION_PROMPT_QUESTIONS_AFTER_SEARCH,
)
from onyx.agents.agent_search.shared_graph_utils.utils import dispatch_separated
from onyx.chat.models import StreamStopInfo
from onyx.chat.models import StreamStopReason
from onyx.chat.models import SubQuestionPiece
from onyx.context.search.models import InferenceSection
from onyx.db.engine import get_session_context_manager
from onyx.tools.tool_implementations.search.search_tool import (
SEARCH_RESPONSE_SUMMARY_ID,
)
from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary
def initial_sub_question_creation(
state: MainState, config: RunnableConfig
) -> BaseDecompUpdate:
now_start = datetime.now()
logger.debug(f"--------{now_start}--------BASE DECOMP START---")
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
question = agent_a_config.search_request.query
chat_session_id = agent_a_config.chat_session_id
primary_message_id = agent_a_config.message_id
perform_initial_search_decomposition = (
agent_a_config.perform_initial_search_decomposition
)
perform_initial_search_path_decision = (
agent_a_config.perform_initial_search_path_decision
)
history = build_history_prompt(agent_a_config.prompt_builder)
# Use the initial search results to inform the decomposition
sample_doc_str = state.sample_doc_str if hasattr(state, "sample_doc_str") else ""
if not chat_session_id or not primary_message_id:
raise ValueError(
"chat_session_id and message_id must be provided for agent search"
)
agent_start_time = datetime.now()
# Initial search to inform decomposition. Just get top 3 fits
if perform_initial_search_decomposition:
if not perform_initial_search_path_decision:
search_tool = agent_a_config.search_tool
retrieved_docs: list[InferenceSection] = []
# new db session to avoid concurrency issues
with get_session_context_manager() as db_session:
for tool_response in search_tool.run(
query=question,
force_no_rerank=True,
alternate_db_session=db_session,
):
# get retrieved docs to send to the rest of the graph
if tool_response.id == SEARCH_RESPONSE_SUMMARY_ID:
response = cast(SearchResponseSummary, tool_response.response)
retrieved_docs = response.top_sections
break
sample_doc_str = "\n\n".join(
[doc.combined_content for _, doc in enumerate(retrieved_docs[:3])]
)
decomposition_prompt = (
INITIAL_DECOMPOSITION_PROMPT_QUESTIONS_AFTER_SEARCH.format(
question=question, sample_doc_str=sample_doc_str, history=history
)
)
else:
decomposition_prompt = INITIAL_DECOMPOSITION_PROMPT_QUESTIONS.format(
question=question, history=history
)
# Start decomposition
msg = [HumanMessage(content=decomposition_prompt)]
# Get the rewritten queries in a defined format
model = agent_a_config.fast_llm
# Send the initial question as a subquestion with number 0
dispatch_custom_event(
"decomp_qs",
SubQuestionPiece(
sub_question=question,
level=0,
level_question_nr=0,
),
)
# dispatches custom events for subquestion tokens, adding in subquestion ids.
streamed_tokens = dispatch_separated(model.stream(msg), dispatch_subquestion(0))
stop_event = StreamStopInfo(
stop_reason=StreamStopReason.FINISHED,
stream_type="sub_questions",
level=0,
)
dispatch_custom_event("stream_finished", stop_event)
deomposition_response = merge_content(*streamed_tokens)
# this call should only return strings. Commenting out for efficiency
# assert [type(tok) == str for tok in streamed_tokens]
# use no-op cast() instead of str() which runs code
# list_of_subquestions = clean_and_parse_list_string(cast(str, response))
list_of_subqs = cast(str, deomposition_response).split("\n")
decomp_list: list[str] = [sq.strip() for sq in list_of_subqs if sq.strip() != ""]
now_end = datetime.now()
logger.debug(f"--------{now_end}--{now_end - now_start}--------BASE DECOMP END---")
return BaseDecompUpdate(
initial_decomp_questions=decomp_list,
agent_start_time=agent_start_time,
agent_refined_start_time=None,
agent_refined_end_time=None,
agent_refined_metrics=AgentRefinedMetrics(
refined_doc_boost_factor=None,
refined_question_boost_factor=None,
duration__s=None,
),
)

View File

@@ -0,0 +1,47 @@
from datetime import datetime
from typing import cast
from langchain_core.runnables import RunnableConfig
from onyx.agents.agent_search.deep_search_a.main.operations import logger
from onyx.agents.agent_search.deep_search_a.main.states import MainState
from onyx.agents.agent_search.deep_search_a.main.states import (
RequireRefinedAnswerUpdate,
)
from onyx.agents.agent_search.models import AgentSearchConfig
def refined_answer_decision(
state: MainState, config: RunnableConfig
) -> RequireRefinedAnswerUpdate:
now_start = datetime.now()
logger.debug(f"--------{now_start}--------REFINED ANSWER DECISION---")
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
if "?" in agent_a_config.search_request.query:
decision = False
else:
decision = True
decision = True
now_end = datetime.now()
logger.debug(
f"--------{now_end}--{now_end - now_start}--------REFINED ANSWER DECISION END---"
)
log_messages = [
f"{now_end} -- Main - Refined answer decision: {decision}, Time taken: {now_end - now_start}"
]
if agent_a_config.allow_refinement:
return RequireRefinedAnswerUpdate(
require_refined_answer=decision,
log_messages=log_messages,
)
else:
return RequireRefinedAnswerUpdate(
require_refined_answer=False,
log_messages=log_messages,
)

View File

@@ -0,0 +1,116 @@
from datetime import datetime
from typing import cast
from langchain_core.callbacks.manager import dispatch_custom_event
from langchain_core.messages import HumanMessage
from langchain_core.messages import merge_content
from langchain_core.runnables import RunnableConfig
from onyx.agents.agent_search.deep_search_a.main.models import FollowUpSubQuestion
from onyx.agents.agent_search.deep_search_a.main.operations import dispatch_subquestion
from onyx.agents.agent_search.deep_search_a.main.operations import logger
from onyx.agents.agent_search.deep_search_a.main.states import (
FollowUpSubQuestionsUpdate,
)
from onyx.agents.agent_search.deep_search_a.main.states import MainState
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
build_history_prompt,
)
from onyx.agents.agent_search.shared_graph_utils.prompts import DEEP_DECOMPOSE_PROMPT
from onyx.agents.agent_search.shared_graph_utils.utils import dispatch_separated
from onyx.agents.agent_search.shared_graph_utils.utils import (
format_entity_term_extraction,
)
from onyx.agents.agent_search.shared_graph_utils.utils import make_question_id
from onyx.tools.models import ToolCallKickoff
def refined_sub_question_creation(
state: MainState, config: RunnableConfig
) -> FollowUpSubQuestionsUpdate:
""" """
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
dispatch_custom_event(
"start_refined_answer_creation",
ToolCallKickoff(
tool_name="agent_search_1",
tool_args={
"query": agent_a_config.search_request.query,
"answer": state.initial_answer,
},
),
)
now_start = datetime.now()
logger.debug(f"--------{now_start}--------FOLLOW UP DECOMPOSE---")
agent_refined_start_time = datetime.now()
question = agent_a_config.search_request.query
base_answer = state.initial_answer
history = build_history_prompt(agent_a_config.prompt_builder)
# get the entity term extraction dict and properly format it
entity_retlation_term_extractions = state.entity_retlation_term_extractions
entity_term_extraction_str = format_entity_term_extraction(
entity_retlation_term_extractions
)
initial_question_answers = state.decomp_answer_results
addressed_question_list = [
x.question for x in initial_question_answers if "yes" in x.quality.lower()
]
failed_question_list = [
x.question for x in initial_question_answers if "no" in x.quality.lower()
]
msg = [
HumanMessage(
content=DEEP_DECOMPOSE_PROMPT.format(
question=question,
history=history,
entity_term_extraction_str=entity_term_extraction_str,
base_answer=base_answer,
answered_sub_questions="\n - ".join(addressed_question_list),
failed_sub_questions="\n - ".join(failed_question_list),
),
)
]
# Grader
model = agent_a_config.fast_llm
streamed_tokens = dispatch_separated(model.stream(msg), dispatch_subquestion(1))
response = merge_content(*streamed_tokens)
if isinstance(response, str):
parsed_response = [q for q in response.split("\n") if q.strip() != ""]
else:
raise ValueError("LLM response is not a string")
refined_sub_question_dict = {}
for sub_question_nr, sub_question in enumerate(parsed_response):
refined_sub_question = FollowUpSubQuestion(
sub_question=sub_question,
sub_question_id=make_question_id(1, sub_question_nr + 1),
verified=False,
answered=False,
answer="",
)
refined_sub_question_dict[sub_question_nr + 1] = refined_sub_question
now_end = datetime.now()
logger.debug(
f"--------{now_end}--{now_end - now_start}--------FOLLOW UP DECOMPOSE END---"
)
return FollowUpSubQuestionsUpdate(
refined_sub_questions=refined_sub_question_dict,
agent_refined_start_time=agent_refined_start_time,
)

View File

@@ -0,0 +1,12 @@
from datetime import datetime
from onyx.agents.agent_search.deep_search_a.main.states import LoggerUpdate
from onyx.agents.agent_search.deep_search_a.main.states import MainState
def retrieval_consolidation(
state: MainState,
) -> LoggerUpdate:
now_start = datetime.now()
return LoggerUpdate(log_messages=[f"{now_start} -- Retrieval consolidation"])

View File

@@ -0,0 +1,145 @@
import re
from collections.abc import Callable
from langchain_core.callbacks.manager import dispatch_custom_event
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkStats
from onyx.agents.agent_search.shared_graph_utils.models import InitialAgentResultStats
from onyx.agents.agent_search.shared_graph_utils.models import QueryResult
from onyx.agents.agent_search.shared_graph_utils.models import (
QuestionAnswerResults,
)
from onyx.chat.models import SubQuestionPiece
from onyx.tools.models import SearchQueryInfo
from onyx.utils.logger import setup_logger
logger = setup_logger()
def remove_document_citations(text: str) -> str:
"""
Removes citation expressions of format '[[D1]]()' from text.
The number after D can vary.
Args:
text: Input text containing citations
Returns:
Text with citations removed
"""
# Pattern explanation:
# \[\[D\d+\]\]\(\) matches:
# \[\[ - literal [[ characters
# D - literal D character
# \d+ - one or more digits
# \]\] - literal ]] characters
# \(\) - literal () characters
return re.sub(r"\[\[(?:D|Q)\d+\]\]\(\)", "", text)
def dispatch_subquestion(level: int) -> Callable[[str, int], None]:
def _helper(sub_question_part: str, num: int) -> None:
dispatch_custom_event(
"decomp_qs",
SubQuestionPiece(
sub_question=sub_question_part,
level=level,
level_question_nr=num,
),
)
return _helper
def calculate_initial_agent_stats(
decomp_answer_results: list[QuestionAnswerResults],
original_question_stats: AgentChunkStats,
) -> InitialAgentResultStats:
initial_agent_result_stats: InitialAgentResultStats = InitialAgentResultStats(
sub_questions={},
original_question={},
agent_effectiveness={},
)
orig_verified = original_question_stats.verified_count
orig_support_score = original_question_stats.verified_avg_scores
verified_document_chunk_ids = []
support_scores = 0.0
for decomp_answer_result in decomp_answer_results:
verified_document_chunk_ids += (
decomp_answer_result.sub_question_retrieval_stats.verified_doc_chunk_ids
)
if (
decomp_answer_result.sub_question_retrieval_stats.verified_avg_scores
is not None
):
support_scores += (
decomp_answer_result.sub_question_retrieval_stats.verified_avg_scores
)
verified_document_chunk_ids = list(set(verified_document_chunk_ids))
# Calculate sub-question stats
if (
verified_document_chunk_ids
and len(verified_document_chunk_ids) > 0
and support_scores is not None
):
sub_question_stats: dict[str, float | int | None] = {
"num_verified_documents": len(verified_document_chunk_ids),
"verified_avg_score": float(support_scores / len(decomp_answer_results)),
}
else:
sub_question_stats = {"num_verified_documents": 0, "verified_avg_score": None}
initial_agent_result_stats.sub_questions.update(sub_question_stats)
# Get original question stats
initial_agent_result_stats.original_question.update(
{
"num_verified_documents": original_question_stats.verified_count,
"verified_avg_score": original_question_stats.verified_avg_scores,
}
)
# Calculate chunk utilization ratio
sub_verified = initial_agent_result_stats.sub_questions["num_verified_documents"]
chunk_ratio: float | None = None
if sub_verified is not None and orig_verified is not None and orig_verified > 0:
chunk_ratio = (float(sub_verified) / orig_verified) if sub_verified > 0 else 0.0
elif sub_verified is not None and sub_verified > 0:
chunk_ratio = 10.0
initial_agent_result_stats.agent_effectiveness["utilized_chunk_ratio"] = chunk_ratio
if (
orig_support_score is None
or orig_support_score == 0.0
and initial_agent_result_stats.sub_questions["verified_avg_score"] is None
):
initial_agent_result_stats.agent_effectiveness["support_ratio"] = None
elif orig_support_score is None or orig_support_score == 0.0:
initial_agent_result_stats.agent_effectiveness["support_ratio"] = 10
elif initial_agent_result_stats.sub_questions["verified_avg_score"] is None:
initial_agent_result_stats.agent_effectiveness["support_ratio"] = 0
else:
initial_agent_result_stats.agent_effectiveness["support_ratio"] = (
initial_agent_result_stats.sub_questions["verified_avg_score"]
/ orig_support_score
)
return initial_agent_result_stats
def get_query_info(results: list[QueryResult]) -> SearchQueryInfo:
# Use the query info from the base document retrieval
# TODO: see if this is the right way to do this
query_infos = [
result.query_info for result in results if result.query_info is not None
]
if len(query_infos) == 0:
raise ValueError("No query info found")
return query_infos[0]

View File

@@ -0,0 +1,160 @@
from datetime import datetime
from operator import add
from typing import Annotated
from typing import TypedDict
from pydantic import BaseModel
from onyx.agents.agent_search.core_state import CoreState
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.models import (
ExpandedRetrievalResult,
)
from onyx.agents.agent_search.deep_search_a.main.models import AgentBaseMetrics
from onyx.agents.agent_search.deep_search_a.main.models import AgentRefinedMetrics
from onyx.agents.agent_search.deep_search_a.main.models import FollowUpSubQuestion
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkStats
from onyx.agents.agent_search.shared_graph_utils.models import (
EntityRelationshipTermExtraction,
)
from onyx.agents.agent_search.shared_graph_utils.models import InitialAgentResultStats
from onyx.agents.agent_search.shared_graph_utils.models import QueryResult
from onyx.agents.agent_search.shared_graph_utils.models import (
QuestionAnswerResults,
)
from onyx.agents.agent_search.shared_graph_utils.models import RefinedAgentStats
from onyx.agents.agent_search.shared_graph_utils.operators import (
dedup_inference_sections,
)
from onyx.agents.agent_search.shared_graph_utils.operators import (
dedup_question_answer_results,
)
from onyx.context.search.models import InferenceSection
### States ###
## Update States
class LoggerUpdate(BaseModel):
log_messages: Annotated[list[str], add] = []
class RefinedAgentStartStats(BaseModel):
agent_refined_start_time: datetime | None = None
class RefinedAgentEndStats(BaseModel):
agent_refined_end_time: datetime | None = None
agent_refined_metrics: AgentRefinedMetrics = AgentRefinedMetrics()
class BaseDecompUpdate(RefinedAgentStartStats, RefinedAgentEndStats):
agent_start_time: datetime = datetime.now()
initial_decomp_questions: list[str] = []
class AnswerComparison(LoggerUpdate):
refined_answer_improvement: bool = False
class RoutingDecision(LoggerUpdate):
routing: str = ""
sample_doc_str: str = ""
class InitialAnswerBASEUpdate(BaseModel):
initial_base_answer: str = ""
class InitialAnswerUpdate(LoggerUpdate):
initial_answer: str = ""
initial_agent_stats: InitialAgentResultStats | None = None
generated_sub_questions: list[str] = []
agent_base_end_time: datetime | None = None
agent_base_metrics: AgentBaseMetrics | None = None
class RefinedAnswerUpdate(RefinedAgentEndStats):
refined_answer: str = ""
refined_agent_stats: RefinedAgentStats | None = None
refined_answer_quality: bool = False
class InitialAnswerQualityUpdate(LoggerUpdate):
initial_answer_quality: bool = False
class RequireRefinedAnswerUpdate(LoggerUpdate):
require_refined_answer: bool = True
class DecompAnswersUpdate(LoggerUpdate):
documents: Annotated[list[InferenceSection], dedup_inference_sections] = []
context_documents: Annotated[list[InferenceSection], dedup_inference_sections] = []
decomp_answer_results: Annotated[
list[QuestionAnswerResults], dedup_question_answer_results
] = []
class FollowUpDecompAnswersUpdate(LoggerUpdate):
refined_documents: Annotated[list[InferenceSection], dedup_inference_sections] = []
refined_decomp_answer_results: Annotated[list[QuestionAnswerResults], add] = []
class ExpandedRetrievalUpdate(LoggerUpdate):
all_original_question_documents: Annotated[
list[InferenceSection], dedup_inference_sections
]
original_question_retrieval_results: list[QueryResult] = []
original_question_retrieval_stats: AgentChunkStats = AgentChunkStats()
class EntityTermExtractionUpdate(LoggerUpdate):
entity_retlation_term_extractions: EntityRelationshipTermExtraction = (
EntityRelationshipTermExtraction()
)
class FollowUpSubQuestionsUpdate(RefinedAgentStartStats):
refined_sub_questions: dict[int, FollowUpSubQuestion] = {}
## Graph Input State
## Graph Input State
class MainInput(CoreState):
pass
## Graph State
class MainState(
# This includes the core state
MainInput,
BaseDecompUpdate,
InitialAnswerUpdate,
InitialAnswerBASEUpdate,
DecompAnswersUpdate,
ExpandedRetrievalUpdate,
EntityTermExtractionUpdate,
InitialAnswerQualityUpdate,
RequireRefinedAnswerUpdate,
FollowUpSubQuestionsUpdate,
FollowUpDecompAnswersUpdate,
RefinedAnswerUpdate,
RefinedAgentStartStats,
RefinedAgentEndStats,
RoutingDecision,
AnswerComparison,
):
# expanded_retrieval_result: Annotated[list[ExpandedRetrievalResult], add]
base_raw_search_result: Annotated[list[ExpandedRetrievalResult], add]
## Graph Output State - presently not used
class MainOutput(TypedDict):
log_messages: list[str]

View File

@@ -0,0 +1,86 @@
from dataclasses import dataclass
from uuid import UUID
from pydantic import BaseModel
from pydantic import model_validator
from sqlalchemy.orm import Session
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.context.search.models import SearchRequest
from onyx.file_store.utils import InMemoryChatFile
from onyx.llm.interfaces import LLM
from onyx.tools.force import ForceUseTool
from onyx.tools.tool import Tool
from onyx.tools.tool_implementations.search.search_tool import SearchTool
@dataclass
class AgentSearchConfig:
"""
Configuration for the Agent Search feature.
"""
# The search request that was used to generate the Pro Search
search_request: SearchRequest
primary_llm: LLM
fast_llm: LLM
search_tool: SearchTool
# Whether to force use of a tool, or to
# force tool args IF the tool is used
force_use_tool: ForceUseTool
# contains message history for the current chat session
# has the following (at most one is non-None)
# message_history: list[PreviousMessage] | None = None
# single_message_history: str | None = None
prompt_builder: AnswerPromptBuilder
use_agentic_search: bool = False
# For persisting agent search data
chat_session_id: UUID | None = None
# The message ID of the user message that triggered the Pro Search
message_id: int | None = None
# Whether to persistence data for the Pro Search (turned off for testing)
use_persistence: bool = True
# The database session for the Pro Search
db_session: Session | None = None
# Whether to perform initial search to inform decomposition
perform_initial_search_path_decision: bool = True
# Whether to perform initial search to inform decomposition
perform_initial_search_decomposition: bool = True
# Whether to allow creation of refinement questions (and entity extraction, etc.)
allow_refinement: bool = True
# Tools available for use
tools: list[Tool] | None = None
using_tool_calling_llm: bool = False
files: list[InMemoryChatFile] | None = None
structured_response_format: dict | None = None
skip_gen_ai_answer_generation: bool = False
@model_validator(mode="after")
def validate_db_session(self) -> "AgentSearchConfig":
if self.use_persistence and self.db_session is None:
raise ValueError(
"db_session must be provided for pro search when using persistence"
)
return self
class AgentDocumentCitations(BaseModel):
document_id: str
document_title: str
link: str

View File

@@ -0,0 +1,272 @@
import asyncio
from collections.abc import AsyncIterable
from collections.abc import Iterable
from datetime import datetime
from typing import cast
from langchain_core.runnables.schema import StreamEvent
from langgraph.graph.state import CompiledStateGraph
from onyx.agents.agent_search.basic.graph_builder import basic_graph_builder
from onyx.agents.agent_search.basic.states import BasicInput
from onyx.agents.agent_search.deep_search_a.main.graph_builder import (
main_graph_builder as main_graph_builder_a,
)
from onyx.agents.agent_search.deep_search_a.main.states import MainInput as MainInput_a
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
from onyx.chat.models import AgentAnswerPiece
from onyx.chat.models import AnswerPacket
from onyx.chat.models import AnswerStream
from onyx.chat.models import ExtendedToolResponse
from onyx.chat.models import RefinedAnswerImprovement
from onyx.chat.models import StreamStopInfo
from onyx.chat.models import SubQueryPiece
from onyx.chat.models import SubQuestionPiece
from onyx.chat.models import ToolResponse
from onyx.configs.dev_configs import GRAPH_NAME
from onyx.context.search.models import SearchRequest
from onyx.db.engine import get_session_context_manager
from onyx.tools.tool_runner import ToolCallKickoff
from onyx.utils.logger import setup_logger
logger = setup_logger()
_COMPILED_GRAPH: CompiledStateGraph | None = None
def _set_combined_token_value(
combined_token: str, parsed_object: AgentAnswerPiece
) -> AgentAnswerPiece:
parsed_object.answer_piece = combined_token
return parsed_object
def _parse_agent_event(
event: StreamEvent,
) -> AnswerPacket | None:
"""
Parse the event into a typed object.
Return None if we are not interested in the event.
"""
event_type = event["event"]
# We always just yield the event data, but this piece is useful for two development reasons:
# 1. It's a list of the names of every place we dispatch a custom event
# 2. We maintain the intended types yielded by each event
if event_type == "on_custom_event":
# TODO: different AnswerStream types for different events
if event["name"] == "decomp_qs":
return cast(SubQuestionPiece, event["data"])
elif event["name"] == "subqueries":
return cast(SubQueryPiece, event["data"])
elif event["name"] == "sub_answers":
return cast(AgentAnswerPiece, event["data"])
elif event["name"] == "stream_finished":
return cast(StreamStopInfo, event["data"])
elif event["name"] == "initial_agent_answer":
return cast(AgentAnswerPiece, event["data"])
elif event["name"] == "refined_agent_answer":
return cast(AgentAnswerPiece, event["data"])
elif event["name"] == "start_refined_answer_creation":
return cast(ToolCallKickoff, event["data"])
elif event["name"] == "tool_response":
return cast(ToolResponse, event["data"])
elif event["name"] == "basic_response":
return cast(AnswerPacket, event["data"])
elif event["name"] == "refined_answer_improvement":
return cast(RefinedAnswerImprovement, event["data"])
return None
# https://stackoverflow.com/questions/60226557/how-to-forcefully-close-an-async-generator
# https://stackoverflow.com/questions/40897428/please-explain-task-was-destroyed-but-it-is-pending-after-cancelling-tasks
task_references: set[asyncio.Task[StreamEvent]] = set()
def _manage_async_event_streaming(
compiled_graph: CompiledStateGraph,
config: AgentSearchConfig | None,
graph_input: MainInput_a | BasicInput,
) -> Iterable[StreamEvent]:
async def _run_async_event_stream() -> AsyncIterable[StreamEvent]:
message_id = config.message_id if config else None
async for event in compiled_graph.astream_events(
input=graph_input,
config={"metadata": {"config": config, "thread_id": str(message_id)}},
# debug=True,
# indicating v2 here deserves further scrutiny
version="v2",
):
yield event
# This might be able to be simplified
def _yield_async_to_sync() -> Iterable[StreamEvent]:
loop = asyncio.new_event_loop()
try:
# Get the async generator
async_gen = _run_async_event_stream()
# Convert to AsyncIterator
async_iter = async_gen.__aiter__()
while True:
try:
# Create a coroutine by calling anext with the async iterator
next_coro = anext(async_iter)
task = asyncio.ensure_future(next_coro, loop=loop)
task_references.add(task)
# Run the coroutine to get the next event
event = loop.run_until_complete(task)
yield event
except (StopAsyncIteration, GeneratorExit):
break
finally:
for task in task_references.pop():
task.cancel()
loop.close()
return _yield_async_to_sync()
def run_graph(
compiled_graph: CompiledStateGraph,
config: AgentSearchConfig,
input: BasicInput | MainInput_a,
) -> AnswerStream:
# TODO: add these to the environment
config.perform_initial_search_path_decision = False
config.perform_initial_search_decomposition = True
config.allow_refinement = True
for event in _manage_async_event_streaming(
compiled_graph=compiled_graph, config=config, graph_input=input
):
if not (parsed_object := _parse_agent_event(event)):
continue
yield parsed_object
# TODO: call this once on startup, TBD where and if it should be gated based
# on dev mode or not
def load_compiled_graph(graph_name: str) -> CompiledStateGraph:
main_graph_builder = (
main_graph_builder_a if graph_name == "a" else main_graph_builder_a
)
global _COMPILED_GRAPH
if _COMPILED_GRAPH is None:
graph = main_graph_builder()
_COMPILED_GRAPH = graph.compile()
return _COMPILED_GRAPH
def run_main_graph(
config: AgentSearchConfig,
graph_name: str = "a",
) -> AnswerStream:
compiled_graph = load_compiled_graph(graph_name)
if graph_name == "a":
input = MainInput_a(base_question=config.search_request.query, log_messages=[])
else:
input = MainInput_a(base_question=config.search_request.query, log_messages=[])
# Agent search is not a Tool per se, but this is helpful for the frontend
yield ToolCallKickoff(
tool_name="agent_search_0",
tool_args={"query": config.search_request.query},
)
yield from run_graph(compiled_graph, config, input)
# TODO: unify input types, especially prosearchconfig
def run_basic_graph(
config: AgentSearchConfig,
) -> AnswerStream:
graph = basic_graph_builder()
compiled_graph = graph.compile()
# TODO: unify basic input
input = BasicInput(
should_stream_answer=True,
)
return run_graph(compiled_graph, config, input)
if __name__ == "__main__":
from onyx.llm.factory import get_default_llms
now_start = datetime.now()
logger.debug(f"Start at {now_start}")
if GRAPH_NAME == "a":
graph = main_graph_builder_a()
else:
graph = main_graph_builder_a()
compiled_graph = graph.compile()
now_end = datetime.now()
logger.debug(f"Graph compiled in {now_end - now_start} seconds")
primary_llm, fast_llm = get_default_llms()
search_request = SearchRequest(
# query="what can you do with gitlab?",
# query="What are the guiding principles behind the development of cockroachDB",
# query="What are the temperatures in Munich, Hawaii, and New York?",
# query="When was Washington born?",
query="What is Onyx?",
)
# Joachim custom persona
with get_session_context_manager() as db_session:
config, search_tool = get_test_config(
db_session, primary_llm, fast_llm, search_request
)
# search_request.persona = get_persona_by_id(1, None, db_session)
config.use_persistence = True
config.perform_initial_search_path_decision = False
config.perform_initial_search_decomposition = True
if GRAPH_NAME == "a":
input = MainInput_a(
base_question=config.search_request.query, log_messages=[]
)
else:
input = MainInput_a(
base_question=config.search_request.query, log_messages=[]
)
# with open("output.txt", "w") as f:
tool_responses: list = []
for output in run_graph(compiled_graph, config, input):
# pass
if isinstance(output, ToolCallKickoff):
pass
elif isinstance(output, ExtendedToolResponse):
tool_responses.append(output.response)
logger.info(
f" ---- ET {output.level} - {output.level_question_nr} | "
)
elif isinstance(output, SubQueryPiece):
logger.info(
f"Sq {output.level} - {output.level_question_nr} - {output.sub_query} | "
)
elif isinstance(output, SubQuestionPiece):
logger.info(
f"SQ {output.level} - {output.level_question_nr} - {output.sub_question} | "
)
elif (
isinstance(output, AgentAnswerPiece)
and output.answer_type == "agent_sub_answer"
):
logger.info(
f" ---- SA {output.level} - {output.level_question_nr} {output.answer_piece} | "
)
elif (
isinstance(output, AgentAnswerPiece)
and output.answer_type == "agent_level_answer"
):
logger.info(
f" ---------- FA {output.level} - {output.level_question_nr} {output.answer_piece} | "
)
elif isinstance(output, RefinedAnswerImprovement):
logger.info(f" ---------- RE {output.refined_answer_improvement} | ")
# for tool_response in tool_responses:
# logger.debug(tool_response)

View File

@@ -0,0 +1,100 @@
from langchain.schema import AIMessage
from langchain.schema import HumanMessage
from langchain.schema import SystemMessage
from langchain_core.messages.tool import ToolMessage
from onyx.agents.agent_search.shared_graph_utils.prompts import BASE_RAG_PROMPT_v2
from onyx.agents.agent_search.shared_graph_utils.prompts import HISTORY_PROMPT
from onyx.agents.agent_search.shared_graph_utils.utils import get_today_prompt
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.context.search.models import InferenceSection
from onyx.llm.interfaces import LLMConfig
from onyx.llm.utils import get_max_input_tokens
from onyx.natural_language_processing.utils import get_tokenizer
from onyx.natural_language_processing.utils import tokenizer_trim_content
def build_sub_question_answer_prompt(
question: str,
original_question: str,
docs: list[InferenceSection],
persona_specification: str,
config: LLMConfig,
) -> list[SystemMessage | HumanMessage | AIMessage | ToolMessage]:
system_message = SystemMessage(
content=persona_specification,
)
date_str = get_today_prompt()
docs_format_list = [
f"""Document Number: [D{doc_nr + 1}]\n
Content: {doc.combined_content}\n\n"""
for doc_nr, doc in enumerate(docs)
]
docs_str = "\n\n".join(docs_format_list)
docs_str = trim_prompt_piece(
config, docs_str, BASE_RAG_PROMPT_v2 + question + original_question + date_str
)
human_message = HumanMessage(
content=BASE_RAG_PROMPT_v2.format(
question=question,
original_question=original_question,
context=docs_str,
date_prompt=date_str,
)
)
return [system_message, human_message]
def trim_prompt_piece(config: LLMConfig, prompt_piece: str, reserved_str: str) -> str:
# TODO: this truncating might add latency. We could do a rougher + faster check
# first to determine whether truncation is needed
# TODO: maybe save the tokenizer and max input tokens if this is getting called multiple times?
llm_tokenizer = get_tokenizer(
provider_type=config.model_provider,
model_name=config.model_name,
)
max_tokens = get_max_input_tokens(
model_provider=config.model_provider,
model_name=config.model_name,
)
# slightly conservative trimming
return tokenizer_trim_content(
content=prompt_piece,
desired_length=max_tokens - len(llm_tokenizer.encode(reserved_str)),
tokenizer=llm_tokenizer,
)
def build_history_prompt(prompt_builder: AnswerPromptBuilder | None) -> str:
if prompt_builder is None:
return ""
if prompt_builder.single_message_history is not None:
history = prompt_builder.single_message_history
else:
history_components = []
previous_message_type = None
for message in prompt_builder.raw_message_history:
if "user" in message.message_type:
history_components.append(f"User: {message.message}\n")
previous_message_type = "user"
elif "assistant" in message.message_type:
# only use the last agent answer for the history
if previous_message_type != "assistant":
history_components.append(f"You/Agent: {message.message}\n")
else:
history_components = history_components[:-1]
history_components.append(f"You/Agent: {message.message}\n")
previous_message_type = "assistant"
else:
continue
history = "\n".join(history_components)
return HISTORY_PROMPT.format(history=history) if history else ""

View File

@@ -0,0 +1,98 @@
import numpy as np
from onyx.agents.agent_search.shared_graph_utils.models import RetrievalFitScoreMetrics
from onyx.agents.agent_search.shared_graph_utils.models import RetrievalFitStats
from onyx.chat.models import SectionRelevancePiece
from onyx.context.search.models import InferenceSection
from onyx.utils.logger import setup_logger
logger = setup_logger()
def unique_chunk_id(doc: InferenceSection) -> str:
return f"{doc.center_chunk.document_id}_{doc.center_chunk.chunk_id}"
def calculate_rank_shift(list1: list, list2: list, top_n: int = 20) -> float:
shift = 0
for rank_first, doc_id in enumerate(list1[:top_n], 1):
try:
rank_second = list2.index(doc_id) + 1
except ValueError:
rank_second = len(list2) # Document not found in second list
shift += np.abs(rank_first - rank_second) / np.log(1 + rank_first * rank_second)
return shift / top_n
def get_fit_scores(
pre_reranked_results: list[InferenceSection],
post_reranked_results: list[InferenceSection] | list[SectionRelevancePiece],
) -> RetrievalFitStats | None:
"""
Calculate retrieval metrics for search purposes
"""
if len(pre_reranked_results) == 0 or len(post_reranked_results) == 0:
return None
ranked_sections = {
"initial": pre_reranked_results,
"reranked": post_reranked_results,
}
fit_eval: RetrievalFitStats = RetrievalFitStats(
fit_score_lift=0,
rerank_effect=0,
fit_scores={
"initial": RetrievalFitScoreMetrics(scores={}, chunk_ids=[]),
"reranked": RetrievalFitScoreMetrics(scores={}, chunk_ids=[]),
},
)
for rank_type, docs in ranked_sections.items():
logger.debug(f"rank_type: {rank_type}")
for i in [1, 5, 10]:
fit_eval.fit_scores[rank_type].scores[str(i)] = (
sum(
[
float(doc.center_chunk.score)
for doc in docs[:i]
if type(doc) == InferenceSection
and doc.center_chunk.score is not None
]
)
/ i
)
fit_eval.fit_scores[rank_type].scores["fit_score"] = (
1
/ 3
* (
fit_eval.fit_scores[rank_type].scores["1"]
+ fit_eval.fit_scores[rank_type].scores["5"]
+ fit_eval.fit_scores[rank_type].scores["10"]
)
)
fit_eval.fit_scores[rank_type].scores["fit_score"] = fit_eval.fit_scores[
rank_type
].scores["1"]
fit_eval.fit_scores[rank_type].chunk_ids = [
unique_chunk_id(doc) for doc in docs if type(doc) == InferenceSection
]
fit_eval.fit_score_lift = (
fit_eval.fit_scores["reranked"].scores["fit_score"]
/ fit_eval.fit_scores["initial"].scores["fit_score"]
)
fit_eval.rerank_effect = calculate_rank_shift(
fit_eval.fit_scores["initial"].chunk_ids,
fit_eval.fit_scores["reranked"].chunk_ids,
)
return fit_eval

View File

@@ -0,0 +1,113 @@
from typing import Literal
from pydantic import BaseModel
from onyx.agents.agent_search.deep_search_a.main.models import AgentAdditionalMetrics
from onyx.agents.agent_search.deep_search_a.main.models import AgentBaseMetrics
from onyx.agents.agent_search.deep_search_a.main.models import AgentRefinedMetrics
from onyx.agents.agent_search.deep_search_a.main.models import AgentTimings
from onyx.context.search.models import InferenceSection
from onyx.tools.models import SearchQueryInfo
# Pydantic models for structured outputs
class RewrittenQueries(BaseModel):
rewritten_queries: list[str]
class BinaryDecision(BaseModel):
decision: Literal["yes", "no"]
class BinaryDecisionWithReasoning(BaseModel):
reasoning: str
decision: Literal["yes", "no"]
class RetrievalFitScoreMetrics(BaseModel):
scores: dict[str, float]
chunk_ids: list[str]
class RetrievalFitStats(BaseModel):
fit_score_lift: float
rerank_effect: float
fit_scores: dict[str, RetrievalFitScoreMetrics]
class AgentChunkScores(BaseModel):
scores: dict[str, dict[str, list[int | float]]]
class AgentChunkStats(BaseModel):
verified_count: int | None = None
verified_avg_scores: float | None = None
rejected_count: int | None = None
rejected_avg_scores: float | None = None
verified_doc_chunk_ids: list[str] = []
dismissed_doc_chunk_ids: list[str] = []
class InitialAgentResultStats(BaseModel):
sub_questions: dict[str, float | int | None]
original_question: dict[str, float | int | None]
agent_effectiveness: dict[str, float | int | None]
class RefinedAgentStats(BaseModel):
revision_doc_efficiency: float | None
revision_question_efficiency: float | None
class Term(BaseModel):
term_name: str = ""
term_type: str = ""
term_similar_to: list[str] = []
### Models ###
class Entity(BaseModel):
entity_name: str = ""
entity_type: str = ""
class Relationship(BaseModel):
relationship_name: str = ""
relationship_type: str = ""
relationship_entities: list[str] = []
class EntityRelationshipTermExtraction(BaseModel):
entities: list[Entity] = []
relationships: list[Relationship] = []
terms: list[Term] = []
### Models ###
class QueryResult(BaseModel):
query: str
search_results: list[InferenceSection]
stats: RetrievalFitStats | None
query_info: SearchQueryInfo | None
class QuestionAnswerResults(BaseModel):
question: str
question_id: str
answer: str
quality: str
expanded_retrieval_results: list[QueryResult]
documents: list[InferenceSection]
context_documents: list[InferenceSection]
sub_question_retrieval_stats: AgentChunkStats
class CombinedAgentMetrics(BaseModel):
timings: AgentTimings
base_metrics: AgentBaseMetrics | None
refined_metrics: AgentRefinedMetrics
additional_metrics: AgentAdditionalMetrics

View File

@@ -0,0 +1,31 @@
from onyx.agents.agent_search.shared_graph_utils.models import (
QuestionAnswerResults,
)
from onyx.chat.prune_and_merge import _merge_sections
from onyx.context.search.models import InferenceSection
def dedup_inference_sections(
list1: list[InferenceSection], list2: list[InferenceSection]
) -> list[InferenceSection]:
deduped = _merge_sections(list1 + list2)
return deduped
def dedup_question_answer_results(
question_answer_results_1: list[QuestionAnswerResults],
question_answer_results_2: list[QuestionAnswerResults],
) -> list[QuestionAnswerResults]:
deduped_question_answer_results: list[
QuestionAnswerResults
] = question_answer_results_1
utilized_question_ids: set[str] = set(
[x.question_id for x in question_answer_results_1]
)
for question_answer_result in question_answer_results_2:
if question_answer_result.question_id not in utilized_question_ids:
deduped_question_answer_results.append(question_answer_result)
utilized_question_ids.add(question_answer_result.question_id)
return deduped_question_answer_results

View File

@@ -0,0 +1,975 @@
UNKNOWN_ANSWER = "I do not have enough information to answer this question."
NO_RECOVERED_DOCS = "No relevant information recovered"
DATE_PROMPT = """Today is {date}.\n\n"""
HISTORY_PROMPT = """\n
For more context, here is the history of the conversation so far that preceeded this question:
\n ------- \n
{history}
\n ------- \n\n
"""
REWRITE_PROMPT_MULTI_ORIGINAL = """ \n
Please convert an initial user question into a 2-3 more appropriate short and pointed search queries for retrievel from a
document store. Particularly, try to think about resolving ambiguities and make the search queries more specific,
enabling the system to search more broadly.
Also, try to make the search queries not redundant, i.e. not too similar! \n\n
Here is the initial question:
\n ------- \n
{question}
\n ------- \n
Formulate the queries separated by newlines (Do not say 'Query 1: ...', just write the querytext) as follows:
<query 1>
<query 2>
...
queries: """
REWRITE_PROMPT_MULTI = """ \n
Please create a list of 2-3 sample documents that could answer an original question. Each document
should be about as long as the original question. \n
Here is the initial question:
\n ------- \n
{question}
\n ------- \n
Formulate the sample documents separated by '--' (Do not say 'Document 1: ...', just write the text): """
# The prompt is only used if there is no persona prompt, so the placeholder is ''
BASE_RAG_PROMPT = (
""" \n
{persona_specification}
{date_prompt}
Use the context provided below - and only the
provided context - to answer the given question. (Note that the answer is in service of anserwing a broader
question, given below as 'motivation'.)
Again, only use the provided context and do not use your internal knowledge! If you cannot answer the
question based on the context, say """
+ f'"{UNKNOWN_ANSWER}"'
+ """.
It is a matter of life and death that you do NOT use your internal knowledge, just the provided`
information!
Make sure that you keep all relevant information, specifically as it concerns to the ultimate goal.
(But keep other details as well.)
\nContext:\n {context} \n
Motivation:\n {original_question} \n\n
\n\n
And here is the question I want you to answer based on the context above (with the motivation in mind):
\n--\n {question} \n--\n
"""
)
BASE_RAG_PROMPT_v2 = (
""" \n
{date_prompt}
Use the context provided below - and only the
provided context - to answer the given question. (Note that the answer is in service of answering a broader
question, given below as 'motivation'.)
Again, only use the provided context and do not use your internal knowledge! If you cannot answer the
question based on the context, say """
+ f'"{UNKNOWN_ANSWER}"'
+ """. It is a matter of life and death that you do NOT
use your internal knowledge, just the provided information!
Make sure that you keep all relevant information, specifically as it concerns to the ultimate goal.
(But keep other details as well.)
Please remember to provide inline citations in the format [[D1]](), [[D2]](), [[D3]](), etc!
It is important that the citation is close to the information it supports.
Proper citations are very important to the user!\n\n\n
For your general information, here is the ultimate motivation:
\n--\n {original_question} \n--\n
\n\n
And here is the actual question I want you to answer based on the context above (with the motivation in mind):
\n--\n {question} \n--\n
Here is the context:
\n\n\n--\n {context} \n--\n
"""
)
SUB_CHECK_YES = "yes"
SUB_CHECK_NO = "no"
SUB_CHECK_PROMPT = (
"""
Your task is to see whether a given answer addresses a given question.
Please do not use any internal knowledge you may have - just focus on whether the answer
as given seems to largely address the question as given, or at least addresses part of the question.
Here is the question:
\n ------- \n
{question}
\n ------- \n
Here is the suggested answer:
\n ------- \n
{base_answer}
\n ------- \n
Does the suggested answer address the question? Please answer with """
+ f'"{SUB_CHECK_YES}" or "{SUB_CHECK_NO}".'
)
BASE_CHECK_PROMPT = """ \n
Please check whether 1) the suggested answer seems to fully address the original question AND 2)the
original question requests a simple, factual answer, and there are no ambiguities, judgements,
aggregations, or any other complications that may require extra context. (I.e., if the question is
somewhat addressed, but the answer would benefit from more context, then answer with 'no'.)
Please only answer with 'yes' or 'no' \n
Here is the initial question:
\n ------- \n
{question}
\n ------- \n
Here is the proposed answer:
\n ------- \n
{initial_answer}
\n ------- \n
Please answer with yes or no:"""
VERIFIER_PROMPT = """
You are supposed to judge whether a document text contains data or information that is potentially relevant
for a question. It does not have to be fully relevant, but check whether it has some information that
could help to address the question.
Here is a document text that you can take as a fact:
--
DOCUMENT INFORMATION:
{document_content}
--
Do you think that this document text is useful and relevant to answer the following question?
(Other documents may supply additional information, so do not worry if the provided information
is not enough to answer the question, but it needs to be relevant to the question.)
--
QUESTION:
{question}
--
Please answer with 'yes' or 'no':
Answer:
"""
INITIAL_DECOMPOSITION_PROMPT_BASIC = """ \n
If you think it is helpful, please decompose an initial user question into not more
than 4 appropriate sub-questions that help to answer the original question.
The purpose for this decomposition is to isolate individulal entities
(i.e., 'compare sales of company A and company B' -> 'what are sales for company A' + 'what are sales
for company B'), split ambiguous terms (i.e., 'what is our success with company A' -> 'what are our
sales with company A' + 'what is our market share with company A' + 'is company A a reference customer
for us'), etc. Each sub-question should be realistically be answerable by a good RAG system.
Importantly, if you think it is not needed or helpful, please just return an empty list. That is ok too.
Here is the initial question:
\n ------- \n
{question}
\n ------- \n
Please formulate your answer as a list of subquestions:
Answer:
"""
REWRITE_PROMPT_SINGLE = """ \n
Please convert an initial user question into a more appropriate search query for retrievel from a
document store. \n
Here is the initial question:
\n ------- \n
{question}
\n ------- \n
Formulate the query: """
MODIFIED_RAG_PROMPT = (
"""You are an assistant for question-answering tasks. Use the context provided below
- and only this context - to answer the question. It is a matter of life and death that you do NOT
use your internal knowledge, just the provided information!
If you don't have enough infortmation to generate an answer, just say """
+ f'"{UNKNOWN_ANSWER}"'
+ """.
Use three sentences maximum and keep the answer concise.
Pay also particular attention to the sub-questions and their answers, at least it may enrich the answer.
Again, only use the provided context and do not use your internal knowledge!
\nQuestion: {question}
\nContext: {combined_context} \n
Answer:"""
)
ORIG_DEEP_DECOMPOSE_PROMPT = """ \n
An initial user question needs to be answered. An initial answer has been provided but it wasn't quite
good enough. Also, some sub-questions had been answered and this information has been used to provide
the initial answer. Some other subquestions may have been suggested based on little knowledge, but they
were not directly answerable. Also, some entities, relationships and terms are givenm to you so that
you have an idea of how the avaiolable data looks like.
Your role is to generate 3-5 new sub-questions that would help to answer the initial question,
considering:
1) The initial question
2) The initial answer that was found to be unsatisfactory
3) The sub-questions that were answered
4) The sub-questions that were suggested but not answered
5) The entities, relationships and terms that were extracted from the context
The individual questions should be answerable by a good RAG system.
So a good idea would be to use the sub-questions to resolve ambiguities and/or to separate the
question for different entities that may be involved in the original question, but in a way that does
not duplicate questions that were already tried.
Additional Guidelines:
- The sub-questions should be specific to the question and provide richer context for the question,
resolve ambiguities, or address shortcoming of the initial answer
- Each sub-question - when answered - should be relevant for the answer to the original question
- The sub-questions should be free from comparisions, ambiguities,judgements, aggregations, or any
other complications that may require extra context.
- The sub-questions MUST have the full context of the original question so that it can be executed by
a RAG system independently without the original question available
(Example:
- initial question: "What is the capital of France?"
- bad sub-question: "What is the name of the river there?"
- good sub-question: "What is the name of the river that flows through Paris?"
- For each sub-question, please provide a short explanation for why it is a good sub-question. So
generate a list of dictionaries with the following format:
[{{"sub_question": <sub-question>, "explanation": <explanation>, "search_term": <rewrite the
sub-question using as a search phrase for the document store>}}, ...]
\n\n
Here is the initial question:
\n ------- \n
{question}
\n ------- \n
Here is the initial sub-optimal answer:
\n ------- \n
{base_answer}
\n ------- \n
Here are the sub-questions that were answered:
\n ------- \n
{answered_sub_questions}
\n ------- \n
Here are the sub-questions that were suggested but not answered:
\n ------- \n
{failed_sub_questions}
\n ------- \n
And here are the entities, relationships and terms extracted from the context:
\n ------- \n
{entity_term_extraction_str}
\n ------- \n
Please generate the list of good, fully contextualized sub-questions that would help to address the
main question. Again, please find questions that are NOT overlapping too much with the already answered
sub-questions or those that already were suggested and failed.
In other words - what can we try in addition to what has been tried so far?
Please think through it step by step and then generate the list of json dictionaries with the following
format:
{{"sub_questions": [{{"sub_question": <sub-question>,
"explanation": <explanation>,
"search_term": <rewrite the sub-question using as a search phrase for the document store>}},
...]}} """
DEEP_DECOMPOSE_PROMPT = """ \n
An initial user question needs to be answered. An initial answer has been provided but it wasn't quite
good enough. Also, some sub-questions had been answered and this information has been used to provide
the initial answer. Some other subquestions may have been suggested based on little knowledge, but they
were not directly answerable. Also, some entities, relationships and terms are givenm to you so that
you have an idea of how the avaiolable data looks like.
Your role is to generate 2-4 new sub-questions that would help to answer the initial question,
considering:
1) The initial question
2) The initial answer that was found to be unsatisfactory
3) The sub-questions that were answered
4) The sub-questions that were suggested but not answered
5) The entities, relationships and terms that were extracted from the context
The individual questions should be answerable by a good RAG system.
So a good idea would be to use the sub-questions to resolve ambiguities and/or to separate the
question for different entities that may be involved in the original question, but in a way that does
not duplicate questions that were already tried.
Additional Guidelines:
- The sub-questions should be specific to the question and provide richer context for the question,
resolve ambiguities, or address shortcoming of the initial answer
- Each sub-question - when answered - should be relevant for the answer to the original question
- The sub-questions should be free from comparisions, ambiguities,judgements, aggregations, or any
other complications that may require extra context.
- The sub-questions MUST have the full context of the original question so that it can be executed by
a RAG system independently without the original question available
(Example:
- initial question: "What is the capital of France?"
- bad sub-question: "What is the name of the river there?"
- good sub-question: "What is the name of the river that flows through Paris?"
- For each sub-question, please also provide a search term that can be used to retrieve relevant
documents from a document store.
- Consider specifically the sub-questions that were suggested but not answered. This is a sign that they are not
answerable with the available context, and you should not ask similar questions.
\n\n
Here is the initial question:
\n ------- \n
{question}
\n ------- \n
{history}
Here is the initial sub-optimal answer:
\n ------- \n
{base_answer}
\n ------- \n
Here are the sub-questions that were answered:
\n ------- \n
{answered_sub_questions}
\n ------- \n
Here are the sub-questions that were suggested but not answered:
\n ------- \n
{failed_sub_questions}
\n ------- \n
And here are the entities, relationships and terms extracted from the context:
\n ------- \n
{entity_term_extraction_str}
\n ------- \n
Please generate the list of good, fully contextualized sub-questions that would help to address the
main question.
Specifically pay attention also to the entities, relationships and terms extracted, as these indicate what type of
objects/relationships/terms you can ask about! Do not ask about entities, terms or relationships that are not
mentioned in the 'entities, relationships and terms' section.
Again, please find questions that are NOT overlapping too much with the already answered
sub-questions or those that already were suggested and failed.
In other words - what can we try in addition to what has been tried so far?
Generate the list of questions separated by one new line like this:
<sub-question 1>
<sub-question 2>
<sub-question 3>
...
"""
DECOMPOSE_PROMPT = """ \n
For an initial user question, please generate at 5-10 individual sub-questions whose answers would help
\n to answer the initial question. The individual questions should be answerable by a good RAG system.
So a good idea would be to \n use the sub-questions to resolve ambiguities and/or to separate the
question for different entities that may be involved in the original question.
In order to arrive at meaningful sub-questions, please also consider the context retrieved from the
document store, expressed as entities, relationships and terms. You can also think about the types
mentioned in brackets
Guidelines:
- The sub-questions should be specific to the question and provide richer context for the question,
and or resolve ambiguities
- Each sub-question - when answered - should be relevant for the answer to the original question
- The sub-questions should be free from comparisions, ambiguities,judgements, aggregations, or any
other complications that may require extra context.
- The sub-questions MUST have the full context of the original question so that it can be executed by
a RAG system independently without the original question available
(Example:
- initial question: "What is the capital of France?"
- bad sub-question: "What is the name of the river there?"
- good sub-question: "What is the name of the river that flows through Paris?"
- For each sub-question, please provide a short explanation for why it is a good sub-question. So
generate a list of dictionaries with the following format:
[{{"sub_question": <sub-question>, "explanation": <explanation>}}, ...]
\n\n
Here is the initial question:
\n ------- \n
{question}
\n ------- \n
And here are the entities, relationships and terms extracted from the context:
\n ------- \n
{entity_term_extraction_str}
\n ------- \n
Please generate the list of good, fully contextualized sub-questions that would help to address the
main question. Don't be too specific unless the original question is specific.
Please think through it step by step and then generate the list of json dictionaries with the following
format:
{{"sub_questions": [{{"sub_question": <sub-question>,
"explanation": <explanation>,
"search_term": <rewrite the sub-question using as a search phrase for the document store>}},
...]}} """
#### Consolidations
COMBINED_CONTEXT = """-------
Below you will find useful information to answer the original question. First, you see a number of
sub-questions with their answers. This information should be considered to be more focussed and
somewhat more specific to the original question as it tries to contextualized facts.
After that will see the documents that were considered to be relevant to answer the original question.
Here are the sub-questions and their answers:
\n\n {deep_answer_context} \n\n
\n\n Here are the documents that were considered to be relevant to answer the original question:
\n\n {formated_docs} \n\n
----------------
"""
SUB_QUESTION_EXPLANATION_RANKER_PROMPT = """-------
Below you will find a question that we ultimately want to answer (the original question) and a list of
motivations in arbitrary order for generated sub-questions that are supposed to help us answering the
original question. The motivations are formatted as <motivation number>: <motivation explanation>.
(Again, the numbering is arbitrary and does not necessarily mean that 1 is the most relevant
motivation and 2 is less relevant.)
Please rank the motivations in order of relevance for answering the original question. Also, try to
ensure that the top questions do not duplicate too much, i.e. that they are not too similar.
Ultimately, create a list with the motivation numbers where the number of the most relevant
motivations comes first.
Here is the original question:
\n\n {original_question} \n\n
\n\n Here is the list of sub-question motivations:
\n\n {sub_question_explanations} \n\n
----------------
Please think step by step and then generate the ranked list of motivations.
Please format your answer as a json object in the following format:
{{"reasonning": <explain your reasoning for the ranking>,
"ranked_motivations": <ranked list of motivation numbers>}}
"""
INITIAL_DECOMPOSITION_PROMPT_QUESTIONS = """
If you think it is helpful, please decompose an initial user question into no more than 3 appropriate sub-questions that help to
answer the original question. The purpose for this decomposition may be to
1) isolate individual entities (i.e., 'compare sales of company A and company B' -> ['what are sales for company A',
'what are sales for company B')]
2) clarify or disambiguate ambiguous terms (i.e., 'what is our success with company A' -> ['what are our sales with company A',
'what is our market share with company A', 'is company A a reference customer for us', etc.])
3) if a term or a metric is essentially clear, but it could relate to various components of an entity and you are generally
familiar with the entity, then you can decompose the question into sub-questions that are more specific to components
(i.e., 'what do we do to improve scalability of product X', 'what do we to to improve scalability of product X',
'what do we do to improve stability of product X', ...])
4) research an area that could really help to answer the question. (But clarifications or disambiguations are more important.)
If you think that a decomposition is not needed or helpful, please just return an empty string. That is ok too.
Here is the initial question:
-------
{question}
-------
{history}
Please formulate your answer as a newline-separated list of questions like so:
<sub-question>
<sub-question>
<sub-question>
Answer:"""
INITIAL_DECOMPOSITION_PROMPT_QUESTIONS_AFTER_SEARCH = """
If you think it is helpful, please decompose an initial user question into no more than 3 appropriate sub-questions that help to
answer the original question. The purpose for this decomposition may be to
1) isolate individual entities (i.e., 'compare sales of company A and company B' -> ['what are sales for company A',
'what are sales for company B')]
2) clarify or disambiguate ambiguous terms (i.e., 'what is our success with company A' -> ['what are our sales with company A',
'what is our market share with company A', 'is company A a reference customer for us', etc.])
3) if a term or a metric is essentially clear, but it could relate to various components of an entity and you are generally
familiar with the entity, then you can decompose the question into sub-questions that are more specific to components
(i.e., 'what do we do to improve scalability of product X', 'what do we to to improve scalability of product X',
'what do we do to improve stability of product X', ...])
4) research an area that could really help to answer the question. (But clarifications or disambiguations are more important.)
Here are some other ruleds:
1) To give you some context, you will see below also some documents that relate to the question. Please only
use this information to learn what the question is approximately asking about, but do not focus on the details
to construct the sub-questions.
2) If you think that a decomposition is not needed or helpful, please just return an empty string. That is very muchok too.
Here are the sampple docs to give you some context:
-------
{sample_doc_str}
-------
And here is the initial question that you should think about decomposing:
-------
{question}
-------
{history}
Please formulate your answer as a newline-separated list of questions like so:
<sub-question>
<sub-question>
<sub-question>
Answer:"""
INITIAL_DECOMPOSITION_PROMPT = """ \n
Please decompose an initial user question into 2 or 3 appropriate sub-questions that help to
answer the original question. The purpose for this decomposition is to isolate individulal entities
(i.e., 'compare sales of company A and company B' -> 'what are sales for company A' + 'what are sales
for company B'), split ambiguous terms (i.e., 'what is our success with company A' -> 'what are our
sales with company A' + 'what is our market share with company A' + 'is company A a reference customer
for us'), etc. Each sub-question should be realistically be answerable by a good RAG system. \n
For each sub-question, please also create one search term that can be used to retrieve relevant
documents from a document store.
Here is the initial question:
\n ------- \n
{question}
\n ------- \n
Please formulate your answer as a list of json objects with the following format:
[{{"sub_question": <sub-question>, "search_term": <search term>}}, ...]
Answer:
"""
INITIAL_RAG_BASE_PROMPT = (
""" \n
You are an assistant for question-answering tasks. Use the information provided below - and only the
provided information - to answer the provided question.
The information provided below consists ofa number of documents that were deemed relevant for the question.
IMPORTANT RULES:
- If you cannot reliably answer the question solely using the provided information, say that you cannot reliably answer.
You may give some additional facts you learned, but do not try to invent an answer.
- If the information is irrelevant, just say """
+ f'"{UNKNOWN_ANSWER}"'
+ """.
- If the information is relevant but not fully conclusive, specify that the information is not conclusive and say why.
Try to keep your answer concise.
Here is the contextual information from the document store:
\n ------- \n
{context} \n\n\n
\n ------- \n
And here is the question I want you to answer based on the context above (with the motivation in mind):
\n--\n {question} \n--\n
Answer:"""
)
AGENT_DECISION_PROMPT = """
You are an large language model assistant helping users address their information needs. You are tasked with deciding
whether to use a thorough agent search ('research') of a document store to answer a question or request, or whether you want to
address the question or request yourself as an LLM.
Here are some rules:
- If you think that a thorough search through a document store will help answer the question
or address the request, you should choose the 'research' option.
- If the question asks you do do somethng ('please create...', 'write for me...', etc.), you should choose the 'LLM' option.
- If you think the question is very general and does not refer to a contents of a document store, you should choose
the 'LLM' option.
- Otherwise, you should choose the 'research' option.
{history}
Here is the initial question:
-------
{question}
-------
Please decide whether to use the agent search or the LLM to answer the question. Choose from two choices,
'research' or 'LLM'.
Answer:"""
AGENT_DECISION_PROMPT_AFTER_SEARCH = """
You are an large language model assistant helping users address their information needs. You are given an initial question
or request and very few sample of documents that a preliminary and fast search from a document store returned.
You are tasked with deciding whether to use a thorough agent search ('research') of the document store to answer a question
or request, or whether you want to address the question or request yourself as an LLM.
Here are some rules:
- If based on the retrieved documents you think there may be useful information in the document
store to answer or materially help with the request, you should choose the 'research' option.
- If you think that the retrieved document do not help to answer the question or do not help with the request, AND
you know the answer/can handle the request, you should choose the 'LLM' option.
- If the question asks you do do somethng ('please create...', 'write for me...', etc.), you should choose the 'LLM' option.
- If in doubt, choose the 'research' option.
{history}
Here is the initial question:
-------
{question}
-------
Here is the sample of documents that were retrieved from a document store:
-------
{sample_doc_str}
-------
Please decide whether to use the agent search ('research') or the LLM to answer the question. Choose from two choices,
'research' or 'LLM'.
Answer:"""
### ANSWER GENERATION PROMPTS
# Persona specification
ASSISTANT_SYSTEM_PROMPT_DEFAULT = """
You are an assistant for question-answering tasks."""
ASSISTANT_SYSTEM_PROMPT_PERSONA = """
You are an assistant for question-answering tasks. Here is more information about you:
\n ------- \n
{persona_prompt}
\n ------- \n
"""
SUB_QUESTION_ANSWER_TEMPLATE = """
Sub-Question: Q{sub_question_nr}\n Sub-Question:\n - \n{sub_question}\n --\nAnswer:\n -\n {sub_answer}\n\n
"""
SUB_QUESTION_ANSWER_TEMPLATE_REVISED = """
Sub-Question: Q{sub_question_nr}\n Type: {level_type}\n Sub-Question:\n
- \n{sub_question}\n --\nAnswer:\n -\n {sub_answer}\n\n
"""
SUB_QUESTION_SEARCH_RESULTS_TEMPLATE = """
Sub-Question: Q{sub_question_nr}\n Sub-Question:\n - \n{sub_question}\n --\nRelevant Documents:\n
-\n {formatted_sub_question_docs}\n\n
"""
INITIAL_RAG_PROMPT_SUB_QUESTION_SEARCH = (
""" \n
{persona_specification}
{date_prompt}
Use the information provided below - and only the provided information - to answer the main question that will be provided.
The information provided below consists of:
1) a number of sub-questions and supporting document information that would help answer them.
2) a broader collection of documents that were deemed relevant for the question. These documents contain informattion
that was also provided in the sub-questions and often more.
IMPORTANT RULES:
- If you cannot reliably answer the question solely using the provided information, say that you cannot reliably answer.
You may give some additional facts you learned, but do not try to invent an answer.
- If the information is irrelevant, just say """
+ f'"{UNKNOWN_ANSWER}"'
+ """.
- If the information is relevant but not fully conclusive, specify that the information is not conclusive and say why.
- The answers to the subquestions should help you to structure your thoughts in order to answer the question.
Please provide inline citations of documentsin the format [[D1]](), [[D2]](), [[D3]](), etc.!
It is important that the citation is close to the information it supports. If you have multiple citations,
please cite for example as [[D1]]()[[D3]](), or [[D2]]()[[D4]](), etc. Feel free to cite documents in addition
to the sub-questions! Proper citations are important for the final answer to be verifiable! \n\n\n
Again, you should be sure that the answer is supported by the information provided!
Try to keep your answer concise. But also highlight uncertainties you may have should there be substantial ones,
or assumptions you made.
Here is the contextual information:
\n-------\n
*Answered Sub-questions (these should really help to organize your thoughts):
{answered_sub_questions}
And here are relevant document information that supports the sub-question answers, or that are relevant for the actual question:\n
{relevant_docs}
\n-------\n
\n
And here is the main question I want you to answer based on the information above:
\n--\n
{question}
\n--\n\n
Answer:"""
)
DIRECT_LLM_PROMPT = """ \n
{persona_specification}
Please answer the following question/address the request:
\n--\n
{question}
\n--\n\n
Answer:"""
INITIAL_RAG_PROMPT = (
""" \n
{persona_specification}
{date_prompt}
Use the information provided below - and only the provided information - to answer the provided main question.
The information provided below consists of:
1) a number of answered sub-questions - these are very important to help you organize your thoughts and your
answer
2) a number of documents that deemed relevant for the question.
{history}
Please provide inline citations to documents in the format [[D1]](), [[D2]](), [[D3]](), etc.! It is important that the citation
is close to the information it supports. If you have multiple citations that support a fact, please cite for example
as [[D1]]()[[D3]](), or [[D2]]()[[D4]](), etc.
Feel free to also cite sub-questions in addition to documents, but make sure that you have documents cited with the sub-question
citation. If you want to cite both a document and a sub-question, please use [[D1]]()[[Q3]](), or [[D2]]()[[D7]]()[[Q4]](), etc.
Again, please NEVER cite sub-questions without a document citation!
Proper citations are very important for the user!
IMPORTANT RULES:
- If you cannot reliably answer the question solely using the provided information, say that you cannot reliably answer.
You may give some additional facts you learned, but do not try to invent an answer.
- If the information is empty or irrelevant, just say """
+ f'"{UNKNOWN_ANSWER}"'
+ """.
- If the information is relevant but not fully conclusive, specify that the information is not conclusive and say why.
Again, you should be sure that the answer is supported by the information provided!
Try to keep your answer concise. But also highlight uncertainties you may have should there be substantial ones,
or assumptions you made.
Here is the contextual information:
\n-------\n
*Answered Sub-questions (these should really matter!):
{answered_sub_questions}
And here are relevant document information that support the sub-question answers, or that are relevant for the actual question:\n
{relevant_docs}
\n-------\n
\n
And here is the question I want you to answer based on the information above:
\n--\n
{question}
\n--\n\n
Answer:"""
)
# sub_question_answer_str is empty
INITIAL_RAG_PROMPT_NO_SUB_QUESTIONS = (
"""{answered_sub_questions}
{persona_specification}
{date_prompt}
Use the information provided below
- and only the provided information - to answer the provided question.
The information provided below consists of a number of documents that were deemed relevant for the question.
{history}
IMPORTANT RULES:
- If you cannot reliably answer the question solely using the provided information, say that you cannot reliably answer.
You may give some additional facts you learned, but do not try to invent an answer.
- If the information is irrelevant, just say """
+ f'"{UNKNOWN_ANSWER}"'
+ """.
- If the information is relevant but not fully conclusive, specify that the information is not conclusive and say why.
Again, you should be sure that the answer is supported by the information provided!
Please provide inline citations to documents in the format [[D1]](), [[D2]](), [[D3]](), etc! It is important that the citation
is close to the information it supports. If you have multiple
citations, please cite for example as [[D1]]()[[D3]](), or [[D2]]()[[D4]](), etc. Citations are very important for the
user!
Try to keep your answer concise.
Here are is the relevant context information:
\n-------\n
{relevant_docs}
\n-------\n
And here is the question I want you to answer based on the context above
\n--\n
{question}
\n--\n
Answer:"""
)
REVISED_RAG_PROMPT = (
"""\n
{persona_specification}
{date_prompt}
Use the information provided below - and only the provided information - to answer the provided main question.
The information provided below consists of:
1) an initial answer that was given but found to be lacking in some way.
2) a number of answered sub-questions - these are very important(!) and definitely should help yoiu to answer
the main question. Note that the sub-questions have a type, 'initial' and 'revised'. The 'initial'
ones were available for the initial answer, and the 'revised' were not. So please use the 'revised' sub-questions in
particular to update/extend/correct the initial answer!
3) a number of documents that were deemed relevant for the question. This the is the context that you use largey for
citations (see below).
Please provide inline citations to documents in the format [[D1]](), [[D2]](), [[D3]](), etc!
It is important that the citation is close to the information it supports. If you have multiple
citations, please cite for example as [[D1]]()[[D3]](), or [[D2]]()[[D4]](), etc.
Feel free to also cite sub-questions in addition to documents, but make sure that you have documents cited with the sub-question
citation. If you want to cite both a document and a sub-question, please use [[D1]]()[[Q3]](), or [[D2]]()[[D7]]()[[Q4]](), etc.
Again, please NEVER cite sub-questions without a document citation!
Proper citations are very important for the user!\n\n
{history}
IMPORTANT RULES:
- If you cannot reliably answer the question solely using the provided information, say that you cannot reliably answer.
You may give some additional facts you learned, but do not try to invent an answer.
- If the information is empty or irrelevant, just say """
+ f'"{UNKNOWN_ANSWER}"'
+ """.
- If the information is relevant but not fully conclusive, provide and answer to the extent you can but also
specify that the information is not conclusive and why.
- Ignore any exisiting citations within the answered sub-questions, like [[D1]]()... and [[Q2]]()!
The citations you will need to use will need to refer to the documents (and sub-questions) that you are explicitly
presented with below!
Again, you should be sure that the answer is supported by the information provided!
Try to keep your answer concise. But also highlight uncertainties you may have should there be substantial ones,
or assumptions you made.
Here is the contextual information:
\n-------\n
*Initial Answer that was found to be lacking:
{initial_answer}
*Answered Sub-questions (these should really help ypu to research your answer! They also contain questions/answers
that were not available when the original answer was constructed):
{answered_sub_questions}
And here are the relevant documents that support the sub-question answers, and that are relevant for the actual question:\n
{relevant_docs}
\n-------\n
\n
Lastly, here is the main question I want you to answer based on the information above:
\n--\n
{question}
\n--\n\n
Answer:"""
)
# sub_question_answer_str is empty
REVISED_RAG_PROMPT_NO_SUB_QUESTIONS = (
"""{answered_sub_questions}\n
{persona_specification}
{date_prompt}
Use the information provided below - and only the
provided information - to answer the provided question.
The information provided below consists of:
1) an initial answer that was given but found to be lacking in some way.
2) a number of documents that were also deemed relevant for the question.
Please provide inline citations to documents in the format [[D1]](), [[D2]](), [[D3]](), etc!
It is important that the citation is close to the information it supports. If you have multiple
citations, please cite for example as [[D1]]()[[D3]](), or [[D2]]()[[D4]](), etc. Citations are very important for the user!\n\n
{history}
IMPORTANT RULES:
- If you cannot reliably answer the question solely using the provided information, say that you cannot reliably answer.
You may give some additional facts you learned, but do not try to invent an answer.
- If the information is empty or irrelevant, just say """
+ f'"{UNKNOWN_ANSWER}"'
+ """.
- If the information is relevant but not fully conclusive, provide and answer to the extent you can but also
specify that the information is not conclusive and why.
Again, you should be sure that the answer is supported by the information provided!
Try to keep your answer concise. But also highlight uncertainties you may have should there be substantial ones,
or assumptions you made.
Here is the contextual information:
\n-------\n
*Initial Answer that was found to be lacking:
{initial_answer}
And here are relevant document information that support the sub-question answers, or that are relevant for the actual question:\n
{relevant_docs}
\n-------\n
\n
Lastly, here is the question I want you to answer based on the information above:
\n--\n
{question}
\n--\n\n
Answer:"""
)
ENTITY_TERM_PROMPT = """ \n
Based on the original question and the context retieved from a dataset, please generate a list of
entities (e.g. companies, organizations, industries, products, locations, etc.), terms and concepts
(e.g. sales, revenue, etc.) that are relevant for the question, plus their relations to each other.
\n\n
Here is the original question:
\n ------- \n
{question}
\n ------- \n
And here is the context retrieved:
\n ------- \n
{context}
\n ------- \n
Please format your answer as a json object in the following format:
{{"retrieved_entities_relationships": {{
"entities": [{{
"entity_name": <assign a name for the entity>,
"entity_type": <specify a short type name for the entity, such as 'company', 'location',...>
}}],
"relationships": [{{
"relationship_name": <assign a name for the relationship>,
"relationship_type": <specify a short type name for the relationship, such as 'sales_to', 'is_location_of',...>,
"relationship_entities": [<related entity name 1>, <related entity name 2>, ...]
}}],
"terms": [{{
"term_name": <assign a name for the term>,
"term_type": <specify a short type name for the term, such as 'revenue', 'market_share',...>,
"term_similar_to": <list terms that are similar to this term>
}}]
}}
}}
"""
ANSWER_COMPARISON_PROMPT = """
For the given question, please compare the initial answer and the refined answer and determine if
the refined answer is substantially better than the initial answer. Better could mean:
- additional information
- more comprehensive information
- more concise information
- more structured information
- substantially more document citations ([[D1]](), [[D2]](), [[D3]](), etc.)
Here is the question:
{question}
Here is the initial answer:
{initial_answer}
Here is the refined answer:
{refined_answer}
With these criteria in mind, is the refined answer substantially better than the initial answer?
Please answer with a simple 'yes' or 'no'.
"""

View File

@@ -0,0 +1,283 @@
import ast
import json
import re
from collections.abc import Callable
from collections.abc import Iterator
from collections.abc import Sequence
from datetime import datetime
from datetime import timedelta
from typing import Any
from typing import cast
from uuid import UUID
from langchain_core.messages import BaseMessage
from langchain_core.messages import HumanMessage
from sqlalchemy.orm import Session
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.agents.agent_search.shared_graph_utils.models import (
EntityRelationshipTermExtraction,
)
from onyx.agents.agent_search.shared_graph_utils.prompts import DATE_PROMPT
from onyx.chat.models import AnswerStyleConfig
from onyx.chat.models import CitationConfig
from onyx.chat.models import DocumentPruningConfig
from onyx.chat.models import PromptConfig
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
from onyx.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
from onyx.configs.constants import DEFAULT_PERSONA_ID
from onyx.context.search.enums import LLMEvaluationType
from onyx.context.search.models import InferenceSection
from onyx.context.search.models import RetrievalDetails
from onyx.context.search.models import SearchRequest
from onyx.db.persona import get_persona_by_id
from onyx.db.persona import Persona
from onyx.llm.interfaces import LLM
from onyx.tools.force import ForceUseTool
from onyx.tools.tool_constructor import SearchToolConfig
from onyx.tools.tool_implementations.search.search_tool import SearchTool
def normalize_whitespace(text: str) -> str:
"""Normalize whitespace in text to single spaces and strip leading/trailing whitespace."""
import re
return re.sub(r"\s+", " ", text.strip())
# Post-processing
def format_docs(docs: Sequence[InferenceSection]) -> str:
formatted_doc_list = []
for doc_nr, doc in enumerate(docs):
formatted_doc_list.append(f"Document D{doc_nr + 1}:\n{doc.combined_content}")
return "\n\n".join(formatted_doc_list)
def format_docs_content_flat(docs: Sequence[InferenceSection]) -> str:
formatted_doc_list = []
for _, doc in enumerate(docs):
formatted_doc_list.append(f"\n...{doc.combined_content}\n")
return "\n\n".join(formatted_doc_list)
def clean_and_parse_list_string(json_string: str) -> list[dict]:
# Remove any prefixes/labels before the actual JSON content
json_string = re.sub(r"^.*?(?=\[)", "", json_string, flags=re.DOTALL)
# Remove markdown code block markers and any newline prefixes
cleaned_string = re.sub(r"```json\n|\n```", "", json_string)
cleaned_string = cleaned_string.replace("\\n", " ").replace("\n", " ")
cleaned_string = " ".join(cleaned_string.split())
# Try parsing with json.loads first, fall back to ast.literal_eval
try:
return json.loads(cleaned_string)
except json.JSONDecodeError:
try:
return ast.literal_eval(cleaned_string)
except (ValueError, SyntaxError) as e:
raise ValueError(f"Failed to parse JSON string: {cleaned_string}") from e
def clean_and_parse_json_string(json_string: str) -> dict[str, Any]:
# Remove markdown code block markers and any newline prefixes
cleaned_string = re.sub(r"```json\n|\n```", "", json_string)
cleaned_string = cleaned_string.replace("\\n", " ").replace("\n", " ")
cleaned_string = " ".join(cleaned_string.split())
# Parse the cleaned string into a Python dictionary
return json.loads(cleaned_string)
def format_entity_term_extraction(
entity_term_extraction_dict: EntityRelationshipTermExtraction,
) -> str:
entities = entity_term_extraction_dict.entities
terms = entity_term_extraction_dict.terms
relationships = entity_term_extraction_dict.relationships
entity_strs = ["\nEntities:\n"]
for entity in entities:
entity_str = f"{entity.entity_name} ({entity.entity_type})"
entity_strs.append(entity_str)
entity_str = "\n - ".join(entity_strs)
relationship_strs = ["\n\nRelationships:\n"]
for relationship in relationships:
relationship_name = relationship.relationship_name
relationship_type = relationship.relationship_type
relationship_entities = relationship.relationship_entities
relationship_str = (
f"""{relationship_name} ({relationship_type}): {relationship_entities}"""
)
relationship_strs.append(relationship_str)
relationship_str = "\n - ".join(relationship_strs)
term_strs = ["\n\nTerms:\n"]
for term in terms:
term_str = f"{term.term_name} ({term.term_type}): similar to {', '.join(term.term_similar_to)}"
term_strs.append(term_str)
term_str = "\n - ".join(term_strs)
return "\n".join(entity_strs + relationship_strs + term_strs)
def _format_time_delta(time: timedelta) -> str:
seconds_from_start = f"{((time).seconds):03d}"
microseconds_from_start = f"{((time).microseconds):06d}"
return f"{seconds_from_start}.{microseconds_from_start}"
def generate_log_message(
message: str,
node_start_time: datetime,
graph_start_time: datetime | None = None,
) -> str:
current_time = datetime.now()
if graph_start_time is not None:
graph_time_str = _format_time_delta(current_time - graph_start_time)
else:
graph_time_str = "N/A"
node_time_str = _format_time_delta(current_time - node_start_time)
return f"{graph_time_str} ({node_time_str} s): {message}"
def get_test_config(
db_session: Session, primary_llm: LLM, fast_llm: LLM, search_request: SearchRequest
) -> tuple[AgentSearchConfig, SearchTool]:
persona = get_persona_by_id(DEFAULT_PERSONA_ID, None, db_session)
document_pruning_config = DocumentPruningConfig(
max_chunks=int(
persona.num_chunks
if persona.num_chunks is not None
else MAX_CHUNKS_FED_TO_CHAT
),
max_window_percentage=CHAT_TARGET_CHUNK_PERCENTAGE,
)
answer_style_config = AnswerStyleConfig(
citation_config=CitationConfig(
# The docs retrieved by this flow are already relevance-filtered
all_docs_useful=True
),
document_pruning_config=document_pruning_config,
structured_response_format=None,
)
search_tool_config = SearchToolConfig(
answer_style_config=answer_style_config,
document_pruning_config=document_pruning_config,
retrieval_options=RetrievalDetails(), # may want to set dedupe_docs=True
rerank_settings=None, # Can use this to change reranking model
selected_sections=None,
latest_query_files=None,
bypass_acl=False,
)
prompt_config = PromptConfig.from_model(persona.prompts[0])
search_tool = SearchTool(
db_session=db_session,
user=None,
persona=persona,
retrieval_options=search_tool_config.retrieval_options,
prompt_config=prompt_config,
llm=primary_llm,
fast_llm=fast_llm,
pruning_config=search_tool_config.document_pruning_config,
answer_style_config=search_tool_config.answer_style_config,
selected_sections=search_tool_config.selected_sections,
chunks_above=search_tool_config.chunks_above,
chunks_below=search_tool_config.chunks_below,
full_doc=search_tool_config.full_doc,
evaluation_type=(
LLMEvaluationType.BASIC
if persona.llm_relevance_filter
else LLMEvaluationType.SKIP
),
rerank_settings=search_tool_config.rerank_settings,
bypass_acl=search_tool_config.bypass_acl,
)
config = AgentSearchConfig(
search_request=search_request,
primary_llm=primary_llm,
fast_llm=fast_llm,
search_tool=search_tool,
force_use_tool=ForceUseTool(force_use=False, tool_name=""),
prompt_builder=AnswerPromptBuilder(
user_message=HumanMessage(content=search_request.query),
message_history=[],
llm_config=primary_llm.config,
raw_user_query=search_request.query,
raw_user_uploaded_files=[],
),
# chat_session_id=UUID("123e4567-e89b-12d3-a456-426614174000"),
chat_session_id=UUID("edda10d5-6cef-45d8-acfb-39317552a1f4"), # Joachim
# chat_session_id=UUID("d1acd613-2692-4bc3-9d65-c6d3da62e58e"), # Evan
message_id=1,
use_persistence=True,
db_session=db_session,
)
return config, search_tool
def get_persona_prompt(persona: Persona | None) -> str:
if persona is None:
return ""
else:
return "\n".join([x.system_prompt for x in persona.prompts])
def make_question_id(level: int, question_nr: int) -> str:
return f"{level}_{question_nr}"
def parse_question_id(question_id: str) -> tuple[int, int]:
level, question_nr = question_id.split("_")
return int(level), int(question_nr)
def _dispatch_nonempty(
content: str, dispatch_event: Callable[[str, int], None], num: int
) -> None:
if content != "":
dispatch_event(content, num)
def dispatch_separated(
token_itr: Iterator[BaseMessage],
dispatch_event: Callable[[str, int], None],
sep: str = "\n",
) -> list[str | list[str | dict[str, Any]]]:
num = 1
streamed_tokens: list[str | list[str | dict[str, Any]]] = [""]
for message in token_itr:
content = cast(str, message.content)
if sep in content:
sub_question_parts = content.split(sep)
_dispatch_nonempty(sub_question_parts[0], dispatch_event, num)
num += 1
_dispatch_nonempty(
"".join(sub_question_parts[1:]).strip(), dispatch_event, num
)
else:
_dispatch_nonempty(content, dispatch_event, num)
streamed_tokens.append(content)
return streamed_tokens
def get_today_prompt() -> str:
return DATE_PROMPT.format(date=datetime.now().strftime("%A, %B %d, %Y"))

View File

@@ -23,7 +23,6 @@ def load_no_auth_user_preferences(store: KeyValueStore) -> UserPreferences:
preferences_data = cast(
Mapping[str, Any], store.load(KV_NO_AUTH_USER_PREFERENCES_KEY)
)
print("preferences_data", preferences_data)
return UserPreferences(**preferences_data)
except KvKeyNotFoundError:
return UserPreferences(

View File

@@ -55,6 +55,7 @@ from onyx.auth.invited_users import get_invited_users
from onyx.auth.schemas import UserCreate
from onyx.auth.schemas import UserRole
from onyx.auth.schemas import UserUpdate
from onyx.configs.app_configs import AUTH_COOKIE_EXPIRE_TIME_SECONDS
from onyx.configs.app_configs import AUTH_TYPE
from onyx.configs.app_configs import DISABLE_AUTH
from onyx.configs.app_configs import EMAIL_CONFIGURED
@@ -209,6 +210,7 @@ def verify_email_domain(email: str) -> None:
class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
reset_password_token_secret = USER_AUTH_SECRET
verification_token_secret = USER_AUTH_SECRET
verification_token_lifetime_seconds = AUTH_COOKIE_EXPIRE_TIME_SECONDS
user_db: SQLAlchemyUserDatabase[User, uuid.UUID]

View File

@@ -23,8 +23,7 @@ from onyx.background.celery.celery_utils import celery_is_worker_primary
from onyx.configs.constants import ONYX_CLOUD_CELERY_TASK_PREFIX
from onyx.configs.constants import OnyxRedisLocks
from onyx.db.engine import get_sqlalchemy_engine
from onyx.document_index.vespa.shared_utils.utils import get_vespa_http_client
from onyx.document_index.vespa_constants import VESPA_CONFIG_SERVER_URL
from onyx.document_index.vespa.shared_utils.utils import wait_for_vespa_with_timeout
from onyx.redis.redis_connector import RedisConnector
from onyx.redis.redis_connector_credential_pair import RedisConnectorCredentialPair
from onyx.redis.redis_connector_delete import RedisConnectorDelete
@@ -280,51 +279,6 @@ def wait_for_db(sender: Any, **kwargs: Any) -> None:
return
def wait_for_vespa(sender: Any, **kwargs: Any) -> None:
"""Waits for Vespa to become ready subject to a hardcoded timeout.
Will raise WorkerShutdown to kill the celery worker if the timeout is reached."""
WAIT_INTERVAL = 5
WAIT_LIMIT = 60
ready = False
time_start = time.monotonic()
logger.info("Vespa: Readiness probe starting.")
while True:
try:
client = get_vespa_http_client()
response = client.get(f"{VESPA_CONFIG_SERVER_URL}/state/v1/health")
response.raise_for_status()
response_dict = response.json()
if response_dict["status"]["code"] == "up":
ready = True
break
except Exception:
pass
time_elapsed = time.monotonic() - time_start
if time_elapsed > WAIT_LIMIT:
break
logger.info(
f"Vespa: Readiness probe ongoing. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}"
)
time.sleep(WAIT_INTERVAL)
if not ready:
msg = (
f"Vespa: Readiness probe did not succeed within the timeout "
f"({WAIT_LIMIT} seconds). Exiting..."
)
logger.error(msg)
raise WorkerShutdown(msg)
logger.info("Vespa: Readiness probe succeeded. Continuing...")
return
def on_secondary_worker_init(sender: Any, **kwargs: Any) -> None:
logger.info("Running as a secondary celery worker.")
@@ -510,3 +464,13 @@ def reset_tenant_id(
) -> None:
"""Signal handler to reset tenant ID in context var after task ends."""
CURRENT_TENANT_ID_CONTEXTVAR.set(POSTGRES_DEFAULT_SCHEMA)
def wait_for_vespa_or_shutdown(sender: Any, **kwargs: Any) -> None:
"""Waits for Vespa to become ready subject to a timeout.
Raises WorkerShutdown if the timeout is reached."""
if not wait_for_vespa_with_timeout():
msg = "Vespa: Readiness probe did not succeed within the timeout. Exiting..."
logger.error(msg)
raise WorkerShutdown(msg)

View File

@@ -62,7 +62,7 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
app_base.wait_for_redis(sender, **kwargs)
app_base.wait_for_db(sender, **kwargs)
app_base.wait_for_vespa(sender, **kwargs)
app_base.wait_for_vespa_or_shutdown(sender, **kwargs)
# Less startup checks in multi-tenant case
if MULTI_TENANT:

View File

@@ -68,7 +68,7 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
app_base.wait_for_redis(sender, **kwargs)
app_base.wait_for_db(sender, **kwargs)
app_base.wait_for_vespa(sender, **kwargs)
app_base.wait_for_vespa_or_shutdown(sender, **kwargs)
# Less startup checks in multi-tenant case
if MULTI_TENANT:

View File

@@ -63,7 +63,7 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
app_base.wait_for_redis(sender, **kwargs)
app_base.wait_for_db(sender, **kwargs)
app_base.wait_for_vespa(sender, **kwargs)
app_base.wait_for_vespa_or_shutdown(sender, **kwargs)
# Less startup checks in multi-tenant case
if MULTI_TENANT:

View File

@@ -86,7 +86,7 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
app_base.wait_for_redis(sender, **kwargs)
app_base.wait_for_db(sender, **kwargs)
app_base.wait_for_vespa(sender, **kwargs)
app_base.wait_for_vespa_or_shutdown(sender, **kwargs)
logger.info("Running as the primary celery worker.")

View File

@@ -29,6 +29,16 @@ cloud_tasks_to_schedule = [
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-alembic",
"task": OnyxCeleryTask.CLOUD_CHECK_ALEMBIC,
"schedule": timedelta(hours=1),
"options": {
"priority": OnyxCeleryPriority.HIGH,
"expires": BEAT_EXPIRES_DEFAULT,
"queue": OnyxCeleryQueues.MONITORING,
},
},
]
# tasks that run in either self-hosted on cloud

View File

@@ -674,6 +674,9 @@ def connector_indexing_proxy_task(
while True:
sleep(5)
# renew watchdog signal (this has a shorter timeout than set_active)
redis_connector_index.set_watchdog(True)
# renew active signal
redis_connector_index.set_active()
@@ -780,6 +783,7 @@ def connector_indexing_proxy_task(
)
continue
redis_connector_index.set_watchdog(False)
task_logger.info(
f"Indexing watchdog - finished: attempt={index_attempt_id} "
f"cc_pair={cc_pair_id} "

View File

@@ -1,6 +1,8 @@
import json
import time
from collections.abc import Callable
from datetime import timedelta
from itertools import islice
from typing import Any
from celery import shared_task
@@ -10,13 +12,17 @@ from pydantic import BaseModel
from redis import Redis
from redis.lock import Lock as RedisLock
from sqlalchemy import select
from sqlalchemy import text
from sqlalchemy.orm import Session
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.tasks.vespa.tasks import celery_get_queue_length
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import ONYX_CLOUD_TENANT_ID
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisLocks
from onyx.db.engine import get_all_tenant_ids
from onyx.db.engine import get_db_current_time
from onyx.db.engine import get_session_with_tenant
from onyx.db.enums import IndexingStatus
@@ -27,6 +33,7 @@ from onyx.db.models import IndexAttempt
from onyx.db.models import SyncRecord
from onyx.db.models import UserGroup
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_pool import redis_lock_dump
from onyx.utils.telemetry import optional_telemetry
from onyx.utils.telemetry import RecordType
@@ -456,3 +463,116 @@ def monitor_background_processes(self: Task, *, tenant_id: str | None) -> None:
lock_monitoring.release()
task_logger.info("Background monitoring task finished")
@shared_task(
name=OnyxCeleryTask.CLOUD_CHECK_ALEMBIC,
)
def cloud_check_alembic() -> bool | None:
"""A task to verify that all tenants are on the same alembic revision.
This check is expected to fail if a cloud alembic migration is currently running
across all tenants.
TODO: have the cloud migration script set an activity signal that this check
uses to know it doesn't make sense to run a check at the present time.
"""
time_start = time.monotonic()
redis_client = get_redis_client(tenant_id=ONYX_CLOUD_TENANT_ID)
lock_beat: RedisLock = redis_client.lock(
OnyxRedisLocks.CLOUD_CHECK_ALEMBIC_BEAT_LOCK,
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
)
# these tasks should never overlap
if not lock_beat.acquire(blocking=False):
return None
last_lock_time = time.monotonic()
tenant_to_revision: dict[str, str | None] = {}
revision_counts: dict[str, int] = {}
out_of_date_tenants: dict[str, str | None] = {}
top_revision: str = ""
try:
# map each tenant_id to its revision
tenant_ids = get_all_tenant_ids()
for tenant_id in tenant_ids:
current_time = time.monotonic()
if current_time - last_lock_time >= (CELERY_GENERIC_BEAT_LOCK_TIMEOUT / 4):
lock_beat.reacquire()
last_lock_time = current_time
if tenant_id is None:
continue
with get_session_with_tenant(tenant_id=None) as session:
result = session.execute(
text(f'SELECT * FROM "{tenant_id}".alembic_version LIMIT 1')
)
result_scalar: str | None = result.scalar_one_or_none()
tenant_to_revision[tenant_id] = result_scalar
# get the total count of each revision
for k, v in tenant_to_revision.items():
if v is None:
continue
revision_counts[v] = revision_counts.get(v, 0) + 1
# get the revision with the most counts
sorted_revision_counts = sorted(
revision_counts.items(), key=lambda item: item[1], reverse=True
)
if len(sorted_revision_counts) == 0:
task_logger.error(
f"cloud_check_alembic - No revisions found for {len(tenant_ids)} tenant ids!"
)
else:
top_revision, _ = sorted_revision_counts[0]
# build a list of out of date tenants
for k, v in tenant_to_revision.items():
if v == top_revision:
continue
out_of_date_tenants[k] = v
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
)
except Exception:
task_logger.exception("Unexpected exception during cloud alembic check")
raise
finally:
if lock_beat.owned():
lock_beat.release()
else:
task_logger.error("cloud_check_alembic - Lock not owned on completion")
redis_lock_dump(lock_beat, redis_client)
if len(out_of_date_tenants) > 0:
task_logger.error(
f"Found out of date tenants: "
f"num_out_of_date_tenants={len(out_of_date_tenants)} "
f"num_tenants={len(tenant_ids)} "
f"revision={top_revision}"
)
for k, v in islice(out_of_date_tenants.items(), 5):
task_logger.info(f"Out of date tenant: tenant={k} revision={v}")
else:
task_logger.info(
f"All tenants are up to date: num_tenants={len(tenant_ids)} revision={top_revision}"
)
time_elapsed = time.monotonic() - time_start
task_logger.info(
f"cloud_check_alembic finished: num_tenants={len(tenant_ids)} elapsed={time_elapsed:.2f}"
)
return True

View File

@@ -735,7 +735,7 @@ def monitor_ccpair_indexing_taskset(
composite_id = RedisConnector.get_id_from_fence_key(fence_key)
if composite_id is None:
task_logger.warning(
f"monitor_ccpair_indexing_taskset: could not parse composite_id from {fence_key}"
f"Connector indexing: could not parse composite_id from {fence_key}"
)
return
@@ -785,6 +785,7 @@ def monitor_ccpair_indexing_taskset(
# inner/outer/inner double check pattern to avoid race conditions when checking for
# bad state
# Verify: if the generator isn't complete, the task must not be in READY state
# inner = get_completion / generator_complete not signaled
# outer = result.state in READY state
status_int = redis_connector_index.get_completion()
@@ -830,7 +831,7 @@ def monitor_ccpair_indexing_taskset(
)
except Exception:
task_logger.exception(
"monitor_ccpair_indexing_taskset - transient exception marking index attempt as failed: "
"Connector indexing - Transient exception marking index attempt as failed: "
f"attempt={payload.index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
@@ -840,6 +841,20 @@ def monitor_ccpair_indexing_taskset(
redis_connector_index.reset()
return
if redis_connector_index.watchdog_signaled():
# if the generator is complete, don't clean up until the watchdog has exited
task_logger.info(
f"Connector indexing - Delaying finalization until watchdog has exited: "
f"attempt={payload.index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id} "
f"progress={progress} "
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f} "
f"elapsed_started={elapsed_started_str}"
)
return
status_enum = HTTPStatus(status_int)
task_logger.info(
@@ -858,9 +873,13 @@ def monitor_ccpair_indexing_taskset(
@shared_task(name=OnyxCeleryTask.MONITOR_VESPA_SYNC, soft_time_limit=300, bind=True)
def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool | None:
"""This is a celery beat task that monitors and finalizes metadata sync tasksets.
"""This is a celery beat task that monitors and finalizes various long running tasks.
The name monitor_vespa_sync is a bit of a misnomer since it checks many different tasks
now. Should change that at some point.
It scans for fence values and then gets the counts of any associated tasksets.
If the count is 0, that means all tasks finished and we should clean up.
For many tasks, the count is 0, that means all tasks finished and we should clean up.
This task lock timeout is CELERY_METADATA_SYNC_BEAT_LOCK_TIMEOUT seconds, so don't
do anything too expensive in this function!
@@ -1045,6 +1064,8 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool | None:
def vespa_metadata_sync_task(
self: Task, document_id: str, tenant_id: str | None
) -> bool:
start = time.monotonic()
try:
with get_session_with_tenant(tenant_id) as db_session:
curr_ind_name, sec_ind_name = get_both_index_names(db_session)
@@ -1095,7 +1116,13 @@ def vespa_metadata_sync_task(
# r = get_redis_client(tenant_id=tenant_id)
# r.delete(redis_syncing_key)
task_logger.info(f"doc={document_id} action=sync chunks={chunks_affected}")
elapsed = time.monotonic() - start
task_logger.info(
f"doc={document_id} "
f"action=sync "
f"chunks={chunks_affected} "
f"elapsed={elapsed:.2f}"
)
except SoftTimeLimitExceeded:
task_logger.info(f"SoftTimeLimitExceeded exception. doc={document_id}")
except Exception as ex:

View File

@@ -1,50 +1,35 @@
from collections import defaultdict
from collections.abc import Callable
from collections.abc import Iterator
from uuid import uuid4
from langchain.schema.messages import BaseMessage
from langchain_core.messages import AIMessageChunk
from langchain_core.messages import ToolCall
from sqlalchemy.orm import Session
from onyx.chat.llm_response_handler import LLMResponseHandlerManager
from onyx.chat.models import AnswerQuestionPossibleReturn
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.agents.agent_search.run_graph import run_basic_graph
from onyx.agents.agent_search.run_graph import run_main_graph
from onyx.chat.models import AgentAnswerPiece
from onyx.chat.models import AnswerPacket
from onyx.chat.models import AnswerStream
from onyx.chat.models import AnswerStyleConfig
from onyx.chat.models import CitationInfo
from onyx.chat.models import OnyxAnswerPiece
from onyx.chat.models import PromptConfig
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_system_message
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_user_message
from onyx.chat.prompt_builder.answer_prompt_builder import LLMCall
from onyx.chat.stream_processing.answer_response_handler import (
CitationResponseHandler,
)
from onyx.chat.stream_processing.answer_response_handler import (
DummyAnswerResponseHandler,
)
from onyx.chat.stream_processing.utils import (
map_document_id_order,
)
from onyx.chat.tool_handling.tool_response_handler import ToolResponseHandler
from onyx.chat.models import StreamStopInfo
from onyx.chat.models import StreamStopReason
from onyx.chat.tool_handling.tool_response_handler import get_tool_by_name
from onyx.configs.constants import BASIC_KEY
from onyx.file_store.utils import InMemoryChatFile
from onyx.llm.interfaces import LLM
from onyx.llm.models import PreviousMessage
from onyx.natural_language_processing.utils import get_tokenizer
from onyx.tools.force import ForceUseTool
from onyx.tools.models import ToolResponse
from onyx.tools.tool import Tool
from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.tools.tool_runner import ToolCallKickoff
from onyx.tools.utils import explicit_tool_calling_supported
from onyx.utils.logger import setup_logger
logger = setup_logger()
AnswerStream = Iterator[AnswerQuestionPossibleReturn | ToolCallKickoff | ToolResponse]
class Answer:
def __init__(
self,
@@ -53,13 +38,13 @@ class Answer:
llm: LLM,
prompt_config: PromptConfig,
force_use_tool: ForceUseTool,
agent_search_config: AgentSearchConfig,
# must be the same length as `docs`. If None, all docs are considered "relevant"
message_history: list[PreviousMessage] | None = None,
single_message_history: str | None = None,
# newly passed in files to include as part of this question
# TODO THIS NEEDS TO BE HANDLED
latest_query_files: list[InMemoryChatFile] | None = None,
files: list[InMemoryChatFile] | None = None,
tools: list[Tool] | None = None,
# NOTE: for native tool-calling, this is only supported by OpenAI atm,
# but we only support them anyways
@@ -69,6 +54,8 @@ class Answer:
return_contexts: bool = False,
skip_gen_ai_answer_generation: bool = False,
is_connected: Callable[[], bool] | None = None,
fast_llm: LLM | None = None,
db_session: Session | None = None,
) -> None:
if single_message_history and message_history:
raise ValueError(
@@ -79,7 +66,6 @@ class Answer:
self.is_connected: Callable[[], bool] | None = is_connected
self.latest_query_files = latest_query_files or []
self.file_id_to_file = {file.file_id: file for file in (files or [])}
self.tools = tools or []
self.force_use_tool = force_use_tool
@@ -92,6 +78,7 @@ class Answer:
self.prompt_config = prompt_config
self.llm = llm
self.fast_llm = fast_llm
self.llm_tokenizer = get_tokenizer(
provider_type=llm.config.model_provider,
model_name=llm.config.model_name,
@@ -100,9 +87,7 @@ class Answer:
self._final_prompt: list[BaseMessage] | None = None
self._streamed_output: list[str] | None = None
self._processed_stream: (
list[AnswerQuestionPossibleReturn | ToolResponse | ToolCallKickoff] | None
) = None
self._processed_stream: (list[AnswerPacket] | None) = None
self._return_contexts = return_contexts
self.skip_gen_ai_answer_generation = skip_gen_ai_answer_generation
@@ -115,131 +100,150 @@ class Answer:
and not skip_explicit_tool_calling
)
self.agent_search_config = agent_search_config
self.db_session = db_session
def _get_tools_list(self) -> list[Tool]:
if not self.force_use_tool.force_use:
return self.tools
tool = next(
(t for t in self.tools if t.name == self.force_use_tool.tool_name), None
)
if tool is None:
raise RuntimeError(f"Tool '{self.force_use_tool.tool_name}' not found")
tool = get_tool_by_name(self.tools, self.force_use_tool.tool_name)
logger.info(
f"Forcefully using tool='{tool.name}'"
+ (
f" with args='{self.force_use_tool.args}'"
if self.force_use_tool.args is not None
else ""
)
args_str = (
f" with args='{self.force_use_tool.args}'"
if self.force_use_tool.args
else ""
)
logger.info(f"Forcefully using tool='{tool.name}'{args_str}")
return [tool]
def _handle_specified_tool_call(
self, llm_calls: list[LLMCall], tool: Tool, tool_args: dict
) -> AnswerStream:
current_llm_call = llm_calls[-1]
# TODO: delete the function and move the full body to processed_streamed_output
def _get_response(self) -> AnswerStream:
# current_llm_call = llm_calls[-1]
# make a dummy tool handler
tool_handler = ToolResponseHandler([tool])
# tool, tool_args = None, None
# # handle the case where no decision has to be made; we simply run the tool
# if (
# current_llm_call.force_use_tool.force_use
# and current_llm_call.force_use_tool.args is not None
# ):
# tool_name, tool_args = (
# current_llm_call.force_use_tool.tool_name,
# current_llm_call.force_use_tool.args,
# )
# tool = get_tool_by_name(current_llm_call.tools, tool_name)
dummy_tool_call_chunk = AIMessageChunk(content="")
dummy_tool_call_chunk.tool_calls = [
ToolCall(name=tool.name, args=tool_args, id=str(uuid4()))
]
# # special pre-logic for non-tool calling LLM case
# elif not self.using_tool_calling_llm and current_llm_call.tools:
# chosen_tool_and_args = (
# ToolResponseHandler.get_tool_call_for_non_tool_calling_llm(
# current_llm_call, self.llm
# )
# )
# if chosen_tool_and_args:
# tool, tool_args = chosen_tool_and_args
response_handler_manager = LLMResponseHandlerManager(
tool_handler, DummyAnswerResponseHandler(), self.is_cancelled
)
yield from response_handler_manager.handle_llm_response(
iter([dummy_tool_call_chunk])
)
# if tool and tool_args:
# dummy_tool_call_chunk = AIMessageChunk(content="")
# dummy_tool_call_chunk.tool_calls = [
# ToolCall(name=tool.name, args=tool_args, id=str(uuid4()))
# ]
new_llm_call = response_handler_manager.next_llm_call(current_llm_call)
if new_llm_call:
yield from self._get_response(llm_calls + [new_llm_call])
else:
raise RuntimeError("Tool call handler did not return a new LLM call")
# response_handler_manager = LLMResponseHandlerManager(
# ToolResponseHandler([tool]), None, self.is_cancelled
# )
# yield from response_handler_manager.handle_llm_response(
# iter([dummy_tool_call_chunk])
# )
def _get_response(self, llm_calls: list[LLMCall]) -> AnswerStream:
current_llm_call = llm_calls[-1]
# tmp_call = response_handler_manager.next_llm_call(current_llm_call)
# if tmp_call is None:
# return # no more LLM calls to process
# current_llm_call = tmp_call
# handle the case where no decision has to be made; we simply run the tool
if (
current_llm_call.force_use_tool.force_use
and current_llm_call.force_use_tool.args is not None
):
tool_name, tool_args = (
current_llm_call.force_use_tool.tool_name,
current_llm_call.force_use_tool.args,
)
tool = next(
(t for t in current_llm_call.tools if t.name == tool_name), None
)
if not tool:
raise RuntimeError(f"Tool '{tool_name}' not found")
# # if we're skipping gen ai answer generation, we should break
# # out unless we're forcing a tool call. If we don't, we might generate an
# # answer, which is a no-no!
# if (
# self.skip_gen_ai_answer_generation
# and not current_llm_call.force_use_tool.force_use
# ):
# return
yield from self._handle_specified_tool_call(llm_calls, tool, tool_args)
return
# # set up "handlers" to listen to the LLM response stream and
# # feed back the processed results + handle tool call requests
# # + figure out what the next LLM call should be
# tool_call_handler = ToolResponseHandler(current_llm_call.tools)
# special pre-logic for non-tool calling LLM case
if not self.using_tool_calling_llm and current_llm_call.tools:
chosen_tool_and_args = (
ToolResponseHandler.get_tool_call_for_non_tool_calling_llm(
current_llm_call, self.llm
# final_search_results, displayed_search_results = SearchTool.get_search_result(
# current_llm_call
# ) or ([], [])
# # NEXT: we still want to handle the LLM response stream, but it is now:
# # 1. handle the tool call requests
# # 2. feed back the processed results
# # 3. handle the citations
# answer_handler = CitationResponseHandler(
# context_docs=final_search_results,
# final_doc_id_to_rank_map=map_document_id_order(final_search_results),
# display_doc_id_to_rank_map=map_document_id_order(displayed_search_results),
# )
# # At the moment, this wrapper class passes streamed stuff through citation and tool handlers.
# # In the future, we'll want to handle citations and tool calls in the langgraph graph.
# response_handler_manager = LLMResponseHandlerManager(
# tool_call_handler, answer_handler, self.is_cancelled
# )
# In langgraph, whether we do the basic thing (call llm stream) or pro search
# is based on a flag in the pro search config
if self.agent_search_config.use_agentic_search:
if (
self.agent_search_config.db_session is None
and self.agent_search_config.use_persistence
):
raise ValueError(
"db_session must be provided for pro search when using persistence"
)
stream = run_main_graph(
config=self.agent_search_config,
)
else:
stream = run_basic_graph(
config=self.agent_search_config,
)
if chosen_tool_and_args:
tool, tool_args = chosen_tool_and_args
yield from self._handle_specified_tool_call(llm_calls, tool, tool_args)
return
# if we're skipping gen ai answer generation, we should break
# out unless we're forcing a tool call. If we don't, we might generate an
# answer, which is a no-no!
if (
self.skip_gen_ai_answer_generation
and not current_llm_call.force_use_tool.force_use
):
return
# set up "handlers" to listen to the LLM response stream and
# feed back the processed results + handle tool call requests
# + figure out what the next LLM call should be
tool_call_handler = ToolResponseHandler(current_llm_call.tools)
final_search_results, displayed_search_results = SearchTool.get_search_result(
current_llm_call
) or ([], [])
answer_handler = CitationResponseHandler(
context_docs=final_search_results,
final_doc_id_to_rank_map=map_document_id_order(final_search_results),
display_doc_id_to_rank_map=map_document_id_order(displayed_search_results),
)
response_handler_manager = LLMResponseHandlerManager(
tool_call_handler, answer_handler, self.is_cancelled
)
processed_stream = []
for packet in stream:
if self.is_cancelled():
packet = StreamStopInfo(stop_reason=StreamStopReason.CANCELLED)
yield packet
break
processed_stream.append(packet)
yield packet
self._processed_stream = processed_stream
return
# DEBUG: good breakpoint
stream = self.llm.stream(
# For tool calling LLMs, we want to insert the task prompt as part of this flow, this is because the LLM
# may choose to not call any tools and just generate the answer, in which case the task prompt is needed.
prompt=current_llm_call.prompt_builder.build(),
tools=[tool.tool_definition() for tool in current_llm_call.tools] or None,
tool_choice=(
"required"
if current_llm_call.tools and current_llm_call.force_use_tool.force_use
else None
),
structured_response_format=self.answer_style_config.structured_response_format,
)
yield from response_handler_manager.handle_llm_response(stream)
# stream = self.llm.stream(
# # For tool calling LLMs, we want to insert the task prompt as part of this flow, this is because the LLM
# # may choose to not call any tools and just generate the answer, in which case the task prompt is needed.
# prompt=current_llm_call.prompt_builder.build(),
# tools=[tool.tool_definition() for tool in current_llm_call.tools] or None,
# tool_choice=(
# "required"
# if current_llm_call.tools and current_llm_call.force_use_tool.force_use
# else None
# ),
# structured_response_format=self.answer_style_config.structured_response_format,
# )
# yield from response_handler_manager.handle_llm_response(stream)
new_llm_call = response_handler_manager.next_llm_call(current_llm_call)
if new_llm_call:
yield from self._get_response(llm_calls + [new_llm_call])
# new_llm_call = response_handler_manager.next_llm_call(current_llm_call)
# if new_llm_call:
# yield from self._get_response(llm_calls + [new_llm_call])
@property
def processed_streamed_output(self) -> AnswerStream:
@@ -247,33 +251,33 @@ class Answer:
yield from self._processed_stream
return
prompt_builder = AnswerPromptBuilder(
user_message=default_build_user_message(
user_query=self.question,
prompt_config=self.prompt_config,
files=self.latest_query_files,
single_message_history=self.single_message_history,
),
message_history=self.message_history,
llm_config=self.llm.config,
raw_user_query=self.question,
raw_user_uploaded_files=self.latest_query_files or [],
single_message_history=self.single_message_history,
)
prompt_builder.update_system_prompt(
default_build_system_message(self.prompt_config)
)
llm_call = LLMCall(
prompt_builder=prompt_builder,
tools=self._get_tools_list(),
force_use_tool=self.force_use_tool,
files=self.latest_query_files,
tool_call_info=[],
using_tool_calling_llm=self.using_tool_calling_llm,
)
# prompt_builder = AnswerPromptBuilder(
# user_message=default_build_user_message(
# user_query=self.question,
# prompt_config=self.prompt_config,
# files=self.latest_query_files,
# single_message_history=self.single_message_history,
# ),
# message_history=self.message_history,
# llm_config=self.llm.config,
# raw_user_query=self.question,
# raw_user_uploaded_files=self.latest_query_files or [],
# single_message_history=self.single_message_history,
# )
# prompt_builder.update_system_prompt(
# default_build_system_message(self.prompt_config)
# )
# llm_call = LLMCall(
# prompt_builder=prompt_builder,
# tools=self._get_tools_list(),
# force_use_tool=self.force_use_tool,
# files=self.latest_query_files,
# tool_call_info=[],
# using_tool_calling_llm=self.using_tool_calling_llm,
# )
processed_stream = []
for processed_packet in self._get_response([llm_call]):
for processed_packet in self._get_response():
processed_stream.append(processed_packet)
yield processed_packet
@@ -283,20 +287,56 @@ class Answer:
def llm_answer(self) -> str:
answer = ""
for packet in self.processed_streamed_output:
if isinstance(packet, OnyxAnswerPiece) and packet.answer_piece:
# handle basic answer flow, plus level 0 agent answer flow
# since level 0 is the first answer the user sees and therefore the
# child message of the user message in the db (so it is handled
# like a basic flow answer)
if (isinstance(packet, OnyxAnswerPiece) and packet.answer_piece) or (
isinstance(packet, AgentAnswerPiece)
and packet.answer_piece
and packet.answer_type == "agent_level_answer"
and packet.level == 0
):
answer += packet.answer_piece
return answer
def llm_answer_by_level(self) -> dict[int, str]:
answer_by_level: dict[int, str] = defaultdict(str)
for packet in self.processed_streamed_output:
if (
isinstance(packet, AgentAnswerPiece)
and packet.answer_piece
and packet.answer_type == "agent_level_answer"
):
answer_by_level[packet.level] += packet.answer_piece
elif isinstance(packet, OnyxAnswerPiece) and packet.answer_piece:
answer_by_level[BASIC_KEY[0]] += packet.answer_piece
return answer_by_level
@property
def citations(self) -> list[CitationInfo]:
citations: list[CitationInfo] = []
for packet in self.processed_streamed_output:
if isinstance(packet, CitationInfo):
if isinstance(packet, CitationInfo) and packet.level is None:
citations.append(packet)
return citations
def citations_by_subquestion(self) -> dict[tuple[int, int], list[CitationInfo]]:
citations_by_subquestion: dict[
tuple[int, int], list[CitationInfo]
] = defaultdict(list)
for packet in self.processed_streamed_output:
if isinstance(packet, CitationInfo):
if packet.level_question_nr is not None and packet.level is not None:
citations_by_subquestion[
(packet.level, packet.level_question_nr)
].append(packet)
elif packet.level is None:
citations_by_subquestion[BASIC_KEY].append(packet)
return citations_by_subquestion
def is_cancelled(self) -> bool:
if self._is_cancelled:
return True

View File

@@ -48,6 +48,8 @@ def prepare_chat_message_request(
retrieval_details: RetrievalDetails | None,
rerank_settings: RerankingDetails | None,
db_session: Session,
use_agentic_search: bool = False,
skip_gen_ai_answer_generation: bool = False,
) -> CreateChatMessageRequest:
# Typically used for one shot flows like SlackBot or non-chat API endpoint use cases
new_chat_session = create_chat_session(
@@ -72,6 +74,8 @@ def prepare_chat_message_request(
search_doc_ids=None,
retrieval_options=retrieval_details,
rerank_settings=rerank_settings,
use_agentic_search=use_agentic_search,
skip_gen_ai_answer_generation=skip_gen_ai_answer_generation,
)

Some files were not shown because too many files have changed in this diff Show More