Compare commits

..

27 Commits

Author SHA1 Message Date
pablodanswer
e55dd89444 k 2025-01-23 18:51:14 -08:00
pablodanswer
811d564a0f updated + functional 2025-01-23 18:51:14 -08:00
pablodanswer
206b247ca5 update- reorg 2025-01-23 18:51:14 -08:00
pablodanswer
cd445f0e3f k 2025-01-23 18:51:14 -08:00
pablodanswer
967b0e5b0f build fix 2025-01-23 18:05:34 -08:00
joachim-danswer
fac9525833 Merge pull request #3757 from onyx-dot-app/agent-search-feature-jr-1
Changes addressing YS questions from 01/22/25
2025-01-23 15:20:16 -08:00
joachim-danswer
79fc4ae47d EL comments addressed 2025-01-23 15:18:43 -08:00
joachim-danswer
1c09c75e5f loser verification prompt 2025-01-23 14:11:34 -08:00
joachim-danswer
23ec33e411 turning off initial search pre route decision 2025-01-23 13:29:13 -08:00
joachim-danswer
fea429e11b change of sub-question answer if no docs recovered 2025-01-23 13:23:10 -08:00
joachim-danswer
f9d7d21d8e various fixes from Yuhong's list 2025-01-23 13:06:26 -08:00
Yuhong Sun
d2a8938545 Copy changes 2025-01-23 11:25:06 -08:00
evan-danswer
ac909f8437 Merge pull request #3752 from onyx-dot-app/asf-evan-async-task-cleanup
async task cleanup + basic citations
2025-01-23 10:47:29 -08:00
Evan Lohn
2f32111169 removed print statements, fixed pass through handling 2025-01-23 10:38:09 -08:00
Evan Lohn
23acb163f5 fixed basic flow citations and second test 2025-01-23 10:18:22 -08:00
Evan Lohn
ebe15b42d2 fix for early cancellation test; solves issue with tasks being destroyed while pending 2025-01-22 21:15:56 -08:00
pablodanswer
6cbb237945 add agent search frontend 2025-01-22 18:31:30 -08:00
Evan Lohn
6803548066 fix alembic history 2025-01-22 17:35:26 -08:00
joachim-danswer
1111ce6ce4 streaming + saving of search docs of no verified ones available
- sub-questions only
2025-01-22 17:30:36 -08:00
Evan Lohn
3f68e8ea8e reworked history messages in agent config 2025-01-22 17:30:36 -08:00
Evan Lohn
06a8373ff4 missed files from prev commit 2025-01-22 17:30:36 -08:00
Evan Lohn
86e770d968 basic search restructure: WIP on fixing tests 2025-01-22 17:30:36 -08:00
joachim-danswer
f11216132e prompts that even further motivates to cite docs over sub-q's 2025-01-22 17:30:36 -08:00
joachim-danswer
1f7d05cd75 pydantic for LangGraph + changed ERT extraction flow 2025-01-22 17:30:36 -08:00
joachim-danswer
c8bf051fb6 history added to agent flow 2025-01-22 17:30:36 -08:00
pablodanswer
14b54db033 minor fixes to branch 2025-01-22 17:30:36 -08:00
Evan Lohn
0e9f9301ba second clean commit 2025-01-22 17:30:36 -08:00
465 changed files with 11618 additions and 15824 deletions

View File

@@ -8,8 +8,6 @@ on: push
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
GEN_AI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
MOCK_LLM_RESPONSE: true
jobs:
playwright-tests:

View File

@@ -21,10 +21,10 @@ jobs:
- name: Set up Helm
uses: azure/setup-helm@v4.2.0
with:
version: v3.17.0
version: v3.14.4
- name: Set up chart-testing
uses: helm/chart-testing-action@v2.7.0
uses: helm/chart-testing-action@v2.6.1
# even though we specify chart-dirs in ct.yaml, it isn't used by ct for the list-changed command...
- name: Run chart-testing (list-changed)
@@ -37,6 +37,22 @@ jobs:
echo "changed=true" >> "$GITHUB_OUTPUT"
fi
# rkuo: I don't think we need python?
# - name: Set up Python
# uses: actions/setup-python@v5
# with:
# python-version: '3.11'
# cache: 'pip'
# cache-dependency-path: |
# backend/requirements/default.txt
# backend/requirements/dev.txt
# backend/requirements/model_server.txt
# - run: |
# python -m pip install --upgrade pip
# pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
# pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
# pip install --retries 5 --timeout 30 -r backend/requirements/model_server.txt
# lint all charts if any changes were detected
- name: Run chart-testing (lint)
if: steps.list-changed.outputs.changed == 'true'
@@ -46,7 +62,7 @@ jobs:
- name: Create kind cluster
if: steps.list-changed.outputs.changed == 'true'
uses: helm/kind-action@v1.12.0
uses: helm/kind-action@v1.10.0
- name: Run chart-testing (install)
if: steps.list-changed.outputs.changed == 'true'

View File

@@ -9,9 +9,9 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Check PR body for Linear link or override
env:
PR_BODY: ${{ github.event.pull_request.body }}
run: |
PR_BODY="${{ github.event.pull_request.body }}"
# Looking for "https://linear.app" in the body
if echo "$PR_BODY" | grep -qE "https://linear\.app"; then
echo "Found a Linear link. Check passed."

View File

@@ -39,12 +39,6 @@ env:
AIRTABLE_TEST_TABLE_ID: ${{ secrets.AIRTABLE_TEST_TABLE_ID }}
AIRTABLE_TEST_TABLE_NAME: ${{ secrets.AIRTABLE_TEST_TABLE_NAME }}
AIRTABLE_ACCESS_TOKEN: ${{ secrets.AIRTABLE_ACCESS_TOKEN }}
# Sharepoint
SHAREPOINT_CLIENT_ID: ${{ secrets.SHAREPOINT_CLIENT_ID }}
SHAREPOINT_CLIENT_SECRET: ${{ secrets.SHAREPOINT_CLIENT_SECRET }}
SHAREPOINT_CLIENT_DIRECTORY_ID: ${{ secrets.SHAREPOINT_CLIENT_DIRECTORY_ID }}
SHAREPOINT_SITE: ${{ secrets.SHAREPOINT_SITE }}
jobs:
connectors-check:
# See https://runs-on.com/runners/linux/

View File

@@ -124,7 +124,7 @@ There are two editions of Onyx:
To try the Onyx Enterprise Edition:
1. Checkout our [Cloud product](https://cloud.onyx.app/signup).
2. For self-hosting, contact us at [founders@onyx.app](mailto:founders@onyx.app) or book a call with us on our [Cal](https://cal.com/team/onyx/founders).
2. For self-hosting, contact us at [founders@onyx.app](mailto:founders@onyx.app) or book a call with us on our [Cal](https://cal.com/team/danswer/founders).
## 💡 Contributing

1
Untitled-12 Normal file
View File

@@ -0,0 +1 @@

View File

@@ -9,10 +9,8 @@ founders@onyx.app for more information. Please visit https://github.com/onyx-dot
# Default ONYX_VERSION, typically overriden during builds by GitHub Actions.
ARG ONYX_VERSION=0.8-dev
# DO_NOT_TRACK is used to disable telemetry for Unstructured
ENV ONYX_VERSION=${ONYX_VERSION} \
DANSWER_RUNNING_IN_DOCKER="true" \
DO_NOT_TRACK="true"
DANSWER_RUNNING_IN_DOCKER="true"
RUN echo "ONYX_VERSION: ${ONYX_VERSION}"

View File

@@ -0,0 +1,29 @@
"""agent_doc_result_col
Revision ID: 1adf5ea20d2b
Revises: e9cf2bd7baed
Create Date: 2025-01-05 13:14:58.344316
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "1adf5ea20d2b"
down_revision = "e9cf2bd7baed"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Add the new column with JSONB type
op.add_column(
"sub_question",
sa.Column("sub_question_doc_results", postgresql.JSONB(), nullable=True),
)
def downgrade() -> None:
# Drop the column
op.drop_column("sub_question", "sub_question_doc_results")

View File

@@ -1,36 +0,0 @@
"""add chat session specific temperature override
Revision ID: 2f80c6a2550f
Revises: 33ea50e88f24
Create Date: 2025-01-31 10:30:27.289646
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "2f80c6a2550f"
down_revision = "33ea50e88f24"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"chat_session", sa.Column("temperature_override", sa.Float(), nullable=True)
)
op.add_column(
"user",
sa.Column(
"temperature_override_enabled",
sa.Boolean(),
nullable=False,
server_default=sa.false(),
),
)
def downgrade() -> None:
op.drop_column("chat_session", "temperature_override")
op.drop_column("user", "temperature_override_enabled")

View File

@@ -1,80 +0,0 @@
"""foreign key input prompts
Revision ID: 33ea50e88f24
Revises: a6df6b88ef81
Create Date: 2025-01-29 10:54:22.141765
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "33ea50e88f24"
down_revision = "a6df6b88ef81"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Safely drop constraints if exists
op.execute(
"""
ALTER TABLE inputprompt__user
DROP CONSTRAINT IF EXISTS inputprompt__user_input_prompt_id_fkey
"""
)
op.execute(
"""
ALTER TABLE inputprompt__user
DROP CONSTRAINT IF EXISTS inputprompt__user_user_id_fkey
"""
)
# Recreate with ON DELETE CASCADE
op.create_foreign_key(
"inputprompt__user_input_prompt_id_fkey",
"inputprompt__user",
"inputprompt",
["input_prompt_id"],
["id"],
ondelete="CASCADE",
)
op.create_foreign_key(
"inputprompt__user_user_id_fkey",
"inputprompt__user",
"user",
["user_id"],
["id"],
ondelete="CASCADE",
)
def downgrade() -> None:
# Drop the new FKs with ondelete
op.drop_constraint(
"inputprompt__user_input_prompt_id_fkey",
"inputprompt__user",
type_="foreignkey",
)
op.drop_constraint(
"inputprompt__user_user_id_fkey",
"inputprompt__user",
type_="foreignkey",
)
# Recreate them without cascading
op.create_foreign_key(
"inputprompt__user_input_prompt_id_fkey",
"inputprompt__user",
"inputprompt",
["input_prompt_id"],
["id"],
)
op.create_foreign_key(
"inputprompt__user_user_id_fkey",
"inputprompt__user",
"user",
["user_id"],
["id"],
)

View File

@@ -1,37 +0,0 @@
"""lowercase_user_emails
Revision ID: 4d58345da04a
Revises: f1ca58b2f2ec
Create Date: 2025-01-29 07:48:46.784041
"""
from alembic import op
from sqlalchemy.sql import text
# revision identifiers, used by Alembic.
revision = "4d58345da04a"
down_revision = "f1ca58b2f2ec"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Get database connection
connection = op.get_bind()
# Update all user emails to lowercase
connection.execute(
text(
"""
UPDATE "user"
SET email = LOWER(email)
WHERE email != LOWER(email)
"""
)
)
def downgrade() -> None:
# Cannot restore original case of emails
pass

View File

@@ -0,0 +1,35 @@
"""agent_metric_col_rename__s
Revision ID: 925b58bd75b6
Revises: 9787be927e58
Create Date: 2025-01-06 11:20:26.752441
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "925b58bd75b6"
down_revision = "9787be927e58"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Rename columns using PostgreSQL syntax
op.alter_column(
"agent__search_metrics", "base_duration_s", new_column_name="base_duration__s"
)
op.alter_column(
"agent__search_metrics", "full_duration_s", new_column_name="full_duration__s"
)
def downgrade() -> None:
# Revert the column renames
op.alter_column(
"agent__search_metrics", "base_duration__s", new_column_name="base_duration_s"
)
op.alter_column(
"agent__search_metrics", "full_duration__s", new_column_name="full_duration_s"
)

View File

@@ -0,0 +1,25 @@
"""agent_metric_table_renames__agent__
Revision ID: 9787be927e58
Revises: bceb76d618ec
Create Date: 2025-01-06 11:01:44.210160
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "9787be927e58"
down_revision = "bceb76d618ec"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Rename table from agent_search_metrics to agent__search_metrics
op.rename_table("agent_search_metrics", "agent__search_metrics")
def downgrade() -> None:
# Rename table back from agent__search_metrics to agent_search_metrics
op.rename_table("agent__search_metrics", "agent_search_metrics")

View File

@@ -1,25 +1,24 @@
"""agent_tracking
Revision ID: 98a5008d8711
Revises: 2f80c6a2550f
Create Date: 2025-01-29 17:00:00.000001
Revises: f1ca58b2f2ec
Create Date: 2025-01-04 14:41:52.732238
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
from sqlalchemy.dialects.postgresql import UUID
# revision identifiers, used by Alembic.
revision = "98a5008d8711"
down_revision = "2f80c6a2550f"
down_revision = "f1ca58b2f2ec"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.create_table(
"agent__search_metrics",
"agent_search_metrics",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=True),
sa.Column("persona_id", sa.Integer(), nullable=True),
@@ -38,70 +37,6 @@ def upgrade() -> None:
sa.PrimaryKeyConstraint("id"),
)
# Create sub_question table
op.create_table(
"agent__sub_question",
sa.Column("id", sa.Integer, primary_key=True),
sa.Column("primary_question_id", sa.Integer, sa.ForeignKey("chat_message.id")),
sa.Column(
"chat_session_id", UUID(as_uuid=True), sa.ForeignKey("chat_session.id")
),
sa.Column("sub_question", sa.Text),
sa.Column(
"time_created", sa.DateTime(timezone=True), server_default=sa.func.now()
),
sa.Column("sub_answer", sa.Text),
sa.Column("sub_question_doc_results", postgresql.JSONB(), nullable=True),
sa.Column("level", sa.Integer(), nullable=False),
sa.Column("level_question_num", sa.Integer(), nullable=False),
)
# Create sub_query table
op.create_table(
"agent__sub_query",
sa.Column("id", sa.Integer, primary_key=True),
sa.Column(
"parent_question_id", sa.Integer, sa.ForeignKey("agent__sub_question.id")
),
sa.Column(
"chat_session_id", UUID(as_uuid=True), sa.ForeignKey("chat_session.id")
),
sa.Column("sub_query", sa.Text),
sa.Column(
"time_created", sa.DateTime(timezone=True), server_default=sa.func.now()
),
)
# Create sub_query__search_doc association table
op.create_table(
"agent__sub_query__search_doc",
sa.Column(
"sub_query_id",
sa.Integer,
sa.ForeignKey("agent__sub_query.id"),
primary_key=True,
),
sa.Column(
"search_doc_id",
sa.Integer,
sa.ForeignKey("search_doc.id"),
primary_key=True,
),
)
op.add_column(
"chat_message",
sa.Column(
"refined_answer_improvement",
sa.Boolean(),
nullable=True,
),
)
def downgrade() -> None:
op.drop_column("chat_message", "refined_answer_improvement")
op.drop_table("agent__sub_query__search_doc")
op.drop_table("agent__sub_query")
op.drop_table("agent__sub_question")
op.drop_table("agent__search_metrics")
op.drop_table("agent_search_metrics")

View File

@@ -1,29 +0,0 @@
"""remove recent assistants
Revision ID: a6df6b88ef81
Revises: 4d58345da04a
Create Date: 2025-01-29 10:25:52.790407
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "a6df6b88ef81"
down_revision = "4d58345da04a"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.drop_column("user", "recent_assistants")
def downgrade() -> None:
op.add_column(
"user",
sa.Column(
"recent_assistants", postgresql.JSONB(), server_default="[]", nullable=False
),
)

View File

@@ -0,0 +1,84 @@
"""agent_table_renames__agent__
Revision ID: bceb76d618ec
Revises: c0132518a25b
Create Date: 2025-01-06 10:50:48.109285
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "bceb76d618ec"
down_revision = "c0132518a25b"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.drop_constraint(
"sub_query__search_doc_sub_query_id_fkey",
"sub_query__search_doc",
type_="foreignkey",
)
op.drop_constraint(
"sub_query__search_doc_search_doc_id_fkey",
"sub_query__search_doc",
type_="foreignkey",
)
# Rename tables
op.rename_table("sub_query", "agent__sub_query")
op.rename_table("sub_question", "agent__sub_question")
op.rename_table("sub_query__search_doc", "agent__sub_query__search_doc")
# Update both foreign key constraints for agent__sub_query__search_doc
# Create new foreign keys with updated names
op.create_foreign_key(
"agent__sub_query__search_doc_sub_query_id_fkey",
"agent__sub_query__search_doc",
"agent__sub_query",
["sub_query_id"],
["id"],
)
op.create_foreign_key(
"agent__sub_query__search_doc_search_doc_id_fkey",
"agent__sub_query__search_doc",
"search_doc", # This table name doesn't change
["search_doc_id"],
["id"],
)
def downgrade() -> None:
# Update foreign key constraints for sub_query__search_doc
op.drop_constraint(
"agent__sub_query__search_doc_sub_query_id_fkey",
"agent__sub_query__search_doc",
type_="foreignkey",
)
op.drop_constraint(
"agent__sub_query__search_doc_search_doc_id_fkey",
"agent__sub_query__search_doc",
type_="foreignkey",
)
# Rename tables back
op.rename_table("agent__sub_query__search_doc", "sub_query__search_doc")
op.rename_table("agent__sub_question", "sub_question")
op.rename_table("agent__sub_query", "sub_query")
op.create_foreign_key(
"sub_query__search_doc_sub_query_id_fkey",
"sub_query__search_doc",
"sub_query",
["sub_query_id"],
["id"],
)
op.create_foreign_key(
"sub_query__search_doc_search_doc_id_fkey",
"sub_query__search_doc",
"search_doc", # This table name doesn't change
["search_doc_id"],
["id"],
)

View File

@@ -0,0 +1,40 @@
"""agent_table_changes_rename_level
Revision ID: c0132518a25b
Revises: 1adf5ea20d2b
Create Date: 2025-01-05 16:38:37.660152
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "c0132518a25b"
down_revision = "1adf5ea20d2b"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Add level and level_question_nr columns with NOT NULL constraint
op.add_column(
"sub_question",
sa.Column("level", sa.Integer(), nullable=False, server_default="0"),
)
op.add_column(
"sub_question",
sa.Column(
"level_question_nr", sa.Integer(), nullable=False, server_default="0"
),
)
# Remove the server_default after the columns are created
op.alter_column("sub_question", "level", server_default=None)
op.alter_column("sub_question", "level_question_nr", server_default=None)
def downgrade() -> None:
# Remove the columns
op.drop_column("sub_question", "level_question_nr")
op.drop_column("sub_question", "level")

View File

@@ -0,0 +1,68 @@
"""create pro search persistence tables
Revision ID: e9cf2bd7baed
Revises: 98a5008d8711
Create Date: 2025-01-02 17:55:56.544246
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects.postgresql import UUID
# revision identifiers, used by Alembic.
revision = "e9cf2bd7baed"
down_revision = "98a5008d8711"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Create sub_question table
op.create_table(
"sub_question",
sa.Column("id", sa.Integer, primary_key=True),
sa.Column("primary_question_id", sa.Integer, sa.ForeignKey("chat_message.id")),
sa.Column(
"chat_session_id", UUID(as_uuid=True), sa.ForeignKey("chat_session.id")
),
sa.Column("sub_question", sa.Text),
sa.Column(
"time_created", sa.DateTime(timezone=True), server_default=sa.func.now()
),
sa.Column("sub_answer", sa.Text),
)
# Create sub_query table
op.create_table(
"sub_query",
sa.Column("id", sa.Integer, primary_key=True),
sa.Column("parent_question_id", sa.Integer, sa.ForeignKey("sub_question.id")),
sa.Column(
"chat_session_id", UUID(as_uuid=True), sa.ForeignKey("chat_session.id")
),
sa.Column("sub_query", sa.Text),
sa.Column(
"time_created", sa.DateTime(timezone=True), server_default=sa.func.now()
),
)
# Create sub_query__search_doc association table
op.create_table(
"sub_query__search_doc",
sa.Column(
"sub_query_id", sa.Integer, sa.ForeignKey("sub_query.id"), primary_key=True
),
sa.Column(
"search_doc_id",
sa.Integer,
sa.ForeignKey("search_doc.id"),
primary_key=True,
),
)
def downgrade() -> None:
op.drop_table("sub_query__search_doc")
op.drop_table("sub_query")
op.drop_table("sub_question")

View File

@@ -1,76 +0,0 @@
"""add default slack channel config
Revision ID: eaa3b5593925
Revises: 98a5008d8711
Create Date: 2025-02-03 18:07:56.552526
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "eaa3b5593925"
down_revision = "98a5008d8711"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Add is_default column
op.add_column(
"slack_channel_config",
sa.Column("is_default", sa.Boolean(), nullable=False, server_default="false"),
)
op.create_index(
"ix_slack_channel_config_slack_bot_id_default",
"slack_channel_config",
["slack_bot_id", "is_default"],
unique=True,
postgresql_where=sa.text("is_default IS TRUE"),
)
# Create default channel configs for existing slack bots without one
conn = op.get_bind()
slack_bots = conn.execute(sa.text("SELECT id FROM slack_bot")).fetchall()
for slack_bot in slack_bots:
slack_bot_id = slack_bot[0]
existing_default = conn.execute(
sa.text(
"SELECT id FROM slack_channel_config WHERE slack_bot_id = :bot_id AND is_default = TRUE"
),
{"bot_id": slack_bot_id},
).fetchone()
if not existing_default:
conn.execute(
sa.text(
"""
INSERT INTO slack_channel_config (
slack_bot_id, persona_id, channel_config, enable_auto_filters, is_default
) VALUES (
:bot_id, NULL,
'{"channel_name": null, "respond_member_group_list": [], "answer_filters": [], "follow_up_tags": []}',
FALSE, TRUE
)
"""
),
{"bot_id": slack_bot_id},
)
def downgrade() -> None:
# Delete default slack channel configs
conn = op.get_bind()
conn.execute(sa.text("DELETE FROM slack_channel_config WHERE is_default = TRUE"))
# Remove index
op.drop_index(
"ix_slack_channel_config_slack_bot_id_default",
table_name="slack_channel_config",
)
# Remove is_default column
op.drop_column("slack_channel_config", "is_default")

370
backend/chat_packets.log Normal file

File diff suppressed because one or more lines are too long

View File

@@ -32,7 +32,6 @@ def perform_ttl_management_task(
@celery_app.task(
name="check_ttl_management_task",
ignore_result=True,
soft_time_limit=JOB_TIMEOUT,
)
def check_ttl_management_task(*, tenant_id: str | None) -> None:
@@ -57,7 +56,6 @@ def check_ttl_management_task(*, tenant_id: str | None) -> None:
@celery_app.task(
name="autogenerate_usage_report_task",
ignore_result=True,
soft_time_limit=JOB_TIMEOUT,
)
def autogenerate_usage_report_task(*, tenant_id: str | None) -> None:

View File

@@ -1,72 +1,30 @@
from datetime import timedelta
from typing import Any
from onyx.background.celery.tasks.beat_schedule import BEAT_EXPIRES_DEFAULT
from onyx.background.celery.tasks.beat_schedule import (
cloud_tasks_to_schedule as base_cloud_tasks_to_schedule,
)
from onyx.background.celery.tasks.beat_schedule import (
tasks_to_schedule as base_tasks_to_schedule,
)
from onyx.configs.constants import ONYX_CLOUD_CELERY_TASK_PREFIX
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryTask
from shared_configs.configs import MULTI_TENANT
ee_cloud_tasks_to_schedule = [
ee_tasks_to_schedule = [
{
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_autogenerate-usage-report",
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
"schedule": timedelta(days=30),
"options": {
"priority": OnyxCeleryPriority.HIGHEST,
"expires": BEAT_EXPIRES_DEFAULT,
},
"kwargs": {
"task_name": OnyxCeleryTask.AUTOGENERATE_USAGE_REPORT_TASK,
},
"name": "autogenerate-usage-report",
"task": OnyxCeleryTask.AUTOGENERATE_USAGE_REPORT_TASK,
"schedule": timedelta(days=30), # TODO: change this to config flag
},
{
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-ttl-management",
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
"name": "check-ttl-management",
"task": OnyxCeleryTask.CHECK_TTL_MANAGEMENT_TASK,
"schedule": timedelta(hours=1),
"options": {
"priority": OnyxCeleryPriority.HIGHEST,
"expires": BEAT_EXPIRES_DEFAULT,
},
"kwargs": {
"task_name": OnyxCeleryTask.CHECK_TTL_MANAGEMENT_TASK,
},
},
]
ee_tasks_to_schedule: list[dict] = []
if not MULTI_TENANT:
ee_tasks_to_schedule = [
{
"name": "autogenerate-usage-report",
"task": OnyxCeleryTask.AUTOGENERATE_USAGE_REPORT_TASK,
"schedule": timedelta(days=30), # TODO: change this to config flag
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "check-ttl-management",
"task": OnyxCeleryTask.CHECK_TTL_MANAGEMENT_TASK,
"schedule": timedelta(hours=1),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
]
def get_cloud_tasks_to_schedule() -> list[dict[str, Any]]:
return ee_cloud_tasks_to_schedule + base_cloud_tasks_to_schedule
return base_cloud_tasks_to_schedule
def get_tasks_to_schedule() -> list[dict[str, Any]]:

View File

@@ -4,20 +4,6 @@ import os
# Applicable for OIDC Auth
OPENID_CONFIG_URL = os.environ.get("OPENID_CONFIG_URL", "")
# Applicable for OIDC Auth, allows you to override the scopes that
# are requested from the OIDC provider. Currently used when passing
# over access tokens to tool calls and the tool needs more scopes
OIDC_SCOPE_OVERRIDE: list[str] | None = None
_OIDC_SCOPE_OVERRIDE = os.environ.get("OIDC_SCOPE_OVERRIDE")
if _OIDC_SCOPE_OVERRIDE:
try:
OIDC_SCOPE_OVERRIDE = [
scope.strip() for scope in _OIDC_SCOPE_OVERRIDE.split(",")
]
except Exception:
pass
# Applicable for SAML Auth
SAML_CONF_DIR = os.environ.get("SAML_CONF_DIR") or "/app/ee/onyx/configs/saml_config"

View File

@@ -13,7 +13,6 @@ from onyx.connectors.confluence.onyx_confluence import OnyxConfluence
from onyx.connectors.confluence.utils import get_user_email_from_username__server
from onyx.connectors.models import SlimDocument
from onyx.db.models import ConnectorCredentialPair
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -258,7 +257,6 @@ def _fetch_all_page_restrictions(
slim_docs: list[SlimDocument],
space_permissions_by_space_key: dict[str, ExternalAccess],
is_cloud: bool,
callback: IndexingHeartbeatInterface | None,
) -> list[DocExternalAccess]:
"""
For all pages, if a page has restrictions, then use those restrictions.
@@ -267,12 +265,6 @@ def _fetch_all_page_restrictions(
document_restrictions: list[DocExternalAccess] = []
for slim_doc in slim_docs:
if callback:
if callback.should_stop():
raise RuntimeError("confluence_doc_sync: Stop signal detected")
callback.progress("confluence_doc_sync:fetch_all_page_restrictions", 1)
if slim_doc.perm_sync_data is None:
raise ValueError(
f"No permission sync data found for document {slim_doc.id}"
@@ -342,7 +334,7 @@ def _fetch_all_page_restrictions(
def confluence_doc_sync(
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
cc_pair: ConnectorCredentialPair,
) -> list[DocExternalAccess]:
"""
Adds the external permissions to the documents in postgres
@@ -367,12 +359,6 @@ def confluence_doc_sync(
logger.debug("Fetching all slim documents from confluence")
for doc_batch in confluence_connector.retrieve_all_slim_documents():
logger.debug(f"Got {len(doc_batch)} slim documents from confluence")
if callback:
if callback.should_stop():
raise RuntimeError("confluence_doc_sync: Stop signal detected")
callback.progress("confluence_doc_sync", 1)
slim_docs.extend(doc_batch)
logger.debug("Fetching all page restrictions for space")
@@ -381,5 +367,4 @@ def confluence_doc_sync(
slim_docs=slim_docs,
space_permissions_by_space_key=space_permissions_by_space_key,
is_cloud=is_cloud,
callback=callback,
)

View File

@@ -14,8 +14,6 @@ def _build_group_member_email_map(
) -> dict[str, set[str]]:
group_member_emails: dict[str, set[str]] = {}
for user_result in confluence_client.paginated_cql_user_retrieval():
logger.debug(f"Processing groups for user: {user_result}")
user = user_result.get("user", {})
if not user:
logger.warning(f"user result missing user field: {user_result}")
@@ -35,17 +33,10 @@ def _build_group_member_email_map(
logger.warning(f"user result missing email field: {user_result}")
continue
all_users_groups: set[str] = set()
for group in confluence_client.paginated_groups_by_user_retrieval(user):
# group name uniqueness is enforced by Confluence, so we can use it as a group ID
group_id = group["name"]
group_member_emails.setdefault(group_id, set()).add(email)
all_users_groups.add(group_id)
if not group_member_emails:
logger.warning(f"No groups found for user with email: {email}")
else:
logger.debug(f"Found groups {all_users_groups} for user with email {email}")
return group_member_emails

View File

@@ -6,7 +6,6 @@ from onyx.access.models import ExternalAccess
from onyx.connectors.gmail.connector import GmailConnector
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.db.models import ConnectorCredentialPair
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -29,7 +28,7 @@ def _get_slim_doc_generator(
def gmail_doc_sync(
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
cc_pair: ConnectorCredentialPair,
) -> list[DocExternalAccess]:
"""
Adds the external permissions to the documents in postgres
@@ -45,12 +44,6 @@ def gmail_doc_sync(
document_external_access: list[DocExternalAccess] = []
for slim_doc_batch in slim_doc_generator:
for slim_doc in slim_doc_batch:
if callback:
if callback.should_stop():
raise RuntimeError("gmail_doc_sync: Stop signal detected")
callback.progress("gmail_doc_sync", 1)
if slim_doc.perm_sync_data is None:
logger.warning(f"No permissions found for document {slim_doc.id}")
continue

View File

@@ -10,7 +10,6 @@ from onyx.connectors.google_utils.resources import get_drive_service
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.connectors.models import SlimDocument
from onyx.db.models import ConnectorCredentialPair
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -43,22 +42,24 @@ def _fetch_permissions_for_permission_ids(
if not permission_info or not doc_id:
return []
# Check cache first for all permission IDs
permissions = [
_PERMISSION_ID_PERMISSION_MAP[pid]
for pid in permission_ids
if pid in _PERMISSION_ID_PERMISSION_MAP
]
# If we found all permissions in cache, return them
if len(permissions) == len(permission_ids):
return permissions
owner_email = permission_info.get("owner_email")
drive_service = get_drive_service(
creds=google_drive_connector.creds,
user_email=(owner_email or google_drive_connector.primary_admin_email),
)
# Otherwise, fetch all permissions and update cache
fetched_permissions = execute_paginated_retrieval(
retrieval_function=drive_service.permissions().list,
list_key="permissions",
@@ -68,6 +69,7 @@ def _fetch_permissions_for_permission_ids(
)
permissions_for_doc_id = []
# Update cache and return all permissions
for permission in fetched_permissions:
permissions_for_doc_id.append(permission)
_PERMISSION_ID_PERMISSION_MAP[permission["id"]] = permission
@@ -129,7 +131,7 @@ def _get_permissions_from_slim_doc(
def gdrive_doc_sync(
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
cc_pair: ConnectorCredentialPair,
) -> list[DocExternalAccess]:
"""
Adds the external permissions to the documents in postgres
@@ -147,12 +149,6 @@ def gdrive_doc_sync(
document_external_accesses = []
for slim_doc_batch in slim_doc_generator:
for slim_doc in slim_doc_batch:
if callback:
if callback.should_stop():
raise RuntimeError("gdrive_doc_sync: Stop signal detected")
callback.progress("gdrive_doc_sync", 1)
ext_access = _get_permissions_from_slim_doc(
google_drive_connector=google_drive_connector,
slim_doc=slim_doc,

View File

@@ -7,7 +7,6 @@ from onyx.connectors.slack.connector import get_channels
from onyx.connectors.slack.connector import make_paginated_slack_api_call_w_retries
from onyx.connectors.slack.connector import SlackPollConnector
from onyx.db.models import ConnectorCredentialPair
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
@@ -15,7 +14,7 @@ logger = setup_logger()
def _get_slack_document_ids_and_channels(
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
cc_pair: ConnectorCredentialPair,
) -> dict[str, list[str]]:
slack_connector = SlackPollConnector(**cc_pair.connector.connector_specific_config)
slack_connector.load_credentials(cc_pair.credential.credential_json)
@@ -25,14 +24,6 @@ def _get_slack_document_ids_and_channels(
channel_doc_map: dict[str, list[str]] = {}
for doc_metadata_batch in slim_doc_generator:
for doc_metadata in doc_metadata_batch:
if callback:
if callback.should_stop():
raise RuntimeError(
"_get_slack_document_ids_and_channels: Stop signal detected"
)
callback.progress("_get_slack_document_ids_and_channels", 1)
if doc_metadata.perm_sync_data is None:
continue
channel_id = doc_metadata.perm_sync_data["channel_id"]
@@ -123,7 +114,7 @@ def _fetch_channel_permissions(
def slack_doc_sync(
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
cc_pair: ConnectorCredentialPair,
) -> list[DocExternalAccess]:
"""
Adds the external permissions to the documents in postgres
@@ -136,7 +127,7 @@ def slack_doc_sync(
)
user_id_to_email_map = fetch_user_id_to_email_map(slack_client)
channel_doc_map = _get_slack_document_ids_and_channels(
cc_pair=cc_pair, callback=callback
cc_pair=cc_pair,
)
workspace_permissions = _fetch_workspace_permissions(
user_id_to_email_map=user_id_to_email_map,

View File

@@ -15,13 +15,11 @@ from ee.onyx.external_permissions.slack.doc_sync import slack_doc_sync
from onyx.access.models import DocExternalAccess
from onyx.configs.constants import DocumentSource
from onyx.db.models import ConnectorCredentialPair
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
# Defining the input/output types for the sync functions
DocSyncFuncType = Callable[
[
ConnectorCredentialPair,
IndexingHeartbeatInterface | None,
],
list[DocExternalAccess],
]

View File

@@ -1,9 +1,7 @@
from fastapi import FastAPI
from httpx_oauth.clients.google import GoogleOAuth2
from httpx_oauth.clients.openid import BASE_SCOPES
from httpx_oauth.clients.openid import OpenID
from ee.onyx.configs.app_configs import OIDC_SCOPE_OVERRIDE
from ee.onyx.configs.app_configs import OPENID_CONFIG_URL
from ee.onyx.server.analytics.api import router as analytics_router
from ee.onyx.server.auth_check import check_ee_router_auth
@@ -90,13 +88,7 @@ def get_application() -> FastAPI:
include_auth_router_with_prefix(
application,
create_onyx_oauth_router(
OpenID(
OAUTH_CLIENT_ID,
OAUTH_CLIENT_SECRET,
OPENID_CONFIG_URL,
# BASE_SCOPES is the same as not setting this
base_scopes=OIDC_SCOPE_OVERRIDE or BASE_SCOPES,
),
OpenID(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET, OPENID_CONFIG_URL),
auth_backend,
USER_AUTH_SECRET,
associate_by_email=True,

View File

@@ -80,7 +80,7 @@ def oneoff_standard_answers(
def _handle_standard_answers(
message_info: SlackMessageInfo,
receiver_ids: list[str] | None,
slack_channel_config: SlackChannelConfig,
slack_channel_config: SlackChannelConfig | None,
prompt: Prompt | None,
logger: OnyxLoggingAdapter,
client: WebClient,
@@ -94,10 +94,13 @@ def _handle_standard_answers(
Returns True if standard answers are found to match the user's message and therefore,
we still need to respond to the users.
"""
# if no channel config, then no standard answers are configured
if not slack_channel_config:
return False
slack_thread_id = message_info.thread_to_respond
configured_standard_answer_categories = (
slack_channel_config.standard_answer_categories
slack_channel_config.standard_answer_categories if slack_channel_config else []
)
configured_standard_answers = set(
[

View File

@@ -10,7 +10,6 @@ from fastapi import Response
from ee.onyx.auth.users import decode_anonymous_user_jwt_token
from ee.onyx.configs.app_configs import ANONYMOUS_USER_COOKIE_NAME
from onyx.auth.api_key import extract_tenant_from_api_key_header
from onyx.configs.constants import TENANT_ID_COOKIE_NAME
from onyx.db.engine import is_valid_schema_name
from onyx.redis.redis_pool import retrieve_auth_token_data_from_redis
from shared_configs.configs import MULTI_TENANT
@@ -44,7 +43,6 @@ async def _get_tenant_id_from_request(
Attempt to extract tenant_id from:
1) The API key header
2) The Redis-based token (stored in Cookie: fastapiusersauth)
3) Reset token cookie
Fallback: POSTGRES_DEFAULT_SCHEMA
"""
# Check for API key
@@ -92,12 +90,3 @@ async def _get_tenant_id_from_request(
except Exception as e:
logger.error(f"Unexpected error in _get_tenant_id_from_request: {str(e)}")
raise HTTPException(status_code=500, detail="Internal server error")
finally:
# As a final step, check for explicit tenant_id cookie
tenant_id_cookie = request.cookies.get(TENANT_ID_COOKIE_NAME)
if tenant_id_cookie and is_valid_schema_name(tenant_id_cookie):
return tenant_id_cookie
# If we've reached this point, return the default schema
return POSTGRES_DEFAULT_SCHEMA

View File

@@ -286,7 +286,6 @@ def prepare_authorization_request(
oauth_state = (
base64.urlsafe_b64encode(oauth_uuid.bytes).rstrip(b"=").decode("utf-8")
)
session: str
if connector == DocumentSource.SLACK:
oauth_url = SlackOAuth.generate_oauth_url(oauth_state)
@@ -555,7 +554,6 @@ def handle_google_drive_oauth_callback(
)
session_json = session_json_bytes.decode("utf-8")
session: GoogleDriveOAuth.OAuthSession
try:
session = GoogleDriveOAuth.parse_session(session_json)

View File

@@ -125,10 +125,10 @@ class OneShotQARequest(ChunkContext):
# will also disable Thread-based Rewording if specified
query_override: str | None = None
# If True, skips generating an AI response to the search query
# If True, skips generative an AI response to the search query
skip_gen_ai_answer_generation: bool = False
# If True, uses agentic search instead of basic search
# If True, uses pro search instead of basic search
use_agentic_search: bool = False
@model_validator(mode="after")

View File

@@ -34,7 +34,6 @@ from onyx.auth.users import get_redis_strategy
from onyx.auth.users import optional_user
from onyx.auth.users import User
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.configs.constants import FASTAPI_USERS_AUTH_COOKIE_NAME
from onyx.db.auth import get_user_count
from onyx.db.engine import get_current_tenant_id
from onyx.db.engine import get_session
@@ -112,7 +111,6 @@ async def login_as_anonymous_user(
token = generate_anonymous_user_jwt_token(tenant_id)
response = Response()
response.delete_cookie(FASTAPI_USERS_AUTH_COOKIE_NAME)
response.set_cookie(
key=ANONYMOUS_USER_COOKIE_NAME,
value=token,

View File

@@ -58,7 +58,6 @@ class UserGroup(BaseModel):
credential=CredentialSnapshot.from_credential_db_model(
cc_pair_relationship.cc_pair.credential
),
access_type=cc_pair_relationship.cc_pair.access_type,
)
for cc_pair_relationship in user_group_model.cc_pair_relationships
if cc_pair_relationship.is_current

View File

@@ -2,17 +2,14 @@ from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
from onyx.agents.agent_search.basic.nodes.basic_use_tool_response import (
basic_use_tool_response,
)
from onyx.agents.agent_search.basic.nodes.llm_tool_choice import llm_tool_choice
from onyx.agents.agent_search.basic.nodes.tool_call import tool_call
from onyx.agents.agent_search.basic.states import BasicInput
from onyx.agents.agent_search.basic.states import BasicOutput
from onyx.agents.agent_search.basic.states import BasicState
from onyx.agents.agent_search.orchestration.nodes.basic_use_tool_response import (
basic_use_tool_response,
)
from onyx.agents.agent_search.orchestration.nodes.llm_tool_choice import llm_tool_choice
from onyx.agents.agent_search.orchestration.nodes.prepare_tool_input import (
prepare_tool_input,
)
from onyx.agents.agent_search.orchestration.nodes.tool_call import tool_call
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -27,11 +24,6 @@ def basic_graph_builder() -> StateGraph:
### Add nodes ###
graph.add_node(
node="prepare_tool_input",
action=prepare_tool_input,
)
graph.add_node(
node="llm_tool_choice",
action=llm_tool_choice,
@@ -49,9 +41,7 @@ def basic_graph_builder() -> StateGraph:
### Add edges ###
graph.add_edge(start_key=START, end_key="prepare_tool_input")
graph.add_edge(start_key="prepare_tool_input", end_key="llm_tool_choice")
graph.add_edge(start_key=START, end_key="llm_tool_choice")
graph.add_conditional_edges("llm_tool_choice", should_continue, ["tool_call", END])
@@ -72,26 +62,10 @@ def should_continue(state: BasicState) -> str:
return (
# If there are no tool calls, basic graph already streamed the answer
END
if state.tool_choice is None
if state["tool_choice"] is None
else "tool_call"
)
if __name__ == "__main__":
from onyx.db.engine import get_session_context_manager
from onyx.context.search.models import SearchRequest
from onyx.llm.factory import get_default_llms
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
graph = basic_graph_builder()
compiled_graph = graph.compile()
input = BasicInput(_unused=True)
primary_llm, fast_llm = get_default_llms()
with get_session_context_manager() as db_session:
config, _ = get_test_config(
db_session=db_session,
primary_llm=primary_llm,
fast_llm=fast_llm,
search_request=SearchRequest(query="How does onyx use FastAPI?"),
)
compiled_graph.invoke(input, config={"metadata": {"config": config}})
pass

View File

@@ -0,0 +1,63 @@
from typing import cast
from langchain_core.runnables.config import RunnableConfig
from onyx.agents.agent_search.basic.states import BasicOutput
from onyx.agents.agent_search.basic.states import BasicState
from onyx.agents.agent_search.basic.utils import process_llm_stream
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.chat.models import LlmDoc
from onyx.tools.tool_implementations.search.search_tool import (
SEARCH_DOC_CONTENT_ID,
)
from onyx.tools.tool_implementations.search_like_tool_utils import (
FINAL_CONTEXT_DOCUMENTS_ID,
)
def basic_use_tool_response(state: BasicState, config: RunnableConfig) -> BasicOutput:
agent_config = cast(AgentSearchConfig, config["metadata"]["config"])
structured_response_format = agent_config.structured_response_format
llm = agent_config.primary_llm
tool_choice = state["tool_choice"]
if tool_choice is None:
raise ValueError("Tool choice is None")
tool = tool_choice["tool"]
prompt_builder = agent_config.prompt_builder
tool_call_summary = state["tool_call_summary"]
tool_call_responses = state["tool_call_responses"]
state["tool_call_final_result"]
new_prompt_builder = tool.build_next_prompt(
prompt_builder=prompt_builder,
tool_call_summary=tool_call_summary,
tool_responses=tool_call_responses,
using_tool_calling_llm=agent_config.using_tool_calling_llm,
)
final_search_results = []
initial_search_results = []
for yield_item in tool_call_responses:
if yield_item.id == FINAL_CONTEXT_DOCUMENTS_ID:
final_search_results = cast(list[LlmDoc], yield_item.response)
elif yield_item.id == SEARCH_DOC_CONTENT_ID:
search_contexts = yield_item.response.contexts
for doc in search_contexts:
if doc.document_id not in initial_search_results:
initial_search_results.append(doc)
initial_search_results = cast(list[LlmDoc], initial_search_results)
stream = llm.stream(
prompt=new_prompt_builder.build(),
structured_response_format=structured_response_format,
)
# For now, we don't do multiple tool calls, so we ignore the tool_message
process_llm_stream(
stream,
True,
final_search_results=final_search_results,
displayed_search_results=initial_search_results,
)
return BasicOutput()

View File

@@ -3,14 +3,12 @@ from uuid import uuid4
from langchain_core.messages import ToolCall
from langchain_core.runnables.config import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.basic.states import BasicState
from onyx.agents.agent_search.basic.states import ToolChoice
from onyx.agents.agent_search.basic.states import ToolChoiceUpdate
from onyx.agents.agent_search.basic.utils import process_llm_stream
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.orchestration.states import ToolChoice
from onyx.agents.agent_search.orchestration.states import ToolChoiceState
from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.chat.tool_handling.tool_response_handler import get_tool_by_name
from onyx.chat.tool_handling.tool_response_handler import (
get_tool_call_for_non_tool_calling_llm_impl,
@@ -25,29 +23,22 @@ logger = setup_logger()
# and a function that handles extracting the necessary fields
# from the state and config
# TODO: fan-out to multiple tool call nodes? Make this configurable?
def llm_tool_choice(
state: ToolChoiceState,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> ToolChoiceUpdate:
def llm_tool_choice(state: BasicState, config: RunnableConfig) -> ToolChoiceUpdate:
"""
This node is responsible for calling the LLM to choose a tool. If no tool is chosen,
The node MAY emit an answer, depending on whether state["should_stream_answer"] is set.
"""
should_stream_answer = state.should_stream_answer
should_stream_answer = state["should_stream_answer"]
agent_config = cast(GraphConfig, config["metadata"]["config"])
using_tool_calling_llm = agent_config.tooling.using_tool_calling_llm
prompt_builder = state.prompt_snapshot or agent_config.inputs.prompt_builder
agent_config = cast(AgentSearchConfig, config["metadata"]["config"])
using_tool_calling_llm = agent_config.using_tool_calling_llm
prompt_builder = agent_config.prompt_builder
llm = agent_config.primary_llm
skip_gen_ai_answer_generation = agent_config.skip_gen_ai_answer_generation
llm = agent_config.tooling.primary_llm
skip_gen_ai_answer_generation = agent_config.behavior.skip_gen_ai_answer_generation
structured_response_format = agent_config.inputs.structured_response_format
tools = [
tool for tool in (agent_config.tooling.tools or []) if tool.name in state.tools
]
force_use_tool = agent_config.tooling.force_use_tool
structured_response_format = agent_config.structured_response_format
tools = agent_config.tools or []
force_use_tool = agent_config.force_use_tool
tool, tool_args = None, None
if force_use_tool.force_use and force_use_tool.args is not None:
@@ -68,7 +59,7 @@ def llm_tool_choice(
if chosen_tool_and_args:
tool, tool_args = chosen_tool_and_args
# If we have a tool and tool args, we are ready to request a tool call.
# If we have a tool and tool args, we are redy to request a tool call.
# This only happens if the tool call was forced or we are using a non-tool calling LLM.
if tool and tool_args:
return ToolChoiceUpdate(
@@ -87,32 +78,21 @@ def llm_tool_choice(
tool_choice=None,
)
built_prompt = (
prompt_builder.build()
if isinstance(prompt_builder, AnswerPromptBuilder)
else prompt_builder.built_prompt
)
# At this point, we are either using a tool calling LLM or we are skipping the tool call.
# DEBUG: good breakpoint
stream = llm.stream(
# For tool calling LLMs, we want to insert the task prompt as part of this flow, this is because the LLM
# may choose to not call any tools and just generate the answer, in which case the task prompt is needed.
prompt=built_prompt,
prompt=prompt_builder.build(),
tools=[tool.tool_definition() for tool in tools] or None,
tool_choice=("required" if tools and force_use_tool.force_use else None),
structured_response_format=structured_response_format,
)
tool_message = process_llm_stream(
stream,
should_stream_answer
and not agent_config.behavior.skip_gen_ai_answer_generation,
writer,
)
tool_message = process_llm_stream(stream, should_stream_answer)
# If no tool calls are emitted by the LLM, we should not choose a tool
if len(tool_message.tool_calls) == 0:
logger.debug("No tool calls emitted by LLM")
return ToolChoiceUpdate(
tool_choice=None,
)
@@ -142,7 +122,7 @@ def llm_tool_choice(
f"Tool call attempted with tool {selected_tool}, request {selected_tool_call_request}"
)
logger.debug(f"Selected tool: {selected_tool.name}")
logger.info(f"Selected tool: {selected_tool.name}")
logger.debug(f"Selected tool call request: {selected_tool_call_request}")
return ToolChoiceUpdate(

View File

@@ -0,0 +1,69 @@
from typing import cast
from langchain_core.callbacks.manager import dispatch_custom_event
from langchain_core.messages import AIMessageChunk
from langchain_core.messages.tool import ToolCall
from langchain_core.runnables.config import RunnableConfig
from onyx.agents.agent_search.basic.states import BasicState
from onyx.agents.agent_search.basic.states import ToolCallUpdate
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.chat.models import AnswerPacket
from onyx.tools.message import build_tool_message
from onyx.tools.message import ToolCallSummary
from onyx.tools.tool_runner import ToolRunner
from onyx.utils.logger import setup_logger
logger = setup_logger()
def emit_packet(packet: AnswerPacket) -> None:
dispatch_custom_event("basic_response", packet)
# TODO: handle is_cancelled
def tool_call(state: BasicState, config: RunnableConfig) -> ToolCallUpdate:
"""Calls the tool specified in the state and updates the state with the result"""
# TODO: implement
cast(AgentSearchConfig, config["metadata"]["config"])
# Unnecessary now, node should only be called if there is a tool call
# if not self.tool_call_chunk or not self.tool_call_chunk.tool_calls:
# return
tool_choice = state["tool_choice"]
if tool_choice is None:
raise ValueError("Cannot invoke tool call node without a tool choice")
tool = tool_choice["tool"]
tool_args = tool_choice["tool_args"]
tool_id = tool_choice["id"]
tool_runner = ToolRunner(tool, tool_args)
tool_kickoff = tool_runner.kickoff()
# TODO: custom events for yields
emit_packet(tool_kickoff)
tool_responses = []
for response in tool_runner.tool_responses():
tool_responses.append(response)
emit_packet(response)
tool_final_result = tool_runner.tool_final_result()
emit_packet(tool_final_result)
tool_call = ToolCall(name=tool.name, args=tool_args, id=tool_id)
tool_call_summary = ToolCallSummary(
tool_call_request=AIMessageChunk(content="", tool_calls=[tool_call]),
tool_call_result=build_tool_message(
tool_call, tool_runner.tool_message_content()
),
)
return ToolCallUpdate(
tool_call_summary=tool_call_summary,
tool_call_kickoff=tool_kickoff,
tool_call_responses=tool_responses,
tool_call_final_result=tool_final_result,
)

View File

@@ -1,35 +1,55 @@
from typing import TypedDict
from langchain_core.messages import AIMessageChunk
from pydantic import BaseModel
from onyx.agents.agent_search.orchestration.states import ToolCallUpdate
from onyx.agents.agent_search.orchestration.states import ToolChoiceInput
from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
from onyx.tools.message import ToolCallSummary
from onyx.tools.models import ToolCallFinalResult
from onyx.tools.models import ToolCallKickoff
from onyx.tools.models import ToolResponse
from onyx.tools.tool import Tool
# States contain values that change over the course of graph execution,
# Config is for values that are set at the start and never change.
# If you are using a value from the config and realize it needs to change,
# you should add it to the state and use/update the version in the state.
## Graph Input State
class BasicInput(BaseModel):
# Langgraph needs a nonempty input, but we pass in all static
# data through a RunnableConfig.
_unused: bool = True
class BasicInput(TypedDict):
should_stream_answer: bool
## Graph Output State
class BasicOutput(TypedDict):
tool_call_chunk: AIMessageChunk
pass
## Update States
class ToolCallUpdate(TypedDict):
tool_call_summary: ToolCallSummary
tool_call_kickoff: ToolCallKickoff
tool_call_responses: list[ToolResponse]
tool_call_final_result: ToolCallFinalResult
class ToolChoice(TypedDict):
tool: Tool
tool_args: dict
id: str | None
class ToolChoiceUpdate(TypedDict):
tool_choice: ToolChoice | None
## Graph State
class BasicState(
BasicInput,
ToolChoiceInput,
ToolCallUpdate,
ToolChoiceUpdate,
BasicOutput,
):
pass

View File

@@ -1,13 +1,11 @@
from collections.abc import Iterator
from typing import cast
from langchain_core.callbacks.manager import dispatch_custom_event
from langchain_core.messages import AIMessageChunk
from langchain_core.messages import BaseMessage
from langgraph.types import StreamWriter
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import LlmDoc
from onyx.chat.models import OnyxContext
from onyx.chat.stream_processing.answer_response_handler import AnswerResponseHandler
from onyx.chat.stream_processing.answer_response_handler import CitationResponseHandler
from onyx.chat.stream_processing.answer_response_handler import (
@@ -18,15 +16,23 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
# TODO: handle citations here; below is what was previously passed in
# see basic_use_tool_response.py for where these variables come from
# answer_handler = CitationResponseHandler(
# context_docs=final_search_results,
# final_doc_id_to_rank_map=map_document_id_order(final_search_results),
# display_doc_id_to_rank_map=map_document_id_order(displayed_search_results),
# )
def process_llm_stream(
messages: Iterator[BaseMessage],
stream: Iterator[BaseMessage],
should_stream_answer: bool,
writer: StreamWriter,
final_search_results: list[LlmDoc] | None = None,
displayed_search_results: list[OnyxContext] | list[LlmDoc] | None = None,
displayed_search_results: list[LlmDoc] | None = None,
) -> AIMessageChunk:
tool_call_chunk = AIMessageChunk(content="")
# for response in response_handler_manager.handle_llm_response(stream):
if final_search_results and displayed_search_results:
answer_handler: AnswerResponseHandler = CitationResponseHandler(
@@ -37,28 +43,25 @@ def process_llm_stream(
else:
answer_handler = PassThroughAnswerResponseHandler()
full_answer = ""
# This stream will be the llm answer if no tool is chosen. When a tool is chosen,
# the stream will contain AIMessageChunks with tool call information.
for message in messages:
answer_piece = message.content
for response in stream:
answer_piece = response.content
if not isinstance(answer_piece, str):
# this is only used for logging, so fine to
# just add the string representation
# TODO: handle non-string content
logger.warning(f"Received non-string content: {type(answer_piece)}")
answer_piece = str(answer_piece)
full_answer += answer_piece
if isinstance(message, AIMessageChunk) and (
message.tool_call_chunks or message.tool_calls
if isinstance(response, AIMessageChunk) and (
response.tool_call_chunks or response.tool_calls
):
tool_call_chunk += message # type: ignore
tool_call_chunk += response # type: ignore
elif should_stream_answer:
for response_part in answer_handler.handle_response_part(message, []):
write_custom_event(
# TODO: handle emitting of CitationInfo
for response_part in answer_handler.handle_response_part(response, []):
dispatch_custom_event(
"basic_response",
response_part,
writer,
)
logger.debug(f"Full answer: {full_answer}")
return cast(AIMessageChunk, tool_call_chunk)

View File

@@ -0,0 +1,66 @@
from uuid import UUID
from sqlalchemy.orm import Session
from onyx.db.models import AgentSubQuery
from onyx.db.models import AgentSubQuestion
def create_sub_question(
db_session: Session,
chat_session_id: UUID,
primary_message_id: int,
sub_question: str,
sub_answer: str,
) -> AgentSubQuestion:
"""Create a new sub-question record in the database."""
sub_q = AgentSubQuestion(
chat_session_id=chat_session_id,
primary_question_id=primary_message_id,
sub_question=sub_question,
sub_answer=sub_answer,
)
db_session.add(sub_q)
db_session.flush()
return sub_q
def create_sub_query(
db_session: Session,
chat_session_id: UUID,
parent_question_id: int,
sub_query: str,
) -> AgentSubQuery:
"""Create a new sub-query record in the database."""
sub_q = AgentSubQuery(
chat_session_id=chat_session_id,
parent_question_id=parent_question_id,
sub_query=sub_query,
)
db_session.add(sub_q)
db_session.flush()
return sub_q
def get_sub_questions_for_message(
db_session: Session,
primary_message_id: int,
) -> list[AgentSubQuestion]:
"""Get all sub-questions for a given primary message."""
return (
db_session.query(AgentSubQuestion)
.filter(AgentSubQuestion.primary_question_id == primary_message_id)
.all()
)
def get_sub_queries_for_question(
db_session: Session,
sub_question_id: int,
) -> list[AgentSubQuery]:
"""Get all sub-queries for a given sub-question."""
return (
db_session.query(AgentSubQuery)
.filter(AgentSubQuery.parent_question_id == sub_question_id)
.all()
)

View File

@@ -1,31 +0,0 @@
from collections.abc import Hashable
from datetime import datetime
from langgraph.types import Send
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
SubQuestionAnsweringInput,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
ExpandedRetrievalInput,
)
from onyx.utils.logger import setup_logger
logger = setup_logger()
def send_to_expanded_retrieval(state: SubQuestionAnsweringInput) -> Send | Hashable:
"""
LangGraph edge to send a sub-question to the expanded retrieval.
"""
edge_start_time = datetime.now()
return Send(
"initial_sub_question_expanded_retrieval",
ExpandedRetrievalInput(
question=state.question,
base_search=False,
sub_question_id=state.question_id,
log_messages=[f"{edge_start_time} -- Sending to expanded retrieval"],
),
)

View File

@@ -1,75 +0,0 @@
from datetime import datetime
from typing import cast
from langchain_core.messages import HumanMessage
from langchain_core.messages import merge_message_runs
from langchain_core.runnables.config import RunnableConfig
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
AnswerQuestionState,
)
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
SubQuestionAnswerCheckUpdate,
)
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
from onyx.prompts.agent_search import SUB_ANSWER_CHECK_PROMPT
from onyx.prompts.agent_search import UNKNOWN_ANSWER
def check_sub_answer(
state: AnswerQuestionState, config: RunnableConfig
) -> SubQuestionAnswerCheckUpdate:
"""
LangGraph node to check the quality of the sub-answer. The answer
is represented as a boolean value.
"""
node_start_time = datetime.now()
level, question_num = parse_question_id(state.question_id)
if state.answer == UNKNOWN_ANSWER:
return SubQuestionAnswerCheckUpdate(
answer_quality=False,
log_messages=[
get_langgraph_node_log_string(
graph_component="initial - generate individual sub answer",
node_name="check sub answer",
node_start_time=node_start_time,
result="unknown answer",
)
],
)
msg = [
HumanMessage(
content=SUB_ANSWER_CHECK_PROMPT.format(
question=state.question,
base_answer=state.answer,
)
)
]
graph_config = cast(GraphConfig, config["metadata"]["config"])
fast_llm = graph_config.tooling.fast_llm
response = list(
fast_llm.stream(
prompt=msg,
)
)
quality_str: str = merge_message_runs(response, chunk_separator="")[0].content
answer_quality = "yes" in quality_str.lower()
return SubQuestionAnswerCheckUpdate(
answer_quality=answer_quality,
log_messages=[
get_langgraph_node_log_string(
graph_component="initial - generate individual sub answer",
node_name="check sub answer",
node_start_time=node_start_time,
result=f"Answer quality: {quality_str}",
)
],
)

View File

@@ -1,30 +0,0 @@
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
AnswerQuestionOutput,
)
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
AnswerQuestionState,
)
from onyx.agents.agent_search.shared_graph_utils.models import (
SubQuestionAnswerResults,
)
def format_sub_answer(state: AnswerQuestionState) -> AnswerQuestionOutput:
"""
LangGraph node to generate the sub-answer format.
"""
return AnswerQuestionOutput(
answer_results=[
SubQuestionAnswerResults(
question=state.question,
question_id=state.question_id,
verified_high_quality=state.answer_quality,
answer=state.answer,
sub_query_retrieval_results=state.expanded_retrieval_results,
verified_reranked_documents=state.verified_reranked_documents,
context_documents=state.context_documents,
cited_documents=state.cited_documents,
sub_question_retrieval_stats=state.sub_question_retrieval_stats,
)
],
)

View File

@@ -1,137 +0,0 @@
from datetime import datetime
from typing import Any
from typing import cast
from langchain_core.messages import merge_message_runs
from langchain_core.runnables.config import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
AnswerQuestionState,
)
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
SubQuestionAnswerGenerationUpdate,
)
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
build_sub_question_answer_prompt,
)
from onyx.agents.agent_search.shared_graph_utils.utils import get_answer_citation_ids
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_persona_agent_prompt_expressions,
)
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import AgentAnswerPiece
from onyx.chat.models import StreamStopInfo
from onyx.chat.models import StreamStopReason
from onyx.chat.models import StreamType
from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS
from onyx.prompts.agent_search import NO_RECOVERED_DOCS
from onyx.utils.logger import setup_logger
logger = setup_logger()
def generate_sub_answer(
state: AnswerQuestionState,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> SubQuestionAnswerGenerationUpdate:
"""
LangGraph node to generate a sub-answer.
"""
node_start_time = datetime.now()
graph_config = cast(GraphConfig, config["metadata"]["config"])
question = state.question
state.verified_reranked_documents
level, question_num = parse_question_id(state.question_id)
context_docs = state.context_documents[:AGENT_MAX_ANSWER_CONTEXT_DOCS]
persona_contextualized_prompt = get_persona_agent_prompt_expressions(
graph_config.inputs.search_request.persona
).contextualized_prompt
if len(context_docs) == 0:
answer_str = NO_RECOVERED_DOCS
write_custom_event(
"sub_answers",
AgentAnswerPiece(
answer_piece=answer_str,
level=level,
level_question_num=question_num,
answer_type="agent_sub_answer",
),
writer,
)
else:
fast_llm = graph_config.tooling.fast_llm
msg = build_sub_question_answer_prompt(
question=question,
original_question=graph_config.inputs.search_request.query,
docs=context_docs,
persona_specification=persona_contextualized_prompt,
config=fast_llm.config,
)
response: list[str | list[str | dict[str, Any]]] = []
dispatch_timings: list[float] = []
for message in fast_llm.stream(
prompt=msg,
):
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
content = message.content
if not isinstance(content, str):
raise ValueError(
f"Expected content to be a string, but got {type(content)}"
)
start_stream_token = datetime.now()
write_custom_event(
"sub_answers",
AgentAnswerPiece(
answer_piece=content,
level=level,
level_question_num=question_num,
answer_type="agent_sub_answer",
),
writer,
)
end_stream_token = datetime.now()
dispatch_timings.append(
(end_stream_token - start_stream_token).microseconds
)
response.append(content)
answer_str = merge_message_runs(response, chunk_separator="")[0].content
logger.debug(
f"Average dispatch time: {sum(dispatch_timings) / len(dispatch_timings)}"
)
answer_citation_ids = get_answer_citation_ids(answer_str)
cited_documents = [
context_docs[id] for id in answer_citation_ids if id < len(context_docs)
]
stop_event = StreamStopInfo(
stop_reason=StreamStopReason.FINISHED,
stream_type=StreamType.SUB_ANSWER,
level=level,
level_question_num=question_num,
)
write_custom_event("stream_finished", stop_event, writer)
return SubQuestionAnswerGenerationUpdate(
answer=answer_str,
cited_documents=cited_documents,
log_messages=[
get_langgraph_node_log_string(
graph_component="initial - generate individual sub answer",
node_name="generate sub answer",
node_start_time=node_start_time,
result="",
)
],
)

View File

@@ -1,25 +0,0 @@
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
SubQuestionRetrievalIngestionUpdate,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
ExpandedRetrievalOutput,
)
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkRetrievalStats
def ingest_retrieved_documents(
state: ExpandedRetrievalOutput,
) -> SubQuestionRetrievalIngestionUpdate:
"""
LangGraph node to ingest the retrieved documents to format it for the sub-answer.
"""
sub_question_retrieval_stats = state.expanded_retrieval_result.retrieval_stats
if sub_question_retrieval_stats is None:
sub_question_retrieval_stats = [AgentChunkRetrievalStats()]
return SubQuestionRetrievalIngestionUpdate(
expanded_retrieval_results=state.expanded_retrieval_result.expanded_query_results,
verified_reranked_documents=state.expanded_retrieval_result.verified_reranked_documents,
context_documents=state.expanded_retrieval_result.context_documents,
sub_question_retrieval_stats=sub_question_retrieval_stats,
)

View File

@@ -1,50 +0,0 @@
from collections.abc import Hashable
from datetime import datetime
from langgraph.types import Send
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
AnswerQuestionOutput,
)
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
SubQuestionAnsweringInput,
)
from onyx.agents.agent_search.deep_search.initial.generate_initial_answer.states import (
SubQuestionRetrievalState,
)
from onyx.agents.agent_search.shared_graph_utils.utils import make_question_id
def parallelize_initial_sub_question_answering(
state: SubQuestionRetrievalState,
) -> list[Send | Hashable]:
"""
LangGraph edge to parallelize the initial sub-question answering. If there are no sub-questions,
we send empty answers to the initial answer generation, and that answer would be generated
solely based on the documents retrieved for the original question.
"""
edge_start_time = datetime.now()
if len(state.initial_sub_questions) > 0:
return [
Send(
"answer_query_subgraph",
SubQuestionAnsweringInput(
question=question,
question_id=make_question_id(0, question_num + 1),
log_messages=[
f"{edge_start_time} -- Main Edge - Parallelize Initial Sub-question Answering"
],
),
)
for question_num, question in enumerate(state.initial_sub_questions)
]
else:
return [
Send(
"ingest_answers",
AnswerQuestionOutput(
answer_results=[],
),
)
]

View File

@@ -1,96 +0,0 @@
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
from onyx.agents.agent_search.deep_search.initial.generate_initial_answer.nodes.generate_initial_answer import (
generate_initial_answer,
)
from onyx.agents.agent_search.deep_search.initial.generate_initial_answer.nodes.validate_initial_answer import (
validate_initial_answer,
)
from onyx.agents.agent_search.deep_search.initial.generate_initial_answer.states import (
SubQuestionRetrievalInput,
)
from onyx.agents.agent_search.deep_search.initial.generate_initial_answer.states import (
SubQuestionRetrievalState,
)
from onyx.agents.agent_search.deep_search.initial.generate_sub_answers.graph_builder import (
generate_sub_answers_graph_builder,
)
from onyx.agents.agent_search.deep_search.initial.retrieve_orig_question_docs.graph_builder import (
retrieve_orig_question_docs_graph_builder,
)
from onyx.utils.logger import setup_logger
logger = setup_logger()
def generate_initial_answer_graph_builder(test_mode: bool = False) -> StateGraph:
"""
LangGraph graph builder for the initial answer generation.
"""
graph = StateGraph(
state_schema=SubQuestionRetrievalState,
input=SubQuestionRetrievalInput,
)
# The sub-graph that generates the initial sub-answers
generate_sub_answers = generate_sub_answers_graph_builder().compile()
graph.add_node(
node="generate_sub_answers_subgraph",
action=generate_sub_answers,
)
# The sub-graph that retrieves the original question documents. This is run
# in parallel with the sub-answer generation process
retrieve_orig_question_docs = retrieve_orig_question_docs_graph_builder().compile()
graph.add_node(
node="retrieve_orig_question_docs_subgraph_wrapper",
action=retrieve_orig_question_docs,
)
# Node that generates the initial answer using the results of the previous
# two sub-graphs
graph.add_node(
node="generate_initial_answer",
action=generate_initial_answer,
)
# Node that validates the initial answer
graph.add_node(
node="validate_initial_answer",
action=validate_initial_answer,
)
### Add edges ###
graph.add_edge(
start_key=START,
end_key="retrieve_orig_question_docs_subgraph_wrapper",
)
graph.add_edge(
start_key=START,
end_key="generate_sub_answers_subgraph",
)
# Wait for both, the original question docs and the sub-answers to be generated before proceeding
graph.add_edge(
start_key=[
"retrieve_orig_question_docs_subgraph_wrapper",
"generate_sub_answers_subgraph",
],
end_key="generate_initial_answer",
)
graph.add_edge(
start_key="generate_initial_answer",
end_key="validate_initial_answer",
)
graph.add_edge(
start_key="validate_initial_answer",
end_key=END,
)
return graph

View File

@@ -1,313 +0,0 @@
from datetime import datetime
from typing import Any
from typing import cast
from langchain_core.messages import HumanMessage
from langchain_core.messages import merge_content
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.deep_search.initial.generate_initial_answer.states import (
SubQuestionRetrievalState,
)
from onyx.agents.agent_search.deep_search.main.models import AgentBaseMetrics
from onyx.agents.agent_search.deep_search.main.operations import (
calculate_initial_agent_stats,
)
from onyx.agents.agent_search.deep_search.main.operations import get_query_info
from onyx.agents.agent_search.deep_search.main.operations import logger
from onyx.agents.agent_search.deep_search.main.states import (
InitialAnswerUpdate,
)
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
get_prompt_enrichment_components,
)
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
trim_prompt_piece,
)
from onyx.agents.agent_search.shared_graph_utils.models import InitialAgentResultStats
from onyx.agents.agent_search.shared_graph_utils.operators import (
dedup_inference_sections,
)
from onyx.agents.agent_search.shared_graph_utils.utils import (
dispatch_main_answer_stop_info,
)
from onyx.agents.agent_search.shared_graph_utils.utils import format_docs
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import relevance_from_docs
from onyx.agents.agent_search.shared_graph_utils.utils import remove_document_citations
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import AgentAnswerPiece
from onyx.chat.models import ExtendedToolResponse
from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS
from onyx.configs.agent_configs import AGENT_MIN_ORIG_QUESTION_DOCS
from onyx.context.search.models import InferenceSection
from onyx.prompts.agent_search import (
INITIAL_ANSWER_PROMPT_W_SUB_QUESTIONS,
)
from onyx.prompts.agent_search import (
INITIAL_ANSWER_PROMPT_WO_SUB_QUESTIONS,
)
from onyx.prompts.agent_search import (
SUB_QUESTION_ANSWER_TEMPLATE,
)
from onyx.prompts.agent_search import UNKNOWN_ANSWER
from onyx.tools.tool_implementations.search.search_tool import yield_search_responses
def generate_initial_answer(
state: SubQuestionRetrievalState,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> InitialAnswerUpdate:
"""
LangGraph node to generate the initial answer, using the initial sub-questions/sub-answers and the
documents retrieved for the original question.
"""
node_start_time = datetime.now()
graph_config = cast(GraphConfig, config["metadata"]["config"])
question = graph_config.inputs.search_request.query
prompt_enrichment_components = get_prompt_enrichment_components(graph_config)
sub_questions_cited_documents = state.cited_documents
orig_question_retrieval_documents = state.orig_question_retrieved_documents
consolidated_context_docs: list[InferenceSection] = sub_questions_cited_documents
counter = 0
for original_doc_number, original_doc in enumerate(
orig_question_retrieval_documents
):
if original_doc_number not in sub_questions_cited_documents:
if (
counter <= AGENT_MIN_ORIG_QUESTION_DOCS
or len(consolidated_context_docs) < AGENT_MAX_ANSWER_CONTEXT_DOCS
):
consolidated_context_docs.append(original_doc)
counter += 1
# sort docs by their scores - though the scores refer to different questions
relevant_docs = dedup_inference_sections(
consolidated_context_docs, consolidated_context_docs
)
sub_questions: list[str] = []
streamed_documents = (
relevant_docs
if len(relevant_docs) > 0
else state.orig_question_retrieved_documents[:15]
)
# Use the query info from the base document retrieval
query_info = get_query_info(state.orig_question_sub_query_retrieval_results)
assert (
graph_config.tooling.search_tool
), "search_tool must be provided for agentic search"
relevance_list = relevance_from_docs(relevant_docs)
for tool_response in yield_search_responses(
query=question,
reranked_sections=streamed_documents,
final_context_sections=streamed_documents,
search_query_info=query_info,
get_section_relevance=lambda: relevance_list,
search_tool=graph_config.tooling.search_tool,
):
write_custom_event(
"tool_response",
ExtendedToolResponse(
id=tool_response.id,
response=tool_response.response,
level=0,
level_question_num=0, # 0, 0 is the base question
),
writer,
)
if len(relevant_docs) == 0:
write_custom_event(
"initial_agent_answer",
AgentAnswerPiece(
answer_piece=UNKNOWN_ANSWER,
level=0,
level_question_num=0,
answer_type="agent_level_answer",
),
writer,
)
dispatch_main_answer_stop_info(0, writer)
answer = UNKNOWN_ANSWER
initial_agent_stats = InitialAgentResultStats(
sub_questions={},
original_question={},
agent_effectiveness={},
)
else:
sub_question_answer_results = state.sub_question_results
# Collect the sub-questions and sub-answers and construct an appropriate
# prompt string.
# Consider replacing by a function.
answered_sub_questions: list[str] = []
all_sub_questions: list[str] = [] # Separate list for tracking all questions
for idx, sub_question_answer_result in enumerate(
sub_question_answer_results, start=1
):
all_sub_questions.append(sub_question_answer_result.question)
is_valid_answer = (
sub_question_answer_result.verified_high_quality
and sub_question_answer_result.answer
and sub_question_answer_result.answer != UNKNOWN_ANSWER
)
if is_valid_answer:
answered_sub_questions.append(
SUB_QUESTION_ANSWER_TEMPLATE.format(
sub_question=sub_question_answer_result.question,
sub_answer=sub_question_answer_result.answer,
sub_question_num=idx,
)
)
sub_question_answer_str = (
"\n\n------\n\n".join(answered_sub_questions)
if answered_sub_questions
else ""
)
# Use the appropriate prompt based on whether there are sub-questions.
base_prompt = (
INITIAL_ANSWER_PROMPT_W_SUB_QUESTIONS
if answered_sub_questions
else INITIAL_ANSWER_PROMPT_WO_SUB_QUESTIONS
)
sub_questions = all_sub_questions # Replace the original assignment
model = graph_config.tooling.fast_llm
doc_context = format_docs(relevant_docs)
doc_context = trim_prompt_piece(
config=model.config,
prompt_piece=doc_context,
reserved_str=(
base_prompt
+ sub_question_answer_str
+ prompt_enrichment_components.persona_prompts.contextualized_prompt
+ prompt_enrichment_components.history
+ prompt_enrichment_components.date_str
),
)
msg = [
HumanMessage(
content=base_prompt.format(
question=question,
answered_sub_questions=remove_document_citations(
sub_question_answer_str
),
relevant_docs=doc_context,
persona_specification=prompt_enrichment_components.persona_prompts.contextualized_prompt,
history=prompt_enrichment_components.history,
date_prompt=prompt_enrichment_components.date_str,
)
)
]
streamed_tokens: list[str | list[str | dict[str, Any]]] = [""]
dispatch_timings: list[float] = []
for message in model.stream(msg):
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
content = message.content
if not isinstance(content, str):
raise ValueError(
f"Expected content to be a string, but got {type(content)}"
)
start_stream_token = datetime.now()
write_custom_event(
"initial_agent_answer",
AgentAnswerPiece(
answer_piece=content,
level=0,
level_question_num=0,
answer_type="agent_level_answer",
),
writer,
)
end_stream_token = datetime.now()
dispatch_timings.append(
(end_stream_token - start_stream_token).microseconds
)
streamed_tokens.append(content)
logger.debug(
f"Average dispatch time for initial answer: {sum(dispatch_timings) / len(dispatch_timings)}"
)
dispatch_main_answer_stop_info(0, writer)
response = merge_content(*streamed_tokens)
answer = cast(str, response)
initial_agent_stats = calculate_initial_agent_stats(
state.sub_question_results, state.orig_question_retrieval_stats
)
logger.debug(
f"\n\nYYYYY--Sub-Questions:\n\n{sub_question_answer_str}\n\nStats:\n\n"
)
if initial_agent_stats:
logger.debug(initial_agent_stats.original_question)
logger.debug(initial_agent_stats.sub_questions)
logger.debug(initial_agent_stats.agent_effectiveness)
agent_base_end_time = datetime.now()
if agent_base_end_time and state.agent_start_time:
duration_s = (agent_base_end_time - state.agent_start_time).total_seconds()
else:
duration_s = None
agent_base_metrics = AgentBaseMetrics(
num_verified_documents_total=len(relevant_docs),
num_verified_documents_core=state.orig_question_retrieval_stats.verified_count,
verified_avg_score_core=state.orig_question_retrieval_stats.verified_avg_scores,
num_verified_documents_base=initial_agent_stats.sub_questions.get(
"num_verified_documents"
),
verified_avg_score_base=initial_agent_stats.sub_questions.get(
"verified_avg_score"
),
base_doc_boost_factor=initial_agent_stats.agent_effectiveness.get(
"utilized_chunk_ratio"
),
support_boost_factor=initial_agent_stats.agent_effectiveness.get(
"support_ratio"
),
duration_s=duration_s,
)
return InitialAnswerUpdate(
initial_answer=answer,
initial_agent_stats=initial_agent_stats,
generated_sub_questions=sub_questions,
agent_base_end_time=agent_base_end_time,
agent_base_metrics=agent_base_metrics,
log_messages=[
get_langgraph_node_log_string(
graph_component="initial - generate initial answer",
node_name="generate initial answer",
node_start_time=node_start_time,
result="",
)
],
)

View File

@@ -1,40 +0,0 @@
from datetime import datetime
from onyx.agents.agent_search.deep_search.initial.generate_initial_answer.states import (
SubQuestionRetrievalState,
)
from onyx.agents.agent_search.deep_search.main.operations import logger
from onyx.agents.agent_search.deep_search.main.states import (
InitialAnswerQualityUpdate,
)
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
def validate_initial_answer(
state: SubQuestionRetrievalState,
) -> InitialAnswerQualityUpdate:
"""
Check whether the initial answer sufficiently addresses the original user question.
"""
node_start_time = datetime.now()
logger.debug(
f"--------{node_start_time}--------Checking for base answer validity - for not set True/False manually"
)
verdict = True
return InitialAnswerQualityUpdate(
initial_answer_quality_eval=verdict,
log_messages=[
get_langgraph_node_log_string(
graph_component="initial - generate initial answer",
node_name="validate initial answer",
node_start_time=node_start_time,
result="",
)
],
)

View File

@@ -1,51 +0,0 @@
from operator import add
from typing import Annotated
from typing import TypedDict
from onyx.agents.agent_search.core_state import CoreState
from onyx.agents.agent_search.deep_search.main.states import (
ExploratorySearchUpdate,
)
from onyx.agents.agent_search.deep_search.main.states import (
InitialAnswerQualityUpdate,
)
from onyx.agents.agent_search.deep_search.main.states import (
InitialAnswerUpdate,
)
from onyx.agents.agent_search.deep_search.main.states import (
InitialQuestionDecompositionUpdate,
)
from onyx.agents.agent_search.deep_search.main.states import (
OrigQuestionRetrievalUpdate,
)
from onyx.agents.agent_search.deep_search.main.states import (
SubQuestionResultsUpdate,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.models import (
QuestionRetrievalResult,
)
from onyx.context.search.models import InferenceSection
### States ###
class SubQuestionRetrievalInput(CoreState):
exploratory_search_results: list[InferenceSection]
## Graph State
class SubQuestionRetrievalState(
# This includes the core state
SubQuestionRetrievalInput,
InitialQuestionDecompositionUpdate,
InitialAnswerUpdate,
SubQuestionResultsUpdate,
OrigQuestionRetrievalUpdate,
InitialAnswerQualityUpdate,
ExploratorySearchUpdate,
):
base_raw_search_result: Annotated[list[QuestionRetrievalResult], add]
## Graph Output State
class SubQuestionRetrievalOutput(TypedDict):
log_messages: list[str]

View File

@@ -1,48 +0,0 @@
from collections.abc import Hashable
from datetime import datetime
from langgraph.types import Send
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
AnswerQuestionOutput,
)
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
SubQuestionAnsweringInput,
)
from onyx.agents.agent_search.deep_search.initial.generate_initial_answer.states import (
SubQuestionRetrievalState,
)
from onyx.agents.agent_search.shared_graph_utils.utils import make_question_id
def parallelize_initial_sub_question_answering(
state: SubQuestionRetrievalState,
) -> list[Send | Hashable]:
"""
LangGraph edge to parallelize the initial sub-question answering.
"""
edge_start_time = datetime.now()
if len(state.initial_sub_questions) > 0:
return [
Send(
"answer_sub_question_subgraphs",
SubQuestionAnsweringInput(
question=question,
question_id=make_question_id(0, question_num + 1),
log_messages=[
f"{edge_start_time} -- Main Edge - Parallelize Initial Sub-question Answering"
],
),
)
for question_num, question in enumerate(state.initial_sub_questions)
]
else:
return [
Send(
"ingest_answers",
AnswerQuestionOutput(
answer_results=[],
),
)
]

View File

@@ -1,81 +0,0 @@
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.graph_builder import (
answer_query_graph_builder,
)
from onyx.agents.agent_search.deep_search.initial.generate_sub_answers.edges import (
parallelize_initial_sub_question_answering,
)
from onyx.agents.agent_search.deep_search.initial.generate_sub_answers.nodes.decompose_orig_question import (
decompose_orig_question,
)
from onyx.agents.agent_search.deep_search.initial.generate_sub_answers.nodes.format_initial_sub_answers import (
format_initial_sub_answers,
)
from onyx.agents.agent_search.deep_search.initial.generate_sub_answers.states import (
SubQuestionAnsweringInput,
)
from onyx.agents.agent_search.deep_search.initial.generate_sub_answers.states import (
SubQuestionAnsweringState,
)
from onyx.utils.logger import setup_logger
logger = setup_logger()
test_mode = False
def generate_sub_answers_graph_builder() -> StateGraph:
"""
LangGraph graph builder for the initial sub-answer generation process.
It generates the initial sub-questions and produces the answers.
"""
graph = StateGraph(
state_schema=SubQuestionAnsweringState,
input=SubQuestionAnsweringInput,
)
# Decompose the original question into sub-questions
graph.add_node(
node="decompose_orig_question",
action=decompose_orig_question,
)
# The sub-graph that executes the initial sub-question answering for
# each of the sub-questions.
answer_sub_question_subgraphs = answer_query_graph_builder().compile()
graph.add_node(
node="answer_sub_question_subgraphs",
action=answer_sub_question_subgraphs,
)
# Node that collects and formats the initial sub-question answers
graph.add_node(
node="format_initial_sub_question_answers",
action=format_initial_sub_answers,
)
graph.add_edge(
start_key=START,
end_key="decompose_orig_question",
)
graph.add_conditional_edges(
source="decompose_orig_question",
path=parallelize_initial_sub_question_answering,
path_map=["answer_sub_question_subgraphs"],
)
graph.add_edge(
start_key=["answer_sub_question_subgraphs"],
end_key="format_initial_sub_question_answers",
)
graph.add_edge(
start_key="format_initial_sub_question_answers",
end_key=END,
)
return graph

View File

@@ -1,153 +0,0 @@
from datetime import datetime
from typing import cast
from langchain_core.messages import HumanMessage
from langchain_core.messages import merge_content
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.deep_search.initial.generate_initial_answer.states import (
SubQuestionRetrievalState,
)
from onyx.agents.agent_search.deep_search.main.models import (
AgentRefinedMetrics,
)
from onyx.agents.agent_search.deep_search.main.operations import (
dispatch_subquestion,
)
from onyx.agents.agent_search.deep_search.main.states import (
InitialQuestionDecompositionUpdate,
)
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
build_history_prompt,
)
from onyx.agents.agent_search.shared_graph_utils.utils import dispatch_separated
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import StreamStopInfo
from onyx.chat.models import StreamStopReason
from onyx.chat.models import StreamType
from onyx.chat.models import SubQuestionPiece
from onyx.configs.agent_configs import AGENT_NUM_DOCS_FOR_DECOMPOSITION
from onyx.prompts.agent_search import (
INITIAL_DECOMPOSITION_PROMPT_QUESTIONS_AFTER_SEARCH,
)
from onyx.prompts.agent_search import (
INITIAL_QUESTION_DECOMPOSITION_PROMPT,
)
from onyx.utils.logger import setup_logger
logger = setup_logger()
def decompose_orig_question(
state: SubQuestionRetrievalState,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> InitialQuestionDecompositionUpdate:
"""
LangGraph node to decompose the original question into sub-questions.
"""
node_start_time = datetime.now()
graph_config = cast(GraphConfig, config["metadata"]["config"])
question = graph_config.inputs.search_request.query
perform_initial_search_decomposition = (
graph_config.behavior.perform_initial_search_decomposition
)
# Get the rewritten queries in a defined format
model = graph_config.tooling.fast_llm
history = build_history_prompt(graph_config, question)
# Use the initial search results to inform the decomposition
agent_start_time = datetime.now()
# Initial search to inform decomposition. Just get top 3 fits
if perform_initial_search_decomposition:
# Due to unfortunate state representation in LangGraph, we need here to double check that the retrieval has
# happened prior to this point, allowing silent failure here since it is not critical for decomposition in
# all queries.
if not state.exploratory_search_results:
logger.error("Initial search for decomposition failed")
sample_doc_str = "\n\n".join(
[
doc.combined_content
for doc in state.exploratory_search_results[
:AGENT_NUM_DOCS_FOR_DECOMPOSITION
]
]
)
decomposition_prompt = (
INITIAL_DECOMPOSITION_PROMPT_QUESTIONS_AFTER_SEARCH.format(
question=question, sample_doc_str=sample_doc_str, history=history
)
)
else:
decomposition_prompt = INITIAL_QUESTION_DECOMPOSITION_PROMPT.format(
question=question, history=history
)
# Start decomposition
msg = [HumanMessage(content=decomposition_prompt)]
# Send the initial question as a subquestion with number 0
write_custom_event(
"decomp_qs",
SubQuestionPiece(
sub_question=question,
level=0,
level_question_num=0,
),
writer,
)
# dispatches custom events for subquestion tokens, adding in subquestion ids.
streamed_tokens = dispatch_separated(
model.stream(msg), dispatch_subquestion(0, writer)
)
stop_event = StreamStopInfo(
stop_reason=StreamStopReason.FINISHED,
stream_type=StreamType.SUB_QUESTIONS,
level=0,
)
write_custom_event("stream_finished", stop_event, writer)
deomposition_response = merge_content(*streamed_tokens)
# this call should only return strings. Commenting out for efficiency
# assert [type(tok) == str for tok in streamed_tokens]
# use no-op cast() instead of str() which runs code
# list_of_subquestions = clean_and_parse_list_string(cast(str, response))
list_of_subqs = cast(str, deomposition_response).split("\n")
decomp_list: list[str] = [sq.strip() for sq in list_of_subqs if sq.strip() != ""]
return InitialQuestionDecompositionUpdate(
initial_sub_questions=decomp_list,
agent_start_time=agent_start_time,
agent_refined_start_time=None,
agent_refined_end_time=None,
agent_refined_metrics=AgentRefinedMetrics(
refined_doc_boost_factor=None,
refined_question_boost_factor=None,
duration_s=None,
),
log_messages=[
get_langgraph_node_log_string(
graph_component="initial - generate sub answers",
node_name="decompose original question",
node_start_time=node_start_time,
result=f"decomposed original question into {len(decomp_list)} subquestions",
)
],
)

View File

@@ -1,50 +0,0 @@
from datetime import datetime
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
AnswerQuestionOutput,
)
from onyx.agents.agent_search.deep_search.main.states import (
SubQuestionResultsUpdate,
)
from onyx.agents.agent_search.shared_graph_utils.operators import (
dedup_inference_sections,
)
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
def format_initial_sub_answers(
state: AnswerQuestionOutput,
) -> SubQuestionResultsUpdate:
"""
LangGraph node to format the answers to the initial sub-questions, including
deduping verified documents and context documents.
"""
node_start_time = datetime.now()
documents = []
context_documents = []
cited_documents = []
answer_results = state.answer_results
for answer_result in answer_results:
documents.extend(answer_result.verified_reranked_documents)
context_documents.extend(answer_result.context_documents)
cited_documents.extend(answer_result.cited_documents)
return SubQuestionResultsUpdate(
# Deduping is done by the documents operator for the main graph
# so we might not need to dedup here
verified_reranked_documents=dedup_inference_sections(documents, []),
context_documents=dedup_inference_sections(context_documents, []),
cited_documents=dedup_inference_sections(cited_documents, []),
sub_question_results=answer_results,
log_messages=[
get_langgraph_node_log_string(
graph_component="initial - generate sub answers",
node_name="format initial sub answers",
node_start_time=node_start_time,
result="",
)
],
)

View File

@@ -1,34 +0,0 @@
from typing import TypedDict
from onyx.agents.agent_search.core_state import CoreState
from onyx.agents.agent_search.deep_search.main.states import (
InitialAnswerUpdate,
)
from onyx.agents.agent_search.deep_search.main.states import (
InitialQuestionDecompositionUpdate,
)
from onyx.agents.agent_search.deep_search.main.states import (
SubQuestionResultsUpdate,
)
from onyx.context.search.models import InferenceSection
### States ###
class SubQuestionAnsweringInput(CoreState):
exploratory_search_results: list[InferenceSection]
## Graph State
class SubQuestionAnsweringState(
# This includes the core state
SubQuestionAnsweringInput,
InitialQuestionDecompositionUpdate,
InitialAnswerUpdate,
SubQuestionResultsUpdate,
):
pass
## Graph Output State
class SubQuestionAnsweringOutput(TypedDict):
log_messages: list[str]

View File

@@ -1,81 +0,0 @@
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
from onyx.agents.agent_search.deep_search.initial.retrieve_orig_question_docs.nodes.format_orig_question_search_input import (
format_orig_question_search_input,
)
from onyx.agents.agent_search.deep_search.initial.retrieve_orig_question_docs.nodes.format_orig_question_search_output import (
format_orig_question_search_output,
)
from onyx.agents.agent_search.deep_search.initial.retrieve_orig_question_docs.states import (
BaseRawSearchInput,
)
from onyx.agents.agent_search.deep_search.initial.retrieve_orig_question_docs.states import (
BaseRawSearchOutput,
)
from onyx.agents.agent_search.deep_search.initial.retrieve_orig_question_docs.states import (
BaseRawSearchState,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.graph_builder import (
expanded_retrieval_graph_builder,
)
def retrieve_orig_question_docs_graph_builder() -> StateGraph:
"""
LangGraph graph builder for the retrieval of documents
that are relevant to the original question. This is
largely a wrapper around the expanded retrieval process to
ensure parallelism with the sub-question answer process.
"""
graph = StateGraph(
state_schema=BaseRawSearchState,
input=BaseRawSearchInput,
output=BaseRawSearchOutput,
)
### Add nodes ###
# Format the original question search output
graph.add_node(
node="format_orig_question_search_output",
action=format_orig_question_search_output,
)
# The sub-graph that executes the expanded retrieval process
expanded_retrieval = expanded_retrieval_graph_builder().compile()
graph.add_node(
node="retrieve_orig_question_docs_subgraph",
action=expanded_retrieval,
)
# Format the original question search input
graph.add_node(
node="format_orig_question_search_input",
action=format_orig_question_search_input,
)
### Add edges ###
graph.add_edge(start_key=START, end_key="format_orig_question_search_input")
graph.add_edge(
start_key="format_orig_question_search_input",
end_key="retrieve_orig_question_docs_subgraph",
)
graph.add_edge(
start_key="retrieve_orig_question_docs_subgraph",
end_key="format_orig_question_search_output",
)
graph.add_edge(
start_key="format_orig_question_search_output",
end_key=END,
)
return graph
if __name__ == "__main__":
pass

View File

@@ -1,30 +0,0 @@
from onyx.agents.agent_search.deep_search.main.states import OrigQuestionRetrievalUpdate
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
ExpandedRetrievalOutput,
)
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkRetrievalStats
from onyx.utils.logger import setup_logger
logger = setup_logger()
def format_orig_question_search_output(
state: ExpandedRetrievalOutput,
) -> OrigQuestionRetrievalUpdate:
"""
LangGraph node to format the search result for the original question into the
proper format.
"""
sub_question_retrieval_stats = state.expanded_retrieval_result.retrieval_stats
if sub_question_retrieval_stats is None:
sub_question_retrieval_stats = AgentChunkRetrievalStats()
else:
sub_question_retrieval_stats = sub_question_retrieval_stats
return OrigQuestionRetrievalUpdate(
orig_question_verified_reranked_documents=state.expanded_retrieval_result.verified_reranked_documents,
orig_question_sub_query_retrieval_results=state.expanded_retrieval_result.expanded_query_results,
orig_question_retrieved_documents=state.retrieved_documents,
orig_question_retrieval_stats=sub_question_retrieval_stats,
log_messages=[],
)

View File

@@ -1,29 +0,0 @@
from onyx.agents.agent_search.deep_search.main.states import (
OrigQuestionRetrievalUpdate,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
ExpandedRetrievalInput,
)
## Graph Input State
class BaseRawSearchInput(ExpandedRetrievalInput):
pass
## Graph Output State
class BaseRawSearchOutput(OrigQuestionRetrievalUpdate):
"""
This is a list of results even though each call of this subgraph only returns one result.
This is because if we parallelize the answer query subgraph, there will be multiple
results in a list so the add operator is used to add them together.
"""
# base_expanded_retrieval_result: QuestionRetrievalResult = QuestionRetrievalResult()
## Graph State
class BaseRawSearchState(
BaseRawSearchInput, BaseRawSearchOutput, OrigQuestionRetrievalUpdate
):
pass

View File

@@ -1,113 +0,0 @@
from collections.abc import Hashable
from datetime import datetime
from typing import cast
from typing import Literal
from langchain_core.runnables import RunnableConfig
from langgraph.types import Send
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
AnswerQuestionOutput,
)
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
SubQuestionAnsweringInput,
)
from onyx.agents.agent_search.deep_search.main.states import MainState
from onyx.agents.agent_search.deep_search.main.states import (
RequireRefinemenEvalUpdate,
)
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.utils import make_question_id
from onyx.utils.logger import setup_logger
logger = setup_logger()
def route_initial_tool_choice(
state: MainState, config: RunnableConfig
) -> Literal["tool_call", "start_agent_search", "logging_node"]:
"""
LangGraph edge to route to agent search.
"""
agent_config = cast(GraphConfig, config["metadata"]["config"])
if state.tool_choice is not None:
if (
agent_config.behavior.use_agentic_search
and agent_config.tooling.search_tool is not None
and state.tool_choice.tool.name == agent_config.tooling.search_tool.name
):
return "start_agent_search"
else:
return "tool_call"
else:
return "logging_node"
def parallelize_initial_sub_question_answering(
state: MainState,
) -> list[Send | Hashable]:
edge_start_time = datetime.now()
if len(state.initial_sub_questions) > 0:
return [
Send(
"answer_query_subgraph",
SubQuestionAnsweringInput(
question=question,
question_id=make_question_id(0, question_num + 1),
log_messages=[
f"{edge_start_time} -- Main Edge - Parallelize Initial Sub-question Answering"
],
),
)
for question_num, question in enumerate(state.initial_sub_questions)
]
else:
return [
Send(
"ingest_answers",
AnswerQuestionOutput(
answer_results=[],
),
)
]
# Define the function that determines whether to continue or not
def continue_to_refined_answer_or_end(
state: RequireRefinemenEvalUpdate,
) -> Literal["create_refined_sub_questions", "logging_node"]:
if state.require_refined_answer_eval:
return "create_refined_sub_questions"
else:
return "logging_node"
def parallelize_refined_sub_question_answering(
state: MainState,
) -> list[Send | Hashable]:
edge_start_time = datetime.now()
if len(state.refined_sub_questions) > 0:
return [
Send(
"answer_refined_question_subgraphs",
SubQuestionAnsweringInput(
question=question_data.sub_question,
question_id=make_question_id(1, question_num),
log_messages=[
f"{edge_start_time} -- Main Edge - Parallelize Refined Sub-question Answering"
],
),
)
for question_num, question_data in state.refined_sub_questions.items()
]
else:
return [
Send(
"ingest_refined_sub_answers",
AnswerQuestionOutput(
answer_results=[],
),
)
]

View File

@@ -1,265 +0,0 @@
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
from onyx.agents.agent_search.deep_search.initial.generate_initial_answer.graph_builder import (
generate_initial_answer_graph_builder,
)
from onyx.agents.agent_search.deep_search.main.edges import (
continue_to_refined_answer_or_end,
)
from onyx.agents.agent_search.deep_search.main.edges import (
parallelize_refined_sub_question_answering,
)
from onyx.agents.agent_search.deep_search.main.edges import (
route_initial_tool_choice,
)
from onyx.agents.agent_search.deep_search.main.nodes.compare_answers import (
compare_answers,
)
from onyx.agents.agent_search.deep_search.main.nodes.create_refined_sub_questions import (
create_refined_sub_questions,
)
from onyx.agents.agent_search.deep_search.main.nodes.decide_refinement_need import (
decide_refinement_need,
)
from onyx.agents.agent_search.deep_search.main.nodes.extract_entities_terms import (
extract_entities_terms,
)
from onyx.agents.agent_search.deep_search.main.nodes.generate_refined_answer import (
generate_refined_answer,
)
from onyx.agents.agent_search.deep_search.main.nodes.ingest_refined_sub_answers import (
ingest_refined_sub_answers,
)
from onyx.agents.agent_search.deep_search.main.nodes.persist_agent_results import (
persist_agent_results,
)
from onyx.agents.agent_search.deep_search.main.nodes.start_agent_search import (
start_agent_search,
)
from onyx.agents.agent_search.deep_search.main.states import MainInput
from onyx.agents.agent_search.deep_search.main.states import MainState
from onyx.agents.agent_search.deep_search.refinement.consolidate_sub_answers.graph_builder import (
answer_refined_query_graph_builder,
)
from onyx.agents.agent_search.orchestration.nodes.basic_use_tool_response import (
basic_use_tool_response,
)
from onyx.agents.agent_search.orchestration.nodes.llm_tool_choice import llm_tool_choice
from onyx.agents.agent_search.orchestration.nodes.prepare_tool_input import (
prepare_tool_input,
)
from onyx.agents.agent_search.orchestration.nodes.tool_call import tool_call
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
from onyx.utils.logger import setup_logger
logger = setup_logger()
test_mode = False
def main_graph_builder(test_mode: bool = False) -> StateGraph:
"""
LangGraph graph builder for the main agent search process.
"""
graph = StateGraph(
state_schema=MainState,
input=MainInput,
)
# Prepare the tool input
graph.add_node(
node="prepare_tool_input",
action=prepare_tool_input,
)
# Choose the initial tool
graph.add_node(
node="initial_tool_choice",
action=llm_tool_choice,
)
# Call the tool, if required
graph.add_node(
node="tool_call",
action=tool_call,
)
# Use the tool response
graph.add_node(
node="basic_use_tool_response",
action=basic_use_tool_response,
)
# Start the agent search process
graph.add_node(
node="start_agent_search",
action=start_agent_search,
)
# The sub-graph for the initial answer generation
generate_initial_answer_subgraph = generate_initial_answer_graph_builder().compile()
graph.add_node(
node="generate_initial_answer_subgraph",
action=generate_initial_answer_subgraph,
)
# Create the refined sub-questions
graph.add_node(
node="create_refined_sub_questions",
action=create_refined_sub_questions,
)
# Subgraph for the refined sub-answer generation
answer_refined_question = answer_refined_query_graph_builder().compile()
graph.add_node(
node="answer_refined_question_subgraphs",
action=answer_refined_question,
)
# Ingest the refined sub-answers
graph.add_node(
node="ingest_refined_sub_answers",
action=ingest_refined_sub_answers,
)
# Node to generate the refined answer
graph.add_node(
node="generate_refined_answer",
action=generate_refined_answer,
)
# Early node to extract the entities and terms from the initial answer,
# This information is used to inform the creation the refined sub-questions
graph.add_node(
node="extract_entity_term",
action=extract_entities_terms,
)
# Decide if the answer needs to be refined (currently always true)
graph.add_node(
node="decide_refinement_need",
action=decide_refinement_need,
)
# Compare the initial and refined answers, and determine whether
# the refined answer is sufficiently better
graph.add_node(
node="compare_answers",
action=compare_answers,
)
# Log the results. This will log the stats as well as the answers, sub-questions, and sub-answers
graph.add_node(
node="logging_node",
action=persist_agent_results,
)
### Add edges ###
graph.add_edge(start_key=START, end_key="prepare_tool_input")
graph.add_edge(
start_key="prepare_tool_input",
end_key="initial_tool_choice",
)
graph.add_conditional_edges(
"initial_tool_choice",
route_initial_tool_choice,
["tool_call", "start_agent_search", "logging_node"],
)
graph.add_edge(
start_key="tool_call",
end_key="basic_use_tool_response",
)
graph.add_edge(
start_key="basic_use_tool_response",
end_key="logging_node",
)
graph.add_edge(
start_key="start_agent_search",
end_key="generate_initial_answer_subgraph",
)
graph.add_edge(
start_key="start_agent_search",
end_key="extract_entity_term",
)
# Wait for the initial answer generation and the entity/term extraction to be complete
# before deciding if a refinement is needed.
graph.add_edge(
start_key=["generate_initial_answer_subgraph", "extract_entity_term"],
end_key="decide_refinement_need",
)
graph.add_conditional_edges(
source="decide_refinement_need",
path=continue_to_refined_answer_or_end,
path_map=["create_refined_sub_questions", "logging_node"],
)
graph.add_conditional_edges(
source="create_refined_sub_questions",
path=parallelize_refined_sub_question_answering,
path_map=["answer_refined_question_subgraphs"],
)
graph.add_edge(
start_key="answer_refined_question_subgraphs",
end_key="ingest_refined_sub_answers",
)
graph.add_edge(
start_key="ingest_refined_sub_answers",
end_key="generate_refined_answer",
)
graph.add_edge(
start_key="generate_refined_answer",
end_key="compare_answers",
)
graph.add_edge(
start_key="compare_answers",
end_key="logging_node",
)
graph.add_edge(
start_key="logging_node",
end_key=END,
)
return graph
if __name__ == "__main__":
pass
from onyx.db.engine import get_session_context_manager
from onyx.llm.factory import get_default_llms
from onyx.context.search.models import SearchRequest
graph = main_graph_builder()
compiled_graph = graph.compile()
primary_llm, fast_llm = get_default_llms()
with get_session_context_manager() as db_session:
search_request = SearchRequest(query="Who created Excel?")
graph_config = get_test_config(
db_session, primary_llm, fast_llm, search_request
)
inputs = MainInput(
base_question=graph_config.inputs.search_request.query, log_messages=[]
)
for thing in compiled_graph.stream(
input=inputs,
config={"configurable": {"config": graph_config}},
stream_mode="custom",
subgraphs=True,
):
logger.debug(thing)

View File

@@ -1,71 +0,0 @@
from datetime import datetime
from typing import cast
from langchain_core.messages import HumanMessage
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.deep_search.main.states import (
InitialRefinedAnswerComparisonUpdate,
)
from onyx.agents.agent_search.deep_search.main.states import MainState
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import RefinedAnswerImprovement
from onyx.prompts.agent_search import (
INITIAL_REFINED_ANSWER_COMPARISON_PROMPT,
)
def compare_answers(
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> InitialRefinedAnswerComparisonUpdate:
"""
LangGraph node to compare the initial answer and the refined answer and determine if the
refined answer is sufficiently better than the initial answer.
"""
node_start_time = datetime.now()
graph_config = cast(GraphConfig, config["metadata"]["config"])
question = graph_config.inputs.search_request.query
initial_answer = state.initial_answer
refined_answer = state.refined_answer
compare_answers_prompt = INITIAL_REFINED_ANSWER_COMPARISON_PROMPT.format(
question=question, initial_answer=initial_answer, refined_answer=refined_answer
)
msg = [HumanMessage(content=compare_answers_prompt)]
# Get the rewritten queries in a defined format
model = graph_config.tooling.fast_llm
# no need to stream this
resp = model.invoke(msg)
refined_answer_improvement = (
isinstance(resp.content, str) and "yes" in resp.content.lower()
)
write_custom_event(
"refined_answer_improvement",
RefinedAnswerImprovement(
refined_answer_improvement=refined_answer_improvement,
),
writer,
)
return InitialRefinedAnswerComparisonUpdate(
refined_answer_improvement_eval=refined_answer_improvement,
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="compare answers",
node_start_time=node_start_time,
result=f"Answer comparison: {refined_answer_improvement}",
)
],
)

View File

@@ -1,131 +0,0 @@
from datetime import datetime
from typing import cast
from langchain_core.messages import HumanMessage
from langchain_core.messages import merge_content
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.deep_search.main.models import (
RefinementSubQuestion,
)
from onyx.agents.agent_search.deep_search.main.operations import (
dispatch_subquestion,
)
from onyx.agents.agent_search.deep_search.main.states import MainState
from onyx.agents.agent_search.deep_search.main.states import (
RefinedQuestionDecompositionUpdate,
)
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
build_history_prompt,
)
from onyx.agents.agent_search.shared_graph_utils.utils import dispatch_separated
from onyx.agents.agent_search.shared_graph_utils.utils import (
format_entity_term_extraction,
)
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import make_question_id
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.prompts.agent_search import (
REFINEMENT_QUESTION_DECOMPOSITION_PROMPT,
)
from onyx.tools.models import ToolCallKickoff
def create_refined_sub_questions(
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> RefinedQuestionDecompositionUpdate:
"""
LangGraph node to create refined sub-questions based on the initial answer, the history,
the entity term extraction results found earlier, and the sub-questions that were answered and failed.
"""
graph_config = cast(GraphConfig, config["metadata"]["config"])
write_custom_event(
"start_refined_answer_creation",
ToolCallKickoff(
tool_name="agent_search_1",
tool_args={
"query": graph_config.inputs.search_request.query,
"answer": state.initial_answer,
},
),
writer,
)
node_start_time = datetime.now()
agent_refined_start_time = datetime.now()
question = graph_config.inputs.search_request.query
base_answer = state.initial_answer
history = build_history_prompt(graph_config, question)
# get the entity term extraction dict and properly format it
entity_retlation_term_extractions = state.entity_relation_term_extractions
entity_term_extraction_str = format_entity_term_extraction(
entity_retlation_term_extractions
)
initial_question_answers = state.sub_question_results
addressed_question_list = [
x.question for x in initial_question_answers if x.verified_high_quality
]
failed_question_list = [
x.question for x in initial_question_answers if not x.verified_high_quality
]
msg = [
HumanMessage(
content=REFINEMENT_QUESTION_DECOMPOSITION_PROMPT.format(
question=question,
history=history,
entity_term_extraction_str=entity_term_extraction_str,
base_answer=base_answer,
answered_sub_questions="\n - ".join(addressed_question_list),
failed_sub_questions="\n - ".join(failed_question_list),
),
)
]
# Grader
model = graph_config.tooling.fast_llm
streamed_tokens = dispatch_separated(
model.stream(msg), dispatch_subquestion(1, writer)
)
response = merge_content(*streamed_tokens)
if isinstance(response, str):
parsed_response = [q for q in response.split("\n") if q.strip() != ""]
else:
raise ValueError("LLM response is not a string")
refined_sub_question_dict = {}
for sub_question_num, sub_question in enumerate(parsed_response):
refined_sub_question = RefinementSubQuestion(
sub_question=sub_question,
sub_question_id=make_question_id(1, sub_question_num + 1),
verified=False,
answered=False,
answer="",
)
refined_sub_question_dict[sub_question_num + 1] = refined_sub_question
return RefinedQuestionDecompositionUpdate(
refined_sub_questions=refined_sub_question_dict,
agent_refined_start_time=agent_refined_start_time,
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="create refined sub questions",
node_start_time=node_start_time,
result=f"Created {len(refined_sub_question_dict)} refined sub questions",
)
],
)

View File

@@ -1,47 +0,0 @@
from datetime import datetime
from typing import cast
from langchain_core.runnables import RunnableConfig
from onyx.agents.agent_search.deep_search.main.states import MainState
from onyx.agents.agent_search.deep_search.main.states import (
RequireRefinemenEvalUpdate,
)
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
def decide_refinement_need(
state: MainState, config: RunnableConfig
) -> RequireRefinemenEvalUpdate:
"""
LangGraph node to decide if refinement is needed based on the initial answer and the question.
At present, we always refine.
"""
node_start_time = datetime.now()
graph_config = cast(GraphConfig, config["metadata"]["config"])
decision = True # TODO: just for current testing purposes
log_messages = [
get_langgraph_node_log_string(
graph_component="main",
node_name="decide refinement need",
node_start_time=node_start_time,
result=f"Refinement decision: {decision}",
)
]
if graph_config.behavior.allow_refinement:
return RequireRefinemenEvalUpdate(
require_refined_answer_eval=decision,
log_messages=log_messages,
)
else:
return RequireRefinemenEvalUpdate(
require_refined_answer_eval=False,
log_messages=log_messages,
)

View File

@@ -1,116 +0,0 @@
from datetime import datetime
from typing import cast
from langchain_core.messages import HumanMessage
from langchain_core.runnables import RunnableConfig
from onyx.agents.agent_search.deep_search.main.operations import logger
from onyx.agents.agent_search.deep_search.main.states import (
EntityTermExtractionUpdate,
)
from onyx.agents.agent_search.deep_search.main.states import MainState
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
trim_prompt_piece,
)
from onyx.agents.agent_search.shared_graph_utils.models import EntityExtractionResult
from onyx.agents.agent_search.shared_graph_utils.models import (
EntityRelationshipTermExtraction,
)
from onyx.agents.agent_search.shared_graph_utils.utils import format_docs
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.configs.constants import NUM_EXPLORATORY_DOCS
from onyx.prompts.agent_search import ENTITY_TERM_EXTRACTION_PROMPT
from onyx.prompts.agent_search import ENTITY_TERM_EXTRACTION_PROMPT_JSON_EXAMPLE
def extract_entities_terms(
state: MainState, config: RunnableConfig
) -> EntityTermExtractionUpdate:
"""
LangGraph node to extract entities, relationships, and terms from the initial search results.
This data is used to inform particularly the sub-questions that are created for the refined answer.
"""
node_start_time = datetime.now()
graph_config = cast(GraphConfig, config["metadata"]["config"])
if not graph_config.behavior.allow_refinement:
return EntityTermExtractionUpdate(
entity_relation_term_extractions=EntityRelationshipTermExtraction(
entities=[],
relationships=[],
terms=[],
),
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="extract entities terms",
node_start_time=node_start_time,
result="Refinement is not allowed",
)
],
)
# first four lines duplicates from generate_initial_answer
question = graph_config.inputs.search_request.query
initial_search_docs = state.exploratory_search_results[:NUM_EXPLORATORY_DOCS]
# start with the entity/term/extraction
doc_context = format_docs(initial_search_docs)
# Calculation here is only approximate
doc_context = trim_prompt_piece(
graph_config.tooling.fast_llm.config,
doc_context,
ENTITY_TERM_EXTRACTION_PROMPT
+ question
+ ENTITY_TERM_EXTRACTION_PROMPT_JSON_EXAMPLE,
)
msg = [
HumanMessage(
content=ENTITY_TERM_EXTRACTION_PROMPT.format(
question=question, context=doc_context
)
+ ENTITY_TERM_EXTRACTION_PROMPT_JSON_EXAMPLE,
)
]
fast_llm = graph_config.tooling.fast_llm
# Grader
llm_response = fast_llm.invoke(
prompt=msg,
)
cleaned_response = (
str(llm_response.content).replace("```json\n", "").replace("\n```", "")
)
first_bracket = cleaned_response.find("{")
last_bracket = cleaned_response.rfind("}")
cleaned_response = cleaned_response[first_bracket : last_bracket + 1]
try:
entity_extraction_result = EntityExtractionResult.model_validate_json(
cleaned_response
)
except ValueError:
logger.error("Failed to parse LLM response as JSON in Entity-Term Extraction")
entity_extraction_result = EntityExtractionResult(
retrieved_entities_relationships=EntityRelationshipTermExtraction(
entities=[],
relationships=[],
terms=[],
),
)
return EntityTermExtractionUpdate(
entity_relation_term_extractions=entity_extraction_result.retrieved_entities_relationships,
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="extract entities terms",
node_start_time=node_start_time,
)
],
)

View File

@@ -1,339 +0,0 @@
from datetime import datetime
from typing import Any
from typing import cast
from langchain_core.messages import HumanMessage
from langchain_core.messages import merge_content
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.deep_search.main.models import (
AgentRefinedMetrics,
)
from onyx.agents.agent_search.deep_search.main.operations import get_query_info
from onyx.agents.agent_search.deep_search.main.operations import logger
from onyx.agents.agent_search.deep_search.main.states import MainState
from onyx.agents.agent_search.deep_search.main.states import (
RefinedAnswerUpdate,
)
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
get_prompt_enrichment_components,
)
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
trim_prompt_piece,
)
from onyx.agents.agent_search.shared_graph_utils.models import InferenceSection
from onyx.agents.agent_search.shared_graph_utils.models import RefinedAgentStats
from onyx.agents.agent_search.shared_graph_utils.operators import (
dedup_inference_sections,
)
from onyx.agents.agent_search.shared_graph_utils.utils import (
dispatch_main_answer_stop_info,
)
from onyx.agents.agent_search.shared_graph_utils.utils import format_docs
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
from onyx.agents.agent_search.shared_graph_utils.utils import relevance_from_docs
from onyx.agents.agent_search.shared_graph_utils.utils import (
remove_document_citations,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import AgentAnswerPiece
from onyx.chat.models import ExtendedToolResponse
from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS
from onyx.configs.agent_configs import AGENT_MIN_ORIG_QUESTION_DOCS
from onyx.prompts.agent_search import (
REFINED_ANSWER_PROMPT_W_SUB_QUESTIONS,
)
from onyx.prompts.agent_search import (
REFINED_ANSWER_PROMPT_WO_SUB_QUESTIONS,
)
from onyx.prompts.agent_search import (
SUB_QUESTION_ANSWER_TEMPLATE_REFINED,
)
from onyx.prompts.agent_search import UNKNOWN_ANSWER
from onyx.tools.tool_implementations.search.search_tool import yield_search_responses
def generate_refined_answer(
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> RefinedAnswerUpdate:
"""
LangGraph node to generate the refined answer.
"""
node_start_time = datetime.now()
graph_config = cast(GraphConfig, config["metadata"]["config"])
question = graph_config.inputs.search_request.query
prompt_enrichment_components = get_prompt_enrichment_components(graph_config)
persona_contextualized_prompt = (
prompt_enrichment_components.persona_prompts.contextualized_prompt
)
verified_reranked_documents = state.verified_reranked_documents
sub_questions_cited_documents = state.cited_documents
original_question_verified_documents = (
state.orig_question_verified_reranked_documents
)
original_question_retrieved_documents = state.orig_question_retrieved_documents
consolidated_context_docs: list[InferenceSection] = sub_questions_cited_documents
counter = 0
for original_doc_number, original_doc in enumerate(
original_question_verified_documents
):
if original_doc_number not in sub_questions_cited_documents:
if (
counter <= AGENT_MIN_ORIG_QUESTION_DOCS
or len(consolidated_context_docs)
< 1.5
* AGENT_MAX_ANSWER_CONTEXT_DOCS # allow for larger context in refinement
):
consolidated_context_docs.append(original_doc)
counter += 1
# sort docs by their scores - though the scores refer to different questions
relevant_docs = dedup_inference_sections(
consolidated_context_docs, consolidated_context_docs
)
streaming_docs = (
relevant_docs
if len(relevant_docs) > 0
else original_question_retrieved_documents[:15]
)
query_info = get_query_info(state.orig_question_sub_query_retrieval_results)
assert (
graph_config.tooling.search_tool
), "search_tool must be provided for agentic search"
# stream refined answer docs, or original question docs if no relevant docs are found
relevance_list = relevance_from_docs(relevant_docs)
for tool_response in yield_search_responses(
query=question,
reranked_sections=streaming_docs,
final_context_sections=streaming_docs,
search_query_info=query_info,
get_section_relevance=lambda: relevance_list,
search_tool=graph_config.tooling.search_tool,
):
write_custom_event(
"tool_response",
ExtendedToolResponse(
id=tool_response.id,
response=tool_response.response,
level=1,
level_question_num=0, # 0, 0 is the base question
),
writer,
)
if len(verified_reranked_documents) > 0:
refined_doc_effectiveness = len(relevant_docs) / len(
verified_reranked_documents
)
else:
refined_doc_effectiveness = 10.0
sub_question_answer_results = state.sub_question_results
answered_sub_question_answer_list: list[str] = []
sub_questions: list[str] = []
initial_answered_sub_questions: set[str] = set()
refined_answered_sub_questions: set[str] = set()
for i, result in enumerate(sub_question_answer_results, 1):
question_level, _ = parse_question_id(result.question_id)
sub_questions.append(result.question)
if (
result.verified_high_quality
and result.answer
and result.answer != UNKNOWN_ANSWER
):
sub_question_type = "initial" if question_level == 0 else "refined"
question_set = (
initial_answered_sub_questions
if question_level == 0
else refined_answered_sub_questions
)
question_set.add(result.question)
answered_sub_question_answer_list.append(
SUB_QUESTION_ANSWER_TEMPLATE_REFINED.format(
sub_question=result.question,
sub_answer=result.answer,
sub_question_num=i,
sub_question_type=sub_question_type,
)
)
# Calculate efficiency
total_answered_questions = (
initial_answered_sub_questions | refined_answered_sub_questions
)
revision_question_efficiency = (
len(total_answered_questions) / len(initial_answered_sub_questions)
if initial_answered_sub_questions
else 10.0
if refined_answered_sub_questions
else 1.0
)
sub_question_answer_str = "\n\n------\n\n".join(
set(answered_sub_question_answer_list)
)
initial_answer = state.initial_answer or ""
# Choose appropriate prompt template
base_prompt = (
REFINED_ANSWER_PROMPT_W_SUB_QUESTIONS
if answered_sub_question_answer_list
else REFINED_ANSWER_PROMPT_WO_SUB_QUESTIONS
)
model = graph_config.tooling.fast_llm
relevant_docs_str = format_docs(relevant_docs)
relevant_docs_str = trim_prompt_piece(
model.config,
relevant_docs_str,
base_prompt
+ question
+ sub_question_answer_str
+ initial_answer
+ persona_contextualized_prompt
+ prompt_enrichment_components.history,
)
msg = [
HumanMessage(
content=base_prompt.format(
question=question,
history=prompt_enrichment_components.history,
answered_sub_questions=remove_document_citations(
sub_question_answer_str
),
relevant_docs=relevant_docs_str,
initial_answer=remove_document_citations(initial_answer)
if initial_answer
else None,
persona_specification=persona_contextualized_prompt,
date_prompt=prompt_enrichment_components.date_str,
)
)
]
streamed_tokens: list[str | list[str | dict[str, Any]]] = [""]
dispatch_timings: list[float] = []
for message in model.stream(msg):
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
content = message.content
if not isinstance(content, str):
raise ValueError(
f"Expected content to be a string, but got {type(content)}"
)
start_stream_token = datetime.now()
write_custom_event(
"refined_agent_answer",
AgentAnswerPiece(
answer_piece=content,
level=1,
level_question_num=0,
answer_type="agent_level_answer",
),
writer,
)
end_stream_token = datetime.now()
dispatch_timings.append((end_stream_token - start_stream_token).microseconds)
streamed_tokens.append(content)
logger.debug(
f"Average dispatch time for refined answer: {sum(dispatch_timings) / len(dispatch_timings)}"
)
dispatch_main_answer_stop_info(1, writer)
response = merge_content(*streamed_tokens)
answer = cast(str, response)
refined_agent_stats = RefinedAgentStats(
revision_doc_efficiency=refined_doc_effectiveness,
revision_question_efficiency=revision_question_efficiency,
)
logger.debug(f"\n\n---INITIAL ANSWER ---\n\n Answer:\n Agent: {initial_answer}")
logger.debug("-" * 10)
logger.debug(f"\n\n---REVISED AGENT ANSWER ---\n\n Answer:\n Agent: {answer}")
logger.debug("-" * 100)
if state.initial_agent_stats:
initial_doc_boost_factor = state.initial_agent_stats.agent_effectiveness.get(
"utilized_chunk_ratio", "--"
)
initial_support_boost_factor = (
state.initial_agent_stats.agent_effectiveness.get("support_ratio", "--")
)
num_initial_verified_docs = state.initial_agent_stats.original_question.get(
"num_verified_documents", "--"
)
initial_verified_docs_avg_score = (
state.initial_agent_stats.original_question.get("verified_avg_score", "--")
)
initial_sub_questions_verified_docs = (
state.initial_agent_stats.sub_questions.get("num_verified_documents", "--")
)
logger.debug("INITIAL AGENT STATS")
logger.debug(f"Document Boost Factor: {initial_doc_boost_factor}")
logger.debug(f"Support Boost Factor: {initial_support_boost_factor}")
logger.debug(f"Originally Verified Docs: {num_initial_verified_docs}")
logger.debug(
f"Originally Verified Docs Avg Score: {initial_verified_docs_avg_score}"
)
logger.debug(
f"Sub-Questions Verified Docs: {initial_sub_questions_verified_docs}"
)
if refined_agent_stats:
logger.debug("-" * 10)
logger.debug("REFINED AGENT STATS")
logger.debug(
f"Revision Doc Factor: {refined_agent_stats.revision_doc_efficiency}"
)
logger.debug(
f"Revision Question Factor: {refined_agent_stats.revision_question_efficiency}"
)
agent_refined_end_time = datetime.now()
if state.agent_refined_start_time:
agent_refined_duration = (
agent_refined_end_time - state.agent_refined_start_time
).total_seconds()
else:
agent_refined_duration = None
agent_refined_metrics = AgentRefinedMetrics(
refined_doc_boost_factor=refined_agent_stats.revision_doc_efficiency,
refined_question_boost_factor=refined_agent_stats.revision_question_efficiency,
duration_s=agent_refined_duration,
)
return RefinedAnswerUpdate(
refined_answer=answer,
refined_answer_quality=True, # TODO: replace this with the actual check value
refined_agent_stats=refined_agent_stats,
agent_refined_end_time=agent_refined_end_time,
agent_refined_metrics=agent_refined_metrics,
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="generate refined answer",
node_start_time=node_start_time,
)
],
)

View File

@@ -1,42 +0,0 @@
from datetime import datetime
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
AnswerQuestionOutput,
)
from onyx.agents.agent_search.deep_search.main.states import (
SubQuestionResultsUpdate,
)
from onyx.agents.agent_search.shared_graph_utils.operators import (
dedup_inference_sections,
)
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
def ingest_refined_sub_answers(
state: AnswerQuestionOutput,
) -> SubQuestionResultsUpdate:
"""
LangGraph node to ingest and format the refined sub-answers and retrieved documents.
"""
node_start_time = datetime.now()
documents = []
answer_results = state.answer_results
for answer_result in answer_results:
documents.extend(answer_result.verified_reranked_documents)
return SubQuestionResultsUpdate(
# Deduping is done by the documents operator for the main graph
# so we might not need to dedup here
verified_reranked_documents=dedup_inference_sections(documents, []),
sub_question_results=answer_results,
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="ingest refined answers",
node_start_time=node_start_time,
)
],
)

View File

@@ -1,129 +0,0 @@
from datetime import datetime
from typing import cast
from langchain_core.runnables import RunnableConfig
from onyx.agents.agent_search.deep_search.main.models import (
AgentAdditionalMetrics,
)
from onyx.agents.agent_search.deep_search.main.models import AgentTimings
from onyx.agents.agent_search.deep_search.main.operations import logger
from onyx.agents.agent_search.deep_search.main.states import MainOutput
from onyx.agents.agent_search.deep_search.main.states import MainState
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.models import CombinedAgentMetrics
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.db.chat import log_agent_metrics
from onyx.db.chat import log_agent_sub_question_results
def persist_agent_results(state: MainState, config: RunnableConfig) -> MainOutput:
"""
LangGraph node to persist the agent results, including agent logging data.
"""
node_start_time = datetime.now()
agent_start_time = state.agent_start_time
agent_base_end_time = state.agent_base_end_time
agent_refined_start_time = state.agent_refined_start_time
agent_refined_end_time = state.agent_refined_end_time
agent_end_time = agent_refined_end_time or agent_base_end_time
agent_base_duration = None
if agent_base_end_time and agent_start_time:
agent_base_duration = (agent_base_end_time - agent_start_time).total_seconds()
agent_refined_duration = None
if agent_refined_start_time and agent_refined_end_time:
agent_refined_duration = (
agent_refined_end_time - agent_refined_start_time
).total_seconds()
agent_full_duration = None
if agent_end_time and agent_start_time:
agent_full_duration = (agent_end_time - agent_start_time).total_seconds()
agent_type = "refined" if agent_refined_duration else "base"
agent_base_metrics = state.agent_base_metrics
agent_refined_metrics = state.agent_refined_metrics
combined_agent_metrics = CombinedAgentMetrics(
timings=AgentTimings(
base_duration_s=agent_base_duration,
refined_duration_s=agent_refined_duration,
full_duration_s=agent_full_duration,
),
base_metrics=agent_base_metrics,
refined_metrics=agent_refined_metrics,
additional_metrics=AgentAdditionalMetrics(),
)
persona_id = None
graph_config = cast(GraphConfig, config["metadata"]["config"])
if graph_config.inputs.search_request.persona:
persona_id = graph_config.inputs.search_request.persona.id
user_id = None
assert (
graph_config.tooling.search_tool
), "search_tool must be provided for agentic search"
user = graph_config.tooling.search_tool.user
if user:
user_id = user.id
# log the agent metrics
if graph_config.persistence:
if agent_base_duration is not None:
log_agent_metrics(
db_session=graph_config.persistence.db_session,
user_id=user_id,
persona_id=persona_id,
agent_type=agent_type,
start_time=agent_start_time,
agent_metrics=combined_agent_metrics,
)
# Persist the sub-answer in the database
db_session = graph_config.persistence.db_session
chat_session_id = graph_config.persistence.chat_session_id
primary_message_id = graph_config.persistence.message_id
sub_question_answer_results = state.sub_question_results
log_agent_sub_question_results(
db_session=db_session,
chat_session_id=chat_session_id,
primary_message_id=primary_message_id,
sub_question_answer_results=sub_question_answer_results,
)
main_output = MainOutput(
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="persist agent results",
node_start_time=node_start_time,
)
],
)
for log_message in state.log_messages:
logger.debug(log_message)
if state.agent_base_metrics:
logger.debug(f"Initial loop: {state.agent_base_metrics.duration_s}")
if state.agent_refined_metrics:
logger.debug(f"Refined loop: {state.agent_refined_metrics.duration_s}")
if (
state.agent_base_metrics
and state.agent_refined_metrics
and state.agent_base_metrics.duration_s
and state.agent_refined_metrics.duration_s
):
logger.debug(
f"Total time: {float(state.agent_base_metrics.duration_s) + float(state.agent_refined_metrics.duration_s)}"
)
return main_output

View File

@@ -1,52 +0,0 @@
from datetime import datetime
from typing import cast
from langchain_core.runnables import RunnableConfig
from onyx.agents.agent_search.deep_search.main.states import (
ExploratorySearchUpdate,
)
from onyx.agents.agent_search.deep_search.main.states import MainState
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
build_history_prompt,
)
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import retrieve_search_docs
from onyx.configs.agent_configs import AGENT_EXPLORATORY_SEARCH_RESULTS
from onyx.context.search.models import InferenceSection
def start_agent_search(
state: MainState, config: RunnableConfig
) -> ExploratorySearchUpdate:
"""
LangGraph node to start the agentic search process.
"""
node_start_time = datetime.now()
graph_config = cast(GraphConfig, config["metadata"]["config"])
question = graph_config.inputs.search_request.query
history = build_history_prompt(graph_config, question)
# Initial search to inform decomposition. Just get top 3 fits
search_tool = graph_config.tooling.search_tool
assert search_tool, "search_tool must be provided for agentic search"
retrieved_docs: list[InferenceSection] = retrieve_search_docs(search_tool, question)
exploratory_search_results = retrieved_docs[:AGENT_EXPLORATORY_SEARCH_RESULTS]
return ExploratorySearchUpdate(
exploratory_search_results=exploratory_search_results,
previous_history_summary=history,
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="start agent search",
node_start_time=node_start_time,
)
],
)

View File

@@ -1,172 +0,0 @@
from datetime import datetime
from operator import add
from typing import Annotated
from typing import TypedDict
from pydantic import BaseModel
from onyx.agents.agent_search.core_state import CoreState
from onyx.agents.agent_search.deep_search.main.models import AgentBaseMetrics
from onyx.agents.agent_search.deep_search.main.models import (
AgentRefinedMetrics,
)
from onyx.agents.agent_search.deep_search.main.models import (
RefinementSubQuestion,
)
from onyx.agents.agent_search.orchestration.states import ToolCallUpdate
from onyx.agents.agent_search.orchestration.states import ToolChoiceInput
from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkRetrievalStats
from onyx.agents.agent_search.shared_graph_utils.models import (
EntityRelationshipTermExtraction,
)
from onyx.agents.agent_search.shared_graph_utils.models import InitialAgentResultStats
from onyx.agents.agent_search.shared_graph_utils.models import QueryRetrievalResult
from onyx.agents.agent_search.shared_graph_utils.models import RefinedAgentStats
from onyx.agents.agent_search.shared_graph_utils.models import (
SubQuestionAnswerResults,
)
from onyx.agents.agent_search.shared_graph_utils.operators import (
dedup_inference_sections,
)
from onyx.agents.agent_search.shared_graph_utils.operators import (
dedup_question_answer_results,
)
from onyx.context.search.models import InferenceSection
### States ###
class LoggerUpdate(BaseModel):
log_messages: Annotated[list[str], add] = []
class RefinedAgentStartStats(BaseModel):
agent_refined_start_time: datetime | None = None
class RefinedAgentEndStats(BaseModel):
agent_refined_end_time: datetime | None = None
agent_refined_metrics: AgentRefinedMetrics = AgentRefinedMetrics()
class InitialQuestionDecompositionUpdate(
RefinedAgentStartStats, RefinedAgentEndStats, LoggerUpdate
):
agent_start_time: datetime | None = None
previous_history: str | None = None
initial_sub_questions: list[str] = []
class ExploratorySearchUpdate(LoggerUpdate):
exploratory_search_results: list[InferenceSection] = []
previous_history_summary: str | None = None
class InitialRefinedAnswerComparisonUpdate(LoggerUpdate):
"""
Evaluation of whether the refined answer is better than the initial answer
"""
refined_answer_improvement_eval: bool = False
class InitialAnswerUpdate(LoggerUpdate):
"""
Initial answer information
"""
initial_answer: str | None = None
initial_agent_stats: InitialAgentResultStats | None = None
generated_sub_questions: list[str] = []
agent_base_end_time: datetime | None = None
agent_base_metrics: AgentBaseMetrics | None = None
class RefinedAnswerUpdate(RefinedAgentEndStats, LoggerUpdate):
"""
Refined answer information
"""
refined_answer: str | None = None
refined_agent_stats: RefinedAgentStats | None = None
refined_answer_quality: bool = False
class InitialAnswerQualityUpdate(LoggerUpdate):
"""
Initial answer quality evaluation
"""
initial_answer_quality_eval: bool = False
class RequireRefinemenEvalUpdate(LoggerUpdate):
require_refined_answer_eval: bool = True
class SubQuestionResultsUpdate(LoggerUpdate):
verified_reranked_documents: Annotated[
list[InferenceSection], dedup_inference_sections
] = []
context_documents: Annotated[list[InferenceSection], dedup_inference_sections] = []
cited_documents: Annotated[
list[InferenceSection], dedup_inference_sections
] = [] # cited docs from sub-answers are used for answer context
sub_question_results: Annotated[
list[SubQuestionAnswerResults], dedup_question_answer_results
] = []
class OrigQuestionRetrievalUpdate(LoggerUpdate):
orig_question_retrieved_documents: Annotated[
list[InferenceSection], dedup_inference_sections
]
orig_question_verified_reranked_documents: Annotated[
list[InferenceSection], dedup_inference_sections
]
orig_question_sub_query_retrieval_results: list[QueryRetrievalResult] = []
orig_question_retrieval_stats: AgentChunkRetrievalStats = AgentChunkRetrievalStats()
class EntityTermExtractionUpdate(LoggerUpdate):
entity_relation_term_extractions: EntityRelationshipTermExtraction = (
EntityRelationshipTermExtraction()
)
class RefinedQuestionDecompositionUpdate(RefinedAgentStartStats, LoggerUpdate):
refined_sub_questions: dict[int, RefinementSubQuestion] = {}
## Graph Input State
class MainInput(CoreState):
pass
## Graph State
class MainState(
# This includes the core state
MainInput,
ToolChoiceInput,
ToolCallUpdate,
ToolChoiceUpdate,
InitialQuestionDecompositionUpdate,
InitialAnswerUpdate,
SubQuestionResultsUpdate,
OrigQuestionRetrievalUpdate,
EntityTermExtractionUpdate,
InitialAnswerQualityUpdate,
RequireRefinemenEvalUpdate,
RefinedQuestionDecompositionUpdate,
RefinedAnswerUpdate,
RefinedAgentStartStats,
RefinedAgentEndStats,
InitialRefinedAnswerComparisonUpdate,
ExploratorySearchUpdate,
):
pass
## Graph Output State - presently not used
class MainOutput(TypedDict):
log_messages: list[str]

View File

@@ -1,161 +0,0 @@
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.edges import (
parallel_retrieval_edge,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.nodes.expand_queries import (
expand_queries,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.nodes.format_queries import (
format_queries,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.nodes.format_results import (
format_results,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.nodes.kickoff_verification import (
kickoff_verification,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.nodes.rerank_documents import (
rerank_documents,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.nodes.retrieve_documents import (
retrieve_documents,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.nodes.verify_documents import (
verify_documents,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
ExpandedRetrievalInput,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
ExpandedRetrievalOutput,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
ExpandedRetrievalState,
)
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
from onyx.utils.logger import setup_logger
logger = setup_logger()
def expanded_retrieval_graph_builder() -> StateGraph:
"""
LangGraph graph builder for the expanded retrieval process.
"""
graph = StateGraph(
state_schema=ExpandedRetrievalState,
input=ExpandedRetrievalInput,
output=ExpandedRetrievalOutput,
)
### Add nodes ###
# Convert the question into multiple sub-queries
graph.add_node(
node="expand_queries",
action=expand_queries,
)
# Format the sub-queries into a list of strings
graph.add_node(
node="format_queries",
action=format_queries,
)
# Retrieve the documents for each sub-query
graph.add_node(
node="retrieve_documents",
action=retrieve_documents,
)
# Start verification process that the documents are relevant to the question (not the query)
graph.add_node(
node="kickoff_verification",
action=kickoff_verification,
)
# Verify that a given document is relevant to the question (not the query)
graph.add_node(
node="verify_documents",
action=verify_documents,
)
# Rerank the documents that have been verified
graph.add_node(
node="rerank_documents",
action=rerank_documents,
)
# Format the results into a list of strings
graph.add_node(
node="format_results",
action=format_results,
)
### Add edges ###
graph.add_edge(
start_key=START,
end_key="expand_queries",
)
graph.add_edge(
start_key="expand_queries",
end_key="format_queries",
)
graph.add_conditional_edges(
source="format_queries",
path=parallel_retrieval_edge,
path_map=["retrieve_documents"],
)
graph.add_edge(
start_key="retrieve_documents",
end_key="kickoff_verification",
)
graph.add_edge(
start_key="verify_documents",
end_key="rerank_documents",
)
graph.add_edge(
start_key="rerank_documents",
end_key="format_results",
)
graph.add_edge(
start_key="format_results",
end_key=END,
)
return graph
if __name__ == "__main__":
from onyx.db.engine import get_session_context_manager
from onyx.llm.factory import get_default_llms
from onyx.context.search.models import SearchRequest
graph = expanded_retrieval_graph_builder()
compiled_graph = graph.compile()
primary_llm, fast_llm = get_default_llms()
search_request = SearchRequest(
query="what can you do with onyx or danswer?",
)
with get_session_context_manager() as db_session:
graph_config, search_tool = get_test_config(
db_session, primary_llm, fast_llm, search_request
)
inputs = ExpandedRetrievalInput(
question="what can you do with onyx?",
base_search=False,
sub_question_id=None,
log_messages=[],
)
for thing in compiled_graph.stream(
input=inputs,
config={"configurable": {"config": graph_config}},
stream_mode="custom",
subgraphs=True,
):
logger.debug(thing)

View File

@@ -1,13 +0,0 @@
from pydantic import BaseModel
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkRetrievalStats
from onyx.agents.agent_search.shared_graph_utils.models import QueryRetrievalResult
from onyx.context.search.models import InferenceSection
class QuestionRetrievalResult(BaseModel):
expanded_query_results: list[QueryRetrievalResult] = []
retrieved_documents: list[InferenceSection] = []
verified_reranked_documents: list[InferenceSection] = []
context_documents: list[InferenceSection] = []
retrieval_stats: AgentChunkRetrievalStats = AgentChunkRetrievalStats()

View File

@@ -1,75 +0,0 @@
from datetime import datetime
from typing import cast
from langchain_core.messages import HumanMessage
from langchain_core.messages import merge_message_runs
from langchain_core.runnables.config import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.operations import (
dispatch_subquery,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
ExpandedRetrievalInput,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
QueryExpansionUpdate,
)
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.utils import dispatch_separated
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
from onyx.prompts.agent_search import (
QUERY_REWRITING_PROMPT,
)
def expand_queries(
state: ExpandedRetrievalInput,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> QueryExpansionUpdate:
"""
LangGraph node to expand a question into multiple search queries.
"""
# Sometimes we want to expand the original question, sometimes we want to expand a sub-question.
# When we are running this node on the original question, no question is explictly passed in.
# Instead, we use the original question from the search request.
graph_config = cast(GraphConfig, config["metadata"]["config"])
node_start_time = datetime.now()
question = state.question
llm = graph_config.tooling.fast_llm
sub_question_id = state.sub_question_id
if sub_question_id is None:
level, question_num = 0, 0
else:
level, question_num = parse_question_id(sub_question_id)
msg = [
HumanMessage(
content=QUERY_REWRITING_PROMPT.format(question=question),
)
]
llm_response_list = dispatch_separated(
llm.stream(prompt=msg), dispatch_subquery(level, question_num, writer)
)
llm_response = merge_message_runs(llm_response_list, chunk_separator="")[0].content
rewritten_queries = llm_response.split("\n")
return QueryExpansionUpdate(
expanded_queries=rewritten_queries,
log_messages=[
get_langgraph_node_log_string(
graph_component="shared - expanded retrieval",
node_name="expand queries",
node_start_time=node_start_time,
result=f"Number of expanded queries: {len(rewritten_queries)}",
)
],
)

View File

@@ -1,91 +0,0 @@
from typing import cast
from langchain_core.runnables.config import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.deep_search.main.operations import get_query_info
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.models import (
QuestionRetrievalResult,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.operations import (
calculate_sub_question_retrieval_stats,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
ExpandedRetrievalState,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
ExpandedRetrievalUpdate,
)
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkRetrievalStats
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
from onyx.agents.agent_search.shared_graph_utils.utils import relevance_from_docs
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import ExtendedToolResponse
from onyx.tools.tool_implementations.search.search_tool import yield_search_responses
def format_results(
state: ExpandedRetrievalState,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> ExpandedRetrievalUpdate:
"""
LangGraph node that constructs the proper expanded retrieval format.
"""
level, question_num = parse_question_id(state.sub_question_id or "0_0")
query_info = get_query_info(state.query_retrieval_results)
graph_config = cast(GraphConfig, config["metadata"]["config"])
# Main question docs will be sent later after aggregation and deduping with sub-question docs
reranked_documents = state.reranked_documents
if not (level == 0 and question_num == 0):
if len(reranked_documents) == 0:
# The sub-question is used as the last query. If no verified documents are found, stream
# the top 3 for that one. We may want to revisit this.
reranked_documents = state.query_retrieval_results[-1].retrieved_documents[
:3
]
assert (
graph_config.tooling.search_tool
), "search_tool must be provided for agentic search"
relevance_list = relevance_from_docs(reranked_documents)
for tool_response in yield_search_responses(
query=state.question,
reranked_sections=state.retrieved_documents,
final_context_sections=reranked_documents,
search_query_info=query_info,
get_section_relevance=lambda: relevance_list,
search_tool=graph_config.tooling.search_tool,
):
write_custom_event(
"tool_response",
ExtendedToolResponse(
id=tool_response.id,
response=tool_response.response,
level=level,
level_question_num=question_num,
),
writer,
)
sub_question_retrieval_stats = calculate_sub_question_retrieval_stats(
verified_documents=state.verified_documents,
expanded_retrieval_results=state.query_retrieval_results,
)
if sub_question_retrieval_stats is None:
sub_question_retrieval_stats = AgentChunkRetrievalStats()
return ExpandedRetrievalUpdate(
expanded_retrieval_result=QuestionRetrievalResult(
expanded_query_results=state.query_retrieval_results,
retrieved_documents=state.retrieved_documents,
verified_reranked_documents=reranked_documents,
context_documents=state.reranked_documents,
retrieval_stats=sub_question_retrieval_stats,
),
)

View File

@@ -1,44 +0,0 @@
from typing import Literal
from langchain_core.runnables.config import RunnableConfig
from langgraph.types import Command
from langgraph.types import Send
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
DocVerificationInput,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
ExpandedRetrievalState,
)
def kickoff_verification(
state: ExpandedRetrievalState,
config: RunnableConfig,
) -> Command[Literal["verify_documents"]]:
"""
LangGraph node (Command node!) that kicks off the verification process for the retrieved documents.
Note that this is a Command node and does the routing as well. (At present, no state updates
are done here, so this could be replaced with an edge. But we may choose to make state
updates later.)
"""
retrieved_documents = state.retrieved_documents
verification_question = state.question
sub_question_id = state.sub_question_id
return Command(
update={},
goto=[
Send(
node="verify_documents",
arg=DocVerificationInput(
retrieved_document_to_verify=document,
question=verification_question,
base_search=False,
sub_question_id=sub_question_id,
log_messages=[],
),
)
for document in retrieved_documents
],
)

View File

@@ -1,105 +0,0 @@
from datetime import datetime
from typing import cast
from langchain_core.runnables.config import RunnableConfig
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.operations import (
logger,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
DocRerankingUpdate,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
ExpandedRetrievalState,
)
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.calculations import get_fit_scores
from onyx.agents.agent_search.shared_graph_utils.models import RetrievalFitStats
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.configs.agent_configs import AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS
from onyx.configs.agent_configs import AGENT_RERANKING_STATS
from onyx.context.search.models import InferenceSection
from onyx.context.search.models import SearchRequest
from onyx.context.search.pipeline import retrieval_preprocessing
from onyx.context.search.postprocessing.postprocessing import rerank_sections
from onyx.db.engine import get_session_context_manager
def rerank_documents(
state: ExpandedRetrievalState, config: RunnableConfig
) -> DocRerankingUpdate:
"""
LangGraph node to rerank the retrieved and verified documents. A part of the
pre-existing pipeline is used here.
"""
node_start_time = datetime.now()
verified_documents = state.verified_documents
# Rerank post retrieval and verification. First, create a search query
# then create the list of reranked sections
graph_config = cast(GraphConfig, config["metadata"]["config"])
question = (
state.question if state.question else graph_config.inputs.search_request.query
)
assert (
graph_config.tooling.search_tool
), "search_tool must be provided for agentic search"
with get_session_context_manager() as db_session:
# we ignore some of the user specified fields since this search is
# internal to agentic search, but we still want to pass through
# persona (for stuff like document sets) and rerank settings
# (to not make an unnecessary db call).
search_request = SearchRequest(
query=question,
persona=graph_config.inputs.search_request.persona,
rerank_settings=graph_config.inputs.search_request.rerank_settings,
)
_search_query = retrieval_preprocessing(
search_request=search_request,
user=graph_config.tooling.search_tool.user, # bit of a hack
llm=graph_config.tooling.fast_llm,
db_session=db_session,
)
# skip section filtering
if (
_search_query.rerank_settings
and _search_query.rerank_settings.rerank_model_name
and _search_query.rerank_settings.num_rerank > 0
and len(verified_documents) > 0
):
if len(verified_documents) > 1:
reranked_documents = rerank_sections(
_search_query,
verified_documents,
)
else:
num = "No" if len(verified_documents) == 0 else "One"
logger.warning(f"{num} verified document(s) found, skipping reranking")
reranked_documents = verified_documents
else:
logger.warning("No reranking settings found, using unranked documents")
reranked_documents = verified_documents
if AGENT_RERANKING_STATS:
fit_scores = get_fit_scores(verified_documents, reranked_documents)
else:
fit_scores = RetrievalFitStats(fit_score_lift=0, rerank_effect=0, fit_scores={})
return DocRerankingUpdate(
reranked_documents=[
doc for doc in reranked_documents if type(doc) == InferenceSection
][:AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS],
sub_question_retrieval_stats=fit_scores,
log_messages=[
get_langgraph_node_log_string(
graph_component="shared - expanded retrieval",
node_name="rerank documents",
node_start_time=node_start_time,
)
],
)

View File

@@ -0,0 +1,29 @@
from collections.abc import Hashable
from datetime import datetime
from langgraph.types import Send
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.states import (
AnswerQuestionInput,
)
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
ExpandedRetrievalInput,
)
from onyx.utils.logger import setup_logger
logger = setup_logger()
def send_to_expanded_retrieval(state: AnswerQuestionInput) -> Send | Hashable:
logger.debug("sending to expanded retrieval via edge")
now_start = datetime.now()
return Send(
"initial_sub_question_expanded_retrieval",
ExpandedRetrievalInput(
question=state.question,
base_search=False,
sub_question_id=state.question_id,
log_messages=[f"{now_start} -- Sending to expanded retrieval"],
),
)

View File

@@ -2,31 +2,31 @@ from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.edges import (
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.edges import (
send_to_expanded_retrieval,
)
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.nodes.check_sub_answer import (
check_sub_answer,
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.nodes.answer_check import (
answer_check,
)
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.nodes.format_sub_answer import (
format_sub_answer,
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.nodes.answer_generation import (
answer_generation,
)
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.nodes.generate_sub_answer import (
generate_sub_answer,
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.nodes.format_answer import (
format_answer,
)
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.nodes.ingest_retrieved_documents import (
ingest_retrieved_documents,
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.nodes.ingest_retrieval import (
ingest_retrieval,
)
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.states import (
AnswerQuestionInput,
)
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.states import (
AnswerQuestionOutput,
)
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.states import (
AnswerQuestionState,
)
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
SubQuestionAnsweringInput,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.graph_builder import (
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.graph_builder import (
expanded_retrieval_graph_builder,
)
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
@@ -36,47 +36,34 @@ logger = setup_logger()
def answer_query_graph_builder() -> StateGraph:
"""
LangGraph sub-graph builder for the initial individual sub-answer generation.
"""
graph = StateGraph(
state_schema=AnswerQuestionState,
input=SubQuestionAnsweringInput,
input=AnswerQuestionInput,
output=AnswerQuestionOutput,
)
### Add nodes ###
# The sub-graph that executes the expanded retrieval process for a sub-question
expanded_retrieval = expanded_retrieval_graph_builder().compile()
graph.add_node(
node="initial_sub_question_expanded_retrieval",
action=expanded_retrieval,
)
# The node that ingests the retrieved documents and puts them into the proper
# state keys.
graph.add_node(
node="ingest_retrieval",
action=ingest_retrieved_documents,
)
# The node that generates the sub-answer
graph.add_node(
node="generate_sub_answer",
action=generate_sub_answer,
)
# The node that checks the sub-answer
graph.add_node(
node="answer_check",
action=check_sub_answer,
action=answer_check,
)
graph.add_node(
node="answer_generation",
action=answer_generation,
)
# The node that formats the sub-answer for the following initial answer generation
graph.add_node(
node="format_answer",
action=format_sub_answer,
action=format_answer,
)
graph.add_node(
node="ingest_retrieval",
action=ingest_retrieval,
)
### Add edges ###
@@ -92,10 +79,10 @@ def answer_query_graph_builder() -> StateGraph:
)
graph.add_edge(
start_key="ingest_retrieval",
end_key="generate_sub_answer",
end_key="answer_generation",
)
graph.add_edge(
start_key="generate_sub_answer",
start_key="answer_generation",
end_key="answer_check",
)
graph.add_edge(
@@ -122,16 +109,18 @@ if __name__ == "__main__":
query="what can you do with onyx or danswer?",
)
with get_session_context_manager() as db_session:
graph_config, search_tool = get_test_config(
agent_search_config, search_tool = get_test_config(
db_session, primary_llm, fast_llm, search_request
)
inputs = SubQuestionAnsweringInput(
inputs = AnswerQuestionInput(
question="what can you do with onyx?",
question_id="0_0",
log_messages=[],
)
for thing in compiled_graph.stream(
input=inputs,
config={"configurable": {"config": graph_config}},
config={"configurable": {"config": agent_search_config}},
# debug=True,
# subgraphs=True,
):
logger.debug(thing)

View File

@@ -0,0 +1,8 @@
from pydantic import BaseModel
### Models ###
class AnswerRetrievalStats(BaseModel):
answer_retrieval_stats: dict[str, float | int]

View File

@@ -0,0 +1,59 @@
from datetime import datetime
from typing import cast
from langchain_core.messages import HumanMessage
from langchain_core.messages import merge_message_runs
from langchain_core.runnables.config import RunnableConfig
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.states import (
AnswerQuestionState,
)
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.states import (
QACheckUpdate,
)
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.agents.agent_search.shared_graph_utils.prompts import SUB_CHECK_NO
from onyx.agents.agent_search.shared_graph_utils.prompts import SUB_CHECK_PROMPT
from onyx.agents.agent_search.shared_graph_utils.prompts import UNKNOWN_ANSWER
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
def answer_check(state: AnswerQuestionState, config: RunnableConfig) -> QACheckUpdate:
now_start = datetime.now()
level, question_num = parse_question_id(state.question_id)
if state.answer == UNKNOWN_ANSWER:
now_end = datetime.now()
return QACheckUpdate(
answer_quality=SUB_CHECK_NO,
log_messages=[
f"{now_end} -- Answer check SQ-{level}-{question_num} - unknown answer, Time taken: {now_end - now_start}"
],
)
msg = [
HumanMessage(
content=SUB_CHECK_PROMPT.format(
question=state.question,
base_answer=state.answer,
)
)
]
agent_searchch_config = cast(AgentSearchConfig, config["metadata"]["config"])
fast_llm = agent_searchch_config.fast_llm
response = list(
fast_llm.stream(
prompt=msg,
)
)
quality_str = merge_message_runs(response, chunk_separator="")[0].content
now_end = datetime.now()
return QACheckUpdate(
answer_quality=quality_str,
log_messages=[
f"""{now_end} -- Answer check SQ-{level}-{question_num} - Answer quality: {quality_str},
Time taken: {now_end - now_start}"""
],
)

View File

@@ -0,0 +1,116 @@
from datetime import datetime
from typing import Any
from typing import cast
from langchain_core.callbacks.manager import dispatch_custom_event
from langchain_core.messages import merge_message_runs
from langchain_core.runnables.config import RunnableConfig
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.states import (
AnswerQuestionState,
)
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.states import (
QAGenerationUpdate,
)
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
build_sub_question_answer_prompt,
)
from onyx.agents.agent_search.shared_graph_utils.prompts import (
ASSISTANT_SYSTEM_PROMPT_DEFAULT,
)
from onyx.agents.agent_search.shared_graph_utils.prompts import (
ASSISTANT_SYSTEM_PROMPT_PERSONA,
)
from onyx.agents.agent_search.shared_graph_utils.prompts import NO_RECOVERED_DOCS
from onyx.agents.agent_search.shared_graph_utils.utils import get_persona_prompt
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
from onyx.chat.models import AgentAnswerPiece
from onyx.chat.models import StreamStopInfo
from onyx.chat.models import StreamStopReason
from onyx.utils.logger import setup_logger
logger = setup_logger()
def answer_generation(
state: AnswerQuestionState, config: RunnableConfig
) -> QAGenerationUpdate:
now_start = datetime.now()
logger.debug(f"--------{now_start}--------START ANSWER GENERATION---")
agent_search_config = cast(AgentSearchConfig, config["metadata"]["config"])
question = state.question
docs = state.documents
level, question_nr = parse_question_id(state.question_id)
context_docs = state.context_documents
persona_prompt = get_persona_prompt(agent_search_config.search_request.persona)
if len(context_docs) == 0:
answer_str = NO_RECOVERED_DOCS
dispatch_custom_event(
"sub_answers",
AgentAnswerPiece(
answer_piece=answer_str,
level=level,
level_question_nr=question_nr,
answer_type="agent_sub_answer",
),
)
else:
if len(persona_prompt) > 0:
persona_specification = ASSISTANT_SYSTEM_PROMPT_DEFAULT
else:
persona_specification = ASSISTANT_SYSTEM_PROMPT_PERSONA.format(
persona_prompt=persona_prompt
)
logger.debug(f"Number of verified retrieval docs: {len(docs)}")
fast_llm = agent_search_config.fast_llm
msg = build_sub_question_answer_prompt(
question=question,
original_question=agent_search_config.search_request.query,
docs=docs,
persona_specification=persona_specification,
config=fast_llm.config,
)
response: list[str | list[str | dict[str, Any]]] = []
for message in fast_llm.stream(
prompt=msg,
):
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
content = message.content
if not isinstance(content, str):
raise ValueError(
f"Expected content to be a string, but got {type(content)}"
)
dispatch_custom_event(
"sub_answers",
AgentAnswerPiece(
answer_piece=content,
level=level,
level_question_nr=question_nr,
answer_type="agent_sub_answer",
),
)
response.append(content)
answer_str = merge_message_runs(response, chunk_separator="")[0].content
stop_event = StreamStopInfo(
stop_reason=StreamStopReason.FINISHED,
stream_type="sub_answer",
level=level,
level_question_nr=question_nr,
)
dispatch_custom_event("stream_finished", stop_event)
now_end = datetime.now()
return QAGenerationUpdate(
answer=answer_str,
log_messages=[
f"{now_end} -- Answer generation SQ-{level} - Q{question_nr} - Time taken: {now_end - now_start}"
],
)

View File

@@ -0,0 +1,28 @@
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.states import (
AnswerQuestionOutput,
)
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.states import (
AnswerQuestionState,
)
from onyx.agents.agent_search.shared_graph_utils.models import (
QuestionAnswerResults,
)
def format_answer(state: AnswerQuestionState) -> AnswerQuestionOutput:
return AnswerQuestionOutput(
answer_results=[
QuestionAnswerResults(
question=state.question,
question_id=state.question_id,
quality=state.answer_quality
if hasattr(state, "answer_quality")
else "No",
answer=state.answer,
expanded_retrieval_results=state.expanded_retrieval_results,
documents=state.documents,
context_documents=state.context_documents,
sub_question_retrieval_stats=state.sub_question_retrieval_stats,
)
],
)

View File

@@ -0,0 +1,22 @@
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.states import (
RetrievalIngestionUpdate,
)
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
ExpandedRetrievalOutput,
)
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkStats
def ingest_retrieval(state: ExpandedRetrievalOutput) -> RetrievalIngestionUpdate:
sub_question_retrieval_stats = (
state.expanded_retrieval_result.sub_question_retrieval_stats
)
if sub_question_retrieval_stats is None:
sub_question_retrieval_stats = [AgentChunkStats()]
return RetrievalIngestionUpdate(
expanded_retrieval_results=state.expanded_retrieval_result.expanded_queries_results,
documents=state.expanded_retrieval_result.all_documents,
context_documents=state.expanded_retrieval_result.context_documents,
sub_question_retrieval_stats=sub_question_retrieval_stats,
)

View File

@@ -4,11 +4,10 @@ from typing import Annotated
from pydantic import BaseModel
from onyx.agents.agent_search.core_state import SubgraphCoreState
from onyx.agents.agent_search.deep_search.main.states import LoggerUpdate
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkRetrievalStats
from onyx.agents.agent_search.shared_graph_utils.models import QueryRetrievalResult
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkStats
from onyx.agents.agent_search.shared_graph_utils.models import QueryResult
from onyx.agents.agent_search.shared_graph_utils.models import (
SubQuestionAnswerResults,
QuestionAnswerResults,
)
from onyx.agents.agent_search.shared_graph_utils.operators import (
dedup_inference_sections,
@@ -17,31 +16,28 @@ from onyx.context.search.models import InferenceSection
## Update States
class SubQuestionAnswerCheckUpdate(LoggerUpdate, BaseModel):
answer_quality: bool = False
class QACheckUpdate(BaseModel):
answer_quality: str = ""
log_messages: list[str] = []
class SubQuestionAnswerGenerationUpdate(LoggerUpdate, BaseModel):
class QAGenerationUpdate(BaseModel):
answer: str = ""
log_messages: list[str] = []
cited_documents: Annotated[list[InferenceSection], dedup_inference_sections] = []
# answer_stat: AnswerStats
class SubQuestionRetrievalIngestionUpdate(LoggerUpdate, BaseModel):
expanded_retrieval_results: list[QueryRetrievalResult] = []
verified_reranked_documents: Annotated[
list[InferenceSection], dedup_inference_sections
] = []
class RetrievalIngestionUpdate(BaseModel):
expanded_retrieval_results: list[QueryResult] = []
documents: Annotated[list[InferenceSection], dedup_inference_sections] = []
context_documents: Annotated[list[InferenceSection], dedup_inference_sections] = []
sub_question_retrieval_stats: AgentChunkRetrievalStats = AgentChunkRetrievalStats()
sub_question_retrieval_stats: AgentChunkStats = AgentChunkStats()
## Graph Input State
class SubQuestionAnsweringInput(SubgraphCoreState):
class AnswerQuestionInput(SubgraphCoreState):
question: str = ""
question_id: str = (
"" # 0_0 is original question, everything else is <level>_<question_num>.
@@ -54,10 +50,10 @@ class SubQuestionAnsweringInput(SubgraphCoreState):
class AnswerQuestionState(
SubQuestionAnsweringInput,
SubQuestionAnswerGenerationUpdate,
SubQuestionAnswerCheckUpdate,
SubQuestionRetrievalIngestionUpdate,
AnswerQuestionInput,
QAGenerationUpdate,
QACheckUpdate,
RetrievalIngestionUpdate,
):
pass
@@ -65,11 +61,11 @@ class AnswerQuestionState(
## Graph Output State
class AnswerQuestionOutput(LoggerUpdate, BaseModel):
class AnswerQuestionOutput(BaseModel):
"""
This is a list of results even though each call of this subgraph only returns one result.
This is because if we parallelize the answer query subgraph, there will be multiple
results in a list so the add operator is used to add them together.
"""
answer_results: Annotated[list[SubQuestionAnswerResults], add] = []
answer_results: Annotated[list[QuestionAnswerResults], add] = []

View File

@@ -3,10 +3,10 @@ from datetime import datetime
from langgraph.types import Send
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
SubQuestionAnsweringInput,
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.states import (
AnswerQuestionInput,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
ExpandedRetrievalInput,
)
from onyx.utils.logger import setup_logger
@@ -14,12 +14,7 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
def send_to_expanded_refined_retrieval(
state: SubQuestionAnsweringInput,
) -> Send | Hashable:
"""
LangGraph edge to sends a refined sub-question extended retrieval.
"""
def send_to_expanded_refined_retrieval(state: AnswerQuestionInput) -> Send | Hashable:
logger.debug("sending to expanded retrieval for follow up question via edge")
datetime.now()
return Send(

View File

@@ -2,31 +2,31 @@ from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.nodes.check_sub_answer import (
check_sub_answer,
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.nodes.answer_check import (
answer_check,
)
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.nodes.format_sub_answer import (
format_sub_answer,
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.nodes.answer_generation import (
answer_generation,
)
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.nodes.generate_sub_answer import (
generate_sub_answer,
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.nodes.format_answer import (
format_answer,
)
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.nodes.ingest_retrieved_documents import (
ingest_retrieved_documents,
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.nodes.ingest_retrieval import (
ingest_retrieval,
)
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.states import (
AnswerQuestionInput,
)
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.states import (
AnswerQuestionOutput,
)
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.states import (
AnswerQuestionState,
)
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
SubQuestionAnsweringInput,
)
from onyx.agents.agent_search.deep_search.refinement.consolidate_sub_answers.edges import (
from onyx.agents.agent_search.deep_search_a.answer_refinement_sub_question.edges import (
send_to_expanded_refined_retrieval,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.graph_builder import (
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.graph_builder import (
expanded_retrieval_graph_builder,
)
from onyx.utils.logger import setup_logger
@@ -35,46 +35,34 @@ logger = setup_logger()
def answer_refined_query_graph_builder() -> StateGraph:
"""
LangGraph graph builder for the refined sub-answer generation process.
"""
graph = StateGraph(
state_schema=AnswerQuestionState,
input=SubQuestionAnsweringInput,
input=AnswerQuestionInput,
output=AnswerQuestionOutput,
)
### Add nodes ###
# Subgraph for the expanded retrieval process
expanded_retrieval = expanded_retrieval_graph_builder().compile()
graph.add_node(
node="refined_sub_question_expanded_retrieval",
action=expanded_retrieval,
)
# Ingest the retrieved documents
graph.add_node(
node="ingest_refined_retrieval",
action=ingest_retrieved_documents,
)
# Generate the refined sub-answer
graph.add_node(
node="generate_refined_sub_answer",
action=generate_sub_answer,
)
# Check if the refined sub-answer is correct
graph.add_node(
node="refined_sub_answer_check",
action=check_sub_answer,
action=answer_check,
)
graph.add_node(
node="refined_sub_answer_generation",
action=answer_generation,
)
# Format the refined sub-answer
graph.add_node(
node="format_refined_sub_answer",
action=format_sub_answer,
action=format_answer,
)
graph.add_node(
node="ingest_refined_retrieval",
action=ingest_retrieval,
)
### Add edges ###
@@ -90,10 +78,10 @@ def answer_refined_query_graph_builder() -> StateGraph:
)
graph.add_edge(
start_key="ingest_refined_retrieval",
end_key="generate_refined_sub_answer",
end_key="refined_sub_answer_generation",
)
graph.add_edge(
start_key="generate_refined_sub_answer",
start_key="refined_sub_answer_generation",
end_key="refined_sub_answer_check",
)
graph.add_edge(
@@ -120,13 +108,16 @@ if __name__ == "__main__":
query="what can you do with onyx or danswer?",
)
with get_session_context_manager() as db_session:
inputs = SubQuestionAnsweringInput(
inputs = AnswerQuestionInput(
question="what can you do with onyx?",
question_id="0_0",
log_messages=[],
)
for thing in compiled_graph.stream(
input=inputs,
stream_mode="custom",
# debug=True,
# subgraphs=True,
):
logger.debug(thing)
# output = compiled_graph.invoke(inputs)
# logger.debug(output)

View File

@@ -0,0 +1,19 @@
from pydantic import BaseModel
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkStats
from onyx.context.search.models import InferenceSection
### Models ###
class AnswerRetrievalStats(BaseModel):
answer_retrieval_stats: dict[str, float | int]
class QuestionAnswerResults(BaseModel):
question: str
answer: str
quality: str
# expanded_retrieval_results: list[QueryResult]
documents: list[InferenceSection]
sub_question_retrieval_stats: AgentChunkStats

View File

@@ -0,0 +1,76 @@
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
from onyx.agents.agent_search.deep_search_a.base_raw_search.nodes.format_raw_search_results import (
format_raw_search_results,
)
from onyx.agents.agent_search.deep_search_a.base_raw_search.nodes.generate_raw_search_data import (
generate_raw_search_data,
)
from onyx.agents.agent_search.deep_search_a.base_raw_search.states import (
BaseRawSearchInput,
)
from onyx.agents.agent_search.deep_search_a.base_raw_search.states import (
BaseRawSearchOutput,
)
from onyx.agents.agent_search.deep_search_a.base_raw_search.states import (
BaseRawSearchState,
)
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.graph_builder import (
expanded_retrieval_graph_builder,
)
def base_raw_search_graph_builder() -> StateGraph:
graph = StateGraph(
state_schema=BaseRawSearchState,
input=BaseRawSearchInput,
output=BaseRawSearchOutput,
)
### Add nodes ###
graph.add_node(
node="generate_raw_search_data",
action=generate_raw_search_data,
)
expanded_retrieval = expanded_retrieval_graph_builder().compile()
graph.add_node(
node="expanded_retrieval_base_search",
action=expanded_retrieval,
)
graph.add_node(
node="format_raw_search_results",
action=format_raw_search_results,
)
### Add edges ###
graph.add_edge(start_key=START, end_key="generate_raw_search_data")
graph.add_edge(
start_key="generate_raw_search_data",
end_key="expanded_retrieval_base_search",
)
graph.add_edge(
start_key="expanded_retrieval_base_search",
end_key="format_raw_search_results",
)
# graph.add_edge(
# start_key="expanded_retrieval_base_search",
# end_key=END,
# )
graph.add_edge(
start_key="format_raw_search_results",
end_key=END,
)
return graph
if __name__ == "__main__":
pass

View File

@@ -0,0 +1,20 @@
from pydantic import BaseModel
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkStats
from onyx.agents.agent_search.shared_graph_utils.models import QueryResult
from onyx.context.search.models import InferenceSection
### Models ###
class AnswerRetrievalStats(BaseModel):
answer_retrieval_stats: dict[str, float | int]
class QuestionAnswerResults(BaseModel):
question: str
answer: str
quality: str
expanded_retrieval_results: list[QueryResult]
documents: list[InferenceSection]
sub_question_retrieval_stats: list[AgentChunkStats]

View File

@@ -0,0 +1,18 @@
from onyx.agents.agent_search.deep_search_a.base_raw_search.states import (
BaseRawSearchOutput,
)
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
ExpandedRetrievalOutput,
)
from onyx.utils.logger import setup_logger
logger = setup_logger()
def format_raw_search_results(state: ExpandedRetrievalOutput) -> BaseRawSearchOutput:
logger.debug("format_raw_search_results")
return BaseRawSearchOutput(
base_expanded_retrieval_result=state.expanded_retrieval_result,
# base_retrieval_results=[state.expanded_retrieval_result],
# base_search_documents=[],
)

View File

@@ -3,25 +3,22 @@ from typing import cast
from langchain_core.runnables.config import RunnableConfig
from onyx.agents.agent_search.core_state import CoreState
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
ExpandedRetrievalInput,
)
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.utils.logger import setup_logger
logger = setup_logger()
def format_orig_question_search_input(
def generate_raw_search_data(
state: CoreState, config: RunnableConfig
) -> ExpandedRetrievalInput:
"""
LangGraph node to format the search input for the original question.
"""
logger.debug("generate_raw_search_data")
graph_config = cast(GraphConfig, config["metadata"]["config"])
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
return ExpandedRetrievalInput(
question=graph_config.inputs.search_request.query,
question=agent_a_config.search_request.query,
base_search=True,
sub_question_id=None, # This graph is always and only used for the original question
log_messages=[],

View File

@@ -0,0 +1,43 @@
from pydantic import BaseModel
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.models import (
ExpandedRetrievalResult,
)
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
ExpandedRetrievalInput,
)
## Update States
## Graph Input State
class BaseRawSearchInput(ExpandedRetrievalInput):
pass
## Graph Output State
class BaseRawSearchOutput(BaseModel):
"""
This is a list of results even though each call of this subgraph only returns one result.
This is because if we parallelize the answer query subgraph, there will be multiple
results in a list so the add operator is used to add them together.
"""
# base_search_documents: Annotated[list[InferenceSection], dedup_inference_sections]
# base_retrieval_results: Annotated[list[ExpandedRetrievalResult], add]
base_expanded_retrieval_result: ExpandedRetrievalResult = ExpandedRetrievalResult()
## Graph State
class BaseRawSearchState(
BaseRawSearchInput,
BaseRawSearchOutput,
):
pass

View File

@@ -4,32 +4,27 @@ from typing import cast
from langchain_core.runnables.config import RunnableConfig
from langgraph.types import Send
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
ExpandedRetrievalState,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
RetrievalInput,
)
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.models import AgentSearchConfig
def parallel_retrieval_edge(
state: ExpandedRetrievalState, config: RunnableConfig
) -> list[Send | Hashable]:
"""
LangGraph edge to parallelize the retrieval process for each of the
generated sub-queries and the original question.
"""
graph_config = cast(GraphConfig, config["metadata"]["config"])
question = (
state.question if state.question else graph_config.inputs.search_request.query
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
question = state.question if state.question else agent_a_config.search_request.query
query_expansions = (
state.expanded_queries if state.expanded_queries else [] + [question]
)
query_expansions = state.expanded_queries + [question]
return [
Send(
"retrieve_documents",
"doc_retrieval",
RetrievalInput(
query_to_retrieve=query,
question=question,

View File

@@ -0,0 +1,147 @@
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.edges import (
parallel_retrieval_edge,
)
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.nodes.doc_reranking import (
doc_reranking,
)
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.nodes.doc_retrieval import (
doc_retrieval,
)
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.nodes.doc_verification import (
doc_verification,
)
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.nodes.dummy import (
dummy,
)
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.nodes.expand_queries import (
expand_queries,
)
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.nodes.format_results import (
format_results,
)
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.nodes.verification_kickoff import (
verification_kickoff,
)
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
ExpandedRetrievalInput,
)
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
ExpandedRetrievalOutput,
)
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
ExpandedRetrievalState,
)
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
from onyx.utils.logger import setup_logger
logger = setup_logger()
def expanded_retrieval_graph_builder() -> StateGraph:
graph = StateGraph(
state_schema=ExpandedRetrievalState,
input=ExpandedRetrievalInput,
output=ExpandedRetrievalOutput,
)
### Add nodes ###
graph.add_node(
node="expand_queries",
action=expand_queries,
)
graph.add_node(
node="dummy",
action=dummy,
)
graph.add_node(
node="doc_retrieval",
action=doc_retrieval,
)
graph.add_node(
node="verification_kickoff",
action=verification_kickoff,
)
graph.add_node(
node="doc_verification",
action=doc_verification,
)
graph.add_node(
node="doc_reranking",
action=doc_reranking,
)
graph.add_node(
node="format_results",
action=format_results,
)
### Add edges ###
graph.add_edge(
start_key=START,
end_key="expand_queries",
)
graph.add_edge(
start_key="expand_queries",
end_key="dummy",
)
graph.add_conditional_edges(
source="dummy",
path=parallel_retrieval_edge,
path_map=["doc_retrieval"],
)
graph.add_edge(
start_key="doc_retrieval",
end_key="verification_kickoff",
)
graph.add_edge(
start_key="doc_verification",
end_key="doc_reranking",
)
graph.add_edge(
start_key="doc_reranking",
end_key="format_results",
)
graph.add_edge(
start_key="format_results",
end_key=END,
)
return graph
if __name__ == "__main__":
from onyx.db.engine import get_session_context_manager
from onyx.llm.factory import get_default_llms
from onyx.context.search.models import SearchRequest
graph = expanded_retrieval_graph_builder()
compiled_graph = graph.compile()
primary_llm, fast_llm = get_default_llms()
search_request = SearchRequest(
query="what can you do with onyx or danswer?",
)
with get_session_context_manager() as db_session:
agent_a_config, search_tool = get_test_config(
db_session, primary_llm, fast_llm, search_request
)
inputs = ExpandedRetrievalInput(
question="what can you do with onyx?",
base_search=False,
sub_question_id=None,
log_messages=[],
)
for thing in compiled_graph.stream(
input=inputs,
config={"configurable": {"config": agent_a_config}},
# debug=True,
subgraphs=True,
):
logger.debug(thing)

View File

@@ -0,0 +1,12 @@
from pydantic import BaseModel
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkStats
from onyx.agents.agent_search.shared_graph_utils.models import QueryResult
from onyx.context.search.models import InferenceSection
class ExpandedRetrievalResult(BaseModel):
expanded_queries_results: list[QueryResult] = []
all_documents: list[InferenceSection] = []
context_documents: list[InferenceSection] = []
sub_question_retrieval_stats: AgentChunkStats = AgentChunkStats()

View File

@@ -0,0 +1,74 @@
from datetime import datetime
from typing import cast
from langchain_core.runnables.config import RunnableConfig
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.operations import logger
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
DocRerankingUpdate,
)
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
ExpandedRetrievalState,
)
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.agents.agent_search.shared_graph_utils.calculations import get_fit_scores
from onyx.agents.agent_search.shared_graph_utils.models import RetrievalFitStats
from onyx.configs.dev_configs import AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS
from onyx.configs.dev_configs import AGENT_RERANKING_STATS
from onyx.context.search.models import InferenceSection
from onyx.context.search.models import SearchRequest
from onyx.context.search.pipeline import retrieval_preprocessing
from onyx.context.search.postprocessing.postprocessing import rerank_sections
from onyx.db.engine import get_session_context_manager
def doc_reranking(
state: ExpandedRetrievalState, config: RunnableConfig
) -> DocRerankingUpdate:
now_start = datetime.now()
verified_documents = state.verified_documents
# Rerank post retrieval and verification. First, create a search query
# then create the list of reranked sections
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
question = state.question if state.question else agent_a_config.search_request.query
with get_session_context_manager() as db_session:
_search_query = retrieval_preprocessing(
search_request=SearchRequest(query=question),
user=agent_a_config.search_tool.user, # bit of a hack
llm=agent_a_config.fast_llm,
db_session=db_session,
)
# skip section filtering
if (
_search_query.rerank_settings
and _search_query.rerank_settings.rerank_model_name
and _search_query.rerank_settings.num_rerank > 0
):
reranked_documents = rerank_sections(
_search_query,
verified_documents,
)
else:
logger.warning("No reranking settings found, using unranked documents")
reranked_documents = verified_documents
if AGENT_RERANKING_STATS:
fit_scores = get_fit_scores(verified_documents, reranked_documents)
else:
fit_scores = RetrievalFitStats(fit_score_lift=0, rerank_effect=0, fit_scores={})
# TODO: stream deduped docs here, or decide to use search tool ranking/verification
now_end = datetime.now()
return DocRerankingUpdate(
reranked_documents=[
doc for doc in reranked_documents if type(doc) == InferenceSection
][:AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS],
sub_question_retrieval_stats=fit_scores,
log_messages=[
f"{now_end} -- Expanded Retrieval - Reranking - Time taken: {now_end - now_start}"
],
)

View File

@@ -3,23 +3,18 @@ from typing import cast
from langchain_core.runnables.config import RunnableConfig
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.operations import (
logger,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.operations import logger
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
DocRetrievalUpdate,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
RetrievalInput,
)
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.agents.agent_search.shared_graph_utils.calculations import get_fit_scores
from onyx.agents.agent_search.shared_graph_utils.models import QueryRetrievalResult
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.configs.agent_configs import AGENT_MAX_QUERY_RETRIEVAL_RESULTS
from onyx.configs.agent_configs import AGENT_RETRIEVAL_STATS
from onyx.agents.agent_search.shared_graph_utils.models import QueryResult
from onyx.configs.dev_configs import AGENT_MAX_QUERY_RETRIEVAL_RESULTS
from onyx.configs.dev_configs import AGENT_RETRIEVAL_STATS
from onyx.context.search.models import InferenceSection
from onyx.db.engine import get_session_context_manager
from onyx.tools.models import SearchQueryInfo
@@ -29,47 +24,42 @@ from onyx.tools.tool_implementations.search.search_tool import (
from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary
def retrieve_documents(
state: RetrievalInput, config: RunnableConfig
) -> DocRetrievalUpdate:
def doc_retrieval(state: RetrievalInput, config: RunnableConfig) -> DocRetrievalUpdate:
"""
LangGraph node to retrieve documents from the search tool.
Retrieve documents
Args:
state (RetrievalInput): Primary state + the query to retrieve
config (RunnableConfig): Configuration containing ProSearchConfig
Updates:
expanded_retrieval_results: list[ExpandedRetrievalResult]
retrieved_documents: list[InferenceSection]
"""
node_start_time = datetime.now()
now_start = datetime.now()
query_to_retrieve = state.query_to_retrieve
graph_config = cast(GraphConfig, config["metadata"]["config"])
search_tool = graph_config.tooling.search_tool
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
search_tool = agent_a_config.search_tool
retrieved_docs: list[InferenceSection] = []
if not query_to_retrieve.strip():
logger.warning("Empty query, skipping retrieval")
now_end = datetime.now()
return DocRetrievalUpdate(
query_retrieval_results=[],
expanded_retrieval_results=[],
retrieved_documents=[],
log_messages=[
get_langgraph_node_log_string(
graph_component="shared - expanded retrieval",
node_name="retrieve documents",
node_start_time=node_start_time,
result="Empty query, skipping retrieval",
)
f"{now_end} -- Expanded Retrieval - Retrieval - Empty Query - Time taken: {now_end - now_start}"
],
)
query_info = None
if search_tool is None:
raise ValueError("search_tool must be provided for agentic search")
callback_container: list[list[InferenceSection]] = []
# new db session to avoid concurrency issues
with get_session_context_manager() as db_session:
for tool_response in search_tool.run(
query=query_to_retrieve,
force_no_rerank=True,
alternate_db_session=db_session,
retrieved_sections_callback=callback_container.append,
):
# get retrieved docs to send to the rest of the graph
if tool_response.id == SEARCH_RESPONSE_SUMMARY_ID:
@@ -83,9 +73,13 @@ def retrieve_documents(
break
retrieved_docs = retrieved_docs[:AGENT_MAX_QUERY_RETRIEVAL_RESULTS]
pre_rerank_docs = retrieved_docs
if search_tool.search_pipeline is not None:
pre_rerank_docs = (
search_tool.search_pipeline._retrieved_sections or retrieved_docs
)
if AGENT_RETRIEVAL_STATS:
pre_rerank_docs = callback_container[0]
fit_scores = get_fit_scores(
pre_rerank_docs,
retrieved_docs,
@@ -93,21 +87,17 @@ def retrieve_documents(
else:
fit_scores = None
expanded_retrieval_result = QueryRetrievalResult(
expanded_retrieval_result = QueryResult(
query=query_to_retrieve,
retrieved_documents=retrieved_docs,
search_results=retrieved_docs,
stats=fit_scores,
query_info=query_info,
)
now_end = datetime.now()
return DocRetrievalUpdate(
query_retrieval_results=[expanded_retrieval_result],
expanded_retrieval_results=[expanded_retrieval_result],
retrieved_documents=retrieved_docs,
log_messages=[
get_langgraph_node_log_string(
graph_component="shared - expanded retrieval",
node_name="retrieve documents",
node_start_time=node_start_time,
)
f"{now_end} -- Expanded Retrieval - Retrieval - Time taken: {now_end - now_start}"
],
)

View File

@@ -3,26 +3,24 @@ from typing import cast
from langchain_core.messages import HumanMessage
from langchain_core.runnables.config import RunnableConfig
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
DocVerificationInput,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
DocVerificationUpdate,
)
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
trim_prompt_piece,
)
from onyx.prompts.agent_search import (
DOCUMENT_VERIFICATION_PROMPT,
)
from onyx.agents.agent_search.shared_graph_utils.prompts import VERIFIER_PROMPT
def verify_documents(
def doc_verification(
state: DocVerificationInput, config: RunnableConfig
) -> DocVerificationUpdate:
"""
LangGraph node to check whether the document is relevant for the original user question
Check whether the document is relevant for the original user question
Args:
state (DocVerificationInput): The current state
@@ -33,19 +31,19 @@ def verify_documents(
"""
question = state.question
retrieved_document_to_verify = state.retrieved_document_to_verify
document_content = retrieved_document_to_verify.combined_content
doc_to_verify = state.doc_to_verify
document_content = doc_to_verify.combined_content
graph_config = cast(GraphConfig, config["metadata"]["config"])
fast_llm = graph_config.tooling.fast_llm
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
fast_llm = agent_a_config.fast_llm
document_content = trim_prompt_piece(
fast_llm.config, document_content, DOCUMENT_VERIFICATION_PROMPT + question
fast_llm.config, document_content, VERIFIER_PROMPT + question
)
msg = [
HumanMessage(
content=DOCUMENT_VERIFICATION_PROMPT.format(
content=VERIFIER_PROMPT.format(
question=question, document_content=document_content
)
)
@@ -55,7 +53,7 @@ def verify_documents(
verified_documents = []
if isinstance(response.content, str) and "yes" in response.content.lower():
verified_documents.append(retrieved_document_to_verify)
verified_documents.append(doc_to_verify)
return DocVerificationUpdate(
verified_documents=verified_documents,

Some files were not shown because too many files have changed in this diff Show More