Compare commits

...

159 Commits

Author SHA1 Message Date
joachim-danswer
16a4149a05 fixed Clarification treatment 2025-08-16 17:33:31 -07:00
joachim-danswer
fe4e1b75fb fixed tool identification (should be further improved) 2025-08-16 14:24:49 -07:00
Weves
06f0224622 Fix build 2025-08-16 11:00:55 -07:00
Weves
d512912fcd Adjust toggle behavior 2025-08-16 10:51:55 -07:00
Weves
30991db439 Refactor + fix citations 2025-08-15 20:50:03 -07:00
Weves
5f07f189bf Merge branch 'dr-merge-chris' into dr-merge 2025-08-15 19:50:12 -07:00
Weves
c3cb351231 Fix citations + assistant selection 2025-08-15 19:49:58 -07:00
joachim-danswer
088156f0bf initial image generation 2025-08-15 19:13:13 -07:00
joachim-danswer
86c0bfec1c Prompts, Assistants, and Answer style 2025-08-15 18:19:58 -07:00
joachim-danswer
f5d1319ffa fixes
- general conversation
- some Deep Research aspects
2025-08-15 13:20:48 -07:00
Weves
8e55e1c2ec Fix build 2025-08-15 10:47:31 -07:00
Weves
5678ceeae9 small fix 2025-08-15 10:40:03 -07:00
Weves
82f18bec66 Fix 2025-08-15 10:36:47 -07:00
Weves
55b314fba6 Small FE improvements 2025-08-15 10:35:58 -07:00
joachim-danswer
ac062f43a5 streaming 2025-08-14 19:19:49 -07:00
joachim-danswer
9f783e2762 streaming back - base 2025-08-14 14:56:11 -07:00
joachim-danswer
55c26f5f34 wip - citation rework start 2025-08-14 14:56:11 -07:00
Rei Meguro
050524d8b0 custom tool + orchestrator fix for non-document returning tools 2025-08-14 21:10:26 +09:00
joachim-danswer
0930712fdf reload progress 2025-08-13 16:24:13 -07:00
joachim-danswer
486c820c46 nit 2025-08-13 14:59:37 -07:00
joachim-danswer
9a568a10e1 start reload 2025-08-13 14:57:19 -07:00
Weves
91b91311d5 Merge branch 'dr-merge-chris' into dr-merge 2025-08-13 13:54:26 -07:00
Weves
5dc5bf4ff5 Error handling 2025-08-13 13:54:18 -07:00
joachim-danswer
f7fc1c827d kg streaming 2025-08-12 19:55:31 -07:00
Weves
e721618e95 Fix build 2025-08-12 19:40:01 -07:00
Weves
04897e222e Merge branch 'dr-merge-chris' into dr-merge 2025-08-12 19:32:35 -07:00
Weves
09ce320f7f More 2025-08-12 19:32:13 -07:00
joachim-danswer
99f79804bb internet search streaming 2025-08-12 18:31:09 -07:00
joachim-danswer
062c9a4b73 internet search streaming 2025-08-12 18:15:02 -07:00
Weves
fdd83eeaa9 more improvements 2025-08-12 18:02:16 -07:00
Weves
f4828cfc18 More improvements 2025-08-12 17:53:28 -07:00
joachim-danswer
f526171ca4 fix for general assistants 2025-08-12 17:37:55 -07:00
Weves
8b8802f9be More improvements 2025-08-12 16:51:05 -07:00
Weves
dc3ca66f6d Merge branch 'dr-merge-chris' into dr-merge 2025-08-12 16:16:34 -07:00
Weves
3a9f1accaf Input bar stuff draft 2025-08-12 16:16:19 -07:00
joachim-danswer
d264e880f2 commit fix 2025-08-12 16:06:27 -07:00
joachim-danswer
9e0a22d866 fix 2025-08-12 15:42:41 -07:00
joachim-danswer
bcbb075b96 basic search with error at end 2025-08-12 15:37:03 -07:00
Weves
5526e2b34b Simplify backend 2025-08-12 13:32:40 -07:00
Weves
ee1541061c More fixes 2025-08-12 11:54:53 -07:00
Weves
c1446d1508 Cleanup 2025-08-12 11:52:58 -07:00
Weves
acf9f615b1 Merge branch 'refactor-message-protocol' into dr-merge 2025-08-12 11:51:47 -07:00
Weves
ccde845e47 Improve citation text 2025-08-12 10:22:40 -07:00
Weves
cad3517f85 more 2025-08-12 09:48:41 -07:00
joachim-danswer
e71489b2ff initial front-end/backend integration 2025-08-11 21:52:58 -07:00
Weves
bfb6d632d2 Merge branch 'refactor-message-protocol' into dr-merge 2025-08-10 15:54:15 -07:00
Weves
191577fa19 Fix build 2025-08-10 15:53:54 -07:00
Weves
a190934193 Merge branch 'refactor-message-protocol' into dr-merge 2025-08-10 14:51:50 -07:00
Weves
a7d140cb5d Add zustand dependency 2025-08-10 14:51:38 -07:00
Weves
5e743515e9 Merge branch 'refactor-message-protocol' into dr-merge 2025-08-10 14:31:45 -07:00
Weves
4ef7e44c95 try something 2025-08-10 14:30:36 -07:00
Weves
91bc1e93ba Merge branch 'dr_v0' into dr-merge 2025-08-10 13:08:01 -07:00
Weves
e7bd58cc85 Improvements 2025-08-10 13:03:47 -07:00
Weves
dd18291d51 Custom tool support 2025-08-10 12:31:18 -07:00
Weves
9a5ea03cd1 more 2025-08-10 12:06:45 -07:00
Rei Meguro
7b37e72b9d almost working custom tools 2025-08-11 02:10:48 +09:00
Rei Meguro
09d672ff22 more cleanup for tools 2025-08-10 23:36:24 +09:00
Rei Meguro
b028b25737 rename folder 2025-08-10 21:59:27 +09:00
Rei Meguro
07768d5484 feat: initial custom tool support prep 2025-08-10 18:59:46 +09:00
Rei Meguro
5ca8ca2b1e mypy and proper id implementation 2025-08-10 17:07:41 +09:00
Rei Meguro
62872e58ae cleanup 2025-08-10 16:20:59 +09:00
Weves
eee3054b45 More stuff 2025-08-08 19:29:51 -07:00
joachim-danswer
5f66a27c67 ResearchType vs DRTimeBudget 2025-08-08 16:54:34 -07:00
joachim-danswer
c21fa21958 initial decision using tool-calling if tool-calling LLM 2025-08-08 16:37:32 -07:00
joachim-danswer
cd6577c3ca tool_id for custom tools 2025-08-08 12:57:58 -07:00
Rei Meguro
16406f0ebd kg citations 2025-08-08 08:44:25 -07:00
Rei Meguro
4ae5bb1e6b fix: iteration citation replacement 2025-08-08 08:44:25 -07:00
Rei Meguro
b0c95ec876 fix: constants 2025-08-08 08:44:25 -07:00
Rei Meguro
397d30c802 better prompt templating 2025-08-08 08:44:25 -07:00
joachim-danswer
f13b08b461 persistence 2025-08-08 08:44:25 -07:00
Rei Meguro
e66245ec13 properly merge inference section contents 2025-08-08 08:44:25 -07:00
Rei Meguro
c64c6368c1 feat: kg tool proper implementation 2025-08-08 08:44:25 -07:00
joachim-danswer
b2fe55c8f8 more DR updates 2025-08-08 08:44:25 -07:00
joachim-danswer
2b661441d7 nits 2025-08-08 08:44:25 -07:00
joachim-danswer
f83f06228b reworked 'fast' search 2025-08-08 08:44:25 -07:00
joachim-danswer
fabfa8d166 query rejection step 2025-08-08 08:44:25 -07:00
joachim-danswer
994e7f7666 active_source_description 2025-08-08 08:44:25 -07:00
Rei Meguro
c81a7e1ef2 mypy fix 2025-08-08 08:44:25 -07:00
joachim-danswer
1d7d2f06d8 time filter and source prediction 2025-08-08 08:44:25 -07:00
joachim-danswer
916d6cb119 base search in DR refactoring 2025-08-08 08:44:25 -07:00
Rei Meguro
6d3542ded1 fix error overwrite 2025-08-08 08:44:25 -07:00
Rei Meguro
e5dbfc34c3 faster relationship sql generation 2025-08-08 08:44:25 -07:00
Rei Meguro
1aad7f44d2 add back kg 2025-08-08 08:44:25 -07:00
Rei Meguro
a0d6d0b922 cleanup 2025-08-08 08:44:25 -07:00
joachim-danswer
588023a1f6 state updates for internal search 2025-08-08 08:44:25 -07:00
joachim-danswer
e4c2427728 merging of new citation handling and sending back by Closer 2025-08-08 08:44:25 -07:00
joachim-danswer
bf77da26fc closer can suggest more research 2025-08-08 08:44:25 -07:00
Rei Meguro
abfecde097 citation improvements with answer claim structure 2025-08-08 08:44:25 -07:00
Rei Meguro
3f4936ad0a cleanup 2025-08-08 08:44:25 -07:00
joachim-danswer
3b8d16a136 claim improvements 2025-08-08 08:44:24 -07:00
joachim-danswer
322e8668da claim start 2025-08-08 08:44:24 -07:00
Rei Meguro
d1dcad60d6 prompt improvements 2025-08-08 08:44:24 -07:00
Rei Meguro
7b3bdbdf83 fix: mypy 2025-08-08 08:44:24 -07:00
Rei Meguro
8b09fb0cef better clarification (still need prompt work) + prompt template fix 2025-08-08 08:44:24 -07:00
Rei Meguro
a2dd1bbf4f cleanup 2025-08-08 08:44:24 -07:00
Rei Meguro
828231815a fix: mypy 2025-08-08 08:44:24 -07:00
joachim-danswer
d48cbc2b79 custom tools 2025-08-08 08:44:24 -07:00
joachim-danswer
991bd4f8bf separation of tools 2025-08-08 08:44:24 -07:00
Rei Meguro
74418b84a2 consolidate user feedback 2025-08-08 08:44:24 -07:00
joachim-danswer
df1c40c791 prompt spellings 2025-08-08 08:44:24 -07:00
Rei Meguro
c253844500 minor cleanups + mypy fix 2025-08-08 08:44:24 -07:00
joachim-danswer
e972fb3e07 adding current time to prompts 2025-08-08 08:44:24 -07:00
joachim-danswer
726211c27d internet search improvements 2025-08-08 08:44:24 -07:00
joachim-danswer
c0435ddfd6 parallelized internet search 2025-08-08 08:44:24 -07:00
joachim-danswer
48dc934c35 internet search 2025-08-08 08:44:24 -07:00
Rei Meguro
3a575a92d5 fix: incorrect citations 2025-08-08 08:44:24 -07:00
Rei Meguro
de4a9e4687 mypy + rename vars for clarity 2025-08-08 08:44:24 -07:00
joachim-danswer
c330152417 nit 2025-08-08 08:44:24 -07:00
joachim-danswer
dca39f27a6 nit 2025-08-08 08:44:24 -07:00
joachim-danswer
d3cc27846a multi-search for Thoughtful 2025-08-08 08:44:24 -07:00
Rei Meguro
fedc665b88 kg bugfix + fix sql on error + slightly improved dr user feedback prompt 2025-08-08 08:44:24 -07:00
Rei Meguro
614672f357 fix: chat history + question passed to closer 2025-08-08 08:44:24 -07:00
Rei Meguro
6aca9ee005 fix docstring + move shared vars to constants.py 2025-08-08 08:44:24 -07:00
joachim-danswer
f9f64fb1a5 cleaning up of isolating feedback generation 2025-08-08 08:44:24 -07:00
joachim-danswer
4a63e631cd rough - included clarification
TODO: clean up!
2025-08-08 08:44:24 -07:00
Rei Meguro
3d5586d623 feat: make kg query part of state, rather than config 2025-08-08 08:44:24 -07:00
joachim-danswer
6c4eb17b5d prompt improvements 2025-08-08 08:44:24 -07:00
joachim-danswer
0917d9acd3 nits 2025-08-08 08:44:24 -07:00
Rei Meguro
89ea0f8d48 fix: wrong indentation 2025-08-08 08:44:24 -07:00
Rei Meguro
31ae6f1eb1 aggregate context improvements (no duplicates) 2025-08-08 08:44:24 -07:00
Rei Meguro
1b8d246afb feat: preparation for parallel search 2025-08-08 08:44:24 -07:00
Rei Meguro
05e55559d8 feat: citation improvements 2025-08-08 08:44:24 -07:00
joachim-danswer
241b8d062c adding final references 2025-08-08 08:44:24 -07:00
Rei Meguro
6359d2f2d6 formatting 2025-08-08 08:44:24 -07:00
Rei Meguro
83325f9012 feat: previous chat context 2025-08-08 08:44:24 -07:00
joachim-danswer
0b26ed602d updates - KG search w/ citations 2025-08-08 08:44:24 -07:00
Rei Meguro
2b69d1ba52 more minor prompt improvements 2025-08-08 08:44:24 -07:00
Rei Meguro
27cd1d44dc feat: small prompt improvements 2025-08-08 08:44:24 -07:00
Rei Meguro
b5ddf31742 sligtly better planner prompt 2025-08-08 08:44:24 -07:00
Rei Meguro
ce1c80148b final answer streaming 2025-08-08 08:44:24 -07:00
Rei Meguro
bb95c46015 feat: structured response 2025-08-08 08:44:24 -07:00
Rei Meguro
e8a593c315 mypy + better typing 2025-08-08 08:44:24 -07:00
joachim-danswer
bb1b12988c improvements 2025-08-08 08:44:24 -07:00
joachim-danswer
72bbcabedf improved DR 2025-08-08 08:44:24 -07:00
Rei Meguro
2ee98ba795 plan of record fix 2025-08-08 08:44:24 -07:00
Rei Meguro
594bbdb167 greptile + evan comments 2025-08-08 08:44:24 -07:00
Rei Meguro
d5c67b6f50 mypy + typing next_step and plan_of_records 2025-08-08 08:44:24 -07:00
joachim-danswer
9c7638ceba iteration prep 2025-08-08 08:44:24 -07:00
joachim-danswer
b1488ddccc update to KG Beta 2025-08-08 08:44:24 -07:00
joachim-danswer
d9a9818b9a nit 2025-08-08 08:44:24 -07:00
joachim-danswer
4bd3b8b0bb nit 2025-08-08 08:44:24 -07:00
joachim-danswer
da3979fc41 is_agentic_overwrite 2025-08-08 08:44:24 -07:00
joachim-danswer
ffed8b4300 orchestration base 2025-08-08 08:44:24 -07:00
Weves
5eea47cb1c more 2025-08-07 18:14:29 -07:00
Weves
c830364c15 Major cleanup 2025-08-07 15:35:05 -07:00
Weves
04f3ba1f3d MORE 2025-08-07 14:45:37 -07:00
Weves
84f76fbee7 remove unused imports 2025-08-03 13:46:13 -07:00
Weves
00aeb3b280 More stuff 2025-08-03 13:30:09 -07:00
Weves
8c30085a9e more 2025-08-01 16:24:40 -07:00
Weves
419e82f9f4 more stuff 2025-08-01 16:24:40 -07:00
Weves
8330e5d8f4 Add missing files 2025-08-01 16:24:40 -07:00
Weves
e06c60a1a3 Many small fixes 2025-08-01 16:24:40 -07:00
Weves
e7eef67893 Fixes 2025-08-01 16:24:40 -07:00
Weves
b5209edffa rebase 2025-08-01 16:24:40 -07:00
Weves
07ad4dc022 More stuff 2025-08-01 16:24:40 -07:00
Weves
06e1a2c1a5 Basic, jank image gen support 2025-08-01 16:24:40 -07:00
Weves
083c152878 Initial new message protocol 2025-08-01 16:24:39 -07:00
Weves
06f11a0a06 Remove more * alternativeAssistant logic 2025-08-01 16:24:06 -07:00
Weves
fabfcddadb initial refactor
more

rebase

remove console.log

Use zustand

more refactor
2025-08-01 16:24:05 -07:00
224 changed files with 16296 additions and 8553 deletions

View File

@@ -0,0 +1,91 @@
"""add research agent database tables and chat message research fields
Revision ID: 5ae8240accb3
Revises: 62c3a055a141
Create Date: 2025-08-06 14:29:24.691388
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "5ae8240accb3"
down_revision = "62c3a055a141"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Add research_type and research_plan columns to chat_message table
op.add_column(
"chat_message",
sa.Column("research_type", sa.String(), nullable=True),
)
op.add_column(
"chat_message",
sa.Column("research_plan", postgresql.JSONB(), nullable=True),
)
# Create research_agent_iteration table
op.create_table(
"research_agent_iteration",
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
sa.Column(
"primary_question_id",
sa.Integer(),
sa.ForeignKey("chat_message.id", ondelete="CASCADE"),
nullable=False,
),
sa.Column("iteration_nr", sa.Integer(), nullable=False),
sa.Column("created_at", sa.DateTime(), nullable=False),
sa.Column("purpose", sa.String(), nullable=True),
sa.Column("reasoning", sa.String(), nullable=True),
sa.PrimaryKeyConstraint("id"),
)
# Create research_agent_iteration_sub_step table
op.create_table(
"research_agent_iteration_sub_step",
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
sa.Column(
"primary_question_id",
sa.Integer(),
sa.ForeignKey("chat_message.id", ondelete="CASCADE"),
nullable=False,
),
sa.Column(
"parent_question_id",
sa.Integer(),
sa.ForeignKey("research_agent_iteration_sub_step.id", ondelete="CASCADE"),
nullable=True,
),
sa.Column("iteration_nr", sa.Integer(), nullable=False),
sa.Column("iteration_sub_step_nr", sa.Integer(), nullable=False),
sa.Column("created_at", sa.DateTime(), nullable=False),
sa.Column("sub_step_instructions", sa.String(), nullable=True),
sa.Column(
"sub_step_tool_id",
sa.Integer(),
sa.ForeignKey("tool.id"),
nullable=True,
),
sa.Column("reasoning", sa.String(), nullable=True),
sa.Column("sub_answer", sa.String(), nullable=True),
sa.Column("cited_doc_results", postgresql.JSONB(), nullable=True),
sa.Column("claims", postgresql.JSONB(), nullable=True),
sa.Column("additional_data", postgresql.JSONB(), nullable=True),
sa.PrimaryKeyConstraint("id"),
)
def downgrade() -> None:
# Drop tables in reverse order
op.drop_table("research_agent_iteration_sub_step")
op.drop_table("research_agent_iteration")
# Remove columns from chat_message table
op.drop_column("chat_message", "research_plan")
op.drop_column("chat_message", "research_type")

View File

@@ -0,0 +1,30 @@
"""add research_answer_purpose to chat_message
Revision ID: f8a9b2c3d4e5
Revises: 5ae8240accb3
Create Date: 2025-01-27 12:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "f8a9b2c3d4e5"
down_revision = "5ae8240accb3"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Add research_answer_purpose column to chat_message table
op.add_column(
"chat_message",
sa.Column("research_answer_purpose", sa.String(), nullable=True),
)
def downgrade() -> None:
# Remove research_answer_purpose column from chat_message table
op.drop_column("chat_message", "research_answer_purpose")

View File

@@ -29,7 +29,6 @@ from onyx.chat.models import QADocsResponse
from onyx.chat.models import RefinedAnswerImprovement
from onyx.chat.models import StreamingError
from onyx.chat.models import SubQueryPiece
from onyx.chat.models import SubQuestionIdentifier
from onyx.chat.models import SubQuestionPiece
from onyx.chat.process_message import ChatPacketStream
from onyx.chat.process_message import stream_chat_message_objects
@@ -48,6 +47,7 @@ from onyx.natural_language_processing.utils import get_tokenizer
from onyx.secondary_llm_flows.query_expansion import thread_based_query_rephrase
from onyx.server.query_and_chat.models import ChatMessageDetail
from onyx.server.query_and_chat.models import CreateChatMessageRequest
from onyx.server.query_and_chat.streaming_models import SubQuestionIdentifier
from onyx.utils.logger import setup_logger
logger = setup_logger()

View File

@@ -6,10 +6,8 @@ from pydantic import BaseModel
from pydantic import Field
from pydantic import model_validator
from onyx.chat.models import CitationInfo
from onyx.chat.models import PersonaOverrideConfig
from onyx.chat.models import QADocsResponse
from onyx.chat.models import SubQuestionIdentifier
from onyx.chat.models import ThreadMessage
from onyx.configs.constants import DocumentSource
from onyx.context.search.enums import LLMEvaluationType
@@ -19,6 +17,8 @@ from onyx.context.search.models import RerankingDetails
from onyx.context.search.models import RetrievalDetails
from onyx.context.search.models import SavedSearchDoc
from onyx.server.manage.models import StandardAnswer
from onyx.server.query_and_chat.streaming_models import CitationInfo
from onyx.server.query_and_chat.streaming_models import SubQuestionIdentifier
class StandardAnswerRequest(BaseModel):

View File

@@ -0,0 +1,12 @@
from langchain_core.messages import AIMessageChunk
from pydantic import BaseModel
from onyx.chat.models import LlmDoc
from onyx.context.search.models import InferenceSection
class BasicSearchProcessedStreamResults(BaseModel):
ai_message_chunk: AIMessageChunk = AIMessageChunk(content="")
full_answer: str | None = None
cited_references: list[InferenceSection] = []
retrieved_documents: list[LlmDoc] = []

View File

@@ -6,6 +6,9 @@ from pydantic import BaseModel
from onyx.agents.agent_search.orchestration.states import ToolCallUpdate
from onyx.agents.agent_search.orchestration.states import ToolChoiceInput
from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
from onyx.chat.models import LlmDoc
from onyx.context.search.models import InferenceSection
# States contain values that change over the course of graph execution,
# Config is for values that are set at the start and never change.
@@ -18,11 +21,15 @@ class BasicInput(BaseModel):
# Langgraph needs a nonempty input, but we pass in all static
# data through a RunnableConfig.
unused: bool = True
query_override: str | None = None
## Graph Output State
class BasicOutput(TypedDict):
tool_call_chunk: AIMessageChunk
full_answer: str | None
cited_references: list[InferenceSection] | None
retrieved_documents: list[LlmDoc] | None
## Graph State

View File

@@ -5,7 +5,9 @@ from langchain_core.messages import AIMessageChunk
from langchain_core.messages import BaseMessage
from langgraph.types import StreamWriter
from onyx.agents.agent_search.basic.models import BasicSearchProcessedStreamResults
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.chat_utils import saved_search_docs_from_llm_docs
from onyx.chat.models import LlmDoc
from onyx.chat.stream_processing.answer_response_handler import AnswerResponseHandler
from onyx.chat.stream_processing.answer_response_handler import CitationResponseHandler
@@ -13,6 +15,9 @@ from onyx.chat.stream_processing.answer_response_handler import (
PassThroughAnswerResponseHandler,
)
from onyx.chat.stream_processing.utils import map_document_id_order
from onyx.server.query_and_chat.streaming_models import MessageDelta
from onyx.server.query_and_chat.streaming_models import MessageStart
from onyx.server.query_and_chat.streaming_models import SectionEnd
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -22,9 +27,12 @@ def process_llm_stream(
messages: Iterator[BaseMessage],
should_stream_answer: bool,
writer: StreamWriter,
ind: int,
final_search_results: list[LlmDoc] | None = None,
displayed_search_results: list[LlmDoc] | None = None,
) -> AIMessageChunk:
generate_final_answer: bool = False,
chat_message_id: str | None = None,
) -> BasicSearchProcessedStreamResults:
tool_call_chunk = AIMessageChunk(content="")
if final_search_results and displayed_search_results:
@@ -37,6 +45,7 @@ def process_llm_stream(
answer_handler = PassThroughAnswerResponseHandler()
full_answer = ""
start_final_answer_streaming_set = False
# This stream will be the llm answer if no tool is chosen. When a tool is chosen,
# the stream will contain AIMessageChunks with tool call information.
for message in messages:
@@ -54,11 +63,53 @@ def process_llm_stream(
tool_call_chunk += message # type: ignore
elif should_stream_answer:
for response_part in answer_handler.handle_response_part(message, []):
write_custom_event(
"basic_response",
response_part,
writer,
)
if (
hasattr(response_part, "answer_piece")
and generate_final_answer
and response_part.answer_piece
):
if chat_message_id is None:
raise ValueError(
"chat_message_id is required when generating final answer"
)
if not start_final_answer_streaming_set:
# Convert LlmDocs to SavedSearchDocs
saved_search_docs = saved_search_docs_from_llm_docs(
final_search_results
)
write_custom_event(
ind,
MessageStart(content="", final_documents=saved_search_docs),
writer,
)
start_final_answer_streaming_set = True
write_custom_event(
ind,
MessageDelta(
content=response_part.answer_piece, type="message_delta"
),
writer,
)
else:
write_custom_event(
ind,
response_part,
writer,
)
if generate_final_answer and start_final_answer_streaming_set:
# start_final_answer_streaming_set is only set if the answer is verbal and not a tool call
write_custom_event(
ind,
SectionEnd(),
writer,
)
logger.debug(f"Full answer: {full_answer}")
return cast(AIMessageChunk, tool_call_chunk)
return BasicSearchProcessedStreamResults(
ai_message_chunk=cast(AIMessageChunk, tool_call_chunk), full_answer=full_answer
)

View File

@@ -10,6 +10,7 @@ class CoreState(BaseModel):
"""
log_messages: Annotated[list[str], add] = []
current_step_nr: int = 1
class SubgraphCoreState(BaseModel):

View File

@@ -0,0 +1,54 @@
from collections.abc import Hashable
from langgraph.graph import END
from langgraph.types import Send
from onyx.agents.agent_search.dr.enums import DRPath
from onyx.agents.agent_search.dr.states import MainState
def decision_router(state: MainState) -> list[Send | Hashable] | DRPath | str:
if not state.tools_used:
raise IndexError("state.tools_used cannot be empty")
# next_tool is either a generic tool name or a DRPath string
next_tool = state.tools_used[-1]
try:
next_path = DRPath(next_tool)
except ValueError:
next_path = DRPath.GENERIC_TOOL
# handle END
if next_path == DRPath.END:
return END
# handle invalid paths
if next_path == DRPath.CLARIFIER:
raise ValueError("CLARIFIER is not a valid path during iteration")
# handle tool calls without a query
if (
next_path
in (
DRPath.INTERNAL_SEARCH,
DRPath.INTERNET_SEARCH,
DRPath.KNOWLEDGE_GRAPH,
DRPath.IMAGE_GENERATION,
)
and len(state.query_list) == 0
):
return DRPath.CLOSER
return next_path
def completeness_router(state: MainState) -> DRPath | str:
if not state.tools_used:
raise IndexError("tools_used cannot be empty")
# go to closer if path is CLOSER or no queries
next_path = state.tools_used[-1]
if next_path == DRPath.ORCHESTRATOR.value:
return DRPath.ORCHESTRATOR
return END

View File

@@ -0,0 +1,30 @@
from onyx.agents.agent_search.dr.enums import DRPath
from onyx.agents.agent_search.dr.enums import ResearchType
MAX_CHAT_HISTORY_MESSAGES = (
3 # note: actual count is x2 to account for user and assistant messages
)
MAX_DR_PARALLEL_SEARCH = 4
# TODO: test more, generally not needed/adds unnecessary iterations
MAX_NUM_CLOSER_SUGGESTIONS = (
0 # how many times the closer can send back to the orchestrator
)
CLARIFICATION_REQUEST_PREFIX = "PLEASE CLARIFY:"
HIGH_LEVEL_PLAN_PREFIX = "HIGH_LEVEL PLAN:"
AVERAGE_TOOL_COSTS: dict[DRPath, float] = {
DRPath.INTERNAL_SEARCH: 1.0,
DRPath.KNOWLEDGE_GRAPH: 2.0,
DRPath.INTERNET_SEARCH: 1.5,
DRPath.IMAGE_GENERATION: 3.0,
DRPath.GENERIC_TOOL: 1.5, # TODO: see todo in OrchestratorTool
DRPath.CLOSER: 0.0,
}
DR_TIME_BUDGET_BY_TYPE = {
ResearchType.THOUGHTFUL: 3.0,
ResearchType.DEEP: 6.0,
}

View File

@@ -0,0 +1,114 @@
from datetime import datetime
from onyx.agents.agent_search.dr.enums import DRPath
from onyx.agents.agent_search.dr.enums import ResearchType
from onyx.agents.agent_search.dr.models import DRPromptPurpose
from onyx.agents.agent_search.dr.models import OrchestratorTool
from onyx.prompts.dr_prompts import GET_CLARIFICATION_PROMPT
from onyx.prompts.dr_prompts import KG_TYPES_DESCRIPTIONS
from onyx.prompts.dr_prompts import ORCHESTRATOR_DEEP_INITIAL_PLAN_PROMPT
from onyx.prompts.dr_prompts import ORCHESTRATOR_DEEP_ITERATIVE_DECISION_PROMPT
from onyx.prompts.dr_prompts import ORCHESTRATOR_FAST_ITERATIVE_DECISION_PROMPT
from onyx.prompts.dr_prompts import ORCHESTRATOR_FAST_ITERATIVE_REASONING_PROMPT
from onyx.prompts.dr_prompts import ORCHESTRATOR_NEXT_STEP_PURPOSE_PROMPT
from onyx.prompts.dr_prompts import TOOL_DIFFERENTIATION_HINTS
from onyx.prompts.dr_prompts import TOOL_QUESTION_HINTS
from onyx.prompts.prompt_template import PromptTemplate
def get_dr_prompt_orchestration_templates(
purpose: DRPromptPurpose,
research_type: ResearchType,
available_tools: dict[str, OrchestratorTool],
entity_types_string: str | None = None,
relationship_types_string: str | None = None,
reasoning_result: str | None = None,
tool_calls_string: str | None = None,
) -> PromptTemplate:
available_tools = available_tools or {}
tool_names = list(available_tools.keys())
tool_description_str = "\n\n".join(
f"- {tool_name}: {tool.description}"
for tool_name, tool in available_tools.items()
)
tool_cost_str = "\n".join(
f"{tool_name}: {tool.cost}" for tool_name, tool in available_tools.items()
)
tool_differentiations: list[str] = []
for tool_1 in available_tools:
for tool_2 in available_tools:
if (tool_1, tool_2) in TOOL_DIFFERENTIATION_HINTS:
tool_differentiations.append(
TOOL_DIFFERENTIATION_HINTS[(tool_1, tool_2)]
)
tool_differentiation_hint_string = (
"\n".join(tool_differentiations) or "(No differentiating hints available)"
)
# TODO: add tool deliniation pairs for custom tools as well
tool_question_hint_string = (
"\n".join(
"- " + TOOL_QUESTION_HINTS[tool]
for tool in available_tools
if tool in TOOL_QUESTION_HINTS
)
or "(No examples available)"
)
if DRPath.KNOWLEDGE_GRAPH.value in available_tools:
if not entity_types_string or not relationship_types_string:
raise ValueError(
"Entity types and relationship types must be provided if the Knowledge Graph is used."
)
kg_types_descriptions = KG_TYPES_DESCRIPTIONS.build(
possible_entities=entity_types_string,
possible_relationships=relationship_types_string,
)
else:
kg_types_descriptions = "(The Knowledge Graph is not used.)"
if purpose == DRPromptPurpose.PLAN:
if research_type == ResearchType.THOUGHTFUL:
raise ValueError("plan generation is not supported for FAST time budget")
base_template = ORCHESTRATOR_DEEP_INITIAL_PLAN_PROMPT
elif purpose == DRPromptPurpose.NEXT_STEP_REASONING:
if research_type == ResearchType.THOUGHTFUL:
base_template = ORCHESTRATOR_FAST_ITERATIVE_REASONING_PROMPT
else:
raise ValueError(
"reasoning is not separately required for DEEP time budget"
)
elif purpose == DRPromptPurpose.NEXT_STEP_PURPOSE:
base_template = ORCHESTRATOR_NEXT_STEP_PURPOSE_PROMPT
elif purpose == DRPromptPurpose.NEXT_STEP:
if research_type == ResearchType.THOUGHTFUL:
base_template = ORCHESTRATOR_FAST_ITERATIVE_DECISION_PROMPT
else:
base_template = ORCHESTRATOR_DEEP_ITERATIVE_DECISION_PROMPT
elif purpose == DRPromptPurpose.CLARIFICATION:
if research_type == ResearchType.THOUGHTFUL:
raise ValueError("clarification is not supported for FAST time budget")
base_template = GET_CLARIFICATION_PROMPT
else:
# for mypy, clearly a mypy bug
raise ValueError(f"Invalid purpose: {purpose}")
return base_template.partial_build(
num_available_tools=str(len(tool_names)),
available_tools=", ".join(tool_names),
tool_choice_options=" or ".join(tool_names),
current_time=datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
kg_types_descriptions=kg_types_descriptions,
tool_descriptions=tool_description_str,
tool_differentiation_hints=tool_differentiation_hint_string,
tool_question_hints=tool_question_hint_string,
average_tool_costs=tool_cost_str,
reasoning_result=reasoning_result or "(No reasoning result provided.)",
tool_calls_string=tool_calls_string or "(No tool calls provided.)",
)

View File

@@ -0,0 +1,28 @@
from enum import Enum
class ResearchType(str, Enum):
"""Research type options for agent search operations"""
# BASIC = "BASIC"
THOUGHTFUL = "THOUGHTFUL"
DEEP = "DEEP"
class ResearchAnswerPurpose(str, Enum):
"""Research answer purpose options for agent search operations"""
ANSWER = "ANSWER"
CLARIFICATION_REQUEST = "CLARIFICATION_REQUEST"
class DRPath(str, Enum):
CLARIFIER = "Clarifier"
ORCHESTRATOR = "Orchestrator"
INTERNAL_SEARCH = "Internal Search"
GENERIC_TOOL = "Generic Tool"
KNOWLEDGE_GRAPH = "Knowledge Graph"
INTERNET_SEARCH = "Internet Search"
IMAGE_GENERATION = "Image Generation"
CLOSER = "Closer"
END = "End"

View File

@@ -0,0 +1,80 @@
from langgraph.graph import START
from langgraph.graph import StateGraph
from onyx.agents.agent_search.dr.conditional_edges import completeness_router
from onyx.agents.agent_search.dr.conditional_edges import decision_router
from onyx.agents.agent_search.dr.enums import DRPath
from onyx.agents.agent_search.dr.nodes.dr_a0_clarification import clarifier
from onyx.agents.agent_search.dr.nodes.dr_a1_orchestrator import orchestrator
from onyx.agents.agent_search.dr.nodes.dr_a2_closer import closer
from onyx.agents.agent_search.dr.states import MainInput
from onyx.agents.agent_search.dr.states import MainState
from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_basic_search_graph_builder import (
dr_basic_search_graph_builder,
)
from onyx.agents.agent_search.dr.sub_agents.custom_tool.dr_custom_tool_graph_builder import (
dr_custom_tool_graph_builder,
)
from onyx.agents.agent_search.dr.sub_agents.image_generation.dr_image_generation_graph_builder import (
dr_image_generation_graph_builder,
)
from onyx.agents.agent_search.dr.sub_agents.internet_search.dr_is_graph_builder import (
dr_is_graph_builder,
)
from onyx.agents.agent_search.dr.sub_agents.kg_search.dr_kg_search_graph_builder import (
dr_kg_search_graph_builder,
)
from onyx.utils.logger import setup_logger
# from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_basic_search_2_act import search
logger = setup_logger()
def dr_graph_builder() -> StateGraph:
"""
LangGraph graph builder for the deep research agent.
"""
graph = StateGraph(state_schema=MainState, input=MainInput)
### Add nodes ###
graph.add_node(DRPath.CLARIFIER, clarifier)
graph.add_node(DRPath.ORCHESTRATOR, orchestrator)
basic_search_graph = dr_basic_search_graph_builder().compile()
graph.add_node(DRPath.INTERNAL_SEARCH, basic_search_graph)
kg_search_graph = dr_kg_search_graph_builder().compile()
graph.add_node(DRPath.KNOWLEDGE_GRAPH, kg_search_graph)
internet_search_graph = dr_is_graph_builder().compile()
graph.add_node(DRPath.INTERNET_SEARCH, internet_search_graph)
image_generation_graph = dr_image_generation_graph_builder().compile()
graph.add_node(DRPath.IMAGE_GENERATION, image_generation_graph)
custom_tool_graph = dr_custom_tool_graph_builder().compile()
graph.add_node(DRPath.GENERIC_TOOL, custom_tool_graph)
graph.add_node(DRPath.CLOSER, closer)
### Add edges ###
graph.add_edge(start_key=START, end_key=DRPath.CLARIFIER)
graph.add_conditional_edges(DRPath.CLARIFIER, decision_router)
graph.add_conditional_edges(DRPath.ORCHESTRATOR, decision_router)
graph.add_edge(start_key=DRPath.INTERNAL_SEARCH, end_key=DRPath.ORCHESTRATOR)
graph.add_edge(start_key=DRPath.KNOWLEDGE_GRAPH, end_key=DRPath.ORCHESTRATOR)
graph.add_edge(start_key=DRPath.INTERNET_SEARCH, end_key=DRPath.ORCHESTRATOR)
graph.add_edge(start_key=DRPath.IMAGE_GENERATION, end_key=DRPath.ORCHESTRATOR)
graph.add_edge(start_key=DRPath.GENERIC_TOOL, end_key=DRPath.ORCHESTRATOR)
graph.add_conditional_edges(DRPath.CLOSER, completeness_router)
return graph

View File

@@ -0,0 +1,108 @@
from enum import Enum
from pydantic import BaseModel
from onyx.agents.agent_search.dr.enums import DRPath
from onyx.context.search.models import InferenceSection
from onyx.tools.tool import Tool
class OrchestratorStep(BaseModel):
tool: str
questions: list[str]
class OrchestratorDecisonsNoPlan(BaseModel):
reasoning: str
next_step: OrchestratorStep
class OrchestrationPlan(BaseModel):
reasoning: str
plan: str
class ClarificationGenerationResponse(BaseModel):
clarification_needed: bool
clarification_question: str
class QueryEvaluationResponse(BaseModel):
reasoning: str
query_permitted: bool
class OrchestrationClarificationInfo(BaseModel):
clarification_question: str
clarification_response: str | None = None
class SearchAnswer(BaseModel):
reasoning: str
answer: str
claims: list[str] | None = None
class TestInfoCompleteResponse(BaseModel):
reasoning: str
complete: bool
gaps: list[str]
# TODO: revisit with custom tools implementation in v2
# each tool should be a class with the attributes below, plus the actual tool implementation
# this will also allow custom tools to have their own cost
class OrchestratorTool(BaseModel):
tool_id: int
name: str
llm_path: str # the path for the LLM to refer by
path: DRPath # the actual path in the graph
description: str
metadata: dict[str, str]
cost: float
tool_object: Tool | None = None # None for CLOSER
class Config:
arbitrary_types_allowed = True
class IterationInstructions(BaseModel):
iteration_nr: int
plan: str | None
reasoning: str
purpose: str
class IterationAnswer(BaseModel):
tool: str
tool_id: int
iteration_nr: int
parallelization_nr: int
question: str
reasoning: str | None
answer: str
cited_documents: dict[int, InferenceSection]
background_info: str | None = None
claims: list[str] | None = None
additional_data: dict[str, str] | None = None
class AggregatedDRContext(BaseModel):
context: str
cited_documents: list[InferenceSection]
is_internet_marker_dict: dict[str, bool]
global_iteration_responses: list[IterationAnswer]
class DRPromptPurpose(str, Enum):
PLAN = "PLAN"
NEXT_STEP = "NEXT_STEP"
NEXT_STEP_REASONING = "NEXT_STEP_REASONING"
NEXT_STEP_PURPOSE = "NEXT_STEP_PURPOSE"
CLARIFICATION = "CLARIFICATION"
class BaseSearchProcessingResponse(BaseModel):
specified_source_types: list[str]
rewritten_query: str
time_filter: str

View File

@@ -0,0 +1,572 @@
import re
from datetime import datetime
from typing import cast
from langchain_core.messages import HumanMessage
from langchain_core.messages import merge_content
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.basic.utils import process_llm_stream
from onyx.agents.agent_search.dr.constants import AVERAGE_TOOL_COSTS
from onyx.agents.agent_search.dr.constants import MAX_CHAT_HISTORY_MESSAGES
from onyx.agents.agent_search.dr.dr_prompt_builder import (
get_dr_prompt_orchestration_templates,
)
from onyx.agents.agent_search.dr.enums import DRPath
from onyx.agents.agent_search.dr.enums import ResearchAnswerPurpose
from onyx.agents.agent_search.dr.enums import ResearchType
from onyx.agents.agent_search.dr.models import ClarificationGenerationResponse
from onyx.agents.agent_search.dr.models import DRPromptPurpose
from onyx.agents.agent_search.dr.models import OrchestrationClarificationInfo
from onyx.agents.agent_search.dr.models import OrchestratorTool
from onyx.agents.agent_search.dr.states import MainState
from onyx.agents.agent_search.dr.states import OrchestrationSetup
from onyx.agents.agent_search.dr.utils import get_chat_history_string
from onyx.agents.agent_search.dr.utils import update_db_session_with_messages
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.llm import invoke_llm_json
from onyx.agents.agent_search.shared_graph_utils.llm import stream_llm_answer
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import run_with_timeout
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.agents.agent_search.utils import create_question_prompt
from onyx.configs.constants import DocumentSourceDescription
from onyx.db.connector import fetch_unique_document_sources
from onyx.kg.utils.extraction_utils import get_entity_types_str
from onyx.kg.utils.extraction_utils import get_relationship_types_str
from onyx.prompts.dr_prompts import DECISION_PROMPT_W_TOOL_CALLING
from onyx.prompts.dr_prompts import DECISION_PROMPT_WO_TOOL_CALLING
from onyx.prompts.dr_prompts import DEFAULT_DR_SYSTEM_PROMPT
from onyx.prompts.dr_prompts import EVAL_SYSTEM_PROMPT_W_TOOL_CALLING
from onyx.prompts.dr_prompts import EVAL_SYSTEM_PROMPT_WO_TOOL_CALLING
from onyx.prompts.dr_prompts import GENERAL_DR_ANSWER_PROMPT
from onyx.prompts.dr_prompts import TOOL_DESCRIPTION
from onyx.server.query_and_chat.streaming_models import MessageDelta
from onyx.server.query_and_chat.streaming_models import MessageStart
from onyx.server.query_and_chat.streaming_models import OverallStop
from onyx.server.query_and_chat.streaming_models import SectionEnd
from onyx.tools.tool_implementations.custom.custom_tool import CustomTool
from onyx.tools.tool_implementations.images.image_generation_tool import (
ImageGenerationTool,
)
from onyx.tools.tool_implementations.internet_search.internet_search_tool import (
InternetSearchTool,
)
from onyx.tools.tool_implementations.knowledge_graph.knowledge_graph_tool import (
KnowledgeGraphTool,
)
from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.utils.logger import setup_logger
logger = setup_logger()
def _format_tool_name(tool_name: str) -> str:
"""Convert tool name to LLM-friendly format."""
name = tool_name.replace(" ", "_")
# take care of camel case like GetAPIKey -> GET_API_KEY for LLM readability
name = re.sub(r"(?<=[a-z0-9])(?=[A-Z])|(?<=[A-Z])(?=[A-Z][a-z])", "_", name)
return name.upper()
def _get_available_tools(
graph_config: GraphConfig, kg_enabled: bool
) -> dict[str, OrchestratorTool]:
available_tools: dict[str, OrchestratorTool] = {}
for tool in graph_config.tooling.tools:
tool_info = OrchestratorTool(
tool_id=tool.id,
name=tool.name,
llm_path=_format_tool_name(tool.name),
path=DRPath.GENERIC_TOOL,
description=tool.description,
metadata={},
cost=1.0,
tool_object=tool,
)
if isinstance(tool, CustomTool):
# tool_info.metadata["summary_signature"] = CUSTOM_TOOL_RESPONSE_ID
pass
elif isinstance(tool, InternetSearchTool):
# tool_info.metadata["summary_signature"] = (
# INTERNET_SEARCH_RESPONSE_SUMMARY_ID
# )
tool_info.llm_path = DRPath.INTERNET_SEARCH.value
tool_info.path = DRPath.INTERNET_SEARCH
elif isinstance(tool, SearchTool):
# tool_info.metadata["summary_signature"] = SEARCH_RESPONSE_SUMMARY_ID
tool_info.llm_path = DRPath.INTERNAL_SEARCH.value
tool_info.path = DRPath.INTERNAL_SEARCH
elif isinstance(tool, KnowledgeGraphTool):
if not kg_enabled:
logger.warning("KG must be enabled to use KG search tool, skipping")
continue
tool_info.llm_path = DRPath.KNOWLEDGE_GRAPH.value
tool_info.path = DRPath.KNOWLEDGE_GRAPH
elif isinstance(tool, ImageGenerationTool):
tool_info.llm_path = DRPath.IMAGE_GENERATION.value
tool_info.path = DRPath.IMAGE_GENERATION
else:
logger.warning(f"Tool {tool.name} ({type(tool)}) is not supported")
continue
tool_info.description = TOOL_DESCRIPTION.get(tool_info.path, tool.description)
tool_info.cost = AVERAGE_TOOL_COSTS[tool_info.path]
# TODO: handle custom tools with same name as other tools (e.g., CLOSER)
available_tools[tool_info.llm_path] = tool_info
# make sure KG isn't enabled without internal search
if (
DRPath.KNOWLEDGE_GRAPH.value in available_tools
and DRPath.INTERNAL_SEARCH.value not in available_tools
):
raise ValueError(
"The Knowledge Graph is not supported without internal search tool"
)
# add CLOSER tool, which is always available
available_tools[DRPath.CLOSER.value] = OrchestratorTool(
tool_id=-1,
name="closer",
llm_path=DRPath.CLOSER.value,
path=DRPath.CLOSER,
description=TOOL_DESCRIPTION[DRPath.CLOSER],
metadata={},
cost=0.0,
tool_object=None,
)
return available_tools
def _get_existing_clarification_request(
graph_config: GraphConfig,
) -> tuple[OrchestrationClarificationInfo, str, str] | None:
"""
Returns the clarification info, original question, and updated chat history if
a clarification request and response exists, otherwise returns None.
"""
# check for clarification request and response in message history
previous_raw_messages = graph_config.inputs.prompt_builder.raw_message_history
if len(previous_raw_messages) == 0 or (
previous_raw_messages[-1].research_answer_purpose
!= ResearchAnswerPurpose.CLARIFICATION_REQUEST
):
return None
# get the clarification request and response
previous_messages = graph_config.inputs.prompt_builder.message_history
last_message = previous_raw_messages[-1].message
clarification = OrchestrationClarificationInfo(
clarification_question=last_message.strip(),
clarification_response=graph_config.inputs.prompt_builder.raw_user_query,
)
original_question = graph_config.inputs.prompt_builder.raw_user_query
chat_history_string = "(No chat history yet available)"
# get the original user query and chat history string before the original query
# e.g., if history = [user query, assistant clarification request, user clarification response],
# previous_messages = [user query, assistant clarification request], we want the user query
for i, message in enumerate(reversed(previous_messages), 1):
if (
isinstance(message, HumanMessage)
and message.content
and isinstance(message.content, str)
):
original_question = message.content
chat_history_string = (
get_chat_history_string(
graph_config.inputs.prompt_builder.message_history[:-i],
MAX_CHAT_HISTORY_MESSAGES,
)
or "(No chat history yet available)"
)
break
return clarification, original_question, chat_history_string
_ARTIFICIAL_ALL_ENCOMPASSING_TOOL = {
"type": "function",
"function": {
"name": "run_any_knowledge_retrieval_and_any_action_tool",
"description": "Use this tool to get any external information \
that is relevant to the question, or for any action to be taken.",
"parameters": {
"type": "object",
"properties": {
"request": {
"type": "string",
"description": "The request to be made to the tool",
},
},
"required": ["request"],
},
},
}
def clarifier(
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> OrchestrationSetup:
"""
Perform a quick search on the question as is and see whether a set of clarification
questions is needed. For now this is based on the models
"""
node_start_time = datetime.now()
graph_config = cast(GraphConfig, config["metadata"]["config"])
use_tool_calling_llm = graph_config.tooling.using_tool_calling_llm
db_session = graph_config.persistence.db_session
original_question = graph_config.inputs.prompt_builder.raw_user_query
research_type = graph_config.behavior.research_type
message_id = graph_config.persistence.message_id
# get the connected tools and format for the Deep Research flow
kg_enabled = graph_config.behavior.kg_config_settings.KG_ENABLED
available_tools = _get_available_tools(graph_config, kg_enabled)
non_internal_search_tools = [
tool
for tool in available_tools.values()
if tool.path != DRPath.INTERNAL_SEARCH and tool.path != DRPath.KNOWLEDGE_GRAPH
]
all_entity_types = get_entity_types_str(active=True)
all_relationship_types = get_relationship_types_str(active=True)
db_session = graph_config.persistence.db_session
active_source_types = fetch_unique_document_sources(db_session)
# if not active_source_types:
# raise ValueError("No active source types found")
active_source_types_descriptions = [
DocumentSourceDescription[source_type] for source_type in active_source_types
]
if graph_config.inputs.persona and len(graph_config.inputs.persona.prompts) > 0:
assistant_system_prompt = (
graph_config.inputs.persona.prompts[0].system_prompt
or DEFAULT_DR_SYSTEM_PROMPT
) + "\n\n"
if graph_config.inputs.persona.prompts[0].task_prompt:
assistant_task_prompt = (
"\n\nHere are more specifications from the user:\n\n"
+ graph_config.inputs.persona.prompts[0].task_prompt
)
else:
assistant_task_prompt = ""
else:
assistant_system_prompt = DEFAULT_DR_SYSTEM_PROMPT + "\n\n"
assistant_task_prompt = ""
chat_history_string = (
get_chat_history_string(
graph_config.inputs.prompt_builder.message_history,
MAX_CHAT_HISTORY_MESSAGES,
)
or "(No chat history yet available)"
)
if len(available_tools) == 0 or (
len(non_internal_search_tools) == 0 and len(active_source_types) == 0
):
answer_prompt = GENERAL_DR_ANSWER_PROMPT.build(
question=original_question, chat_history_string=chat_history_string
)
stream = graph_config.tooling.primary_llm.stream(
prompt=create_question_prompt(
assistant_system_prompt, answer_prompt + assistant_task_prompt
),
tools=None,
tool_choice=(None),
structured_response_format=None,
)
full_response = process_llm_stream(
messages=stream,
should_stream_answer=True,
writer=writer,
ind=0,
generate_final_answer=True,
chat_message_id=str(graph_config.persistence.chat_session_id),
)
if isinstance(full_response.full_answer, str):
full_answer = full_response.full_answer
else:
full_answer = None
update_db_session_with_messages(
db_session=db_session,
chat_message_id=message_id,
chat_session_id=str(graph_config.persistence.chat_session_id),
is_agentic=graph_config.behavior.use_agentic_search,
message=full_answer,
update_parent_message=True,
research_answer_purpose=ResearchAnswerPurpose.ANSWER,
)
db_session.commit()
return OrchestrationSetup(
original_question=original_question,
chat_history_string="",
tools_used=[DRPath.END.value],
query_list=[],
assistant_system_prompt=assistant_system_prompt,
assistant_task_prompt=assistant_task_prompt,
)
elif not use_tool_calling_llm:
decision_prompt = DECISION_PROMPT_WO_TOOL_CALLING.build(
question=original_question, chat_history_string=chat_history_string
)
initial_decision_tokens, _, _ = run_with_timeout(
80,
lambda: stream_llm_answer(
llm=graph_config.tooling.primary_llm,
prompt=create_question_prompt(
assistant_system_prompt + EVAL_SYSTEM_PROMPT_WO_TOOL_CALLING,
decision_prompt + assistant_task_prompt,
),
event_name="basic_response",
writer=writer,
agent_answer_level=0,
agent_answer_question_num=0,
agent_answer_type="agent_level_answer",
timeout_override=60,
max_tokens=None,
),
)
initial_decision_str = cast(str, merge_content(*initial_decision_tokens))
if len(initial_decision_str.replace(" ", "")) > 0:
return OrchestrationSetup(
original_question=original_question,
chat_history_string="",
tools_used=[DRPath.END.value],
query_list=[],
assistant_system_prompt=assistant_system_prompt,
assistant_task_prompt=assistant_task_prompt,
)
else:
decision_prompt = DECISION_PROMPT_W_TOOL_CALLING.build(
question=original_question, chat_history_string=chat_history_string
)
stream = graph_config.tooling.primary_llm.stream(
prompt=create_question_prompt(
assistant_system_prompt + EVAL_SYSTEM_PROMPT_W_TOOL_CALLING,
decision_prompt + assistant_task_prompt,
),
tools=([_ARTIFICIAL_ALL_ENCOMPASSING_TOOL]),
tool_choice=(None),
structured_response_format=graph_config.inputs.structured_response_format,
)
full_response = process_llm_stream(
messages=stream,
should_stream_answer=True,
writer=writer,
ind=0,
generate_final_answer=True,
chat_message_id=str(graph_config.persistence.chat_session_id),
)
if len(full_response.ai_message_chunk.tool_calls) == 0:
if isinstance(full_response.full_answer, str):
full_answer = full_response.full_answer
else:
full_answer = None
update_db_session_with_messages(
db_session=db_session,
chat_message_id=message_id,
chat_session_id=str(graph_config.persistence.chat_session_id),
is_agentic=graph_config.behavior.use_agentic_search,
message=full_answer,
update_parent_message=True,
research_answer_purpose=ResearchAnswerPurpose.ANSWER,
)
db_session.commit()
return OrchestrationSetup(
original_question=original_question,
chat_history_string="",
tools_used=[DRPath.END.value],
query_list=[],
assistant_system_prompt=assistant_system_prompt,
assistant_task_prompt=assistant_task_prompt,
)
# Continue, as external knowledge is required.
clarification = None
if research_type != ResearchType.THOUGHTFUL:
result = _get_existing_clarification_request(graph_config)
if result is not None:
clarification, original_question, chat_history_string = result
else:
# generate clarification questions if needed
chat_history_string = (
get_chat_history_string(
graph_config.inputs.prompt_builder.message_history,
MAX_CHAT_HISTORY_MESSAGES,
)
or "(No chat history yet available)"
)
base_clarification_prompt = get_dr_prompt_orchestration_templates(
DRPromptPurpose.CLARIFICATION,
research_type,
entity_types_string=all_entity_types,
relationship_types_string=all_relationship_types,
available_tools=available_tools,
)
clarification_prompt = base_clarification_prompt.build(
question=original_question,
chat_history_string=chat_history_string,
)
try:
clarification_response = invoke_llm_json(
llm=graph_config.tooling.primary_llm,
prompt=create_question_prompt(
assistant_system_prompt, clarification_prompt
),
schema=ClarificationGenerationResponse,
timeout_override=25,
# max_tokens=1500,
)
except Exception as e:
logger.error(f"Error in clarification generation: {e}")
raise e
if (
clarification_response.clarification_needed
and clarification_response.clarification_question
):
clarification = OrchestrationClarificationInfo(
clarification_question=clarification_response.clarification_question,
clarification_response=None,
)
write_custom_event(
0,
MessageStart(
content="",
final_documents=None,
),
writer,
)
write_custom_event(
0,
MessageDelta(
content=clarification_response.clarification_question,
type="message_delta",
),
writer,
)
write_custom_event(
0,
SectionEnd(
type="section_end",
),
writer,
)
write_custom_event(
1,
OverallStop(),
writer,
)
update_db_session_with_messages(
db_session=db_session,
chat_message_id=message_id,
chat_session_id=str(graph_config.persistence.chat_session_id),
is_agentic=graph_config.behavior.use_agentic_search,
message=clarification_response.clarification_question,
update_parent_message=True,
research_type=research_type,
research_answer_purpose=ResearchAnswerPurpose.CLARIFICATION_REQUEST,
)
db_session.commit()
else:
chat_history_string = (
get_chat_history_string(
graph_config.inputs.prompt_builder.message_history,
MAX_CHAT_HISTORY_MESSAGES,
)
or "(No chat history yet available)"
)
if (
clarification
and clarification.clarification_question
and clarification.clarification_response is None
):
update_db_session_with_messages(
db_session=db_session,
chat_message_id=message_id,
chat_session_id=str(graph_config.persistence.chat_session_id),
is_agentic=graph_config.behavior.use_agentic_search,
message=clarification.clarification_question,
update_parent_message=True,
research_type=research_type,
research_answer_purpose=ResearchAnswerPurpose.CLARIFICATION_REQUEST,
)
db_session.commit()
next_tool = DRPath.END.value
else:
next_tool = DRPath.ORCHESTRATOR.value
return OrchestrationSetup(
original_question=original_question,
chat_history_string=chat_history_string,
tools_used=[next_tool],
query_list=[],
iteration_nr=0,
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="clarifier",
node_start_time=node_start_time,
)
],
clarification=clarification,
available_tools=available_tools,
active_source_types=active_source_types,
active_source_types_descriptions="\n".join(active_source_types_descriptions),
assistant_system_prompt=assistant_system_prompt,
assistant_task_prompt=assistant_task_prompt,
)

View File

@@ -0,0 +1,445 @@
from datetime import datetime
from typing import cast
from langchain_core.messages import merge_content
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.dr.constants import DR_TIME_BUDGET_BY_TYPE
from onyx.agents.agent_search.dr.constants import HIGH_LEVEL_PLAN_PREFIX
from onyx.agents.agent_search.dr.dr_prompt_builder import (
get_dr_prompt_orchestration_templates,
)
from onyx.agents.agent_search.dr.enums import DRPath
from onyx.agents.agent_search.dr.enums import ResearchType
from onyx.agents.agent_search.dr.models import DRPromptPurpose
from onyx.agents.agent_search.dr.models import OrchestrationPlan
from onyx.agents.agent_search.dr.models import OrchestratorDecisonsNoPlan
from onyx.agents.agent_search.dr.states import IterationInstructions
from onyx.agents.agent_search.dr.states import MainState
from onyx.agents.agent_search.dr.states import OrchestrationUpdate
from onyx.agents.agent_search.dr.utils import aggregate_context
from onyx.agents.agent_search.dr.utils import create_tool_call_string
from onyx.agents.agent_search.dr.utils import get_prompt_question
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.llm import invoke_llm_json
from onyx.agents.agent_search.shared_graph_utils.llm import stream_llm_answer
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import run_with_timeout
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.agents.agent_search.utils import create_question_prompt
from onyx.kg.utils.extraction_utils import get_entity_types_str
from onyx.kg.utils.extraction_utils import get_relationship_types_str
from onyx.prompts.dr_prompts import SUFFICIENT_INFORMATION_STRING
from onyx.server.query_and_chat.streaming_models import ReasoningDelta
from onyx.server.query_and_chat.streaming_models import ReasoningStart
from onyx.server.query_and_chat.streaming_models import SectionEnd
from onyx.utils.logger import setup_logger
logger = setup_logger()
def orchestrator(
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> OrchestrationUpdate:
"""
LangGraph node to decide the next step in the DR process.
"""
node_start_time = datetime.now()
graph_config = cast(GraphConfig, config["metadata"]["config"])
question = state.original_question
if not question:
raise ValueError("Question is required for orchestrator")
plan_of_record = state.plan_of_record
clarification = state.clarification
assistant_system_prompt = state.assistant_system_prompt
assistant_task_prompt = state.assistant_task_prompt
iteration_nr = state.iteration_nr + 1
current_step_nr = state.current_step_nr
research_type = graph_config.behavior.research_type
remaining_time_budget = state.remaining_time_budget
chat_history_string = state.chat_history_string or "(No chat history yet available)"
answer_history_string = (
aggregate_context(state.iteration_responses, include_documents=True).context
or "(No answer history yet available)"
)
available_tools = state.available_tools or {}
questions = [
f"{iteration_response.tool}: {iteration_response.question}"
for iteration_response in state.iteration_responses
if len(iteration_response.question) > 0
]
question_history_string = (
"\n".join(f" - {question}" for question in questions)
if questions
else "(No question history yet available)"
)
prompt_question = get_prompt_question(question, clarification)
gaps_str = (
("\n - " + "\n - ".join(state.gaps))
if state.gaps
else "(No explicit gaps were pointed out so far)"
)
all_entity_types = get_entity_types_str(active=True)
all_relationship_types = get_relationship_types_str(active=True)
# default to closer
next_tool = DRPath.CLOSER.value
query_list = ["Answer the question with the information you have."]
decision_prompt = None
reasoning_result = "(No reasoning result provided yet.)"
tool_calls_string = "(No tool calls provided yet.)"
if research_type == ResearchType.THOUGHTFUL:
if iteration_nr == 1:
remaining_time_budget = DR_TIME_BUDGET_BY_TYPE[ResearchType.THOUGHTFUL]
elif iteration_nr > 1:
# for each iteration past the first one, we need to see whether we
# have enough information to answer the question.
# if we do, we can stop the iteration and return the answer.
# if we do not, we need to continue the iteration.
base_reasoning_prompt = get_dr_prompt_orchestration_templates(
DRPromptPurpose.NEXT_STEP_REASONING,
ResearchType.THOUGHTFUL,
entity_types_string=all_entity_types,
relationship_types_string=all_relationship_types,
available_tools=available_tools,
)
reasoning_prompt = base_reasoning_prompt.build(
question=question,
chat_history_string=chat_history_string,
answer_history_string=answer_history_string,
iteration_nr=str(iteration_nr),
remaining_time_budget=str(remaining_time_budget),
)
reasoning_tokens: list[str] = [""]
reasoning_tokens, _, _ = run_with_timeout(
80,
lambda: stream_llm_answer(
llm=graph_config.tooling.primary_llm,
prompt=create_question_prompt(
assistant_system_prompt,
reasoning_prompt + (assistant_task_prompt or ""),
),
event_name="basic_response",
writer=writer,
agent_answer_level=0,
agent_answer_question_num=0,
agent_answer_type="agent_level_answer",
timeout_override=60,
answer_piece="reasoning_delta",
ind=current_step_nr,
# max_tokens=None,
),
)
write_custom_event(
current_step_nr,
SectionEnd(),
writer,
)
current_step_nr += 1
reasoning_result = cast(str, merge_content(*reasoning_tokens))
if SUFFICIENT_INFORMATION_STRING in reasoning_result:
return OrchestrationUpdate(
tools_used=[DRPath.CLOSER.value],
current_step_nr=current_step_nr,
query_list=[],
iteration_nr=iteration_nr,
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="orchestrator",
node_start_time=node_start_time,
)
],
plan_of_record=plan_of_record,
remaining_time_budget=remaining_time_budget,
iteration_instructions=[
IterationInstructions(
iteration_nr=iteration_nr,
plan=None,
reasoning=reasoning_result,
purpose="",
)
],
)
base_decision_prompt = get_dr_prompt_orchestration_templates(
DRPromptPurpose.NEXT_STEP,
ResearchType.THOUGHTFUL,
entity_types_string=all_entity_types,
relationship_types_string=all_relationship_types,
available_tools=available_tools,
)
decision_prompt = base_decision_prompt.build(
question=question,
chat_history_string=chat_history_string,
answer_history_string=answer_history_string,
iteration_nr=str(iteration_nr),
remaining_time_budget=str(remaining_time_budget),
reasoning_result=reasoning_result,
)
if remaining_time_budget > 0:
if decision_prompt is None:
raise ValueError("Decision prompt is required")
try:
orchestrator_action = invoke_llm_json(
llm=graph_config.tooling.primary_llm,
prompt=create_question_prompt(
assistant_system_prompt,
decision_prompt + (assistant_task_prompt or ""),
),
schema=OrchestratorDecisonsNoPlan,
timeout_override=35,
# max_tokens=2500,
)
next_step = orchestrator_action.next_step
next_tool = next_step.tool
query_list = [q for q in (next_step.questions or [])]
tool_calls_string = create_tool_call_string(next_tool, query_list)
except Exception as e:
logger.error(f"Error in approach extraction: {e}")
raise e
remaining_time_budget -= available_tools[next_tool].cost
else:
if iteration_nr == 1 and not plan_of_record:
# by default, we start a new iteration, but if there is a feedback request,
# we start a new iteration 0 again (set a bit later)
remaining_time_budget = DR_TIME_BUDGET_BY_TYPE[ResearchType.DEEP]
base_plan_prompt = get_dr_prompt_orchestration_templates(
DRPromptPurpose.PLAN,
ResearchType.DEEP,
entity_types_string=all_entity_types,
relationship_types_string=all_relationship_types,
available_tools=available_tools,
)
plan_generation_prompt = base_plan_prompt.build(
question=prompt_question,
chat_history_string=chat_history_string,
)
try:
plan_of_record = invoke_llm_json(
llm=graph_config.tooling.primary_llm,
prompt=create_question_prompt(
assistant_system_prompt,
plan_generation_prompt + (assistant_task_prompt or ""),
),
schema=OrchestrationPlan,
timeout_override=25,
# max_tokens=3000,
)
except Exception as e:
logger.error(f"Error in plan generation: {e}")
raise
write_custom_event(
current_step_nr,
ReasoningStart(
type="reasoning_start",
),
writer,
)
write_custom_event(
current_step_nr,
ReasoningDelta(
reasoning=f"{HIGH_LEVEL_PLAN_PREFIX} {plan_of_record.plan}\n\n",
type="reasoning_delta",
),
writer,
)
write_custom_event(
current_step_nr,
SectionEnd(),
writer,
)
current_step_nr += 1
if not plan_of_record:
raise ValueError(
"Plan information is required for iterative decision making"
)
base_decision_prompt = get_dr_prompt_orchestration_templates(
DRPromptPurpose.NEXT_STEP,
ResearchType.DEEP,
entity_types_string=all_entity_types,
relationship_types_string=all_relationship_types,
available_tools=available_tools,
)
decision_prompt = base_decision_prompt.build(
answer_history_string=answer_history_string,
question_history_string=question_history_string,
question=prompt_question,
iteration_nr=str(iteration_nr),
current_plan_of_record_string=plan_of_record.plan,
chat_history_string=chat_history_string,
remaining_time_budget=str(remaining_time_budget),
gaps=gaps_str,
)
if remaining_time_budget > 0:
if decision_prompt is None:
raise ValueError("Decision prompt is required")
try:
orchestrator_action = invoke_llm_json(
llm=graph_config.tooling.primary_llm,
prompt=create_question_prompt(
assistant_system_prompt,
decision_prompt + (assistant_task_prompt or ""),
),
schema=OrchestratorDecisonsNoPlan,
timeout_override=15,
# max_tokens=1500,
)
next_step = orchestrator_action.next_step
next_tool = next_step.tool
query_list = [q for q in (next_step.questions or [])]
reasoning_result = orchestrator_action.reasoning
tool_calls_string = create_tool_call_string(next_tool, query_list)
except Exception as e:
logger.error(f"Error in approach extraction: {e}")
raise e
remaining_time_budget -= available_tools[next_tool].cost
else:
reasoning_result = "Time to wrap up."
write_custom_event(
current_step_nr,
ReasoningStart(
type="reasoning_start",
),
writer,
)
write_custom_event(
current_step_nr,
ReasoningDelta(
reasoning=reasoning_result,
type="reasoning_delta",
),
writer,
)
write_custom_event(
current_step_nr,
SectionEnd(),
writer,
)
current_step_nr += 1
base_next_step_purpose_prompt = get_dr_prompt_orchestration_templates(
DRPromptPurpose.NEXT_STEP_PURPOSE,
ResearchType.DEEP,
entity_types_string=all_entity_types,
relationship_types_string=all_relationship_types,
available_tools=available_tools,
)
orchestration_next_step_purpose_prompt = base_next_step_purpose_prompt.build(
question=prompt_question,
reasoning_result=reasoning_result,
tool_calls=tool_calls_string,
)
purpose_tokens: list[str] = [""]
try:
write_custom_event(
current_step_nr,
ReasoningStart(
type="reasoning_start",
),
writer,
)
purpose_tokens, _, _ = run_with_timeout(
80,
lambda: stream_llm_answer(
llm=graph_config.tooling.primary_llm,
prompt=create_question_prompt(
assistant_system_prompt,
orchestration_next_step_purpose_prompt
+ (assistant_task_prompt or ""),
),
event_name="basic_response",
writer=writer,
agent_answer_level=0,
agent_answer_question_num=0,
agent_answer_type="agent_level_answer",
timeout_override=60,
answer_piece="reasoning_delta",
ind=current_step_nr,
# max_tokens=None,
),
)
write_custom_event(
current_step_nr,
SectionEnd(),
writer,
)
current_step_nr += 1
except Exception as e:
logger.error(f"Error in orchestration next step purpose: {e}")
raise e
purpose = cast(str, merge_content(*purpose_tokens))
return OrchestrationUpdate(
tools_used=[next_tool],
query_list=query_list or [],
iteration_nr=iteration_nr,
current_step_nr=current_step_nr,
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="orchestrator",
node_start_time=node_start_time,
)
],
plan_of_record=plan_of_record,
remaining_time_budget=remaining_time_budget,
iteration_instructions=[
IterationInstructions(
iteration_nr=iteration_nr,
plan=plan_of_record.plan if plan_of_record else None,
reasoning=reasoning_result,
purpose=purpose,
)
],
)

View File

@@ -0,0 +1,381 @@
import re
from datetime import datetime
from typing import cast
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from sqlalchemy.orm import Session
from onyx.agents.agent_search.dr.constants import MAX_CHAT_HISTORY_MESSAGES
from onyx.agents.agent_search.dr.constants import MAX_NUM_CLOSER_SUGGESTIONS
from onyx.agents.agent_search.dr.enums import DRPath
from onyx.agents.agent_search.dr.enums import ResearchAnswerPurpose
from onyx.agents.agent_search.dr.enums import ResearchType
from onyx.agents.agent_search.dr.models import AggregatedDRContext
from onyx.agents.agent_search.dr.models import TestInfoCompleteResponse
from onyx.agents.agent_search.dr.states import FinalUpdate
from onyx.agents.agent_search.dr.states import MainState
from onyx.agents.agent_search.dr.states import OrchestrationUpdate
from onyx.agents.agent_search.dr.utils import aggregate_context
from onyx.agents.agent_search.dr.utils import convert_inference_sections_to_search_docs
from onyx.agents.agent_search.dr.utils import get_chat_history_string
from onyx.agents.agent_search.dr.utils import get_prompt_question
from onyx.agents.agent_search.dr.utils import parse_plan_to_dict
from onyx.agents.agent_search.dr.utils import update_db_session_with_messages
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.llm import invoke_llm_json
from onyx.agents.agent_search.shared_graph_utils.llm import stream_llm_answer
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.agents.agent_search.utils import create_question_prompt
from onyx.chat.chat_utils import llm_doc_from_inference_section
from onyx.context.search.models import InferenceSection
from onyx.db.chat import create_search_doc_from_inference_section
from onyx.db.models import ChatMessage__SearchDoc
from onyx.db.models import ResearchAgentIteration
from onyx.db.models import ResearchAgentIterationSubStep
from onyx.db.models import SearchDoc as DbSearchDoc
from onyx.prompts.dr_prompts import FINAL_ANSWER_PROMPT_W_SUB_ANSWERS
from onyx.prompts.dr_prompts import FINAL_ANSWER_PROMPT_WITHOUT_SUB_ANSWERS
from onyx.prompts.dr_prompts import TEST_INFO_COMPLETE_PROMPT
from onyx.server.query_and_chat.streaming_models import CitationDelta
from onyx.server.query_and_chat.streaming_models import CitationStart
from onyx.server.query_and_chat.streaming_models import MessageStart
from onyx.server.query_and_chat.streaming_models import OverallStop
from onyx.server.query_and_chat.streaming_models import SectionEnd
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
logger = setup_logger()
def extract_citation_numbers(text: str) -> list[int]:
"""
Extract all citation numbers from text in the format [[<number>]] or [[<number_1>, <number_2>, ...]].
Returns a list of all unique citation numbers found.
"""
import re
# Pattern to match [[number]] or [[number1, number2, ...]]
pattern = r"\[\[(\d+(?:,\s*\d+)*)\]\]"
matches = re.findall(pattern, text)
cited_numbers = []
for match in matches:
# Split by comma and extract all numbers
numbers = [int(num.strip()) for num in match.split(",")]
cited_numbers.extend(numbers)
return list(set(cited_numbers)) # Return unique numbers
def replace_citation_with_link(match: re.Match[str], docs: list[DbSearchDoc]) -> str:
citation_content = match.group(1) # e.g., "3" or "3, 5, 7"
numbers = [int(num.strip()) for num in citation_content.split(",")]
# For multiple citations like [[3, 5, 7]], create separate linked citations
linked_citations = []
for num in numbers:
if num - 1 < len(docs): # Check bounds
link = docs[num - 1].link or ""
linked_citations.append(f"[[{num}]]({link})")
else:
linked_citations.append(f"[[{num}]]") # No link if out of bounds
return "".join(linked_citations)
def insert_chat_message_search_doc_pair(
message_id: int, search_doc_ids: list[int], db_session: Session
) -> None:
"""
Insert a pair of message_id and search_doc_id into the chat_message__search_doc table.
Args:
message_id: The ID of the chat message
search_doc_id: The ID of the search document
db_session: The database session
"""
for search_doc_id in search_doc_ids:
chat_message_search_doc = ChatMessage__SearchDoc(
chat_message_id=message_id, search_doc_id=search_doc_id
)
db_session.add(chat_message_search_doc)
def save_iteration(
state: MainState,
graph_config: GraphConfig,
aggregated_context: AggregatedDRContext,
final_answer: str,
all_cited_documents: list[InferenceSection],
is_internet_marker_dict: dict[str, bool],
) -> None:
db_session = graph_config.persistence.db_session
message_id = graph_config.persistence.message_id
research_type = graph_config.behavior.research_type
# first, insert the search_docs
search_docs = [
create_search_doc_from_inference_section(
inference_section=inference_section,
is_internet=is_internet_marker_dict.get(
inference_section.center_chunk.document_id, False
), # TODO: revisit
db_session=db_session,
commit=False,
)
for inference_section in all_cited_documents
]
# then, map_search_docs to message
insert_chat_message_search_doc_pair(
message_id, [search_doc.id for search_doc in search_docs], db_session
)
# lastly, insert the citations
cited_doc_nrs = extract_citation_numbers(final_answer)
citation_dict: dict[str | int, int] = {}
for cited_doc_nr in cited_doc_nrs:
citation_dict[cited_doc_nr] = search_docs[cited_doc_nr - 1].id
# TODO: generate plan as dict in the first place
plan_of_record = state.plan_of_record.plan if state.plan_of_record else ""
plan_of_record_dict = parse_plan_to_dict(plan_of_record)
# Update the chat message and its parent message in database
update_db_session_with_messages(
db_session=db_session,
chat_message_id=message_id,
chat_session_id=str(graph_config.persistence.chat_session_id),
is_agentic=graph_config.behavior.use_agentic_search,
message=final_answer,
citations=citation_dict,
research_type=research_type,
research_plan=plan_of_record_dict,
final_documents=search_docs,
update_parent_message=True,
research_answer_purpose=ResearchAnswerPurpose.ANSWER,
)
for iteration_preparation in state.iteration_instructions:
research_agent_iteration_step = ResearchAgentIteration(
primary_question_id=message_id,
reasoning=iteration_preparation.reasoning,
purpose=iteration_preparation.purpose,
iteration_nr=iteration_preparation.iteration_nr,
created_at=datetime.now(),
)
db_session.add(research_agent_iteration_step)
for iteration_answer in aggregated_context.global_iteration_responses:
retrieved_search_docs = convert_inference_sections_to_search_docs(
list(iteration_answer.cited_documents.values())
)
# Convert SavedSearchDoc objects to JSON-serializable format
serialized_search_docs = [doc.model_dump() for doc in retrieved_search_docs]
research_agent_iteration_sub_step = ResearchAgentIterationSubStep(
primary_question_id=message_id,
parent_question_id=None,
iteration_nr=iteration_answer.iteration_nr,
iteration_sub_step_nr=iteration_answer.parallelization_nr,
sub_step_instructions=iteration_answer.question,
sub_step_tool_id=iteration_answer.tool_id,
sub_answer=iteration_answer.answer,
reasoning=iteration_answer.reasoning,
claims=iteration_answer.claims,
cited_doc_results=serialized_search_docs,
additional_data=iteration_answer.additional_data,
created_at=datetime.now(),
)
db_session.add(research_agent_iteration_sub_step)
db_session.commit()
def closer(
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> FinalUpdate | OrchestrationUpdate:
"""
LangGraph node to close the DR process and finalize the answer.
"""
node_start_time = datetime.now()
# TODO: generate final answer using all the previous steps
# (right now, answers from each step are concatenated onto each other)
# Also, add missing fields once usage in UI is clear.
current_step_nr = state.current_step_nr
graph_config = cast(GraphConfig, config["metadata"]["config"])
base_question = state.original_question
if not base_question:
raise ValueError("Question is required for closer")
research_type = graph_config.behavior.research_type
assistant_system_prompt = state.assistant_system_prompt
assistant_task_prompt = state.assistant_task_prompt
clarification = state.clarification
prompt_question = get_prompt_question(base_question, clarification)
chat_history_string = (
get_chat_history_string(
graph_config.inputs.prompt_builder.message_history,
MAX_CHAT_HISTORY_MESSAGES,
)
or "(No chat history yet available)"
)
aggregated_context = aggregate_context(
state.iteration_responses, include_documents=True
)
iteration_responses_string = aggregated_context.context
all_cited_documents = aggregated_context.cited_documents
is_internet_marker_dict = aggregated_context.is_internet_marker_dict
num_closer_suggestions = state.num_closer_suggestions
if (
num_closer_suggestions < MAX_NUM_CLOSER_SUGGESTIONS
and research_type == ResearchType.DEEP
):
test_info_complete_prompt = TEST_INFO_COMPLETE_PROMPT.build(
base_question=prompt_question,
questions_answers_claims=iteration_responses_string,
chat_history_string=chat_history_string,
high_level_plan=(
state.plan_of_record.plan
if state.plan_of_record
else "No plan available"
),
)
test_info_complete_json = invoke_llm_json(
llm=graph_config.tooling.primary_llm,
prompt=create_question_prompt(
assistant_system_prompt,
test_info_complete_prompt + (assistant_task_prompt or ""),
),
schema=TestInfoCompleteResponse,
timeout_override=40,
# max_tokens=1000,
)
if test_info_complete_json.complete:
pass
else:
return OrchestrationUpdate(
tools_used=[DRPath.ORCHESTRATOR.value],
query_list=[],
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="closer",
node_start_time=node_start_time,
)
],
gaps=test_info_complete_json.gaps,
num_closer_suggestions=num_closer_suggestions + 1,
)
retrieved_search_docs = convert_inference_sections_to_search_docs(
all_cited_documents
)
write_custom_event(
current_step_nr,
MessageStart(
content="",
final_documents=retrieved_search_docs,
),
writer,
)
if research_type == ResearchType.THOUGHTFUL:
final_answer_base_prompt = FINAL_ANSWER_PROMPT_WITHOUT_SUB_ANSWERS
else:
final_answer_base_prompt = FINAL_ANSWER_PROMPT_W_SUB_ANSWERS
final_answer_prompt = final_answer_base_prompt.build(
base_question=prompt_question,
iteration_responses_string=iteration_responses_string,
chat_history_string=chat_history_string,
)
all_context_llmdocs = [
llm_doc_from_inference_section(inference_section)
for inference_section in all_cited_documents
]
try:
streamed_output, _, citation_infos = run_with_timeout(
240,
lambda: stream_llm_answer(
llm=graph_config.tooling.primary_llm,
prompt=create_question_prompt(
assistant_system_prompt,
final_answer_prompt + (assistant_task_prompt or ""),
),
event_name="basic_response",
writer=writer,
agent_answer_level=0,
agent_answer_question_num=0,
agent_answer_type="agent_level_answer",
timeout_override=60,
answer_piece="message_delta",
ind=current_step_nr,
context_docs=all_context_llmdocs,
replace_citations=True,
# max_tokens=None,
),
)
final_answer = "".join(streamed_output)
except Exception as e:
raise ValueError(f"Error in consolidate_research: {e}")
write_custom_event(current_step_nr, SectionEnd(), writer)
current_step_nr += 1
write_custom_event(current_step_nr, CitationStart(), writer)
write_custom_event(current_step_nr, CitationDelta(citations=citation_infos), writer)
write_custom_event(current_step_nr, SectionEnd(), writer)
current_step_nr += 1
write_custom_event(current_step_nr, OverallStop(), writer)
# Log the research agent steps
save_iteration(
state,
graph_config,
aggregated_context,
final_answer,
all_cited_documents,
is_internet_marker_dict,
)
return FinalUpdate(
final_answer=final_answer,
all_cited_documents=all_cited_documents,
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="closer",
node_start_time=node_start_time,
)
],
)

View File

@@ -0,0 +1,79 @@
from operator import add
from typing import Annotated
from typing import TypedDict
from pydantic import BaseModel
from onyx.agents.agent_search.core_state import CoreState
from onyx.agents.agent_search.dr.models import IterationAnswer
from onyx.agents.agent_search.dr.models import IterationInstructions
from onyx.agents.agent_search.dr.models import OrchestrationClarificationInfo
from onyx.agents.agent_search.dr.models import OrchestrationPlan
from onyx.agents.agent_search.dr.models import OrchestratorTool
from onyx.context.search.models import InferenceSection
from onyx.db.connector import DocumentSource
### States ###
class LoggerUpdate(BaseModel):
log_messages: Annotated[list[str], add] = []
class OrchestrationUpdate(LoggerUpdate):
tools_used: Annotated[list[str], add] = []
query_list: list[str] = []
iteration_nr: int = 0
current_step_nr: int = 1
plan_of_record: OrchestrationPlan | None = None # None for Thoughtful
remaining_time_budget: float = 2.0 # set by default to about 2 searches
num_closer_suggestions: int = 0 # how many times the closer was suggested
gaps: list[str] = (
[]
) # gaps that may be identified by the closer before being able to answer the question.
iteration_instructions: Annotated[list[IterationInstructions], add] = []
class OrchestrationSetup(OrchestrationUpdate):
original_question: str | None = None
chat_history_string: str | None = None
clarification: OrchestrationClarificationInfo | None = None
available_tools: dict[str, OrchestratorTool] | None = None
num_closer_suggestions: int = 0 # how many times the closer was suggested
active_source_types: list[DocumentSource] | None = None
active_source_types_descriptions: str | None = None
assistant_system_prompt: str | None = None
assistant_task_prompt: str | None = None
class AnswerUpdate(LoggerUpdate):
iteration_responses: Annotated[list[IterationAnswer], add] = []
class FinalUpdate(LoggerUpdate):
final_answer: str | None = None
all_cited_documents: list[InferenceSection] = []
## Graph Input State
class MainInput(CoreState):
pass
## Graph State
class MainState(
# This includes the core state
MainInput,
OrchestrationSetup,
AnswerUpdate,
FinalUpdate,
):
pass
## Graph Output State
class MainOutput(TypedDict):
log_messages: list[str]
final_answer: str | None
all_cited_documents: list[InferenceSection]

View File

@@ -0,0 +1,36 @@
from datetime import datetime
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.dr.states import LoggerUpdate
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.utils.logger import setup_logger
logger = setup_logger()
def basic_search_branch(
state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> LoggerUpdate:
"""
LangGraph node to perform a standard search as part of the DR process.
"""
node_start_time = datetime.now()
iteration_nr = state.iteration_nr
logger.debug(f"Search start for Basic Search {iteration_nr} at {datetime.now()}")
return LoggerUpdate(
log_messages=[
get_langgraph_node_log_string(
graph_component="basic_search",
node_name="branching",
node_start_time=node_start_time,
)
],
)

View File

@@ -0,0 +1,232 @@
import re
from datetime import datetime
from typing import cast
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.dr.enums import ResearchType
from onyx.agents.agent_search.dr.models import BaseSearchProcessingResponse
from onyx.agents.agent_search.dr.models import IterationAnswer
from onyx.agents.agent_search.dr.models import SearchAnswer
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
from onyx.agents.agent_search.dr.sub_agents.states import BranchUpdate
from onyx.agents.agent_search.dr.utils import extract_document_citations
from onyx.agents.agent_search.kb_search.graph_utils import build_document_context
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.llm import invoke_llm_json
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.utils import create_question_prompt
from onyx.chat.models import LlmDoc
from onyx.context.search.models import InferenceSection
from onyx.db.connector import DocumentSource
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.prompts.dr_prompts import BASE_SEARCH_PROCESSING_PROMPT
from onyx.prompts.dr_prompts import INTERNAL_SEARCH_PROMPTS
from onyx.tools.models import SearchToolOverrideKwargs
from onyx.tools.tool_implementations.search.search_tool import (
SEARCH_RESPONSE_SUMMARY_ID,
)
from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary
from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.utils.logger import setup_logger
logger = setup_logger()
def basic_search(
state: BranchInput,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> BranchUpdate:
"""
LangGraph node to perform a standard search as part of the DR process.
"""
node_start_time = datetime.now()
iteration_nr = state.iteration_nr
parallelization_nr = state.parallelization_nr
assistant_system_prompt = state.assistant_system_prompt
assistant_task_prompt = state.assistant_task_prompt
branch_query = state.branch_question
if not branch_query:
raise ValueError("branch_query is not set")
graph_config = cast(GraphConfig, config["metadata"]["config"])
base_question = graph_config.inputs.prompt_builder.raw_user_query
research_type = graph_config.behavior.research_type
if not state.available_tools:
raise ValueError("available_tools is not set")
search_tool_info = state.available_tools[state.tools_used[-1]]
search_tool = cast(SearchTool, search_tool_info.tool_object)
# sanity check
if search_tool != graph_config.tooling.search_tool:
raise ValueError("search_tool does not match the configured search tool")
# rewrite query and identify source types
active_source_types_str = ", ".join(
[source.value for source in state.active_source_types or []]
)
base_search_processing_prompt = BASE_SEARCH_PROCESSING_PROMPT.build(
active_source_types_str=active_source_types_str,
branch_query=branch_query,
)
try:
search_processing = invoke_llm_json(
llm=graph_config.tooling.primary_llm,
prompt=create_question_prompt(
assistant_system_prompt, base_search_processing_prompt
),
schema=BaseSearchProcessingResponse,
timeout_override=5,
# max_tokens=100,
)
except Exception as e:
logger.error(f"Could not process query: {e}")
raise e
rewritten_query = search_processing.rewritten_query
implied_start_date = search_processing.time_filter
# Validate time_filter format if it exists
implied_time_filter = None
if implied_start_date:
# Check if time_filter is in YYYY-MM-DD format
date_pattern = r"^\d{4}-\d{2}-\d{2}$"
if re.match(date_pattern, implied_start_date):
implied_time_filter = datetime.strptime(implied_start_date, "%Y-%m-%d")
specified_source_types: list[DocumentSource] | None = [
DocumentSource(source_type)
for source_type in search_processing.specified_source_types
]
if specified_source_types is not None and len(specified_source_types) == 0:
specified_source_types = None
logger.debug(
f"Search start for Standard Search {iteration_nr}.{parallelization_nr} at {datetime.now()}"
)
retrieved_docs: list[InferenceSection] = []
callback_container: list[list[InferenceSection]] = []
# new db session to avoid concurrency issues
with get_session_with_current_tenant() as search_db_session:
for tool_response in search_tool.run(
query=rewritten_query,
document_sources=specified_source_types,
time_filter=implied_time_filter,
override_kwargs=SearchToolOverrideKwargs(
force_no_rerank=True,
alternate_db_session=search_db_session,
retrieved_sections_callback=callback_container.append,
skip_query_analysis=True,
),
):
# get retrieved docs to send to the rest of the graph
if tool_response.id == SEARCH_RESPONSE_SUMMARY_ID:
response = cast(SearchResponseSummary, tool_response.response)
retrieved_docs = response.top_sections
break
document_texts_list = []
for doc_num, retrieved_doc in enumerate(retrieved_docs[:15]):
if not isinstance(retrieved_doc, (InferenceSection, LlmDoc)):
raise ValueError(f"Unexpected document type: {type(retrieved_doc)}")
chunk_text = build_document_context(retrieved_doc, doc_num + 1)
document_texts_list.append(chunk_text)
document_texts = "\n\n".join(document_texts_list)
logger.debug(
f"Search end/LLM start for Standard Search {iteration_nr}.{parallelization_nr} at {datetime.now()}"
)
# Built prompt
if research_type == ResearchType.DEEP:
search_prompt = INTERNAL_SEARCH_PROMPTS[research_type].build(
search_query=branch_query,
base_question=base_question,
document_text=document_texts,
)
# Run LLM
# search_answer_json = None
search_answer_json = invoke_llm_json(
llm=graph_config.tooling.primary_llm,
prompt=create_question_prompt(
assistant_system_prompt, search_prompt + (assistant_task_prompt or "")
),
schema=SearchAnswer,
timeout_override=40,
# max_tokens=1500,
)
logger.debug(
f"LLM/all done for Standard Search {iteration_nr}.{parallelization_nr} at {datetime.now()}"
)
# get cited documents
answer_string = search_answer_json.answer
claims = search_answer_json.claims or []
reasoning = search_answer_json.reasoning
# answer_string = ""
# claims = []
(
citation_numbers,
answer_string,
claims,
) = extract_document_citations(answer_string, claims)
cited_documents = {
citation_number: retrieved_docs[citation_number - 1]
for citation_number in citation_numbers
}
else:
answer_string = ""
claims = []
cited_documents = {
doc_num + 1: retrieved_doc
for doc_num, retrieved_doc in enumerate(retrieved_docs[:15])
}
reasoning = ""
return BranchUpdate(
branch_iteration_responses=[
IterationAnswer(
tool=search_tool_info.llm_path,
tool_id=search_tool_info.tool_id,
iteration_nr=iteration_nr,
parallelization_nr=parallelization_nr,
question=branch_query,
answer=answer_string,
claims=claims,
cited_documents=cited_documents,
reasoning=reasoning,
additional_data=None,
)
],
log_messages=[
get_langgraph_node_log_string(
graph_component="basic_search",
node_name="searching",
node_start_time=node_start_time,
)
],
)

View File

@@ -0,0 +1,99 @@
from datetime import datetime
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentUpdate
from onyx.agents.agent_search.dr.utils import chunks_or_sections_to_search_docs
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.context.search.models import SavedSearchDoc
from onyx.server.query_and_chat.streaming_models import SearchToolDelta
from onyx.server.query_and_chat.streaming_models import SearchToolStart
from onyx.server.query_and_chat.streaming_models import SectionEnd
from onyx.utils.logger import setup_logger
logger = setup_logger()
def is_reducer(
state: SubAgentMainState,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> SubAgentUpdate:
"""
LangGraph node to perform a standard search as part of the DR process.
"""
node_start_time = datetime.now()
branch_updates = state.branch_iteration_responses
current_iteration = state.iteration_nr
current_step_nr = state.current_step_nr
new_updates = [
update for update in branch_updates if update.iteration_nr == current_iteration
]
queries = [update.question for update in new_updates]
doc_lists = [list(update.cited_documents.values()) for update in new_updates]
doc_list = []
for xs in doc_lists:
for x in xs:
doc_list.append(x)
# Convert InferenceSections to SavedSearchDocs
search_docs = chunks_or_sections_to_search_docs(doc_list)
retrieved_saved_search_docs = [
SavedSearchDoc.from_search_doc(search_doc, db_doc_id=0)
for search_doc in search_docs
]
for retrieved_saved_search_doc in retrieved_saved_search_docs:
retrieved_saved_search_doc.is_internet = False
# Write the results to the stream
write_custom_event(
current_step_nr,
SearchToolStart(
type="internal_search_tool_start",
is_internet_search=False,
),
writer,
)
write_custom_event(
current_step_nr,
SearchToolDelta(
queries=queries,
documents=retrieved_saved_search_docs,
type="internal_search_tool_delta",
),
writer,
)
write_custom_event(
current_step_nr,
SectionEnd(),
writer,
)
current_step_nr += 1
return SubAgentUpdate(
iteration_responses=new_updates,
current_step_nr=current_step_nr,
log_messages=[
get_langgraph_node_log_string(
graph_component="basic_search",
node_name="consolidation",
node_start_time=node_start_time,
)
],
)

View File

@@ -0,0 +1,50 @@
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_basic_search_1_branch import (
basic_search_branch,
)
from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_basic_search_2_act import (
basic_search,
)
from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_basic_search_3_reduce import (
is_reducer,
)
from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_image_generation_conditional_edges import (
branching_router,
)
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
from onyx.utils.logger import setup_logger
logger = setup_logger()
def dr_basic_search_graph_builder() -> StateGraph:
"""
LangGraph graph builder for Internet Search Sub-Agent
"""
graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
### Add nodes ###
graph.add_node("branch", basic_search_branch)
graph.add_node("act", basic_search)
graph.add_node("reducer", is_reducer)
### Add edges ###
graph.add_edge(start_key=START, end_key="branch")
graph.add_conditional_edges("branch", branching_router)
graph.add_edge(start_key="act", end_key="reducer")
graph.add_edge(start_key="reducer", end_key=END)
return graph

View File

@@ -0,0 +1,29 @@
from collections.abc import Hashable
from langgraph.types import Send
from onyx.agents.agent_search.dr.constants import MAX_DR_PARALLEL_SEARCH
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
return [
Send(
"act",
BranchInput(
iteration_nr=state.iteration_nr,
parallelization_nr=parallelization_nr,
branch_question=query,
context="",
active_source_types=state.active_source_types,
tools_used=state.tools_used,
available_tools=state.available_tools,
assistant_system_prompt=state.assistant_system_prompt,
assistant_task_prompt=state.assistant_task_prompt,
),
)
for parallelization_nr, query in enumerate(
state.query_list[:MAX_DR_PARALLEL_SEARCH]
)
]

View File

@@ -0,0 +1,36 @@
from datetime import datetime
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.dr.states import LoggerUpdate
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.utils.logger import setup_logger
logger = setup_logger()
def custom_tool_branch(
state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> LoggerUpdate:
"""
LangGraph node to perform a generic tool call as part of the DR process.
"""
node_start_time = datetime.now()
iteration_nr = state.iteration_nr
logger.debug(f"Search start for Generic Tool {iteration_nr} at {datetime.now()}")
return LoggerUpdate(
log_messages=[
get_langgraph_node_log_string(
graph_component="custom_tool",
node_name="branching",
node_start_time=node_start_time,
)
],
)

View File

@@ -0,0 +1,152 @@
import json
from datetime import datetime
from typing import cast
from langchain_core.messages import AIMessage
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.dr.models import IterationAnswer
from onyx.agents.agent_search.dr.states import AnswerUpdate
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.prompts.dr_prompts import CUSTOM_TOOL_USE_PROMPT
from onyx.tools.tool_implementations.custom.custom_tool import CUSTOM_TOOL_RESPONSE_ID
from onyx.tools.tool_implementations.custom.custom_tool import CustomTool
from onyx.tools.tool_implementations.custom.custom_tool import CustomToolCallSummary
from onyx.utils.logger import setup_logger
logger = setup_logger()
def custom_tool_act(
state: BranchInput,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> AnswerUpdate:
"""
LangGraph node to perform a generic tool call as part of the DR process.
"""
node_start_time = datetime.now()
iteration_nr = state.iteration_nr
parallelization_nr = state.parallelization_nr
if not state.available_tools:
raise ValueError("available_tools is not set")
custom_tool_info = state.available_tools[state.tools_used[-1]]
custom_tool_name = custom_tool_info.llm_path
custom_tool = cast(CustomTool, custom_tool_info.tool_object)
branch_query = state.branch_question
if not branch_query:
raise ValueError("branch_query is not set")
graph_config = cast(GraphConfig, config["metadata"]["config"])
base_question = graph_config.inputs.prompt_builder.raw_user_query
logger.debug(
f"Tool call start for {custom_tool_name} {iteration_nr}.{parallelization_nr} at {datetime.now()}"
)
# get tool call args
tool_args: dict | None = None
if graph_config.tooling.using_tool_calling_llm:
# get tool call args from tool-calling LLM
tool_use_prompt = CUSTOM_TOOL_USE_PROMPT.build(
query=branch_query,
base_question=base_question,
tool_response="(No tool response yet. You need to call the tool to answer the question.)",
)
tool_calling_msg = graph_config.tooling.primary_llm.invoke(
tool_use_prompt,
tools=[custom_tool.tool_definition()],
tool_choice="required",
timeout_override=40,
)
# make sure we got a tool call
if (
isinstance(tool_calling_msg, AIMessage)
and len(tool_calling_msg.tool_calls) == 1
):
tool_args = tool_calling_msg.tool_calls[0]["args"]
else:
logger.warning("Tool-calling LLM did not emit a tool call")
if tool_args is None:
# get tool call args from non-tool-calling LLM or for failed tool-calling LLM
tool_args = custom_tool.get_args_for_non_tool_calling_llm(
query=branch_query,
history=[],
llm=graph_config.tooling.primary_llm,
force_run=True,
)
if tool_args is None:
raise ValueError("Failed to obtain tool arguments from LLM")
# run the tool
response_summary: CustomToolCallSummary | None = None
for tool_response in custom_tool.run(**tool_args):
if tool_response.id == CUSTOM_TOOL_RESPONSE_ID:
response_summary = cast(CustomToolCallSummary, tool_response.response)
break
if not response_summary:
raise ValueError("Custom tool did not return a valid response summary")
# summarise tool result
if response_summary.response_type == "json":
tool_result_str = json.dumps(response_summary.tool_result, ensure_ascii=False)
elif response_summary.response_type in {"image", "csv"}:
tool_result_str = f"{response_summary.response_type} files: {response_summary.tool_result.file_ids}"
else:
tool_result_str = str(response_summary.tool_result)
tool_str = (
f"Tool used: {custom_tool_name}\n"
f"Description: {custom_tool_info.description}\n"
f"Result: {tool_result_str}"
)
tool_summary_prompt = CUSTOM_TOOL_USE_PROMPT.build(
query=branch_query, base_question=base_question, tool_response=tool_str
)
answer_string = str(
graph_config.tooling.primary_llm.invoke(
tool_summary_prompt, timeout_override=40
).content
).strip()
logger.debug(
f"Tool call end for {custom_tool_name} {iteration_nr}.{parallelization_nr} at {datetime.now()}"
)
return AnswerUpdate(
iteration_responses=[
IterationAnswer(
tool=custom_tool_name,
tool_id=custom_tool_info.tool_id,
iteration_nr=iteration_nr,
parallelization_nr=parallelization_nr,
question=branch_query,
answer=answer_string,
claims=[],
cited_documents={},
reasoning="",
additional_data=None,
)
],
log_messages=[
get_langgraph_node_log_string(
graph_component="custom_tool",
node_name="tool_calling",
node_start_time=node_start_time,
)
],
)

View File

@@ -0,0 +1,44 @@
from datetime import datetime
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentUpdate
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.utils.logger import setup_logger
logger = setup_logger()
def custom_tool_reducer(
state: SubAgentMainState,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> SubAgentUpdate:
"""
LangGraph node to perform a generic tool call as part of the DR process.
"""
node_start_time = datetime.now()
branch_updates = state.branch_iteration_responses
current_iteration = state.iteration_nr
new_updates = [
update for update in branch_updates if update.iteration_nr == current_iteration
]
return SubAgentUpdate(
iteration_responses=new_updates,
log_messages=[
get_langgraph_node_log_string(
graph_component="custom_tool",
node_name="consolidation",
node_start_time=node_start_time,
)
],
)

View File

@@ -0,0 +1,28 @@
from collections.abc import Hashable
from langgraph.types import Send
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
from onyx.agents.agent_search.dr.sub_agents.states import (
SubAgentInput,
)
def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
return [
Send(
"act",
BranchInput(
iteration_nr=state.iteration_nr,
parallelization_nr=parallelization_nr,
branch_question=query,
context="",
active_source_types=state.active_source_types,
tools_used=state.tools_used,
available_tools=state.available_tools,
),
)
for parallelization_nr, query in enumerate(
state.query_list[:1] # no parallel call for now
)
]

View File

@@ -0,0 +1,50 @@
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
from onyx.agents.agent_search.dr.sub_agents.custom_tool.dr_custom_tool_1_branch import (
custom_tool_branch,
)
from onyx.agents.agent_search.dr.sub_agents.custom_tool.dr_custom_tool_2_act import (
custom_tool_act,
)
from onyx.agents.agent_search.dr.sub_agents.custom_tool.dr_custom_tool_3_reduce import (
custom_tool_reducer,
)
from onyx.agents.agent_search.dr.sub_agents.custom_tool.dr_custom_tool_conditional_edges import (
branching_router,
)
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
from onyx.utils.logger import setup_logger
logger = setup_logger()
def dr_custom_tool_graph_builder() -> StateGraph:
"""
LangGraph graph builder for Generic Tool Sub-Agent
"""
graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
### Add nodes ###
graph.add_node("branch", custom_tool_branch)
graph.add_node("act", custom_tool_act)
graph.add_node("reducer", custom_tool_reducer)
### Add edges ###
graph.add_edge(start_key=START, end_key="branch")
graph.add_conditional_edges("branch", branching_router)
graph.add_edge(start_key="act", end_key="reducer")
graph.add_edge(start_key="reducer", end_key=END)
return graph

View File

@@ -0,0 +1,36 @@
from datetime import datetime
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.dr.states import LoggerUpdate
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.utils.logger import setup_logger
logger = setup_logger()
def image_generation_branch(
state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> LoggerUpdate:
"""
LangGraph node to perform a standard search as part of the DR process.
"""
node_start_time = datetime.now()
iteration_nr = state.iteration_nr
logger.debug(f"Search start for Basic Search {iteration_nr} at {datetime.now()}")
return LoggerUpdate(
log_messages=[
get_langgraph_node_log_string(
graph_component="image_generation",
node_name="branching",
node_start_time=node_start_time,
)
],
)

View File

@@ -0,0 +1,115 @@
from datetime import datetime
from typing import cast
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.dr.models import IterationAnswer
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
from onyx.agents.agent_search.dr.sub_agents.states import BranchUpdate
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.tools.tool_implementations.images.image_generation_tool import (
IMAGE_GENERATION_RESPONSE_ID,
)
from onyx.tools.tool_implementations.images.image_generation_tool import (
ImageGenerationResponse,
)
from onyx.tools.tool_implementations.images.image_generation_tool import (
ImageGenerationTool,
)
from onyx.utils.logger import setup_logger
logger = setup_logger()
def image_generation(
state: BranchInput,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> BranchUpdate:
"""
LangGraph node to perform a standard search as part of the DR process.
"""
node_start_time = datetime.now()
iteration_nr = state.iteration_nr
parallelization_nr = state.parallelization_nr
state.assistant_system_prompt
state.assistant_task_prompt
branch_query = state.branch_question
if not branch_query:
raise ValueError("branch_query is not set")
graph_config = cast(GraphConfig, config["metadata"]["config"])
graph_config.inputs.prompt_builder.raw_user_query
graph_config.behavior.research_type
if not state.available_tools:
raise ValueError("available_tools is not set")
image_tool_info = state.available_tools[state.tools_used[-1]]
image_tool = cast(ImageGenerationTool, image_tool_info.tool_object)
logger.debug(
f"Image generation start for {iteration_nr}.{parallelization_nr} at {datetime.now()}"
)
# Generate images using the image generation tool
generated_images: list[ImageGenerationResponse] = []
for tool_response in image_tool.run(prompt=branch_query):
if tool_response.id == IMAGE_GENERATION_RESPONSE_ID:
response = cast(list[ImageGenerationResponse], tool_response.response)
generated_images = response
break
logger.debug(
f"Image generation complete for {iteration_nr}.{parallelization_nr} at {datetime.now()}"
)
# Create answer string describing the generated images
if generated_images:
image_descriptions = []
for i, img in enumerate(generated_images, 1):
image_descriptions.append(f"Image {i}: {img.revised_prompt}")
answer_string = (
f"Generated {len(generated_images)} image(s) based on the request: {branch_query}\n\n"
+ "\n".join(image_descriptions)
)
reasoning = f"Used image generation tool to create {len(generated_images)} image(s) based on the user's request."
else:
answer_string = f"Failed to generate images for request: {branch_query}"
reasoning = "Image generation tool did not return any results."
return BranchUpdate(
branch_iteration_responses=[
IterationAnswer(
tool=image_tool_info.llm_path,
tool_id=image_tool_info.tool_id,
iteration_nr=iteration_nr,
parallelization_nr=parallelization_nr,
question=branch_query,
answer=answer_string,
claims=[],
cited_documents={},
reasoning=reasoning,
additional_data=(
{"generated_images": str(len(generated_images))}
if generated_images
else None
),
)
],
log_messages=[
get_langgraph_node_log_string(
graph_component="image_generation",
node_name="generating",
node_start_time=node_start_time,
)
],
)

View File

@@ -0,0 +1,76 @@
from datetime import datetime
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentUpdate
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.server.query_and_chat.streaming_models import ImageGenerationToolDelta
from onyx.server.query_and_chat.streaming_models import ImageGenerationToolStart
from onyx.server.query_and_chat.streaming_models import SectionEnd
from onyx.utils.logger import setup_logger
logger = setup_logger()
def is_reducer(
state: SubAgentMainState,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> SubAgentUpdate:
"""
LangGraph node to perform a standard search as part of the DR process.
"""
node_start_time = datetime.now()
branch_updates = state.branch_iteration_responses
current_iteration = state.iteration_nr
current_step_nr = state.current_step_nr
new_updates = [
update for update in branch_updates if update.iteration_nr == current_iteration
]
# Write the results to the stream
write_custom_event(
current_step_nr,
ImageGenerationToolStart(
type="image_generation_tool_start",
),
writer,
)
write_custom_event(
current_step_nr,
ImageGenerationToolDelta(
images={},
type="image_generation_tool_delta",
),
writer,
)
write_custom_event(
current_step_nr,
SectionEnd(),
writer,
)
current_step_nr += 1
return SubAgentUpdate(
iteration_responses=new_updates,
current_step_nr=current_step_nr,
log_messages=[
get_langgraph_node_log_string(
graph_component="image_generation",
node_name="consolidation",
node_start_time=node_start_time,
)
],
)

View File

@@ -0,0 +1,29 @@
from collections.abc import Hashable
from langgraph.types import Send
from onyx.agents.agent_search.dr.constants import MAX_DR_PARALLEL_SEARCH
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
return [
Send(
"act",
BranchInput(
iteration_nr=state.iteration_nr,
parallelization_nr=parallelization_nr,
branch_question=query,
context="",
active_source_types=state.active_source_types,
tools_used=state.tools_used,
available_tools=state.available_tools,
assistant_system_prompt=state.assistant_system_prompt,
assistant_task_prompt=state.assistant_task_prompt,
),
)
for parallelization_nr, query in enumerate(
state.query_list[:MAX_DR_PARALLEL_SEARCH]
)
]

View File

@@ -0,0 +1,50 @@
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
from onyx.agents.agent_search.dr.sub_agents.image_generation.dr_image_generation_1_branch import (
image_generation_branch,
)
from onyx.agents.agent_search.dr.sub_agents.image_generation.dr_image_generation_2_act import (
image_generation,
)
from onyx.agents.agent_search.dr.sub_agents.image_generation.dr_image_generation_3_reduce import (
is_reducer,
)
from onyx.agents.agent_search.dr.sub_agents.image_generation.dr_image_generation_conditional_edges import (
branching_router,
)
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
from onyx.utils.logger import setup_logger
logger = setup_logger()
def dr_image_generation_graph_builder() -> StateGraph:
"""
LangGraph graph builder for Internet Search Sub-Agent
"""
graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
### Add nodes ###
graph.add_node("branch", image_generation_branch)
graph.add_node("act", image_generation)
graph.add_node("reducer", is_reducer)
### Add edges ###
graph.add_edge(start_key=START, end_key="branch")
graph.add_conditional_edges("branch", branching_router)
graph.add_edge(start_key="act", end_key="reducer")
graph.add_edge(start_key="reducer", end_key=END)
return graph

View File

@@ -0,0 +1,36 @@
from datetime import datetime
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.dr.states import LoggerUpdate
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.utils.logger import setup_logger
logger = setup_logger()
def is_branch(
state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> LoggerUpdate:
"""
LangGraph node to perform a internet search as part of the DR process.
"""
node_start_time = datetime.now()
iteration_nr = state.iteration_nr
logger.debug(f"Search start for Internet Search {iteration_nr} at {datetime.now()}")
return LoggerUpdate(
log_messages=[
get_langgraph_node_log_string(
graph_component="internet_search",
node_name="branching",
node_start_time=node_start_time,
)
],
)

View File

@@ -0,0 +1,175 @@
from datetime import datetime
from typing import cast
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.dr.enums import ResearchType
from onyx.agents.agent_search.dr.models import IterationAnswer
from onyx.agents.agent_search.dr.models import SearchAnswer
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
from onyx.agents.agent_search.dr.sub_agents.states import BranchUpdate
from onyx.agents.agent_search.dr.utils import extract_document_citations
from onyx.agents.agent_search.kb_search.graph_utils import build_document_context
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.llm import invoke_llm_json
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.utils import create_question_prompt
from onyx.chat.models import LlmDoc
from onyx.context.search.models import InferenceSection
from onyx.prompts.dr_prompts import INTERNAL_SEARCH_PROMPTS
from onyx.tools.tool_implementations.internet_search.internet_search_tool import (
INTERNET_SEARCH_RESPONSE_SUMMARY_ID,
)
from onyx.tools.tool_implementations.internet_search.internet_search_tool import (
InternetSearchTool,
)
from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary
from onyx.utils.logger import setup_logger
logger = setup_logger()
def internet_search(
state: BranchInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> BranchUpdate:
"""
LangGraph node to perform a internet search as part of the DR process.
"""
node_start_time = datetime.now()
iteration_nr = state.iteration_nr
parallelization_nr = state.parallelization_nr
assistant_system_prompt = state.assistant_system_prompt
assistant_task_prompt = state.assistant_task_prompt
search_query = state.branch_question
if not search_query:
raise ValueError("search_query is not set")
graph_config = cast(GraphConfig, config["metadata"]["config"])
base_question = graph_config.inputs.prompt_builder.raw_user_query
research_type = graph_config.behavior.research_type
logger.debug(
f"Search start for Internet Search {iteration_nr}.{parallelization_nr} at {datetime.now()}"
)
if graph_config.inputs.persona is None:
raise ValueError("persona is not set")
if not state.available_tools:
raise ValueError("available_tools is not set")
is_tool_info = state.available_tools[state.tools_used[-1]]
internet_search_tool = cast(InternetSearchTool, is_tool_info.tool_object)
if internet_search_tool.provider is None:
raise ValueError(
"internet_search_tool.provider is not set. This should not happen."
)
# Update search parameters
internet_search_tool.max_chunks = 10
internet_search_tool.provider.num_results = 10
retrieved_docs: list[InferenceSection] = []
for tool_response in internet_search_tool.run(internet_search_query=search_query):
# get retrieved docs to send to the rest of the graph
if tool_response.id == INTERNET_SEARCH_RESPONSE_SUMMARY_ID:
response = cast(SearchResponseSummary, tool_response.response)
retrieved_docs = response.top_sections
break
# stream_write_step_answer_explicit(writer, step_nr=1, answer=full_answer)
document_texts_list = []
for doc_num, retrieved_doc in enumerate(retrieved_docs[:15]):
if not isinstance(retrieved_doc, (InferenceSection, LlmDoc)):
raise ValueError(f"Unexpected document type: {type(retrieved_doc)}")
chunk_text = build_document_context(retrieved_doc, doc_num + 1)
document_texts_list.append(chunk_text)
document_texts = "\n\n".join(document_texts_list)
logger.debug(
f"Search end/LLM start for Internet Search {iteration_nr}.{parallelization_nr} at {datetime.now()}"
)
# Built prompt
if research_type == ResearchType.DEEP:
search_prompt = INTERNAL_SEARCH_PROMPTS[research_type].build(
search_query=search_query,
base_question=base_question,
document_text=document_texts,
)
# Run LLM
search_answer_json = invoke_llm_json(
llm=graph_config.tooling.primary_llm,
prompt=create_question_prompt(
assistant_system_prompt, search_prompt + (assistant_task_prompt or "")
),
schema=SearchAnswer,
timeout_override=40,
# max_tokens=3000,
)
logger.debug(
f"LLM/all done for Internet Search {iteration_nr}.{parallelization_nr} at {datetime.now()}"
)
# get cited documents
answer_string = search_answer_json.answer
claims = search_answer_json.claims or []
reasoning = search_answer_json.reasoning or ""
(
citation_numbers,
answer_string,
claims,
) = extract_document_citations(answer_string, claims)
cited_documents = {
citation_number: retrieved_docs[citation_number - 1]
for citation_number in citation_numbers
}
else:
answer_string = ""
claims = []
reasoning = ""
cited_documents = {
doc_num + 1: retrieved_doc
for doc_num, retrieved_doc in enumerate(retrieved_docs[:15])
}
return BranchUpdate(
branch_iteration_responses=[
IterationAnswer(
tool=is_tool_info.llm_path,
tool_id=is_tool_info.tool_id,
iteration_nr=iteration_nr,
parallelization_nr=parallelization_nr,
question=search_query,
answer=answer_string,
claims=claims,
cited_documents=cited_documents,
reasoning=reasoning,
additional_data=None,
)
],
log_messages=[
get_langgraph_node_log_string(
graph_component="internet_search",
node_name="searching",
node_start_time=node_start_time,
)
],
)

View File

@@ -0,0 +1,92 @@
from datetime import datetime
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentUpdate
from onyx.agents.agent_search.dr.utils import convert_inference_sections_to_search_docs
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.server.query_and_chat.streaming_models import SearchToolDelta
from onyx.server.query_and_chat.streaming_models import SearchToolStart
from onyx.server.query_and_chat.streaming_models import SectionEnd
from onyx.utils.logger import setup_logger
logger = setup_logger()
def is_reducer(
state: SubAgentMainState,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> SubAgentUpdate:
"""
LangGraph node to perform a internet search as part of the DR process.
"""
node_start_time = datetime.now()
branch_updates = state.branch_iteration_responses
current_iteration = state.iteration_nr
current_step_nr = state.current_step_nr
new_updates = [
update for update in branch_updates if update.iteration_nr == current_iteration
]
queries = [update.question for update in new_updates]
doc_lists = [list(update.cited_documents.values()) for update in new_updates]
doc_list = []
for xs in doc_lists:
for x in xs:
doc_list.append(x)
retrieved_search_docs = convert_inference_sections_to_search_docs(
doc_list, is_internet=True
)
# Write the results to the stream
write_custom_event(
current_step_nr,
SearchToolStart(
type="internal_search_tool_start",
is_internet_search=True,
),
writer,
)
write_custom_event(
current_step_nr,
SearchToolDelta(
queries=queries,
documents=retrieved_search_docs,
type="internal_search_tool_delta",
),
writer,
)
write_custom_event(
current_step_nr,
SectionEnd(),
writer,
)
current_step_nr += 1
return SubAgentUpdate(
iteration_responses=new_updates,
current_step_nr=current_step_nr,
log_messages=[
get_langgraph_node_log_string(
graph_component="internet_search",
node_name="consolidation",
node_start_time=node_start_time,
)
],
)

View File

@@ -0,0 +1,28 @@
from collections.abc import Hashable
from langgraph.types import Send
from onyx.agents.agent_search.dr.constants import MAX_DR_PARALLEL_SEARCH
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
return [
Send(
"act",
BranchInput(
iteration_nr=state.iteration_nr,
parallelization_nr=parallelization_nr,
branch_question=query,
context="",
tools_used=state.tools_used,
available_tools=state.available_tools,
assistant_system_prompt=state.assistant_system_prompt,
assistant_task_prompt=state.assistant_task_prompt,
),
)
for parallelization_nr, query in enumerate(
state.query_list[:MAX_DR_PARALLEL_SEARCH]
)
]

View File

@@ -0,0 +1,50 @@
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
from onyx.agents.agent_search.dr.sub_agents.internet_search.dr_is_1_branch import (
is_branch,
)
from onyx.agents.agent_search.dr.sub_agents.internet_search.dr_is_2_act import (
internet_search,
)
from onyx.agents.agent_search.dr.sub_agents.internet_search.dr_is_3_reduce import (
is_reducer,
)
from onyx.agents.agent_search.dr.sub_agents.internet_search.dr_is_conditional_edges import (
branching_router,
)
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
from onyx.utils.logger import setup_logger
logger = setup_logger()
def dr_is_graph_builder() -> StateGraph:
"""
LangGraph graph builder for Internet Search Sub-Agent
"""
graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
### Add nodes ###
graph.add_node("branch", is_branch)
graph.add_node("act", internet_search)
graph.add_node("reducer", is_reducer)
### Add edges ###
graph.add_edge(start_key=START, end_key="branch")
graph.add_conditional_edges("branch", branching_router)
graph.add_edge(start_key="act", end_key="reducer")
graph.add_edge(start_key="reducer", end_key=END)
return graph

View File

@@ -0,0 +1,36 @@
from datetime import datetime
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.dr.states import LoggerUpdate
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.utils.logger import setup_logger
logger = setup_logger()
def kg_search_branch(
state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> LoggerUpdate:
"""
LangGraph node to perform a KG search as part of the DR process.
"""
node_start_time = datetime.now()
iteration_nr = state.iteration_nr
logger.debug(f"Search start for KG Search {iteration_nr} at {datetime.now()}")
return LoggerUpdate(
log_messages=[
get_langgraph_node_log_string(
graph_component="kg_search",
node_name="branching",
node_start_time=node_start_time,
)
],
)

View File

@@ -0,0 +1,97 @@
from datetime import datetime
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.dr.models import IterationAnswer
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
from onyx.agents.agent_search.dr.sub_agents.states import BranchUpdate
from onyx.agents.agent_search.dr.utils import extract_document_citations
from onyx.agents.agent_search.kb_search.graph_builder import kb_graph_builder
from onyx.agents.agent_search.kb_search.states import MainInput as KbMainInput
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.context.search.models import InferenceSection
from onyx.utils.logger import setup_logger
logger = setup_logger()
def kg_search(
state: BranchInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> BranchUpdate:
"""
LangGraph node to perform a KG search as part of the DR process.
"""
node_start_time = datetime.now()
iteration_nr = state.iteration_nr
state.current_step_nr
parallelization_nr = state.parallelization_nr
search_query = state.branch_question
if not search_query:
raise ValueError("search_query is not set")
logger.debug(
f"Search start for KG Search {iteration_nr}.{parallelization_nr} at {datetime.now()}"
)
if not state.available_tools:
raise ValueError("available_tools is not set")
kg_tool_info = state.available_tools[state.tools_used[-1]]
kb_graph = kb_graph_builder().compile()
kb_results = kb_graph.invoke(
input=KbMainInput(question=search_query, individual_flow=False),
config=config,
)
# get cited documents
answer_string = kb_results.get("final_answer") or "No answer provided"
claims: list[str] = []
retrieved_docs: list[InferenceSection] = kb_results.get("retrieved_documents", [])
(
citation_numbers,
answer_string,
claims,
) = extract_document_citations(answer_string, claims)
# if citation is empty, the answer must have come from the KG rather than a doc
# in that case, simply cite the docs returned by the KG
if not citation_numbers:
citation_numbers = [i + 1 for i in range(len(retrieved_docs))]
cited_documents = {
citation_number: retrieved_docs[citation_number - 1]
for citation_number in citation_numbers
if citation_number <= len(retrieved_docs)
}
return BranchUpdate(
branch_iteration_responses=[
IterationAnswer(
tool=kg_tool_info.llm_path,
tool_id=kg_tool_info.tool_id,
iteration_nr=iteration_nr,
parallelization_nr=parallelization_nr,
question=search_query,
answer=answer_string,
claims=claims,
cited_documents=cited_documents,
reasoning=None,
additional_data=None,
)
],
log_messages=[
get_langgraph_node_log_string(
graph_component="kg_search",
node_name="searching",
node_start_time=node_start_time,
)
],
)

View File

@@ -0,0 +1,124 @@
from datetime import datetime
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentUpdate
from onyx.agents.agent_search.dr.utils import convert_inference_sections_to_search_docs
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.server.query_and_chat.streaming_models import ReasoningDelta
from onyx.server.query_and_chat.streaming_models import ReasoningStart
from onyx.server.query_and_chat.streaming_models import SearchToolDelta
from onyx.server.query_and_chat.streaming_models import SearchToolStart
from onyx.server.query_and_chat.streaming_models import SectionEnd
from onyx.utils.logger import setup_logger
logger = setup_logger()
_MAX_KG_STEAMED_ANSWER_LENGTH = 1000 # num characters
def kg_search_reducer(
state: SubAgentMainState,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> SubAgentUpdate:
"""
LangGraph node to perform a KG search as part of the DR process.
"""
node_start_time = datetime.now()
branch_updates = state.branch_iteration_responses
current_iteration = state.iteration_nr
current_step_nr = state.current_step_nr
new_updates = [
update for update in branch_updates if update.iteration_nr == current_iteration
]
queries = [update.question for update in new_updates]
doc_lists = [list(update.cited_documents.values()) for update in new_updates]
doc_list = []
for xs in doc_lists:
for x in xs:
doc_list.append(x)
retrieved_search_docs = convert_inference_sections_to_search_docs(doc_list)
if len(queries) == 1:
kg_answer: str | None = (
"The Knowledge Graph Answer:\n\n" + new_updates[0].answer
)
else:
kg_answer = None
if len(retrieved_search_docs) > 0:
write_custom_event(
current_step_nr,
SearchToolStart(
type="internal_search_tool_start",
),
writer,
)
write_custom_event(
current_step_nr,
SearchToolDelta(
queries=queries,
documents=retrieved_search_docs,
type="internal_search_tool_delta",
),
writer,
)
write_custom_event(
current_step_nr,
SectionEnd(),
writer,
)
current_step_nr += 1
if kg_answer is not None:
kg_display_answer = (
f"{kg_answer[:_MAX_KG_STEAMED_ANSWER_LENGTH]}..."
if len(kg_answer) > _MAX_KG_STEAMED_ANSWER_LENGTH
else kg_answer
)
write_custom_event(
current_step_nr,
ReasoningStart(),
writer,
)
write_custom_event(
current_step_nr,
ReasoningDelta(reasoning=kg_display_answer, type="reasoning_delta"),
writer,
)
write_custom_event(
current_step_nr,
SectionEnd(),
writer,
)
current_step_nr += 1
return SubAgentUpdate(
iteration_responses=new_updates,
current_step_nr=current_step_nr,
log_messages=[
get_langgraph_node_log_string(
graph_component="kg_search",
node_name="consolidation",
node_start_time=node_start_time,
)
],
)

View File

@@ -0,0 +1,27 @@
from collections.abc import Hashable
from langgraph.types import Send
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
return [
Send(
"act",
BranchInput(
iteration_nr=state.iteration_nr,
parallelization_nr=parallelization_nr,
branch_question=query,
context="",
tools_used=state.tools_used,
available_tools=state.available_tools,
assistant_system_prompt=state.assistant_system_prompt,
assistant_task_prompt=state.assistant_task_prompt,
),
)
for parallelization_nr, query in enumerate(
state.query_list[:1] # no parallel search for now
)
]

View File

@@ -0,0 +1,50 @@
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
from onyx.agents.agent_search.dr.sub_agents.kg_search.dr_kg_search_1_branch import (
kg_search_branch,
)
from onyx.agents.agent_search.dr.sub_agents.kg_search.dr_kg_search_2_act import (
kg_search,
)
from onyx.agents.agent_search.dr.sub_agents.kg_search.dr_kg_search_3_reduce import (
kg_search_reducer,
)
from onyx.agents.agent_search.dr.sub_agents.kg_search.dr_kg_search_conditional_edges import (
branching_router,
)
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
from onyx.utils.logger import setup_logger
logger = setup_logger()
def dr_kg_search_graph_builder() -> StateGraph:
"""
LangGraph graph builder for KG Search Sub-Agent
"""
graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
### Add nodes ###
graph.add_node("branch", kg_search_branch)
graph.add_node("act", kg_search)
graph.add_node("reducer", kg_search_reducer)
### Add edges ###
graph.add_edge(start_key=START, end_key="branch")
graph.add_conditional_edges("branch", branching_router)
graph.add_edge(start_key="act", end_key="reducer")
graph.add_edge(start_key="reducer", end_key=END)
return graph

View File

@@ -0,0 +1,46 @@
from operator import add
from typing import Annotated
from onyx.agents.agent_search.dr.models import IterationAnswer
from onyx.agents.agent_search.dr.models import OrchestratorTool
from onyx.agents.agent_search.dr.states import LoggerUpdate
from onyx.db.connector import DocumentSource
class SubAgentUpdate(LoggerUpdate):
iteration_responses: Annotated[list[IterationAnswer], add] = []
current_step_nr: int = 1
class BranchUpdate(LoggerUpdate):
branch_iteration_responses: Annotated[list[IterationAnswer], add] = []
class SubAgentInput(LoggerUpdate):
iteration_nr: int = 0
current_step_nr: int = 1
query_list: list[str] = []
context: str | None = None
active_source_types: list[DocumentSource] | None = None
tools_used: Annotated[list[str], add] = []
available_tools: dict[str, OrchestratorTool] | None = None
assistant_system_prompt: str | None = None
assistant_task_prompt: str | None = None
class SubAgentMainState(
# This includes the core state
SubAgentInput,
SubAgentUpdate,
BranchUpdate,
):
pass
class BranchInput(SubAgentInput):
parallelization_nr: int = 0
branch_question: str | None = None
class CustomToolBranchInput(LoggerUpdate):
tool_info: OrchestratorTool

View File

@@ -0,0 +1,343 @@
import re
from langchain.schema.messages import BaseMessage
from langchain.schema.messages import HumanMessage
from sqlalchemy.orm import Session
from onyx.agents.agent_search.dr.enums import ResearchAnswerPurpose
from onyx.agents.agent_search.dr.enums import ResearchType
from onyx.agents.agent_search.dr.models import AggregatedDRContext
from onyx.agents.agent_search.dr.models import IterationAnswer
from onyx.agents.agent_search.dr.models import OrchestrationClarificationInfo
from onyx.agents.agent_search.kb_search.graph_utils import build_document_context
from onyx.agents.agent_search.shared_graph_utils.operators import (
dedup_inference_section_list,
)
from onyx.configs.constants import MessageType
from onyx.context.search.models import InferenceSection
from onyx.context.search.models import SavedSearchDoc
from onyx.context.search.utils import chunks_or_sections_to_search_docs
from onyx.db.models import ChatMessage
from onyx.db.models import SearchDoc
CITATION_PREFIX = "CITE:"
def extract_document_citations(
answer: str, claims: list[str]
) -> tuple[list[int], str, list[str]]:
"""
Finds all citations of the form [1], [2, 3], etc. and returns the list of cited indices,
as well as the answer and claims with the citations replaced with [<CITATION_PREFIX>1],
etc., to help with citation deduplication later on.
"""
citations: set[int] = set()
# Pattern to match both single citations [1] and multiple citations [1, 2, 3]
# This regex matches:
# - \[(\d+)\] for single citations like [1]
# - \[(\d+(?:,\s*\d+)*)\] for multiple citations like [1, 2, 3]
pattern = re.compile(r"\[(\d+(?:,\s*\d+)*)\]")
def _extract_and_replace(match: re.Match[str]) -> str:
numbers = [int(num) for num in match.group(1).split(",")]
citations.update(numbers)
return "".join(f"[{CITATION_PREFIX}{num}]" for num in numbers)
new_answer = pattern.sub(_extract_and_replace, answer)
new_claims = [pattern.sub(_extract_and_replace, claim) for claim in claims]
return list(citations), new_answer, new_claims
def aggregate_context(
iteration_responses: list[IterationAnswer], include_documents: bool = True
) -> AggregatedDRContext:
"""
Converts the iteration response into a single string with unified citations.
For example,
it 1: the answer is x [3][4]. {3: doc_abc, 4: doc_xyz}
it 2: blah blah [1, 3]. {1: doc_xyz, 3: doc_pqr}
Output:
it 1: the answer is x [1][2].
it 2: blah blah [2][3]
[1]: doc_xyz
[2]: doc_abc
[3]: doc_pqr
"""
# dedupe and merge inference section contents
unrolled_inference_sections: list[InferenceSection] = []
is_internet_marker_dict: dict[str, bool] = {}
for iteration_response in sorted(
iteration_responses,
key=lambda x: (x.iteration_nr, x.parallelization_nr),
):
iteration_tool = iteration_response.tool
if iteration_tool == "InternetSearchTool":
is_internet = True
else:
is_internet = False
for cited_doc in iteration_response.cited_documents.values():
unrolled_inference_sections.append(cited_doc)
if cited_doc.center_chunk.document_id not in is_internet_marker_dict:
is_internet_marker_dict[cited_doc.center_chunk.document_id] = (
is_internet
)
cited_doc.center_chunk.score = None # None means maintain order
global_documents = dedup_inference_section_list(unrolled_inference_sections)
global_citations = {
doc.center_chunk.document_id: i for i, doc in enumerate(global_documents, 1)
}
# build output string
output_strings: list[str] = []
global_iteration_responses: list[IterationAnswer] = []
for iteration_response in sorted(
iteration_responses,
key=lambda x: (x.iteration_nr, x.parallelization_nr),
):
# add basic iteration info
output_strings.append(
f"Iteration: {iteration_response.iteration_nr}, "
f"Question {iteration_response.parallelization_nr}"
)
output_strings.append(f"Tool: {iteration_response.tool}")
output_strings.append(f"Question: {iteration_response.question}")
# get answer and claims with global citations
answer_str = iteration_response.answer
claims = iteration_response.claims or []
iteration_citations: list[int] = []
for local_number, cited_doc in iteration_response.cited_documents.items():
global_number = global_citations[cited_doc.center_chunk.document_id]
# translate local citations to global citations
answer_str = answer_str.replace(
f"[{CITATION_PREFIX}{local_number}]", f"[{global_number}]"
)
claims = [
claim.replace(
f"[{CITATION_PREFIX}{local_number}]", f"[{global_number}]"
)
for claim in claims
]
iteration_citations.append(global_number)
# add answer, claims, and citation info
if answer_str:
output_strings.append(f"Answer: {answer_str}")
if claims:
output_strings.append(
"Claims: " + "".join(f"\n - {claim}" for claim in claims or [])
or "No claims provided"
)
if not answer_str and not claims:
output_strings.append(
"Retrieved documents: "
+ (
"".join(
f"[{global_number}]"
for global_number in sorted(iteration_citations)
)
or "No documents retrieved"
)
)
output_strings.append("\n---\n")
# save global iteration response
global_iteration_responses.append(
IterationAnswer(
tool=iteration_response.tool,
tool_id=iteration_response.tool_id,
iteration_nr=iteration_response.iteration_nr,
parallelization_nr=iteration_response.parallelization_nr,
question=iteration_response.question,
reasoning=iteration_response.reasoning,
answer=answer_str,
cited_documents={
global_citations[doc.center_chunk.document_id]: doc
for doc in iteration_response.cited_documents.values()
},
background_info=iteration_response.background_info,
claims=claims,
additional_data=iteration_response.additional_data,
)
)
# add document contents if requested
if include_documents:
if global_documents:
output_strings.append("Cited document contents:")
for doc in global_documents:
output_strings.append(
build_document_context(
doc, global_citations[doc.center_chunk.document_id]
)
)
output_strings.append("\n---\n")
return AggregatedDRContext(
context="\n".join(output_strings),
cited_documents=global_documents,
is_internet_marker_dict=is_internet_marker_dict,
global_iteration_responses=global_iteration_responses,
)
def get_chat_history_string(chat_history: list[BaseMessage], max_messages: int) -> str:
"""
Get the chat history (up to max_messages) as a string.
"""
# get past max_messages USER, ASSISTANT message pairs
past_messages = chat_history[-max_messages * 2 :]
return (
"...\n"
if len(chat_history) > len(past_messages)
else ""
"\n".join(
("user" if isinstance(msg, HumanMessage) else "you")
+ f": {str(msg.content).strip()}"
for msg in past_messages
)
)
def get_prompt_question(
question: str, clarification: OrchestrationClarificationInfo | None
) -> str:
if clarification:
clarification_question = clarification.clarification_question
clarification_response = clarification.clarification_response
return (
f"Initial User Question: {question}\n"
f"(Clarification Question: {clarification_question}\n"
f"User Response: {clarification_response})"
)
return question
def create_tool_call_string(tool_name: str, query_list: list[str]) -> str:
"""
Create a string representation of the tool call.
"""
questions_str = "\n - ".join(query_list)
return f"Tool: {tool_name}\n\nQuestions:\n{questions_str}"
def parse_plan_to_dict(plan_text: str) -> dict[str, str]:
# Convert plan string to numbered dict format
if not plan_text:
return {}
# Split by numbered items (1., 2., 3., etc. or 1), 2), 3), etc.)
parts = re.split(r"(\d+[.)])", plan_text)
plan_dict = {}
for i in range(
1, len(parts), 2
): # Skip empty first part, then take number and text pairs
if i + 1 < len(parts):
number = parts[i].rstrip(".)") # Remove the dot or parenthesis
text = parts[i + 1].strip()
if text: # Only add if there's actual content
plan_dict[number] = text
return plan_dict
def convert_inference_sections_to_search_docs(
inference_sections: list[InferenceSection],
is_internet: bool = False,
) -> list[SavedSearchDoc]:
# Convert InferenceSections to SavedSearchDocs
search_docs = chunks_or_sections_to_search_docs(inference_sections)
for search_doc in search_docs:
search_doc.is_internet = is_internet
retrieved_saved_search_docs = [
SavedSearchDoc.from_search_doc(search_doc, db_doc_id=0)
for search_doc in search_docs
]
return retrieved_saved_search_docs
def update_db_session_with_messages(
db_session: Session,
chat_message_id: int,
chat_session_id: str,
is_agentic: bool | None,
message: str | None = None,
message_type: str | None = None,
token_count: int | None = None,
rephrased_query: str | None = None,
prompt_id: int | None = None,
citations: dict[str | int, int] | None = None,
error: str | None = None,
alternate_assistant_id: int | None = None,
overridden_model: str | None = None,
research_type: str | None = None,
research_plan: dict[str, str] | None = None,
final_documents: list[SearchDoc] | None = None,
update_parent_message: bool = True,
research_answer_purpose: ResearchAnswerPurpose | None = None,
) -> None:
chat_message = (
db_session.query(ChatMessage)
.filter(
ChatMessage.id == chat_message_id,
ChatMessage.chat_session_id == chat_session_id,
)
.first()
)
if not chat_message:
raise ValueError("Chat message with id not found") # should never happen
if message:
chat_message.message = message
if message_type:
chat_message.message_type = MessageType(message_type)
if token_count:
chat_message.token_count = token_count
if rephrased_query:
chat_message.rephrased_query = rephrased_query
if prompt_id:
chat_message.prompt_id = prompt_id
if citations:
# Convert string keys to integers to match database field type
chat_message.citations = {int(k): v for k, v in citations.items()}
if error:
chat_message.error = error
if alternate_assistant_id:
chat_message.alternate_assistant_id = alternate_assistant_id
if overridden_model:
chat_message.overridden_model = overridden_model
if research_type:
chat_message.research_type = ResearchType(research_type)
if research_plan:
chat_message.research_plan = research_plan
if final_documents:
chat_message.search_docs = final_documents
if is_agentic:
chat_message.is_agentic = is_agentic
if research_answer_purpose:
chat_message.research_answer_purpose = research_answer_purpose
if update_parent_message:
parent_chat_message = (
db_session.query(ChatMessage)
.filter(ChatMessage.id == chat_message.parent_message)
.first()
)
if parent_chat_message:
parent_chat_message.latest_child_message = chat_message.id
return

View File

@@ -6,7 +6,12 @@ from langgraph.types import StreamWriter
from onyx.agents.agent_search.kb_search.models import KGEntityDocInfo
from onyx.agents.agent_search.kb_search.models import KGExpandedGraphObjects
from onyx.agents.agent_search.kb_search.states import SubQuestionAnswerResults
from onyx.agents.agent_search.kb_search.step_definitions import STEP_DESCRIPTIONS
from onyx.agents.agent_search.kb_search.step_definitions import (
BASIC_SEARCH_STEP_DESCRIPTIONS,
)
from onyx.agents.agent_search.kb_search.step_definitions import (
KG_SEARCH_STEP_DESCRIPTIONS,
)
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkRetrievalStats
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import AgentAnswerPiece
@@ -95,14 +100,14 @@ def create_minimal_connected_query_graph(
return KGExpandedGraphObjects(entities=entities, relationships=relationships)
def stream_write_step_description(
def stream_write_kg_search_description(
writer: StreamWriter, step_nr: int, level: int = 0
) -> None:
write_custom_event(
"decomp_qs",
SubQuestionPiece(
sub_question=STEP_DESCRIPTIONS[step_nr].description,
sub_question=KG_SEARCH_STEP_DESCRIPTIONS[step_nr].description,
level=level,
level_question_num=step_nr,
),
@@ -113,10 +118,12 @@ def stream_write_step_description(
sleep(0.2)
def stream_write_step_activities(
def stream_write_kg_search_activities(
writer: StreamWriter, step_nr: int, level: int = 0
) -> None:
for activity_nr, activity in enumerate(STEP_DESCRIPTIONS[step_nr].activities):
for activity_nr, activity in enumerate(
KG_SEARCH_STEP_DESCRIPTIONS[step_nr].activities
):
write_custom_event(
"subqueries",
SubQueryPiece(
@@ -129,23 +136,25 @@ def stream_write_step_activities(
)
def stream_write_step_activity_explicit(
writer: StreamWriter, step_nr: int, query_id: int, activity: str, level: int = 0
def stream_write_basic_search_activities(
writer: StreamWriter, step_nr: int, level: int = 0
) -> None:
for activity in STEP_DESCRIPTIONS[step_nr].activities:
for activity_nr, activity in enumerate(
BASIC_SEARCH_STEP_DESCRIPTIONS[step_nr].activities
):
write_custom_event(
"subqueries",
SubQueryPiece(
sub_query=activity,
level=level,
level_question_num=step_nr,
query_id=query_id,
query_id=activity_nr + 1,
),
writer,
)
def stream_write_step_answer_explicit(
def stream_write_kg_search_answer_explicit(
writer: StreamWriter, step_nr: int, answer: str, level: int = 0
) -> None:
write_custom_event(
@@ -160,8 +169,8 @@ def stream_write_step_answer_explicit(
)
def stream_write_step_structure(writer: StreamWriter, level: int = 0) -> None:
for step_nr, step_detail in STEP_DESCRIPTIONS.items():
def stream_write_kg_search_structure(writer: StreamWriter, level: int = 0) -> None:
for step_nr, step_detail in KG_SEARCH_STEP_DESCRIPTIONS.items():
write_custom_event(
"decomp_qs",
@@ -173,7 +182,7 @@ def stream_write_step_structure(writer: StreamWriter, level: int = 0) -> None:
writer,
)
for step_nr in STEP_DESCRIPTIONS.keys():
for step_nr in KG_SEARCH_STEP_DESCRIPTIONS.keys():
write_custom_event(
"stream_finished",
@@ -195,7 +204,40 @@ def stream_write_step_structure(writer: StreamWriter, level: int = 0) -> None:
write_custom_event("stream_finished", stop_event, writer)
def stream_close_step_answer(
def stream_write_basic_search_structure(writer: StreamWriter, level: int = 0) -> None:
for step_nr, step_detail in BASIC_SEARCH_STEP_DESCRIPTIONS.items():
write_custom_event(
"decomp_qs",
SubQuestionPiece(
sub_question=step_detail.description,
level=level,
level_question_num=step_nr,
),
writer,
)
for step_nr in BASIC_SEARCH_STEP_DESCRIPTIONS:
write_custom_event(
"stream_finished",
StreamStopInfo(
stop_reason=StreamStopReason.FINISHED,
stream_type=StreamType.SUB_QUESTIONS,
level=level,
level_question_num=step_nr,
),
writer,
)
stop_event = StreamStopInfo(
stop_reason=StreamStopReason.FINISHED,
stream_type=StreamType.SUB_QUESTIONS,
level=0,
)
write_custom_event("stream_finished", stop_event, writer)
def stream_kg_search_close_step_answer(
writer: StreamWriter, step_nr: int, level: int = 0
) -> None:
stop_event = StreamStopInfo(
@@ -207,7 +249,7 @@ def stream_close_step_answer(
write_custom_event("stream_finished", stop_event, writer)
def stream_write_close_steps(writer: StreamWriter, level: int = 0) -> None:
def stream_write_kg_search_close_steps(writer: StreamWriter, level: int = 0) -> None:
stop_event = StreamStopInfo(
stop_reason=StreamStopReason.FINISHED,
stream_type=StreamType.SUB_QUESTIONS,
@@ -355,7 +397,7 @@ def get_near_empty_step_results(
Get near-empty step results from a list of step results.
"""
return SubQuestionAnswerResults(
question=STEP_DESCRIPTIONS[step_number].description,
question=KG_SEARCH_STEP_DESCRIPTIONS[step_number].description,
question_id="0_" + str(step_number),
answer=step_answer,
verified_high_quality=True,

View File

@@ -7,17 +7,23 @@ from langgraph.types import StreamWriter
from pydantic import ValidationError
from onyx.agents.agent_search.kb_search.graph_utils import get_near_empty_step_results
from onyx.agents.agent_search.kb_search.graph_utils import stream_close_step_answer
from onyx.agents.agent_search.kb_search.graph_utils import stream_write_step_activities
from onyx.agents.agent_search.kb_search.graph_utils import (
stream_write_step_answer_explicit,
stream_kg_search_close_step_answer,
)
from onyx.agents.agent_search.kb_search.graph_utils import (
stream_write_kg_search_activities,
)
from onyx.agents.agent_search.kb_search.graph_utils import (
stream_write_kg_search_answer_explicit,
)
from onyx.agents.agent_search.kb_search.graph_utils import (
stream_write_kg_search_structure,
)
from onyx.agents.agent_search.kb_search.graph_utils import stream_write_step_structure
from onyx.agents.agent_search.kb_search.models import KGQuestionEntityExtractionResult
from onyx.agents.agent_search.kb_search.models import (
KGQuestionRelationshipExtractionResult,
)
from onyx.agents.agent_search.kb_search.states import ERTExtractionUpdate
from onyx.agents.agent_search.kb_search.states import EntityRelationshipExtractionUpdate
from onyx.agents.agent_search.kb_search.states import MainState
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.utils import (
@@ -42,7 +48,7 @@ logger = setup_logger()
def extract_ert(
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> ERTExtractionUpdate:
) -> EntityRelationshipExtractionUpdate:
"""
LangGraph node to start the agentic search process.
"""
@@ -68,17 +74,17 @@ def extract_ert(
user_name = user_email.split("@")[0] or "unknown"
# first four lines duplicates from generate_initial_answer
question = graph_config.inputs.prompt_builder.raw_user_query
question = state.question
today_date = datetime.now().strftime("%A, %Y-%m-%d")
all_entity_types = get_entity_types_str(active=True)
all_relationship_types = get_relationship_types_str(active=True)
# Stream structure of substeps out to the UI
stream_write_step_structure(writer)
if state.individual_flow:
# Stream structure of substeps out to the UI
stream_write_kg_search_structure(writer)
# Now specify core activities in the step (step 1)
stream_write_step_activities(writer, _KG_STEP_NR)
stream_write_kg_search_activities(writer, _KG_STEP_NR)
# Create temporary views. TODO: move into parallel step, if ultimately materialized
tenant_id = get_current_tenant_id()
@@ -240,12 +246,13 @@ def extract_ert(
step_answer = f"""Entities and relationships have been extracted from query - \n \
Entities: {extracted_entity_string} - \n Relationships: {extracted_relationship_string}"""
stream_write_step_answer_explicit(writer, step_nr=1, answer=step_answer)
if state.individual_flow:
stream_write_kg_search_answer_explicit(writer, step_nr=1, answer=step_answer)
# Finish Step 1
stream_close_step_answer(writer, _KG_STEP_NR)
# Finish Step 1
stream_kg_search_close_step_answer(writer, _KG_STEP_NR)
return ERTExtractionUpdate(
return EntityRelationshipExtractionUpdate(
entities_types_str=all_entity_types,
relationship_types_str=all_relationship_types,
extracted_entities_w_attributes=entity_extraction_result.entities,

View File

@@ -9,10 +9,14 @@ from onyx.agents.agent_search.kb_search.graph_utils import (
create_minimal_connected_query_graph,
)
from onyx.agents.agent_search.kb_search.graph_utils import get_near_empty_step_results
from onyx.agents.agent_search.kb_search.graph_utils import stream_close_step_answer
from onyx.agents.agent_search.kb_search.graph_utils import stream_write_step_activities
from onyx.agents.agent_search.kb_search.graph_utils import (
stream_write_step_answer_explicit,
stream_kg_search_close_step_answer,
)
from onyx.agents.agent_search.kb_search.graph_utils import (
stream_write_kg_search_activities,
)
from onyx.agents.agent_search.kb_search.graph_utils import (
stream_write_kg_search_answer_explicit,
)
from onyx.agents.agent_search.kb_search.models import KGAnswerApproach
from onyx.agents.agent_search.kb_search.states import AnalysisUpdate
@@ -141,7 +145,7 @@ def analyze(
node_start_time = datetime.now()
graph_config = cast(GraphConfig, config["metadata"]["config"])
question = graph_config.inputs.prompt_builder.raw_user_query
question = state.question
entities = (
state.extracted_entities_no_attributes
) # attribute knowledge is not required for this step
@@ -150,7 +154,8 @@ def analyze(
## STEP 2 - stream out goals
stream_write_step_activities(writer, _KG_STEP_NR)
if state.individual_flow:
stream_write_kg_search_activities(writer, _KG_STEP_NR)
# Continue with node
@@ -277,9 +282,12 @@ Format: {output_format.value}, Broken down question: {broken_down_question}"
else:
query_type = KGRelationshipDetection.NO_RELATIONSHIPS.value
stream_write_step_answer_explicit(writer, step_nr=_KG_STEP_NR, answer=step_answer)
if state.individual_flow:
stream_write_kg_search_answer_explicit(
writer, step_nr=_KG_STEP_NR, answer=step_answer
)
stream_close_step_answer(writer, _KG_STEP_NR)
stream_kg_search_close_step_answer(writer, _KG_STEP_NR)
# End node

View File

@@ -8,10 +8,14 @@ from langgraph.types import StreamWriter
from sqlalchemy import text
from onyx.agents.agent_search.kb_search.graph_utils import get_near_empty_step_results
from onyx.agents.agent_search.kb_search.graph_utils import stream_close_step_answer
from onyx.agents.agent_search.kb_search.graph_utils import stream_write_step_activities
from onyx.agents.agent_search.kb_search.graph_utils import (
stream_write_step_answer_explicit,
stream_kg_search_close_step_answer,
)
from onyx.agents.agent_search.kb_search.graph_utils import (
stream_write_kg_search_activities,
)
from onyx.agents.agent_search.kb_search.graph_utils import (
stream_write_kg_search_answer_explicit,
)
from onyx.agents.agent_search.kb_search.states import KGAnswerStrategy
from onyx.agents.agent_search.kb_search.states import KGRelationshipDetection
@@ -33,8 +37,10 @@ from onyx.db.engine.sql_engine import get_db_readonly_user_session_with_current_
from onyx.db.kg_temp_view import drop_views
from onyx.llm.interfaces import LLM
from onyx.prompts.kg_prompts import ENTITY_SOURCE_DETECTION_PROMPT
from onyx.prompts.kg_prompts import ENTITY_TABLE_DESCRIPTION
from onyx.prompts.kg_prompts import RELATIONSHIP_TABLE_DESCRIPTION
from onyx.prompts.kg_prompts import SIMPLE_ENTITY_SQL_PROMPT
from onyx.prompts.kg_prompts import SIMPLE_SQL_CORRECTION_PROMPT
from onyx.prompts.kg_prompts import SIMPLE_SQL_ERROR_FIX_PROMPT
from onyx.prompts.kg_prompts import SIMPLE_SQL_PROMPT
from onyx.prompts.kg_prompts import SOURCE_DETECTION_PROMPT
from onyx.utils.logger import setup_logger
@@ -122,6 +128,22 @@ def _sql_is_aggregate_query(sql_statement: str) -> bool:
)
def _run_sql(
sql_statement: str, rel_temp_view: str, ent_temp_view: str
) -> list[dict[str, Any]]:
# check sql, just in case
_raise_error_if_sql_fails_problem_test(sql_statement, rel_temp_view, ent_temp_view)
with get_db_readonly_user_session_with_current_tenant() as db_session:
result = db_session.execute(text(sql_statement))
# Handle scalar results (like COUNT)
if sql_statement.upper().startswith("SELECT COUNT"):
scalar_result = result.scalar()
return [{"count": int(scalar_result)}] if scalar_result is not None else []
# Handle regular row results
rows = result.fetchall()
return [dict(row._mapping) for row in rows]
def _get_source_documents(
sql_statement: str,
llm: LLM,
@@ -189,7 +211,7 @@ def generate_simple_sql(
node_start_time = datetime.now()
graph_config = cast(GraphConfig, config["metadata"]["config"])
question = graph_config.inputs.prompt_builder.raw_user_query
question = state.question
entities_types_str = state.entities_types_str
relationship_types_str = state.relationship_types_str
@@ -199,7 +221,6 @@ def generate_simple_sql(
raise ValueError("kg_doc_temp_view_name is not set")
if state.kg_rel_temp_view_name is None:
raise ValueError("kg_rel_temp_view_name is not set")
if state.kg_entity_temp_view_name is None:
raise ValueError("kg_entity_temp_view_name is not set")
@@ -207,7 +228,8 @@ def generate_simple_sql(
## STEP 3 - articulate goals
stream_write_step_activities(writer, _KG_STEP_NR)
if state.individual_flow:
stream_write_kg_search_activities(writer, _KG_STEP_NR)
if graph_config.tooling.search_tool is None:
raise ValueError("Search tool is not set")
@@ -270,6 +292,12 @@ def generate_simple_sql(
)
.replace("---question---", question)
.replace("---entity_explanation_string---", entity_explanation_str)
.replace(
"---query_entities_with_attributes---",
"\n".join(state.query_graph_entities_w_attributes),
)
.replace("---today_date---", datetime.now().strftime("%Y-%m-%d"))
.replace("---user_name---", f"EMPLOYEE:{user_name}")
)
else:
simple_sql_prompt = (
@@ -289,8 +317,7 @@ def generate_simple_sql(
.replace("---user_name---", f"EMPLOYEE:{user_name}")
)
# prepare SQL query generation
# generate initial sql statement
msg = [
HumanMessage(
content=simple_sql_prompt,
@@ -298,7 +325,6 @@ def generate_simple_sql(
]
primary_llm = graph_config.tooling.primary_llm
# Grader
try:
llm_response = run_with_timeout(
KG_SQL_GENERATION_TIMEOUT,
@@ -336,53 +362,6 @@ def generate_simple_sql(
)
raise e
if state.query_type == KGRelationshipDetection.RELATIONSHIPS.value:
# Correction if needed:
correction_prompt = SIMPLE_SQL_CORRECTION_PROMPT.replace(
"---draft_sql---", sql_statement
)
msg = [
HumanMessage(
content=correction_prompt,
)
]
try:
llm_response = run_with_timeout(
KG_SQL_GENERATION_TIMEOUT,
primary_llm.invoke,
prompt=msg,
timeout_override=25,
max_tokens=1500,
)
cleaned_response = (
str(llm_response.content)
.replace("```json\n", "")
.replace("\n```", "")
)
sql_statement = (
cleaned_response.split("<sql>")[1].split("</sql>")[0].strip()
)
sql_statement = sql_statement.split(";")[0].strip() + ";"
sql_statement = sql_statement.replace("sql", "").strip()
except Exception as e:
logger.error(
f"Error in generating the sql correction: {e}. Original model response: {cleaned_response}"
)
drop_views(
allowed_docs_view_name=doc_temp_view,
kg_relationships_view_name=rel_temp_view,
kg_entity_view_name=ent_temp_view,
)
raise e
# display sql statement with view names replaced by general view names
sql_statement_display = sql_statement.replace(
state.kg_doc_temp_view_name, "<your_allowed_docs_view_name>"
@@ -437,51 +416,93 @@ def generate_simple_sql(
logger.debug(f"A3 source_documents_sql: {source_documents_sql_display}")
scalar_result = None
query_results = None
query_results = [] # if no results, will be empty (not None)
query_generation_error = None
# check sql, just in case
_raise_error_if_sql_fails_problem_test(
sql_statement, rel_temp_view, ent_temp_view
)
# run sql
try:
query_results = _run_sql(sql_statement, rel_temp_view, ent_temp_view)
if not query_results:
query_generation_error = "SQL query returned no results"
logger.warning(f"{query_generation_error}, retrying...")
except Exception as e:
query_generation_error = str(e)
logger.warning(f"Error executing SQL query: {e}, retrying...")
# fix sql and try one more time if sql query didn't work out
# if the result is still empty after this, the kg probably doesn't have the answer,
# so we update the strategy to simple and address this in the answer generation
if query_generation_error is not None:
sql_fix_prompt = (
SIMPLE_SQL_ERROR_FIX_PROMPT.replace(
"---table_description---",
(
ENTITY_TABLE_DESCRIPTION
if state.query_type
== KGRelationshipDetection.NO_RELATIONSHIPS.value
else RELATIONSHIP_TABLE_DESCRIPTION
),
)
.replace("---entity_types---", entities_types_str)
.replace("---relationship_types---", relationship_types_str)
.replace("---question---", question)
.replace("---sql_statement---", sql_statement)
.replace("---error_message---", query_generation_error)
.replace("---today_date---", datetime.now().strftime("%Y-%m-%d"))
.replace("---user_name---", f"EMPLOYEE:{user_name}")
)
msg = [HumanMessage(content=sql_fix_prompt)]
primary_llm = graph_config.tooling.primary_llm
with get_db_readonly_user_session_with_current_tenant() as db_session:
try:
result = db_session.execute(text(sql_statement))
# Handle scalar results (like COUNT)
if sql_statement.upper().startswith("SELECT COUNT"):
scalar_result = result.scalar()
query_results = (
[{"count": int(scalar_result)}]
if scalar_result is not None
else []
)
else:
# Handle regular row results
rows = result.fetchall()
query_results = [dict(row._mapping) for row in rows]
llm_response = run_with_timeout(
KG_SQL_GENERATION_TIMEOUT,
primary_llm.invoke,
prompt=msg,
timeout_override=KG_SQL_GENERATION_TIMEOUT_OVERRIDE,
max_tokens=KG_SQL_GENERATION_MAX_TOKENS,
)
cleaned_response = (
str(llm_response.content)
.replace("```json\n", "")
.replace("\n```", "")
)
sql_statement = (
cleaned_response.split("<sql>")[1].split("</sql>")[0].strip()
)
sql_statement = sql_statement.split(";")[0].strip() + ";"
sql_statement = sql_statement.replace("sql", "").strip()
sql_statement = sql_statement.replace(
"relationship_table", rel_temp_view
)
sql_statement = sql_statement.replace("entity_table", ent_temp_view)
reasoning = (
cleaned_response.split("<reasoning>")[1]
.strip()
.split("</reasoning>")[0]
)
query_results = _run_sql(sql_statement, rel_temp_view, ent_temp_view)
except Exception as e:
logger.error(f"Error executing SQL query even after retry: {e}")
# TODO: raise error on frontend
logger.error(f"Error executing SQL query: {e}")
drop_views(
allowed_docs_view_name=doc_temp_view,
kg_relationships_view_name=rel_temp_view,
kg_entity_view_name=ent_temp_view,
)
raise e
raise
source_document_results = None
if source_documents_sql is not None and source_documents_sql != sql_statement:
# check source document sql, just in case
_raise_error_if_sql_fails_problem_test(
source_documents_sql, rel_temp_view, ent_temp_view
)
with get_db_readonly_user_session_with_current_tenant() as db_session:
try:
result = db_session.execute(text(source_documents_sql))
rows = result.fetchall()
@@ -491,28 +512,16 @@ def generate_simple_sql(
for source_document_result in query_source_document_results
]
except Exception as e:
# TODO: raise error on frontend
drop_views(
allowed_docs_view_name=doc_temp_view,
kg_relationships_view_name=rel_temp_view,
kg_entity_view_name=ent_temp_view,
)
# TODO: raise warning on frontend
logger.error(f"Error executing Individualized SQL query: {e}")
elif state.query_type == KGRelationshipDetection.NO_RELATIONSHIPS.value:
# source documents should be returned for entity queries
source_document_results = [
x["source_document"] for x in query_results if "source_document" in x
]
else:
if state.query_type == KGRelationshipDetection.NO_RELATIONSHIPS.value:
# source documents should be returned for entity queries
source_document_results = [
x["source_document"]
for x in query_results
if "source_document" in x
]
else:
source_document_results = None
source_document_results = None
drop_views(
allowed_docs_view_name=doc_temp_view,
@@ -528,21 +537,25 @@ def generate_simple_sql(
main_sql_statement = sql_statement
if reasoning:
stream_write_step_answer_explicit(writer, step_nr=_KG_STEP_NR, answer=reasoning)
if reasoning and state.individual_flow:
stream_write_kg_search_answer_explicit(
writer, step_nr=_KG_STEP_NR, answer=reasoning
)
if sql_statement_display:
stream_write_step_answer_explicit(
if sql_statement_display and state.individual_flow:
stream_write_kg_search_answer_explicit(
writer,
step_nr=_KG_STEP_NR,
answer=f" \n Generated SQL: {sql_statement_display}",
)
stream_close_step_answer(writer, _KG_STEP_NR)
if state.individual_flow:
stream_kg_search_close_step_answer(writer, _KG_STEP_NR)
# Update path if too many results are retrieved
if query_results and len(query_results) > KG_MAX_DEEP_SEARCH_RESULTS:
# Update path if too many, or no results were retrieved from sql
if main_sql_statement and (
not query_results or len(query_results) > KG_MAX_DEEP_SEARCH_RESULTS
):
updated_strategy = KGAnswerStrategy.SIMPLE
else:
updated_strategy = None

View File

@@ -34,7 +34,7 @@ def construct_deep_search_filters(
node_start_time = datetime.now()
graph_config = cast(GraphConfig, config["metadata"]["config"])
question = graph_config.inputs.prompt_builder.raw_user_query
question = state.question
entities_types_str = state.entities_types_str
entities = state.query_graph_entities_no_attributes
@@ -155,7 +155,11 @@ def construct_deep_search_filters(
if div_con_structure:
for entity_type in double_grounded_entity_types:
if entity_type.grounded_source_name.lower() in div_con_structure[0].lower():
# entity_type is guaranteed to have grounded_source_name
if (
cast(str, entity_type.grounded_source_name).lower()
in div_con_structure[0].lower()
):
source_division = True
break

View File

@@ -98,16 +98,17 @@ def process_individual_deep_search(
kg_relationship_filters = None
# Step 4 - stream out the research query
write_custom_event(
"subqueries",
SubQueryPiece(
sub_query=f"{get_doc_information_for_entity(object).semantic_entity_name}",
level=0,
level_question_num=_KG_STEP_NR,
query_id=research_nr + 1,
),
writer,
)
if state.individual_flow:
write_custom_event(
"subqueries",
SubQueryPiece(
sub_query=f"{get_doc_information_for_entity(object).semantic_entity_name}",
level=0,
level_question_num=_KG_STEP_NR,
query_id=research_nr + 1,
),
writer,
)
if source_filters and (len(source_filters) > KG_MAX_SEARCH_DOCUMENTS):
logger.debug(

View File

@@ -7,9 +7,11 @@ from langgraph.types import StreamWriter
from onyx.agents.agent_search.kb_search.graph_utils import build_document_context
from onyx.agents.agent_search.kb_search.graph_utils import get_near_empty_step_results
from onyx.agents.agent_search.kb_search.graph_utils import stream_close_step_answer
from onyx.agents.agent_search.kb_search.graph_utils import (
stream_write_step_answer_explicit,
stream_kg_search_close_step_answer,
)
from onyx.agents.agent_search.kb_search.graph_utils import (
stream_write_kg_search_answer_explicit,
)
from onyx.agents.agent_search.kb_search.graph_utils import write_custom_event
from onyx.agents.agent_search.kb_search.ops import research
@@ -49,7 +51,7 @@ def filtered_search(
graph_config = cast(GraphConfig, config["metadata"]["config"])
search_tool = graph_config.tooling.search_tool
question = graph_config.inputs.prompt_builder.raw_user_query
question = state.question
if not search_tool:
raise ValueError("search_tool is not provided")
@@ -72,17 +74,18 @@ def filtered_search(
logger.debug(f"kg_entity_filters: {kg_entity_filters}")
logger.debug(f"kg_relationship_filters: {kg_relationship_filters}")
# Step 4 - stream out the research query
write_custom_event(
"subqueries",
SubQueryPiece(
sub_query="Conduct a filtered search",
level=0,
level_question_num=_KG_STEP_NR,
query_id=1,
),
writer,
)
if state.individual_flow:
# Step 4 - stream out the research query
write_custom_event(
"subqueries",
SubQueryPiece(
sub_query="Conduct a filtered search",
level=0,
level_question_num=_KG_STEP_NR,
query_id=1,
),
writer,
)
retrieved_docs = cast(
list[InferenceSection],
@@ -165,11 +168,12 @@ def filtered_search(
step_answer = "Filtered search is complete."
stream_write_step_answer_explicit(
writer, answer=step_answer, level=0, step_nr=_KG_STEP_NR
)
if state.individual_flow:
stream_write_kg_search_answer_explicit(
writer, answer=step_answer, level=0, step_nr=_KG_STEP_NR
)
stream_close_step_answer(writer, level=0, step_nr=_KG_STEP_NR)
stream_kg_search_close_step_answer(writer, level=0, step_nr=_KG_STEP_NR)
return ConsolidatedResearchUpdate(
consolidated_research_object_results_str=filtered_search_answer,

View File

@@ -5,9 +5,11 @@ from langgraph.types import StreamWriter
from onyx.agents.agent_search.kb_search.graph_utils import get_near_empty_step_results
from onyx.agents.agent_search.kb_search.graph_utils import rename_entities_in_answer
from onyx.agents.agent_search.kb_search.graph_utils import stream_close_step_answer
from onyx.agents.agent_search.kb_search.graph_utils import (
stream_write_step_answer_explicit,
stream_kg_search_close_step_answer,
)
from onyx.agents.agent_search.kb_search.graph_utils import (
stream_write_kg_search_answer_explicit,
)
from onyx.agents.agent_search.kb_search.states import ConsolidatedResearchUpdate
from onyx.agents.agent_search.kb_search.states import MainState
@@ -41,11 +43,12 @@ def consolidate_individual_deep_search(
step_answer = "All research is complete. Consolidating results..."
stream_write_step_answer_explicit(
writer, answer=step_answer, level=0, step_nr=_KG_STEP_NR
)
if state.individual_flow:
stream_write_kg_search_answer_explicit(
writer, answer=step_answer, level=0, step_nr=_KG_STEP_NR
)
stream_close_step_answer(writer, level=0, step_nr=_KG_STEP_NR)
stream_kg_search_close_step_answer(writer, level=0, step_nr=_KG_STEP_NR)
return ConsolidatedResearchUpdate(
consolidated_research_object_results_str=consolidated_research_object_results_str,

View File

@@ -4,9 +4,11 @@ from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.kb_search.graph_utils import get_near_empty_step_results
from onyx.agents.agent_search.kb_search.graph_utils import stream_close_step_answer
from onyx.agents.agent_search.kb_search.graph_utils import (
stream_write_step_answer_explicit,
stream_kg_search_close_step_answer,
)
from onyx.agents.agent_search.kb_search.graph_utils import (
stream_write_kg_search_answer_explicit,
)
from onyx.agents.agent_search.kb_search.states import MainState
from onyx.agents.agent_search.kb_search.states import ResultsDataUpdate
@@ -66,28 +68,26 @@ def process_kg_only_answers(
# we use this stream write explicitly
write_custom_event(
"subqueries",
SubQueryPiece(
sub_query="Formatted References",
level=0,
level_question_num=_KG_STEP_NR,
query_id=1,
),
writer,
)
query_results_list = []
if state.individual_flow:
write_custom_event(
"subqueries",
SubQueryPiece(
sub_query="Formatted References",
level=0,
level_question_num=_KG_STEP_NR,
query_id=1,
),
writer,
)
if query_results:
for query_result in query_results:
query_results_list.append(
str(query_result).replace("::", ":: ").capitalize()
)
query_results_data_str = "\n".join(
str(query_result).replace("::", ":: ").capitalize()
for query_result in query_results
)
else:
raise ValueError("No query results were found")
query_results_data_str = "\n".join(query_results_list)
logger.warning("No query results were found")
query_results_data_str = "(No query results were found)"
source_reference_result_str = _get_formated_source_reference_results(
source_document_results
@@ -99,9 +99,12 @@ def process_kg_only_answers(
"No further research is needed, the answer is derived from the knowledge graph."
)
stream_write_step_answer_explicit(writer, step_nr=_KG_STEP_NR, answer=step_answer)
if state.individual_flow:
stream_write_kg_search_answer_explicit(
writer, step_nr=_KG_STEP_NR, answer=step_answer
)
stream_close_step_answer(writer, _KG_STEP_NR)
stream_kg_search_close_step_answer(writer, _KG_STEP_NR)
return ResultsDataUpdate(
query_results_data_str=query_results_data_str,

View File

@@ -7,14 +7,17 @@ from langgraph.types import StreamWriter
from onyx.access.access import get_acl_for_user
from onyx.agents.agent_search.kb_search.graph_utils import rename_entities_in_answer
from onyx.agents.agent_search.kb_search.graph_utils import stream_write_close_steps
from onyx.agents.agent_search.kb_search.graph_utils import (
stream_write_kg_search_close_steps,
)
from onyx.agents.agent_search.kb_search.ops import research
from onyx.agents.agent_search.kb_search.states import MainOutput
from onyx.agents.agent_search.kb_search.states import FinalAnswerUpdate
from onyx.agents.agent_search.kb_search.states import MainState
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.calculations import (
get_answer_generation_documents,
)
from onyx.agents.agent_search.shared_graph_utils.llm import get_answer_from_llm
from onyx.agents.agent_search.shared_graph_utils.llm import stream_llm_answer
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
@@ -42,7 +45,7 @@ logger = setup_logger()
def generate_answer(
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> MainOutput:
) -> FinalAnswerUpdate:
"""
LangGraph node to start the agentic search process.
"""
@@ -50,7 +53,9 @@ def generate_answer(
node_start_time = datetime.now()
graph_config = cast(GraphConfig, config["metadata"]["config"])
question = graph_config.inputs.prompt_builder.raw_user_query
question = state.question
final_answer: str | None = None
user = (
graph_config.tooling.search_tool.user
@@ -69,7 +74,8 @@ def generate_answer(
# DECLARE STEPS DONE
stream_write_close_steps(writer)
if state.individual_flow:
stream_write_kg_search_close_steps(writer)
## MAIN ANSWER
@@ -128,16 +134,17 @@ def generate_answer(
get_section_relevance=lambda: relevance_list,
search_tool=graph_config.tooling.search_tool,
):
write_custom_event(
"tool_response",
ExtendedToolResponse(
id=tool_response.id,
response=tool_response.response,
level=0,
level_question_num=0, # 0, 0 is the base question
),
writer,
)
if state.individual_flow:
write_custom_event(
"tool_response",
ExtendedToolResponse(
id=tool_response.id,
response=tool_response.response,
level=0,
level_question_num=0, # 0, 0 is the base question
),
writer,
)
# continue with the answer generation
@@ -206,24 +213,40 @@ def generate_answer(
)
]
try:
run_with_timeout(
KG_TIMEOUT_LLM_INITIAL_ANSWER_GENERATION,
lambda: stream_llm_answer(
llm=graph_config.tooling.fast_llm,
prompt=msg,
event_name="initial_agent_answer",
writer=writer,
agent_answer_level=0,
agent_answer_question_num=0,
agent_answer_type="agent_level_answer",
if state.individual_flow:
stream_results, _, _ = run_with_timeout(
KG_TIMEOUT_LLM_INITIAL_ANSWER_GENERATION,
lambda: stream_llm_answer(
llm=graph_config.tooling.primary_llm,
prompt=msg,
event_name="initial_agent_answer",
writer=writer,
agent_answer_level=0,
agent_answer_question_num=0,
agent_answer_type="agent_level_answer",
timeout_override=KG_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION,
max_tokens=KG_MAX_TOKENS_ANSWER_GENERATION,
),
)
final_answer = "".join(stream_results)
else:
final_answer = get_answer_from_llm(
llm=graph_config.tooling.primary_llm,
prompt=output_format_prompt,
stream=False,
json_string_flag=False,
timeout_override=KG_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION,
max_tokens=KG_MAX_TOKENS_ANSWER_GENERATION,
),
)
)
except Exception as e:
raise ValueError(f"Could not generate the answer. Error {e}")
return MainOutput(
return FinalAnswerUpdate(
final_answer=final_answer,
retrieved_documents=answer_generation_documents.context_documents,
step_results=[],
remarks=[],
log_messages=[
get_langgraph_node_log_string(
graph_component="main",

View File

@@ -48,6 +48,8 @@ def log_data(
)
return MainOutput(
final_answer=state.final_answer,
retrieved_documents=state.retrieved_documents,
log_messages=[
get_langgraph_node_log_string(
graph_component="main",

View File

@@ -120,7 +120,7 @@ class ResearchObjectOutput(LoggerUpdate):
research_object_results: Annotated[list[dict[str, Any]], add] = []
class ERTExtractionUpdate(LoggerUpdate):
class EntityRelationshipExtractionUpdate(LoggerUpdate):
entities_types_str: str = ""
relationship_types_str: str = ""
extracted_entities_w_attributes: list[str] = []
@@ -144,7 +144,13 @@ class ResearchObjectUpdate(LoggerUpdate):
## Graph Input State
class MainInput(CoreState):
pass
question: str
individual_flow: bool = True # used for UI display purposes
class FinalAnswerUpdate(LoggerUpdate):
final_answer: str | None = None
retrieved_documents: list[InferenceSection] | None = None
## Graph State
@@ -154,7 +160,7 @@ class MainState(
ToolChoiceInput,
ToolCallUpdate,
ToolChoiceUpdate,
ERTExtractionUpdate,
EntityRelationshipExtractionUpdate,
AnalysisUpdate,
SQLSimpleGenerationUpdate,
ResultsDataUpdate,
@@ -162,6 +168,7 @@ class MainState(
DeepSearchFilterUpdate,
ResearchObjectUpdate,
ConsolidatedResearchUpdate,
FinalAnswerUpdate,
):
pass
@@ -169,6 +176,8 @@ class MainState(
## Graph Output State - presently not used
class MainOutput(TypedDict):
log_messages: list[str]
final_answer: str | None
retrieved_documents: list[InferenceSection] | None
class ResearchObjectInput(LoggerUpdate):
@@ -179,3 +188,4 @@ class ResearchObjectInput(LoggerUpdate):
source_division: bool | None
source_entity_filters: list[str] | None
segment_type: str
individual_flow: bool = True # used for UI display purposes

View File

@@ -1,6 +1,6 @@
from onyx.agents.agent_search.kb_search.models import KGSteps
STEP_DESCRIPTIONS: dict[int, KGSteps] = {
KG_SEARCH_STEP_DESCRIPTIONS: dict[int, KGSteps] = {
1: KGSteps(
description="Analyzing the question...",
activities=[
@@ -27,3 +27,7 @@ STEP_DESCRIPTIONS: dict[int, KGSteps] = {
description="Conducting further research on source documents...", activities=[]
),
}
BASIC_SEARCH_STEP_DESCRIPTIONS: dict[int, KGSteps] = {
1: KGSteps(description="Conducting a standard search...", activities=[]),
}

View File

@@ -4,6 +4,7 @@ from pydantic import BaseModel
from pydantic import model_validator
from sqlalchemy.orm import Session
from onyx.agents.agent_search.dr.enums import ResearchType
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.context.search.models import RerankingDetails
from onyx.db.models import Persona
@@ -72,6 +73,7 @@ class GraphSearchConfig(BaseModel):
skip_gen_ai_answer_generation: bool = False
allow_agent_reranking: bool = False
kg_config_settings: KGConfigSettings = KGConfigSettings()
research_type: ResearchType = ResearchType.THOUGHTFUL
class GraphConfig(BaseModel):

View File

@@ -271,7 +271,10 @@ def choose_tool(
should_stream_answer
and not agent_config.behavior.skip_gen_ai_answer_generation,
writer,
)
).ai_message_chunk
if tool_message is None:
raise ValueError("No tool message emitted by LLM")
# If no tool calls are emitted by the LLM, we should not choose a tool
if len(tool_message.tool_calls) == 0:

View File

@@ -4,6 +4,7 @@ from langchain_core.messages import AIMessageChunk
from langchain_core.runnables.config import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.basic.models import BasicSearchProcessedStreamResults
from onyx.agents.agent_search.basic.states import BasicOutput
from onyx.agents.agent_search.basic.states import BasicState
from onyx.agents.agent_search.basic.utils import process_llm_stream
@@ -21,6 +22,7 @@ from onyx.tools.tool_implementations.search_like_tool_utils import (
from onyx.utils.logger import setup_logger
from onyx.utils.timing import log_function_time
logger = setup_logger()
@@ -62,7 +64,9 @@ def basic_use_tool_response(
for section in dedupe_documents(search_response_summary.top_sections)[0]
]
new_tool_call_chunk = AIMessageChunk(content="")
new_tool_call_chunk = BasicSearchProcessedStreamResults(
ai_message_chunk=AIMessageChunk(content=""), full_answer=None
)
if not agent_config.behavior.skip_gen_ai_answer_generation:
stream = llm.stream(
prompt=new_prompt_builder.build(),
@@ -80,4 +84,9 @@ def basic_use_tool_response(
displayed_search_results=initial_search_results or final_search_results,
)
return BasicOutput(tool_call_chunk=new_tool_call_chunk)
return BasicOutput(
tool_call_chunk=new_tool_call_chunk.ai_message_chunk,
full_answer=new_tool_call_chunk.full_answer,
cited_references=new_tool_call_chunk.cited_references,
retrieved_documents=new_tool_call_chunk.retrieved_documents,
)

View File

@@ -18,79 +18,37 @@ from onyx.agents.agent_search.deep_search.main.graph_builder import (
from onyx.agents.agent_search.deep_search.main.states import (
MainInput as MainInput,
)
from onyx.agents.agent_search.dr.graph_builder import dr_graph_builder
from onyx.agents.agent_search.dr.states import MainInput as DRMainInput
from onyx.agents.agent_search.kb_search.graph_builder import kb_graph_builder
from onyx.agents.agent_search.kb_search.states import MainInput as KBMainInput
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
from onyx.chat.models import AgentAnswerPiece
from onyx.chat.models import AnswerPacket
from onyx.chat.models import AnswerStream
from onyx.chat.models import ExtendedToolResponse
from onyx.chat.models import RefinedAnswerImprovement
from onyx.chat.models import StreamingError
from onyx.chat.models import StreamStopInfo
from onyx.chat.models import SubQueryPiece
from onyx.chat.models import SubQuestionPiece
from onyx.chat.models import ToolResponse
from onyx.context.search.models import SearchRequest
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.llm.factory import get_default_llms
from onyx.server.query_and_chat.streaming_models import Packet
from onyx.tools.tool_runner import ToolCallKickoff
from onyx.utils.logger import setup_logger
logger = setup_logger()
GraphInput = BasicInput | MainInput | DCMainInput | KBMainInput | DRMainInput
_COMPILED_GRAPH: CompiledStateGraph | None = None
def _parse_agent_event(
event: StreamEvent,
) -> AnswerPacket | None:
"""
Parse the event into a typed object.
Return None if we are not interested in the event.
"""
event_type = event["event"]
# We always just yield the event data, but this piece is useful for two development reasons:
# 1. It's a list of the names of every place we dispatch a custom event
# 2. We maintain the intended types yielded by each event
if event_type == "on_custom_event":
if event["name"] == "decomp_qs":
return cast(SubQuestionPiece, event["data"])
elif event["name"] == "subqueries":
return cast(SubQueryPiece, event["data"])
elif event["name"] == "sub_answers":
return cast(AgentAnswerPiece, event["data"])
elif event["name"] == "stream_finished":
return cast(StreamStopInfo, event["data"])
elif event["name"] == "initial_agent_answer":
return cast(AgentAnswerPiece, event["data"])
elif event["name"] == "refined_agent_answer":
return cast(AgentAnswerPiece, event["data"])
elif event["name"] == "start_refined_answer_creation":
return cast(ToolCallKickoff, event["data"])
elif event["name"] == "tool_response":
return cast(ToolResponse, event["data"])
elif event["name"] == "basic_response":
return cast(AnswerPacket, event["data"])
elif event["name"] == "refined_answer_improvement":
return cast(RefinedAnswerImprovement, event["data"])
elif event["name"] == "refined_sub_question_creation_error":
return cast(StreamingError, event["data"])
else:
logger.error(f"Unknown event name: {event['name']}")
return None
logger.error(f"Unknown event type: {event_type}")
return None
def manage_sync_streaming(
compiled_graph: CompiledStateGraph,
config: GraphConfig,
graph_input: BasicInput | MainInput | DCMainInput | KBMainInput,
graph_input: GraphInput,
) -> Iterable[StreamEvent]:
message_id = config.persistence.message_id if config.persistence else None
for event in compiled_graph.stream(
@@ -104,16 +62,14 @@ def manage_sync_streaming(
def run_graph(
compiled_graph: CompiledStateGraph,
config: GraphConfig,
input: BasicInput | MainInput | DCMainInput | KBMainInput,
input: GraphInput,
) -> AnswerStream:
for event in manage_sync_streaming(
compiled_graph=compiled_graph, config=config, graph_input=input
):
if not (parsed_object := _parse_agent_event(event)):
continue
yield parsed_object
yield cast(Packet, event["data"])
# It doesn't actually take very long to load the graph, but we'd rather
@@ -154,16 +110,23 @@ def run_kb_graph(
) -> AnswerStream:
graph = kb_graph_builder()
compiled_graph = graph.compile()
input = KBMainInput(log_messages=[])
yield ToolCallKickoff(
tool_name="agent_search_0",
tool_args={"query": config.inputs.prompt_builder.raw_user_query},
input = KBMainInput(
log_messages=[], question=config.inputs.prompt_builder.raw_user_query
)
yield from run_graph(compiled_graph, config, input)
def run_dr_graph(
config: GraphConfig,
) -> AnswerStream:
graph = dr_graph_builder()
compiled_graph = graph.compile()
input = DRMainInput(log_messages=[])
yield from run_graph(compiled_graph, config, input)
def run_dc_graph(
config: GraphConfig,
) -> AnswerStream:

View File

@@ -1,12 +1,32 @@
import re
from datetime import datetime
from typing import cast
from typing import Literal
from typing import Type
from typing import TypeVar
from langchain.schema.language_model import LanguageModelInput
from langchain_core.messages import HumanMessage
from langgraph.types import StreamWriter
from litellm import get_supported_openai_params
from litellm import supports_response_schema
from pydantic import BaseModel
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import AgentAnswerPiece
from onyx.chat.stream_processing.citation_processing import CitationProcessorGraph
from onyx.chat.stream_processing.citation_processing import LlmDoc
from onyx.llm.interfaces import LLM
from onyx.llm.interfaces import ToolChoiceOptions
from onyx.server.query_and_chat.streaming_models import CitationInfo
from onyx.server.query_and_chat.streaming_models import MessageDelta
from onyx.server.query_and_chat.streaming_models import ReasoningDelta
from onyx.utils.threadpool_concurrency import run_with_timeout
SchemaType = TypeVar("SchemaType", bound=BaseModel)
# match ```json{...}``` or ```{...}```
JSON_PATTERN = re.compile(r"```(?:json)?\s*(\{.*?\})\s*```", re.DOTALL)
def stream_llm_answer(
@@ -19,7 +39,11 @@ def stream_llm_answer(
agent_answer_type: Literal["agent_level_answer", "agent_sub_answer"],
timeout_override: int | None = None,
max_tokens: int | None = None,
) -> tuple[list[str], list[float]]:
answer_piece: str | None = None,
ind: int | None = None,
context_docs: list[LlmDoc] | None = None,
replace_citations: bool = False,
) -> tuple[list[str], list[float], list[CitationInfo]]:
"""Stream the initial answer from the LLM.
Args:
@@ -32,16 +56,32 @@ def stream_llm_answer(
agent_answer_type: The type of answer ("agent_level_answer" or "agent_sub_answer").
timeout_override: The LLM timeout to use.
max_tokens: The LLM max tokens to use.
answer_piece: The type of answer piece to write.
ind: The index of the answer piece.
tools: The tools to use.
tool_choice: The tool choice to use.
structured_response_format: The structured response format to use.
Returns:
A tuple of the response and the dispatch timings.
"""
response: list[str] = []
dispatch_timings: list[float] = []
citation_infos: list[CitationInfo] = []
if context_docs:
citation_processor = CitationProcessorGraph(
context_docs=context_docs,
)
else:
replace_citations = False
for message in llm.stream(
prompt, timeout_override=timeout_override, max_tokens=max_tokens
prompt,
timeout_override=timeout_override,
max_tokens=max_tokens,
):
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
content = message.content
if not isinstance(content, str):
@@ -50,19 +90,153 @@ def stream_llm_answer(
)
start_stream_token = datetime.now()
write_custom_event(
event_name,
AgentAnswerPiece(
answer_piece=content,
level=agent_answer_level,
level_question_num=agent_answer_question_num,
answer_type=agent_answer_type,
),
writer,
)
if answer_piece == "message_delta":
if ind is None:
raise ValueError("index is required when answer_piece is message_delta")
if replace_citations:
processed_token = citation_processor.process_token(content)
if isinstance(processed_token, tuple):
content = processed_token[0]
citation_infos.extend(processed_token[1])
elif isinstance(processed_token, str):
content = processed_token
else:
continue
write_custom_event(
ind,
MessageDelta(content=content, type="message_delta"),
writer,
)
elif answer_piece == "reasoning_delta":
if ind is None:
raise ValueError(
"index is required when answer_piece is reasoning_delta"
)
write_custom_event(
ind,
ReasoningDelta(reasoning=content, type="reasoning_delta"),
writer,
)
else:
raise ValueError(f"Invalid answer piece: {answer_piece}")
end_stream_token = datetime.now()
dispatch_timings.append((end_stream_token - start_stream_token).microseconds)
response.append(content)
return response, dispatch_timings
return response, dispatch_timings, citation_infos
def invoke_llm_json(
llm: LLM,
prompt: LanguageModelInput,
schema: Type[SchemaType],
tools: list[dict] | None = None,
tool_choice: ToolChoiceOptions | None = None,
timeout_override: int | None = None,
max_tokens: int | None = None,
) -> SchemaType:
"""
Invoke an LLM, forcing it to respond in a specified JSON format if possible,
and return an object of that schema.
"""
# check if the model supports response_format: json_schema
supports_json = "response_format" in (
get_supported_openai_params(llm.config.model_name, llm.config.model_provider)
or []
) and supports_response_schema(llm.config.model_name, llm.config.model_provider)
response_content = str(
llm.invoke(
prompt,
tools=tools,
tool_choice=tool_choice,
timeout_override=timeout_override,
max_tokens=max_tokens,
**cast(
dict, {"structured_response_format": schema} if supports_json else {}
),
).content
)
if not supports_json:
# remove newlines as they often lead to json decoding errors
response_content = response_content.replace("\n", " ")
# hope the prompt is structured in a way a json is outputted...
json_block_match = JSON_PATTERN.search(response_content)
if json_block_match:
response_content = json_block_match.group(1)
else:
first_bracket = response_content.find("{")
last_bracket = response_content.rfind("}")
response_content = response_content[first_bracket : last_bracket + 1]
return schema.model_validate_json(response_content)
def get_answer_from_llm(
llm: LLM,
prompt: str,
timeout: int = 25,
timeout_override: int = 5,
max_tokens: int = 500,
stream: bool = False,
writer: StreamWriter = lambda _: None,
agent_answer_level: int = 0,
agent_answer_question_num: int = 0,
agent_answer_type: Literal[
"agent_sub_answer", "agent_level_answer"
] = "agent_level_answer",
json_string_flag: bool = False,
) -> str:
msg = [
HumanMessage(
content=prompt,
)
]
if stream:
# TODO - adjust for new UI. This is currently not working for current UI/Basic Search
stream_response, _, _ = run_with_timeout(
timeout,
lambda: stream_llm_answer(
llm=llm,
prompt=msg,
event_name="sub_answers",
writer=writer,
agent_answer_level=agent_answer_level,
agent_answer_question_num=agent_answer_question_num,
agent_answer_type=agent_answer_type,
timeout_override=timeout_override,
max_tokens=max_tokens,
),
)
content = "".join(stream_response)
else:
llm_response = run_with_timeout(
timeout,
llm.invoke,
prompt=msg,
timeout_override=timeout_override,
max_tokens=max_tokens,
)
content = str(llm_response.content)
cleaned_response = content
if json_string_flag:
cleaned_response = (
str(content).replace("```json\n", "").replace("\n```", "").replace("\n", "")
)
first_bracket = cleaned_response.find("{")
last_bracket = cleaned_response.rfind("}")
cleaned_response = cleaned_response[first_bracket : last_bracket + 1]
return cleaned_response

View File

@@ -73,6 +73,7 @@ from onyx.prompts.agent_search import (
HISTORY_CONTEXT_SUMMARY_PROMPT,
)
from onyx.prompts.prompt_utils import handle_onyx_date_awareness
from onyx.server.query_and_chat.streaming_models import Packet
from onyx.tools.force import ForceUseTool
from onyx.tools.models import SearchToolOverrideKwargs
from onyx.tools.tool_constructor import SearchToolConfig
@@ -353,7 +354,7 @@ def dispatch_main_answer_stop_info(level: int, writer: StreamWriter) -> None:
stream_type=StreamType.MAIN_ANSWER,
level=level,
)
write_custom_event("stream_finished", stop_event, writer)
write_custom_event(0, stop_event, writer)
def retrieve_search_docs(
@@ -438,9 +439,41 @@ class CustomStreamEvent(TypedDict):
def write_custom_event(
name: str, event: AnswerPacket, stream_writer: StreamWriter
ind: int,
event: AnswerPacket,
stream_writer: StreamWriter,
) -> None:
stream_writer(CustomStreamEvent(event="on_custom_event", name=name, data=event))
# For types that are in PacketObj, wrap in Packet
# For types like StreamStopInfo that frontend handles directly, stream directly
if hasattr(event, "stop_reason"): # StreamStopInfo
stream_writer(
CustomStreamEvent(
event="on_custom_event",
data=event,
name="",
)
)
else:
# Try to wrap in Packet for types that are compatible
pass
try:
stream_writer(
CustomStreamEvent(
event="on_custom_event",
data=Packet(ind=ind, obj=event),
name="",
)
)
except Exception:
# Fallback: stream directly if Packet wrapping fails
stream_writer(
CustomStreamEvent(
event="on_custom_event",
data=event,
name="",
)
)
def relevance_from_docs(

View File

@@ -0,0 +1,39 @@
from typing import Any
from langchain_core.messages import BaseMessage
from langchain_core.messages import HumanMessage
from langchain_core.messages import SystemMessage
from onyx.context.search.models import InferenceSection
def create_citation_format_list(
document_citations: list[InferenceSection],
) -> list[dict[str, Any]]:
citation_list: list[dict[str, Any]] = []
for document_citation in document_citations:
document_citation_dict = {
"link": "",
"blurb": document_citation.center_chunk.blurb,
"content": document_citation.center_chunk.content,
"metadata": document_citation.center_chunk.metadata,
"updated_at": str(document_citation.center_chunk.updated_at),
"document_id": document_citation.center_chunk.document_id,
"source_type": "file",
"source_links": document_citation.center_chunk.source_links,
"match_highlights": document_citation.center_chunk.match_highlights,
"semantic_identifier": document_citation.center_chunk.semantic_identifier,
}
citation_list.append(document_citation_dict)
return citation_list
def create_question_prompt(
system_prompt: str | None, human_prompt: str
) -> list[BaseMessage]:
return [
SystemMessage(content=system_prompt or ""),
HumanMessage(content=human_prompt),
]

View File

@@ -1,9 +1,11 @@
from collections import defaultdict
from collections.abc import Callable
from typing import Any
from uuid import UUID
from sqlalchemy.orm import Session
from onyx.agents.agent_search.dr.enums import ResearchType
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.models import GraphInputs
from onyx.agents.agent_search.models import GraphPersistence
@@ -12,12 +14,11 @@ from onyx.agents.agent_search.models import GraphTooling
from onyx.agents.agent_search.run_graph import run_agent_search_graph
from onyx.agents.agent_search.run_graph import run_basic_graph
from onyx.agents.agent_search.run_graph import run_dc_graph
from onyx.agents.agent_search.run_graph import run_kb_graph
from onyx.agents.agent_search.run_graph import run_dr_graph
from onyx.chat.models import AgentAnswerPiece
from onyx.chat.models import AnswerPacket
from onyx.chat.models import AnswerStream
from onyx.chat.models import AnswerStyleConfig
from onyx.chat.models import CitationInfo
from onyx.chat.models import OnyxAnswerPiece
from onyx.chat.models import StreamStopInfo
from onyx.chat.models import StreamStopReason
@@ -32,6 +33,7 @@ from onyx.db.kg_config import get_kg_config_settings
from onyx.db.models import Persona
from onyx.file_store.utils import InMemoryChatFile
from onyx.llm.interfaces import LLM
from onyx.server.query_and_chat.streaming_models import CitationInfo
from onyx.tools.force import ForceUseTool
from onyx.tools.tool import Tool
from onyx.tools.tool_implementations.search.search_tool import SearchTool
@@ -68,6 +70,8 @@ class Answer:
skip_gen_ai_answer_generation: bool = False,
is_connected: Callable[[], bool] | None = None,
use_agentic_search: bool = False,
research_type: ResearchType | None = None,
research_plan: dict[str, Any] | None = None,
) -> None:
self.is_connected: Callable[[], bool] | None = is_connected
self._processed_stream: list[AnswerPacket] | None = None
@@ -124,6 +128,9 @@ class Answer:
allow_agent_reranking=allow_agent_reranking,
perform_initial_search_decomposition=INITIAL_SEARCH_DECOMPOSITION_ENABLED,
kg_config_settings=get_kg_config_settings(),
research_type=(
ResearchType.DEEP if use_agentic_search else ResearchType.THOUGHTFUL
),
)
self.graph_config = GraphConfig(
inputs=self.graph_inputs,
@@ -138,12 +145,10 @@ class Answer:
yield from self._processed_stream
return
if self.graph_config.behavior.use_agentic_search and (
self.graph_config.inputs.persona
and self.graph_config.behavior.kg_config_settings.KG_ENABLED
and self.graph_config.inputs.persona.name.startswith("KG Beta")
):
run_langgraph = run_kb_graph
# TODO: add toggle in UI with customizable TimeBudget
if self.graph_config.inputs.persona:
run_langgraph = run_dr_graph
elif self.graph_config.behavior.use_agentic_search:
run_langgraph = run_agent_search_graph
elif (
@@ -210,23 +215,6 @@ class Answer:
return citations
def citations_by_subquestion(self) -> dict[SubQuestionKey, list[CitationInfo]]:
citations_by_subquestion: dict[SubQuestionKey, list[CitationInfo]] = (
defaultdict(list)
)
basic_subq_key = SubQuestionKey(level=BASIC_KEY[0], question_num=BASIC_KEY[1])
for packet in self.processed_streamed_output:
if isinstance(packet, CitationInfo):
if packet.level_question_num is not None and packet.level is not None:
citations_by_subquestion[
SubQuestionKey(
level=packet.level, question_num=packet.level_question_num
)
].append(packet)
elif packet.level is None:
citations_by_subquestion[basic_subq_key].append(packet)
return citations_by_subquestion
def is_cancelled(self) -> bool:
if self._is_cancelled:
return True

View File

@@ -13,15 +13,16 @@ from onyx.background.celery.tasks.kg_processing.kg_indexing import (
from onyx.background.celery.tasks.kg_processing.kg_indexing import (
try_creating_kg_source_reset_task,
)
from onyx.chat.models import CitationInfo
from onyx.chat.models import LlmDoc
from onyx.chat.models import PersonaOverrideConfig
from onyx.chat.models import ThreadMessage
from onyx.configs.constants import DEFAULT_PERSONA_ID
from onyx.configs.constants import MessageType
from onyx.configs.constants import TMP_DRALPHA_PERSONA_NAME
from onyx.context.search.models import InferenceSection
from onyx.context.search.models import RerankingDetails
from onyx.context.search.models import RetrievalDetails
from onyx.context.search.models import SavedSearchDoc
from onyx.db.chat import create_chat_session
from onyx.db.chat import get_chat_messages_by_session
from onyx.db.kg_config import get_kg_config_settings
@@ -31,6 +32,7 @@ from onyx.db.llm import fetch_existing_tools
from onyx.db.models import ChatMessage
from onyx.db.models import Persona
from onyx.db.models import Prompt
from onyx.db.models import SearchDoc
from onyx.db.models import Tool
from onyx.db.models import User
from onyx.db.prompts import get_prompts_by_ids
@@ -42,6 +44,7 @@ from onyx.kg.setup.kg_default_entity_definitions import (
from onyx.llm.models import PreviousMessage
from onyx.natural_language_processing.utils import BaseTokenizer
from onyx.server.query_and_chat.models import CreateChatMessageRequest
from onyx.server.query_and_chat.streaming_models import CitationInfo
from onyx.tools.tool_implementations.custom.custom_tool import (
build_custom_tools_from_openapi_schema_and_headers,
)
@@ -113,6 +116,42 @@ def llm_doc_from_inference_section(inference_section: InferenceSection) -> LlmDo
)
def saved_search_docs_from_llm_docs(
llm_docs: list[LlmDoc] | None,
) -> list[SavedSearchDoc]:
"""Convert LlmDoc objects to SavedSearchDoc format."""
if not llm_docs:
return []
search_docs = []
for i, llm_doc in enumerate(llm_docs):
# Convert LlmDoc to SearchDoc format
# Note: Some fields need default values as they're not in LlmDoc
search_doc = SearchDoc(
document_id=llm_doc.document_id,
chunk_ind=0, # Default value as LlmDoc doesn't have chunk index
semantic_identifier=llm_doc.semantic_identifier,
link=llm_doc.link,
blurb=llm_doc.blurb,
source_type=llm_doc.source_type,
boost=0, # Default value
hidden=False, # Default value
metadata=llm_doc.metadata,
score=None, # Will be set by SavedSearchDoc
match_highlights=llm_doc.match_highlights or [],
updated_at=llm_doc.updated_at,
primary_owners=None, # Default value
secondary_owners=None, # Default value
is_internet=False, # Default value
)
# Convert SearchDoc to SavedSearchDoc
saved_search_doc = SavedSearchDoc.from_search_doc(search_doc, db_doc_id=0)
search_docs.append(saved_search_doc)
return search_docs
def combine_message_thread(
messages: list[ThreadMessage],
max_tokens: int | None,
@@ -401,7 +440,7 @@ def process_kg_commands(
) -> None:
# Temporarily, until we have a draft UI for the KG Operations/Management
# TODO: move to api endpoint once we get frontend
if not persona_name.startswith("KG Beta"):
if not persona_name.startswith(TMP_DRALPHA_PERSONA_NAME):
return
kg_config_settings = get_kg_config_settings()

View File

@@ -1,7 +1,5 @@
from collections import OrderedDict
from collections.abc import Callable
from collections.abc import Iterator
from collections.abc import Mapping
from datetime import datetime
from enum import Enum
from typing import Any
@@ -22,6 +20,19 @@ from onyx.context.search.models import RetrievalDocs
from onyx.db.models import SearchDoc as DbSearchDoc
from onyx.file_store.models import FileDescriptor
from onyx.llm.override_models import PromptOverride
from onyx.server.query_and_chat.streaming_models import CitationDelta
from onyx.server.query_and_chat.streaming_models import CitationInfo
from onyx.server.query_and_chat.streaming_models import CitationStart
from onyx.server.query_and_chat.streaming_models import MessageDelta
from onyx.server.query_and_chat.streaming_models import MessageStart
from onyx.server.query_and_chat.streaming_models import OverallStop
from onyx.server.query_and_chat.streaming_models import Packet
from onyx.server.query_and_chat.streaming_models import ReasoningDelta
from onyx.server.query_and_chat.streaming_models import ReasoningStart
from onyx.server.query_and_chat.streaming_models import SearchToolDelta
from onyx.server.query_and_chat.streaming_models import SearchToolStart
from onyx.server.query_and_chat.streaming_models import SectionEnd
from onyx.server.query_and_chat.streaming_models import SubQuestionIdentifier
from onyx.tools.models import ToolCallFinalResult
from onyx.tools.models import ToolCallKickoff
from onyx.tools.models import ToolResponse
@@ -46,46 +57,6 @@ class LlmDoc(BaseModel):
match_highlights: list[str] | None
class SubQuestionIdentifier(BaseModel):
"""None represents references to objects in the original flow. To our understanding,
these will not be None in the packets returned from agent search.
"""
level: int | None = None
level_question_num: int | None = None
@staticmethod
def make_dict_by_level(
original_dict: Mapping[tuple[int, int], "SubQuestionIdentifier"],
) -> dict[int, list["SubQuestionIdentifier"]]:
"""returns a dict of level to object list (sorted by level_question_num)
Ordering is asc for readability.
"""
# organize by level, then sort ascending by question_index
level_dict: dict[int, list[SubQuestionIdentifier]] = {}
# group by level
for k, obj in original_dict.items():
level = k[0]
if level not in level_dict:
level_dict[level] = []
level_dict[level].append(obj)
# for each level, sort the group
for k2, value2 in level_dict.items():
# we need to handle the none case due to SubQuestionIdentifier typing
# level_question_num as int | None, even though it should never be None here.
level_dict[k2] = sorted(
value2,
key=lambda x: (x.level_question_num is None, x.level_question_num),
)
# sort by level
sorted_dict = OrderedDict(sorted(level_dict.items()))
return sorted_dict
# First chunk of info for streaming QA
class QADocsResponse(RetrievalDocs, SubQuestionIdentifier):
rephrased_query: str | None = None
@@ -164,11 +135,6 @@ class OnyxAnswerPiece(BaseModel):
# An intermediate representation of citations, later translated into
# a mapping of the citation [n] number to SearchDoc
class CitationInfo(SubQuestionIdentifier):
citation_num: int
document_id: str
class AllCitations(BaseModel):
citations: list[CitationInfo]
@@ -388,7 +354,21 @@ AgentSearchPacket = Union[
]
AnswerPacket = (
AnswerQuestionPossibleReturn | AgentSearchPacket | ToolCallKickoff | ToolResponse
AnswerQuestionPossibleReturn
| AgentSearchPacket
| ToolCallKickoff
| ToolResponse
| MessageStart
| MessageDelta
| SectionEnd
| ReasoningStart
| ReasoningDelta
| SearchToolStart
| SearchToolDelta
| OnyxAnswerPiece
| CitationStart
| CitationDelta
| OverallStop
)
@@ -402,12 +382,12 @@ ResponsePart = (
| AgentSearchPacket
)
AnswerStream = Iterator[AnswerPacket]
AnswerStream = Iterator[Packet]
class AnswerPostInfo(BaseModel):
ai_message_files: list[FileDescriptor]
qa_docs_response: QADocsResponse | None = None
rephrased_query: str | None = None
reference_db_search_docs: list[DbSearchDoc] | None = None
dropped_indices: list[int] | None = None
tool_result: ToolCallFinalResult | None = None

View File

@@ -0,0 +1,68 @@
from collections.abc import Generator
from typing import cast
from typing import Union
from onyx.chat.models import AgenticMessageResponseIDInfo
from onyx.chat.models import AgentSearchPacket
from onyx.chat.models import AllCitations
from onyx.chat.models import AnswerStream
from onyx.chat.models import CustomToolResponse
from onyx.chat.models import FileChatDisplay
from onyx.chat.models import FinalUsedContextDocsResponse
from onyx.chat.models import LLMRelevanceFilterResponse
from onyx.chat.models import MessageResponseIDInfo
from onyx.chat.models import MessageSpecificCitations
from onyx.chat.models import QADocsResponse
from onyx.chat.models import StreamingError
from onyx.chat.models import StreamStopInfo
from onyx.chat.models import UserKnowledgeFilePacket
from onyx.file_store.models import ChatFileType
from onyx.server.query_and_chat.models import ChatMessageDetail
from onyx.server.query_and_chat.streaming_models import CitationInfo
from onyx.server.query_and_chat.streaming_models import OverallStop
from onyx.server.query_and_chat.streaming_models import Packet
from onyx.utils.logger import setup_logger
logger = setup_logger()
COMMON_TOOL_RESPONSE_TYPES = {
"image": ChatFileType.IMAGE,
"csv": ChatFileType.CSV,
}
# Type definitions for packet processing
ChatPacket = Union[
StreamingError,
QADocsResponse,
LLMRelevanceFilterResponse,
FinalUsedContextDocsResponse,
ChatMessageDetail,
AllCitations,
CitationInfo,
FileChatDisplay,
CustomToolResponse,
MessageResponseIDInfo,
MessageSpecificCitations,
AgenticMessageResponseIDInfo,
StreamStopInfo,
AgentSearchPacket,
UserKnowledgeFilePacket,
Packet,
]
def process_streamed_packets(
answer_processed_output: AnswerStream,
) -> Generator[ChatPacket, None, None]:
"""Process the streamed output from the answer and yield chat packets."""
last_index = 0
for packet in answer_processed_output:
if isinstance(packet, Packet):
if packet.ind > last_index:
last_index = packet.ind
yield cast(ChatPacket, packet)
# Yield STOP packet to indicate streaming is complete
yield Packet(ind=last_index, obj=OverallStop())

View File

@@ -0,0 +1,164 @@
from collections.abc import Generator
from onyx.context.search.utils import chunks_or_sections_to_search_docs
from onyx.context.search.utils import dedupe_documents
from onyx.db.chat import create_db_search_doc
from onyx.db.chat import create_search_doc_from_user_file
from onyx.db.chat import translate_db_search_doc_to_server_search_doc
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.models import SearchDoc as DbSearchDoc
from onyx.db.models import UserFile
from onyx.file_store.models import InMemoryChatFile
from onyx.file_store.utils import save_files
from onyx.server.query_and_chat.streaming_models import ImageGenerationToolDelta
from onyx.server.query_and_chat.streaming_models import Packet
from onyx.server.query_and_chat.streaming_models import SearchToolDelta
from onyx.server.query_and_chat.streaming_models import SectionEnd
from onyx.tools.tool_implementations.images.image_generation_tool import (
ImageGenerationResponse,
)
from onyx.tools.tool_implementations.internet_search.models import (
InternetSearchResponseSummary,
)
from onyx.tools.tool_implementations.internet_search.utils import (
internet_search_response_to_search_docs,
)
from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary
def handle_search_tool_response_summary(
current_ind: int,
search_response: SearchResponseSummary,
selected_search_docs: list[DbSearchDoc] | None,
is_extended: bool,
dedupe_docs: bool = False,
user_files: list[UserFile] | None = None,
loaded_user_files: list[InMemoryChatFile] | None = None,
) -> Generator[Packet, None, tuple[list[DbSearchDoc], list[int] | None]]:
dropped_inds = None
if not selected_search_docs:
top_docs = chunks_or_sections_to_search_docs(search_response.top_sections)
deduped_docs = top_docs
if (
dedupe_docs and not is_extended
): # Extended tool responses are already deduped
deduped_docs, dropped_inds = dedupe_documents(top_docs)
with get_session_with_current_tenant() as db_session:
reference_db_search_docs = [
create_db_search_doc(server_search_doc=doc, db_session=db_session)
for doc in deduped_docs
]
else:
reference_db_search_docs = selected_search_docs
doc_ids = {doc.id for doc in reference_db_search_docs}
if user_files is not None and loaded_user_files is not None:
for user_file in user_files:
if user_file.id in doc_ids:
continue
associated_chat_file = next(
(
file
for file in loaded_user_files
if file.file_id == str(user_file.file_id)
),
None,
)
# Use create_search_doc_from_user_file to properly add the document to the database
if associated_chat_file is not None:
with get_session_with_current_tenant() as db_session:
db_doc = create_search_doc_from_user_file(
user_file, associated_chat_file, db_session
)
reference_db_search_docs.append(db_doc)
response_docs = [
translate_db_search_doc_to_server_search_doc(db_search_doc)
for db_search_doc in reference_db_search_docs
]
yield Packet(
ind=current_ind,
obj=SearchToolDelta(
documents=response_docs,
),
)
yield Packet(
ind=current_ind,
obj=SectionEnd(),
)
return reference_db_search_docs, dropped_inds
def handle_internet_search_tool_response(
current_ind: int,
internet_search_response: InternetSearchResponseSummary,
) -> Generator[Packet, None, list[DbSearchDoc]]:
server_search_docs = internet_search_response_to_search_docs(
internet_search_response
)
with get_session_with_current_tenant() as db_session:
reference_db_search_docs = [
create_db_search_doc(server_search_doc=doc, db_session=db_session)
for doc in server_search_docs
]
response_docs = [
translate_db_search_doc_to_server_search_doc(db_search_doc)
for db_search_doc in reference_db_search_docs
]
yield Packet(
ind=current_ind,
obj=SearchToolDelta(
documents=response_docs,
),
)
yield Packet(
ind=current_ind,
obj=SectionEnd(),
)
return reference_db_search_docs
def handle_image_generation_tool_response(
current_ind: int,
img_generation_responses: list[ImageGenerationResponse],
) -> Generator[Packet, None, None]:
# Save files and get file IDs
file_ids = save_files(
urls=[img.url for img in img_generation_responses if img.url],
base64_files=[
img.image_data for img in img_generation_responses if img.image_data
],
)
yield Packet(
ind=current_ind,
obj=ImageGenerationToolDelta(
images=[
{
"id": str(file_id),
"url": "", # URL will be constructed by frontend
"prompt": img.revised_prompt,
}
for file_id, img in zip(file_ids, img_generation_responses)
]
),
)
# Emit ImageToolEnd packet with file information
yield Packet(
ind=current_ind,
obj=SectionEnd(),
)

View File

@@ -1,6 +1,5 @@
import time
import traceback
from collections import defaultdict
from collections.abc import Callable
from collections.abc import Generator
from collections.abc import Iterator
@@ -17,30 +16,25 @@ from onyx.chat.chat_utils import create_temporary_persona
from onyx.chat.chat_utils import process_kg_commands
from onyx.chat.models import AgenticMessageResponseIDInfo
from onyx.chat.models import AgentMessageIDInfo
from onyx.chat.models import AgentSearchPacket
from onyx.chat.models import AllCitations
from onyx.chat.models import AnswerPostInfo
from onyx.chat.models import AnswerStyleConfig
from onyx.chat.models import ChatOnyxBotResponse
from onyx.chat.models import CitationConfig
from onyx.chat.models import CitationInfo
from onyx.chat.models import CustomToolResponse
from onyx.chat.models import DocumentPruningConfig
from onyx.chat.models import ExtendedToolResponse
from onyx.chat.models import FileChatDisplay
from onyx.chat.models import FinalUsedContextDocsResponse
from onyx.chat.models import LLMRelevanceFilterResponse
from onyx.chat.models import MessageResponseIDInfo
from onyx.chat.models import MessageSpecificCitations
from onyx.chat.models import OnyxAnswerPiece
from onyx.chat.models import PromptConfig
from onyx.chat.models import QADocsResponse
from onyx.chat.models import RefinedAnswerImprovement
from onyx.chat.models import StreamingError
from onyx.chat.models import StreamStopInfo
from onyx.chat.models import StreamStopReason
from onyx.chat.models import SubQuestionKey
from onyx.chat.models import UserKnowledgeFilePacket
from onyx.chat.packet_proccessing.process_streamed_packets import ChatPacket
from onyx.chat.packet_proccessing.process_streamed_packets import (
process_streamed_packets,
)
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_system_message
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_user_message
@@ -54,22 +48,15 @@ from onyx.configs.constants import BASIC_KEY
from onyx.configs.constants import MessageType
from onyx.configs.constants import MilestoneRecordType
from onyx.configs.constants import NO_AUTH_USER_ID
from onyx.configs.constants import TMP_DRALPHA_PERSONA_NAME
from onyx.context.search.enums import OptionalSearchSetting
from onyx.context.search.enums import QueryFlow
from onyx.context.search.enums import SearchType
from onyx.context.search.models import InferenceSection
from onyx.context.search.models import RetrievalDetails
from onyx.context.search.retrieval.search_runner import (
inference_sections_from_ids,
)
from onyx.context.search.utils import chunks_or_sections_to_search_docs
from onyx.context.search.utils import dedupe_documents
from onyx.context.search.utils import drop_llm_indices
from onyx.context.search.utils import relevant_sections_to_indices
from onyx.db.chat import attach_files_to_chat_message
from onyx.db.chat import create_db_search_doc
from onyx.db.chat import create_new_chat_message
from onyx.db.chat import create_search_doc_from_user_file
from onyx.db.chat import get_chat_message
from onyx.db.chat import get_chat_session_by_id
from onyx.db.chat import get_db_search_doc_by_id
@@ -77,7 +64,6 @@ from onyx.db.chat import get_doc_query_identifiers_from_model
from onyx.db.chat import get_or_create_root_message
from onyx.db.chat import reserve_message_id
from onyx.db.chat import translate_db_message_to_chat_message_detail
from onyx.db.chat import translate_db_search_doc_to_server_search_doc
from onyx.db.chat import update_chat_session_updated_at_timestamp
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.milestone import check_multi_assistant_milestone
@@ -88,15 +74,12 @@ from onyx.db.models import Persona
from onyx.db.models import SearchDoc as DbSearchDoc
from onyx.db.models import ToolCall
from onyx.db.models import User
from onyx.db.models import UserFile
from onyx.db.persona import get_persona_by_id
from onyx.db.search_settings import get_current_search_settings
from onyx.document_index.factory import get_default_document_index
from onyx.file_store.models import ChatFileType
from onyx.file_store.models import FileDescriptor
from onyx.file_store.models import InMemoryChatFile
from onyx.file_store.utils import load_all_chat_files
from onyx.file_store.utils import save_files
from onyx.kg.models import KGException
from onyx.llm.exceptions import GenAIDisabledException
from onyx.llm.factory import get_llms_for_persona
@@ -107,50 +90,20 @@ from onyx.llm.utils import litellm_exception_to_error_msg
from onyx.natural_language_processing.utils import get_tokenizer
from onyx.server.query_and_chat.models import ChatMessageDetail
from onyx.server.query_and_chat.models import CreateChatMessageRequest
from onyx.server.query_and_chat.streaming_models import CitationInfo
from onyx.server.utils import get_json_line
from onyx.tools.force import ForceUseTool
from onyx.tools.models import SearchToolOverrideKwargs
from onyx.tools.models import ToolResponse
from onyx.tools.tool import Tool
from onyx.tools.tool_constructor import construct_tools
from onyx.tools.tool_constructor import CustomToolConfig
from onyx.tools.tool_constructor import ImageGenerationToolConfig
from onyx.tools.tool_constructor import InternetSearchToolConfig
from onyx.tools.tool_constructor import SearchToolConfig
from onyx.tools.tool_implementations.custom.custom_tool import (
CUSTOM_TOOL_RESPONSE_ID,
)
from onyx.tools.tool_implementations.custom.custom_tool import CustomToolCallSummary
from onyx.tools.tool_implementations.images.image_generation_tool import (
IMAGE_GENERATION_RESPONSE_ID,
)
from onyx.tools.tool_implementations.images.image_generation_tool import (
ImageGenerationResponse,
)
from onyx.tools.tool_implementations.internet_search.internet_search_tool import (
INTERNET_SEARCH_RESPONSE_SUMMARY_ID,
)
from onyx.tools.tool_implementations.internet_search.internet_search_tool import (
InternetSearchTool,
)
from onyx.tools.tool_implementations.internet_search.models import (
InternetSearchResponseSummary,
)
from onyx.tools.tool_implementations.internet_search.utils import (
internet_search_response_to_search_docs,
)
from onyx.tools.tool_implementations.search.search_tool import (
FINAL_CONTEXT_DOCUMENTS_ID,
)
from onyx.tools.tool_implementations.search.search_tool import (
SEARCH_RESPONSE_SUMMARY_ID,
)
from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary
from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.tools.tool_implementations.search.search_tool import (
SECTION_RELEVANCE_LIST_ID,
)
from onyx.tools.tool_runner import ToolCallFinalResult
from onyx.utils.logger import setup_logger
from onyx.utils.long_term_log import LongTermLogger
from onyx.utils.telemetry import mt_cloud_telemetry
@@ -201,113 +154,6 @@ def _translate_citations(
return MessageSpecificCitations(citation_map=citation_to_saved_doc_id_map)
def _handle_search_tool_response_summary(
packet: ToolResponse,
db_session: Session,
selected_search_docs: list[DbSearchDoc] | None,
dedupe_docs: bool = False,
user_files: list[UserFile] | None = None,
loaded_user_files: list[InMemoryChatFile] | None = None,
) -> tuple[QADocsResponse, list[DbSearchDoc], list[int] | None]:
response_summary = cast(SearchResponseSummary, packet.response)
is_extended = isinstance(packet, ExtendedToolResponse)
dropped_inds = None
if not selected_search_docs:
top_docs = chunks_or_sections_to_search_docs(response_summary.top_sections)
deduped_docs = top_docs
if (
dedupe_docs and not is_extended
): # Extended tool responses are already deduped
deduped_docs, dropped_inds = dedupe_documents(top_docs)
reference_db_search_docs = [
create_db_search_doc(server_search_doc=doc, db_session=db_session)
for doc in deduped_docs
]
else:
reference_db_search_docs = selected_search_docs
doc_ids = {doc.id for doc in reference_db_search_docs}
if user_files is not None and loaded_user_files is not None:
for user_file in user_files:
if user_file.id in doc_ids:
continue
associated_chat_file = next(
(
file
for file in loaded_user_files
if file.file_id == str(user_file.file_id)
),
None,
)
# Use create_search_doc_from_user_file to properly add the document to the database
if associated_chat_file is not None:
db_doc = create_search_doc_from_user_file(
user_file, associated_chat_file, db_session
)
reference_db_search_docs.append(db_doc)
response_docs = [
translate_db_search_doc_to_server_search_doc(db_search_doc)
for db_search_doc in reference_db_search_docs
]
level, question_num = None, None
if isinstance(packet, ExtendedToolResponse):
level, question_num = packet.level, packet.level_question_num
return (
QADocsResponse(
rephrased_query=response_summary.rephrased_query,
top_documents=response_docs,
predicted_flow=response_summary.predicted_flow,
predicted_search=response_summary.predicted_search,
applied_source_filters=response_summary.final_filters.source_type,
applied_time_cutoff=response_summary.final_filters.time_cutoff,
recency_bias_multiplier=response_summary.recency_bias_multiplier,
level=level,
level_question_num=question_num,
),
reference_db_search_docs,
dropped_inds,
)
def _handle_internet_search_tool_response_summary(
packet: ToolResponse,
db_session: Session,
) -> tuple[QADocsResponse, list[DbSearchDoc]]:
internet_search_response = cast(InternetSearchResponseSummary, packet.response)
server_search_docs = internet_search_response_to_search_docs(
internet_search_response
)
reference_db_search_docs = [
create_db_search_doc(server_search_doc=doc, db_session=db_session)
for doc in server_search_docs
]
response_docs = [
translate_db_search_doc_to_server_search_doc(db_search_doc)
for db_search_doc in reference_db_search_docs
]
return (
QADocsResponse(
rephrased_query=internet_search_response.query,
top_documents=response_docs,
predicted_flow=QueryFlow.QUESTION_ANSWER,
predicted_search=SearchType.INTERNET,
applied_source_filters=[],
applied_time_cutoff=None,
recency_bias_multiplier=1.0,
),
reference_db_search_docs,
)
def _get_force_search_settings(
new_msg_req: CreateChatMessageRequest,
tools: list[Tool],
@@ -392,136 +238,9 @@ def _get_persona_for_chat_session(
return persona
ChatPacket = (
StreamingError
| QADocsResponse
| LLMRelevanceFilterResponse
| FinalUsedContextDocsResponse
| ChatMessageDetail
| OnyxAnswerPiece
| AllCitations
| CitationInfo
| FileChatDisplay
| CustomToolResponse
| MessageSpecificCitations
| MessageResponseIDInfo
| AgenticMessageResponseIDInfo
| StreamStopInfo
| AgentSearchPacket
| UserKnowledgeFilePacket
)
ChatPacketStream = Iterator[ChatPacket]
def _process_tool_response(
packet: ToolResponse,
db_session: Session,
selected_db_search_docs: list[DbSearchDoc] | None,
info_by_subq: dict[SubQuestionKey, AnswerPostInfo],
retrieval_options: RetrievalDetails | None,
user_file_files: list[UserFile] | None,
user_files: list[InMemoryChatFile] | None,
) -> Generator[ChatPacket, None, dict[SubQuestionKey, AnswerPostInfo]]:
level, level_question_num = (
(packet.level, packet.level_question_num)
if isinstance(packet, ExtendedToolResponse)
else BASIC_KEY
)
assert level is not None
assert level_question_num is not None
info = info_by_subq[SubQuestionKey(level=level, question_num=level_question_num)]
# TODO: don't need to dedupe here when we do it in agent flow
if packet.id == SEARCH_RESPONSE_SUMMARY_ID:
(
info.qa_docs_response,
info.reference_db_search_docs,
info.dropped_indices,
) = _handle_search_tool_response_summary(
packet=packet,
db_session=db_session,
selected_search_docs=selected_db_search_docs,
# Deduping happens at the last step to avoid harming quality by dropping content early on
dedupe_docs=bool(retrieval_options and retrieval_options.dedupe_docs),
user_files=[],
loaded_user_files=[],
)
yield info.qa_docs_response
elif packet.id == SECTION_RELEVANCE_LIST_ID:
relevance_sections = packet.response
if info.reference_db_search_docs is None:
logger.warning("No reference docs found for relevance filtering")
return info_by_subq
llm_indices = relevant_sections_to_indices(
relevance_sections=relevance_sections,
items=[
translate_db_search_doc_to_server_search_doc(doc)
for doc in info.reference_db_search_docs
],
)
if info.dropped_indices:
llm_indices = drop_llm_indices(
llm_indices=llm_indices,
search_docs=info.reference_db_search_docs,
dropped_indices=info.dropped_indices,
)
yield LLMRelevanceFilterResponse(llm_selected_doc_indices=llm_indices)
elif packet.id == FINAL_CONTEXT_DOCUMENTS_ID:
yield FinalUsedContextDocsResponse(final_context_docs=packet.response)
elif packet.id == IMAGE_GENERATION_RESPONSE_ID:
img_generation_response = cast(list[ImageGenerationResponse], packet.response)
file_ids = save_files(
urls=[img.url for img in img_generation_response if img.url],
base64_files=[
img.image_data for img in img_generation_response if img.image_data
],
)
info.ai_message_files.extend(
[
FileDescriptor(id=str(file_id), type=ChatFileType.IMAGE)
for file_id in file_ids
]
)
yield FileChatDisplay(file_ids=[str(file_id) for file_id in file_ids])
elif packet.id == INTERNET_SEARCH_RESPONSE_SUMMARY_ID:
(
info.qa_docs_response,
info.reference_db_search_docs,
) = _handle_internet_search_tool_response_summary(
packet=packet,
db_session=db_session,
)
yield info.qa_docs_response
elif packet.id == CUSTOM_TOOL_RESPONSE_ID:
custom_tool_response = cast(CustomToolCallSummary, packet.response)
response_type = custom_tool_response.response_type
if response_type in COMMON_TOOL_RESPONSE_TYPES:
file_ids = custom_tool_response.tool_result.file_ids
file_type = COMMON_TOOL_RESPONSE_TYPES[response_type]
info.ai_message_files.extend(
[
FileDescriptor(id=str(file_id), type=file_type)
for file_id in file_ids
]
)
yield FileChatDisplay(file_ids=[str(file_id) for file_id in file_ids])
else:
yield CustomToolResponse(
response=custom_tool_response.tool_result,
tool_name=custom_tool_response.tool_name,
)
return info_by_subq
def stream_chat_message_objects(
new_msg_req: CreateChatMessageRequest,
user: User | None,
@@ -561,6 +280,7 @@ def stream_chat_message_objects(
new_msg_req.chunks_below = 0
llm: LLM
answer: Answer
try:
# Move these variables inside the try block
@@ -845,6 +565,18 @@ def stream_chat_message_objects(
error: str | None,
tool_call: ToolCall | None,
) -> ChatMessage:
is_kg_beta = parent_message.chat_session.persona.name.startswith(
TMP_DRALPHA_PERSONA_NAME
)
is_basic_search = tool_call and tool_call.tool_name == SearchTool._NAME
is_agentic_overwrite = new_msg_req.use_agentic_search and not (
is_kg_beta and is_basic_search
)
if is_kg_beta:
is_agentic_overwrite = False
return create_new_chat_message(
chat_session_id=chat_session_id,
parent_message=(
@@ -867,11 +599,9 @@ def stream_chat_message_objects(
db_session=db_session,
commit=False,
reserved_message_id=reserved_message_id,
is_agentic=new_msg_req.use_agentic_search,
is_agentic=is_agentic_overwrite,
)
partial_response = create_response
prompt_override = new_msg_req.prompt_override or chat_session.prompt_override
if new_msg_req.persona_override_config:
prompt_config = PromptConfig(
@@ -983,7 +713,6 @@ def stream_chat_message_objects(
)
# LLM prompt building, response capturing, etc.
answer = Answer(
prompt_builder=prompt_builder,
is_connected=is_connected,
@@ -1013,41 +742,10 @@ def stream_chat_message_objects(
skip_gen_ai_answer_generation=new_msg_req.skip_gen_ai_answer_generation,
)
info_by_subq: dict[SubQuestionKey, AnswerPostInfo] = defaultdict(
lambda: AnswerPostInfo(ai_message_files=[])
# Process streamed packets using the new packet processing module
yield from process_streamed_packets(
answer_processed_output=answer.processed_streamed_output,
)
refined_answer_improvement = True
for packet in answer.processed_streamed_output:
if isinstance(packet, ToolResponse):
info_by_subq = yield from _process_tool_response(
packet=packet,
db_session=db_session,
selected_db_search_docs=selected_db_search_docs,
info_by_subq=info_by_subq,
retrieval_options=retrieval_options,
user_file_files=user_file_models,
user_files=in_memory_user_files,
)
elif isinstance(packet, StreamStopInfo):
if packet.stop_reason == StreamStopReason.FINISHED:
yield packet
elif isinstance(packet, RefinedAnswerImprovement):
refined_answer_improvement = packet.refined_answer_improvement
yield packet
else:
if isinstance(packet, ToolCallFinalResult):
level, level_question_num = (
(packet.level, packet.level_question_num)
if packet.level is not None
and packet.level_question_num is not None
else BASIC_KEY
)
info = info_by_subq[
SubQuestionKey(level=level, question_num=level_question_num)
]
info.tool_result = packet
yield cast(ChatPacket, packet)
except ValueError as e:
logger.exception("Failed to process chat message.")
@@ -1083,17 +781,6 @@ def stream_chat_message_objects(
db_session.rollback()
return
yield from _post_llm_answer_processing(
answer=answer,
info_by_subq=info_by_subq,
tool_dict=tool_dict,
partial_response=partial_response,
llm_tokenizer_encode_func=llm_tokenizer_encode_func,
db_session=db_session,
chat_session_id=chat_session_id,
refined_answer_improvement=refined_answer_improvement,
)
def _post_llm_answer_processing(
answer: Answer,
@@ -1103,7 +790,6 @@ def _post_llm_answer_processing(
llm_tokenizer_encode_func: Callable[[str], list[int]],
db_session: Session,
chat_session_id: UUID,
refined_answer_improvement: bool | None,
) -> Generator[ChatPacket, None, None]:
"""
Stores messages in the db and yields some final packets to the frontend
@@ -1115,20 +801,6 @@ def _post_llm_answer_processing(
for tool in tool_list:
tool_name_to_tool_id[tool.name] = tool_id
subq_citations = answer.citations_by_subquestion()
for subq_key in subq_citations:
info = info_by_subq[subq_key]
logger.debug("Post-LLM answer processing")
if info.reference_db_search_docs:
info.message_specific_citations = _translate_citations(
citations_list=subq_citations[subq_key],
db_docs=info.reference_db_search_docs,
)
# TODO: AllCitations should contain subq info?
if not answer.is_cancelled():
yield AllCitations(citations=subq_citations[subq_key])
# Saving Gen AI answer and responding with message info
basic_key = SubQuestionKey(level=BASIC_KEY[0], question_num=BASIC_KEY[1])
@@ -1144,9 +816,7 @@ def _post_llm_answer_processing(
)
gen_ai_response_message = partial_response(
message=answer.llm_answer,
rephrased_query=(
info.qa_docs_response.rephrased_query if info.qa_docs_response else None
),
rephrased_query=info.rephrased_query,
reference_docs=info.reference_db_search_docs,
files=info.ai_message_files,
token_count=len(llm_tokenizer_encode_func(answer.llm_answer)),
@@ -1205,7 +875,6 @@ def _post_llm_answer_processing(
else None
),
error=ERROR_TYPE_CANCELLED if answer.is_cancelled() else None,
refined_answer_improvement=refined_answer_improvement,
is_agentic=True,
)
agentic_message_ids.append(

View File

@@ -3,12 +3,12 @@ from collections.abc import Generator
from langchain_core.messages import BaseMessage
from onyx.chat.models import CitationInfo
from onyx.chat.models import LlmDoc
from onyx.chat.models import OnyxAnswerPiece
from onyx.chat.models import ResponsePart
from onyx.chat.stream_processing.citation_processing import CitationProcessor
from onyx.chat.stream_processing.utils import DocumentIdOrderMapping
from onyx.server.query_and_chat.streaming_models import CitationInfo
from onyx.utils.logger import setup_logger
logger = setup_logger()

View File

@@ -1,12 +1,12 @@
import re
from collections.abc import Generator
from onyx.chat.models import CitationInfo
from onyx.chat.models import LlmDoc
from onyx.chat.models import OnyxAnswerPiece
from onyx.chat.stream_processing.utils import DocumentIdOrderMapping
from onyx.configs.chat_configs import STOP_STREAM_PAT
from onyx.prompts.constants import TRIPLE_BACKTICK
from onyx.server.query_and_chat.streaming_models import CitationInfo
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -172,3 +172,151 @@ class CitationProcessor:
)
return final_processed_str, final_citation_info
class CitationProcessorGraph:
def __init__(
self,
context_docs: list[LlmDoc],
stop_stream: str | None = STOP_STREAM_PAT,
):
self.context_docs = context_docs # list of docs in the order the LLM sees
self.max_citation_num = len(context_docs)
self.stop_stream = stop_stream
self.llm_out = "" # entire output so far
self.curr_segment = "" # tokens held for citation processing
self.hold = "" # tokens held for stop token processing
self.recent_cited_documents: set[str] = set() # docs recently cited
self.cited_documents: set[str] = set() # docs cited in the entire stream
self.non_citation_count = 0
# '[', '[[', '[1', '[[1', '[1,', '[1, ', '[1,2', '[1, 2,', etc.
self.possible_citation_pattern = re.compile(r"(\[+(?:\d+,? ?)*$)")
# group 1: '[[1]]', [[2]], etc.
# group 2: '[1]', '[1, 2]', '[1,2,16]', etc.
self.citation_pattern = re.compile(r"(\[\[\d+\]\])|(\[\d+(?:, ?\d+)*\])")
def process_token(
self, token: str | None
) -> str | tuple[str, list[CitationInfo]] | None:
# None -> end of stream
if token is None:
return None
if self.stop_stream:
next_hold = self.hold + token
if self.stop_stream in next_hold:
return None
if next_hold == self.stop_stream[: len(next_hold)]:
self.hold = next_hold
return None
token = next_hold
self.hold = ""
self.curr_segment += token
self.llm_out += token
# Handle code blocks without language tags
if "`" in self.curr_segment:
if self.curr_segment.endswith("`"):
pass
elif "```" in self.curr_segment:
piece_that_comes_after = self.curr_segment.split("```")[1][0]
if piece_that_comes_after == "\n" and in_code_block(self.llm_out):
self.curr_segment = self.curr_segment.replace("```", "```plaintext")
citation_matches = list(self.citation_pattern.finditer(self.curr_segment))
possible_citation_found = bool(
re.search(self.possible_citation_pattern, self.curr_segment)
)
result = ""
if citation_matches and not in_code_block(self.llm_out):
match_idx = 0
citation_infos = []
for match in citation_matches:
match_span = match.span()
# add stuff before/between the matches
intermatch_str = self.curr_segment[match_idx : match_span[0]]
self.non_citation_count += len(intermatch_str)
match_idx = match_span[1]
result += intermatch_str
# reset recent citations if no citations found for a while
if self.non_citation_count > 5:
self.recent_cited_documents.clear()
# process the citation string and emit citation info
res, citation_info = self.process_citation(match)
result += res
citation_infos.extend(citation_info)
self.non_citation_count = 0
# leftover could be part of next citation
self.curr_segment = self.curr_segment[match_idx:]
self.non_citation_count = len(self.curr_segment)
return result, citation_infos
# hold onto the current segment if potential citations found, otherwise stream
if not possible_citation_found:
result += self.curr_segment
self.non_citation_count += len(self.curr_segment)
self.curr_segment = ""
if result:
return result
return None
def process_citation(self, match: re.Match) -> tuple[str, list[CitationInfo]]:
"""
Process a single citation match and return the citation string and the
citation info. The match string can look like '[1]', '[1, 13, 6], '[[4]]', etc.
"""
citation_str: str = match.group() # e.g., '[1]', '[1, 2, 3]', '[[1]]', etc.
formatted = match.lastindex == 1 # True means already in the form '[[1]]'
final_processed_str = ""
final_citation_info: list[CitationInfo] = []
# process the citation_str
citation_content = citation_str[2:-2] if formatted else citation_str[1:-1]
for num in (int(num) for num in citation_content.split(",")):
# keep invalid citations as is
if not (1 <= num <= self.max_citation_num):
final_processed_str += f"[[{num}]]" if formatted else f"[{num}]"
continue
# translate the citation number of the LLM to what the user sees
# should always be in the display_doc_order_dict. But check anyways
context_llm_doc = self.context_docs[num - 1]
llm_docid = context_llm_doc.document_id
# skip citations of the same work if cited recently
if llm_docid in self.recent_cited_documents:
continue
self.recent_cited_documents.add(llm_docid)
# format the citation string
# if formatted:
# final_processed_str += f"[[{num}]]({link})"
# else:
link = context_llm_doc.link or ""
final_processed_str += f"[[{num}]]({link})"
# create the citation info
if llm_docid not in self.cited_documents:
self.cited_documents.add(llm_docid)
final_citation_info.append(
CitationInfo(
citation_num=num,
document_id=llm_docid,
)
)
return final_processed_str, final_citation_info

View File

@@ -3,6 +3,7 @@ import socket
from enum import auto
from enum import Enum
ONYX_DEFAULT_APPLICATION_NAME = "Onyx"
ONYX_SLACK_URL = "https://join.slack.com/t/onyx-dot-app/shared_invite/zt-2twesxdr6-5iQitKZQpgq~hYIZ~dv3KA"
ONYX_EMAILABLE_LOGO_MAX_DIM = 512
@@ -138,6 +139,8 @@ CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT = 300 # 5 min
DANSWER_REDIS_FUNCTION_LOCK_PREFIX = "da_function_lock:"
TMP_DRALPHA_PERSONA_NAME = "KG Beta"
class DocumentSource(str, Enum):
# Special case, document passed in via Onyx APIs without specifying a source type
@@ -512,3 +515,57 @@ else:
class OnyxCallTypes(str, Enum):
FIREFLIES = "FIREFLIES"
GONG = "GONG"
# TODO: this should be stored likely in database
DocumentSourceDescription: dict[DocumentSource, str] = {
# Special case, document passed in via Onyx APIs without specifying a source type
DocumentSource.INGESTION_API: "ingestion_api",
DocumentSource.SLACK: "slack channels",
DocumentSource.WEB: "web pages",
DocumentSource.GOOGLE_DRIVE: "google drive documents (docs, sheets, etc.)",
DocumentSource.GMAIL: "email messages",
DocumentSource.REQUESTTRACKER: "requesttracker",
DocumentSource.GITHUB: "github data",
DocumentSource.GITBOOK: "gitbook data",
DocumentSource.GITLAB: "gitlab data",
DocumentSource.GURU: "guru data",
DocumentSource.BOOKSTACK: "bookstack data",
DocumentSource.CONFLUENCE: "confluence data (pages, spaces, etc.)",
DocumentSource.JIRA: "jira data (issues, tickets, projects, etc.)",
DocumentSource.SLAB: "slab data",
DocumentSource.PRODUCTBOARD: "productboard data (boards, etc.)",
DocumentSource.FILE: "files",
DocumentSource.NOTION: "notion data - a workspace that combines note-taking, \
project management, and collaboration tools into a single, customizable platform",
DocumentSource.ZULIP: "zulip data",
DocumentSource.LINEAR: "linear data - project management tool, including tickets etc.",
DocumentSource.HUBSPOT: "hubspot data - CRM and marketing automation data",
DocumentSource.DOCUMENT360: "document360 data",
DocumentSource.GONG: "gong - call transcripts",
DocumentSource.GOOGLE_SITES: "google_sites - websites",
DocumentSource.ZENDESK: "zendesk - customer support data",
DocumentSource.LOOPIO: "loopio - rfp data",
DocumentSource.DROPBOX: "dropbox - files",
DocumentSource.SHAREPOINT: "sharepoint - files",
DocumentSource.TEAMS: "teams - chat and collaboration",
DocumentSource.SALESFORCE: "salesforce - CRM data",
DocumentSource.DISCOURSE: "discourse - discussion forums",
DocumentSource.AXERO: "axero - employee engagement data",
DocumentSource.CLICKUP: "clickup - project management tool",
DocumentSource.MEDIAWIKI: "mediawiki - wiki data",
DocumentSource.WIKIPEDIA: "wikipedia - encyclopedia data",
DocumentSource.ASANA: "asana",
DocumentSource.S3: "s3",
DocumentSource.R2: "r2",
DocumentSource.GOOGLE_CLOUD_STORAGE: "google_cloud_storage - cloud storage",
DocumentSource.OCI_STORAGE: "oci_storage - cloud storage",
DocumentSource.XENFORO: "xenforo - forum data",
DocumentSource.DISCORD: "discord - chat and collaboration",
DocumentSource.FRESHDESK: "freshdesk - customer support data",
DocumentSource.FIREFLIES: "fireflies - call transcripts",
DocumentSource.EGNYTE: "egnyte - files",
DocumentSource.AIRTABLE: "airtable - database",
DocumentSource.HIGHSPOT: "highspot - CRM data",
DocumentSource.IMAP: "imap - email data",
}

View File

@@ -140,3 +140,5 @@ KG_MAX_SEARCH_DOCUMENTS: int = int(os.environ.get("KG_MAX_SEARCH_DOCUMENTS", "15
KG_MAX_DECOMPOSITION_SEGMENTS: int = int(
os.environ.get("KG_MAX_DECOMPOSITION_SEGMENTS", "10")
)
KG_BETA_ASSISTANT_DESCRIPTION = "The KG Beta assistant uses the Onyx Knowledge Graph (beta) structure \
to answer questions"

View File

@@ -378,6 +378,11 @@ class SavedSearchDoc(SearchDoc):
search_doc_data["score"] = search_doc_data.get("score") or 0.0
return cls(**search_doc_data, db_doc_id=db_doc_id)
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "SavedSearchDoc":
"""Create SavedSearchDoc from serialized dictionary data (e.g., from database JSON)"""
return cls(**data)
def __lt__(self, other: Any) -> bool:
if not isinstance(other, SavedSearchDoc):
return NotImplemented

View File

@@ -19,10 +19,12 @@ from sqlalchemy.exc import MultipleResultsFound
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import Session
from onyx.agents.agent_search.dr.enums import ResearchType
from onyx.agents.agent_search.shared_graph_utils.models import CombinedAgentMetrics
from onyx.agents.agent_search.shared_graph_utils.models import (
SubQuestionAnswerResults,
)
from onyx.agents.agent_search.utils import create_citation_format_list
from onyx.auth.schemas import UserRole
from onyx.chat.models import DocumentRelevance
from onyx.configs.chat_configs import HARD_DELETE_CHATS
@@ -41,12 +43,14 @@ from onyx.db.models import ChatMessage__SearchDoc
from onyx.db.models import ChatSession
from onyx.db.models import ChatSessionSharedStatus
from onyx.db.models import Prompt
from onyx.db.models import ResearchAgentIteration
from onyx.db.models import SearchDoc
from onyx.db.models import SearchDoc as DBSearchDoc
from onyx.db.models import ToolCall
from onyx.db.models import User
from onyx.db.models import UserFile
from onyx.db.persona import get_best_persona_id_for_user
from onyx.db.tools import get_tool_by_id
from onyx.file_store.file_store import get_default_file_store
from onyx.file_store.models import FileDescriptor
from onyx.file_store.models import InMemoryChatFile
@@ -55,12 +59,211 @@ from onyx.llm.override_models import PromptOverride
from onyx.server.query_and_chat.models import ChatMessageDetail
from onyx.server.query_and_chat.models import SubQueryDetail
from onyx.server.query_and_chat.models import SubQuestionDetail
from onyx.server.query_and_chat.streaming_models import CitationDelta
from onyx.server.query_and_chat.streaming_models import CitationInfo
from onyx.server.query_and_chat.streaming_models import CitationStart
from onyx.server.query_and_chat.streaming_models import EndStepPacketList
from onyx.server.query_and_chat.streaming_models import ImageGenerationToolDelta
from onyx.server.query_and_chat.streaming_models import ImageGenerationToolStart
from onyx.server.query_and_chat.streaming_models import MessageDelta
from onyx.server.query_and_chat.streaming_models import MessageStart
from onyx.server.query_and_chat.streaming_models import OverallStop
from onyx.server.query_and_chat.streaming_models import Packet
from onyx.server.query_and_chat.streaming_models import ReasoningDelta
from onyx.server.query_and_chat.streaming_models import ReasoningStart
from onyx.server.query_and_chat.streaming_models import SearchToolDelta
from onyx.server.query_and_chat.streaming_models import SearchToolStart
from onyx.server.query_and_chat.streaming_models import SectionEnd
from onyx.tools.tool_runner import ToolCallFinalResult
from onyx.utils.logger import setup_logger
from onyx.utils.special_types import JSON_ro
logger = setup_logger()
_CANNOT_SHOW_STEP_RESULTS_STR = "[Cannot display step results]"
def create_message_packets(
message_text: str, final_documents: list[SavedSearchDoc] | None, step_nr: int
) -> list[Packet]:
packets: list[Packet] = []
packets.append(
Packet(
ind=step_nr,
obj=MessageStart(
content="",
final_documents=final_documents,
),
)
)
packets.append(
Packet(
ind=step_nr,
obj=MessageDelta(
type="message_delta",
content=message_text,
),
),
)
packets.append(
Packet(
ind=step_nr,
obj=SectionEnd(
type="section_end",
),
)
)
return packets
def create_citation_packets(
citation_info_list: list[CitationInfo], step_nr: int
) -> list[Packet]:
packets: list[Packet] = []
packets.append(
Packet(
ind=step_nr,
obj=CitationStart(
type="citation_start",
),
)
)
packets.append(
Packet(
ind=step_nr,
obj=CitationDelta(
type="citation_delta",
citations=citation_info_list,
),
)
)
packets.append(
Packet(
ind=step_nr,
obj=SectionEnd(
type="section_end",
),
)
)
return packets
def create_reasoning_packets(reasoning_text: str, step_nr: int) -> list[Packet]:
packets: list[Packet] = []
packets.append(
Packet(
ind=step_nr,
obj=ReasoningStart(
type="reasoning_start",
),
)
)
packets.append(
Packet(
ind=step_nr,
obj=ReasoningDelta(
type="reasoning_delta",
reasoning=reasoning_text,
),
),
)
packets.append(
Packet(
ind=step_nr,
obj=SectionEnd(
type="section_end",
),
)
)
return packets
def create_image_generation_packets(
images: list[dict[str, str]] | None, step_nr: int
) -> list[Packet]:
packets: list[Packet] = []
packets.append(
Packet(
ind=step_nr,
obj=ImageGenerationToolStart(type="image_generation_tool_start"),
)
)
packets.append(
Packet(
ind=step_nr,
obj=ImageGenerationToolDelta(
type="image_generation_tool_delta", images=images
),
),
)
packets.append(
Packet(
ind=step_nr,
obj=SectionEnd(
type="section_end",
),
)
)
return packets
def create_search_packets(
search_queries: list[str],
saved_search_docs: list[SavedSearchDoc] | None,
is_internet_search: bool,
step_nr: int,
) -> list[Packet]:
packets: list[Packet] = []
packets.append(
Packet(
ind=step_nr,
obj=SearchToolStart(
type="internal_search_tool_start",
is_internet_search=is_internet_search,
),
)
)
packets.append(
Packet(
ind=step_nr,
obj=SearchToolDelta(
type="internal_search_tool_delta",
queries=search_queries,
documents=saved_search_docs,
),
),
)
packets.append(
Packet(
ind=step_nr,
obj=SectionEnd(
type="section_end",
),
)
)
return packets
def get_chat_session_by_id(
chat_session_id: UUID,
@@ -550,11 +753,23 @@ def get_chat_messages_by_session(
)
if prefetch_tool_calls:
# stmt = stmt.options(
# joinedload(ChatMessage.tool_call),
# joinedload(ChatMessage.sub_questions).joinedload(
# AgentSubQuestion.sub_queries
# ),
# )
# result = db_session.scalars(stmt).unique().all()
stmt = (
select(ChatMessage)
.where(ChatMessage.chat_session_id == chat_session_id)
.order_by(nullsfirst(ChatMessage.parent_message))
)
stmt = stmt.options(
joinedload(ChatMessage.tool_call),
joinedload(ChatMessage.sub_questions).joinedload(
AgentSubQuestion.sub_queries
),
joinedload(ChatMessage.research_iterations).joinedload(
ResearchAgentIteration.sub_steps
)
)
result = db_session.scalars(stmt).unique().all()
else:
@@ -645,8 +860,9 @@ def create_new_chat_message(
commit: bool = True,
reserved_message_id: int | None = None,
overridden_model: str | None = None,
refined_answer_improvement: bool | None = None,
is_agentic: bool = False,
research_type: ResearchType | None = None,
research_plan: dict[str, Any] | None = None,
) -> ChatMessage:
if reserved_message_id is not None:
# Edit existing message
@@ -667,8 +883,9 @@ def create_new_chat_message(
existing_message.error = error
existing_message.alternate_assistant_id = alternate_assistant_id
existing_message.overridden_model = overridden_model
existing_message.refined_answer_improvement = refined_answer_improvement
existing_message.is_agentic = is_agentic
existing_message.research_type = research_type
existing_message.research_plan = research_plan
new_chat_message = existing_message
else:
# Create new message
@@ -687,8 +904,9 @@ def create_new_chat_message(
error=error,
alternate_assistant_id=alternate_assistant_id,
overridden_model=overridden_model,
refined_answer_improvement=refined_answer_improvement,
is_agentic=is_agentic,
research_type=research_type,
research_plan=research_plan,
)
db_session.add(new_chat_message)
@@ -1032,6 +1250,160 @@ def get_retrieval_docs_from_search_docs(
return RetrievalDocs(top_documents=top_documents)
def translate_db_message_to_packets(
chat_message: ChatMessage,
db_session: Session,
remove_doc_content: bool = False,
start_step_nr: int = 1,
) -> EndStepPacketList:
step_nr = start_step_nr
packet_list: list[Packet] = []
# only stream out packets for assistant messages
if chat_message.message_type == MessageType.ASSISTANT:
citations = chat_message.citations
# Get document IDs from SearchDoc table using citation mapping
citation_info_list = []
if citations:
for citation_num, search_doc_id in citations.items():
search_doc = get_db_search_doc_by_id(search_doc_id, db_session)
if search_doc:
citation_info_list.append(
CitationInfo(
citation_num=citation_num,
document_id=search_doc.document_id,
)
)
if chat_message.research_type in [ResearchType.THOUGHTFUL, ResearchType.DEEP]:
research_iterations = sorted(
chat_message.research_iterations, key=lambda x: x.iteration_nr
) # sorted iterations
for research_iteration in research_iterations:
if research_iteration.iteration_nr > 1:
# first iteration does noty need to be reasoned for
packet_list.extend(
create_reasoning_packets(research_iteration.reasoning, step_nr)
)
step_nr += 1
if research_iteration.purpose:
packet_list.extend(
create_reasoning_packets(research_iteration.purpose, step_nr)
)
step_nr += 1
sub_steps = research_iteration.sub_steps
tasks = []
tool_call_ids = []
cited_docs: list[SavedSearchDoc] = []
for sub_step in sub_steps:
tasks.append(sub_step.sub_step_instructions)
tool_call_ids.append(sub_step.sub_step_tool_id)
sub_step_cited_docs = sub_step.cited_doc_results
if isinstance(sub_step_cited_docs, list):
# Convert serialized dict data back to SavedSearchDoc objects
saved_search_docs = [
(
SavedSearchDoc.from_dict(doc_data)
if isinstance(doc_data, dict)
else doc_data
)
for doc_data in sub_step_cited_docs
]
cited_docs.extend(saved_search_docs)
else:
packet_list.extend(
create_reasoning_packets(
_CANNOT_SHOW_STEP_RESULTS_STR, step_nr
)
)
step_nr += 1
if len(set(tool_call_ids)) > 1:
packet_list.extend(
create_reasoning_packets(_CANNOT_SHOW_STEP_RESULTS_STR, step_nr)
)
step_nr += 1
elif (
len(sub_steps) == 0
): # no sub steps, no tool calls. But iteration can have reasoning or purpose
continue
else:
# TODO: replace with isinstance, resolving circular imports
tool_id = tool_call_ids[0]
tool = get_tool_by_id(tool_id, db_session)
tool_name = tool.name
if tool_name in ["SearchTool", "KnowledgeGraphTool"]:
cited_docs = cast(list[SavedSearchDoc], cited_docs)
packet_list.extend(
create_search_packets(tasks, cited_docs, False, step_nr)
)
step_nr += 1
elif tool_name == "InternetSearchTool":
cited_docs = cast(list[SavedSearchDoc], cited_docs)
packet_list.extend(
create_search_packets(tasks, cited_docs, True, step_nr)
)
step_nr += 1
elif tool_name == "ImageGenerationTool":
if len(tasks) > 1:
packet_list.extend(
create_reasoning_packets(
_CANNOT_SHOW_STEP_RESULTS_STR, step_nr
)
)
step_nr += 1
else:
images = cited_docs[0]
packet_list.extend(
create_image_generation_packets(images, step_nr)
)
step_nr += 1
else:
raise ValueError(f"Unknown tool name: {tool_name}")
packet_list.extend(
create_message_packets(
message_text=chat_message.message,
final_documents=[
translate_db_search_doc_to_server_search_doc(doc)
for doc in chat_message.search_docs
],
step_nr=step_nr,
)
)
step_nr += 1
packet_list.extend(create_citation_packets(citation_info_list, step_nr))
step_nr += 1
packet_list.append(Packet(ind=step_nr, obj=OverallStop()))
return EndStepPacketList(
end_step_nr=step_nr,
packet_list=packet_list,
)
def translate_db_message_to_chat_message_detail(
chat_message: ChatMessage,
remove_doc_content: bool = False,
@@ -1061,11 +1433,6 @@ def translate_db_message_to_chat_message_detail(
),
alternate_assistant_id=chat_message.alternate_assistant_id,
overridden_model=chat_message.overridden_model,
sub_questions=translate_db_sub_questions_to_server_objects(
chat_message.sub_questions
),
refined_answer_improvement=chat_message.refined_answer_improvement,
is_agentic=chat_message.is_agentic,
error=chat_message.error,
)
@@ -1111,27 +1478,6 @@ def log_agent_sub_question_results(
primary_message_id: int | None,
sub_question_answer_results: list[SubQuestionAnswerResults],
) -> None:
def _create_citation_format_list(
document_citations: list[InferenceSection],
) -> list[dict[str, Any]]:
citation_list: list[dict[str, Any]] = []
for document_citation in document_citations:
document_citation_dict = {
"link": "",
"blurb": document_citation.center_chunk.blurb,
"content": document_citation.center_chunk.content,
"metadata": document_citation.center_chunk.metadata,
"updated_at": str(document_citation.center_chunk.updated_at),
"document_id": document_citation.center_chunk.document_id,
"source_type": "file",
"source_links": document_citation.center_chunk.source_links,
"match_highlights": document_citation.center_chunk.match_highlights,
"semantic_identifier": document_citation.center_chunk.semantic_identifier,
}
citation_list.append(document_citation_dict)
return citation_list
now = datetime.now()
@@ -1141,7 +1487,7 @@ def log_agent_sub_question_results(
]
sub_question = sub_question_answer_result.question
sub_answer = sub_question_answer_result.answer
sub_document_results = _create_citation_format_list(
sub_document_results = create_citation_format_list(
sub_question_answer_result.context_documents
)
@@ -1198,3 +1544,58 @@ def update_chat_session_updated_at_timestamp(
.values(time_updated=func.now())
)
# No commit - the caller is responsible for committing the transaction
def create_search_doc_from_inference_section(
inference_section: InferenceSection,
is_internet: bool,
db_session: Session,
score: float = 0.0,
is_relevant: bool | None = None,
relevance_explanation: str | None = None,
commit: bool = False,
) -> SearchDoc:
"""Create a SearchDoc in the database from an InferenceSection."""
db_search_doc = SearchDoc(
document_id=inference_section.center_chunk.document_id,
chunk_ind=inference_section.center_chunk.chunk_id,
semantic_id=inference_section.center_chunk.semantic_identifier,
link=(
inference_section.center_chunk.source_links.get(0)
if inference_section.center_chunk.source_links
else None
),
blurb=inference_section.center_chunk.blurb,
source_type=inference_section.center_chunk.source_type,
boost=inference_section.center_chunk.boost,
hidden=inference_section.center_chunk.hidden,
doc_metadata=inference_section.center_chunk.metadata,
score=score,
is_relevant=is_relevant,
relevance_explanation=relevance_explanation,
match_highlights=inference_section.center_chunk.match_highlights,
updated_at=inference_section.center_chunk.updated_at,
primary_owners=inference_section.center_chunk.primary_owners or [],
secondary_owners=inference_section.center_chunk.secondary_owners or [],
is_internet=is_internet,
)
db_session.add(db_search_doc)
if commit:
db_session.commit()
else:
db_session.flush()
return db_search_doc
def create_search_doc_from_saved_search_doc(
saved_search_doc: SavedSearchDoc,
) -> SearchDoc:
"""Convert SavedSearchDoc to SearchDoc by excluding the additional fields"""
data = saved_search_doc.model_dump()
# Remove the fields that are specific to SavedSearchDoc
data.pop("db_doc_id", None)
# Keep score since SearchDoc has it as an optional field
return SearchDoc(**data)

View File

@@ -82,6 +82,8 @@ from onyx.utils.encryption import encrypt_string_to_bytes
from onyx.utils.headers import HeaderItemDict
from shared_configs.enums import EmbeddingProvider
from shared_configs.enums import RerankerProvider
from onyx.agents.agent_search.dr.enums import ResearchType
from onyx.agents.agent_search.dr.enums import ResearchAnswerPurpose
logger = setup_logger()
@@ -677,8 +679,8 @@ class KGEntityType(Base):
DateTime(timezone=True), server_default=func.now()
)
grounded_source_name: Mapped[str] = mapped_column(
NullFilteredString, nullable=False, index=False
grounded_source_name: Mapped[str | None] = mapped_column(
NullFilteredString, nullable=True, index=False
)
entity_values: Mapped[list[str]] = mapped_column(
@@ -2139,12 +2141,26 @@ class ChatMessage(Base):
order_by="(AgentSubQuestion.level, AgentSubQuestion.level_question_num)",
)
research_iterations: Mapped[list["ResearchAgentIteration"]] = relationship(
"ResearchAgentIteration",
foreign_keys="ResearchAgentIteration.primary_question_id",
cascade="all, delete-orphan",
)
standard_answers: Mapped[list["StandardAnswer"]] = relationship(
"StandardAnswer",
secondary=ChatMessage__StandardAnswer.__table__,
back_populates="chat_messages",
)
research_type: Mapped[ResearchType] = mapped_column(
Enum(ResearchType, native_enum=False), nullable=True
)
research_plan: Mapped[JSON_ro] = mapped_column(postgresql.JSONB(), nullable=True)
research_answer_purpose: Mapped[ResearchAnswerPurpose] = mapped_column(
Enum(ResearchAnswerPurpose, native_enum=False), nullable=True
)
class ChatFolder(Base):
"""For organizing chat sessions"""
@@ -3343,3 +3359,71 @@ class TenantAnonymousUserPath(Base):
anonymous_user_path: Mapped[str] = mapped_column(
String, nullable=False, unique=True
)
class ResearchAgentIteration(Base):
__tablename__ = "research_agent_iteration"
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
primary_question_id: Mapped[int] = mapped_column(
ForeignKey("chat_message.id", ondelete="CASCADE")
)
iteration_nr: Mapped[int] = mapped_column(Integer, nullable=False)
created_at: Mapped[datetime.datetime] = mapped_column(DateTime, nullable=False)
purpose: Mapped[str] = mapped_column(String, nullable=True)
reasoning: Mapped[str] = mapped_column(String, nullable=True)
# Relationships
primary_message: Mapped["ChatMessage"] = relationship(
"ChatMessage",
foreign_keys=[primary_question_id],
back_populates="research_iterations",
)
sub_steps: Mapped[list["ResearchAgentIterationSubStep"]] = relationship(
"ResearchAgentIterationSubStep",
primaryjoin=(
"and_("
"ResearchAgentIteration.primary_question_id == ResearchAgentIterationSubStep.primary_question_id, "
"ResearchAgentIteration.iteration_nr == ResearchAgentIterationSubStep.iteration_nr"
")"
),
foreign_keys="[ResearchAgentIterationSubStep.primary_question_id, ResearchAgentIterationSubStep.iteration_nr]",
cascade="all, delete-orphan",
)
class ResearchAgentIterationSubStep(Base):
__tablename__ = "research_agent_iteration_sub_step"
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
primary_question_id: Mapped[int] = mapped_column(
ForeignKey("chat_message.id", ondelete="CASCADE")
)
parent_question_id: Mapped[int | None] = mapped_column(
ForeignKey("research_agent_iteration_sub_step.id", ondelete="CASCADE"),
nullable=True,
)
iteration_nr: Mapped[int] = mapped_column(Integer, nullable=False)
iteration_sub_step_nr: Mapped[int] = mapped_column(Integer, nullable=False)
created_at: Mapped[datetime.datetime] = mapped_column(DateTime, nullable=False)
sub_step_instructions: Mapped[str] = mapped_column(String, nullable=True)
sub_step_tool_id: Mapped[int] = mapped_column(ForeignKey("tool.id"), nullable=True)
reasoning: Mapped[str] = mapped_column(String, nullable=True)
sub_answer: Mapped[str] = mapped_column(String, nullable=True)
cited_doc_results: Mapped[JSON_ro] = mapped_column(postgresql.JSONB())
claims: Mapped[list[str]] = mapped_column(postgresql.JSONB(), nullable=True)
additional_data: Mapped[JSON_ro] = mapped_column(postgresql.JSONB(), nullable=True)
# Relationships
primary_message: Mapped["ChatMessage"] = relationship(
"ChatMessage",
foreign_keys=[primary_question_id],
)
parent_sub_step: Mapped["ResearchAgentIterationSubStep"] = relationship(
"ResearchAgentIterationSubStep",
foreign_keys=[parent_question_id],
remote_side="ResearchAgentIterationSubStep.id",
)

View File

@@ -16,7 +16,8 @@ from onyx.db.models import User
from onyx.db.persona import mark_persona_as_deleted
from onyx.db.persona import upsert_persona
from onyx.db.prompts import get_default_prompt
from onyx.tools.built_in_tools import get_search_tool
from onyx.tools.built_in_tools import get_builtin_tool
from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.utils.errors import EERequiredError
from onyx.utils.variable_functionality import (
fetch_versioned_implementation_with_fallback,
@@ -49,9 +50,7 @@ def create_slack_channel_persona(
) -> Persona:
"""NOTE: does not commit changes"""
search_tool = get_search_tool(db_session)
if search_tool is None:
raise ValueError("Search tool not found")
search_tool = get_builtin_tool(db_session=db_session, tool_type=SearchTool)
# create/update persona associated with the Slack channel
persona_name = _build_persona_name(channel_name)

View File

@@ -47,15 +47,15 @@ logger = setup_logger()
def _get_classification_extraction_instructions() -> (
dict[str, dict[str, KGEntityTypeInstructions]]
dict[str | None, dict[str, KGEntityTypeInstructions]]
):
"""
Prepare the classification instructions for the given source.
"""
classification_instructions_dict: dict[str, dict[str, KGEntityTypeInstructions]] = (
{}
)
classification_instructions_dict: dict[
str | None, dict[str, KGEntityTypeInstructions]
] = {}
with get_session_with_current_tenant() as db_session:
entity_types = get_entity_types(db_session, active=True)

View File

@@ -32,9 +32,7 @@ def format_entity_id_for_models(entity_id_name: str) -> str:
separator = entity_type = ""
formatted_entity_type = entity_type.strip().upper()
formatted_entity_name = (
entity_name.strip().replace('"', "").replace("'", "").title()
)
formatted_entity_name = entity_name.strip().replace('"', "").replace("'", "")
return f"{formatted_entity_type}{separator}{formatted_entity_name}"

View File

@@ -6,6 +6,7 @@ from langchain.schema.messages import HumanMessage
from langchain.schema.messages import SystemMessage
from pydantic import BaseModel
from onyx.agents.agent_search.dr.enums import ResearchAnswerPurpose
from onyx.configs.constants import MessageType
from onyx.file_store.models import InMemoryChatFile
from onyx.llm.utils import build_content_with_imgs
@@ -25,6 +26,7 @@ class PreviousMessage(BaseModel):
files: list[InMemoryChatFile]
tool_call: ToolCallFinalResult | None
refined_answer_improvement: bool | None
research_answer_purpose: ResearchAnswerPurpose | None
@classmethod
def from_chat_message(
@@ -52,6 +54,7 @@ class PreviousMessage(BaseModel):
else None
),
refined_answer_improvement=chat_message.refined_answer_improvement,
research_answer_purpose=chat_message.research_answer_purpose,
)
def to_langchain_msg(self) -> BaseMessage:

View File

@@ -9,7 +9,6 @@ from slack_sdk.socket_mode.request import SocketModeRequest
from slack_sdk.webhook import WebhookClient
from onyx.chat.models import ChatOnyxBotResponse
from onyx.chat.models import CitationInfo
from onyx.chat.models import QADocsResponse
from onyx.configs.constants import MessageType
from onyx.configs.constants import SearchFeedbackType
@@ -50,6 +49,7 @@ from onyx.onyxbot.slack.utils import respond_in_thread_or_channel
from onyx.onyxbot.slack.utils import TenantSocketModeClient
from onyx.onyxbot.slack.utils import update_emote_react
from onyx.server.query_and_chat.models import ChatMessageDetail
from onyx.server.query_and_chat.streaming_models import CitationInfo
from onyx.utils.logger import setup_logger

File diff suppressed because it is too large Load Diff

View File

@@ -669,8 +669,8 @@ that should be used to analyze each object/each source (or 'the object' that fit
}}
Do not include any other text or explanations.
"""
SOURCE_DETECTION_PROMPT = f"""
You are an expert in generating, understanding and analyzing SQL statements.
@@ -773,11 +773,29 @@ Please structure your answer using <reasoning>, </reasoning>,<sql>, </sql> start
""".strip()
SIMPLE_SQL_PROMPT = f"""
You are an expert in generating a SQL statement that only uses ONE TABLE that captures RELATIONSHIPS \
between TWO ENTITIES. The table has the following structure:
ENTITY_TABLE_DESCRIPTION = f"""\
- Table name: entity_table
- Columns:
- entity (str): The name of the ENTITY, combining the nature of the entity and the id of the entity. \
It is of the form <entity_type>::<entity_name> [example: ACCOUNT::625482894].
- entity_type (str): the type of the entity [example: ACCOUNT].
- entity_attributes (json): the attributes of the entity [example: {{"priority": "high", "status": "active"}}]
- source_document (str): the id of the document that contains the entity. Note that the combination of \
id_name and source_document IS UNIQUE!
- source_date (timestamp): the 'event' date of the source document [example: 2025-04-25 21:43:31.054741+00]
{SEPARATOR_LINE}
Importantly, here are the entity (node) types that you can use, with a short description of what they mean. You may need to \
identify the proper entity type through its description. Also notice the allowed attributes for each entity type and \
their values, if provided. Of particular importance is the 'subtype' attribute, if provided, as this is how \
the entity type may also often be referred to.
{SEPARATOR_LINE}
---entity_types---
{SEPARATOR_LINE}
"""
RELATIONSHIP_TABLE_DESCRIPTION = f"""\
- Table name: relationship_table
- Columns:
- relationship (str): The name of the RELATIONSHIP, combining the nature of the relationship and the names of the entities. \
@@ -803,17 +821,27 @@ id_name and source_document IS UNIQUE!
Importantly, here are the entity (node) types that you can use, with a short description of what they mean. You may need to \
identify the proper entity type through its description. Also notice the allowed attributes for each entity type and \
their values, if provided.
their values, if provided. Of particular importance is the 'subtype' attribute, if provided, as this is how \
the entity type may also often be referred to.
{SEPARATOR_LINE}
---entity_types---
{SEPARATOR_LINE}
Here are the relationship types that are in the table, denoted as <source_entity_type>__<relationship_type>__<target_entity_type>:
Here are the relationship types that are in the table, denoted as <source_entity_type>__<relationship_type>__<target_entity_type>.
In the table, the actual relationships are not quite of this form, but each <entity_type> is followed by '::<entity_name>' \
in the relationship id as shown above.
{SEPARATOR_LINE}
---relationship_types---
{SEPARATOR_LINE}
In the table, the actual relationships are not quite of this form, but each <entity_type> is followed by ':<entity_name>' in the \
relationship id as shown above..
"""
SIMPLE_SQL_PROMPT = f"""
You are an expert in generating a SQL statement that only uses ONE TABLE that captures RELATIONSHIPS \
between TWO ENTITIES. The table has the following structure:
{SEPARATOR_LINE}
{RELATIONSHIP_TABLE_DESCRIPTION}
Here is the question you are supposed to translate into a SQL statement:
{SEPARATOR_LINE}
@@ -936,7 +964,7 @@ Please structure your answer using <reasoning>, </reasoning>, <sql>, </sql> star
<sql>[the SQL statement that you generate to satisfy the task]</sql>
""".strip()
# TODO: remove following before merging after enough testing
SIMPLE_SQL_CORRECTION_PROMPT = f"""
You are an expert in reviewing and fixing SQL statements.
@@ -949,7 +977,7 @@ Guidance:
SELECT statement as well! And it needs to be in the EXACT FORM! So if a \
conversion took place, make sure to include the conversion in the SELECT and the ORDER BY clause!
- never should 'source_document' be in the SELECT clause! Remove if present!
- if there are joins, they must be on entities, never sour ce documents
- if there are joins, they must be on entities, never source documents
- if there are joins, consider the possibility that the second entity does not exist for all examples.\
Therefore consider using LEFT joins (or RIGHT joins) as appropriate.
@@ -969,26 +997,7 @@ You are an expert in generating a SQL statement that only uses ONE TABLE that ca
and their attributes and other data. The table has the following structure:
{SEPARATOR_LINE}
- Table name: entity_table
- Columns:
- entity (str): The name of the ENTITY, combining the nature of the entity and the id of the entity. \
It is of the form <entity_type>::<entity_name> [example: ACCOUNT::625482894].
- entity_type (str): the type of the entity [example: ACCOUNT].
- entity_attributes (json): the attributes of the entity [example: {{"priority": "high", "status": "active"}}]
- source_document (str): the id of the document that contains the entity. Note that the combination of \
id_name and source_document IS UNIQUE!
- source_date (timestamp): the 'event' date of the source document [example: 2025-04-25 21:43:31.054741+00]
{SEPARATOR_LINE}
Importantly, here are the entity (node) types that you can use, with a short description of what they mean. You may need to \
identify the proper entity type through its description. Also notice the allowed attributes for each entity type and \
their values, if provided. Of particular importance is the 'subtype' attribute, if provided, as this is how \
the entity type may also often be referred to.
{SEPARATOR_LINE}
---entity_types---
{SEPARATOR_LINE}
{ENTITY_TABLE_DESCRIPTION}
Here is the question you are supposed to translate into a SQL statement:
{SEPARATOR_LINE}
@@ -1077,33 +1086,55 @@ Please structure your answer using <reasoning>, </reasoning>, <sql>, </sql> star
<sql>[the SQL statement that you generate to satisfy the task]</sql>
""".strip()
SIMPLE_SQL_ERROR_FIX_PROMPT = f"""
You are an expert at fixing SQL statements. You will be provided with a SQL statement that aims to address \
a question, but it contains an error. Your task is to fix the SQL statement, based on the error message.
SQL_AGGREGATION_REMOVAL_PROMPT = f"""
You are a SQL expert. You were provided with a SQL statement that returns an aggregation, and you are \
tasked to show the underlying objects that were aggregated. For this you need to remove the aggregate functions \
from the SQL statement in the correct way.
Here is the description of the table that the SQL statement is supposed to use:
---table_description---
Additional rules:
- if you see a 'select count(*)', you should NOT convert \
that to 'select *...', but rather return the corresponding id_name, entity_type_id_name, name, and document_id. \
As in: 'select <table, if necessary>.id_name, <table, if necessary>.entity_type_id_name, \
<table, if necessary>.name, <table, if necessary>.document_id ...'. \
The id_name is always the primary index, and those should be returned, along with the type (entity_type_id_name), \
the name (name) of the objects, and the document_id (document_id) of the object.
- Add a limit of 30 to the select statement.
- Don't change anything else.
- The final select statement needs obviously to be a valid SQL statement.
Here is the question you are supposed to translate into a SQL statement:
{SEPARATOR_LINE}
---question---
{SEPARATOR_LINE}
Here is the SQL statement you are supposed to remove the aggregate functions from:
Here is the SQL statement that you should fix:
{SEPARATOR_LINE}
---sql_statement---
{SEPARATOR_LINE}
Here is the error message that was returned:
{SEPARATOR_LINE}
---error_message---
{SEPARATOR_LINE}
Note that in the case the error states the sql statement did not return any results, it is possible that the \
sql statement is correct, but the question is not addressable with the information in the knowledge graph. \
If you are absolutely certain that is the case, you may return the original sql statement.
Here are a couple common errors that you may encounter:
- source_document is in the SELECT clause -> remove it
- columns used in ORDER BY must also appear in the SELECT DISTINCT clause
- consider carefully the type of the columns you are using, especially for attributes. You may need to cast them
- dates are ALWAYS in string format of the form YYYY-MM-DD, for source date as well as for date-like the attributes! \
So please use that format, particularly if you use data comparisons (>, <, ...)
- attributes are stored in the attributes json field. As this is postgres, querying for those must be done as \
"attributes ->> '<attribute>' = '<attribute value>'" (or "attributes ? '<attribute>'" to check for existence).
- if you are using joins and the sql returned no joins, make sure you are using the appropriate join type (LEFT, RIGHT, etc.) \
it is possible that the second entity does not exist for all examples.
- (ignore if using entity_table) if using the relationship_table and the sql returned no results, make sure you are \
selecting the correct column! Use the available relationship types to determine whether to use the source or target entity.
APPROACH:
Please think through this step by step. Please also bear in mind that the sql statement is written in postgres syntax.
Also, in case it is important, today is ---today_date--- and the user/employee asking is ---user_name---.
Please structure your answer using <reasoning>, </reasoning>, <sql>, </sql> start and end tags as in:
<reasoning>[your short step-by step thinking]</reasoning>
<sql>[the SQL statement without the aggregate functions]</sql>
""".strip()
<reasoning>[think through the logic but do so extremely briefly! Not more than 3-4 sentences.]</reasoning>
<sql>[the SQL statement that you generate to satisfy the task]</sql>
"""
SEARCH_FILTER_CONSTRUCTION_PROMPT = f"""

View File

@@ -0,0 +1,43 @@
import re
class PromptTemplate:
"""
A class for building prompt templates with placeholders.
Useful when building templates with json schemas, as {} will not work with f-strings.
Unlike string.replace, this class will raise an error if the fields are missing.
"""
DEFAULT_PATTERN = r"---([a-zA-Z0-9_]+)---"
def __init__(self, template: str, pattern: str = DEFAULT_PATTERN):
self._pattern_str = pattern
self._pattern = re.compile(pattern)
self._template = template
self._fields: set[str] = set(self._pattern.findall(template))
def build(self, **kwargs: str) -> str:
"""
Build the prompt template with the given fields.
Will raise an error if the fields are missing.
Will ignore fields that are not in the template.
"""
missing = self._fields - set(kwargs.keys())
if missing:
raise ValueError(f"Missing required fields: {missing}.")
return self._replace_fields(kwargs)
def partial_build(self, **kwargs: str) -> "PromptTemplate":
"""
Returns another PromptTemplate with the given fields replaced.
Will ignore fields that are not in the template.
"""
new_template = self._replace_fields(kwargs)
return PromptTemplate(new_template, self._pattern_str)
def _replace_fields(self, field_vals: dict[str, str]) -> str:
def repl(match: re.Match) -> str:
key = match.group(1)
return field_vals.get(key, match.group(0))
return self._pattern.sub(repl, self._template)

View File

@@ -3,6 +3,8 @@ from fastapi import Depends
from sqlalchemy.orm import Session
from onyx.auth.users import current_admin_user
from onyx.configs.constants import TMP_DRALPHA_PERSONA_NAME
from onyx.configs.kg_configs import KG_BETA_ASSISTANT_DESCRIPTION
from onyx.context.search.enums import RecencyBiasSetting
from onyx.db.engine.sql_engine import get_session
from onyx.db.entities import get_entity_stats_by_grounded_source_name
@@ -31,12 +33,13 @@ from onyx.server.kg.models import KGConfig
from onyx.server.kg.models import KGConfig as KGConfigAPIModel
from onyx.server.kg.models import SourceAndEntityTypeView
from onyx.server.kg.models import SourceStatistics
from onyx.tools.built_in_tools import get_search_tool
from onyx.tools.built_in_tools import get_builtin_tool
from onyx.tools.tool_implementations.knowledge_graph.knowledge_graph_tool import (
KnowledgeGraphTool,
)
from onyx.tools.tool_implementations.search.search_tool import SearchTool
_KG_BETA_ASSISTANT_DESCRIPTION = "The KG Beta assistant uses the Onyx Knowledge Graph (beta) structure \
to answer questions"
admin_router = APIRouter(prefix="/admin/kg")
@@ -95,12 +98,9 @@ def enable_or_disable_kg(
enable_kg(enable_req=req)
populate_missing_default_entity_types__commit(db_session=db_session)
# Create or restore KG Beta persona
# Get the search tool
search_tool = get_search_tool(db_session=db_session)
if not search_tool:
raise RuntimeError("SearchTool not found in the database.")
# Get the search and knowledge graph tools
search_tool = get_builtin_tool(db_session=db_session, tool_type=SearchTool)
kg_tool = get_builtin_tool(db_session=db_session, tool_type=KnowledgeGraphTool)
# Check if we have a previously created persona
kg_config_settings = get_kg_config_settings()
@@ -132,8 +132,8 @@ def enable_or_disable_kg(
is_public = len(user_ids) == 0
persona_request = PersonaUpsertRequest(
name="KG Beta",
description=_KG_BETA_ASSISTANT_DESCRIPTION,
name=TMP_DRALPHA_PERSONA_NAME,
description=KG_BETA_ASSISTANT_DESCRIPTION,
system_prompt=KG_BETA_ASSISTANT_SYSTEM_PROMPT,
task_prompt=KG_BETA_ASSISTANT_TASK_PROMPT,
datetime_aware=False,
@@ -145,7 +145,7 @@ def enable_or_disable_kg(
recency_bias=RecencyBiasSetting.NO_DECAY,
prompt_ids=[0],
document_set_ids=[],
tool_ids=[search_tool.id],
tool_ids=[search_tool.id, kg_tool.id],
llm_model_provider_override=None,
llm_model_version_override=None,
starter_messages=None,

View File

@@ -47,6 +47,7 @@ from onyx.db.chat import get_chat_sessions_by_user
from onyx.db.chat import get_or_create_root_message
from onyx.db.chat import set_as_latest_chat_message
from onyx.db.chat import translate_db_message_to_chat_message_detail
from onyx.db.chat import translate_db_message_to_packets
from onyx.db.chat import update_chat_session
from onyx.db.chat_search import search_chat_sessions
from onyx.db.connector import create_connector
@@ -92,6 +93,8 @@ from onyx.server.query_and_chat.models import RenameChatSessionResponse
from onyx.server.query_and_chat.models import SearchFeedbackRequest
from onyx.server.query_and_chat.models import UpdateChatSessionTemperatureRequest
from onyx.server.query_and_chat.models import UpdateChatSessionThreadRequest
from onyx.server.query_and_chat.streaming_models import OverallStop
from onyx.server.query_and_chat.streaming_models import Packet
from onyx.server.query_and_chat.token_limit import check_token_rate_limits
from onyx.utils.file_types import UploadMimeTypes
from onyx.utils.headers import get_custom_tool_additional_request_headers
@@ -233,6 +236,24 @@ def get_chat_session(
prefetch_tool_calls=True,
)
# Convert messages to ChatMessageDetail format
chat_message_details = [
translate_db_message_to_chat_message_detail(msg) for msg in session_messages
]
simplified_packet_lists: list[list[Packet]] = []
end_step_nr = 1
for msg in session_messages:
if msg.message_type == MessageType.ASSISTANT:
msg_packet_object = translate_db_message_to_packets(
msg, db_session=db_session, start_step_nr=end_step_nr
)
end_step_nr = msg_packet_object.end_step_nr
msg_packet_list = msg_packet_object.packet_list
msg_packet_list.append(Packet(ind=end_step_nr, obj=OverallStop()))
simplified_packet_lists.append(msg_packet_list)
return ChatSessionDetailResponse(
chat_session_id=session_id,
description=chat_session.description,
@@ -245,13 +266,13 @@ def get_chat_session(
chat_session.persona.icon_shape if chat_session.persona else None
),
current_alternate_model=chat_session.current_alternate_model,
messages=[
translate_db_message_to_chat_message_detail(msg) for msg in session_messages
],
messages=chat_message_details,
time_created=chat_session.time_created,
shared_status=chat_session.shared_status,
current_temperature_override=chat_session.temperature_override,
deleted=chat_session.deleted,
# specifically for the Onyx Chat UI
packets=simplified_packet_lists,
)

View File

@@ -22,6 +22,7 @@ from onyx.db.enums import ChatSessionSharedStatus
from onyx.file_store.models import FileDescriptor
from onyx.llm.override_models import LLMOverride
from onyx.llm.override_models import PromptOverride
from onyx.server.query_and_chat.streaming_models import Packet
from onyx.tools.models import ToolCallFinalResult
@@ -240,11 +241,8 @@ class ChatMessageDetail(BaseModel):
chat_session_id: UUID | None = None
# Dict mapping citation number to db_doc_id
citations: dict[int, int] | None = None
sub_questions: list[SubQuestionDetail] | None = None
files: list[FileDescriptor]
tool_call: ToolCallFinalResult | None
refined_answer_improvement: bool | None = None
is_agentic: bool | None = None
error: str | None = None
def model_dump(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
@@ -274,6 +272,8 @@ class ChatSessionDetailResponse(BaseModel):
current_temperature_override: float | None
deleted: bool = False
packets: list[list[Packet]]
# This one is not used anymore
class QueryValidationResponse(BaseModel):

View File

@@ -0,0 +1,190 @@
from collections import OrderedDict
from collections.abc import Mapping
from typing import Annotated
from typing import Literal
from typing import Union
from pydantic import BaseModel
from pydantic import Field
from onyx.context.search.models import SavedSearchDoc
class BaseObj(BaseModel):
type: str = ""
"""Basic Message Packets"""
class MessageStart(BaseObj):
type: Literal["message_start"] = "message_start"
# Merged set of all documents considered
final_documents: list[SavedSearchDoc] | None
content: str
class MessageDelta(BaseObj):
content: str
type: Literal["message_delta"] = "message_delta"
"""Control Packets"""
class OverallStop(BaseObj):
type: Literal["stop"] = "stop"
class SectionEnd(BaseObj):
type: Literal["section_end"] = "section_end"
"""Tool Packets"""
class SearchToolStart(BaseObj):
type: Literal["internal_search_tool_start"] = "internal_search_tool_start"
is_internet_search: bool = False
class SearchToolDelta(BaseObj):
type: Literal["internal_search_tool_delta"] = "internal_search_tool_delta"
queries: list[str] | None = None
documents: list[SavedSearchDoc] | None = None
class ImageGenerationToolStart(BaseObj):
type: Literal["image_generation_tool_start"] = "image_generation_tool_start"
class ImageGenerationToolDelta(BaseObj):
type: Literal["image_generation_tool_delta"] = "image_generation_tool_delta"
images: list[dict[str, str]] | None = None
class CustomToolStart(BaseObj):
type: Literal["custom_tool_start"] = "custom_tool_start"
tool_name: str
class CustomToolDelta(BaseObj):
type: Literal["custom_tool_delta"] = "custom_tool_delta"
tool_name: str
response_type: str
# For non-file responses
data: dict | list | str | int | float | bool | None = None
# For file-based responses like image/csv
file_ids: list[str] | None = None
"""Reasoning Packets"""
class ReasoningStart(BaseObj):
type: Literal["reasoning_start"] = "reasoning_start"
class ReasoningDelta(BaseObj):
type: Literal["reasoning_delta"] = "reasoning_delta"
reasoning: str
"""Citation Packets"""
class CitationStart(BaseObj):
type: Literal["citation_start"] = "citation_start"
class SubQuestionIdentifier(BaseModel):
"""None represents references to objects in the original flow. To our understanding,
these will not be None in the packets returned from agent search.
"""
level: int | None = None
level_question_num: int | None = None
@staticmethod
def make_dict_by_level(
original_dict: Mapping[tuple[int, int], "SubQuestionIdentifier"],
) -> dict[int, list["SubQuestionIdentifier"]]:
"""returns a dict of level to object list (sorted by level_question_num)
Ordering is asc for readability.
"""
# organize by level, then sort ascending by question_index
level_dict: dict[int, list[SubQuestionIdentifier]] = {}
# group by level
for k, obj in original_dict.items():
level = k[0]
if level not in level_dict:
level_dict[level] = []
level_dict[level].append(obj)
# for each level, sort the group
for k2, value2 in level_dict.items():
# we need to handle the none case due to SubQuestionIdentifier typing
# level_question_num as int | None, even though it should never be None here.
level_dict[k2] = sorted(
value2,
key=lambda x: (x.level_question_num is None, x.level_question_num),
)
# sort by level
sorted_dict = OrderedDict(sorted(level_dict.items()))
return sorted_dict
class CitationInfo(SubQuestionIdentifier):
citation_num: int
document_id: str
class CitationDelta(BaseObj):
type: Literal["citation_delta"] = "citation_delta"
citations: list[CitationInfo] | None = None
"""Packet"""
# Discriminated union of all possible packet object types
PacketObj = Annotated[
Union[
MessageStart,
MessageDelta,
OverallStop,
SectionEnd,
SearchToolStart,
SearchToolDelta,
ImageGenerationToolStart,
ImageGenerationToolDelta,
CustomToolStart,
CustomToolDelta,
ReasoningStart,
ReasoningDelta,
CitationStart,
CitationDelta,
],
Field(discriminator="type"),
]
class Packet(BaseModel):
ind: int
obj: PacketObj
class EndStepPacketList(BaseModel):
end_step_nr: int
packet_list: list[Packet]

View File

@@ -0,0 +1,318 @@
from onyx.configs.constants import MessageType
from onyx.file_store.models import ChatFileType
from onyx.server.query_and_chat.models import ChatMessageDetail
from onyx.server.query_and_chat.streaming_models import CitationDelta
from onyx.server.query_and_chat.streaming_models import CitationInfo
from onyx.server.query_and_chat.streaming_models import CitationStart
from onyx.server.query_and_chat.streaming_models import CustomToolDelta
from onyx.server.query_and_chat.streaming_models import CustomToolStart
from onyx.server.query_and_chat.streaming_models import ImageGenerationToolDelta
from onyx.server.query_and_chat.streaming_models import ImageGenerationToolStart
from onyx.server.query_and_chat.streaming_models import MessageDelta
from onyx.server.query_and_chat.streaming_models import MessageStart
from onyx.server.query_and_chat.streaming_models import OverallStop
from onyx.server.query_and_chat.streaming_models import Packet
from onyx.server.query_and_chat.streaming_models import SearchToolDelta
from onyx.server.query_and_chat.streaming_models import SearchToolStart
from onyx.server.query_and_chat.streaming_models import SectionEnd
def create_simplified_packets_for_message(
message: ChatMessageDetail, packet_index_start: int = 0
) -> list[Packet]:
"""
Convert a ChatMessageDetail into simplified streaming packets that represent
what would have been sent during the original streaming response.
Args:
message: The chat message to convert to packets
packet_index_start: Starting index for packet numbering
Returns:
List of simplified packets representing the message
"""
packets: list[Packet] = []
current_index = packet_index_start
# Only create packets for assistant messages
if message.message_type != MessageType.ASSISTANT:
return packets
# Handle all tool-related packets in one unified block
# Check for tool calls first, then fall back to inferred tools from context/files
if message.tool_call:
tool_call = message.tool_call
# Handle different tool types based on tool name
if tool_call.tool_name == "run_search":
# Handle search tools - create search tool packets
# Use context docs if available, otherwise use tool result
if message.context_docs and message.context_docs.top_documents:
search_docs = message.context_docs.top_documents
# Start search tool
packets.append(
Packet(
ind=current_index,
obj=SearchToolStart(),
)
)
# Include queries and documents in the delta
if message.rephrased_query and message.rephrased_query.strip():
queries = [str(message.rephrased_query)]
else:
queries = [message.message]
packets.append(
Packet(
ind=current_index,
obj=SearchToolDelta(
queries=queries,
documents=search_docs,
),
)
)
# End search tool
packets.append(
Packet(
ind=current_index,
obj=SectionEnd(),
)
)
current_index += 1
elif tool_call.tool_name == "run_image_generation":
# Handle image generation tools - create image generation packets
# Use files if available, otherwise create from tool result
if message.files:
image_files = [
f for f in message.files if f["type"] == ChatFileType.IMAGE
]
if image_files:
# Start image tool
image_tool_start = ImageGenerationToolStart()
packets.append(Packet(ind=current_index, obj=image_tool_start))
# Send images via tool delta
images = []
for file in image_files:
images.append(
{
"id": file["id"],
"url": "", # URL will be constructed by frontend
"prompt": file.get("name") or "Generated image",
}
)
image_tool_delta = ImageGenerationToolDelta(images=images)
packets.append(Packet(ind=current_index, obj=image_tool_delta))
# End image tool
image_tool_end = SectionEnd()
packets.append(Packet(ind=current_index, obj=image_tool_end))
current_index += 1
elif tool_call.tool_name == "run_internet_search":
# Internet search tools return document data, but should be treated as custom tools
# for packet purposes since they have a different data structure
# Start custom tool
custom_tool_start = CustomToolStart(tool_name=tool_call.tool_name)
packets.append(Packet(ind=current_index, obj=custom_tool_start))
# Send internet search results as custom tool data
custom_tool_delta = CustomToolDelta(
tool_name=tool_call.tool_name,
response_type="json",
data=tool_call.tool_result,
file_ids=None,
)
packets.append(Packet(ind=current_index, obj=custom_tool_delta))
# End custom tool
custom_tool_end = SectionEnd()
packets.append(Packet(ind=current_index, obj=custom_tool_end))
current_index += 1
else:
# Handle custom tools and any other tool types
# Start custom tool
custom_tool_start = CustomToolStart(tool_name=tool_call.tool_name)
packets.append(Packet(ind=current_index, obj=custom_tool_start))
# Determine response type and data from tool result
response_type = "json" # default
data = None
file_ids = None
if tool_call.tool_result:
# Check if it's a custom tool call summary (most common case)
if isinstance(tool_call.tool_result, dict):
# Try to extract response_type if it's structured like CustomToolCallSummary
if "response_type" in tool_call.tool_result:
response_type = tool_call.tool_result["response_type"]
tool_result = tool_call.tool_result.get("tool_result")
# Handle file-based responses
if isinstance(tool_result, dict) and "file_ids" in tool_result:
file_ids = tool_result["file_ids"]
else:
data = tool_result
else:
# Plain dict response
data = tool_call.tool_result
else:
# Non-dict response (string, number, etc.)
data = tool_call.tool_result
# Send tool response via tool delta
custom_tool_delta = CustomToolDelta(
tool_name=tool_call.tool_name,
response_type=response_type,
data=data,
file_ids=file_ids,
)
packets.append(Packet(ind=current_index, obj=custom_tool_delta))
# End custom tool
custom_tool_end = SectionEnd()
packets.append(Packet(ind=current_index, obj=custom_tool_end))
current_index += 1
# Fallback handling for when there's no explicit tool_call but we have tool-related data
elif message.context_docs and message.context_docs.top_documents:
# Handle search results without explicit tool call (legacy support)
search_docs = message.context_docs.top_documents
# Start search tool
packets.append(
Packet(
ind=current_index,
obj=SearchToolStart(),
)
)
# Include queries and documents in the delta
if message.rephrased_query and message.rephrased_query.strip():
queries = [str(message.rephrased_query)]
else:
queries = [message.message]
packets.append(
Packet(
ind=current_index,
obj=SearchToolDelta(
queries=queries,
documents=search_docs,
),
)
)
# End search tool
packets.append(
Packet(
ind=current_index,
obj=SectionEnd(),
)
)
current_index += 1
# Handle image files without explicit tool call (legacy support)
if message.files:
image_files = [f for f in message.files if f["type"] == ChatFileType.IMAGE]
if image_files and not message.tool_call:
# Only create image packets if there's no tool call that might have handled them
# Start image tool
image_tool_start = ImageGenerationToolStart()
packets.append(Packet(ind=current_index, obj=image_tool_start))
# Send images via tool delta
images = []
for file in image_files:
images.append(
{
"id": file["id"],
"url": "", # URL will be constructed by frontend
"prompt": file.get("name") or "Generated image",
}
)
image_tool_delta = ImageGenerationToolDelta(images=images)
packets.append(Packet(ind=current_index, obj=image_tool_delta))
# End image tool
image_tool_end = SectionEnd()
packets.append(Packet(ind=current_index, obj=image_tool_end))
current_index += 1
# Create Citation packets if there are citations
if message.citations:
# Start citation flow
citation_start = CitationStart()
packets.append(Packet(ind=current_index, obj=citation_start))
# Create citation data
# Convert dict[int, int] to list[StreamingCitation] format
citations_list: list[CitationInfo] = []
for citation_num, doc_id in message.citations.items():
citation = CitationInfo(citation_num=citation_num, document_id=str(doc_id))
citations_list.append(citation)
# Send citations via citation delta
citation_delta = CitationDelta(citations=citations_list)
packets.append(Packet(ind=current_index, obj=citation_delta))
# End citation flow
citation_end = SectionEnd()
packets.append(Packet(ind=current_index, obj=citation_end))
current_index += 1
# Create MESSAGE_START packet
message_start = MessageStart(
content="",
final_documents=(
message.context_docs.top_documents if message.context_docs else None
),
)
packets.append(Packet(ind=current_index, obj=message_start))
# Create MESSAGE_DELTA packet with the full message content
# In a real streaming scenario, this would be broken into multiple deltas
if message.message:
message_delta = MessageDelta(content=message.message)
packets.append(Packet(ind=current_index, obj=message_delta))
# Create MESSAGE_END packet
message_end = SectionEnd()
packets.append(Packet(ind=current_index, obj=message_end))
current_index += 1
# Create STOP packet
stop = OverallStop()
packets.append(Packet(ind=current_index, obj=stop))
return packets
def create_simplified_packets_for_session(
messages: list[ChatMessageDetail],
) -> list[list[Packet]]:
"""
Convert a list of chat messages into simplified streaming packets organized by message.
Each inner list contains packets for a single assistant message.
Args:
messages: List of chat messages from the session
Returns:
List of lists of simplified packets, where each inner list represents one assistant message
"""
packets_by_message: list[list[Packet]] = []
for message in messages:
if message.message_type == MessageType.ASSISTANT:
message_packets = create_simplified_packets_for_message(message, 0)
if message_packets: # Only add if there are actual packets
packets_by_message.append(message_packets)
return packets_by_message

View File

@@ -17,6 +17,9 @@ from onyx.tools.tool_implementations.internet_search.internet_search_tool import
from onyx.tools.tool_implementations.internet_search.providers import (
get_available_providers,
)
from onyx.tools.tool_implementations.knowledge_graph.knowledge_graph_tool import (
KnowledgeGraphTool,
)
from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.tools.tool import Tool
from onyx.utils.logger import setup_logger
@@ -63,6 +66,15 @@ BUILT_IN_TOOLS: list[InCodeToolInfo] = [
if (bool(get_available_providers()))
else []
),
InCodeToolInfo(
cls=KnowledgeGraphTool,
description=(
"The Knowledge Graph Search Action allows the assistant to search the knowledge graph for information."
"This tool should only be used by the Deep Research Agent, not via tool calling."
),
in_code_tool_id=KnowledgeGraphTool.__name__,
display_name=KnowledgeGraphTool._DISPLAY_NAME,
),
]
@@ -106,27 +118,37 @@ def load_builtin_tools(db_session: Session) -> None:
logger.notice("All built-in tools are loaded/verified.")
def get_search_tool(db_session: Session) -> ToolDBModel | None:
def get_builtin_tool(
db_session: Session,
tool_type: Type[
SearchTool | ImageGenerationTool | InternetSearchTool | KnowledgeGraphTool
],
) -> ToolDBModel:
"""
Retrieves for the SearchTool from the BUILT_IN_TOOLS list.
Retrieves a built-in tool from the database based on the tool type.
"""
search_tool_id = next(
tool_id = next(
(
tool["in_code_tool_id"]
for tool in BUILT_IN_TOOLS
if tool["cls"].__name__ == SearchTool.__name__
if tool["cls"].__name__ == tool_type.__name__
),
None,
)
if not search_tool_id:
raise RuntimeError("SearchTool not found in the BUILT_IN_TOOLS list.")
if not tool_id:
raise RuntimeError(
f"Tool type {tool_type.__name__} not found in the BUILT_IN_TOOLS list."
)
search_tool = db_session.execute(
select(ToolDBModel).where(ToolDBModel.in_code_tool_id == search_tool_id)
db_tool = db_session.execute(
select(ToolDBModel).where(ToolDBModel.in_code_tool_id == tool_id)
).scalar_one_or_none()
return search_tool
if not db_tool:
raise RuntimeError(f"Tool type {tool_type.__name__} not found in the database.")
return db_tool
def auto_add_search_tool_to_personas(db_session: Session) -> None:
@@ -136,10 +158,7 @@ def auto_add_search_tool_to_personas(db_session: Session) -> None:
Persona objects that were created before the concept of Tools were added.
"""
# Fetch the SearchTool from the database based on in_code_tool_id from BUILT_IN_TOOLS
search_tool = get_search_tool(db_session)
if not search_tool:
raise RuntimeError("SearchTool not found in the database.")
search_tool = get_builtin_tool(db_session=db_session, tool_type=SearchTool)
# Fetch all Personas that need the SearchTool added
personas_to_update = (

View File

@@ -20,6 +20,11 @@ OVERRIDE_T = TypeVar("OVERRIDE_T")
class Tool(abc.ABC, Generic[OVERRIDE_T]):
@property
@abc.abstractmethod
def id(self) -> int:
raise NotImplementedError
@property
@abc.abstractmethod
def name(self) -> str:

View File

@@ -16,6 +16,7 @@ from onyx.configs.app_configs import AZURE_DALLE_DEPLOYMENT_NAME
from onyx.configs.app_configs import IMAGE_MODEL_NAME
from onyx.configs.chat_configs import NUM_INTERNET_SEARCH_CHUNKS
from onyx.configs.chat_configs import NUM_INTERNET_SEARCH_RESULTS
from onyx.configs.constants import TMP_DRALPHA_PERSONA_NAME
from onyx.configs.model_configs import GEN_AI_TEMPERATURE
from onyx.context.search.enums import LLMEvaluationType
from onyx.context.search.enums import OptionalSearchSetting
@@ -41,6 +42,9 @@ from onyx.tools.tool_implementations.images.image_generation_tool import (
from onyx.tools.tool_implementations.internet_search.internet_search_tool import (
InternetSearchTool,
)
from onyx.tools.tool_implementations.knowledge_graph.knowledge_graph_tool import (
KnowledgeGraphTool,
)
from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.tools.utils import compute_all_tool_tokens
from onyx.tools.utils import explicit_tool_calling_supported
@@ -265,6 +269,14 @@ def construct_tools(
"Internet search tool requires a Bing or Exa API key, please contact your Onyx admin to get it added!"
)
# Handle KG Tool
elif tool_cls.__name__ == KnowledgeGraphTool.__name__:
if persona.name != TMP_DRALPHA_PERSONA_NAME:
raise ValueError(
f"Knowledge Graph Tool should only be used by the '{TMP_DRALPHA_PERSONA_NAME}' Agent."
)
tool_dict[db_tool_model.id] = [KnowledgeGraphTool()]
# Handle custom tools
elif db_tool_model.openapi_schema:
if not custom_tool_config:

View File

@@ -17,6 +17,8 @@ from requests import JSONDecodeError
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.configs.constants import FileOrigin
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.tools import get_tools
from onyx.file_store.file_store import get_default_file_store
from onyx.file_store.models import ChatFileType
from onyx.file_store.models import InMemoryChatFile
@@ -77,6 +79,7 @@ class CustomToolCallSummary(BaseModel):
class CustomTool(BaseTool):
def __init__(
self,
id: int,
method_spec: MethodSpec,
base_url: str,
custom_headers: list[HeaderItemDict] | None = None,
@@ -86,6 +89,7 @@ class CustomTool(BaseTool):
self._method_spec = method_spec
self._tool_definition = self._method_spec.to_tool_definition()
self._user_oauth_token = user_oauth_token
self._id = id
self._name = self._method_spec.name
self._description = self._method_spec.summary
@@ -107,6 +111,10 @@ class CustomTool(BaseTool):
if self._user_oauth_token:
self.headers["Authorization"] = f"Bearer {self._user_oauth_token}"
@property
def id(self) -> int:
return self._id
@property
def name(self) -> str:
return self._name
@@ -382,11 +390,27 @@ def build_custom_tools_from_openapi_schema_and_headers(
url = openapi_to_url(openapi_schema)
method_specs = openapi_to_method_specs(openapi_schema)
openapi_schema_str = json.dumps(openapi_schema)
with get_session_with_current_tenant() as temp_db_session:
tools = get_tools(temp_db_session)
tool_id: int | None = None
for tool in tools:
if tool.openapi_schema and (
json.dumps(tool.openapi_schema) == openapi_schema_str
):
tool_id = tool.id
break
if not tool_id:
raise ValueError(f"Tool with openapi_schema {openapi_schema_str} not found")
return [
CustomTool(
method_spec,
url,
custom_headers,
id=tool_id,
method_spec=method_spec,
base_url=url,
custom_headers=custom_headers,
user_oauth_token=user_oauth_token,
)
for method_spec in method_specs

View File

@@ -13,6 +13,8 @@ from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.configs.app_configs import IMAGE_MODEL_NAME
from onyx.configs.model_configs import GEN_AI_HISTORY_CUTOFF
from onyx.configs.tool_configs import IMAGE_GENERATION_OUTPUT_FORMAT
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.models import Tool as ToolDBModel
from onyx.llm.interfaces import LLM
from onyx.llm.models import PreviousMessage
from onyx.llm.utils import build_content_with_imgs
@@ -112,6 +114,22 @@ class ImageGenerationTool(Tool[None]):
self.additional_headers = additional_headers
self.output_format = output_format
with get_session_with_current_tenant() as db_session:
tool_id: int | None = (
db_session.query(ToolDBModel.id)
.filter(ToolDBModel.in_code_tool_id == ImageGenerationTool.__name__)
.scalar()
)
if not tool_id:
raise ValueError(
"Image Generation tool not found. This should never happen."
)
self._id = tool_id
@property
def id(self) -> int:
return self._id
@property
def name(self) -> str:
return self._NAME

View File

@@ -29,6 +29,7 @@ from onyx.context.search.enums import SearchType
from onyx.context.search.models import InferenceChunk
from onyx.context.search.models import InferenceSection
from onyx.db.models import Persona
from onyx.db.models import Tool as ToolDBModel
from onyx.db.search_settings import get_current_search_settings
from onyx.indexing.chunker import Chunker
from onyx.indexing.embedder import DefaultIndexingEmbedder
@@ -143,8 +144,23 @@ class InternetSearchTool(Tool[None]):
)
)
tool_id: int | None = (
db_session.query(ToolDBModel.id)
.filter(ToolDBModel.in_code_tool_id == InternetSearchTool.__name__)
.scalar()
)
if not tool_id:
raise ValueError(
"Internet Search tool not found. This should never happen."
)
self._id = tool_id
"""For explicit tool calling"""
@property
def id(self) -> int:
return self._id
@property
def name(self) -> str:
return self._NAME

View File

@@ -0,0 +1,118 @@
from collections.abc import Generator
from typing import Any
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.models import Tool as ToolDBModel
from onyx.llm.interfaces import LLM
from onyx.llm.models import PreviousMessage
from onyx.tools.message import ToolCallSummary
from onyx.tools.models import ToolResponse
from onyx.tools.tool import Tool
from onyx.utils.logger import setup_logger
from onyx.utils.special_types import JSON_ro
logger = setup_logger()
QUERY_FIELD = "query"
class KnowledgeGraphTool(Tool[None]):
_NAME = "run_kg_search"
_DESCRIPTION = "Search the knowledge graph for information. Never call this tool."
_DISPLAY_NAME = "Knowledge Graph Search"
def __init__(self) -> None:
with get_session_with_current_tenant() as db_session:
tool_id: int | None = (
db_session.query(ToolDBModel.id)
.filter(ToolDBModel.in_code_tool_id == KnowledgeGraphTool.__name__)
.scalar()
)
if not tool_id:
raise ValueError(
"Knowledge Graph tool not found. This should never happen."
)
self._id = tool_id
@property
def id(self) -> int:
return self._id
@property
def name(self) -> str:
return self._NAME
@property
def description(self) -> str:
return self._DESCRIPTION
@property
def display_name(self) -> str:
return self._DISPLAY_NAME
def tool_definition(self) -> dict:
return {
"type": "function",
"function": {
"name": self.name,
"description": self.description,
"parameters": {
"type": "object",
"properties": {
QUERY_FIELD: {
"type": "string",
"description": "What to search for",
},
},
"required": [QUERY_FIELD],
},
},
}
def get_args_for_non_tool_calling_llm(
self,
query: str,
history: list[PreviousMessage],
llm: LLM,
force_run: bool = False,
) -> dict[str, Any] | None:
raise ValueError(
"KnowledgeGraphTool should only be used by the Deep Research Agent, "
"not via tool calling."
)
def build_tool_message_content(
self, *args: ToolResponse
) -> str | list[str | dict[str, Any]]:
raise ValueError(
"KnowledgeGraphTool should only be used by the Deep Research Agent, "
"not via tool calling."
)
def run(
self, override_kwargs: None = None, **kwargs: str
) -> Generator[ToolResponse, None, None]:
raise ValueError(
"KnowledgeGraphTool should only be used by the Deep Research Agent, "
"not via tool calling."
)
def final_result(self, *args: ToolResponse) -> JSON_ro:
raise ValueError(
"KnowledgeGraphTool should only be used by the Deep Research Agent, "
"not via tool calling."
)
def build_next_prompt(
self,
prompt_builder: AnswerPromptBuilder,
tool_call_summary: ToolCallSummary,
tool_responses: list[ToolResponse],
using_tool_calling_llm: bool,
) -> AnswerPromptBuilder:
raise ValueError(
"KnowledgeGraphTool should only be used by the Deep Research Agent, "
"not via tool calling."
)

View File

@@ -34,6 +34,7 @@ from onyx.context.search.models import UserFileFilters
from onyx.context.search.pipeline import SearchPipeline
from onyx.context.search.pipeline import section_relevance_list_impl
from onyx.db.models import Persona
from onyx.db.models import Tool as ToolDBModel
from onyx.db.models import User
from onyx.llm.interfaces import LLM
from onyx.llm.models import PreviousMessage
@@ -162,6 +163,19 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
)
)
tool_id: int | None = (
db_session.query(ToolDBModel.id)
.filter(ToolDBModel.in_code_tool_id == SearchTool.__name__)
.scalar()
)
if not tool_id:
raise ValueError("Search tool not found. This should never happen.")
self._id = tool_id
@property
def id(self) -> int:
return self._id
@property
def name(self) -> str:
return self._NAME

Some files were not shown because too many files have changed in this diff Show More