mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-20 09:15:47 +00:00
Compare commits
46 Commits
error_supp
...
improvemen
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b5f20ab12c | ||
|
|
c052d6251d | ||
|
|
4cb603586d | ||
|
|
07e64caeed | ||
|
|
dd6e623cab | ||
|
|
4e4f7bf823 | ||
|
|
fa70770033 | ||
|
|
789d98ff9a | ||
|
|
30c2b54a15 | ||
|
|
346152aecb | ||
|
|
eb98c3d069 | ||
|
|
a387eddd1f | ||
|
|
f23e9e2d96 | ||
|
|
704607fac4 | ||
|
|
5c6c0c4f56 | ||
|
|
ef4762a523 | ||
|
|
4f5f89df11 | ||
|
|
66e675c36a | ||
|
|
803ae4511b | ||
|
|
f85941a206 | ||
|
|
4402c5e550 | ||
|
|
a031ac4b17 | ||
|
|
e99d1d49e3 | ||
|
|
150d6f7acc | ||
|
|
316ebc0d12 | ||
|
|
4122675725 | ||
|
|
48d43e049d | ||
|
|
364dd4bb19 | ||
|
|
e07a00d353 | ||
|
|
398263712d | ||
|
|
14dd40d155 | ||
|
|
88d5796b28 | ||
|
|
12fa4fb0be | ||
|
|
5aa7fce418 | ||
|
|
90ea89e70b | ||
|
|
4f4cbfeeb1 | ||
|
|
5500357585 | ||
|
|
3ed10b85f7 | ||
|
|
3233f31010 | ||
|
|
67c4ec74ff | ||
|
|
63908c90c5 | ||
|
|
90e0aba34b | ||
|
|
0b2ff44a54 | ||
|
|
6d4af4c6af | ||
|
|
c4c0b42ca4 | ||
|
|
a4eae8a302 |
2
.github/workflows/pr-chromatic-tests.yml
vendored
2
.github/workflows/pr-chromatic-tests.yml
vendored
@@ -8,8 +8,6 @@ on: push
|
||||
env:
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
|
||||
GEN_AI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
MOCK_LLM_RESPONSE: true
|
||||
|
||||
jobs:
|
||||
playwright-tests:
|
||||
|
||||
22
.github/workflows/pr-helm-chart-testing.yml
vendored
22
.github/workflows/pr-helm-chart-testing.yml
vendored
@@ -21,10 +21,10 @@ jobs:
|
||||
- name: Set up Helm
|
||||
uses: azure/setup-helm@v4.2.0
|
||||
with:
|
||||
version: v3.17.0
|
||||
version: v3.14.4
|
||||
|
||||
- name: Set up chart-testing
|
||||
uses: helm/chart-testing-action@v2.7.0
|
||||
uses: helm/chart-testing-action@v2.6.1
|
||||
|
||||
# even though we specify chart-dirs in ct.yaml, it isn't used by ct for the list-changed command...
|
||||
- name: Run chart-testing (list-changed)
|
||||
@@ -37,6 +37,22 @@ jobs:
|
||||
echo "changed=true" >> "$GITHUB_OUTPUT"
|
||||
fi
|
||||
|
||||
# rkuo: I don't think we need python?
|
||||
# - name: Set up Python
|
||||
# uses: actions/setup-python@v5
|
||||
# with:
|
||||
# python-version: '3.11'
|
||||
# cache: 'pip'
|
||||
# cache-dependency-path: |
|
||||
# backend/requirements/default.txt
|
||||
# backend/requirements/dev.txt
|
||||
# backend/requirements/model_server.txt
|
||||
# - run: |
|
||||
# python -m pip install --upgrade pip
|
||||
# pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
|
||||
# pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
|
||||
# pip install --retries 5 --timeout 30 -r backend/requirements/model_server.txt
|
||||
|
||||
# lint all charts if any changes were detected
|
||||
- name: Run chart-testing (lint)
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
@@ -46,7 +62,7 @@ jobs:
|
||||
|
||||
- name: Create kind cluster
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
uses: helm/kind-action@v1.12.0
|
||||
uses: helm/kind-action@v1.10.0
|
||||
|
||||
- name: Run chart-testing (install)
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
|
||||
4
.github/workflows/pr-linear-check.yml
vendored
4
.github/workflows/pr-linear-check.yml
vendored
@@ -9,9 +9,9 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Check PR body for Linear link or override
|
||||
env:
|
||||
PR_BODY: ${{ github.event.pull_request.body }}
|
||||
run: |
|
||||
PR_BODY="${{ github.event.pull_request.body }}"
|
||||
|
||||
# Looking for "https://linear.app" in the body
|
||||
if echo "$PR_BODY" | grep -qE "https://linear\.app"; then
|
||||
echo "Found a Linear link. Check passed."
|
||||
|
||||
@@ -39,12 +39,6 @@ env:
|
||||
AIRTABLE_TEST_TABLE_ID: ${{ secrets.AIRTABLE_TEST_TABLE_ID }}
|
||||
AIRTABLE_TEST_TABLE_NAME: ${{ secrets.AIRTABLE_TEST_TABLE_NAME }}
|
||||
AIRTABLE_ACCESS_TOKEN: ${{ secrets.AIRTABLE_ACCESS_TOKEN }}
|
||||
# Sharepoint
|
||||
SHAREPOINT_CLIENT_ID: ${{ secrets.SHAREPOINT_CLIENT_ID }}
|
||||
SHAREPOINT_CLIENT_SECRET: ${{ secrets.SHAREPOINT_CLIENT_SECRET }}
|
||||
SHAREPOINT_CLIENT_DIRECTORY_ID: ${{ secrets.SHAREPOINT_CLIENT_DIRECTORY_ID }}
|
||||
SHAREPOINT_SITE: ${{ secrets.SHAREPOINT_SITE }}
|
||||
|
||||
jobs:
|
||||
connectors-check:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
|
||||
4
.gitignore
vendored
4
.gitignore
vendored
@@ -7,4 +7,6 @@
|
||||
.vscode/
|
||||
*.sw?
|
||||
/backend/tests/regression/answer_quality/search_test_config.yaml
|
||||
/web/test-results/
|
||||
/web/test-results/
|
||||
backend/onyx/agent_search/main/test_data.json
|
||||
backend/tests/regression/answer_quality/test_data.json
|
||||
|
||||
6
.vscode/env_template.txt
vendored
6
.vscode/env_template.txt
vendored
@@ -52,3 +52,9 @@ BING_API_KEY=<REPLACE THIS>
|
||||
# Enable the full set of Danswer Enterprise Edition features
|
||||
# NOTE: DO NOT ENABLE THIS UNLESS YOU HAVE A PAID ENTERPRISE LICENSE (or if you are using this for local testing/development)
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=False
|
||||
|
||||
# Agent Search configs # TODO: Remove give proper namings
|
||||
AGENT_RETRIEVAL_STATS=False # Note: This setting will incur substantial re-ranking effort
|
||||
AGENT_RERANKING_STATS=True
|
||||
AGENT_MAX_QUERY_RETRIEVAL_RESULTS=20
|
||||
AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS=20
|
||||
|
||||
1
Untitled-12
Normal file
1
Untitled-12
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
@@ -0,0 +1,29 @@
|
||||
"""agent_doc_result_col
|
||||
|
||||
Revision ID: 1adf5ea20d2b
|
||||
Revises: e9cf2bd7baed
|
||||
Create Date: 2025-01-05 13:14:58.344316
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "1adf5ea20d2b"
|
||||
down_revision = "e9cf2bd7baed"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add the new column with JSONB type
|
||||
op.add_column(
|
||||
"sub_question",
|
||||
sa.Column("sub_question_doc_results", postgresql.JSONB(), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop the column
|
||||
op.drop_column("sub_question", "sub_question_doc_results")
|
||||
@@ -0,0 +1,31 @@
|
||||
"""refined answer improvement
|
||||
|
||||
Revision ID: 211b14ab5a91
|
||||
Revises: 925b58bd75b6
|
||||
Create Date: 2025-01-24 14:05:03.334309
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "211b14ab5a91"
|
||||
down_revision = "925b58bd75b6"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"chat_message",
|
||||
sa.Column(
|
||||
"refined_answer_improvement",
|
||||
sa.Boolean(),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("chat_message", "refined_answer_improvement")
|
||||
@@ -1,36 +0,0 @@
|
||||
"""add chat session specific temperature override
|
||||
|
||||
Revision ID: 2f80c6a2550f
|
||||
Revises: 33ea50e88f24
|
||||
Create Date: 2025-01-31 10:30:27.289646
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "2f80c6a2550f"
|
||||
down_revision = "33ea50e88f24"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"chat_session", sa.Column("temperature_override", sa.Float(), nullable=True)
|
||||
)
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column(
|
||||
"temperature_override_enabled",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.false(),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("chat_session", "temperature_override")
|
||||
op.drop_column("user", "temperature_override_enabled")
|
||||
@@ -1,80 +0,0 @@
|
||||
"""foreign key input prompts
|
||||
|
||||
Revision ID: 33ea50e88f24
|
||||
Revises: a6df6b88ef81
|
||||
Create Date: 2025-01-29 10:54:22.141765
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "33ea50e88f24"
|
||||
down_revision = "a6df6b88ef81"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Safely drop constraints if exists
|
||||
op.execute(
|
||||
"""
|
||||
ALTER TABLE inputprompt__user
|
||||
DROP CONSTRAINT IF EXISTS inputprompt__user_input_prompt_id_fkey
|
||||
"""
|
||||
)
|
||||
op.execute(
|
||||
"""
|
||||
ALTER TABLE inputprompt__user
|
||||
DROP CONSTRAINT IF EXISTS inputprompt__user_user_id_fkey
|
||||
"""
|
||||
)
|
||||
|
||||
# Recreate with ON DELETE CASCADE
|
||||
op.create_foreign_key(
|
||||
"inputprompt__user_input_prompt_id_fkey",
|
||||
"inputprompt__user",
|
||||
"inputprompt",
|
||||
["input_prompt_id"],
|
||||
["id"],
|
||||
ondelete="CASCADE",
|
||||
)
|
||||
|
||||
op.create_foreign_key(
|
||||
"inputprompt__user_user_id_fkey",
|
||||
"inputprompt__user",
|
||||
"user",
|
||||
["user_id"],
|
||||
["id"],
|
||||
ondelete="CASCADE",
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop the new FKs with ondelete
|
||||
op.drop_constraint(
|
||||
"inputprompt__user_input_prompt_id_fkey",
|
||||
"inputprompt__user",
|
||||
type_="foreignkey",
|
||||
)
|
||||
op.drop_constraint(
|
||||
"inputprompt__user_user_id_fkey",
|
||||
"inputprompt__user",
|
||||
type_="foreignkey",
|
||||
)
|
||||
|
||||
# Recreate them without cascading
|
||||
op.create_foreign_key(
|
||||
"inputprompt__user_input_prompt_id_fkey",
|
||||
"inputprompt__user",
|
||||
"inputprompt",
|
||||
["input_prompt_id"],
|
||||
["id"],
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"inputprompt__user_user_id_fkey",
|
||||
"inputprompt__user",
|
||||
"user",
|
||||
["user_id"],
|
||||
["id"],
|
||||
)
|
||||
@@ -1,37 +0,0 @@
|
||||
"""lowercase_user_emails
|
||||
|
||||
Revision ID: 4d58345da04a
|
||||
Revises: f1ca58b2f2ec
|
||||
Create Date: 2025-01-29 07:48:46.784041
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
from sqlalchemy.sql import text
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "4d58345da04a"
|
||||
down_revision = "f1ca58b2f2ec"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Get database connection
|
||||
connection = op.get_bind()
|
||||
|
||||
# Update all user emails to lowercase
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
UPDATE "user"
|
||||
SET email = LOWER(email)
|
||||
WHERE email != LOWER(email)
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Cannot restore original case of emails
|
||||
pass
|
||||
@@ -0,0 +1,35 @@
|
||||
"""agent_metric_col_rename__s
|
||||
|
||||
Revision ID: 925b58bd75b6
|
||||
Revises: 9787be927e58
|
||||
Create Date: 2025-01-06 11:20:26.752441
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "925b58bd75b6"
|
||||
down_revision = "9787be927e58"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Rename columns using PostgreSQL syntax
|
||||
op.alter_column(
|
||||
"agent__search_metrics", "base_duration_s", new_column_name="base_duration__s"
|
||||
)
|
||||
op.alter_column(
|
||||
"agent__search_metrics", "full_duration_s", new_column_name="full_duration__s"
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Revert the column renames
|
||||
op.alter_column(
|
||||
"agent__search_metrics", "base_duration__s", new_column_name="base_duration_s"
|
||||
)
|
||||
op.alter_column(
|
||||
"agent__search_metrics", "full_duration__s", new_column_name="full_duration_s"
|
||||
)
|
||||
@@ -0,0 +1,25 @@
|
||||
"""agent_metric_table_renames__agent__
|
||||
|
||||
Revision ID: 9787be927e58
|
||||
Revises: bceb76d618ec
|
||||
Create Date: 2025-01-06 11:01:44.210160
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "9787be927e58"
|
||||
down_revision = "bceb76d618ec"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Rename table from agent_search_metrics to agent__search_metrics
|
||||
op.rename_table("agent_search_metrics", "agent__search_metrics")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Rename table back from agent__search_metrics to agent_search_metrics
|
||||
op.rename_table("agent__search_metrics", "agent_search_metrics")
|
||||
42
backend/alembic/versions/98a5008d8711_agent_tracking.py
Normal file
42
backend/alembic/versions/98a5008d8711_agent_tracking.py
Normal file
@@ -0,0 +1,42 @@
|
||||
"""agent_tracking
|
||||
|
||||
Revision ID: 98a5008d8711
|
||||
Revises: f1ca58b2f2ec
|
||||
Create Date: 2025-01-04 14:41:52.732238
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "98a5008d8711"
|
||||
down_revision = "f1ca58b2f2ec"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"agent_search_metrics",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=True),
|
||||
sa.Column("persona_id", sa.Integer(), nullable=True),
|
||||
sa.Column("agent_type", sa.String(), nullable=False),
|
||||
sa.Column("start_time", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("base_duration_s", sa.Float(), nullable=False),
|
||||
sa.Column("full_duration_s", sa.Float(), nullable=False),
|
||||
sa.Column("base_metrics", postgresql.JSONB(), nullable=True),
|
||||
sa.Column("refined_metrics", postgresql.JSONB(), nullable=True),
|
||||
sa.Column("all_metrics", postgresql.JSONB(), nullable=True),
|
||||
sa.ForeignKeyConstraint(
|
||||
["persona_id"],
|
||||
["persona.id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("agent_search_metrics")
|
||||
@@ -1,29 +0,0 @@
|
||||
"""remove recent assistants
|
||||
|
||||
Revision ID: a6df6b88ef81
|
||||
Revises: 4d58345da04a
|
||||
Create Date: 2025-01-29 10:25:52.790407
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "a6df6b88ef81"
|
||||
down_revision = "4d58345da04a"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.drop_column("user", "recent_assistants")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column(
|
||||
"recent_assistants", postgresql.JSONB(), server_default="[]", nullable=False
|
||||
),
|
||||
)
|
||||
@@ -0,0 +1,84 @@
|
||||
"""agent_table_renames__agent__
|
||||
|
||||
Revision ID: bceb76d618ec
|
||||
Revises: c0132518a25b
|
||||
Create Date: 2025-01-06 10:50:48.109285
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "bceb76d618ec"
|
||||
down_revision = "c0132518a25b"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.drop_constraint(
|
||||
"sub_query__search_doc_sub_query_id_fkey",
|
||||
"sub_query__search_doc",
|
||||
type_="foreignkey",
|
||||
)
|
||||
op.drop_constraint(
|
||||
"sub_query__search_doc_search_doc_id_fkey",
|
||||
"sub_query__search_doc",
|
||||
type_="foreignkey",
|
||||
)
|
||||
# Rename tables
|
||||
op.rename_table("sub_query", "agent__sub_query")
|
||||
op.rename_table("sub_question", "agent__sub_question")
|
||||
op.rename_table("sub_query__search_doc", "agent__sub_query__search_doc")
|
||||
|
||||
# Update both foreign key constraints for agent__sub_query__search_doc
|
||||
|
||||
# Create new foreign keys with updated names
|
||||
op.create_foreign_key(
|
||||
"agent__sub_query__search_doc_sub_query_id_fkey",
|
||||
"agent__sub_query__search_doc",
|
||||
"agent__sub_query",
|
||||
["sub_query_id"],
|
||||
["id"],
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"agent__sub_query__search_doc_search_doc_id_fkey",
|
||||
"agent__sub_query__search_doc",
|
||||
"search_doc", # This table name doesn't change
|
||||
["search_doc_id"],
|
||||
["id"],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Update foreign key constraints for sub_query__search_doc
|
||||
op.drop_constraint(
|
||||
"agent__sub_query__search_doc_sub_query_id_fkey",
|
||||
"agent__sub_query__search_doc",
|
||||
type_="foreignkey",
|
||||
)
|
||||
op.drop_constraint(
|
||||
"agent__sub_query__search_doc_search_doc_id_fkey",
|
||||
"agent__sub_query__search_doc",
|
||||
type_="foreignkey",
|
||||
)
|
||||
|
||||
# Rename tables back
|
||||
op.rename_table("agent__sub_query__search_doc", "sub_query__search_doc")
|
||||
op.rename_table("agent__sub_question", "sub_question")
|
||||
op.rename_table("agent__sub_query", "sub_query")
|
||||
|
||||
op.create_foreign_key(
|
||||
"sub_query__search_doc_sub_query_id_fkey",
|
||||
"sub_query__search_doc",
|
||||
"sub_query",
|
||||
["sub_query_id"],
|
||||
["id"],
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"sub_query__search_doc_search_doc_id_fkey",
|
||||
"sub_query__search_doc",
|
||||
"search_doc", # This table name doesn't change
|
||||
["search_doc_id"],
|
||||
["id"],
|
||||
)
|
||||
@@ -0,0 +1,40 @@
|
||||
"""agent_table_changes_rename_level
|
||||
|
||||
Revision ID: c0132518a25b
|
||||
Revises: 1adf5ea20d2b
|
||||
Create Date: 2025-01-05 16:38:37.660152
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "c0132518a25b"
|
||||
down_revision = "1adf5ea20d2b"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add level and level_question_nr columns with NOT NULL constraint
|
||||
op.add_column(
|
||||
"sub_question",
|
||||
sa.Column("level", sa.Integer(), nullable=False, server_default="0"),
|
||||
)
|
||||
op.add_column(
|
||||
"sub_question",
|
||||
sa.Column(
|
||||
"level_question_nr", sa.Integer(), nullable=False, server_default="0"
|
||||
),
|
||||
)
|
||||
|
||||
# Remove the server_default after the columns are created
|
||||
op.alter_column("sub_question", "level", server_default=None)
|
||||
op.alter_column("sub_question", "level_question_nr", server_default=None)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Remove the columns
|
||||
op.drop_column("sub_question", "level_question_nr")
|
||||
op.drop_column("sub_question", "level")
|
||||
@@ -0,0 +1,68 @@
|
||||
"""create pro search persistence tables
|
||||
|
||||
Revision ID: e9cf2bd7baed
|
||||
Revises: 98a5008d8711
|
||||
Create Date: 2025-01-02 17:55:56.544246
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "e9cf2bd7baed"
|
||||
down_revision = "98a5008d8711"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Create sub_question table
|
||||
op.create_table(
|
||||
"sub_question",
|
||||
sa.Column("id", sa.Integer, primary_key=True),
|
||||
sa.Column("primary_question_id", sa.Integer, sa.ForeignKey("chat_message.id")),
|
||||
sa.Column(
|
||||
"chat_session_id", UUID(as_uuid=True), sa.ForeignKey("chat_session.id")
|
||||
),
|
||||
sa.Column("sub_question", sa.Text),
|
||||
sa.Column(
|
||||
"time_created", sa.DateTime(timezone=True), server_default=sa.func.now()
|
||||
),
|
||||
sa.Column("sub_answer", sa.Text),
|
||||
)
|
||||
|
||||
# Create sub_query table
|
||||
op.create_table(
|
||||
"sub_query",
|
||||
sa.Column("id", sa.Integer, primary_key=True),
|
||||
sa.Column("parent_question_id", sa.Integer, sa.ForeignKey("sub_question.id")),
|
||||
sa.Column(
|
||||
"chat_session_id", UUID(as_uuid=True), sa.ForeignKey("chat_session.id")
|
||||
),
|
||||
sa.Column("sub_query", sa.Text),
|
||||
sa.Column(
|
||||
"time_created", sa.DateTime(timezone=True), server_default=sa.func.now()
|
||||
),
|
||||
)
|
||||
|
||||
# Create sub_query__search_doc association table
|
||||
op.create_table(
|
||||
"sub_query__search_doc",
|
||||
sa.Column(
|
||||
"sub_query_id", sa.Integer, sa.ForeignKey("sub_query.id"), primary_key=True
|
||||
),
|
||||
sa.Column(
|
||||
"search_doc_id",
|
||||
sa.Integer,
|
||||
sa.ForeignKey("search_doc.id"),
|
||||
primary_key=True,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("sub_query__search_doc")
|
||||
op.drop_table("sub_query")
|
||||
op.drop_table("sub_question")
|
||||
370
backend/chat_packets.log
Normal file
370
backend/chat_packets.log
Normal file
File diff suppressed because one or more lines are too long
@@ -32,7 +32,6 @@ def perform_ttl_management_task(
|
||||
|
||||
@celery_app.task(
|
||||
name="check_ttl_management_task",
|
||||
ignore_result=True,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
)
|
||||
def check_ttl_management_task(*, tenant_id: str | None) -> None:
|
||||
@@ -57,7 +56,6 @@ def check_ttl_management_task(*, tenant_id: str | None) -> None:
|
||||
|
||||
@celery_app.task(
|
||||
name="autogenerate_usage_report_task",
|
||||
ignore_result=True,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
)
|
||||
def autogenerate_usage_report_task(*, tenant_id: str | None) -> None:
|
||||
|
||||
@@ -13,7 +13,6 @@ from onyx.connectors.confluence.onyx_confluence import OnyxConfluence
|
||||
from onyx.connectors.confluence.utils import get_user_email_from_username__server
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -258,7 +257,6 @@ def _fetch_all_page_restrictions(
|
||||
slim_docs: list[SlimDocument],
|
||||
space_permissions_by_space_key: dict[str, ExternalAccess],
|
||||
is_cloud: bool,
|
||||
callback: IndexingHeartbeatInterface | None,
|
||||
) -> list[DocExternalAccess]:
|
||||
"""
|
||||
For all pages, if a page has restrictions, then use those restrictions.
|
||||
@@ -267,12 +265,6 @@ def _fetch_all_page_restrictions(
|
||||
document_restrictions: list[DocExternalAccess] = []
|
||||
|
||||
for slim_doc in slim_docs:
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
raise RuntimeError("confluence_doc_sync: Stop signal detected")
|
||||
|
||||
callback.progress("confluence_doc_sync:fetch_all_page_restrictions", 1)
|
||||
|
||||
if slim_doc.perm_sync_data is None:
|
||||
raise ValueError(
|
||||
f"No permission sync data found for document {slim_doc.id}"
|
||||
@@ -342,7 +334,7 @@ def _fetch_all_page_restrictions(
|
||||
|
||||
|
||||
def confluence_doc_sync(
|
||||
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
) -> list[DocExternalAccess]:
|
||||
"""
|
||||
Adds the external permissions to the documents in postgres
|
||||
@@ -367,12 +359,6 @@ def confluence_doc_sync(
|
||||
logger.debug("Fetching all slim documents from confluence")
|
||||
for doc_batch in confluence_connector.retrieve_all_slim_documents():
|
||||
logger.debug(f"Got {len(doc_batch)} slim documents from confluence")
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
raise RuntimeError("confluence_doc_sync: Stop signal detected")
|
||||
|
||||
callback.progress("confluence_doc_sync", 1)
|
||||
|
||||
slim_docs.extend(doc_batch)
|
||||
|
||||
logger.debug("Fetching all page restrictions for space")
|
||||
@@ -381,5 +367,4 @@ def confluence_doc_sync(
|
||||
slim_docs=slim_docs,
|
||||
space_permissions_by_space_key=space_permissions_by_space_key,
|
||||
is_cloud=is_cloud,
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
@@ -14,8 +14,6 @@ def _build_group_member_email_map(
|
||||
) -> dict[str, set[str]]:
|
||||
group_member_emails: dict[str, set[str]] = {}
|
||||
for user_result in confluence_client.paginated_cql_user_retrieval():
|
||||
logger.debug(f"Processing groups for user: {user_result}")
|
||||
|
||||
user = user_result.get("user", {})
|
||||
if not user:
|
||||
logger.warning(f"user result missing user field: {user_result}")
|
||||
@@ -35,17 +33,10 @@ def _build_group_member_email_map(
|
||||
logger.warning(f"user result missing email field: {user_result}")
|
||||
continue
|
||||
|
||||
all_users_groups: set[str] = set()
|
||||
for group in confluence_client.paginated_groups_by_user_retrieval(user):
|
||||
# group name uniqueness is enforced by Confluence, so we can use it as a group ID
|
||||
group_id = group["name"]
|
||||
group_member_emails.setdefault(group_id, set()).add(email)
|
||||
all_users_groups.add(group_id)
|
||||
|
||||
if not group_member_emails:
|
||||
logger.warning(f"No groups found for user with email: {email}")
|
||||
else:
|
||||
logger.debug(f"Found groups {all_users_groups} for user with email {email}")
|
||||
|
||||
return group_member_emails
|
||||
|
||||
|
||||
@@ -6,7 +6,6 @@ from onyx.access.models import ExternalAccess
|
||||
from onyx.connectors.gmail.connector import GmailConnector
|
||||
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -29,7 +28,7 @@ def _get_slim_doc_generator(
|
||||
|
||||
|
||||
def gmail_doc_sync(
|
||||
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
) -> list[DocExternalAccess]:
|
||||
"""
|
||||
Adds the external permissions to the documents in postgres
|
||||
@@ -45,12 +44,6 @@ def gmail_doc_sync(
|
||||
document_external_access: list[DocExternalAccess] = []
|
||||
for slim_doc_batch in slim_doc_generator:
|
||||
for slim_doc in slim_doc_batch:
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
raise RuntimeError("gmail_doc_sync: Stop signal detected")
|
||||
|
||||
callback.progress("gmail_doc_sync", 1)
|
||||
|
||||
if slim_doc.perm_sync_data is None:
|
||||
logger.warning(f"No permissions found for document {slim_doc.id}")
|
||||
continue
|
||||
|
||||
@@ -10,7 +10,6 @@ from onyx.connectors.google_utils.resources import get_drive_service
|
||||
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -43,22 +42,24 @@ def _fetch_permissions_for_permission_ids(
|
||||
if not permission_info or not doc_id:
|
||||
return []
|
||||
|
||||
# Check cache first for all permission IDs
|
||||
permissions = [
|
||||
_PERMISSION_ID_PERMISSION_MAP[pid]
|
||||
for pid in permission_ids
|
||||
if pid in _PERMISSION_ID_PERMISSION_MAP
|
||||
]
|
||||
|
||||
# If we found all permissions in cache, return them
|
||||
if len(permissions) == len(permission_ids):
|
||||
return permissions
|
||||
|
||||
owner_email = permission_info.get("owner_email")
|
||||
|
||||
drive_service = get_drive_service(
|
||||
creds=google_drive_connector.creds,
|
||||
user_email=(owner_email or google_drive_connector.primary_admin_email),
|
||||
)
|
||||
|
||||
# Otherwise, fetch all permissions and update cache
|
||||
fetched_permissions = execute_paginated_retrieval(
|
||||
retrieval_function=drive_service.permissions().list,
|
||||
list_key="permissions",
|
||||
@@ -68,6 +69,7 @@ def _fetch_permissions_for_permission_ids(
|
||||
)
|
||||
|
||||
permissions_for_doc_id = []
|
||||
# Update cache and return all permissions
|
||||
for permission in fetched_permissions:
|
||||
permissions_for_doc_id.append(permission)
|
||||
_PERMISSION_ID_PERMISSION_MAP[permission["id"]] = permission
|
||||
@@ -129,7 +131,7 @@ def _get_permissions_from_slim_doc(
|
||||
|
||||
|
||||
def gdrive_doc_sync(
|
||||
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
) -> list[DocExternalAccess]:
|
||||
"""
|
||||
Adds the external permissions to the documents in postgres
|
||||
@@ -147,12 +149,6 @@ def gdrive_doc_sync(
|
||||
document_external_accesses = []
|
||||
for slim_doc_batch in slim_doc_generator:
|
||||
for slim_doc in slim_doc_batch:
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
raise RuntimeError("gdrive_doc_sync: Stop signal detected")
|
||||
|
||||
callback.progress("gdrive_doc_sync", 1)
|
||||
|
||||
ext_access = _get_permissions_from_slim_doc(
|
||||
google_drive_connector=google_drive_connector,
|
||||
slim_doc=slim_doc,
|
||||
|
||||
@@ -7,7 +7,6 @@ from onyx.connectors.slack.connector import get_channels
|
||||
from onyx.connectors.slack.connector import make_paginated_slack_api_call_w_retries
|
||||
from onyx.connectors.slack.connector import SlackPollConnector
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
@@ -15,7 +14,7 @@ logger = setup_logger()
|
||||
|
||||
|
||||
def _get_slack_document_ids_and_channels(
|
||||
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
) -> dict[str, list[str]]:
|
||||
slack_connector = SlackPollConnector(**cc_pair.connector.connector_specific_config)
|
||||
slack_connector.load_credentials(cc_pair.credential.credential_json)
|
||||
@@ -25,14 +24,6 @@ def _get_slack_document_ids_and_channels(
|
||||
channel_doc_map: dict[str, list[str]] = {}
|
||||
for doc_metadata_batch in slim_doc_generator:
|
||||
for doc_metadata in doc_metadata_batch:
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
raise RuntimeError(
|
||||
"_get_slack_document_ids_and_channels: Stop signal detected"
|
||||
)
|
||||
|
||||
callback.progress("_get_slack_document_ids_and_channels", 1)
|
||||
|
||||
if doc_metadata.perm_sync_data is None:
|
||||
continue
|
||||
channel_id = doc_metadata.perm_sync_data["channel_id"]
|
||||
@@ -123,7 +114,7 @@ def _fetch_channel_permissions(
|
||||
|
||||
|
||||
def slack_doc_sync(
|
||||
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
) -> list[DocExternalAccess]:
|
||||
"""
|
||||
Adds the external permissions to the documents in postgres
|
||||
@@ -136,7 +127,7 @@ def slack_doc_sync(
|
||||
)
|
||||
user_id_to_email_map = fetch_user_id_to_email_map(slack_client)
|
||||
channel_doc_map = _get_slack_document_ids_and_channels(
|
||||
cc_pair=cc_pair, callback=callback
|
||||
cc_pair=cc_pair,
|
||||
)
|
||||
workspace_permissions = _fetch_workspace_permissions(
|
||||
user_id_to_email_map=user_id_to_email_map,
|
||||
|
||||
@@ -15,13 +15,11 @@ from ee.onyx.external_permissions.slack.doc_sync import slack_doc_sync
|
||||
from onyx.access.models import DocExternalAccess
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
|
||||
# Defining the input/output types for the sync functions
|
||||
DocSyncFuncType = Callable[
|
||||
[
|
||||
ConnectorCredentialPair,
|
||||
IndexingHeartbeatInterface | None,
|
||||
],
|
||||
list[DocExternalAccess],
|
||||
]
|
||||
|
||||
@@ -179,6 +179,7 @@ def handle_simplified_chat_message(
|
||||
chunks_below=0,
|
||||
full_doc=chat_message_req.full_doc,
|
||||
structured_response_format=chat_message_req.structured_response_format,
|
||||
use_agentic_search=chat_message_req.use_agentic_search,
|
||||
)
|
||||
|
||||
packets = stream_chat_message_objects(
|
||||
@@ -301,6 +302,7 @@ def handle_send_message_simple_with_history(
|
||||
chunks_below=0,
|
||||
full_doc=req.full_doc,
|
||||
structured_response_format=req.structured_response_format,
|
||||
use_agentic_search=req.use_agentic_search,
|
||||
)
|
||||
|
||||
packets = stream_chat_message_objects(
|
||||
|
||||
@@ -57,6 +57,9 @@ class BasicCreateChatMessageRequest(ChunkContext):
|
||||
# https://platform.openai.com/docs/guides/structured-outputs/introduction
|
||||
structured_response_format: dict | None = None
|
||||
|
||||
# If True, uses agentic search instead of basic search
|
||||
use_agentic_search: bool = False
|
||||
|
||||
|
||||
class BasicCreateChatMessageWithHistoryRequest(ChunkContext):
|
||||
# Last element is the new query. All previous elements are historical context
|
||||
@@ -71,6 +74,8 @@ class BasicCreateChatMessageWithHistoryRequest(ChunkContext):
|
||||
# only works if using an OpenAI model. See the following for more details:
|
||||
# https://platform.openai.com/docs/guides/structured-outputs/introduction
|
||||
structured_response_format: dict | None = None
|
||||
# If True, uses agentic search instead of basic search
|
||||
use_agentic_search: bool = False
|
||||
|
||||
|
||||
class SimpleDoc(BaseModel):
|
||||
@@ -123,6 +128,9 @@ class OneShotQARequest(ChunkContext):
|
||||
# If True, skips generative an AI response to the search query
|
||||
skip_gen_ai_answer_generation: bool = False
|
||||
|
||||
# If True, uses pro search instead of basic search
|
||||
use_agentic_search: bool = False
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_persona_fields(self) -> "OneShotQARequest":
|
||||
if self.persona_override_config is None and self.persona_id is None:
|
||||
|
||||
@@ -196,6 +196,8 @@ def get_answer_stream(
|
||||
retrieval_details=query_request.retrieval_options,
|
||||
rerank_settings=query_request.rerank_settings,
|
||||
db_session=db_session,
|
||||
use_agentic_search=query_request.use_agentic_search,
|
||||
skip_gen_ai_answer_generation=query_request.skip_gen_ai_answer_generation,
|
||||
)
|
||||
|
||||
packets = stream_chat_message_objects(
|
||||
|
||||
@@ -111,7 +111,6 @@ async def login_as_anonymous_user(
|
||||
token = generate_anonymous_user_jwt_token(tenant_id)
|
||||
|
||||
response = Response()
|
||||
response.delete_cookie("fastapiusersauth")
|
||||
response.set_cookie(
|
||||
key=ANONYMOUS_USER_COOKIE_NAME,
|
||||
value=token,
|
||||
|
||||
@@ -58,7 +58,6 @@ class UserGroup(BaseModel):
|
||||
credential=CredentialSnapshot.from_credential_db_model(
|
||||
cc_pair_relationship.cc_pair.credential
|
||||
),
|
||||
access_type=cc_pair_relationship.cc_pair.access_type,
|
||||
)
|
||||
for cc_pair_relationship in user_group_model.cc_pair_relationships
|
||||
if cc_pair_relationship.is_current
|
||||
|
||||
98
backend/onyx/agents/agent_search/basic/graph_builder.py
Normal file
98
backend/onyx/agents/agent_search/basic/graph_builder.py
Normal file
@@ -0,0 +1,98 @@
|
||||
from langgraph.graph import END
|
||||
from langgraph.graph import START
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
from onyx.agents.agent_search.basic.states import BasicInput
|
||||
from onyx.agents.agent_search.basic.states import BasicOutput
|
||||
from onyx.agents.agent_search.basic.states import BasicState
|
||||
from onyx.agents.agent_search.orchestration.nodes.basic_use_tool_response import (
|
||||
basic_use_tool_response,
|
||||
)
|
||||
from onyx.agents.agent_search.orchestration.nodes.llm_tool_choice import llm_tool_choice
|
||||
from onyx.agents.agent_search.orchestration.nodes.prepare_tool_input import (
|
||||
prepare_tool_input,
|
||||
)
|
||||
from onyx.agents.agent_search.orchestration.nodes.tool_call import tool_call
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def basic_graph_builder() -> StateGraph:
|
||||
graph = StateGraph(
|
||||
state_schema=BasicState,
|
||||
input=BasicInput,
|
||||
output=BasicOutput,
|
||||
)
|
||||
|
||||
### Add nodes ###
|
||||
|
||||
graph.add_node(
|
||||
node="prepare_tool_input",
|
||||
action=prepare_tool_input,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
node="llm_tool_choice",
|
||||
action=llm_tool_choice,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
node="tool_call",
|
||||
action=tool_call,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
node="basic_use_tool_response",
|
||||
action=basic_use_tool_response,
|
||||
)
|
||||
|
||||
### Add edges ###
|
||||
|
||||
graph.add_edge(start_key=START, end_key="prepare_tool_input")
|
||||
|
||||
graph.add_edge(start_key="prepare_tool_input", end_key="llm_tool_choice")
|
||||
|
||||
graph.add_conditional_edges("llm_tool_choice", should_continue, ["tool_call", END])
|
||||
|
||||
graph.add_edge(
|
||||
start_key="tool_call",
|
||||
end_key="basic_use_tool_response",
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key="basic_use_tool_response",
|
||||
end_key=END,
|
||||
)
|
||||
|
||||
return graph
|
||||
|
||||
|
||||
def should_continue(state: BasicState) -> str:
|
||||
return (
|
||||
# If there are no tool calls, basic graph already streamed the answer
|
||||
END
|
||||
if state.tool_choice is None
|
||||
else "tool_call"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.context.search.models import SearchRequest
|
||||
from onyx.llm.factory import get_default_llms
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
|
||||
|
||||
graph = basic_graph_builder()
|
||||
compiled_graph = graph.compile()
|
||||
# TODO: unify basic input
|
||||
input = BasicInput(logs="")
|
||||
primary_llm, fast_llm = get_default_llms()
|
||||
with get_session_context_manager() as db_session:
|
||||
config, _ = get_test_config(
|
||||
db_session=db_session,
|
||||
primary_llm=primary_llm,
|
||||
fast_llm=fast_llm,
|
||||
search_request=SearchRequest(query="How does onyx use FastAPI?"),
|
||||
)
|
||||
compiled_graph.invoke(input, config={"metadata": {"config": config}})
|
||||
42
backend/onyx/agents/agent_search/basic/states.py
Normal file
42
backend/onyx/agents/agent_search/basic/states.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from typing import TypedDict
|
||||
|
||||
from langchain_core.messages import AIMessageChunk
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.agents.agent_search.orchestration.states import ToolCallUpdate
|
||||
from onyx.agents.agent_search.orchestration.states import ToolChoiceInput
|
||||
from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
|
||||
|
||||
# States contain values that change over the course of graph execution,
|
||||
# Config is for values that are set at the start and never change.
|
||||
# If you are using a value from the config and realize it needs to change,
|
||||
# you should add it to the state and use/update the version in the state.
|
||||
|
||||
## Graph Input State
|
||||
|
||||
|
||||
class BasicInput(BaseModel):
|
||||
# TODO: subclass global log update state
|
||||
logs: str = ""
|
||||
|
||||
|
||||
## Graph Output State
|
||||
|
||||
|
||||
class BasicOutput(TypedDict):
|
||||
tool_call_chunk: AIMessageChunk
|
||||
|
||||
|
||||
## Update States
|
||||
|
||||
|
||||
## Graph State
|
||||
|
||||
|
||||
class BasicState(
|
||||
BasicInput,
|
||||
ToolChoiceInput,
|
||||
ToolCallUpdate,
|
||||
ToolChoiceUpdate,
|
||||
):
|
||||
pass
|
||||
69
backend/onyx/agents/agent_search/basic/utils.py
Normal file
69
backend/onyx/agents/agent_search/basic/utils.py
Normal file
@@ -0,0 +1,69 @@
|
||||
from collections.abc import Iterator
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.callbacks.manager import dispatch_custom_event
|
||||
from langchain_core.messages import AIMessageChunk
|
||||
from langchain_core.messages import BaseMessage
|
||||
|
||||
from onyx.chat.models import LlmDoc
|
||||
from onyx.chat.stream_processing.answer_response_handler import AnswerResponseHandler
|
||||
from onyx.chat.stream_processing.answer_response_handler import CitationResponseHandler
|
||||
from onyx.chat.stream_processing.answer_response_handler import (
|
||||
PassThroughAnswerResponseHandler,
|
||||
)
|
||||
from onyx.chat.stream_processing.utils import map_document_id_order
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# TODO: handle citations here; below is what was previously passed in
|
||||
# see basic_use_tool_response.py for where these variables come from
|
||||
# answer_handler = CitationResponseHandler(
|
||||
# context_docs=final_search_results,
|
||||
# final_doc_id_to_rank_map=map_document_id_order(final_search_results),
|
||||
# display_doc_id_to_rank_map=map_document_id_order(displayed_search_results),
|
||||
# )
|
||||
|
||||
|
||||
def process_llm_stream(
|
||||
stream: Iterator[BaseMessage],
|
||||
should_stream_answer: bool,
|
||||
final_search_results: list[LlmDoc] | None = None,
|
||||
displayed_search_results: list[LlmDoc] | None = None,
|
||||
) -> AIMessageChunk:
|
||||
tool_call_chunk = AIMessageChunk(content="")
|
||||
# for response in response_handler_manager.handle_llm_response(stream):
|
||||
|
||||
if final_search_results and displayed_search_results:
|
||||
answer_handler: AnswerResponseHandler = CitationResponseHandler(
|
||||
context_docs=final_search_results,
|
||||
final_doc_id_to_rank_map=map_document_id_order(final_search_results),
|
||||
display_doc_id_to_rank_map=map_document_id_order(displayed_search_results),
|
||||
)
|
||||
else:
|
||||
answer_handler = PassThroughAnswerResponseHandler()
|
||||
|
||||
full_answer = ""
|
||||
# This stream will be the llm answer if no tool is chosen. When a tool is chosen,
|
||||
# the stream will contain AIMessageChunks with tool call information.
|
||||
for response in stream:
|
||||
answer_piece = response.content
|
||||
if not isinstance(answer_piece, str):
|
||||
# TODO: handle non-string content
|
||||
logger.warning(f"Received non-string content: {type(answer_piece)}")
|
||||
answer_piece = str(answer_piece)
|
||||
full_answer += answer_piece
|
||||
|
||||
if isinstance(response, AIMessageChunk) and (
|
||||
response.tool_call_chunks or response.tool_calls
|
||||
):
|
||||
tool_call_chunk += response # type: ignore
|
||||
elif should_stream_answer:
|
||||
for response_part in answer_handler.handle_response_part(response, []):
|
||||
dispatch_custom_event(
|
||||
"basic_response",
|
||||
response_part,
|
||||
)
|
||||
|
||||
logger.info(f"Full answer: {full_answer}")
|
||||
return cast(AIMessageChunk, tool_call_chunk)
|
||||
21
backend/onyx/agents/agent_search/core_state.py
Normal file
21
backend/onyx/agents/agent_search/core_state.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from operator import add
|
||||
from typing import Annotated
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class CoreState(BaseModel):
|
||||
"""
|
||||
This is the core state that is shared across all subgraphs.
|
||||
"""
|
||||
|
||||
base_question: str = ""
|
||||
log_messages: Annotated[list[str], add] = []
|
||||
|
||||
|
||||
class SubgraphCoreState(BaseModel):
|
||||
"""
|
||||
This is the core state that is shared across all subgraphs.
|
||||
"""
|
||||
|
||||
log_messages: Annotated[list[str], add]
|
||||
66
backend/onyx/agents/agent_search/db_operations.py
Normal file
66
backend/onyx/agents/agent_search/db_operations.py
Normal file
@@ -0,0 +1,66 @@
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.models import AgentSubQuery
|
||||
from onyx.db.models import AgentSubQuestion
|
||||
|
||||
|
||||
def create_sub_question(
|
||||
db_session: Session,
|
||||
chat_session_id: UUID,
|
||||
primary_message_id: int,
|
||||
sub_question: str,
|
||||
sub_answer: str,
|
||||
) -> AgentSubQuestion:
|
||||
"""Create a new sub-question record in the database."""
|
||||
sub_q = AgentSubQuestion(
|
||||
chat_session_id=chat_session_id,
|
||||
primary_question_id=primary_message_id,
|
||||
sub_question=sub_question,
|
||||
sub_answer=sub_answer,
|
||||
)
|
||||
db_session.add(sub_q)
|
||||
db_session.flush()
|
||||
return sub_q
|
||||
|
||||
|
||||
def create_sub_query(
|
||||
db_session: Session,
|
||||
chat_session_id: UUID,
|
||||
parent_question_id: int,
|
||||
sub_query: str,
|
||||
) -> AgentSubQuery:
|
||||
"""Create a new sub-query record in the database."""
|
||||
sub_q = AgentSubQuery(
|
||||
chat_session_id=chat_session_id,
|
||||
parent_question_id=parent_question_id,
|
||||
sub_query=sub_query,
|
||||
)
|
||||
db_session.add(sub_q)
|
||||
db_session.flush()
|
||||
return sub_q
|
||||
|
||||
|
||||
def get_sub_questions_for_message(
|
||||
db_session: Session,
|
||||
primary_message_id: int,
|
||||
) -> list[AgentSubQuestion]:
|
||||
"""Get all sub-questions for a given primary message."""
|
||||
return (
|
||||
db_session.query(AgentSubQuestion)
|
||||
.filter(AgentSubQuestion.primary_question_id == primary_message_id)
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
def get_sub_queries_for_question(
|
||||
db_session: Session,
|
||||
sub_question_id: int,
|
||||
) -> list[AgentSubQuery]:
|
||||
"""Get all sub-queries for a given sub-question."""
|
||||
return (
|
||||
db_session.query(AgentSubQuery)
|
||||
.filter(AgentSubQuery.parent_question_id == sub_question_id)
|
||||
.all()
|
||||
)
|
||||
@@ -0,0 +1,29 @@
|
||||
from collections.abc import Hashable
|
||||
from datetime import datetime
|
||||
|
||||
from langgraph.types import Send
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.states import (
|
||||
AnswerQuestionInput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
|
||||
ExpandedRetrievalInput,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def send_to_expanded_retrieval(state: AnswerQuestionInput) -> Send | Hashable:
|
||||
logger.debug("sending to expanded retrieval via edge")
|
||||
now_start = datetime.now()
|
||||
|
||||
return Send(
|
||||
"initial_sub_question_expanded_retrieval",
|
||||
ExpandedRetrievalInput(
|
||||
question=state.question,
|
||||
base_search=False,
|
||||
sub_question_id=state.question_id,
|
||||
log_messages=[f"{now_start} -- Sending to expanded retrieval"],
|
||||
),
|
||||
)
|
||||
@@ -0,0 +1,126 @@
|
||||
from langgraph.graph import END
|
||||
from langgraph.graph import START
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.edges import (
|
||||
send_to_expanded_retrieval,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.nodes.answer_check import (
|
||||
answer_check,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.nodes.answer_generation import (
|
||||
answer_generation,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.nodes.format_answer import (
|
||||
format_answer,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.nodes.ingest_retrieval import (
|
||||
ingest_retrieval,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.states import (
|
||||
AnswerQuestionInput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.states import (
|
||||
AnswerQuestionOutput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.states import (
|
||||
AnswerQuestionState,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.graph_builder import (
|
||||
expanded_retrieval_graph_builder,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def answer_query_graph_builder() -> StateGraph:
|
||||
graph = StateGraph(
|
||||
state_schema=AnswerQuestionState,
|
||||
input=AnswerQuestionInput,
|
||||
output=AnswerQuestionOutput,
|
||||
)
|
||||
|
||||
### Add nodes ###
|
||||
|
||||
expanded_retrieval = expanded_retrieval_graph_builder().compile()
|
||||
graph.add_node(
|
||||
node="initial_sub_question_expanded_retrieval",
|
||||
action=expanded_retrieval,
|
||||
)
|
||||
graph.add_node(
|
||||
node="answer_check",
|
||||
action=answer_check,
|
||||
)
|
||||
graph.add_node(
|
||||
node="answer_generation",
|
||||
action=answer_generation,
|
||||
)
|
||||
graph.add_node(
|
||||
node="format_answer",
|
||||
action=format_answer,
|
||||
)
|
||||
graph.add_node(
|
||||
node="ingest_retrieval",
|
||||
action=ingest_retrieval,
|
||||
)
|
||||
|
||||
### Add edges ###
|
||||
|
||||
graph.add_conditional_edges(
|
||||
source=START,
|
||||
path=send_to_expanded_retrieval,
|
||||
path_map=["initial_sub_question_expanded_retrieval"],
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="initial_sub_question_expanded_retrieval",
|
||||
end_key="ingest_retrieval",
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="ingest_retrieval",
|
||||
end_key="answer_generation",
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="answer_generation",
|
||||
end_key="answer_check",
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="answer_check",
|
||||
end_key="format_answer",
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="format_answer",
|
||||
end_key=END,
|
||||
)
|
||||
|
||||
return graph
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.llm.factory import get_default_llms
|
||||
from onyx.context.search.models import SearchRequest
|
||||
|
||||
graph = answer_query_graph_builder()
|
||||
compiled_graph = graph.compile()
|
||||
primary_llm, fast_llm = get_default_llms()
|
||||
search_request = SearchRequest(
|
||||
query="what can you do with onyx or danswer?",
|
||||
)
|
||||
with get_session_context_manager() as db_session:
|
||||
agent_search_config, search_tool = get_test_config(
|
||||
db_session, primary_llm, fast_llm, search_request
|
||||
)
|
||||
inputs = AnswerQuestionInput(
|
||||
question="what can you do with onyx?",
|
||||
question_id="0_0",
|
||||
log_messages=[],
|
||||
)
|
||||
for thing in compiled_graph.stream(
|
||||
input=inputs,
|
||||
config={"configurable": {"config": agent_search_config}},
|
||||
# debug=True,
|
||||
# subgraphs=True,
|
||||
):
|
||||
logger.debug(thing)
|
||||
@@ -0,0 +1,8 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
### Models ###
|
||||
|
||||
|
||||
class AnswerRetrievalStats(BaseModel):
|
||||
answer_retrieval_stats: dict[str, float | int]
|
||||
@@ -0,0 +1,59 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import merge_message_runs
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.states import (
|
||||
AnswerQuestionState,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.states import (
|
||||
QACheckUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import SUB_CHECK_NO
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import SUB_CHECK_PROMPT
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import UNKNOWN_ANSWER
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
|
||||
|
||||
|
||||
def answer_check(state: AnswerQuestionState, config: RunnableConfig) -> QACheckUpdate:
|
||||
now_start = datetime.now()
|
||||
|
||||
level, question_num = parse_question_id(state.question_id)
|
||||
if state.answer == UNKNOWN_ANSWER:
|
||||
now_end = datetime.now()
|
||||
return QACheckUpdate(
|
||||
answer_quality=SUB_CHECK_NO,
|
||||
log_messages=[
|
||||
f"{now_start} -- Answer check SQ-{level}-{question_num} - unknown answer, Time taken: {now_end - now_start}"
|
||||
],
|
||||
)
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=SUB_CHECK_PROMPT.format(
|
||||
question=state.question,
|
||||
base_answer=state.answer,
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
agent_searchch_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
fast_llm = agent_searchch_config.fast_llm
|
||||
response = list(
|
||||
fast_llm.stream(
|
||||
prompt=msg,
|
||||
)
|
||||
)
|
||||
|
||||
quality_str = merge_message_runs(response, chunk_separator="")[0].content
|
||||
|
||||
now_end = datetime.now()
|
||||
return QACheckUpdate(
|
||||
answer_quality=quality_str,
|
||||
log_messages=[
|
||||
f"""{now_start} -- Answer check SQ-{level}-{question_num} - Answer quality: {quality_str},
|
||||
Time taken: {now_end - now_start}"""
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,116 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.callbacks.manager import dispatch_custom_event
|
||||
from langchain_core.messages import merge_message_runs
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.states import (
|
||||
AnswerQuestionState,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.states import (
|
||||
QAGenerationUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
build_sub_question_answer_prompt,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import (
|
||||
ASSISTANT_SYSTEM_PROMPT_DEFAULT,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import (
|
||||
ASSISTANT_SYSTEM_PROMPT_PERSONA,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import NO_RECOVERED_DOCS
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import get_persona_prompt
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
|
||||
from onyx.chat.models import AgentAnswerPiece
|
||||
from onyx.chat.models import StreamStopInfo
|
||||
from onyx.chat.models import StreamStopReason
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def answer_generation(
|
||||
state: AnswerQuestionState, config: RunnableConfig
|
||||
) -> QAGenerationUpdate:
|
||||
now_start = datetime.now()
|
||||
logger.debug(f"--------{now_start}--------START ANSWER GENERATION---")
|
||||
|
||||
agent_search_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
question = state.question
|
||||
docs = state.documents
|
||||
level, question_nr = parse_question_id(state.question_id)
|
||||
context_docs = state.context_documents
|
||||
persona_prompt = get_persona_prompt(agent_search_config.search_request.persona)
|
||||
|
||||
if len(context_docs) == 0:
|
||||
answer_str = NO_RECOVERED_DOCS
|
||||
dispatch_custom_event(
|
||||
"sub_answers",
|
||||
AgentAnswerPiece(
|
||||
answer_piece=answer_str,
|
||||
level=level,
|
||||
level_question_nr=question_nr,
|
||||
answer_type="agent_sub_answer",
|
||||
),
|
||||
)
|
||||
else:
|
||||
if len(persona_prompt) > 0:
|
||||
persona_specification = ASSISTANT_SYSTEM_PROMPT_DEFAULT
|
||||
else:
|
||||
persona_specification = ASSISTANT_SYSTEM_PROMPT_PERSONA.format(
|
||||
persona_prompt=persona_prompt
|
||||
)
|
||||
|
||||
logger.debug(f"Number of verified retrieval docs: {len(docs)}")
|
||||
|
||||
fast_llm = agent_search_config.fast_llm
|
||||
msg = build_sub_question_answer_prompt(
|
||||
question=question,
|
||||
original_question=agent_search_config.search_request.query,
|
||||
docs=docs,
|
||||
persona_specification=persona_specification,
|
||||
config=fast_llm.config,
|
||||
)
|
||||
|
||||
response: list[str | list[str | dict[str, Any]]] = []
|
||||
for message in fast_llm.stream(
|
||||
prompt=msg,
|
||||
):
|
||||
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
|
||||
content = message.content
|
||||
if not isinstance(content, str):
|
||||
raise ValueError(
|
||||
f"Expected content to be a string, but got {type(content)}"
|
||||
)
|
||||
dispatch_custom_event(
|
||||
"sub_answers",
|
||||
AgentAnswerPiece(
|
||||
answer_piece=content,
|
||||
level=level,
|
||||
level_question_nr=question_nr,
|
||||
answer_type="agent_sub_answer",
|
||||
),
|
||||
)
|
||||
response.append(content)
|
||||
|
||||
answer_str = merge_message_runs(response, chunk_separator="")[0].content
|
||||
|
||||
stop_event = StreamStopInfo(
|
||||
stop_reason=StreamStopReason.FINISHED,
|
||||
stream_type="sub_answer",
|
||||
level=level,
|
||||
level_question_nr=question_nr,
|
||||
)
|
||||
dispatch_custom_event("stream_finished", stop_event)
|
||||
|
||||
now_end = datetime.now()
|
||||
return QAGenerationUpdate(
|
||||
answer=answer_str,
|
||||
log_messages=[
|
||||
f"{now_end} -- Answer generation SQ-{level} - Q{question_nr} - Time taken: {now_end - now_start}"
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,28 @@
|
||||
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.states import (
|
||||
AnswerQuestionOutput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.states import (
|
||||
AnswerQuestionState,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import (
|
||||
QuestionAnswerResults,
|
||||
)
|
||||
|
||||
|
||||
def format_answer(state: AnswerQuestionState) -> AnswerQuestionOutput:
|
||||
return AnswerQuestionOutput(
|
||||
answer_results=[
|
||||
QuestionAnswerResults(
|
||||
question=state.question,
|
||||
question_id=state.question_id,
|
||||
quality=state.answer_quality
|
||||
if hasattr(state, "answer_quality")
|
||||
else "No",
|
||||
answer=state.answer,
|
||||
expanded_retrieval_results=state.expanded_retrieval_results,
|
||||
documents=state.documents,
|
||||
context_documents=state.context_documents,
|
||||
sub_question_retrieval_stats=state.sub_question_retrieval_stats,
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,22 @@
|
||||
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.states import (
|
||||
RetrievalIngestionUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
|
||||
ExpandedRetrievalOutput,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkStats
|
||||
|
||||
|
||||
def ingest_retrieval(state: ExpandedRetrievalOutput) -> RetrievalIngestionUpdate:
|
||||
sub_question_retrieval_stats = (
|
||||
state.expanded_retrieval_result.sub_question_retrieval_stats
|
||||
)
|
||||
if sub_question_retrieval_stats is None:
|
||||
sub_question_retrieval_stats = [AgentChunkStats()]
|
||||
|
||||
return RetrievalIngestionUpdate(
|
||||
expanded_retrieval_results=state.expanded_retrieval_result.expanded_queries_results,
|
||||
documents=state.expanded_retrieval_result.all_documents,
|
||||
context_documents=state.expanded_retrieval_result.context_documents,
|
||||
sub_question_retrieval_stats=sub_question_retrieval_stats,
|
||||
)
|
||||
@@ -0,0 +1,71 @@
|
||||
from operator import add
|
||||
from typing import Annotated
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.agents.agent_search.core_state import SubgraphCoreState
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkStats
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import QueryResult
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import (
|
||||
QuestionAnswerResults,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.operators import (
|
||||
dedup_inference_sections,
|
||||
)
|
||||
from onyx.context.search.models import InferenceSection
|
||||
|
||||
|
||||
## Update States
|
||||
class QACheckUpdate(BaseModel):
|
||||
answer_quality: str = ""
|
||||
log_messages: list[str] = []
|
||||
|
||||
|
||||
class QAGenerationUpdate(BaseModel):
|
||||
answer: str = ""
|
||||
log_messages: list[str] = []
|
||||
# answer_stat: AnswerStats
|
||||
|
||||
|
||||
class RetrievalIngestionUpdate(BaseModel):
|
||||
expanded_retrieval_results: list[QueryResult] = []
|
||||
documents: Annotated[list[InferenceSection], dedup_inference_sections] = []
|
||||
context_documents: Annotated[list[InferenceSection], dedup_inference_sections] = []
|
||||
sub_question_retrieval_stats: AgentChunkStats = AgentChunkStats()
|
||||
|
||||
|
||||
## Graph Input State
|
||||
|
||||
|
||||
class AnswerQuestionInput(SubgraphCoreState):
|
||||
question: str = ""
|
||||
question_id: str = (
|
||||
"" # 0_0 is original question, everything else is <level>_<question_num>.
|
||||
)
|
||||
# level 0 is original question and first decomposition, level 1 is follow up, etc
|
||||
# question_num is a unique number per original question per level.
|
||||
|
||||
|
||||
## Graph State
|
||||
|
||||
|
||||
class AnswerQuestionState(
|
||||
AnswerQuestionInput,
|
||||
QAGenerationUpdate,
|
||||
QACheckUpdate,
|
||||
RetrievalIngestionUpdate,
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
## Graph Output State
|
||||
|
||||
|
||||
class AnswerQuestionOutput(BaseModel):
|
||||
"""
|
||||
This is a list of results even though each call of this subgraph only returns one result.
|
||||
This is because if we parallelize the answer query subgraph, there will be multiple
|
||||
results in a list so the add operator is used to add them together.
|
||||
"""
|
||||
|
||||
answer_results: Annotated[list[QuestionAnswerResults], add] = []
|
||||
@@ -0,0 +1,28 @@
|
||||
from collections.abc import Hashable
|
||||
from datetime import datetime
|
||||
|
||||
from langgraph.types import Send
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.states import (
|
||||
AnswerQuestionInput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
|
||||
ExpandedRetrievalInput,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def send_to_expanded_refined_retrieval(state: AnswerQuestionInput) -> Send | Hashable:
|
||||
logger.debug("sending to expanded retrieval for follow up question via edge")
|
||||
datetime.now()
|
||||
return Send(
|
||||
"refined_sub_question_expanded_retrieval",
|
||||
ExpandedRetrievalInput(
|
||||
question=state.question,
|
||||
sub_question_id=state.question_id,
|
||||
base_search=False,
|
||||
log_messages=[f"{datetime.now()} -- Sending to expanded retrieval"],
|
||||
),
|
||||
)
|
||||
@@ -0,0 +1,123 @@
|
||||
from langgraph.graph import END
|
||||
from langgraph.graph import START
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.nodes.answer_check import (
|
||||
answer_check,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.nodes.answer_generation import (
|
||||
answer_generation,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.nodes.format_answer import (
|
||||
format_answer,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.nodes.ingest_retrieval import (
|
||||
ingest_retrieval,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.states import (
|
||||
AnswerQuestionInput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.states import (
|
||||
AnswerQuestionOutput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.states import (
|
||||
AnswerQuestionState,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.answer_refinement_sub_question.edges import (
|
||||
send_to_expanded_refined_retrieval,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.graph_builder import (
|
||||
expanded_retrieval_graph_builder,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def answer_refined_query_graph_builder() -> StateGraph:
|
||||
graph = StateGraph(
|
||||
state_schema=AnswerQuestionState,
|
||||
input=AnswerQuestionInput,
|
||||
output=AnswerQuestionOutput,
|
||||
)
|
||||
|
||||
### Add nodes ###
|
||||
|
||||
expanded_retrieval = expanded_retrieval_graph_builder().compile()
|
||||
graph.add_node(
|
||||
node="refined_sub_question_expanded_retrieval",
|
||||
action=expanded_retrieval,
|
||||
)
|
||||
graph.add_node(
|
||||
node="refined_sub_answer_check",
|
||||
action=answer_check,
|
||||
)
|
||||
graph.add_node(
|
||||
node="refined_sub_answer_generation",
|
||||
action=answer_generation,
|
||||
)
|
||||
graph.add_node(
|
||||
node="format_refined_sub_answer",
|
||||
action=format_answer,
|
||||
)
|
||||
graph.add_node(
|
||||
node="ingest_refined_retrieval",
|
||||
action=ingest_retrieval,
|
||||
)
|
||||
|
||||
### Add edges ###
|
||||
|
||||
graph.add_conditional_edges(
|
||||
source=START,
|
||||
path=send_to_expanded_refined_retrieval,
|
||||
path_map=["refined_sub_question_expanded_retrieval"],
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="refined_sub_question_expanded_retrieval",
|
||||
end_key="ingest_refined_retrieval",
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="ingest_refined_retrieval",
|
||||
end_key="refined_sub_answer_generation",
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="refined_sub_answer_generation",
|
||||
end_key="refined_sub_answer_check",
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="refined_sub_answer_check",
|
||||
end_key="format_refined_sub_answer",
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="format_refined_sub_answer",
|
||||
end_key=END,
|
||||
)
|
||||
|
||||
return graph
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.llm.factory import get_default_llms
|
||||
from onyx.context.search.models import SearchRequest
|
||||
|
||||
graph = answer_refined_query_graph_builder()
|
||||
compiled_graph = graph.compile()
|
||||
primary_llm, fast_llm = get_default_llms()
|
||||
search_request = SearchRequest(
|
||||
query="what can you do with onyx or danswer?",
|
||||
)
|
||||
with get_session_context_manager() as db_session:
|
||||
inputs = AnswerQuestionInput(
|
||||
question="what can you do with onyx?",
|
||||
question_id="0_0",
|
||||
log_messages=[],
|
||||
)
|
||||
for thing in compiled_graph.stream(
|
||||
input=inputs,
|
||||
# debug=True,
|
||||
# subgraphs=True,
|
||||
):
|
||||
logger.debug(thing)
|
||||
# output = compiled_graph.invoke(inputs)
|
||||
# logger.debug(output)
|
||||
@@ -0,0 +1,19 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkStats
|
||||
from onyx.context.search.models import InferenceSection
|
||||
|
||||
### Models ###
|
||||
|
||||
|
||||
class AnswerRetrievalStats(BaseModel):
|
||||
answer_retrieval_stats: dict[str, float | int]
|
||||
|
||||
|
||||
class QuestionAnswerResults(BaseModel):
|
||||
question: str
|
||||
answer: str
|
||||
quality: str
|
||||
# expanded_retrieval_results: list[QueryResult]
|
||||
documents: list[InferenceSection]
|
||||
sub_question_retrieval_stats: AgentChunkStats
|
||||
@@ -0,0 +1,76 @@
|
||||
from langgraph.graph import END
|
||||
from langgraph.graph import START
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.base_raw_search.nodes.format_raw_search_results import (
|
||||
format_raw_search_results,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.base_raw_search.nodes.generate_raw_search_data import (
|
||||
generate_raw_search_data,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.base_raw_search.states import (
|
||||
BaseRawSearchInput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.base_raw_search.states import (
|
||||
BaseRawSearchOutput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.base_raw_search.states import (
|
||||
BaseRawSearchState,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.graph_builder import (
|
||||
expanded_retrieval_graph_builder,
|
||||
)
|
||||
|
||||
|
||||
def base_raw_search_graph_builder() -> StateGraph:
|
||||
graph = StateGraph(
|
||||
state_schema=BaseRawSearchState,
|
||||
input=BaseRawSearchInput,
|
||||
output=BaseRawSearchOutput,
|
||||
)
|
||||
|
||||
### Add nodes ###
|
||||
|
||||
graph.add_node(
|
||||
node="generate_raw_search_data",
|
||||
action=generate_raw_search_data,
|
||||
)
|
||||
|
||||
expanded_retrieval = expanded_retrieval_graph_builder().compile()
|
||||
graph.add_node(
|
||||
node="expanded_retrieval_base_search",
|
||||
action=expanded_retrieval,
|
||||
)
|
||||
graph.add_node(
|
||||
node="format_raw_search_results",
|
||||
action=format_raw_search_results,
|
||||
)
|
||||
|
||||
### Add edges ###
|
||||
|
||||
graph.add_edge(start_key=START, end_key="generate_raw_search_data")
|
||||
|
||||
graph.add_edge(
|
||||
start_key="generate_raw_search_data",
|
||||
end_key="expanded_retrieval_base_search",
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="expanded_retrieval_base_search",
|
||||
end_key="format_raw_search_results",
|
||||
)
|
||||
|
||||
# graph.add_edge(
|
||||
# start_key="expanded_retrieval_base_search",
|
||||
# end_key=END,
|
||||
# )
|
||||
|
||||
graph.add_edge(
|
||||
start_key="format_raw_search_results",
|
||||
end_key=END,
|
||||
)
|
||||
|
||||
return graph
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pass
|
||||
@@ -0,0 +1,20 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkStats
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import QueryResult
|
||||
from onyx.context.search.models import InferenceSection
|
||||
|
||||
### Models ###
|
||||
|
||||
|
||||
class AnswerRetrievalStats(BaseModel):
|
||||
answer_retrieval_stats: dict[str, float | int]
|
||||
|
||||
|
||||
class QuestionAnswerResults(BaseModel):
|
||||
question: str
|
||||
answer: str
|
||||
quality: str
|
||||
expanded_retrieval_results: list[QueryResult]
|
||||
documents: list[InferenceSection]
|
||||
sub_question_retrieval_stats: list[AgentChunkStats]
|
||||
@@ -0,0 +1,18 @@
|
||||
from onyx.agents.agent_search.deep_search_a.base_raw_search.states import (
|
||||
BaseRawSearchOutput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
|
||||
ExpandedRetrievalOutput,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def format_raw_search_results(state: ExpandedRetrievalOutput) -> BaseRawSearchOutput:
|
||||
logger.debug("format_raw_search_results")
|
||||
return BaseRawSearchOutput(
|
||||
base_expanded_retrieval_result=state.expanded_retrieval_result,
|
||||
# base_retrieval_results=[state.expanded_retrieval_result],
|
||||
# base_search_documents=[],
|
||||
)
|
||||
@@ -0,0 +1,25 @@
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.core_state import CoreState
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
|
||||
ExpandedRetrievalInput,
|
||||
)
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def generate_raw_search_data(
|
||||
state: CoreState, config: RunnableConfig
|
||||
) -> ExpandedRetrievalInput:
|
||||
logger.debug("generate_raw_search_data")
|
||||
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
return ExpandedRetrievalInput(
|
||||
question=agent_a_config.search_request.query,
|
||||
base_search=True,
|
||||
sub_question_id=None, # This graph is always and only used for the original question
|
||||
log_messages=[],
|
||||
)
|
||||
@@ -0,0 +1,43 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.models import (
|
||||
ExpandedRetrievalResult,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
|
||||
ExpandedRetrievalInput,
|
||||
)
|
||||
|
||||
|
||||
## Update States
|
||||
|
||||
|
||||
## Graph Input State
|
||||
|
||||
|
||||
class BaseRawSearchInput(ExpandedRetrievalInput):
|
||||
pass
|
||||
|
||||
|
||||
## Graph Output State
|
||||
|
||||
|
||||
class BaseRawSearchOutput(BaseModel):
|
||||
"""
|
||||
This is a list of results even though each call of this subgraph only returns one result.
|
||||
This is because if we parallelize the answer query subgraph, there will be multiple
|
||||
results in a list so the add operator is used to add them together.
|
||||
"""
|
||||
|
||||
# base_search_documents: Annotated[list[InferenceSection], dedup_inference_sections]
|
||||
# base_retrieval_results: Annotated[list[ExpandedRetrievalResult], add]
|
||||
base_expanded_retrieval_result: ExpandedRetrievalResult = ExpandedRetrievalResult()
|
||||
|
||||
|
||||
## Graph State
|
||||
|
||||
|
||||
class BaseRawSearchState(
|
||||
BaseRawSearchInput,
|
||||
BaseRawSearchOutput,
|
||||
):
|
||||
pass
|
||||
@@ -0,0 +1,37 @@
|
||||
from collections.abc import Hashable
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
from langgraph.types import Send
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
|
||||
ExpandedRetrievalState,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
|
||||
RetrievalInput,
|
||||
)
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
|
||||
|
||||
def parallel_retrieval_edge(
|
||||
state: ExpandedRetrievalState, config: RunnableConfig
|
||||
) -> list[Send | Hashable]:
|
||||
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
question = state.question if state.question else agent_a_config.search_request.query
|
||||
|
||||
query_expansions = (
|
||||
state.expanded_queries if state.expanded_queries else [] + [question]
|
||||
)
|
||||
return [
|
||||
Send(
|
||||
"doc_retrieval",
|
||||
RetrievalInput(
|
||||
query_to_retrieve=query,
|
||||
question=question,
|
||||
base_search=False,
|
||||
sub_question_id=state.sub_question_id,
|
||||
log_messages=[],
|
||||
),
|
||||
)
|
||||
for query in query_expansions
|
||||
]
|
||||
@@ -0,0 +1,147 @@
|
||||
from langgraph.graph import END
|
||||
from langgraph.graph import START
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.edges import (
|
||||
parallel_retrieval_edge,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.nodes.doc_reranking import (
|
||||
doc_reranking,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.nodes.doc_retrieval import (
|
||||
doc_retrieval,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.nodes.doc_verification import (
|
||||
doc_verification,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.nodes.dummy import (
|
||||
dummy,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.nodes.expand_queries import (
|
||||
expand_queries,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.nodes.format_results import (
|
||||
format_results,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.nodes.verification_kickoff import (
|
||||
verification_kickoff,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
|
||||
ExpandedRetrievalInput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
|
||||
ExpandedRetrievalOutput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
|
||||
ExpandedRetrievalState,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def expanded_retrieval_graph_builder() -> StateGraph:
|
||||
graph = StateGraph(
|
||||
state_schema=ExpandedRetrievalState,
|
||||
input=ExpandedRetrievalInput,
|
||||
output=ExpandedRetrievalOutput,
|
||||
)
|
||||
|
||||
### Add nodes ###
|
||||
|
||||
graph.add_node(
|
||||
node="expand_queries",
|
||||
action=expand_queries,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
node="dummy",
|
||||
action=dummy,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
node="doc_retrieval",
|
||||
action=doc_retrieval,
|
||||
)
|
||||
graph.add_node(
|
||||
node="verification_kickoff",
|
||||
action=verification_kickoff,
|
||||
)
|
||||
graph.add_node(
|
||||
node="doc_verification",
|
||||
action=doc_verification,
|
||||
)
|
||||
graph.add_node(
|
||||
node="doc_reranking",
|
||||
action=doc_reranking,
|
||||
)
|
||||
graph.add_node(
|
||||
node="format_results",
|
||||
action=format_results,
|
||||
)
|
||||
|
||||
### Add edges ###
|
||||
graph.add_edge(
|
||||
start_key=START,
|
||||
end_key="expand_queries",
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="expand_queries",
|
||||
end_key="dummy",
|
||||
)
|
||||
|
||||
graph.add_conditional_edges(
|
||||
source="dummy",
|
||||
path=parallel_retrieval_edge,
|
||||
path_map=["doc_retrieval"],
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="doc_retrieval",
|
||||
end_key="verification_kickoff",
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="doc_verification",
|
||||
end_key="doc_reranking",
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="doc_reranking",
|
||||
end_key="format_results",
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="format_results",
|
||||
end_key=END,
|
||||
)
|
||||
|
||||
return graph
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.llm.factory import get_default_llms
|
||||
from onyx.context.search.models import SearchRequest
|
||||
|
||||
graph = expanded_retrieval_graph_builder()
|
||||
compiled_graph = graph.compile()
|
||||
primary_llm, fast_llm = get_default_llms()
|
||||
search_request = SearchRequest(
|
||||
query="what can you do with onyx or danswer?",
|
||||
)
|
||||
|
||||
with get_session_context_manager() as db_session:
|
||||
agent_a_config, search_tool = get_test_config(
|
||||
db_session, primary_llm, fast_llm, search_request
|
||||
)
|
||||
inputs = ExpandedRetrievalInput(
|
||||
question="what can you do with onyx?",
|
||||
base_search=False,
|
||||
sub_question_id=None,
|
||||
log_messages=[],
|
||||
)
|
||||
for thing in compiled_graph.stream(
|
||||
input=inputs,
|
||||
config={"configurable": {"config": agent_a_config}},
|
||||
# debug=True,
|
||||
subgraphs=True,
|
||||
):
|
||||
logger.debug(thing)
|
||||
@@ -0,0 +1,12 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkStats
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import QueryResult
|
||||
from onyx.context.search.models import InferenceSection
|
||||
|
||||
|
||||
class ExpandedRetrievalResult(BaseModel):
|
||||
expanded_queries_results: list[QueryResult] = []
|
||||
all_documents: list[InferenceSection] = []
|
||||
context_documents: list[InferenceSection] = []
|
||||
sub_question_retrieval_stats: AgentChunkStats = AgentChunkStats()
|
||||
@@ -0,0 +1,82 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.operations import logger
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
|
||||
DocRerankingUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
|
||||
ExpandedRetrievalState,
|
||||
)
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.calculations import get_fit_scores
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import RetrievalFitStats
|
||||
from onyx.configs.agent_configs import AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS
|
||||
from onyx.configs.agent_configs import AGENT_RERANKING_STATS
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.context.search.models import SearchRequest
|
||||
from onyx.context.search.pipeline import retrieval_preprocessing
|
||||
from onyx.context.search.postprocessing.postprocessing import rerank_sections
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
|
||||
|
||||
def doc_reranking(
|
||||
state: ExpandedRetrievalState, config: RunnableConfig
|
||||
) -> DocRerankingUpdate:
|
||||
now_start = datetime.now()
|
||||
verified_documents = state.verified_documents
|
||||
|
||||
# Rerank post retrieval and verification. First, create a search query
|
||||
# then create the list of reranked sections
|
||||
|
||||
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
question = state.question if state.question else agent_a_config.search_request.query
|
||||
if agent_a_config.search_tool is None:
|
||||
raise ValueError("search_tool must be provided for agentic search")
|
||||
with get_session_context_manager() as db_session:
|
||||
_search_query = retrieval_preprocessing(
|
||||
search_request=SearchRequest(query=question),
|
||||
user=agent_a_config.search_tool.user, # bit of a hack
|
||||
llm=agent_a_config.fast_llm,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# skip section filtering
|
||||
|
||||
if (
|
||||
_search_query.rerank_settings
|
||||
and _search_query.rerank_settings.rerank_model_name
|
||||
and _search_query.rerank_settings.num_rerank > 0
|
||||
and len(verified_documents) > 0
|
||||
):
|
||||
if len(verified_documents) > 1:
|
||||
reranked_documents = rerank_sections(
|
||||
_search_query,
|
||||
verified_documents,
|
||||
)
|
||||
else:
|
||||
num = "No" if len(verified_documents) == 0 else "One"
|
||||
logger.warning(f"{num} verified document(s) found, skipping reranking")
|
||||
reranked_documents = verified_documents
|
||||
else:
|
||||
logger.warning("No reranking settings found, using unranked documents")
|
||||
reranked_documents = verified_documents
|
||||
|
||||
if AGENT_RERANKING_STATS:
|
||||
fit_scores = get_fit_scores(verified_documents, reranked_documents)
|
||||
else:
|
||||
fit_scores = RetrievalFitStats(fit_score_lift=0, rerank_effect=0, fit_scores={})
|
||||
|
||||
# TODO: stream deduped docs here, or decide to use search tool ranking/verification
|
||||
now_end = datetime.now()
|
||||
return DocRerankingUpdate(
|
||||
reranked_documents=[
|
||||
doc for doc in reranked_documents if type(doc) == InferenceSection
|
||||
][:AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS],
|
||||
sub_question_retrieval_stats=fit_scores,
|
||||
log_messages=[
|
||||
f"{now_end} -- Expanded Retrieval - Reranking - Time taken: {now_end - now_start}"
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,105 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.operations import logger
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
|
||||
DocRetrievalUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
|
||||
RetrievalInput,
|
||||
)
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.calculations import get_fit_scores
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import QueryResult
|
||||
from onyx.configs.agent_configs import AGENT_MAX_QUERY_RETRIEVAL_RESULTS
|
||||
from onyx.configs.agent_configs import AGENT_RETRIEVAL_STATS
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.tools.models import SearchQueryInfo
|
||||
from onyx.tools.tool_implementations.search.search_tool import (
|
||||
SEARCH_RESPONSE_SUMMARY_ID,
|
||||
)
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary
|
||||
|
||||
|
||||
def doc_retrieval(state: RetrievalInput, config: RunnableConfig) -> DocRetrievalUpdate:
|
||||
"""
|
||||
Retrieve documents
|
||||
|
||||
Args:
|
||||
state (RetrievalInput): Primary state + the query to retrieve
|
||||
config (RunnableConfig): Configuration containing ProSearchConfig
|
||||
|
||||
Updates:
|
||||
expanded_retrieval_results: list[ExpandedRetrievalResult]
|
||||
retrieved_documents: list[InferenceSection]
|
||||
"""
|
||||
now_start = datetime.now()
|
||||
query_to_retrieve = state.query_to_retrieve
|
||||
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
search_tool = agent_a_config.search_tool
|
||||
|
||||
retrieved_docs: list[InferenceSection] = []
|
||||
if not query_to_retrieve.strip():
|
||||
logger.warning("Empty query, skipping retrieval")
|
||||
now_end = datetime.now()
|
||||
return DocRetrievalUpdate(
|
||||
expanded_retrieval_results=[],
|
||||
retrieved_documents=[],
|
||||
log_messages=[
|
||||
f"{now_end} -- Expanded Retrieval - Retrieval - Empty Query - Time taken: {now_end - now_start}"
|
||||
],
|
||||
)
|
||||
|
||||
query_info = None
|
||||
if search_tool is None:
|
||||
raise ValueError("search_tool must be provided for agentic search")
|
||||
# new db session to avoid concurrency issues
|
||||
with get_session_context_manager() as db_session:
|
||||
for tool_response in search_tool.run(
|
||||
query=query_to_retrieve,
|
||||
force_no_rerank=True,
|
||||
alternate_db_session=db_session,
|
||||
):
|
||||
# get retrieved docs to send to the rest of the graph
|
||||
if tool_response.id == SEARCH_RESPONSE_SUMMARY_ID:
|
||||
response = cast(SearchResponseSummary, tool_response.response)
|
||||
retrieved_docs = response.top_sections
|
||||
query_info = SearchQueryInfo(
|
||||
predicted_search=response.predicted_search,
|
||||
final_filters=response.final_filters,
|
||||
recency_bias_multiplier=response.recency_bias_multiplier,
|
||||
)
|
||||
break
|
||||
|
||||
retrieved_docs = retrieved_docs[:AGENT_MAX_QUERY_RETRIEVAL_RESULTS]
|
||||
pre_rerank_docs = retrieved_docs
|
||||
if search_tool.search_pipeline is not None:
|
||||
pre_rerank_docs = (
|
||||
search_tool.search_pipeline._retrieved_sections or retrieved_docs
|
||||
)
|
||||
|
||||
if AGENT_RETRIEVAL_STATS:
|
||||
fit_scores = get_fit_scores(
|
||||
pre_rerank_docs,
|
||||
retrieved_docs,
|
||||
)
|
||||
else:
|
||||
fit_scores = None
|
||||
|
||||
expanded_retrieval_result = QueryResult(
|
||||
query=query_to_retrieve,
|
||||
search_results=retrieved_docs,
|
||||
stats=fit_scores,
|
||||
query_info=query_info,
|
||||
)
|
||||
now_end = datetime.now()
|
||||
return DocRetrievalUpdate(
|
||||
expanded_retrieval_results=[expanded_retrieval_result],
|
||||
retrieved_documents=retrieved_docs,
|
||||
log_messages=[
|
||||
f"{now_end} -- Expanded Retrieval - Retrieval - Time taken: {now_end - now_start}"
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,60 @@
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
|
||||
DocVerificationInput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
|
||||
DocVerificationUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
trim_prompt_piece,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import VERIFIER_PROMPT
|
||||
|
||||
|
||||
def doc_verification(
|
||||
state: DocVerificationInput, config: RunnableConfig
|
||||
) -> DocVerificationUpdate:
|
||||
"""
|
||||
Check whether the document is relevant for the original user question
|
||||
|
||||
Args:
|
||||
state (DocVerificationInput): The current state
|
||||
config (RunnableConfig): Configuration containing ProSearchConfig
|
||||
|
||||
Updates:
|
||||
verified_documents: list[InferenceSection]
|
||||
"""
|
||||
|
||||
question = state.question
|
||||
doc_to_verify = state.doc_to_verify
|
||||
document_content = doc_to_verify.combined_content
|
||||
|
||||
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
fast_llm = agent_a_config.fast_llm
|
||||
|
||||
document_content = trim_prompt_piece(
|
||||
fast_llm.config, document_content, VERIFIER_PROMPT + question
|
||||
)
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=VERIFIER_PROMPT.format(
|
||||
question=question, document_content=document_content
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
response = fast_llm.invoke(msg)
|
||||
|
||||
verified_documents = []
|
||||
if isinstance(response.content, str) and "yes" in response.content.lower():
|
||||
verified_documents.append(doc_to_verify)
|
||||
|
||||
return DocVerificationUpdate(
|
||||
verified_documents=verified_documents,
|
||||
)
|
||||
@@ -0,0 +1,16 @@
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
|
||||
ExpandedRetrievalState,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
|
||||
QueryExpansionUpdate,
|
||||
)
|
||||
|
||||
|
||||
def dummy(
|
||||
state: ExpandedRetrievalState, config: RunnableConfig
|
||||
) -> QueryExpansionUpdate:
|
||||
return QueryExpansionUpdate(
|
||||
expanded_queries=state.expanded_queries,
|
||||
)
|
||||
@@ -0,0 +1,68 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import merge_message_runs
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.operations import (
|
||||
dispatch_subquery,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
|
||||
ExpandedRetrievalInput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
|
||||
QueryExpansionUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import (
|
||||
REWRITE_PROMPT_MULTI_ORIGINAL,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import dispatch_separated
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
|
||||
|
||||
|
||||
def expand_queries(
|
||||
state: ExpandedRetrievalInput, config: RunnableConfig
|
||||
) -> QueryExpansionUpdate:
|
||||
# Sometimes we want to expand the original question, sometimes we want to expand a sub-question.
|
||||
# When we are running this node on the original question, no question is explictly passed in.
|
||||
# Instead, we use the original question from the search request.
|
||||
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
now_start = datetime.now()
|
||||
question = (
|
||||
state.question
|
||||
if hasattr(state, "question")
|
||||
else agent_a_config.search_request.query
|
||||
)
|
||||
llm = agent_a_config.fast_llm
|
||||
chat_session_id = agent_a_config.chat_session_id
|
||||
sub_question_id = state.sub_question_id
|
||||
if sub_question_id is None:
|
||||
level, question_nr = 0, 0
|
||||
else:
|
||||
level, question_nr = parse_question_id(sub_question_id)
|
||||
|
||||
if chat_session_id is None:
|
||||
raise ValueError("chat_session_id must be provided for agent search")
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=REWRITE_PROMPT_MULTI_ORIGINAL.format(question=question),
|
||||
)
|
||||
]
|
||||
|
||||
llm_response_list = dispatch_separated(
|
||||
llm.stream(prompt=msg), dispatch_subquery(level, question_nr)
|
||||
)
|
||||
|
||||
llm_response = merge_message_runs(llm_response_list, chunk_separator="")[0].content
|
||||
|
||||
rewritten_queries = llm_response.split("\n")
|
||||
now_end = datetime.now()
|
||||
return QueryExpansionUpdate(
|
||||
expanded_queries=rewritten_queries,
|
||||
log_messages=[
|
||||
f"{now_end} -- Expanded Retrieval - Query Expansion - Time taken: {now_end - now_start}"
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,84 @@
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.callbacks.manager import dispatch_custom_event
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.models import (
|
||||
ExpandedRetrievalResult,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.operations import (
|
||||
calculate_sub_question_retrieval_stats,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
|
||||
ExpandedRetrievalState,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
|
||||
ExpandedRetrievalUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkStats
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
|
||||
from onyx.chat.models import ExtendedToolResponse
|
||||
from onyx.tools.tool_implementations.search.search_tool import yield_search_responses
|
||||
|
||||
|
||||
def format_results(
|
||||
state: ExpandedRetrievalState, config: RunnableConfig
|
||||
) -> ExpandedRetrievalUpdate:
|
||||
level, question_nr = parse_question_id(state.sub_question_id or "0_0")
|
||||
query_infos = [
|
||||
result.query_info
|
||||
for result in state.expanded_retrieval_results
|
||||
if result.query_info is not None
|
||||
]
|
||||
if len(query_infos) == 0:
|
||||
raise ValueError("No query info found")
|
||||
|
||||
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
# main question docs will be sent later after aggregation and deduping with sub-question docs
|
||||
|
||||
stream_documents = state.reranked_documents
|
||||
|
||||
if not (level == 0 and question_nr == 0):
|
||||
if len(stream_documents) == 0:
|
||||
# The sub-question is used as the last query. If no verified documents are found, stream
|
||||
# the top 3 for that one. We may want to revisit this.
|
||||
stream_documents = state.expanded_retrieval_results[-1].search_results[:3]
|
||||
|
||||
if agent_a_config.search_tool is None:
|
||||
raise ValueError("search_tool must be provided for agentic search")
|
||||
for tool_response in yield_search_responses(
|
||||
query=state.question,
|
||||
reranked_sections=state.retrieved_documents, # TODO: rename params. (sections pre-merging here.)
|
||||
final_context_sections=stream_documents,
|
||||
search_query_info=query_infos[0], # TODO: handle differing query infos?
|
||||
get_section_relevance=lambda: None, # TODO: add relevance
|
||||
search_tool=agent_a_config.search_tool,
|
||||
):
|
||||
dispatch_custom_event(
|
||||
"tool_response",
|
||||
ExtendedToolResponse(
|
||||
id=tool_response.id,
|
||||
response=tool_response.response,
|
||||
level=level,
|
||||
level_question_nr=question_nr,
|
||||
),
|
||||
)
|
||||
sub_question_retrieval_stats = calculate_sub_question_retrieval_stats(
|
||||
verified_documents=state.verified_documents,
|
||||
expanded_retrieval_results=state.expanded_retrieval_results,
|
||||
)
|
||||
|
||||
if sub_question_retrieval_stats is None:
|
||||
sub_question_retrieval_stats = AgentChunkStats()
|
||||
# else:
|
||||
# sub_question_retrieval_stats = [sub_question_retrieval_stats]
|
||||
|
||||
return ExpandedRetrievalUpdate(
|
||||
expanded_retrieval_result=ExpandedRetrievalResult(
|
||||
expanded_queries_results=state.expanded_retrieval_results,
|
||||
all_documents=stream_documents,
|
||||
context_documents=state.reranked_documents,
|
||||
sub_question_retrieval_stats=sub_question_retrieval_stats,
|
||||
),
|
||||
)
|
||||
@@ -0,0 +1,44 @@
|
||||
from typing import cast
|
||||
from typing import Literal
|
||||
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
from langgraph.types import Command
|
||||
from langgraph.types import Send
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
|
||||
DocVerificationInput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
|
||||
ExpandedRetrievalState,
|
||||
)
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
|
||||
|
||||
def verification_kickoff(
|
||||
state: ExpandedRetrievalState,
|
||||
config: RunnableConfig,
|
||||
) -> Command[Literal["doc_verification"]]:
|
||||
documents = state.retrieved_documents
|
||||
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
verification_question = (
|
||||
state.question
|
||||
if hasattr(state, "question")
|
||||
else agent_a_config.search_request.query
|
||||
)
|
||||
sub_question_id = state.sub_question_id
|
||||
return Command(
|
||||
update={},
|
||||
goto=[
|
||||
Send(
|
||||
node="doc_verification",
|
||||
arg=DocVerificationInput(
|
||||
doc_to_verify=doc,
|
||||
question=verification_question,
|
||||
base_search=False,
|
||||
sub_question_id=sub_question_id,
|
||||
log_messages=[],
|
||||
),
|
||||
)
|
||||
for doc in documents
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,97 @@
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
|
||||
import numpy as np
|
||||
from langchain_core.callbacks.manager import dispatch_custom_event
|
||||
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkStats
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import QueryResult
|
||||
from onyx.chat.models import SubQueryPiece
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def dispatch_subquery(level: int, question_nr: int) -> Callable[[str, int], None]:
|
||||
def helper(token: str, num: int) -> None:
|
||||
dispatch_custom_event(
|
||||
"subqueries",
|
||||
SubQueryPiece(
|
||||
sub_query=token,
|
||||
level=level,
|
||||
level_question_nr=question_nr,
|
||||
query_id=num,
|
||||
),
|
||||
)
|
||||
|
||||
return helper
|
||||
|
||||
|
||||
def calculate_sub_question_retrieval_stats(
|
||||
verified_documents: list[InferenceSection],
|
||||
expanded_retrieval_results: list[QueryResult],
|
||||
) -> AgentChunkStats:
|
||||
chunk_scores: dict[str, dict[str, list[int | float]]] = defaultdict(
|
||||
lambda: defaultdict(list)
|
||||
)
|
||||
|
||||
for expanded_retrieval_result in expanded_retrieval_results:
|
||||
for doc in expanded_retrieval_result.search_results:
|
||||
doc_chunk_id = f"{doc.center_chunk.document_id}_{doc.center_chunk.chunk_id}"
|
||||
if doc.center_chunk.score is not None:
|
||||
chunk_scores[doc_chunk_id]["score"].append(doc.center_chunk.score)
|
||||
|
||||
verified_doc_chunk_ids = [
|
||||
f"{verified_document.center_chunk.document_id}_{verified_document.center_chunk.chunk_id}"
|
||||
for verified_document in verified_documents
|
||||
]
|
||||
dismissed_doc_chunk_ids = []
|
||||
|
||||
raw_chunk_stats_counts: dict[str, int] = defaultdict(int)
|
||||
raw_chunk_stats_scores: dict[str, float] = defaultdict(float)
|
||||
for doc_chunk_id, chunk_data in chunk_scores.items():
|
||||
if doc_chunk_id in verified_doc_chunk_ids:
|
||||
raw_chunk_stats_counts["verified_count"] += 1
|
||||
|
||||
valid_chunk_scores = [
|
||||
score for score in chunk_data["score"] if score is not None
|
||||
]
|
||||
raw_chunk_stats_scores["verified_scores"] += float(
|
||||
np.mean(valid_chunk_scores)
|
||||
)
|
||||
else:
|
||||
raw_chunk_stats_counts["rejected_count"] += 1
|
||||
valid_chunk_scores = [
|
||||
score for score in chunk_data["score"] if score is not None
|
||||
]
|
||||
raw_chunk_stats_scores["rejected_scores"] += float(
|
||||
np.mean(valid_chunk_scores)
|
||||
)
|
||||
dismissed_doc_chunk_ids.append(doc_chunk_id)
|
||||
|
||||
if raw_chunk_stats_counts["verified_count"] == 0:
|
||||
verified_avg_scores = 0.0
|
||||
else:
|
||||
verified_avg_scores = raw_chunk_stats_scores["verified_scores"] / float(
|
||||
raw_chunk_stats_counts["verified_count"]
|
||||
)
|
||||
|
||||
rejected_scores = raw_chunk_stats_scores.get("rejected_scores", None)
|
||||
if rejected_scores is not None:
|
||||
rejected_avg_scores = rejected_scores / float(
|
||||
raw_chunk_stats_counts["rejected_count"]
|
||||
)
|
||||
else:
|
||||
rejected_avg_scores = None
|
||||
|
||||
chunk_stats = AgentChunkStats(
|
||||
verified_count=raw_chunk_stats_counts["verified_count"],
|
||||
verified_avg_scores=verified_avg_scores,
|
||||
rejected_count=raw_chunk_stats_counts["rejected_count"],
|
||||
rejected_avg_scores=rejected_avg_scores,
|
||||
verified_doc_chunk_ids=verified_doc_chunk_ids,
|
||||
dismissed_doc_chunk_ids=dismissed_doc_chunk_ids,
|
||||
)
|
||||
|
||||
return chunk_stats
|
||||
@@ -0,0 +1,91 @@
|
||||
from operator import add
|
||||
from typing import Annotated
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.agents.agent_search.core_state import SubgraphCoreState
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.models import (
|
||||
ExpandedRetrievalResult,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import QueryResult
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import RetrievalFitStats
|
||||
from onyx.agents.agent_search.shared_graph_utils.operators import (
|
||||
dedup_inference_sections,
|
||||
)
|
||||
from onyx.context.search.models import InferenceSection
|
||||
|
||||
|
||||
### States ###
|
||||
|
||||
## Graph Input State
|
||||
|
||||
|
||||
class ExpandedRetrievalInput(SubgraphCoreState):
|
||||
question: str = ""
|
||||
base_search: bool = False
|
||||
sub_question_id: str | None = None
|
||||
|
||||
|
||||
## Update/Return States
|
||||
|
||||
|
||||
class QueryExpansionUpdate(BaseModel):
|
||||
expanded_queries: list[str] = ["aaa", "bbb"]
|
||||
log_messages: list[str] = []
|
||||
|
||||
|
||||
class DocVerificationUpdate(BaseModel):
|
||||
verified_documents: Annotated[list[InferenceSection], dedup_inference_sections] = []
|
||||
|
||||
|
||||
class DocRetrievalUpdate(BaseModel):
|
||||
expanded_retrieval_results: Annotated[list[QueryResult], add] = []
|
||||
retrieved_documents: Annotated[
|
||||
list[InferenceSection], dedup_inference_sections
|
||||
] = []
|
||||
log_messages: list[str] = []
|
||||
|
||||
|
||||
class DocRerankingUpdate(BaseModel):
|
||||
reranked_documents: Annotated[list[InferenceSection], dedup_inference_sections] = []
|
||||
sub_question_retrieval_stats: RetrievalFitStats | None = None
|
||||
log_messages: list[str] = []
|
||||
|
||||
|
||||
class ExpandedRetrievalUpdate(BaseModel):
|
||||
expanded_retrieval_result: ExpandedRetrievalResult
|
||||
|
||||
|
||||
## Graph Output State
|
||||
|
||||
|
||||
class ExpandedRetrievalOutput(BaseModel):
|
||||
expanded_retrieval_result: ExpandedRetrievalResult = ExpandedRetrievalResult()
|
||||
base_expanded_retrieval_result: ExpandedRetrievalResult = ExpandedRetrievalResult()
|
||||
log_messages: list[str] = []
|
||||
|
||||
|
||||
## Graph State
|
||||
|
||||
|
||||
class ExpandedRetrievalState(
|
||||
# This includes the core state
|
||||
ExpandedRetrievalInput,
|
||||
QueryExpansionUpdate,
|
||||
DocRetrievalUpdate,
|
||||
DocVerificationUpdate,
|
||||
DocRerankingUpdate,
|
||||
ExpandedRetrievalOutput,
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
## Conditional Input States
|
||||
|
||||
|
||||
class DocVerificationInput(ExpandedRetrievalInput):
|
||||
doc_to_verify: InferenceSection
|
||||
|
||||
|
||||
class RetrievalInput(ExpandedRetrievalInput):
|
||||
query_to_retrieve: str = ""
|
||||
120
backend/onyx/agents/agent_search/deep_search_a/main/edges.py
Normal file
120
backend/onyx/agents/agent_search/deep_search_a/main/edges.py
Normal file
@@ -0,0 +1,120 @@
|
||||
from collections.abc import Hashable
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
from typing import Literal
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import Send
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.states import (
|
||||
AnswerQuestionInput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.states import (
|
||||
AnswerQuestionOutput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import MainState
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import (
|
||||
RequireRefinedAnswerUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import make_question_id
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def route_initial_tool_choice(
|
||||
state: MainState, config: RunnableConfig
|
||||
) -> Literal["tool_call", "agent_search_start", "logging_node"]:
|
||||
agent_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
if state.tool_choice is not None:
|
||||
if (
|
||||
agent_config.use_agentic_search
|
||||
and agent_config.search_tool is not None
|
||||
and state.tool_choice.tool.name == agent_config.search_tool.name
|
||||
):
|
||||
return "agent_search_start"
|
||||
else:
|
||||
return "tool_call"
|
||||
else:
|
||||
return "logging_node"
|
||||
|
||||
|
||||
def parallelize_initial_sub_question_answering(
|
||||
state: MainState,
|
||||
) -> list[Send | Hashable]:
|
||||
now_start = datetime.now()
|
||||
if len(state.initial_decomp_questions) > 0:
|
||||
# sub_question_record_ids = [subq_record.id for subq_record in state["sub_question_records"]]
|
||||
# if len(state["sub_question_records"]) == 0:
|
||||
# if state["config"].use_persistence:
|
||||
# raise ValueError("No sub-questions found for initial decompozed questions")
|
||||
# else:
|
||||
# # in this case, we are doing retrieval on the original question.
|
||||
# # to make all the logic consistent, we create a new sub-question
|
||||
# # with the same content as the original question
|
||||
# sub_question_record_ids = [1] * len(state["initial_decomp_questions"])
|
||||
|
||||
return [
|
||||
Send(
|
||||
"answer_query_subgraph",
|
||||
AnswerQuestionInput(
|
||||
question=question,
|
||||
question_id=make_question_id(0, question_nr + 1),
|
||||
log_messages=[
|
||||
f"{now_start} -- Main Edge - Parallelize Initial Sub-question Answering"
|
||||
],
|
||||
),
|
||||
)
|
||||
for question_nr, question in enumerate(state.initial_decomp_questions)
|
||||
]
|
||||
|
||||
else:
|
||||
return [
|
||||
Send(
|
||||
"ingest_answers",
|
||||
AnswerQuestionOutput(
|
||||
answer_results=[],
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
# Define the function that determines whether to continue or not
|
||||
def continue_to_refined_answer_or_end(
|
||||
state: RequireRefinedAnswerUpdate,
|
||||
) -> Literal["refined_sub_question_creation", "logging_node"]:
|
||||
if state.require_refined_answer:
|
||||
return "refined_sub_question_creation"
|
||||
else:
|
||||
return "logging_node"
|
||||
|
||||
|
||||
def parallelize_refined_sub_question_answering(
|
||||
state: MainState,
|
||||
) -> list[Send | Hashable]:
|
||||
now_start = datetime.now()
|
||||
if len(state.refined_sub_questions) > 0:
|
||||
return [
|
||||
Send(
|
||||
"answer_refined_question",
|
||||
AnswerQuestionInput(
|
||||
question=question_data.sub_question,
|
||||
question_id=make_question_id(1, question_nr),
|
||||
log_messages=[
|
||||
f"{now_start} -- Main Edge - Parallelize Refined Sub-question Answering"
|
||||
],
|
||||
),
|
||||
)
|
||||
for question_nr, question_data in state.refined_sub_questions.items()
|
||||
]
|
||||
|
||||
else:
|
||||
return [
|
||||
Send(
|
||||
"ingest_refined_sub_answers",
|
||||
AnswerQuestionOutput(
|
||||
answer_results=[],
|
||||
),
|
||||
)
|
||||
]
|
||||
@@ -0,0 +1,423 @@
|
||||
from langgraph.graph import END
|
||||
from langgraph.graph import START
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.graph_builder import (
|
||||
answer_query_graph_builder,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.answer_refinement_sub_question.graph_builder import (
|
||||
answer_refined_query_graph_builder,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.base_raw_search.graph_builder import (
|
||||
base_raw_search_graph_builder,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.edges import (
|
||||
continue_to_refined_answer_or_end,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.edges import (
|
||||
parallelize_initial_sub_question_answering,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.edges import (
|
||||
parallelize_refined_sub_question_answering,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.edges import (
|
||||
route_initial_tool_choice,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.nodes.agent_logging import (
|
||||
agent_logging,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.nodes.agent_search_start import (
|
||||
agent_search_start,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.nodes.answer_comparison import (
|
||||
answer_comparison,
|
||||
)
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.main.nodes.entity_term_extraction_llm import (
|
||||
entity_term_extraction_llm,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.nodes.direct_llm_handling import (
|
||||
direct_llm_handling,
|
||||
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.nodes.generate_initial_answer import (
|
||||
generate_initial_answer,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.nodes.generate_refined_answer import (
|
||||
generate_refined_answer,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.nodes.ingest_initial_base_retrieval import (
|
||||
ingest_initial_base_retrieval,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.nodes.ingest_initial_sub_question_answers import (
|
||||
ingest_initial_sub_question_answers,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.nodes.ingest_refined_answers import (
|
||||
ingest_refined_answers,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.nodes.initial_answer_quality_check import (
|
||||
initial_answer_quality_check,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.nodes.initial_sub_question_creation import (
|
||||
initial_sub_question_creation,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.nodes.refined_answer_decision import (
|
||||
refined_answer_decision,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.nodes.refined_sub_question_creation import (
|
||||
refined_sub_question_creation,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.nodes.retrieval_consolidation import (
|
||||
retrieval_consolidation,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import MainInput
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import MainState
|
||||
from onyx.agents.agent_search.orchestration.nodes.basic_use_tool_response import (
|
||||
basic_use_tool_response,
|
||||
)
|
||||
from onyx.agents.agent_search.orchestration.nodes.llm_tool_choice import llm_tool_choice
|
||||
from onyx.agents.agent_search.orchestration.nodes.prepare_tool_input import (
|
||||
prepare_tool_input,
|
||||
)
|
||||
from onyx.agents.agent_search.orchestration.nodes.tool_call import tool_call
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
test_mode = False
|
||||
|
||||
|
||||
def main_graph_builder(test_mode: bool = False) -> StateGraph:
|
||||
graph = StateGraph(
|
||||
state_schema=MainState,
|
||||
input=MainInput,
|
||||
)
|
||||
|
||||
# graph.add_node(
|
||||
# node="agent_path_decision",
|
||||
# action=agent_path_decision,
|
||||
# )
|
||||
|
||||
# graph.add_node(
|
||||
# node="agent_path_routing",
|
||||
# action=agent_path_routing,
|
||||
# )
|
||||
|
||||
# graph.add_node(
|
||||
# node="LLM",
|
||||
# action=direct_llm_handling,
|
||||
# )
|
||||
graph.add_node(
|
||||
node="prepare_tool_input",
|
||||
action=prepare_tool_input,
|
||||
)
|
||||
graph.add_node(
|
||||
node="initial_tool_choice",
|
||||
action=llm_tool_choice,
|
||||
)
|
||||
graph.add_node(
|
||||
node="tool_call",
|
||||
action=tool_call,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
node="basic_use_tool_response",
|
||||
action=basic_use_tool_response,
|
||||
)
|
||||
graph.add_node(
|
||||
node="agent_search_start",
|
||||
action=agent_search_start,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
node="initial_sub_question_creation",
|
||||
action=initial_sub_question_creation,
|
||||
)
|
||||
answer_query_subgraph = answer_query_graph_builder().compile()
|
||||
graph.add_node(
|
||||
node="answer_query_subgraph",
|
||||
action=answer_query_subgraph,
|
||||
)
|
||||
|
||||
base_raw_search_subgraph = base_raw_search_graph_builder().compile()
|
||||
graph.add_node(
|
||||
node="base_raw_search_subgraph",
|
||||
action=base_raw_search_subgraph,
|
||||
)
|
||||
|
||||
# refined_answer_subgraph = refined_answers_graph_builder().compile()
|
||||
# graph.add_node(
|
||||
# node="refined_answer_subgraph",
|
||||
# action=refined_answer_subgraph,
|
||||
# )
|
||||
|
||||
graph.add_node(
|
||||
node="refined_sub_question_creation",
|
||||
action=refined_sub_question_creation,
|
||||
)
|
||||
|
||||
answer_refined_question = answer_refined_query_graph_builder().compile()
|
||||
graph.add_node(
|
||||
node="answer_refined_question",
|
||||
action=answer_refined_question,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
node="ingest_refined_answers",
|
||||
action=ingest_refined_answers,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
node="generate_refined_answer",
|
||||
action=generate_refined_answer,
|
||||
)
|
||||
|
||||
# graph.add_node(
|
||||
# node="check_refined_answer",
|
||||
# action=check_refined_answer,
|
||||
# )
|
||||
|
||||
graph.add_node(
|
||||
node="ingest_initial_retrieval",
|
||||
action=ingest_initial_base_retrieval,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
node="retrieval_consolidation",
|
||||
action=retrieval_consolidation,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
node="ingest_initial_sub_question_answers",
|
||||
action=ingest_initial_sub_question_answers,
|
||||
)
|
||||
graph.add_node(
|
||||
node="generate_initial_answer",
|
||||
action=generate_initial_answer,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
node="initial_answer_quality_check",
|
||||
action=initial_answer_quality_check,
|
||||
)
|
||||
|
||||
# graph.add_node(
|
||||
# node="entity_term_extraction_llm",
|
||||
# action=entity_term_extraction_llm,
|
||||
# )
|
||||
graph.add_node(
|
||||
node="refined_answer_decision",
|
||||
action=refined_answer_decision,
|
||||
)
|
||||
graph.add_node(
|
||||
node="answer_comparison",
|
||||
action=answer_comparison,
|
||||
)
|
||||
graph.add_node(
|
||||
node="logging_node",
|
||||
action=agent_logging,
|
||||
)
|
||||
# if test_mode:
|
||||
# graph.add_node(
|
||||
# node="generate_initial_base_answer",
|
||||
# action=generate_initial_base_answer,
|
||||
# )
|
||||
|
||||
### Add edges ###
|
||||
|
||||
# raph.add_edge(start_key=START, end_key="base_raw_search_subgraph")
|
||||
|
||||
# graph.add_edge(
|
||||
# start_key=START,
|
||||
# end_key="agent_path_decision",
|
||||
# )
|
||||
|
||||
# graph.add_edge(
|
||||
# start_key="agent_path_decision",
|
||||
# end_key="agent_path_routing",
|
||||
# )
|
||||
graph.add_edge(start_key=START, end_key="prepare_tool_input")
|
||||
|
||||
graph.add_edge(
|
||||
start_key="prepare_tool_input",
|
||||
end_key="initial_tool_choice",
|
||||
)
|
||||
|
||||
graph.add_conditional_edges(
|
||||
"initial_tool_choice",
|
||||
route_initial_tool_choice,
|
||||
["tool_call", "agent_search_start", "logging_node"],
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key="tool_call",
|
||||
end_key="basic_use_tool_response",
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="basic_use_tool_response",
|
||||
end_key="logging_node",
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key="agent_search_start",
|
||||
end_key="base_raw_search_subgraph",
|
||||
)
|
||||
|
||||
# graph.add_edge(
|
||||
# start_key="agent_search_start",
|
||||
# end_key="entity_term_extraction_llm",
|
||||
# )
|
||||
|
||||
graph.add_edge(
|
||||
start_key="agent_search_start",
|
||||
end_key="initial_sub_question_creation",
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key="base_raw_search_subgraph",
|
||||
end_key="ingest_initial_retrieval",
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key=["ingest_initial_retrieval", "ingest_initial_sub_question_answers"],
|
||||
end_key="retrieval_consolidation",
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key="retrieval_consolidation",
|
||||
end_key="generate_initial_answer",
|
||||
)
|
||||
|
||||
# graph.add_edge(
|
||||
# start_key="LLM",
|
||||
# end_key=END,
|
||||
# )
|
||||
|
||||
# graph.add_edge(
|
||||
# start_key=START,
|
||||
# end_key="initial_sub_question_creation",
|
||||
# )
|
||||
|
||||
graph.add_conditional_edges(
|
||||
source="initial_sub_question_creation",
|
||||
path=parallelize_initial_sub_question_answering,
|
||||
path_map=["answer_query_subgraph"],
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="answer_query_subgraph",
|
||||
end_key="ingest_initial_sub_question_answers",
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key="retrieval_consolidation",
|
||||
end_key="generate_initial_answer",
|
||||
)
|
||||
|
||||
# graph.add_edge(
|
||||
# start_key="generate_initial_answer",
|
||||
# end_key="entity_term_extraction_llm",
|
||||
# )
|
||||
|
||||
graph.add_edge(
|
||||
start_key="generate_initial_answer",
|
||||
end_key="initial_answer_quality_check",
|
||||
)
|
||||
|
||||
# graph.add_edge(
|
||||
# start_key=["initial_answer_quality_check", "entity_term_extraction_llm"],
|
||||
# end_key="refined_answer_decision",
|
||||
# )
|
||||
graph.add_edge(
|
||||
start_key="initial_answer_quality_check",
|
||||
end_key="refined_answer_decision",
|
||||
)
|
||||
|
||||
graph.add_conditional_edges(
|
||||
source="refined_answer_decision",
|
||||
path=continue_to_refined_answer_or_end,
|
||||
path_map=["refined_sub_question_creation", "logging_node"],
|
||||
)
|
||||
|
||||
graph.add_conditional_edges(
|
||||
source="refined_sub_question_creation", # DONE
|
||||
path=parallelize_refined_sub_question_answering,
|
||||
path_map=["answer_refined_question"],
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="answer_refined_question", # HERE
|
||||
end_key="ingest_refined_answers",
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key="ingest_refined_answers",
|
||||
end_key="generate_refined_answer",
|
||||
)
|
||||
|
||||
# graph.add_conditional_edges(
|
||||
# source="refined_answer_decision",
|
||||
# path=continue_to_refined_answer_or_end,
|
||||
# path_map=["refined_answer_subgraph", END],
|
||||
# )
|
||||
|
||||
# graph.add_edge(
|
||||
# start_key="refined_answer_subgraph",
|
||||
# end_key="generate_refined_answer",
|
||||
# )
|
||||
|
||||
graph.add_edge(
|
||||
start_key="generate_refined_answer",
|
||||
end_key="answer_comparison",
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="answer_comparison",
|
||||
end_key="logging_node",
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key="logging_node",
|
||||
end_key=END,
|
||||
)
|
||||
|
||||
# graph.add_edge(
|
||||
# start_key="generate_refined_answer",
|
||||
# end_key="check_refined_answer",
|
||||
# )
|
||||
|
||||
# graph.add_edge(
|
||||
# start_key="check_refined_answer",
|
||||
# end_key=END,
|
||||
# )
|
||||
|
||||
return graph
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pass
|
||||
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.llm.factory import get_default_llms
|
||||
from onyx.context.search.models import SearchRequest
|
||||
|
||||
graph = main_graph_builder()
|
||||
compiled_graph = graph.compile()
|
||||
primary_llm, fast_llm = get_default_llms()
|
||||
|
||||
with get_session_context_manager() as db_session:
|
||||
search_request = SearchRequest(query="Who created Excel?")
|
||||
agent_a_config, search_tool = get_test_config(
|
||||
db_session, primary_llm, fast_llm, search_request
|
||||
)
|
||||
|
||||
inputs = MainInput(
|
||||
base_question=agent_a_config.search_request.query, log_messages=[]
|
||||
)
|
||||
|
||||
for thing in compiled_graph.stream(
|
||||
input=inputs,
|
||||
config={"configurable": {"config": agent_a_config}},
|
||||
# stream_mode="debug",
|
||||
# debug=True,
|
||||
subgraphs=True,
|
||||
):
|
||||
logger.debug(thing)
|
||||
@@ -0,0 +1,36 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class FollowUpSubQuestion(BaseModel):
|
||||
sub_question: str
|
||||
sub_question_id: str
|
||||
verified: bool
|
||||
answered: bool
|
||||
answer: str
|
||||
|
||||
|
||||
class AgentTimings(BaseModel):
|
||||
base_duration__s: float | None
|
||||
refined_duration__s: float | None
|
||||
full_duration__s: float | None
|
||||
|
||||
|
||||
class AgentBaseMetrics(BaseModel):
|
||||
num_verified_documents_total: int | None
|
||||
num_verified_documents_core: int | None
|
||||
verified_avg_score_core: float | None
|
||||
num_verified_documents_base: int | float | None
|
||||
verified_avg_score_base: float | None = None
|
||||
base_doc_boost_factor: float | None = None
|
||||
support_boost_factor: float | None = None
|
||||
duration__s: float | None = None
|
||||
|
||||
|
||||
class AgentRefinedMetrics(BaseModel):
|
||||
refined_doc_boost_factor: float | None = None
|
||||
refined_question_boost_factor: float | None = None
|
||||
duration__s: float | None = None
|
||||
|
||||
|
||||
class AgentAdditionalMetrics(BaseModel):
|
||||
pass
|
||||
@@ -0,0 +1,115 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.main.models import AgentAdditionalMetrics
|
||||
from onyx.agents.agent_search.deep_search_a.main.models import AgentTimings
|
||||
from onyx.agents.agent_search.deep_search_a.main.operations import logger
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import MainOutput
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import MainState
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import CombinedAgentMetrics
|
||||
from onyx.db.chat import log_agent_metrics
|
||||
from onyx.db.chat import log_agent_sub_question_results
|
||||
|
||||
|
||||
def agent_logging(state: MainState, config: RunnableConfig) -> MainOutput:
|
||||
now_start = datetime.now()
|
||||
|
||||
logger.debug(f"--------{now_start}--------LOGGING NODE---")
|
||||
|
||||
agent_start_time = state.agent_start_time
|
||||
agent_base_end_time = state.agent_base_end_time
|
||||
agent_refined_start_time = state.agent_refined_start_time
|
||||
agent_refined_end_time = state.agent_refined_end_time
|
||||
agent_end_time = agent_refined_end_time or agent_base_end_time
|
||||
|
||||
agent_base_duration = None
|
||||
if agent_base_end_time:
|
||||
agent_base_duration = (agent_base_end_time - agent_start_time).total_seconds()
|
||||
|
||||
agent_refined_duration = None
|
||||
if agent_refined_start_time and agent_refined_end_time:
|
||||
agent_refined_duration = (
|
||||
agent_refined_end_time - agent_refined_start_time
|
||||
).total_seconds()
|
||||
|
||||
agent_full_duration = None
|
||||
if agent_end_time:
|
||||
agent_full_duration = (agent_end_time - agent_start_time).total_seconds()
|
||||
|
||||
agent_type = "refined" if agent_refined_duration else "base"
|
||||
|
||||
agent_base_metrics = state.agent_base_metrics
|
||||
agent_refined_metrics = state.agent_refined_metrics
|
||||
|
||||
combined_agent_metrics = CombinedAgentMetrics(
|
||||
timings=AgentTimings(
|
||||
base_duration__s=agent_base_duration,
|
||||
refined_duration__s=agent_refined_duration,
|
||||
full_duration__s=agent_full_duration,
|
||||
),
|
||||
base_metrics=agent_base_metrics,
|
||||
refined_metrics=agent_refined_metrics,
|
||||
additional_metrics=AgentAdditionalMetrics(),
|
||||
)
|
||||
|
||||
persona_id = None
|
||||
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
if agent_a_config.search_request.persona:
|
||||
persona_id = agent_a_config.search_request.persona.id
|
||||
|
||||
user_id = None
|
||||
if agent_a_config.search_tool is not None:
|
||||
user = agent_a_config.search_tool.user
|
||||
if user:
|
||||
user_id = user.id
|
||||
|
||||
# log the agent metrics
|
||||
if agent_a_config.db_session is not None:
|
||||
if agent_base_duration is not None:
|
||||
log_agent_metrics(
|
||||
db_session=agent_a_config.db_session,
|
||||
user_id=user_id,
|
||||
persona_id=persona_id,
|
||||
agent_type=agent_type,
|
||||
start_time=agent_start_time,
|
||||
agent_metrics=combined_agent_metrics,
|
||||
)
|
||||
|
||||
if agent_a_config.use_persistence:
|
||||
# Persist the sub-answer in the database
|
||||
db_session = agent_a_config.db_session
|
||||
chat_session_id = agent_a_config.chat_session_id
|
||||
primary_message_id = agent_a_config.message_id
|
||||
sub_question_answer_results = state.decomp_answer_results
|
||||
|
||||
log_agent_sub_question_results(
|
||||
db_session=db_session,
|
||||
chat_session_id=chat_session_id,
|
||||
primary_message_id=primary_message_id,
|
||||
sub_question_answer_results=sub_question_answer_results,
|
||||
)
|
||||
|
||||
# if chat_session_id is not None and primary_message_id is not None and sub_question_id is not None:
|
||||
# create_sub_answer(
|
||||
# db_session=db_session,
|
||||
# chat_session_id=chat_session_id,
|
||||
# primary_message_id=primary_message_id,
|
||||
# sub_question_id=sub_question_id,
|
||||
# answer=answer_str,
|
||||
# # )
|
||||
# pass
|
||||
|
||||
now_end = datetime.now()
|
||||
main_output = MainOutput(
|
||||
log_messages=[
|
||||
f"{now_end} -- Main - Logging, Time taken: {now_end - now_start}"
|
||||
],
|
||||
)
|
||||
|
||||
logger.debug(f"--------{now_end}--{now_end - now_start}--------LOGGING NODE END---")
|
||||
logger.debug(f"--------{now_end}--{now_end - now_start}--------LOGGING NODE END---")
|
||||
|
||||
return main_output
|
||||
@@ -0,0 +1,92 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.main.operations import logger
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import MainState
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import RoutingDecision
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
build_history_prompt,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import AGENT_DECISION_PROMPT
|
||||
from onyx.llm.utils import check_number_of_tokens
|
||||
|
||||
|
||||
def agent_path_decision(state: MainState, config: RunnableConfig) -> RoutingDecision:
|
||||
now_start = datetime.now()
|
||||
|
||||
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
question = agent_a_config.search_request.query
|
||||
# perform_initial_search_path_decision = (
|
||||
# agent_a_config.perform_initial_search_path_decision
|
||||
# )
|
||||
|
||||
history = build_history_prompt(agent_a_config.prompt_builder)
|
||||
|
||||
logger.debug(f"--------{now_start}--------DECIDING TO SEARCH OR GO TO LLM---")
|
||||
|
||||
# if perform_initial_search_path_decision:
|
||||
# search_tool = agent_a_config.search_tool
|
||||
# retrieved_docs: list[InferenceSection] = []
|
||||
|
||||
# # new db session to avoid concurrency issues
|
||||
# with get_session_context_manager() as db_session:
|
||||
# for tool_response in search_tool.run(
|
||||
# query=question,
|
||||
# force_no_rerank=True,
|
||||
# alternate_db_session=db_session,
|
||||
# ):
|
||||
# # get retrieved docs to send to the rest of the graph
|
||||
# if tool_response.id == SEARCH_RESPONSE_SUMMARY_ID:
|
||||
# response = cast(SearchResponseSummary, tool_response.response)
|
||||
# retrieved_docs = response.top_sections
|
||||
# break
|
||||
|
||||
# sample_doc_str = "\n\n".join(
|
||||
# [doc.combined_content for _, doc in enumerate(retrieved_docs[:3])]
|
||||
# )
|
||||
|
||||
# agent_decision_prompt = AGENT_DECISION_PROMPT_AFTER_SEARCH.format(
|
||||
# question=question, sample_doc_str=sample_doc_str, history=history
|
||||
# )
|
||||
|
||||
# else:
|
||||
sample_doc_str = ""
|
||||
agent_decision_prompt = AGENT_DECISION_PROMPT.format(
|
||||
question=question, history=history
|
||||
)
|
||||
|
||||
msg = [HumanMessage(content=agent_decision_prompt)]
|
||||
|
||||
# Get the rewritten queries in a defined format
|
||||
model = agent_a_config.fast_llm
|
||||
|
||||
# no need to stream this
|
||||
resp = model.invoke(msg)
|
||||
|
||||
if isinstance(resp.content, str) and "research" in resp.content.lower():
|
||||
routing = "agent_search"
|
||||
else:
|
||||
routing = "LLM"
|
||||
|
||||
routing = "agent_search"
|
||||
|
||||
now_end = datetime.now()
|
||||
|
||||
logger.debug(
|
||||
f"--------{now_end}--{now_end - now_start}--------DECIDING TO SEARCH OR GO TO LLM END---"
|
||||
)
|
||||
|
||||
check_number_of_tokens(agent_decision_prompt)
|
||||
|
||||
return RoutingDecision(
|
||||
# Decide which route to take
|
||||
routing=routing,
|
||||
sample_doc_str=sample_doc_str,
|
||||
log_messages=[
|
||||
f"{now_end} -- Path decision: {routing}, Time taken: {now_end - now_start}"
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,31 @@
|
||||
from datetime import datetime
|
||||
from typing import Literal
|
||||
|
||||
from langgraph.types import Command
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import MainState
|
||||
|
||||
|
||||
def agent_path_routing(
|
||||
state: MainState,
|
||||
) -> Command[Literal["agent_search_start", "LLM"]]:
|
||||
now_start = datetime.now()
|
||||
routing = state.routing if hasattr(state, "routing") else "agent_search"
|
||||
|
||||
if routing == "agent_search":
|
||||
agent_path = "agent_search_start"
|
||||
else:
|
||||
agent_path = "LLM"
|
||||
|
||||
now_end = datetime.now()
|
||||
|
||||
return Command(
|
||||
# state update
|
||||
update={
|
||||
"log_messages": [
|
||||
f"{now_end} -- Main - Path routing: {agent_path}, Time taken: {now_end - now_start}"
|
||||
]
|
||||
},
|
||||
# control flow
|
||||
goto=agent_path,
|
||||
)
|
||||
@@ -0,0 +1,50 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.main.operations import logger
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import ExploratorySearchUpdate
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import MainState
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import retrieve_search_docs
|
||||
from onyx.configs.agent_configs import AGENT_EXPLORATORY_SEARCH_RESULTS
|
||||
from onyx.context.search.models import InferenceSection
|
||||
|
||||
|
||||
def agent_search_start(
|
||||
state: MainState, config: RunnableConfig
|
||||
) -> ExploratorySearchUpdate:
|
||||
now_start = datetime.now()
|
||||
|
||||
logger.debug(f"--------{now_start}--------EXPLORATORY SEARCH START---")
|
||||
|
||||
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
question = agent_a_config.search_request.query
|
||||
chat_session_id = agent_a_config.chat_session_id
|
||||
primary_message_id = agent_a_config.message_id
|
||||
|
||||
if chat_session_id is None or primary_message_id is None:
|
||||
raise ValueError(
|
||||
"chat_session_id and message_id must be provided for agent search"
|
||||
)
|
||||
|
||||
# Initial search to inform decomposition. Just get top 3 fits
|
||||
|
||||
search_tool = agent_a_config.search_tool
|
||||
if search_tool is None:
|
||||
raise ValueError("search_tool must be provided for agentic search")
|
||||
retrieved_docs: list[InferenceSection] = retrieve_search_docs(search_tool, question)
|
||||
|
||||
exploratory_search_results = retrieved_docs[:AGENT_EXPLORATORY_SEARCH_RESULTS]
|
||||
now_end = datetime.now()
|
||||
logger.debug(
|
||||
f"--------{now_end}--{now_end - now_start}--------EXPLORATORY SEARCH END---"
|
||||
)
|
||||
|
||||
return ExploratorySearchUpdate(
|
||||
exploratory_search_results=exploratory_search_results,
|
||||
log_messages=[
|
||||
f"{now_start} -- Main - Exploratory Search, Time taken: {now_end - now_start}"
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,60 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.callbacks.manager import dispatch_custom_event
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.main.operations import logger
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import AnswerComparison
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import MainState
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import ANSWER_COMPARISON_PROMPT
|
||||
from onyx.chat.models import RefinedAnswerImprovement
|
||||
|
||||
|
||||
def answer_comparison(state: MainState, config: RunnableConfig) -> AnswerComparison:
|
||||
now_start = datetime.now()
|
||||
|
||||
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
question = agent_a_config.search_request.query
|
||||
initial_answer = state.initial_answer
|
||||
refined_answer = state.refined_answer
|
||||
|
||||
logger.debug(f"--------{now_start}--------ANSWER COMPARISON STARTED--")
|
||||
|
||||
answer_comparison_prompt = ANSWER_COMPARISON_PROMPT.format(
|
||||
question=question, initial_answer=initial_answer, refined_answer=refined_answer
|
||||
)
|
||||
|
||||
msg = [HumanMessage(content=answer_comparison_prompt)]
|
||||
|
||||
# Get the rewritten queries in a defined format
|
||||
model = agent_a_config.fast_llm
|
||||
|
||||
# no need to stream this
|
||||
resp = model.invoke(msg)
|
||||
|
||||
refined_answer_improvement = (
|
||||
isinstance(resp.content, str) and "yes" in resp.content.lower()
|
||||
)
|
||||
|
||||
dispatch_custom_event(
|
||||
"refined_answer_improvement",
|
||||
RefinedAnswerImprovement(
|
||||
refined_answer_improvement=refined_answer_improvement,
|
||||
),
|
||||
)
|
||||
|
||||
now_end = datetime.now()
|
||||
|
||||
logger.debug(
|
||||
f"--------{now_end}--{now_end - now_start}--------ANSWER COMPARISON COMPLETED---"
|
||||
)
|
||||
|
||||
return AnswerComparison(
|
||||
refined_answer_improvement=refined_answer_improvement,
|
||||
log_messages=[
|
||||
f"{now_start} -- Answer comparison: {refined_answer_improvement}, Time taken: {now_end - now_start}"
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,89 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.callbacks.manager import dispatch_custom_event
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import merge_content
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.main.operations import logger
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import InitialAnswerUpdate
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import MainState
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import (
|
||||
ASSISTANT_SYSTEM_PROMPT_DEFAULT,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import (
|
||||
ASSISTANT_SYSTEM_PROMPT_PERSONA,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import DIRECT_LLM_PROMPT
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import get_persona_prompt
|
||||
from onyx.chat.models import AgentAnswerPiece
|
||||
|
||||
|
||||
def direct_llm_handling(
|
||||
state: MainState, config: RunnableConfig
|
||||
) -> InitialAnswerUpdate:
|
||||
now_start = datetime.now()
|
||||
|
||||
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
question = agent_a_config.search_request.query
|
||||
persona_prompt = get_persona_prompt(agent_a_config.search_request.persona)
|
||||
|
||||
if len(persona_prompt) == 0:
|
||||
persona_specification = ASSISTANT_SYSTEM_PROMPT_DEFAULT
|
||||
else:
|
||||
persona_specification = ASSISTANT_SYSTEM_PROMPT_PERSONA.format(
|
||||
persona_prompt=persona_prompt
|
||||
)
|
||||
|
||||
logger.debug(f"--------{now_start}--------LLM HANDLING START---")
|
||||
|
||||
model = agent_a_config.fast_llm
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=DIRECT_LLM_PROMPT.format(
|
||||
persona_specification=persona_specification, question=question
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
streamed_tokens: list[str | list[str | dict[str, Any]]] = [""]
|
||||
|
||||
for message in model.stream(msg):
|
||||
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
|
||||
content = message.content
|
||||
if not isinstance(content, str):
|
||||
raise ValueError(
|
||||
f"Expected content to be a string, but got {type(content)}"
|
||||
)
|
||||
dispatch_custom_event(
|
||||
"initial_agent_answer",
|
||||
AgentAnswerPiece(
|
||||
answer_piece=content,
|
||||
level=0,
|
||||
level_question_nr=0,
|
||||
answer_type="agent_level_answer",
|
||||
),
|
||||
)
|
||||
streamed_tokens.append(content)
|
||||
|
||||
response = merge_content(*streamed_tokens)
|
||||
answer = cast(str, response)
|
||||
|
||||
now_end = datetime.now()
|
||||
|
||||
logger.debug(f"--------{now_end}--{now_end - now_start}--------LLM HANDLING END---")
|
||||
|
||||
return InitialAnswerUpdate(
|
||||
initial_answer=answer,
|
||||
initial_agent_stats=None,
|
||||
generated_sub_questions=[],
|
||||
agent_base_end_time=now_end,
|
||||
agent_base_metrics=None,
|
||||
log_messages=[
|
||||
f"{now_end} -- Main - LLM handling: {answer}, Time taken: {now_end - now_start}"
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,126 @@
|
||||
import json
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.main.operations import logger
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import (
|
||||
EntityTermExtractionUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import MainState
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
trim_prompt_piece,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import Entity
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import (
|
||||
EntityRelationshipTermExtraction,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import Relationship
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import Term
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import ENTITY_TERM_PROMPT
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import format_docs
|
||||
|
||||
|
||||
def entity_term_extraction_llm(
|
||||
state: MainState, config: RunnableConfig
|
||||
) -> EntityTermExtractionUpdate:
|
||||
now_start = datetime.now()
|
||||
|
||||
logger.debug(f"--------{now_start}--------GENERATE ENTITIES & TERMS---")
|
||||
|
||||
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
if not agent_a_config.allow_refinement:
|
||||
now_end = datetime.now()
|
||||
return EntityTermExtractionUpdate(
|
||||
entity_relation_term_extractions=EntityRelationshipTermExtraction(
|
||||
entities=[],
|
||||
relationships=[],
|
||||
terms=[],
|
||||
),
|
||||
log_messages=[
|
||||
f"{now_end} -- Main - ETR Extraction, Time taken: {now_end - now_start}"
|
||||
],
|
||||
)
|
||||
|
||||
# first four lines duplicates from generate_initial_answer
|
||||
question = agent_a_config.search_request.query
|
||||
initial_search_docs = state.exploratory_search_results[:15]
|
||||
|
||||
# start with the entity/term/extraction
|
||||
doc_context = format_docs(initial_search_docs)
|
||||
|
||||
doc_context = trim_prompt_piece(
|
||||
agent_a_config.fast_llm.config, doc_context, ENTITY_TERM_PROMPT + question
|
||||
)
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=ENTITY_TERM_PROMPT.format(question=question, context=doc_context),
|
||||
)
|
||||
]
|
||||
fast_llm = agent_a_config.fast_llm
|
||||
# Grader
|
||||
llm_response = fast_llm.invoke(
|
||||
prompt=msg,
|
||||
)
|
||||
|
||||
cleaned_response = re.sub(r"```json\n|\n```", "", str(llm_response.content))
|
||||
parsed_response = json.loads(cleaned_response)
|
||||
|
||||
entities = []
|
||||
relationships = []
|
||||
terms = []
|
||||
for entity in parsed_response.get("retrieved_entities_relationships", {}).get(
|
||||
"entities", {}
|
||||
):
|
||||
entity_name = entity.get("entity_name", "")
|
||||
entity_type = entity.get("entity_type", "")
|
||||
entities.append(Entity(entity_name=entity_name, entity_type=entity_type))
|
||||
|
||||
for relationship in parsed_response.get("retrieved_entities_relationships", {}).get(
|
||||
"relationships", {}
|
||||
):
|
||||
relationship_name = relationship.get("relationship_name", "")
|
||||
relationship_type = relationship.get("relationship_type", "")
|
||||
relationship_entities = relationship.get("relationship_entities", [])
|
||||
relationships.append(
|
||||
Relationship(
|
||||
relationship_name=relationship_name,
|
||||
relationship_type=relationship_type,
|
||||
relationship_entities=relationship_entities,
|
||||
)
|
||||
)
|
||||
|
||||
for term in parsed_response.get("retrieved_entities_relationships", {}).get(
|
||||
"terms", {}
|
||||
):
|
||||
term_name = term.get("term_name", "")
|
||||
term_type = term.get("term_type", "")
|
||||
term_similar_to = term.get("term_similar_to", [])
|
||||
terms.append(
|
||||
Term(
|
||||
term_name=term_name,
|
||||
term_type=term_type,
|
||||
term_similar_to=term_similar_to,
|
||||
)
|
||||
)
|
||||
|
||||
now_end = datetime.now()
|
||||
|
||||
logger.debug(
|
||||
f"--------{now_end}--{now_end - now_start}--------ENTITY TERM EXTRACTION END---"
|
||||
)
|
||||
|
||||
return EntityTermExtractionUpdate(
|
||||
entity_relation_term_extractions=EntityRelationshipTermExtraction(
|
||||
entities=entities,
|
||||
relationships=relationships,
|
||||
terms=terms,
|
||||
),
|
||||
log_messages=[
|
||||
f"{now_start} -- Main - ETR Extraction, Time taken: {now_end - now_start}"
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,270 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.callbacks.manager import dispatch_custom_event
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import merge_content
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.main.models import AgentBaseMetrics
|
||||
from onyx.agents.agent_search.deep_search_a.main.operations import (
|
||||
calculate_initial_agent_stats,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.operations import get_query_info
|
||||
from onyx.agents.agent_search.deep_search_a.main.operations import logger
|
||||
from onyx.agents.agent_search.deep_search_a.main.operations import (
|
||||
remove_document_citations,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import InitialAnswerUpdate
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import MainState
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
build_history_prompt,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
trim_prompt_piece,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import InitialAgentResultStats
|
||||
from onyx.agents.agent_search.shared_graph_utils.operators import (
|
||||
dedup_inference_sections,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import (
|
||||
ASSISTANT_SYSTEM_PROMPT_DEFAULT,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import (
|
||||
ASSISTANT_SYSTEM_PROMPT_PERSONA,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import INITIAL_RAG_PROMPT
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import (
|
||||
INITIAL_RAG_PROMPT_NO_SUB_QUESTIONS,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import (
|
||||
SUB_QUESTION_ANSWER_TEMPLATE,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import UNKNOWN_ANSWER
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
dispatch_main_answer_stop_info,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import format_docs
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import get_persona_prompt
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import get_today_prompt
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
|
||||
from onyx.chat.models import AgentAnswerPiece
|
||||
from onyx.chat.models import ExtendedToolResponse
|
||||
from onyx.tools.tool_implementations.search.search_tool import yield_search_responses
|
||||
|
||||
|
||||
def generate_initial_answer(
|
||||
state: MainState, config: RunnableConfig
|
||||
) -> InitialAnswerUpdate:
|
||||
now_start = datetime.now()
|
||||
|
||||
logger.debug(f"--------{now_start}--------GENERATE INITIAL---")
|
||||
|
||||
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
question = agent_a_config.search_request.query
|
||||
persona_prompt = get_persona_prompt(agent_a_config.search_request.persona)
|
||||
|
||||
history = build_history_prompt(agent_a_config.prompt_builder)
|
||||
date_str = get_today_prompt()
|
||||
|
||||
sub_question_docs = state.context_documents
|
||||
all_original_question_documents = state.all_original_question_documents
|
||||
|
||||
relevant_docs = dedup_inference_sections(
|
||||
sub_question_docs, all_original_question_documents
|
||||
)
|
||||
decomp_questions = []
|
||||
|
||||
if len(relevant_docs) == 0:
|
||||
dispatch_custom_event(
|
||||
"initial_agent_answer",
|
||||
AgentAnswerPiece(
|
||||
answer_piece=UNKNOWN_ANSWER,
|
||||
level=0,
|
||||
level_question_nr=0,
|
||||
answer_type="agent_level_answer",
|
||||
),
|
||||
)
|
||||
dispatch_main_answer_stop_info(0)
|
||||
|
||||
answer = UNKNOWN_ANSWER
|
||||
initial_agent_stats = InitialAgentResultStats(
|
||||
sub_questions={},
|
||||
original_question={},
|
||||
agent_effectiveness={},
|
||||
)
|
||||
|
||||
else:
|
||||
# Use the query info from the base document retrieval
|
||||
query_info = get_query_info(state.original_question_retrieval_results)
|
||||
if agent_a_config.search_tool is None:
|
||||
raise ValueError("search_tool must be provided for agentic search")
|
||||
for tool_response in yield_search_responses(
|
||||
query=question,
|
||||
reranked_sections=relevant_docs,
|
||||
final_context_sections=relevant_docs,
|
||||
search_query_info=query_info,
|
||||
get_section_relevance=lambda: None, # TODO: add relevance
|
||||
search_tool=agent_a_config.search_tool,
|
||||
):
|
||||
dispatch_custom_event(
|
||||
"tool_response",
|
||||
ExtendedToolResponse(
|
||||
id=tool_response.id,
|
||||
response=tool_response.response,
|
||||
level=0,
|
||||
level_question_nr=0, # 0, 0 is the base question
|
||||
),
|
||||
)
|
||||
|
||||
net_new_original_question_docs = []
|
||||
for all_original_question_doc in all_original_question_documents:
|
||||
if all_original_question_doc not in sub_question_docs:
|
||||
net_new_original_question_docs.append(all_original_question_doc)
|
||||
|
||||
decomp_answer_results = state.decomp_answer_results
|
||||
|
||||
good_qa_list: list[str] = []
|
||||
|
||||
sub_question_nr = 1
|
||||
|
||||
for decomp_answer_result in decomp_answer_results:
|
||||
decomp_questions.append(decomp_answer_result.question)
|
||||
_, question_nr = parse_question_id(decomp_answer_result.question_id)
|
||||
if (
|
||||
decomp_answer_result.quality.lower().startswith("yes")
|
||||
and len(decomp_answer_result.answer) > 0
|
||||
and decomp_answer_result.answer != UNKNOWN_ANSWER
|
||||
):
|
||||
good_qa_list.append(
|
||||
SUB_QUESTION_ANSWER_TEMPLATE.format(
|
||||
sub_question=decomp_answer_result.question,
|
||||
sub_answer=decomp_answer_result.answer,
|
||||
sub_question_nr=sub_question_nr,
|
||||
)
|
||||
)
|
||||
sub_question_nr += 1
|
||||
|
||||
if len(good_qa_list) > 0:
|
||||
sub_question_answer_str = "\n\n------\n\n".join(good_qa_list)
|
||||
else:
|
||||
sub_question_answer_str = ""
|
||||
|
||||
# Determine which persona-specification prompt to use
|
||||
|
||||
if len(persona_prompt) == 0:
|
||||
persona_specification = ASSISTANT_SYSTEM_PROMPT_DEFAULT
|
||||
else:
|
||||
persona_specification = ASSISTANT_SYSTEM_PROMPT_PERSONA.format(
|
||||
persona_prompt=persona_prompt
|
||||
)
|
||||
|
||||
# Determine which base prompt to use given the sub-question information
|
||||
if len(good_qa_list) > 0:
|
||||
base_prompt = INITIAL_RAG_PROMPT
|
||||
else:
|
||||
base_prompt = INITIAL_RAG_PROMPT_NO_SUB_QUESTIONS
|
||||
|
||||
model = agent_a_config.fast_llm
|
||||
|
||||
doc_context = format_docs(relevant_docs)
|
||||
doc_context = trim_prompt_piece(
|
||||
model.config,
|
||||
doc_context,
|
||||
base_prompt
|
||||
+ sub_question_answer_str
|
||||
+ persona_specification
|
||||
+ history
|
||||
+ date_str,
|
||||
)
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=base_prompt.format(
|
||||
question=question,
|
||||
answered_sub_questions=remove_document_citations(
|
||||
sub_question_answer_str
|
||||
),
|
||||
relevant_docs=format_docs(relevant_docs),
|
||||
persona_specification=persona_specification,
|
||||
history=history,
|
||||
date_prompt=date_str,
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
streamed_tokens: list[str | list[str | dict[str, Any]]] = [""]
|
||||
for message in model.stream(msg):
|
||||
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
|
||||
content = message.content
|
||||
if not isinstance(content, str):
|
||||
raise ValueError(
|
||||
f"Expected content to be a string, but got {type(content)}"
|
||||
)
|
||||
dispatch_custom_event(
|
||||
"initial_agent_answer",
|
||||
AgentAnswerPiece(
|
||||
answer_piece=content,
|
||||
level=0,
|
||||
level_question_nr=0,
|
||||
answer_type="agent_level_answer",
|
||||
),
|
||||
)
|
||||
streamed_tokens.append(content)
|
||||
|
||||
dispatch_main_answer_stop_info(0)
|
||||
response = merge_content(*streamed_tokens)
|
||||
answer = cast(str, response)
|
||||
|
||||
initial_agent_stats = calculate_initial_agent_stats(
|
||||
state.decomp_answer_results, state.original_question_retrieval_stats
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"\n\nYYYYY--Sub-Questions:\n\n{sub_question_answer_str}\n\nStats:\n\n"
|
||||
)
|
||||
|
||||
if initial_agent_stats:
|
||||
logger.debug(initial_agent_stats.original_question)
|
||||
logger.debug(initial_agent_stats.sub_questions)
|
||||
logger.debug(initial_agent_stats.agent_effectiveness)
|
||||
|
||||
now_end = datetime.now()
|
||||
|
||||
logger.debug(
|
||||
f"--------{now_end}--{now_end - now_start}--------INITIAL AGENT ANSWER END---\n\n"
|
||||
)
|
||||
|
||||
agent_base_end_time = datetime.now()
|
||||
|
||||
agent_base_metrics = AgentBaseMetrics(
|
||||
num_verified_documents_total=len(relevant_docs),
|
||||
num_verified_documents_core=state.original_question_retrieval_stats.verified_count,
|
||||
verified_avg_score_core=state.original_question_retrieval_stats.verified_avg_scores,
|
||||
num_verified_documents_base=initial_agent_stats.sub_questions.get(
|
||||
"num_verified_documents", None
|
||||
),
|
||||
verified_avg_score_base=initial_agent_stats.sub_questions.get(
|
||||
"verified_avg_score", None
|
||||
),
|
||||
base_doc_boost_factor=initial_agent_stats.agent_effectiveness.get(
|
||||
"utilized_chunk_ratio", None
|
||||
),
|
||||
support_boost_factor=initial_agent_stats.agent_effectiveness.get(
|
||||
"support_ratio", None
|
||||
),
|
||||
duration__s=(agent_base_end_time - state.agent_start_time).total_seconds(),
|
||||
)
|
||||
|
||||
return InitialAnswerUpdate(
|
||||
initial_answer=answer,
|
||||
initial_agent_stats=initial_agent_stats,
|
||||
generated_sub_questions=decomp_questions,
|
||||
agent_base_end_time=agent_base_end_time,
|
||||
agent_base_metrics=agent_base_metrics,
|
||||
log_messages=[
|
||||
f"{now_end} -- Main - Initial Answer generation, Time taken: {now_end - now_start}"
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,56 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.main.operations import logger
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import InitialAnswerBASEUpdate
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import MainState
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
trim_prompt_piece,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import INITIAL_RAG_BASE_PROMPT
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import format_docs
|
||||
|
||||
|
||||
def generate_initial_base_search_only_answer(
|
||||
state: MainState,
|
||||
config: RunnableConfig,
|
||||
) -> InitialAnswerBASEUpdate:
|
||||
now_start = datetime.now()
|
||||
|
||||
logger.debug(f"--------{now_start}--------GENERATE INITIAL BASE ANSWER---")
|
||||
|
||||
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
question = agent_a_config.search_request.query
|
||||
original_question_docs = state.all_original_question_documents
|
||||
|
||||
model = agent_a_config.fast_llm
|
||||
|
||||
doc_context = format_docs(original_question_docs)
|
||||
doc_context = trim_prompt_piece(
|
||||
model.config, doc_context, INITIAL_RAG_BASE_PROMPT + question
|
||||
)
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=INITIAL_RAG_BASE_PROMPT.format(
|
||||
question=question,
|
||||
context=doc_context,
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
# Grader
|
||||
response = model.invoke(msg)
|
||||
answer = response.pretty_repr()
|
||||
|
||||
now_end = datetime.now()
|
||||
|
||||
logger.debug(
|
||||
f"--------{now_end}--{now_end - now_start}--------INITIAL BASE ANSWER END---\n\n"
|
||||
)
|
||||
|
||||
return InitialAnswerBASEUpdate(initial_base_answer=answer)
|
||||
@@ -0,0 +1,332 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.callbacks.manager import dispatch_custom_event
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import merge_content
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.main.models import AgentRefinedMetrics
|
||||
from onyx.agents.agent_search.deep_search_a.main.operations import get_query_info
|
||||
from onyx.agents.agent_search.deep_search_a.main.operations import logger
|
||||
from onyx.agents.agent_search.deep_search_a.main.operations import (
|
||||
remove_document_citations,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import MainState
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import RefinedAnswerUpdate
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
build_history_prompt,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
trim_prompt_piece,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import RefinedAgentStats
|
||||
from onyx.agents.agent_search.shared_graph_utils.operators import (
|
||||
dedup_inference_sections,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import (
|
||||
ASSISTANT_SYSTEM_PROMPT_DEFAULT,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import (
|
||||
ASSISTANT_SYSTEM_PROMPT_PERSONA,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import REVISED_RAG_PROMPT
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import (
|
||||
REVISED_RAG_PROMPT_NO_SUB_QUESTIONS,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import (
|
||||
SUB_QUESTION_ANSWER_TEMPLATE,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import UNKNOWN_ANSWER
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
dispatch_main_answer_stop_info,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import format_docs
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import get_persona_prompt
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import get_today_prompt
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
|
||||
from onyx.chat.models import AgentAnswerPiece
|
||||
from onyx.chat.models import ExtendedToolResponse
|
||||
from onyx.tools.tool_implementations.search.search_tool import yield_search_responses
|
||||
|
||||
|
||||
def generate_refined_answer(
|
||||
state: MainState, config: RunnableConfig
|
||||
) -> RefinedAnswerUpdate:
|
||||
now_start = datetime.now()
|
||||
|
||||
logger.debug(f"--------{now_start}--------GENERATE REFINED ANSWER---")
|
||||
|
||||
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
question = agent_a_config.search_request.query
|
||||
persona_prompt = get_persona_prompt(agent_a_config.search_request.persona)
|
||||
|
||||
history = build_history_prompt(agent_a_config.prompt_builder)
|
||||
date_str = get_today_prompt()
|
||||
initial_documents = state.documents
|
||||
revised_documents = state.refined_documents
|
||||
|
||||
combined_documents = dedup_inference_sections(initial_documents, revised_documents)
|
||||
|
||||
query_info = get_query_info(state.original_question_retrieval_results)
|
||||
if agent_a_config.search_tool is None:
|
||||
raise ValueError("search_tool must be provided for agentic search")
|
||||
# stream refined answer docs
|
||||
for tool_response in yield_search_responses(
|
||||
query=question,
|
||||
reranked_sections=combined_documents,
|
||||
final_context_sections=combined_documents,
|
||||
search_query_info=query_info,
|
||||
get_section_relevance=lambda: None, # TODO: add relevance
|
||||
search_tool=agent_a_config.search_tool,
|
||||
):
|
||||
dispatch_custom_event(
|
||||
"tool_response",
|
||||
ExtendedToolResponse(
|
||||
id=tool_response.id,
|
||||
response=tool_response.response,
|
||||
level=1,
|
||||
level_question_nr=0, # 0, 0 is the base question
|
||||
),
|
||||
)
|
||||
|
||||
if len(initial_documents) > 0:
|
||||
revision_doc_effectiveness = len(combined_documents) / len(initial_documents)
|
||||
elif len(revised_documents) == 0:
|
||||
revision_doc_effectiveness = 0.0
|
||||
else:
|
||||
revision_doc_effectiveness = 10.0
|
||||
|
||||
decomp_answer_results = state.decomp_answer_results
|
||||
# revised_answer_results = state.refined_decomp_answer_results
|
||||
|
||||
good_qa_list: list[str] = []
|
||||
decomp_questions = []
|
||||
|
||||
initial_good_sub_questions: list[str] = []
|
||||
new_revised_good_sub_questions: list[str] = []
|
||||
|
||||
sub_question_nr = 1
|
||||
|
||||
for decomp_answer_result in decomp_answer_results:
|
||||
question_level, question_nr = parse_question_id(
|
||||
decomp_answer_result.question_id
|
||||
)
|
||||
|
||||
decomp_questions.append(decomp_answer_result.question)
|
||||
if (
|
||||
decomp_answer_result.quality.lower().startswith("yes")
|
||||
and len(decomp_answer_result.answer) > 0
|
||||
and decomp_answer_result.answer != UNKNOWN_ANSWER
|
||||
):
|
||||
good_qa_list.append(
|
||||
SUB_QUESTION_ANSWER_TEMPLATE.format(
|
||||
sub_question=decomp_answer_result.question,
|
||||
sub_answer=decomp_answer_result.answer,
|
||||
sub_question_nr=sub_question_nr,
|
||||
)
|
||||
)
|
||||
if question_level == 0:
|
||||
initial_good_sub_questions.append(decomp_answer_result.question)
|
||||
else:
|
||||
new_revised_good_sub_questions.append(decomp_answer_result.question)
|
||||
|
||||
sub_question_nr += 1
|
||||
|
||||
initial_good_sub_questions = list(set(initial_good_sub_questions))
|
||||
new_revised_good_sub_questions = list(set(new_revised_good_sub_questions))
|
||||
total_good_sub_questions = list(
|
||||
set(initial_good_sub_questions + new_revised_good_sub_questions)
|
||||
)
|
||||
if len(initial_good_sub_questions) > 0:
|
||||
revision_question_efficiency: float = len(total_good_sub_questions) / len(
|
||||
initial_good_sub_questions
|
||||
)
|
||||
elif len(new_revised_good_sub_questions) > 0:
|
||||
revision_question_efficiency = 10.0
|
||||
else:
|
||||
revision_question_efficiency = 1.0
|
||||
|
||||
sub_question_answer_str = "\n\n------\n\n".join(list(set(good_qa_list)))
|
||||
|
||||
# original answer
|
||||
|
||||
initial_answer = state.initial_answer
|
||||
|
||||
# Determine which persona-specification prompt to use
|
||||
|
||||
if len(persona_prompt) == 0:
|
||||
persona_specification = ASSISTANT_SYSTEM_PROMPT_DEFAULT
|
||||
else:
|
||||
persona_specification = ASSISTANT_SYSTEM_PROMPT_PERSONA.format(
|
||||
persona_prompt=persona_prompt
|
||||
)
|
||||
|
||||
# Determine which base prompt to use given the sub-question information
|
||||
if len(good_qa_list) > 0:
|
||||
base_prompt = REVISED_RAG_PROMPT
|
||||
else:
|
||||
base_prompt = REVISED_RAG_PROMPT_NO_SUB_QUESTIONS
|
||||
|
||||
model = agent_a_config.fast_llm
|
||||
relevant_docs = format_docs(combined_documents)
|
||||
relevant_docs = trim_prompt_piece(
|
||||
model.config,
|
||||
relevant_docs,
|
||||
base_prompt
|
||||
+ question
|
||||
+ sub_question_answer_str
|
||||
+ relevant_docs
|
||||
+ initial_answer
|
||||
+ persona_specification
|
||||
+ history,
|
||||
)
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=base_prompt.format(
|
||||
question=question,
|
||||
history=history,
|
||||
answered_sub_questions=remove_document_citations(
|
||||
sub_question_answer_str
|
||||
),
|
||||
relevant_docs=relevant_docs,
|
||||
initial_answer=remove_document_citations(initial_answer),
|
||||
persona_specification=persona_specification,
|
||||
date_prompt=date_str,
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
# Grader
|
||||
|
||||
streamed_tokens: list[str | list[str | dict[str, Any]]] = [""]
|
||||
for message in model.stream(msg):
|
||||
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
|
||||
content = message.content
|
||||
if not isinstance(content, str):
|
||||
raise ValueError(
|
||||
f"Expected content to be a string, but got {type(content)}"
|
||||
)
|
||||
dispatch_custom_event(
|
||||
"refined_agent_answer",
|
||||
AgentAnswerPiece(
|
||||
answer_piece=content,
|
||||
level=1,
|
||||
level_question_nr=0,
|
||||
answer_type="agent_level_answer",
|
||||
),
|
||||
)
|
||||
streamed_tokens.append(content)
|
||||
|
||||
dispatch_main_answer_stop_info(1)
|
||||
response = merge_content(*streamed_tokens)
|
||||
answer = cast(str, response)
|
||||
|
||||
# refined_agent_stats = _calculate_refined_agent_stats(
|
||||
# state.decomp_answer_results, state.original_question_retrieval_stats
|
||||
# )
|
||||
|
||||
initial_good_sub_questions_str = "\n".join(list(set(initial_good_sub_questions)))
|
||||
new_revised_good_sub_questions_str = "\n".join(
|
||||
list(set(new_revised_good_sub_questions))
|
||||
)
|
||||
|
||||
refined_agent_stats = RefinedAgentStats(
|
||||
revision_doc_efficiency=revision_doc_effectiveness,
|
||||
revision_question_efficiency=revision_question_efficiency,
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"\n\n---INITIAL ANSWER START---\n\n Answer:\n Agent: {initial_answer}"
|
||||
)
|
||||
logger.debug("-" * 10)
|
||||
logger.debug(f"\n\n---REVISED AGENT ANSWER START---\n\n Answer:\n Agent: {answer}")
|
||||
|
||||
logger.debug("-" * 100)
|
||||
logger.debug(f"\n\nINITAL Sub-Questions\n\n{initial_good_sub_questions_str}\n\n")
|
||||
logger.debug("-" * 10)
|
||||
logger.debug(
|
||||
f"\n\nNEW REVISED Sub-Questions\n\n{new_revised_good_sub_questions_str}\n\n"
|
||||
)
|
||||
|
||||
logger.debug("-" * 100)
|
||||
|
||||
logger.debug(
|
||||
f"\n\nINITAL & REVISED Sub-Questions & Answers:\n\n{sub_question_answer_str}\n\nStas:\n\n"
|
||||
)
|
||||
|
||||
logger.debug("-" * 100)
|
||||
|
||||
if state.initial_agent_stats:
|
||||
initial_doc_boost_factor = state.initial_agent_stats.agent_effectiveness.get(
|
||||
"utilized_chunk_ratio", "--"
|
||||
)
|
||||
initial_support_boost_factor = (
|
||||
state.initial_agent_stats.agent_effectiveness.get("support_ratio", "--")
|
||||
)
|
||||
num_initial_verified_docs = state.initial_agent_stats.original_question.get(
|
||||
"num_verified_documents", "--"
|
||||
)
|
||||
initial_verified_docs_avg_score = (
|
||||
state.initial_agent_stats.original_question.get("verified_avg_score", "--")
|
||||
)
|
||||
initial_sub_questions_verified_docs = (
|
||||
state.initial_agent_stats.sub_questions.get("num_verified_documents", "--")
|
||||
)
|
||||
|
||||
logger.debug("INITIAL AGENT STATS")
|
||||
logger.debug(f"Document Boost Factor: {initial_doc_boost_factor}")
|
||||
logger.debug(f"Support Boost Factor: {initial_support_boost_factor}")
|
||||
logger.debug(f"Originally Verified Docs: {num_initial_verified_docs}")
|
||||
logger.debug(
|
||||
f"Originally Verified Docs Avg Score: {initial_verified_docs_avg_score}"
|
||||
)
|
||||
logger.debug(
|
||||
f"Sub-Questions Verified Docs: {initial_sub_questions_verified_docs}"
|
||||
)
|
||||
if refined_agent_stats:
|
||||
logger.debug("-" * 10)
|
||||
logger.debug("REFINED AGENT STATS")
|
||||
logger.debug(
|
||||
f"Revision Doc Factor: {refined_agent_stats.revision_doc_efficiency}"
|
||||
)
|
||||
logger.debug(
|
||||
f"Revision Question Factor: {refined_agent_stats.revision_question_efficiency}"
|
||||
)
|
||||
|
||||
now_end = datetime.now()
|
||||
|
||||
logger.debug(
|
||||
f"--------{now_end}--{now_end - now_start}--------INITIAL AGENT ANSWER END---\n\n"
|
||||
)
|
||||
|
||||
agent_refined_end_time = datetime.now()
|
||||
if state.agent_refined_start_time:
|
||||
agent_refined_duration = (
|
||||
agent_refined_end_time - state.agent_refined_start_time
|
||||
).total_seconds()
|
||||
else:
|
||||
agent_refined_duration = None
|
||||
|
||||
agent_refined_metrics = AgentRefinedMetrics(
|
||||
refined_doc_boost_factor=refined_agent_stats.revision_doc_efficiency,
|
||||
refined_question_boost_factor=refined_agent_stats.revision_question_efficiency,
|
||||
duration__s=agent_refined_duration,
|
||||
)
|
||||
|
||||
now_end = datetime.now()
|
||||
|
||||
logger.debug(
|
||||
f"--------{now_end}--{now_end - now_start}--------REFINED ANSWER UPDATE END---"
|
||||
)
|
||||
|
||||
return RefinedAnswerUpdate(
|
||||
refined_answer=answer,
|
||||
refined_answer_quality=True, # TODO: replace this with the actual check value
|
||||
refined_agent_stats=refined_agent_stats,
|
||||
agent_refined_end_time=agent_refined_end_time,
|
||||
agent_refined_metrics=agent_refined_metrics,
|
||||
)
|
||||
@@ -0,0 +1,39 @@
|
||||
from datetime import datetime
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.base_raw_search.states import (
|
||||
BaseRawSearchOutput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.operations import logger
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import ExpandedRetrievalUpdate
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkStats
|
||||
|
||||
|
||||
def ingest_initial_base_retrieval(
|
||||
state: BaseRawSearchOutput,
|
||||
) -> ExpandedRetrievalUpdate:
|
||||
now_start = datetime.now()
|
||||
|
||||
logger.debug(f"--------{now_start}--------INGEST INITIAL RETRIEVAL---")
|
||||
|
||||
sub_question_retrieval_stats = (
|
||||
state.base_expanded_retrieval_result.sub_question_retrieval_stats
|
||||
)
|
||||
if sub_question_retrieval_stats is None:
|
||||
sub_question_retrieval_stats = AgentChunkStats()
|
||||
else:
|
||||
sub_question_retrieval_stats = sub_question_retrieval_stats
|
||||
|
||||
now_end = datetime.now()
|
||||
|
||||
logger.debug(
|
||||
f"--------{now_end}--{now_end - now_start}--------INGEST INITIAL RETRIEVAL END---"
|
||||
)
|
||||
|
||||
return ExpandedRetrievalUpdate(
|
||||
original_question_retrieval_results=state.base_expanded_retrieval_result.expanded_queries_results,
|
||||
all_original_question_documents=state.base_expanded_retrieval_result.context_documents,
|
||||
original_question_retrieval_stats=sub_question_retrieval_stats,
|
||||
log_messages=[
|
||||
f"{now_end} -- Main - Ingestion base retrieval, Time taken: {now_end - now_start}"
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,41 @@
|
||||
from datetime import datetime
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.states import (
|
||||
AnswerQuestionOutput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.operations import logger
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import DecompAnswersUpdate
|
||||
from onyx.agents.agent_search.shared_graph_utils.operators import (
|
||||
dedup_inference_sections,
|
||||
)
|
||||
|
||||
|
||||
def ingest_initial_sub_question_answers(
|
||||
state: AnswerQuestionOutput,
|
||||
) -> DecompAnswersUpdate:
|
||||
now_start = datetime.now()
|
||||
|
||||
logger.debug(f"--------{now_start}--------INGEST ANSWERS---")
|
||||
documents = []
|
||||
context_documents = []
|
||||
answer_results = state.answer_results if hasattr(state, "answer_results") else []
|
||||
for answer_result in answer_results:
|
||||
documents.extend(answer_result.documents)
|
||||
context_documents.extend(answer_result.context_documents)
|
||||
|
||||
now_end = datetime.now()
|
||||
|
||||
logger.debug(
|
||||
f"--------{now_end}--{now_end - now_start}--------INGEST ANSWERS END---"
|
||||
)
|
||||
|
||||
return DecompAnswersUpdate(
|
||||
# Deduping is done by the documents operator for the main graph
|
||||
# so we might not need to dedup here
|
||||
documents=dedup_inference_sections(documents, []),
|
||||
context_documents=dedup_inference_sections(context_documents, []),
|
||||
decomp_answer_results=answer_results,
|
||||
log_messages=[
|
||||
f"{now_end} -- Main - Ingest initial processed sub questions, Time taken: {now_end - now_start}"
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,39 @@
|
||||
from datetime import datetime
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.states import (
|
||||
AnswerQuestionOutput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.operations import logger
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import DecompAnswersUpdate
|
||||
from onyx.agents.agent_search.shared_graph_utils.operators import (
|
||||
dedup_inference_sections,
|
||||
)
|
||||
|
||||
|
||||
def ingest_refined_answers(
|
||||
state: AnswerQuestionOutput,
|
||||
) -> DecompAnswersUpdate:
|
||||
now_start = datetime.now()
|
||||
|
||||
logger.debug(f"--------{now_start}--------INGEST FOLLOW UP ANSWERS---")
|
||||
|
||||
documents = []
|
||||
answer_results = state.answer_results if hasattr(state, "answer_results") else []
|
||||
for answer_result in answer_results:
|
||||
documents.extend(answer_result.documents)
|
||||
|
||||
now_end = datetime.now()
|
||||
|
||||
logger.debug(
|
||||
f"--------{now_end}--{now_end - now_start}--------INGEST FOLLOW UP ANSWERS END---"
|
||||
)
|
||||
|
||||
return DecompAnswersUpdate(
|
||||
# Deduping is done by the documents operator for the main graph
|
||||
# so we might not need to dedup here
|
||||
documents=dedup_inference_sections(documents, []),
|
||||
decomp_answer_results=answer_results,
|
||||
log_messages=[
|
||||
f"{now_end} -- Main - Ingest refined answers, Time taken: {now_end - now_start}"
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,40 @@
|
||||
from datetime import datetime
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.main.operations import logger
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import (
|
||||
InitialAnswerQualityUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import MainState
|
||||
|
||||
|
||||
def initial_answer_quality_check(state: MainState) -> InitialAnswerQualityUpdate:
|
||||
"""
|
||||
Check whether the final output satisfies the original user question
|
||||
|
||||
Args:
|
||||
state (messages): The current state
|
||||
|
||||
Returns:
|
||||
InitialAnswerQualityUpdate
|
||||
"""
|
||||
|
||||
now_start = datetime.now()
|
||||
|
||||
logger.debug(
|
||||
f"--------{now_start}--------Checking for base answer validity - for not set True/False manually"
|
||||
)
|
||||
|
||||
verdict = True
|
||||
|
||||
now_end = datetime.now()
|
||||
|
||||
logger.debug(
|
||||
f"--------{now_end}--{now_end - now_start}--------INITIAL ANSWER QUALITY CHECK END---"
|
||||
)
|
||||
|
||||
return InitialAnswerQualityUpdate(
|
||||
initial_answer_quality=verdict,
|
||||
log_messages=[
|
||||
f"{now_end} -- Main - Initial answer quality check, Time taken: {now_end - now_start}"
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,133 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.callbacks.manager import dispatch_custom_event
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import merge_content
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.main.models import AgentRefinedMetrics
|
||||
from onyx.agents.agent_search.deep_search_a.main.operations import dispatch_subquestion
|
||||
from onyx.agents.agent_search.deep_search_a.main.operations import logger
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import BaseDecompUpdate
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import MainState
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
build_history_prompt,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import (
|
||||
INITIAL_DECOMPOSITION_PROMPT_QUESTIONS,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import (
|
||||
INITIAL_DECOMPOSITION_PROMPT_QUESTIONS_AFTER_SEARCH,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import dispatch_separated
|
||||
from onyx.chat.models import StreamStopInfo
|
||||
from onyx.chat.models import StreamStopReason
|
||||
from onyx.chat.models import SubQuestionPiece
|
||||
from onyx.configs.agent_configs import AGENT_NUM_DOCS_FOR_DECOMPOSITION
|
||||
|
||||
|
||||
def initial_sub_question_creation(
|
||||
state: MainState, config: RunnableConfig
|
||||
) -> BaseDecompUpdate:
|
||||
now_start = datetime.now()
|
||||
|
||||
logger.debug(f"--------{now_start}--------BASE DECOMP START---")
|
||||
|
||||
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
question = agent_a_config.search_request.query
|
||||
chat_session_id = agent_a_config.chat_session_id
|
||||
primary_message_id = agent_a_config.message_id
|
||||
perform_initial_search_decomposition = (
|
||||
agent_a_config.perform_initial_search_decomposition
|
||||
)
|
||||
# perform_initial_search_path_decision = (
|
||||
# agent_a_config.perform_initial_search_path_decision
|
||||
# )
|
||||
history = build_history_prompt(agent_a_config.prompt_builder)
|
||||
|
||||
# Use the initial search results to inform the decomposition
|
||||
sample_doc_str = state.sample_doc_str if hasattr(state, "sample_doc_str") else ""
|
||||
|
||||
if not chat_session_id or not primary_message_id:
|
||||
raise ValueError(
|
||||
"chat_session_id and message_id must be provided for agent search"
|
||||
)
|
||||
agent_start_time = datetime.now()
|
||||
|
||||
# Initial search to inform decomposition. Just get top 3 fits
|
||||
|
||||
if perform_initial_search_decomposition:
|
||||
sample_doc_str = "\n\n".join(
|
||||
[
|
||||
doc.combined_content
|
||||
for doc in state.exploratory_search_results[
|
||||
:AGENT_NUM_DOCS_FOR_DECOMPOSITION
|
||||
]
|
||||
]
|
||||
)
|
||||
|
||||
decomposition_prompt = (
|
||||
INITIAL_DECOMPOSITION_PROMPT_QUESTIONS_AFTER_SEARCH.format(
|
||||
question=question, sample_doc_str=sample_doc_str, history=history
|
||||
)
|
||||
)
|
||||
|
||||
else:
|
||||
decomposition_prompt = INITIAL_DECOMPOSITION_PROMPT_QUESTIONS.format(
|
||||
question=question, history=history
|
||||
)
|
||||
|
||||
# Start decomposition
|
||||
|
||||
msg = [HumanMessage(content=decomposition_prompt)]
|
||||
|
||||
# Get the rewritten queries in a defined format
|
||||
model = agent_a_config.fast_llm
|
||||
|
||||
# Send the initial question as a subquestion with number 0
|
||||
dispatch_custom_event(
|
||||
"decomp_qs",
|
||||
SubQuestionPiece(
|
||||
sub_question=question,
|
||||
level=0,
|
||||
level_question_nr=0,
|
||||
),
|
||||
)
|
||||
# dispatches custom events for subquestion tokens, adding in subquestion ids.
|
||||
streamed_tokens = dispatch_separated(model.stream(msg), dispatch_subquestion(0))
|
||||
|
||||
stop_event = StreamStopInfo(
|
||||
stop_reason=StreamStopReason.FINISHED,
|
||||
stream_type="sub_questions",
|
||||
level=0,
|
||||
)
|
||||
dispatch_custom_event("stream_finished", stop_event)
|
||||
|
||||
deomposition_response = merge_content(*streamed_tokens)
|
||||
|
||||
# this call should only return strings. Commenting out for efficiency
|
||||
# assert [type(tok) == str for tok in streamed_tokens]
|
||||
|
||||
# use no-op cast() instead of str() which runs code
|
||||
# list_of_subquestions = clean_and_parse_list_string(cast(str, response))
|
||||
list_of_subqs = cast(str, deomposition_response).split("\n")
|
||||
|
||||
decomp_list: list[str] = [sq.strip() for sq in list_of_subqs if sq.strip() != ""]
|
||||
|
||||
now_end = datetime.now()
|
||||
|
||||
logger.debug(f"--------{now_end}--{now_end - now_start}--------BASE DECOMP END---")
|
||||
|
||||
return BaseDecompUpdate(
|
||||
initial_decomp_questions=decomp_list,
|
||||
agent_start_time=agent_start_time,
|
||||
agent_refined_start_time=None,
|
||||
agent_refined_end_time=None,
|
||||
agent_refined_metrics=AgentRefinedMetrics(
|
||||
refined_doc_boost_factor=None,
|
||||
refined_question_boost_factor=None,
|
||||
duration__s=None,
|
||||
),
|
||||
)
|
||||
@@ -0,0 +1,47 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.main.operations import logger
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import MainState
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import (
|
||||
RequireRefinedAnswerUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
|
||||
|
||||
def refined_answer_decision(
|
||||
state: MainState, config: RunnableConfig
|
||||
) -> RequireRefinedAnswerUpdate:
|
||||
now_start = datetime.now()
|
||||
|
||||
logger.debug(f"--------{now_start}--------REFINED ANSWER DECISION---")
|
||||
|
||||
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
if "?" in agent_a_config.search_request.query:
|
||||
decision = False
|
||||
else:
|
||||
decision = True
|
||||
|
||||
decision = True
|
||||
|
||||
now_end = datetime.now()
|
||||
|
||||
logger.debug(
|
||||
f"--------{now_end}--{now_end - now_start}--------REFINED ANSWER DECISION END---"
|
||||
)
|
||||
log_messages = [
|
||||
f"{now_end} -- Main - Refined answer decision: {decision}, Time taken: {now_end - now_start}"
|
||||
]
|
||||
if agent_a_config.allow_refinement:
|
||||
return RequireRefinedAnswerUpdate(
|
||||
require_refined_answer=decision,
|
||||
log_messages=log_messages,
|
||||
)
|
||||
|
||||
else:
|
||||
return RequireRefinedAnswerUpdate(
|
||||
require_refined_answer=False,
|
||||
log_messages=log_messages,
|
||||
)
|
||||
@@ -0,0 +1,119 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.callbacks.manager import dispatch_custom_event
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import merge_content
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.main.models import FollowUpSubQuestion
|
||||
from onyx.agents.agent_search.deep_search_a.main.operations import dispatch_subquestion
|
||||
from onyx.agents.agent_search.deep_search_a.main.operations import logger
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import (
|
||||
FollowUpSubQuestionsUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import MainState
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
build_history_prompt,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import DEEP_DECOMPOSE_PROMPT
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import dispatch_separated
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import format_docs
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import make_question_id
|
||||
from onyx.configs.agent_configs import AGENT_NUM_DOCS_FOR_REFINED_DECOMPOSITION
|
||||
from onyx.tools.models import ToolCallKickoff
|
||||
|
||||
|
||||
def refined_sub_question_creation(
|
||||
state: MainState, config: RunnableConfig
|
||||
) -> FollowUpSubQuestionsUpdate:
|
||||
""" """
|
||||
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
dispatch_custom_event(
|
||||
"start_refined_answer_creation",
|
||||
ToolCallKickoff(
|
||||
tool_name="agent_search_1",
|
||||
tool_args={
|
||||
"query": agent_a_config.search_request.query,
|
||||
"answer": state.initial_answer,
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
now_start = datetime.now()
|
||||
|
||||
logger.debug(f"--------{now_start}--------FOLLOW UP DECOMPOSE---")
|
||||
|
||||
agent_refined_start_time = datetime.now()
|
||||
|
||||
question = agent_a_config.search_request.query
|
||||
base_answer = state.initial_answer
|
||||
history = build_history_prompt(agent_a_config.prompt_builder)
|
||||
# get the entity term extraction dict and properly format it
|
||||
# entity_retlation_term_extractions = state.entity_relation_term_extractions
|
||||
|
||||
# entity_term_extraction_str = format_entity_term_extraction(
|
||||
# entity_retlation_term_extractions
|
||||
# )
|
||||
|
||||
docs_str = format_docs(
|
||||
state.all_original_question_documents[:AGENT_NUM_DOCS_FOR_REFINED_DECOMPOSITION]
|
||||
)
|
||||
|
||||
initial_question_answers = state.decomp_answer_results
|
||||
|
||||
addressed_question_list = [
|
||||
x.question for x in initial_question_answers if "yes" in x.quality.lower()
|
||||
]
|
||||
|
||||
failed_question_list = [
|
||||
x.question for x in initial_question_answers if "no" in x.quality.lower()
|
||||
]
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=DEEP_DECOMPOSE_PROMPT.format(
|
||||
question=question,
|
||||
history=history,
|
||||
docs_str=docs_str,
|
||||
base_answer=base_answer,
|
||||
answered_sub_questions="\n - ".join(addressed_question_list),
|
||||
failed_sub_questions="\n - ".join(failed_question_list),
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
# Grader
|
||||
model = agent_a_config.fast_llm
|
||||
|
||||
streamed_tokens = dispatch_separated(model.stream(msg), dispatch_subquestion(1))
|
||||
response = merge_content(*streamed_tokens)
|
||||
|
||||
if isinstance(response, str):
|
||||
parsed_response = [q for q in response.split("\n") if q.strip() != ""]
|
||||
else:
|
||||
raise ValueError("LLM response is not a string")
|
||||
|
||||
refined_sub_question_dict = {}
|
||||
for sub_question_nr, sub_question in enumerate(parsed_response):
|
||||
refined_sub_question = FollowUpSubQuestion(
|
||||
sub_question=sub_question,
|
||||
sub_question_id=make_question_id(1, sub_question_nr + 1),
|
||||
verified=False,
|
||||
answered=False,
|
||||
answer="",
|
||||
)
|
||||
|
||||
refined_sub_question_dict[sub_question_nr + 1] = refined_sub_question
|
||||
|
||||
now_end = datetime.now()
|
||||
|
||||
logger.debug(
|
||||
f"--------{now_end}--{now_end - now_start}--------FOLLOW UP DECOMPOSE END---"
|
||||
)
|
||||
|
||||
return FollowUpSubQuestionsUpdate(
|
||||
refined_sub_questions=refined_sub_question_dict,
|
||||
agent_refined_start_time=agent_refined_start_time,
|
||||
)
|
||||
@@ -0,0 +1,12 @@
|
||||
from datetime import datetime
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import LoggerUpdate
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import MainState
|
||||
|
||||
|
||||
def retrieval_consolidation(
|
||||
state: MainState,
|
||||
) -> LoggerUpdate:
|
||||
now_start = datetime.now()
|
||||
|
||||
return LoggerUpdate(log_messages=[f"{now_start} -- Retrieval consolidation"])
|
||||
@@ -0,0 +1,145 @@
|
||||
import re
|
||||
from collections.abc import Callable
|
||||
|
||||
from langchain_core.callbacks.manager import dispatch_custom_event
|
||||
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkStats
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import InitialAgentResultStats
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import QueryResult
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import (
|
||||
QuestionAnswerResults,
|
||||
)
|
||||
from onyx.chat.models import SubQuestionPiece
|
||||
from onyx.tools.models import SearchQueryInfo
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def remove_document_citations(text: str) -> str:
|
||||
"""
|
||||
Removes citation expressions of format '[[D1]]()' from text.
|
||||
The number after D can vary.
|
||||
|
||||
Args:
|
||||
text: Input text containing citations
|
||||
|
||||
Returns:
|
||||
Text with citations removed
|
||||
"""
|
||||
# Pattern explanation:
|
||||
# \[\[D\d+\]\]\(\) matches:
|
||||
# \[\[ - literal [[ characters
|
||||
# D - literal D character
|
||||
# \d+ - one or more digits
|
||||
# \]\] - literal ]] characters
|
||||
# \(\) - literal () characters
|
||||
return re.sub(r"\[\[(?:D|Q)\d+\]\]\(\)", "", text)
|
||||
|
||||
|
||||
def dispatch_subquestion(level: int) -> Callable[[str, int], None]:
|
||||
def _helper(sub_question_part: str, num: int) -> None:
|
||||
dispatch_custom_event(
|
||||
"decomp_qs",
|
||||
SubQuestionPiece(
|
||||
sub_question=sub_question_part,
|
||||
level=level,
|
||||
level_question_nr=num,
|
||||
),
|
||||
)
|
||||
|
||||
return _helper
|
||||
|
||||
|
||||
def calculate_initial_agent_stats(
|
||||
decomp_answer_results: list[QuestionAnswerResults],
|
||||
original_question_stats: AgentChunkStats,
|
||||
) -> InitialAgentResultStats:
|
||||
initial_agent_result_stats: InitialAgentResultStats = InitialAgentResultStats(
|
||||
sub_questions={},
|
||||
original_question={},
|
||||
agent_effectiveness={},
|
||||
)
|
||||
|
||||
orig_verified = original_question_stats.verified_count
|
||||
orig_support_score = original_question_stats.verified_avg_scores
|
||||
|
||||
verified_document_chunk_ids = []
|
||||
support_scores = 0.0
|
||||
|
||||
for decomp_answer_result in decomp_answer_results:
|
||||
verified_document_chunk_ids += (
|
||||
decomp_answer_result.sub_question_retrieval_stats.verified_doc_chunk_ids
|
||||
)
|
||||
if (
|
||||
decomp_answer_result.sub_question_retrieval_stats.verified_avg_scores
|
||||
is not None
|
||||
):
|
||||
support_scores += (
|
||||
decomp_answer_result.sub_question_retrieval_stats.verified_avg_scores
|
||||
)
|
||||
|
||||
verified_document_chunk_ids = list(set(verified_document_chunk_ids))
|
||||
|
||||
# Calculate sub-question stats
|
||||
if (
|
||||
verified_document_chunk_ids
|
||||
and len(verified_document_chunk_ids) > 0
|
||||
and support_scores is not None
|
||||
):
|
||||
sub_question_stats: dict[str, float | int | None] = {
|
||||
"num_verified_documents": len(verified_document_chunk_ids),
|
||||
"verified_avg_score": float(support_scores / len(decomp_answer_results)),
|
||||
}
|
||||
else:
|
||||
sub_question_stats = {"num_verified_documents": 0, "verified_avg_score": None}
|
||||
|
||||
initial_agent_result_stats.sub_questions.update(sub_question_stats)
|
||||
|
||||
# Get original question stats
|
||||
initial_agent_result_stats.original_question.update(
|
||||
{
|
||||
"num_verified_documents": original_question_stats.verified_count,
|
||||
"verified_avg_score": original_question_stats.verified_avg_scores,
|
||||
}
|
||||
)
|
||||
|
||||
# Calculate chunk utilization ratio
|
||||
sub_verified = initial_agent_result_stats.sub_questions["num_verified_documents"]
|
||||
|
||||
chunk_ratio: float | None = None
|
||||
if sub_verified is not None and orig_verified is not None and orig_verified > 0:
|
||||
chunk_ratio = (float(sub_verified) / orig_verified) if sub_verified > 0 else 0.0
|
||||
elif sub_verified is not None and sub_verified > 0:
|
||||
chunk_ratio = 10.0
|
||||
|
||||
initial_agent_result_stats.agent_effectiveness["utilized_chunk_ratio"] = chunk_ratio
|
||||
|
||||
if (
|
||||
orig_support_score is None
|
||||
or orig_support_score == 0.0
|
||||
and initial_agent_result_stats.sub_questions["verified_avg_score"] is None
|
||||
):
|
||||
initial_agent_result_stats.agent_effectiveness["support_ratio"] = None
|
||||
elif orig_support_score is None or orig_support_score == 0.0:
|
||||
initial_agent_result_stats.agent_effectiveness["support_ratio"] = 10
|
||||
elif initial_agent_result_stats.sub_questions["verified_avg_score"] is None:
|
||||
initial_agent_result_stats.agent_effectiveness["support_ratio"] = 0
|
||||
else:
|
||||
initial_agent_result_stats.agent_effectiveness["support_ratio"] = (
|
||||
initial_agent_result_stats.sub_questions["verified_avg_score"]
|
||||
/ orig_support_score
|
||||
)
|
||||
|
||||
return initial_agent_result_stats
|
||||
|
||||
|
||||
def get_query_info(results: list[QueryResult]) -> SearchQueryInfo:
|
||||
# Use the query info from the base document retrieval
|
||||
# TODO: see if this is the right way to do this
|
||||
query_infos = [
|
||||
result.query_info for result in results if result.query_info is not None
|
||||
]
|
||||
if len(query_infos) == 0:
|
||||
raise ValueError("No query info found")
|
||||
return query_infos[0]
|
||||
171
backend/onyx/agents/agent_search/deep_search_a/main/states.py
Normal file
171
backend/onyx/agents/agent_search/deep_search_a/main/states.py
Normal file
@@ -0,0 +1,171 @@
|
||||
from datetime import datetime
|
||||
from operator import add
|
||||
from typing import Annotated
|
||||
from typing import TypedDict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.agents.agent_search.core_state import CoreState
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.models import (
|
||||
ExpandedRetrievalResult,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.models import AgentBaseMetrics
|
||||
from onyx.agents.agent_search.deep_search_a.main.models import AgentRefinedMetrics
|
||||
from onyx.agents.agent_search.deep_search_a.main.models import FollowUpSubQuestion
|
||||
from onyx.agents.agent_search.orchestration.states import ToolCallUpdate
|
||||
from onyx.agents.agent_search.orchestration.states import ToolChoiceInput
|
||||
from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkStats
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import (
|
||||
EntityRelationshipTermExtraction,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import InitialAgentResultStats
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import QueryResult
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import (
|
||||
QuestionAnswerResults,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import RefinedAgentStats
|
||||
from onyx.agents.agent_search.shared_graph_utils.operators import (
|
||||
dedup_inference_sections,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.operators import (
|
||||
dedup_question_answer_results,
|
||||
)
|
||||
from onyx.context.search.models import InferenceSection
|
||||
|
||||
### States ###
|
||||
|
||||
## Update States
|
||||
|
||||
|
||||
class LoggerUpdate(BaseModel):
|
||||
log_messages: Annotated[list[str], add] = []
|
||||
|
||||
|
||||
class RefinedAgentStartStats(BaseModel):
|
||||
agent_refined_start_time: datetime | None = None
|
||||
|
||||
|
||||
class RefinedAgentEndStats(BaseModel):
|
||||
agent_refined_end_time: datetime | None = None
|
||||
agent_refined_metrics: AgentRefinedMetrics = AgentRefinedMetrics()
|
||||
|
||||
|
||||
class BaseDecompUpdate(RefinedAgentStartStats, RefinedAgentEndStats):
|
||||
agent_start_time: datetime = datetime.now()
|
||||
initial_decomp_questions: list[str] = []
|
||||
|
||||
|
||||
class ExploratorySearchUpdate(LoggerUpdate):
|
||||
exploratory_search_results: list[InferenceSection] = []
|
||||
|
||||
|
||||
class AnswerComparison(LoggerUpdate):
|
||||
refined_answer_improvement: bool = False
|
||||
|
||||
|
||||
class RoutingDecision(LoggerUpdate):
|
||||
routing: str = ""
|
||||
sample_doc_str: str = ""
|
||||
|
||||
|
||||
class InitialAnswerBASEUpdate(BaseModel):
|
||||
initial_base_answer: str = ""
|
||||
|
||||
|
||||
class InitialAnswerUpdate(LoggerUpdate):
|
||||
initial_answer: str = ""
|
||||
initial_agent_stats: InitialAgentResultStats | None = None
|
||||
generated_sub_questions: list[str] = []
|
||||
agent_base_end_time: datetime | None = None
|
||||
agent_base_metrics: AgentBaseMetrics | None = None
|
||||
|
||||
|
||||
class RefinedAnswerUpdate(RefinedAgentEndStats):
|
||||
refined_answer: str = ""
|
||||
refined_agent_stats: RefinedAgentStats | None = None
|
||||
refined_answer_quality: bool = False
|
||||
|
||||
|
||||
class InitialAnswerQualityUpdate(LoggerUpdate):
|
||||
initial_answer_quality: bool = False
|
||||
|
||||
|
||||
class RequireRefinedAnswerUpdate(LoggerUpdate):
|
||||
require_refined_answer: bool = True
|
||||
|
||||
|
||||
class DecompAnswersUpdate(LoggerUpdate):
|
||||
documents: Annotated[list[InferenceSection], dedup_inference_sections] = []
|
||||
context_documents: Annotated[list[InferenceSection], dedup_inference_sections] = []
|
||||
decomp_answer_results: Annotated[
|
||||
list[QuestionAnswerResults], dedup_question_answer_results
|
||||
] = []
|
||||
|
||||
|
||||
class FollowUpDecompAnswersUpdate(LoggerUpdate):
|
||||
refined_documents: Annotated[list[InferenceSection], dedup_inference_sections] = []
|
||||
refined_decomp_answer_results: Annotated[list[QuestionAnswerResults], add] = []
|
||||
|
||||
|
||||
class ExpandedRetrievalUpdate(LoggerUpdate):
|
||||
all_original_question_documents: Annotated[
|
||||
list[InferenceSection], dedup_inference_sections
|
||||
]
|
||||
original_question_retrieval_results: list[QueryResult] = []
|
||||
original_question_retrieval_stats: AgentChunkStats = AgentChunkStats()
|
||||
|
||||
|
||||
class EntityTermExtractionUpdate(LoggerUpdate):
|
||||
entity_relation_term_extractions: EntityRelationshipTermExtraction = (
|
||||
EntityRelationshipTermExtraction()
|
||||
)
|
||||
|
||||
|
||||
class FollowUpSubQuestionsUpdate(RefinedAgentStartStats):
|
||||
refined_sub_questions: dict[int, FollowUpSubQuestion] = {}
|
||||
|
||||
|
||||
## Graph Input State
|
||||
## Graph Input State
|
||||
|
||||
|
||||
class MainInput(CoreState):
|
||||
pass
|
||||
|
||||
|
||||
## Graph State
|
||||
|
||||
|
||||
class MainState(
|
||||
# This includes the core state
|
||||
MainInput,
|
||||
ToolChoiceInput,
|
||||
ToolCallUpdate,
|
||||
ToolChoiceUpdate,
|
||||
BaseDecompUpdate,
|
||||
InitialAnswerUpdate,
|
||||
InitialAnswerBASEUpdate,
|
||||
DecompAnswersUpdate,
|
||||
ExpandedRetrievalUpdate,
|
||||
EntityTermExtractionUpdate,
|
||||
InitialAnswerQualityUpdate,
|
||||
RequireRefinedAnswerUpdate,
|
||||
FollowUpSubQuestionsUpdate,
|
||||
FollowUpDecompAnswersUpdate,
|
||||
RefinedAnswerUpdate,
|
||||
RefinedAgentStartStats,
|
||||
RefinedAgentEndStats,
|
||||
RoutingDecision,
|
||||
AnswerComparison,
|
||||
ExploratorySearchUpdate,
|
||||
):
|
||||
# expanded_retrieval_result: Annotated[list[ExpandedRetrievalResult], add]
|
||||
base_raw_search_result: Annotated[list[ExpandedRetrievalResult], add]
|
||||
|
||||
|
||||
## Graph Output State - presently not used
|
||||
|
||||
|
||||
class MainOutput(TypedDict):
|
||||
log_messages: list[str]
|
||||
93
backend/onyx/agents/agent_search/models.py
Normal file
93
backend/onyx/agents/agent_search/models.py
Normal file
@@ -0,0 +1,93 @@
|
||||
from dataclasses import dataclass
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import model_validator
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.context.search.models import SearchRequest
|
||||
from onyx.file_store.utils import InMemoryChatFile
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.tools.force import ForceUseTool
|
||||
from onyx.tools.tool import Tool
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentSearchConfig:
|
||||
"""
|
||||
Configuration for the Agent Search feature.
|
||||
"""
|
||||
|
||||
# The search request that was used to generate the Pro Search
|
||||
search_request: SearchRequest
|
||||
|
||||
primary_llm: LLM
|
||||
fast_llm: LLM
|
||||
|
||||
# Whether to force use of a tool, or to
|
||||
# force tool args IF the tool is used
|
||||
force_use_tool: ForceUseTool
|
||||
|
||||
# contains message history for the current chat session
|
||||
# has the following (at most one is non-None)
|
||||
# message_history: list[PreviousMessage] | None = None
|
||||
# single_message_history: str | None = None
|
||||
prompt_builder: AnswerPromptBuilder
|
||||
|
||||
search_tool: SearchTool | None = None
|
||||
|
||||
use_agentic_search: bool = False
|
||||
|
||||
# For persisting agent search data
|
||||
chat_session_id: UUID | None = None
|
||||
|
||||
# The message ID of the user message that triggered the Pro Search
|
||||
message_id: int | None = None
|
||||
|
||||
# Whether to persistence data for the Pro Search (turned off for testing)
|
||||
use_persistence: bool = True
|
||||
|
||||
# The database session for the Pro Search
|
||||
db_session: Session | None = None
|
||||
|
||||
# Whether to perform initial search to inform decomposition
|
||||
# perform_initial_search_path_decision: bool = True
|
||||
|
||||
# Whether to perform initial search to inform decomposition
|
||||
perform_initial_search_decomposition: bool = True
|
||||
|
||||
# Whether to allow creation of refinement questions (and entity extraction, etc.)
|
||||
allow_refinement: bool = True
|
||||
|
||||
# Tools available for use
|
||||
tools: list[Tool] | None = None
|
||||
|
||||
using_tool_calling_llm: bool = False
|
||||
|
||||
files: list[InMemoryChatFile] | None = None
|
||||
|
||||
structured_response_format: dict | None = None
|
||||
|
||||
skip_gen_ai_answer_generation: bool = False
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_db_session(self) -> "AgentSearchConfig":
|
||||
if self.use_persistence and self.db_session is None:
|
||||
raise ValueError(
|
||||
"db_session must be provided for pro search when using persistence"
|
||||
)
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_search_tool(self) -> "AgentSearchConfig":
|
||||
if self.use_agentic_search and self.search_tool is None:
|
||||
raise ValueError("search_tool must be provided for agentic search")
|
||||
return self
|
||||
|
||||
|
||||
class AgentDocumentCitations(BaseModel):
|
||||
document_id: str
|
||||
document_title: str
|
||||
link: str
|
||||
@@ -0,0 +1,72 @@
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import AIMessageChunk
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.basic.states import BasicOutput
|
||||
from onyx.agents.agent_search.basic.states import BasicState
|
||||
from onyx.agents.agent_search.basic.utils import process_llm_stream
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.chat.models import LlmDoc
|
||||
from onyx.tools.tool_implementations.search.search_tool import (
|
||||
SEARCH_DOC_CONTENT_ID,
|
||||
)
|
||||
from onyx.tools.tool_implementations.search_like_tool_utils import (
|
||||
FINAL_CONTEXT_DOCUMENTS_ID,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def basic_use_tool_response(state: BasicState, config: RunnableConfig) -> BasicOutput:
|
||||
agent_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
structured_response_format = agent_config.structured_response_format
|
||||
llm = agent_config.primary_llm
|
||||
tool_choice = state.tool_choice
|
||||
if tool_choice is None:
|
||||
raise ValueError("Tool choice is None")
|
||||
tool = tool_choice.tool
|
||||
prompt_builder = agent_config.prompt_builder
|
||||
if state.tool_call_output is None:
|
||||
raise ValueError("Tool call output is None")
|
||||
tool_call_output = state.tool_call_output
|
||||
tool_call_summary = tool_call_output.tool_call_summary
|
||||
tool_call_responses = tool_call_output.tool_call_responses
|
||||
|
||||
new_prompt_builder = tool.build_next_prompt(
|
||||
prompt_builder=prompt_builder,
|
||||
tool_call_summary=tool_call_summary,
|
||||
tool_responses=tool_call_responses,
|
||||
using_tool_calling_llm=agent_config.using_tool_calling_llm,
|
||||
)
|
||||
|
||||
final_search_results = []
|
||||
initial_search_results = []
|
||||
for yield_item in tool_call_responses:
|
||||
if yield_item.id == FINAL_CONTEXT_DOCUMENTS_ID:
|
||||
final_search_results = cast(list[LlmDoc], yield_item.response)
|
||||
elif yield_item.id == SEARCH_DOC_CONTENT_ID:
|
||||
search_contexts = yield_item.response.contexts
|
||||
for doc in search_contexts:
|
||||
if doc.document_id not in initial_search_results:
|
||||
initial_search_results.append(doc)
|
||||
|
||||
initial_search_results = cast(list[LlmDoc], initial_search_results)
|
||||
|
||||
new_tool_call_chunk = AIMessageChunk(content="")
|
||||
if not agent_config.skip_gen_ai_answer_generation:
|
||||
stream = llm.stream(
|
||||
prompt=new_prompt_builder.build(),
|
||||
structured_response_format=structured_response_format,
|
||||
)
|
||||
|
||||
# For now, we don't do multiple tool calls, so we ignore the tool_message
|
||||
new_tool_call_chunk = process_llm_stream(
|
||||
stream,
|
||||
True,
|
||||
final_search_results=final_search_results,
|
||||
displayed_search_results=initial_search_results,
|
||||
)
|
||||
|
||||
return BasicOutput(tool_call_chunk=new_tool_call_chunk)
|
||||
@@ -0,0 +1,144 @@
|
||||
from typing import cast
|
||||
from uuid import uuid4
|
||||
|
||||
from langchain_core.messages import ToolCall
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.basic.utils import process_llm_stream
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.orchestration.states import ToolChoice
|
||||
from onyx.agents.agent_search.orchestration.states import ToolChoiceState
|
||||
from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.chat.tool_handling.tool_response_handler import get_tool_by_name
|
||||
from onyx.chat.tool_handling.tool_response_handler import (
|
||||
get_tool_call_for_non_tool_calling_llm_impl,
|
||||
)
|
||||
from onyx.tools.tool import Tool
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
# TODO: break this out into an implementation function
|
||||
# and a function that handles extracting the necessary fields
|
||||
# from the state and config
|
||||
# TODO: fan-out to multiple tool call nodes? Make this configurable?
|
||||
def llm_tool_choice(state: ToolChoiceState, config: RunnableConfig) -> ToolChoiceUpdate:
|
||||
"""
|
||||
This node is responsible for calling the LLM to choose a tool. If no tool is chosen,
|
||||
The node MAY emit an answer, depending on whether state["should_stream_answer"] is set.
|
||||
"""
|
||||
should_stream_answer = state.should_stream_answer
|
||||
|
||||
agent_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
using_tool_calling_llm = agent_config.using_tool_calling_llm
|
||||
prompt_builder = state.prompt_snapshot or agent_config.prompt_builder
|
||||
|
||||
llm = agent_config.primary_llm
|
||||
skip_gen_ai_answer_generation = agent_config.skip_gen_ai_answer_generation
|
||||
|
||||
structured_response_format = agent_config.structured_response_format
|
||||
tools = [tool for tool in (agent_config.tools or []) if tool.name in state.tools]
|
||||
force_use_tool = agent_config.force_use_tool
|
||||
|
||||
tool, tool_args = None, None
|
||||
if force_use_tool.force_use and force_use_tool.args is not None:
|
||||
tool_name, tool_args = (
|
||||
force_use_tool.tool_name,
|
||||
force_use_tool.args,
|
||||
)
|
||||
tool = get_tool_by_name(tools, tool_name)
|
||||
|
||||
# special pre-logic for non-tool calling LLM case
|
||||
elif not using_tool_calling_llm and tools:
|
||||
chosen_tool_and_args = get_tool_call_for_non_tool_calling_llm_impl(
|
||||
force_use_tool=force_use_tool,
|
||||
tools=tools,
|
||||
prompt_builder=prompt_builder,
|
||||
llm=llm,
|
||||
)
|
||||
if chosen_tool_and_args:
|
||||
tool, tool_args = chosen_tool_and_args
|
||||
|
||||
# If we have a tool and tool args, we are redy to request a tool call.
|
||||
# This only happens if the tool call was forced or we are using a non-tool calling LLM.
|
||||
if tool and tool_args:
|
||||
return ToolChoiceUpdate(
|
||||
tool_choice=ToolChoice(
|
||||
tool=tool,
|
||||
tool_args=tool_args,
|
||||
id=str(uuid4()),
|
||||
),
|
||||
)
|
||||
|
||||
# if we're skipping gen ai answer generation, we should only
|
||||
# continue if we're forcing a tool call (which will be emitted by
|
||||
# the tool calling llm in the stream() below)
|
||||
if skip_gen_ai_answer_generation and not force_use_tool.force_use:
|
||||
return ToolChoiceUpdate(
|
||||
tool_choice=None,
|
||||
)
|
||||
|
||||
built_prompt = (
|
||||
prompt_builder.build()
|
||||
if isinstance(prompt_builder, AnswerPromptBuilder)
|
||||
else prompt_builder.built_prompt
|
||||
)
|
||||
# At this point, we are either using a tool calling LLM or we are skipping the tool call.
|
||||
# DEBUG: good breakpoint
|
||||
stream = llm.stream(
|
||||
# For tool calling LLMs, we want to insert the task prompt as part of this flow, this is because the LLM
|
||||
# may choose to not call any tools and just generate the answer, in which case the task prompt is needed.
|
||||
prompt=built_prompt,
|
||||
tools=[tool.tool_definition() for tool in tools] or None,
|
||||
tool_choice=("required" if tools and force_use_tool.force_use else None),
|
||||
structured_response_format=structured_response_format,
|
||||
)
|
||||
|
||||
tool_message = process_llm_stream(
|
||||
stream, should_stream_answer and not agent_config.skip_gen_ai_answer_generation
|
||||
)
|
||||
|
||||
# If no tool calls are emitted by the LLM, we should not choose a tool
|
||||
if len(tool_message.tool_calls) == 0:
|
||||
logger.info("No tool calls emitted by LLM")
|
||||
return ToolChoiceUpdate(
|
||||
tool_choice=None,
|
||||
)
|
||||
|
||||
# TODO: here we could handle parallel tool calls. Right now
|
||||
# we just pick the first one that matches.
|
||||
selected_tool: Tool | None = None
|
||||
selected_tool_call_request: ToolCall | None = None
|
||||
for tool_call_request in tool_message.tool_calls:
|
||||
known_tools_by_name = [
|
||||
tool for tool in tools if tool.name == tool_call_request["name"]
|
||||
]
|
||||
|
||||
if known_tools_by_name:
|
||||
selected_tool = known_tools_by_name[0]
|
||||
selected_tool_call_request = tool_call_request
|
||||
break
|
||||
|
||||
logger.error(
|
||||
"Tool call requested with unknown name field. \n"
|
||||
f"tools: {tools}"
|
||||
f"tool_call_request: {tool_call_request}"
|
||||
)
|
||||
|
||||
if not selected_tool or not selected_tool_call_request:
|
||||
raise ValueError(
|
||||
f"Tool call attempted with tool {selected_tool}, request {selected_tool_call_request}"
|
||||
)
|
||||
|
||||
logger.info(f"Selected tool: {selected_tool.name}")
|
||||
logger.debug(f"Selected tool call request: {selected_tool_call_request}")
|
||||
|
||||
return ToolChoiceUpdate(
|
||||
tool_choice=ToolChoice(
|
||||
tool=selected_tool,
|
||||
tool_args=selected_tool_call_request["args"],
|
||||
id=selected_tool_call_request["id"],
|
||||
),
|
||||
)
|
||||
@@ -0,0 +1,17 @@
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.orchestration.states import ToolChoiceInput
|
||||
|
||||
|
||||
def prepare_tool_input(state: Any, config: RunnableConfig) -> ToolChoiceInput:
|
||||
agent_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
return ToolChoiceInput(
|
||||
# NOTE: this node is used at the top level of the agent, so we always stream
|
||||
should_stream_answer=True,
|
||||
prompt_snapshot=None, # uses default prompt builder
|
||||
tools=[tool.name for tool in (agent_config.tools or [])],
|
||||
)
|
||||
@@ -0,0 +1,66 @@
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.callbacks.manager import dispatch_custom_event
|
||||
from langchain_core.messages import AIMessageChunk
|
||||
from langchain_core.messages.tool import ToolCall
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.orchestration.states import ToolCallOutput
|
||||
from onyx.agents.agent_search.orchestration.states import ToolCallUpdate
|
||||
from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
|
||||
from onyx.chat.models import AnswerPacket
|
||||
from onyx.tools.message import build_tool_message
|
||||
from onyx.tools.message import ToolCallSummary
|
||||
from onyx.tools.tool_runner import ToolRunner
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def emit_packet(packet: AnswerPacket) -> None:
|
||||
dispatch_custom_event("basic_response", packet)
|
||||
|
||||
|
||||
def tool_call(state: ToolChoiceUpdate, config: RunnableConfig) -> ToolCallUpdate:
|
||||
"""Calls the tool specified in the state and updates the state with the result"""
|
||||
|
||||
cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
|
||||
tool_choice = state.tool_choice
|
||||
if tool_choice is None:
|
||||
raise ValueError("Cannot invoke tool call node without a tool choice")
|
||||
|
||||
tool = tool_choice.tool
|
||||
tool_args = tool_choice.tool_args
|
||||
tool_id = tool_choice.id
|
||||
tool_runner = ToolRunner(tool, tool_args)
|
||||
tool_kickoff = tool_runner.kickoff()
|
||||
|
||||
# TODO: custom events for yields
|
||||
emit_packet(tool_kickoff)
|
||||
|
||||
tool_responses = []
|
||||
for response in tool_runner.tool_responses():
|
||||
tool_responses.append(response)
|
||||
emit_packet(response)
|
||||
|
||||
tool_final_result = tool_runner.tool_final_result()
|
||||
emit_packet(tool_final_result)
|
||||
|
||||
tool_call = ToolCall(name=tool.name, args=tool_args, id=tool_id)
|
||||
tool_call_summary = ToolCallSummary(
|
||||
tool_call_request=AIMessageChunk(content="", tool_calls=[tool_call]),
|
||||
tool_call_result=build_tool_message(
|
||||
tool_call, tool_runner.tool_message_content()
|
||||
),
|
||||
)
|
||||
|
||||
tool_call_output = ToolCallOutput(
|
||||
tool_call_summary=tool_call_summary,
|
||||
tool_call_kickoff=tool_kickoff,
|
||||
tool_call_responses=tool_responses,
|
||||
tool_call_final_result=tool_final_result,
|
||||
)
|
||||
return ToolCallUpdate(tool_call_output=tool_call_output)
|
||||
48
backend/onyx/agents/agent_search/orchestration/states.py
Normal file
48
backend/onyx/agents/agent_search/orchestration/states.py
Normal file
@@ -0,0 +1,48 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import PromptSnapshot
|
||||
from onyx.tools.message import ToolCallSummary
|
||||
from onyx.tools.models import ToolCallFinalResult
|
||||
from onyx.tools.models import ToolCallKickoff
|
||||
from onyx.tools.models import ToolResponse
|
||||
from onyx.tools.tool import Tool
|
||||
|
||||
|
||||
# TODO: adapt the tool choice/tool call to allow for parallel tool calls by
|
||||
# creating a subgraph that can be invoked in parallel via Send/Command APIs
|
||||
class ToolChoiceInput(BaseModel):
|
||||
should_stream_answer: bool = True
|
||||
# default to the prompt builder from the config, but
|
||||
# allow overrides for arbitrary tool calls
|
||||
prompt_snapshot: PromptSnapshot | None = None
|
||||
|
||||
# names of tools to use for tool calling. Filters the tools available in the config
|
||||
tools: list[str] = []
|
||||
|
||||
|
||||
class ToolCallOutput(BaseModel):
|
||||
tool_call_summary: ToolCallSummary
|
||||
tool_call_kickoff: ToolCallKickoff
|
||||
tool_call_responses: list[ToolResponse]
|
||||
tool_call_final_result: ToolCallFinalResult
|
||||
|
||||
|
||||
class ToolCallUpdate(BaseModel):
|
||||
tool_call_output: ToolCallOutput | None = None
|
||||
|
||||
|
||||
class ToolChoice(BaseModel):
|
||||
tool: Tool
|
||||
tool_args: dict
|
||||
id: str | None
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
class ToolChoiceUpdate(BaseModel):
|
||||
tool_choice: ToolChoice | None = None
|
||||
|
||||
|
||||
class ToolChoiceState(ToolChoiceUpdate, ToolChoiceInput):
|
||||
pass
|
||||
274
backend/onyx/agents/agent_search/run_graph.py
Normal file
274
backend/onyx/agents/agent_search/run_graph.py
Normal file
@@ -0,0 +1,274 @@
|
||||
import asyncio
|
||||
from collections.abc import AsyncIterable
|
||||
from collections.abc import Iterable
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.runnables.schema import StreamEvent
|
||||
from langgraph.graph.state import CompiledStateGraph
|
||||
|
||||
from onyx.agents.agent_search.basic.graph_builder import basic_graph_builder
|
||||
from onyx.agents.agent_search.basic.states import BasicInput
|
||||
from onyx.agents.agent_search.deep_search_a.main.graph_builder import (
|
||||
main_graph_builder as main_graph_builder_a,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import MainInput as MainInput_a
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
|
||||
from onyx.chat.models import AgentAnswerPiece
|
||||
from onyx.chat.models import AnswerPacket
|
||||
from onyx.chat.models import AnswerStream
|
||||
from onyx.chat.models import ExtendedToolResponse
|
||||
from onyx.chat.models import RefinedAnswerImprovement
|
||||
from onyx.chat.models import StreamStopInfo
|
||||
from onyx.chat.models import SubQueryPiece
|
||||
from onyx.chat.models import SubQuestionPiece
|
||||
from onyx.chat.models import ToolResponse
|
||||
from onyx.configs.agent_configs import GRAPH_NAME
|
||||
from onyx.context.search.models import SearchRequest
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.tools.tool_runner import ToolCallKickoff
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_COMPILED_GRAPH: CompiledStateGraph | None = None
|
||||
|
||||
|
||||
def _set_combined_token_value(
|
||||
combined_token: str, parsed_object: AgentAnswerPiece
|
||||
) -> AgentAnswerPiece:
|
||||
parsed_object.answer_piece = combined_token
|
||||
|
||||
return parsed_object
|
||||
|
||||
|
||||
def _parse_agent_event(
|
||||
event: StreamEvent,
|
||||
) -> AnswerPacket | None:
|
||||
"""
|
||||
Parse the event into a typed object.
|
||||
Return None if we are not interested in the event.
|
||||
"""
|
||||
|
||||
event_type = event["event"]
|
||||
|
||||
# We always just yield the event data, but this piece is useful for two development reasons:
|
||||
# 1. It's a list of the names of every place we dispatch a custom event
|
||||
# 2. We maintain the intended types yielded by each event
|
||||
if event_type == "on_custom_event":
|
||||
# TODO: different AnswerStream types for different events
|
||||
if event["name"] == "decomp_qs":
|
||||
return cast(SubQuestionPiece, event["data"])
|
||||
elif event["name"] == "subqueries":
|
||||
return cast(SubQueryPiece, event["data"])
|
||||
elif event["name"] == "sub_answers":
|
||||
return cast(AgentAnswerPiece, event["data"])
|
||||
elif event["name"] == "stream_finished":
|
||||
return cast(StreamStopInfo, event["data"])
|
||||
elif event["name"] == "initial_agent_answer":
|
||||
return cast(AgentAnswerPiece, event["data"])
|
||||
elif event["name"] == "refined_agent_answer":
|
||||
return cast(AgentAnswerPiece, event["data"])
|
||||
elif event["name"] == "start_refined_answer_creation":
|
||||
return cast(ToolCallKickoff, event["data"])
|
||||
elif event["name"] == "tool_response":
|
||||
return cast(ToolResponse, event["data"])
|
||||
elif event["name"] == "basic_response":
|
||||
return cast(AnswerPacket, event["data"])
|
||||
elif event["name"] == "refined_answer_improvement":
|
||||
return cast(RefinedAnswerImprovement, event["data"])
|
||||
return None
|
||||
|
||||
|
||||
# https://stackoverflow.com/questions/60226557/how-to-forcefully-close-an-async-generator
|
||||
# https://stackoverflow.com/questions/40897428/please-explain-task-was-destroyed-but-it-is-pending-after-cancelling-tasks
|
||||
task_references: set[asyncio.Task[StreamEvent]] = set()
|
||||
|
||||
|
||||
def _manage_async_event_streaming(
|
||||
compiled_graph: CompiledStateGraph,
|
||||
config: AgentSearchConfig | None,
|
||||
graph_input: MainInput_a | BasicInput,
|
||||
) -> Iterable[StreamEvent]:
|
||||
async def _run_async_event_stream() -> AsyncIterable[StreamEvent]:
|
||||
message_id = config.message_id if config else None
|
||||
async for event in compiled_graph.astream_events(
|
||||
input=graph_input,
|
||||
config={"metadata": {"config": config, "thread_id": str(message_id)}},
|
||||
# debug=True,
|
||||
# indicating v2 here deserves further scrutiny
|
||||
version="v2",
|
||||
):
|
||||
yield event
|
||||
|
||||
# This might be able to be simplified
|
||||
def _yield_async_to_sync() -> Iterable[StreamEvent]:
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
# Get the async generator
|
||||
async_gen = _run_async_event_stream()
|
||||
# Convert to AsyncIterator
|
||||
async_iter = async_gen.__aiter__()
|
||||
while True:
|
||||
try:
|
||||
# Create a coroutine by calling anext with the async iterator
|
||||
next_coro = anext(async_iter)
|
||||
task = asyncio.ensure_future(next_coro, loop=loop)
|
||||
task_references.add(task)
|
||||
# Run the coroutine to get the next event
|
||||
event = loop.run_until_complete(task)
|
||||
yield event
|
||||
except (StopAsyncIteration, GeneratorExit):
|
||||
break
|
||||
finally:
|
||||
try:
|
||||
for task in task_references.pop():
|
||||
task.cancel()
|
||||
except StopAsyncIteration:
|
||||
pass
|
||||
loop.close()
|
||||
|
||||
return _yield_async_to_sync()
|
||||
|
||||
|
||||
def run_graph(
|
||||
compiled_graph: CompiledStateGraph,
|
||||
config: AgentSearchConfig,
|
||||
input: BasicInput | MainInput_a,
|
||||
) -> AnswerStream:
|
||||
# TODO: add these to the environment
|
||||
# config.perform_initial_search_path_decision = False
|
||||
config.perform_initial_search_decomposition = True
|
||||
config.allow_refinement = True
|
||||
|
||||
for event in _manage_async_event_streaming(
|
||||
compiled_graph=compiled_graph, config=config, graph_input=input
|
||||
):
|
||||
if not (parsed_object := _parse_agent_event(event)):
|
||||
continue
|
||||
|
||||
yield parsed_object
|
||||
|
||||
|
||||
# TODO: call this once on startup, TBD where and if it should be gated based
|
||||
# on dev mode or not
|
||||
def load_compiled_graph(graph_name: str) -> CompiledStateGraph:
|
||||
main_graph_builder = (
|
||||
main_graph_builder_a if graph_name == "a" else main_graph_builder_a
|
||||
)
|
||||
global _COMPILED_GRAPH
|
||||
if _COMPILED_GRAPH is None:
|
||||
graph = main_graph_builder()
|
||||
_COMPILED_GRAPH = graph.compile()
|
||||
return _COMPILED_GRAPH
|
||||
|
||||
|
||||
def run_main_graph(
|
||||
config: AgentSearchConfig,
|
||||
graph_name: str = "a",
|
||||
) -> AnswerStream:
|
||||
compiled_graph = load_compiled_graph(graph_name)
|
||||
if graph_name == "a":
|
||||
input = MainInput_a(base_question=config.search_request.query, log_messages=[])
|
||||
else:
|
||||
input = MainInput_a(base_question=config.search_request.query, log_messages=[])
|
||||
|
||||
# Agent search is not a Tool per se, but this is helpful for the frontend
|
||||
yield ToolCallKickoff(
|
||||
tool_name="agent_search_0",
|
||||
tool_args={"query": config.search_request.query},
|
||||
)
|
||||
yield from run_graph(compiled_graph, config, input)
|
||||
|
||||
|
||||
# TODO: unify input types, especially prosearchconfig
|
||||
def run_basic_graph(
|
||||
config: AgentSearchConfig,
|
||||
) -> AnswerStream:
|
||||
graph = basic_graph_builder()
|
||||
compiled_graph = graph.compile()
|
||||
# TODO: unify basic input
|
||||
input = BasicInput()
|
||||
return run_graph(compiled_graph, config, input)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from onyx.llm.factory import get_default_llms
|
||||
|
||||
now_start = datetime.now()
|
||||
logger.debug(f"Start at {now_start}")
|
||||
|
||||
if GRAPH_NAME == "a":
|
||||
graph = main_graph_builder_a()
|
||||
else:
|
||||
graph = main_graph_builder_a()
|
||||
compiled_graph = graph.compile()
|
||||
now_end = datetime.now()
|
||||
logger.debug(f"Graph compiled in {now_end - now_start} seconds")
|
||||
primary_llm, fast_llm = get_default_llms()
|
||||
search_request = SearchRequest(
|
||||
# query="what can you do with gitlab?",
|
||||
# query="What are the guiding principles behind the development of cockroachDB",
|
||||
# query="What are the temperatures in Munich, Hawaii, and New York?",
|
||||
# query="When was Washington born?",
|
||||
# query="What is Onyx?",
|
||||
query="What is the difference between astronomy and astrology?",
|
||||
)
|
||||
# Joachim custom persona
|
||||
|
||||
with get_session_context_manager() as db_session:
|
||||
config, search_tool = get_test_config(
|
||||
db_session, primary_llm, fast_llm, search_request
|
||||
)
|
||||
# search_request.persona = get_persona_by_id(1, None, db_session)
|
||||
config.use_persistence = True
|
||||
# config.perform_initial_search_path_decision = False
|
||||
config.perform_initial_search_decomposition = True
|
||||
if GRAPH_NAME == "a":
|
||||
input = MainInput_a(
|
||||
base_question=config.search_request.query, log_messages=[]
|
||||
)
|
||||
else:
|
||||
input = MainInput_a(
|
||||
base_question=config.search_request.query, log_messages=[]
|
||||
)
|
||||
# with open("output.txt", "w") as f:
|
||||
tool_responses: list = []
|
||||
for output in run_graph(compiled_graph, config, input):
|
||||
# pass
|
||||
|
||||
if isinstance(output, ToolCallKickoff):
|
||||
pass
|
||||
elif isinstance(output, ExtendedToolResponse):
|
||||
tool_responses.append(output.response)
|
||||
logger.info(
|
||||
f" ---- ET {output.level} - {output.level_question_nr} | "
|
||||
)
|
||||
elif isinstance(output, SubQueryPiece):
|
||||
logger.info(
|
||||
f"Sq {output.level} - {output.level_question_nr} - {output.sub_query} | "
|
||||
)
|
||||
elif isinstance(output, SubQuestionPiece):
|
||||
logger.info(
|
||||
f"SQ {output.level} - {output.level_question_nr} - {output.sub_question} | "
|
||||
)
|
||||
elif (
|
||||
isinstance(output, AgentAnswerPiece)
|
||||
and output.answer_type == "agent_sub_answer"
|
||||
):
|
||||
logger.info(
|
||||
f" ---- SA {output.level} - {output.level_question_nr} {output.answer_piece} | "
|
||||
)
|
||||
elif (
|
||||
isinstance(output, AgentAnswerPiece)
|
||||
and output.answer_type == "agent_level_answer"
|
||||
):
|
||||
logger.info(
|
||||
f" ---------- FA {output.level} - {output.level_question_nr} {output.answer_piece} | "
|
||||
)
|
||||
elif isinstance(output, RefinedAnswerImprovement):
|
||||
logger.info(f" ---------- RE {output.refined_answer_improvement} | ")
|
||||
|
||||
# for tool_response in tool_responses:
|
||||
# logger.debug(tool_response)
|
||||
@@ -0,0 +1,100 @@
|
||||
from langchain.schema import AIMessage
|
||||
from langchain.schema import HumanMessage
|
||||
from langchain.schema import SystemMessage
|
||||
from langchain_core.messages.tool import ToolMessage
|
||||
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import BASE_RAG_PROMPT_v2
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import HISTORY_PROMPT
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import get_today_prompt
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.llm.interfaces import LLMConfig
|
||||
from onyx.llm.utils import get_max_input_tokens
|
||||
from onyx.natural_language_processing.utils import get_tokenizer
|
||||
from onyx.natural_language_processing.utils import tokenizer_trim_content
|
||||
|
||||
|
||||
def build_sub_question_answer_prompt(
|
||||
question: str,
|
||||
original_question: str,
|
||||
docs: list[InferenceSection],
|
||||
persona_specification: str,
|
||||
config: LLMConfig,
|
||||
) -> list[SystemMessage | HumanMessage | AIMessage | ToolMessage]:
|
||||
system_message = SystemMessage(
|
||||
content=persona_specification,
|
||||
)
|
||||
|
||||
date_str = get_today_prompt()
|
||||
|
||||
docs_format_list = [
|
||||
f"""Document Number: [D{doc_nr + 1}]\n
|
||||
Content: {doc.combined_content}\n\n"""
|
||||
for doc_nr, doc in enumerate(docs)
|
||||
]
|
||||
|
||||
docs_str = "\n\n".join(docs_format_list)
|
||||
|
||||
docs_str = trim_prompt_piece(
|
||||
config, docs_str, BASE_RAG_PROMPT_v2 + question + original_question + date_str
|
||||
)
|
||||
human_message = HumanMessage(
|
||||
content=BASE_RAG_PROMPT_v2.format(
|
||||
question=question,
|
||||
original_question=original_question,
|
||||
context=docs_str,
|
||||
date_prompt=date_str,
|
||||
)
|
||||
)
|
||||
|
||||
return [system_message, human_message]
|
||||
|
||||
|
||||
def trim_prompt_piece(config: LLMConfig, prompt_piece: str, reserved_str: str) -> str:
|
||||
# TODO: this truncating might add latency. We could do a rougher + faster check
|
||||
# first to determine whether truncation is needed
|
||||
|
||||
# TODO: maybe save the tokenizer and max input tokens if this is getting called multiple times?
|
||||
llm_tokenizer = get_tokenizer(
|
||||
provider_type=config.model_provider,
|
||||
model_name=config.model_name,
|
||||
)
|
||||
|
||||
max_tokens = get_max_input_tokens(
|
||||
model_provider=config.model_provider,
|
||||
model_name=config.model_name,
|
||||
)
|
||||
|
||||
# slightly conservative trimming
|
||||
return tokenizer_trim_content(
|
||||
content=prompt_piece,
|
||||
desired_length=max_tokens - len(llm_tokenizer.encode(reserved_str)),
|
||||
tokenizer=llm_tokenizer,
|
||||
)
|
||||
|
||||
|
||||
def build_history_prompt(prompt_builder: AnswerPromptBuilder | None) -> str:
|
||||
if prompt_builder is None:
|
||||
return ""
|
||||
|
||||
if prompt_builder.single_message_history is not None:
|
||||
history = prompt_builder.single_message_history
|
||||
else:
|
||||
history_components = []
|
||||
previous_message_type = None
|
||||
for message in prompt_builder.raw_message_history:
|
||||
if "user" in message.message_type:
|
||||
history_components.append(f"User: {message.message}\n")
|
||||
previous_message_type = "user"
|
||||
elif "assistant" in message.message_type:
|
||||
# only use the last agent answer for the history
|
||||
if previous_message_type != "assistant":
|
||||
history_components.append(f"You/Agent: {message.message}\n")
|
||||
else:
|
||||
history_components = history_components[:-1]
|
||||
history_components.append(f"You/Agent: {message.message}\n")
|
||||
previous_message_type = "assistant"
|
||||
else:
|
||||
continue
|
||||
history = "\n".join(history_components)
|
||||
return HISTORY_PROMPT.format(history=history) if history else ""
|
||||
@@ -0,0 +1,98 @@
|
||||
import numpy as np
|
||||
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import RetrievalFitScoreMetrics
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import RetrievalFitStats
|
||||
from onyx.chat.models import SectionRelevancePiece
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def unique_chunk_id(doc: InferenceSection) -> str:
|
||||
return f"{doc.center_chunk.document_id}_{doc.center_chunk.chunk_id}"
|
||||
|
||||
|
||||
def calculate_rank_shift(list1: list, list2: list, top_n: int = 20) -> float:
|
||||
shift = 0
|
||||
for rank_first, doc_id in enumerate(list1[:top_n], 1):
|
||||
try:
|
||||
rank_second = list2.index(doc_id) + 1
|
||||
except ValueError:
|
||||
rank_second = len(list2) # Document not found in second list
|
||||
|
||||
shift += np.abs(rank_first - rank_second) / np.log(1 + rank_first * rank_second)
|
||||
|
||||
return shift / top_n
|
||||
|
||||
|
||||
def get_fit_scores(
|
||||
pre_reranked_results: list[InferenceSection],
|
||||
post_reranked_results: list[InferenceSection] | list[SectionRelevancePiece],
|
||||
) -> RetrievalFitStats | None:
|
||||
"""
|
||||
Calculate retrieval metrics for search purposes
|
||||
"""
|
||||
|
||||
if len(pre_reranked_results) == 0 or len(post_reranked_results) == 0:
|
||||
return None
|
||||
|
||||
ranked_sections = {
|
||||
"initial": pre_reranked_results,
|
||||
"reranked": post_reranked_results,
|
||||
}
|
||||
|
||||
fit_eval: RetrievalFitStats = RetrievalFitStats(
|
||||
fit_score_lift=0,
|
||||
rerank_effect=0,
|
||||
fit_scores={
|
||||
"initial": RetrievalFitScoreMetrics(scores={}, chunk_ids=[]),
|
||||
"reranked": RetrievalFitScoreMetrics(scores={}, chunk_ids=[]),
|
||||
},
|
||||
)
|
||||
|
||||
for rank_type, docs in ranked_sections.items():
|
||||
logger.debug(f"rank_type: {rank_type}")
|
||||
|
||||
for i in [1, 5, 10]:
|
||||
fit_eval.fit_scores[rank_type].scores[str(i)] = (
|
||||
sum(
|
||||
[
|
||||
float(doc.center_chunk.score)
|
||||
for doc in docs[:i]
|
||||
if type(doc) == InferenceSection
|
||||
and doc.center_chunk.score is not None
|
||||
]
|
||||
)
|
||||
/ i
|
||||
)
|
||||
|
||||
fit_eval.fit_scores[rank_type].scores["fit_score"] = (
|
||||
1
|
||||
/ 3
|
||||
* (
|
||||
fit_eval.fit_scores[rank_type].scores["1"]
|
||||
+ fit_eval.fit_scores[rank_type].scores["5"]
|
||||
+ fit_eval.fit_scores[rank_type].scores["10"]
|
||||
)
|
||||
)
|
||||
|
||||
fit_eval.fit_scores[rank_type].scores["fit_score"] = fit_eval.fit_scores[
|
||||
rank_type
|
||||
].scores["1"]
|
||||
|
||||
fit_eval.fit_scores[rank_type].chunk_ids = [
|
||||
unique_chunk_id(doc) for doc in docs if type(doc) == InferenceSection
|
||||
]
|
||||
|
||||
fit_eval.fit_score_lift = (
|
||||
fit_eval.fit_scores["reranked"].scores["fit_score"]
|
||||
/ fit_eval.fit_scores["initial"].scores["fit_score"]
|
||||
)
|
||||
|
||||
fit_eval.rerank_effect = calculate_rank_shift(
|
||||
fit_eval.fit_scores["initial"].chunk_ids,
|
||||
fit_eval.fit_scores["reranked"].chunk_ids,
|
||||
)
|
||||
|
||||
return fit_eval
|
||||
113
backend/onyx/agents/agent_search/shared_graph_utils/models.py
Normal file
113
backend/onyx/agents/agent_search/shared_graph_utils/models.py
Normal file
@@ -0,0 +1,113 @@
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.main.models import AgentAdditionalMetrics
|
||||
from onyx.agents.agent_search.deep_search_a.main.models import AgentBaseMetrics
|
||||
from onyx.agents.agent_search.deep_search_a.main.models import AgentRefinedMetrics
|
||||
from onyx.agents.agent_search.deep_search_a.main.models import AgentTimings
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.tools.models import SearchQueryInfo
|
||||
|
||||
|
||||
# Pydantic models for structured outputs
|
||||
class RewrittenQueries(BaseModel):
|
||||
rewritten_queries: list[str]
|
||||
|
||||
|
||||
class BinaryDecision(BaseModel):
|
||||
decision: Literal["yes", "no"]
|
||||
|
||||
|
||||
class BinaryDecisionWithReasoning(BaseModel):
|
||||
reasoning: str
|
||||
decision: Literal["yes", "no"]
|
||||
|
||||
|
||||
class RetrievalFitScoreMetrics(BaseModel):
|
||||
scores: dict[str, float]
|
||||
chunk_ids: list[str]
|
||||
|
||||
|
||||
class RetrievalFitStats(BaseModel):
|
||||
fit_score_lift: float
|
||||
rerank_effect: float
|
||||
fit_scores: dict[str, RetrievalFitScoreMetrics]
|
||||
|
||||
|
||||
class AgentChunkScores(BaseModel):
|
||||
scores: dict[str, dict[str, list[int | float]]]
|
||||
|
||||
|
||||
class AgentChunkStats(BaseModel):
|
||||
verified_count: int | None = None
|
||||
verified_avg_scores: float | None = None
|
||||
rejected_count: int | None = None
|
||||
rejected_avg_scores: float | None = None
|
||||
verified_doc_chunk_ids: list[str] = []
|
||||
dismissed_doc_chunk_ids: list[str] = []
|
||||
|
||||
|
||||
class InitialAgentResultStats(BaseModel):
|
||||
sub_questions: dict[str, float | int | None]
|
||||
original_question: dict[str, float | int | None]
|
||||
agent_effectiveness: dict[str, float | int | None]
|
||||
|
||||
|
||||
class RefinedAgentStats(BaseModel):
|
||||
revision_doc_efficiency: float | None
|
||||
revision_question_efficiency: float | None
|
||||
|
||||
|
||||
class Term(BaseModel):
|
||||
term_name: str = ""
|
||||
term_type: str = ""
|
||||
term_similar_to: list[str] = []
|
||||
|
||||
|
||||
### Models ###
|
||||
|
||||
|
||||
class Entity(BaseModel):
|
||||
entity_name: str = ""
|
||||
entity_type: str = ""
|
||||
|
||||
|
||||
class Relationship(BaseModel):
|
||||
relationship_name: str = ""
|
||||
relationship_type: str = ""
|
||||
relationship_entities: list[str] = []
|
||||
|
||||
|
||||
class EntityRelationshipTermExtraction(BaseModel):
|
||||
entities: list[Entity] = []
|
||||
relationships: list[Relationship] = []
|
||||
terms: list[Term] = []
|
||||
|
||||
|
||||
### Models ###
|
||||
|
||||
|
||||
class QueryResult(BaseModel):
|
||||
query: str
|
||||
search_results: list[InferenceSection]
|
||||
stats: RetrievalFitStats | None
|
||||
query_info: SearchQueryInfo | None
|
||||
|
||||
|
||||
class QuestionAnswerResults(BaseModel):
|
||||
question: str
|
||||
question_id: str
|
||||
answer: str
|
||||
quality: str
|
||||
expanded_retrieval_results: list[QueryResult]
|
||||
documents: list[InferenceSection]
|
||||
context_documents: list[InferenceSection]
|
||||
sub_question_retrieval_stats: AgentChunkStats
|
||||
|
||||
|
||||
class CombinedAgentMetrics(BaseModel):
|
||||
timings: AgentTimings
|
||||
base_metrics: AgentBaseMetrics | None
|
||||
refined_metrics: AgentRefinedMetrics
|
||||
additional_metrics: AgentAdditionalMetrics
|
||||
@@ -0,0 +1,31 @@
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import (
|
||||
QuestionAnswerResults,
|
||||
)
|
||||
from onyx.chat.prune_and_merge import _merge_sections
|
||||
from onyx.context.search.models import InferenceSection
|
||||
|
||||
|
||||
def dedup_inference_sections(
|
||||
list1: list[InferenceSection], list2: list[InferenceSection]
|
||||
) -> list[InferenceSection]:
|
||||
deduped = _merge_sections(list1 + list2)
|
||||
return deduped
|
||||
|
||||
|
||||
def dedup_question_answer_results(
|
||||
question_answer_results_1: list[QuestionAnswerResults],
|
||||
question_answer_results_2: list[QuestionAnswerResults],
|
||||
) -> list[QuestionAnswerResults]:
|
||||
deduped_question_answer_results: list[
|
||||
QuestionAnswerResults
|
||||
] = question_answer_results_1
|
||||
utilized_question_ids: set[str] = set(
|
||||
[x.question_id for x in question_answer_results_1]
|
||||
)
|
||||
|
||||
for question_answer_result in question_answer_results_2:
|
||||
if question_answer_result.question_id not in utilized_question_ids:
|
||||
deduped_question_answer_results.append(question_answer_result)
|
||||
utilized_question_ids.add(question_answer_result.question_id)
|
||||
|
||||
return deduped_question_answer_results
|
||||
1076
backend/onyx/agents/agent_search/shared_graph_utils/prompts.py
Normal file
1076
backend/onyx/agents/agent_search/shared_graph_utils/prompts.py
Normal file
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user