mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-17 15:55:45 +00:00
Compare commits
27 Commits
error_supp
...
improved_u
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e55dd89444 | ||
|
|
811d564a0f | ||
|
|
206b247ca5 | ||
|
|
cd445f0e3f | ||
|
|
967b0e5b0f | ||
|
|
fac9525833 | ||
|
|
79fc4ae47d | ||
|
|
1c09c75e5f | ||
|
|
23ec33e411 | ||
|
|
fea429e11b | ||
|
|
f9d7d21d8e | ||
|
|
d2a8938545 | ||
|
|
ac909f8437 | ||
|
|
2f32111169 | ||
|
|
23acb163f5 | ||
|
|
ebe15b42d2 | ||
|
|
6cbb237945 | ||
|
|
6803548066 | ||
|
|
1111ce6ce4 | ||
|
|
3f68e8ea8e | ||
|
|
06a8373ff4 | ||
|
|
86e770d968 | ||
|
|
f11216132e | ||
|
|
1f7d05cd75 | ||
|
|
c8bf051fb6 | ||
|
|
14b54db033 | ||
|
|
0e9f9301ba |
2
.github/workflows/pr-chromatic-tests.yml
vendored
2
.github/workflows/pr-chromatic-tests.yml
vendored
@@ -8,8 +8,6 @@ on: push
|
||||
env:
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
|
||||
GEN_AI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
MOCK_LLM_RESPONSE: true
|
||||
|
||||
jobs:
|
||||
playwright-tests:
|
||||
|
||||
22
.github/workflows/pr-helm-chart-testing.yml
vendored
22
.github/workflows/pr-helm-chart-testing.yml
vendored
@@ -21,10 +21,10 @@ jobs:
|
||||
- name: Set up Helm
|
||||
uses: azure/setup-helm@v4.2.0
|
||||
with:
|
||||
version: v3.17.0
|
||||
version: v3.14.4
|
||||
|
||||
- name: Set up chart-testing
|
||||
uses: helm/chart-testing-action@v2.7.0
|
||||
uses: helm/chart-testing-action@v2.6.1
|
||||
|
||||
# even though we specify chart-dirs in ct.yaml, it isn't used by ct for the list-changed command...
|
||||
- name: Run chart-testing (list-changed)
|
||||
@@ -37,6 +37,22 @@ jobs:
|
||||
echo "changed=true" >> "$GITHUB_OUTPUT"
|
||||
fi
|
||||
|
||||
# rkuo: I don't think we need python?
|
||||
# - name: Set up Python
|
||||
# uses: actions/setup-python@v5
|
||||
# with:
|
||||
# python-version: '3.11'
|
||||
# cache: 'pip'
|
||||
# cache-dependency-path: |
|
||||
# backend/requirements/default.txt
|
||||
# backend/requirements/dev.txt
|
||||
# backend/requirements/model_server.txt
|
||||
# - run: |
|
||||
# python -m pip install --upgrade pip
|
||||
# pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
|
||||
# pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
|
||||
# pip install --retries 5 --timeout 30 -r backend/requirements/model_server.txt
|
||||
|
||||
# lint all charts if any changes were detected
|
||||
- name: Run chart-testing (lint)
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
@@ -46,7 +62,7 @@ jobs:
|
||||
|
||||
- name: Create kind cluster
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
uses: helm/kind-action@v1.12.0
|
||||
uses: helm/kind-action@v1.10.0
|
||||
|
||||
- name: Run chart-testing (install)
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
|
||||
4
.github/workflows/pr-linear-check.yml
vendored
4
.github/workflows/pr-linear-check.yml
vendored
@@ -9,9 +9,9 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Check PR body for Linear link or override
|
||||
env:
|
||||
PR_BODY: ${{ github.event.pull_request.body }}
|
||||
run: |
|
||||
PR_BODY="${{ github.event.pull_request.body }}"
|
||||
|
||||
# Looking for "https://linear.app" in the body
|
||||
if echo "$PR_BODY" | grep -qE "https://linear\.app"; then
|
||||
echo "Found a Linear link. Check passed."
|
||||
|
||||
@@ -39,12 +39,6 @@ env:
|
||||
AIRTABLE_TEST_TABLE_ID: ${{ secrets.AIRTABLE_TEST_TABLE_ID }}
|
||||
AIRTABLE_TEST_TABLE_NAME: ${{ secrets.AIRTABLE_TEST_TABLE_NAME }}
|
||||
AIRTABLE_ACCESS_TOKEN: ${{ secrets.AIRTABLE_ACCESS_TOKEN }}
|
||||
# Sharepoint
|
||||
SHAREPOINT_CLIENT_ID: ${{ secrets.SHAREPOINT_CLIENT_ID }}
|
||||
SHAREPOINT_CLIENT_SECRET: ${{ secrets.SHAREPOINT_CLIENT_SECRET }}
|
||||
SHAREPOINT_CLIENT_DIRECTORY_ID: ${{ secrets.SHAREPOINT_CLIENT_DIRECTORY_ID }}
|
||||
SHAREPOINT_SITE: ${{ secrets.SHAREPOINT_SITE }}
|
||||
|
||||
jobs:
|
||||
connectors-check:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
|
||||
4
.gitignore
vendored
4
.gitignore
vendored
@@ -7,4 +7,6 @@
|
||||
.vscode/
|
||||
*.sw?
|
||||
/backend/tests/regression/answer_quality/search_test_config.yaml
|
||||
/web/test-results/
|
||||
/web/test-results/
|
||||
backend/onyx/agent_search/main/test_data.json
|
||||
backend/tests/regression/answer_quality/test_data.json
|
||||
|
||||
6
.vscode/env_template.txt
vendored
6
.vscode/env_template.txt
vendored
@@ -52,3 +52,9 @@ BING_API_KEY=<REPLACE THIS>
|
||||
# Enable the full set of Danswer Enterprise Edition features
|
||||
# NOTE: DO NOT ENABLE THIS UNLESS YOU HAVE A PAID ENTERPRISE LICENSE (or if you are using this for local testing/development)
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=False
|
||||
|
||||
# Agent Search configs # TODO: Remove give proper namings
|
||||
AGENT_RETRIEVAL_STATS=False # Note: This setting will incur substantial re-ranking effort
|
||||
AGENT_RERANKING_STATS=True
|
||||
AGENT_MAX_QUERY_RETRIEVAL_RESULTS=20
|
||||
AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS=20
|
||||
|
||||
1
Untitled-12
Normal file
1
Untitled-12
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
@@ -9,10 +9,8 @@ founders@onyx.app for more information. Please visit https://github.com/onyx-dot
|
||||
|
||||
# Default ONYX_VERSION, typically overriden during builds by GitHub Actions.
|
||||
ARG ONYX_VERSION=0.8-dev
|
||||
# DO_NOT_TRACK is used to disable telemetry for Unstructured
|
||||
ENV ONYX_VERSION=${ONYX_VERSION} \
|
||||
DANSWER_RUNNING_IN_DOCKER="true" \
|
||||
DO_NOT_TRACK="true"
|
||||
DANSWER_RUNNING_IN_DOCKER="true"
|
||||
|
||||
|
||||
RUN echo "ONYX_VERSION: ${ONYX_VERSION}"
|
||||
|
||||
@@ -0,0 +1,29 @@
|
||||
"""agent_doc_result_col
|
||||
|
||||
Revision ID: 1adf5ea20d2b
|
||||
Revises: e9cf2bd7baed
|
||||
Create Date: 2025-01-05 13:14:58.344316
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "1adf5ea20d2b"
|
||||
down_revision = "e9cf2bd7baed"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add the new column with JSONB type
|
||||
op.add_column(
|
||||
"sub_question",
|
||||
sa.Column("sub_question_doc_results", postgresql.JSONB(), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop the column
|
||||
op.drop_column("sub_question", "sub_question_doc_results")
|
||||
@@ -1,36 +0,0 @@
|
||||
"""add chat session specific temperature override
|
||||
|
||||
Revision ID: 2f80c6a2550f
|
||||
Revises: 33ea50e88f24
|
||||
Create Date: 2025-01-31 10:30:27.289646
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "2f80c6a2550f"
|
||||
down_revision = "33ea50e88f24"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"chat_session", sa.Column("temperature_override", sa.Float(), nullable=True)
|
||||
)
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column(
|
||||
"temperature_override_enabled",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.false(),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("chat_session", "temperature_override")
|
||||
op.drop_column("user", "temperature_override_enabled")
|
||||
@@ -1,80 +0,0 @@
|
||||
"""foreign key input prompts
|
||||
|
||||
Revision ID: 33ea50e88f24
|
||||
Revises: a6df6b88ef81
|
||||
Create Date: 2025-01-29 10:54:22.141765
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "33ea50e88f24"
|
||||
down_revision = "a6df6b88ef81"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Safely drop constraints if exists
|
||||
op.execute(
|
||||
"""
|
||||
ALTER TABLE inputprompt__user
|
||||
DROP CONSTRAINT IF EXISTS inputprompt__user_input_prompt_id_fkey
|
||||
"""
|
||||
)
|
||||
op.execute(
|
||||
"""
|
||||
ALTER TABLE inputprompt__user
|
||||
DROP CONSTRAINT IF EXISTS inputprompt__user_user_id_fkey
|
||||
"""
|
||||
)
|
||||
|
||||
# Recreate with ON DELETE CASCADE
|
||||
op.create_foreign_key(
|
||||
"inputprompt__user_input_prompt_id_fkey",
|
||||
"inputprompt__user",
|
||||
"inputprompt",
|
||||
["input_prompt_id"],
|
||||
["id"],
|
||||
ondelete="CASCADE",
|
||||
)
|
||||
|
||||
op.create_foreign_key(
|
||||
"inputprompt__user_user_id_fkey",
|
||||
"inputprompt__user",
|
||||
"user",
|
||||
["user_id"],
|
||||
["id"],
|
||||
ondelete="CASCADE",
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop the new FKs with ondelete
|
||||
op.drop_constraint(
|
||||
"inputprompt__user_input_prompt_id_fkey",
|
||||
"inputprompt__user",
|
||||
type_="foreignkey",
|
||||
)
|
||||
op.drop_constraint(
|
||||
"inputprompt__user_user_id_fkey",
|
||||
"inputprompt__user",
|
||||
type_="foreignkey",
|
||||
)
|
||||
|
||||
# Recreate them without cascading
|
||||
op.create_foreign_key(
|
||||
"inputprompt__user_input_prompt_id_fkey",
|
||||
"inputprompt__user",
|
||||
"inputprompt",
|
||||
["input_prompt_id"],
|
||||
["id"],
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"inputprompt__user_user_id_fkey",
|
||||
"inputprompt__user",
|
||||
"user",
|
||||
["user_id"],
|
||||
["id"],
|
||||
)
|
||||
@@ -1,37 +0,0 @@
|
||||
"""lowercase_user_emails
|
||||
|
||||
Revision ID: 4d58345da04a
|
||||
Revises: f1ca58b2f2ec
|
||||
Create Date: 2025-01-29 07:48:46.784041
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
from sqlalchemy.sql import text
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "4d58345da04a"
|
||||
down_revision = "f1ca58b2f2ec"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Get database connection
|
||||
connection = op.get_bind()
|
||||
|
||||
# Update all user emails to lowercase
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
UPDATE "user"
|
||||
SET email = LOWER(email)
|
||||
WHERE email != LOWER(email)
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Cannot restore original case of emails
|
||||
pass
|
||||
@@ -0,0 +1,35 @@
|
||||
"""agent_metric_col_rename__s
|
||||
|
||||
Revision ID: 925b58bd75b6
|
||||
Revises: 9787be927e58
|
||||
Create Date: 2025-01-06 11:20:26.752441
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "925b58bd75b6"
|
||||
down_revision = "9787be927e58"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Rename columns using PostgreSQL syntax
|
||||
op.alter_column(
|
||||
"agent__search_metrics", "base_duration_s", new_column_name="base_duration__s"
|
||||
)
|
||||
op.alter_column(
|
||||
"agent__search_metrics", "full_duration_s", new_column_name="full_duration__s"
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Revert the column renames
|
||||
op.alter_column(
|
||||
"agent__search_metrics", "base_duration__s", new_column_name="base_duration_s"
|
||||
)
|
||||
op.alter_column(
|
||||
"agent__search_metrics", "full_duration__s", new_column_name="full_duration_s"
|
||||
)
|
||||
@@ -0,0 +1,25 @@
|
||||
"""agent_metric_table_renames__agent__
|
||||
|
||||
Revision ID: 9787be927e58
|
||||
Revises: bceb76d618ec
|
||||
Create Date: 2025-01-06 11:01:44.210160
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "9787be927e58"
|
||||
down_revision = "bceb76d618ec"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Rename table from agent_search_metrics to agent__search_metrics
|
||||
op.rename_table("agent_search_metrics", "agent__search_metrics")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Rename table back from agent__search_metrics to agent_search_metrics
|
||||
op.rename_table("agent__search_metrics", "agent_search_metrics")
|
||||
42
backend/alembic/versions/98a5008d8711_agent_tracking.py
Normal file
42
backend/alembic/versions/98a5008d8711_agent_tracking.py
Normal file
@@ -0,0 +1,42 @@
|
||||
"""agent_tracking
|
||||
|
||||
Revision ID: 98a5008d8711
|
||||
Revises: f1ca58b2f2ec
|
||||
Create Date: 2025-01-04 14:41:52.732238
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "98a5008d8711"
|
||||
down_revision = "f1ca58b2f2ec"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"agent_search_metrics",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=True),
|
||||
sa.Column("persona_id", sa.Integer(), nullable=True),
|
||||
sa.Column("agent_type", sa.String(), nullable=False),
|
||||
sa.Column("start_time", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("base_duration_s", sa.Float(), nullable=False),
|
||||
sa.Column("full_duration_s", sa.Float(), nullable=False),
|
||||
sa.Column("base_metrics", postgresql.JSONB(), nullable=True),
|
||||
sa.Column("refined_metrics", postgresql.JSONB(), nullable=True),
|
||||
sa.Column("all_metrics", postgresql.JSONB(), nullable=True),
|
||||
sa.ForeignKeyConstraint(
|
||||
["persona_id"],
|
||||
["persona.id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("agent_search_metrics")
|
||||
@@ -1,29 +0,0 @@
|
||||
"""remove recent assistants
|
||||
|
||||
Revision ID: a6df6b88ef81
|
||||
Revises: 4d58345da04a
|
||||
Create Date: 2025-01-29 10:25:52.790407
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "a6df6b88ef81"
|
||||
down_revision = "4d58345da04a"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.drop_column("user", "recent_assistants")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column(
|
||||
"recent_assistants", postgresql.JSONB(), server_default="[]", nullable=False
|
||||
),
|
||||
)
|
||||
@@ -0,0 +1,84 @@
|
||||
"""agent_table_renames__agent__
|
||||
|
||||
Revision ID: bceb76d618ec
|
||||
Revises: c0132518a25b
|
||||
Create Date: 2025-01-06 10:50:48.109285
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "bceb76d618ec"
|
||||
down_revision = "c0132518a25b"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.drop_constraint(
|
||||
"sub_query__search_doc_sub_query_id_fkey",
|
||||
"sub_query__search_doc",
|
||||
type_="foreignkey",
|
||||
)
|
||||
op.drop_constraint(
|
||||
"sub_query__search_doc_search_doc_id_fkey",
|
||||
"sub_query__search_doc",
|
||||
type_="foreignkey",
|
||||
)
|
||||
# Rename tables
|
||||
op.rename_table("sub_query", "agent__sub_query")
|
||||
op.rename_table("sub_question", "agent__sub_question")
|
||||
op.rename_table("sub_query__search_doc", "agent__sub_query__search_doc")
|
||||
|
||||
# Update both foreign key constraints for agent__sub_query__search_doc
|
||||
|
||||
# Create new foreign keys with updated names
|
||||
op.create_foreign_key(
|
||||
"agent__sub_query__search_doc_sub_query_id_fkey",
|
||||
"agent__sub_query__search_doc",
|
||||
"agent__sub_query",
|
||||
["sub_query_id"],
|
||||
["id"],
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"agent__sub_query__search_doc_search_doc_id_fkey",
|
||||
"agent__sub_query__search_doc",
|
||||
"search_doc", # This table name doesn't change
|
||||
["search_doc_id"],
|
||||
["id"],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Update foreign key constraints for sub_query__search_doc
|
||||
op.drop_constraint(
|
||||
"agent__sub_query__search_doc_sub_query_id_fkey",
|
||||
"agent__sub_query__search_doc",
|
||||
type_="foreignkey",
|
||||
)
|
||||
op.drop_constraint(
|
||||
"agent__sub_query__search_doc_search_doc_id_fkey",
|
||||
"agent__sub_query__search_doc",
|
||||
type_="foreignkey",
|
||||
)
|
||||
|
||||
# Rename tables back
|
||||
op.rename_table("agent__sub_query__search_doc", "sub_query__search_doc")
|
||||
op.rename_table("agent__sub_question", "sub_question")
|
||||
op.rename_table("agent__sub_query", "sub_query")
|
||||
|
||||
op.create_foreign_key(
|
||||
"sub_query__search_doc_sub_query_id_fkey",
|
||||
"sub_query__search_doc",
|
||||
"sub_query",
|
||||
["sub_query_id"],
|
||||
["id"],
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"sub_query__search_doc_search_doc_id_fkey",
|
||||
"sub_query__search_doc",
|
||||
"search_doc", # This table name doesn't change
|
||||
["search_doc_id"],
|
||||
["id"],
|
||||
)
|
||||
@@ -0,0 +1,40 @@
|
||||
"""agent_table_changes_rename_level
|
||||
|
||||
Revision ID: c0132518a25b
|
||||
Revises: 1adf5ea20d2b
|
||||
Create Date: 2025-01-05 16:38:37.660152
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "c0132518a25b"
|
||||
down_revision = "1adf5ea20d2b"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add level and level_question_nr columns with NOT NULL constraint
|
||||
op.add_column(
|
||||
"sub_question",
|
||||
sa.Column("level", sa.Integer(), nullable=False, server_default="0"),
|
||||
)
|
||||
op.add_column(
|
||||
"sub_question",
|
||||
sa.Column(
|
||||
"level_question_nr", sa.Integer(), nullable=False, server_default="0"
|
||||
),
|
||||
)
|
||||
|
||||
# Remove the server_default after the columns are created
|
||||
op.alter_column("sub_question", "level", server_default=None)
|
||||
op.alter_column("sub_question", "level_question_nr", server_default=None)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Remove the columns
|
||||
op.drop_column("sub_question", "level_question_nr")
|
||||
op.drop_column("sub_question", "level")
|
||||
@@ -0,0 +1,68 @@
|
||||
"""create pro search persistence tables
|
||||
|
||||
Revision ID: e9cf2bd7baed
|
||||
Revises: 98a5008d8711
|
||||
Create Date: 2025-01-02 17:55:56.544246
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "e9cf2bd7baed"
|
||||
down_revision = "98a5008d8711"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Create sub_question table
|
||||
op.create_table(
|
||||
"sub_question",
|
||||
sa.Column("id", sa.Integer, primary_key=True),
|
||||
sa.Column("primary_question_id", sa.Integer, sa.ForeignKey("chat_message.id")),
|
||||
sa.Column(
|
||||
"chat_session_id", UUID(as_uuid=True), sa.ForeignKey("chat_session.id")
|
||||
),
|
||||
sa.Column("sub_question", sa.Text),
|
||||
sa.Column(
|
||||
"time_created", sa.DateTime(timezone=True), server_default=sa.func.now()
|
||||
),
|
||||
sa.Column("sub_answer", sa.Text),
|
||||
)
|
||||
|
||||
# Create sub_query table
|
||||
op.create_table(
|
||||
"sub_query",
|
||||
sa.Column("id", sa.Integer, primary_key=True),
|
||||
sa.Column("parent_question_id", sa.Integer, sa.ForeignKey("sub_question.id")),
|
||||
sa.Column(
|
||||
"chat_session_id", UUID(as_uuid=True), sa.ForeignKey("chat_session.id")
|
||||
),
|
||||
sa.Column("sub_query", sa.Text),
|
||||
sa.Column(
|
||||
"time_created", sa.DateTime(timezone=True), server_default=sa.func.now()
|
||||
),
|
||||
)
|
||||
|
||||
# Create sub_query__search_doc association table
|
||||
op.create_table(
|
||||
"sub_query__search_doc",
|
||||
sa.Column(
|
||||
"sub_query_id", sa.Integer, sa.ForeignKey("sub_query.id"), primary_key=True
|
||||
),
|
||||
sa.Column(
|
||||
"search_doc_id",
|
||||
sa.Integer,
|
||||
sa.ForeignKey("search_doc.id"),
|
||||
primary_key=True,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("sub_query__search_doc")
|
||||
op.drop_table("sub_query")
|
||||
op.drop_table("sub_question")
|
||||
370
backend/chat_packets.log
Normal file
370
backend/chat_packets.log
Normal file
File diff suppressed because one or more lines are too long
@@ -32,7 +32,6 @@ def perform_ttl_management_task(
|
||||
|
||||
@celery_app.task(
|
||||
name="check_ttl_management_task",
|
||||
ignore_result=True,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
)
|
||||
def check_ttl_management_task(*, tenant_id: str | None) -> None:
|
||||
@@ -57,7 +56,6 @@ def check_ttl_management_task(*, tenant_id: str | None) -> None:
|
||||
|
||||
@celery_app.task(
|
||||
name="autogenerate_usage_report_task",
|
||||
ignore_result=True,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
)
|
||||
def autogenerate_usage_report_task(*, tenant_id: str | None) -> None:
|
||||
|
||||
@@ -1,72 +1,30 @@
|
||||
from datetime import timedelta
|
||||
from typing import Any
|
||||
|
||||
from onyx.background.celery.tasks.beat_schedule import BEAT_EXPIRES_DEFAULT
|
||||
from onyx.background.celery.tasks.beat_schedule import (
|
||||
cloud_tasks_to_schedule as base_cloud_tasks_to_schedule,
|
||||
)
|
||||
from onyx.background.celery.tasks.beat_schedule import (
|
||||
tasks_to_schedule as base_tasks_to_schedule,
|
||||
)
|
||||
from onyx.configs.constants import ONYX_CLOUD_CELERY_TASK_PREFIX
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
ee_cloud_tasks_to_schedule = [
|
||||
ee_tasks_to_schedule = [
|
||||
{
|
||||
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_autogenerate-usage-report",
|
||||
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
|
||||
"schedule": timedelta(days=30),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGHEST,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
"kwargs": {
|
||||
"task_name": OnyxCeleryTask.AUTOGENERATE_USAGE_REPORT_TASK,
|
||||
},
|
||||
"name": "autogenerate-usage-report",
|
||||
"task": OnyxCeleryTask.AUTOGENERATE_USAGE_REPORT_TASK,
|
||||
"schedule": timedelta(days=30), # TODO: change this to config flag
|
||||
},
|
||||
{
|
||||
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-ttl-management",
|
||||
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
|
||||
"name": "check-ttl-management",
|
||||
"task": OnyxCeleryTask.CHECK_TTL_MANAGEMENT_TASK,
|
||||
"schedule": timedelta(hours=1),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGHEST,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
"kwargs": {
|
||||
"task_name": OnyxCeleryTask.CHECK_TTL_MANAGEMENT_TASK,
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
ee_tasks_to_schedule: list[dict] = []
|
||||
|
||||
if not MULTI_TENANT:
|
||||
ee_tasks_to_schedule = [
|
||||
{
|
||||
"name": "autogenerate-usage-report",
|
||||
"task": OnyxCeleryTask.AUTOGENERATE_USAGE_REPORT_TASK,
|
||||
"schedule": timedelta(days=30), # TODO: change this to config flag
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "check-ttl-management",
|
||||
"task": OnyxCeleryTask.CHECK_TTL_MANAGEMENT_TASK,
|
||||
"schedule": timedelta(hours=1),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def get_cloud_tasks_to_schedule() -> list[dict[str, Any]]:
|
||||
return ee_cloud_tasks_to_schedule + base_cloud_tasks_to_schedule
|
||||
return base_cloud_tasks_to_schedule
|
||||
|
||||
|
||||
def get_tasks_to_schedule() -> list[dict[str, Any]]:
|
||||
|
||||
@@ -4,20 +4,6 @@ import os
|
||||
# Applicable for OIDC Auth
|
||||
OPENID_CONFIG_URL = os.environ.get("OPENID_CONFIG_URL", "")
|
||||
|
||||
# Applicable for OIDC Auth, allows you to override the scopes that
|
||||
# are requested from the OIDC provider. Currently used when passing
|
||||
# over access tokens to tool calls and the tool needs more scopes
|
||||
OIDC_SCOPE_OVERRIDE: list[str] | None = None
|
||||
_OIDC_SCOPE_OVERRIDE = os.environ.get("OIDC_SCOPE_OVERRIDE")
|
||||
|
||||
if _OIDC_SCOPE_OVERRIDE:
|
||||
try:
|
||||
OIDC_SCOPE_OVERRIDE = [
|
||||
scope.strip() for scope in _OIDC_SCOPE_OVERRIDE.split(",")
|
||||
]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Applicable for SAML Auth
|
||||
SAML_CONF_DIR = os.environ.get("SAML_CONF_DIR") or "/app/ee/onyx/configs/saml_config"
|
||||
|
||||
|
||||
@@ -13,7 +13,6 @@ from onyx.connectors.confluence.onyx_confluence import OnyxConfluence
|
||||
from onyx.connectors.confluence.utils import get_user_email_from_username__server
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -258,7 +257,6 @@ def _fetch_all_page_restrictions(
|
||||
slim_docs: list[SlimDocument],
|
||||
space_permissions_by_space_key: dict[str, ExternalAccess],
|
||||
is_cloud: bool,
|
||||
callback: IndexingHeartbeatInterface | None,
|
||||
) -> list[DocExternalAccess]:
|
||||
"""
|
||||
For all pages, if a page has restrictions, then use those restrictions.
|
||||
@@ -267,12 +265,6 @@ def _fetch_all_page_restrictions(
|
||||
document_restrictions: list[DocExternalAccess] = []
|
||||
|
||||
for slim_doc in slim_docs:
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
raise RuntimeError("confluence_doc_sync: Stop signal detected")
|
||||
|
||||
callback.progress("confluence_doc_sync:fetch_all_page_restrictions", 1)
|
||||
|
||||
if slim_doc.perm_sync_data is None:
|
||||
raise ValueError(
|
||||
f"No permission sync data found for document {slim_doc.id}"
|
||||
@@ -342,7 +334,7 @@ def _fetch_all_page_restrictions(
|
||||
|
||||
|
||||
def confluence_doc_sync(
|
||||
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
) -> list[DocExternalAccess]:
|
||||
"""
|
||||
Adds the external permissions to the documents in postgres
|
||||
@@ -367,12 +359,6 @@ def confluence_doc_sync(
|
||||
logger.debug("Fetching all slim documents from confluence")
|
||||
for doc_batch in confluence_connector.retrieve_all_slim_documents():
|
||||
logger.debug(f"Got {len(doc_batch)} slim documents from confluence")
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
raise RuntimeError("confluence_doc_sync: Stop signal detected")
|
||||
|
||||
callback.progress("confluence_doc_sync", 1)
|
||||
|
||||
slim_docs.extend(doc_batch)
|
||||
|
||||
logger.debug("Fetching all page restrictions for space")
|
||||
@@ -381,5 +367,4 @@ def confluence_doc_sync(
|
||||
slim_docs=slim_docs,
|
||||
space_permissions_by_space_key=space_permissions_by_space_key,
|
||||
is_cloud=is_cloud,
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
@@ -14,8 +14,6 @@ def _build_group_member_email_map(
|
||||
) -> dict[str, set[str]]:
|
||||
group_member_emails: dict[str, set[str]] = {}
|
||||
for user_result in confluence_client.paginated_cql_user_retrieval():
|
||||
logger.debug(f"Processing groups for user: {user_result}")
|
||||
|
||||
user = user_result.get("user", {})
|
||||
if not user:
|
||||
logger.warning(f"user result missing user field: {user_result}")
|
||||
@@ -35,17 +33,10 @@ def _build_group_member_email_map(
|
||||
logger.warning(f"user result missing email field: {user_result}")
|
||||
continue
|
||||
|
||||
all_users_groups: set[str] = set()
|
||||
for group in confluence_client.paginated_groups_by_user_retrieval(user):
|
||||
# group name uniqueness is enforced by Confluence, so we can use it as a group ID
|
||||
group_id = group["name"]
|
||||
group_member_emails.setdefault(group_id, set()).add(email)
|
||||
all_users_groups.add(group_id)
|
||||
|
||||
if not group_member_emails:
|
||||
logger.warning(f"No groups found for user with email: {email}")
|
||||
else:
|
||||
logger.debug(f"Found groups {all_users_groups} for user with email {email}")
|
||||
|
||||
return group_member_emails
|
||||
|
||||
|
||||
@@ -6,7 +6,6 @@ from onyx.access.models import ExternalAccess
|
||||
from onyx.connectors.gmail.connector import GmailConnector
|
||||
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -29,7 +28,7 @@ def _get_slim_doc_generator(
|
||||
|
||||
|
||||
def gmail_doc_sync(
|
||||
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
) -> list[DocExternalAccess]:
|
||||
"""
|
||||
Adds the external permissions to the documents in postgres
|
||||
@@ -45,12 +44,6 @@ def gmail_doc_sync(
|
||||
document_external_access: list[DocExternalAccess] = []
|
||||
for slim_doc_batch in slim_doc_generator:
|
||||
for slim_doc in slim_doc_batch:
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
raise RuntimeError("gmail_doc_sync: Stop signal detected")
|
||||
|
||||
callback.progress("gmail_doc_sync", 1)
|
||||
|
||||
if slim_doc.perm_sync_data is None:
|
||||
logger.warning(f"No permissions found for document {slim_doc.id}")
|
||||
continue
|
||||
|
||||
@@ -10,7 +10,6 @@ from onyx.connectors.google_utils.resources import get_drive_service
|
||||
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -43,22 +42,24 @@ def _fetch_permissions_for_permission_ids(
|
||||
if not permission_info or not doc_id:
|
||||
return []
|
||||
|
||||
# Check cache first for all permission IDs
|
||||
permissions = [
|
||||
_PERMISSION_ID_PERMISSION_MAP[pid]
|
||||
for pid in permission_ids
|
||||
if pid in _PERMISSION_ID_PERMISSION_MAP
|
||||
]
|
||||
|
||||
# If we found all permissions in cache, return them
|
||||
if len(permissions) == len(permission_ids):
|
||||
return permissions
|
||||
|
||||
owner_email = permission_info.get("owner_email")
|
||||
|
||||
drive_service = get_drive_service(
|
||||
creds=google_drive_connector.creds,
|
||||
user_email=(owner_email or google_drive_connector.primary_admin_email),
|
||||
)
|
||||
|
||||
# Otherwise, fetch all permissions and update cache
|
||||
fetched_permissions = execute_paginated_retrieval(
|
||||
retrieval_function=drive_service.permissions().list,
|
||||
list_key="permissions",
|
||||
@@ -68,6 +69,7 @@ def _fetch_permissions_for_permission_ids(
|
||||
)
|
||||
|
||||
permissions_for_doc_id = []
|
||||
# Update cache and return all permissions
|
||||
for permission in fetched_permissions:
|
||||
permissions_for_doc_id.append(permission)
|
||||
_PERMISSION_ID_PERMISSION_MAP[permission["id"]] = permission
|
||||
@@ -129,7 +131,7 @@ def _get_permissions_from_slim_doc(
|
||||
|
||||
|
||||
def gdrive_doc_sync(
|
||||
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
) -> list[DocExternalAccess]:
|
||||
"""
|
||||
Adds the external permissions to the documents in postgres
|
||||
@@ -147,12 +149,6 @@ def gdrive_doc_sync(
|
||||
document_external_accesses = []
|
||||
for slim_doc_batch in slim_doc_generator:
|
||||
for slim_doc in slim_doc_batch:
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
raise RuntimeError("gdrive_doc_sync: Stop signal detected")
|
||||
|
||||
callback.progress("gdrive_doc_sync", 1)
|
||||
|
||||
ext_access = _get_permissions_from_slim_doc(
|
||||
google_drive_connector=google_drive_connector,
|
||||
slim_doc=slim_doc,
|
||||
|
||||
@@ -7,7 +7,6 @@ from onyx.connectors.slack.connector import get_channels
|
||||
from onyx.connectors.slack.connector import make_paginated_slack_api_call_w_retries
|
||||
from onyx.connectors.slack.connector import SlackPollConnector
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
@@ -15,7 +14,7 @@ logger = setup_logger()
|
||||
|
||||
|
||||
def _get_slack_document_ids_and_channels(
|
||||
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
) -> dict[str, list[str]]:
|
||||
slack_connector = SlackPollConnector(**cc_pair.connector.connector_specific_config)
|
||||
slack_connector.load_credentials(cc_pair.credential.credential_json)
|
||||
@@ -25,14 +24,6 @@ def _get_slack_document_ids_and_channels(
|
||||
channel_doc_map: dict[str, list[str]] = {}
|
||||
for doc_metadata_batch in slim_doc_generator:
|
||||
for doc_metadata in doc_metadata_batch:
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
raise RuntimeError(
|
||||
"_get_slack_document_ids_and_channels: Stop signal detected"
|
||||
)
|
||||
|
||||
callback.progress("_get_slack_document_ids_and_channels", 1)
|
||||
|
||||
if doc_metadata.perm_sync_data is None:
|
||||
continue
|
||||
channel_id = doc_metadata.perm_sync_data["channel_id"]
|
||||
@@ -123,7 +114,7 @@ def _fetch_channel_permissions(
|
||||
|
||||
|
||||
def slack_doc_sync(
|
||||
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
) -> list[DocExternalAccess]:
|
||||
"""
|
||||
Adds the external permissions to the documents in postgres
|
||||
@@ -136,7 +127,7 @@ def slack_doc_sync(
|
||||
)
|
||||
user_id_to_email_map = fetch_user_id_to_email_map(slack_client)
|
||||
channel_doc_map = _get_slack_document_ids_and_channels(
|
||||
cc_pair=cc_pair, callback=callback
|
||||
cc_pair=cc_pair,
|
||||
)
|
||||
workspace_permissions = _fetch_workspace_permissions(
|
||||
user_id_to_email_map=user_id_to_email_map,
|
||||
|
||||
@@ -15,13 +15,11 @@ from ee.onyx.external_permissions.slack.doc_sync import slack_doc_sync
|
||||
from onyx.access.models import DocExternalAccess
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
|
||||
# Defining the input/output types for the sync functions
|
||||
DocSyncFuncType = Callable[
|
||||
[
|
||||
ConnectorCredentialPair,
|
||||
IndexingHeartbeatInterface | None,
|
||||
],
|
||||
list[DocExternalAccess],
|
||||
]
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
from fastapi import FastAPI
|
||||
from httpx_oauth.clients.google import GoogleOAuth2
|
||||
from httpx_oauth.clients.openid import BASE_SCOPES
|
||||
from httpx_oauth.clients.openid import OpenID
|
||||
|
||||
from ee.onyx.configs.app_configs import OIDC_SCOPE_OVERRIDE
|
||||
from ee.onyx.configs.app_configs import OPENID_CONFIG_URL
|
||||
from ee.onyx.server.analytics.api import router as analytics_router
|
||||
from ee.onyx.server.auth_check import check_ee_router_auth
|
||||
@@ -90,13 +88,7 @@ def get_application() -> FastAPI:
|
||||
include_auth_router_with_prefix(
|
||||
application,
|
||||
create_onyx_oauth_router(
|
||||
OpenID(
|
||||
OAUTH_CLIENT_ID,
|
||||
OAUTH_CLIENT_SECRET,
|
||||
OPENID_CONFIG_URL,
|
||||
# BASE_SCOPES is the same as not setting this
|
||||
base_scopes=OIDC_SCOPE_OVERRIDE or BASE_SCOPES,
|
||||
),
|
||||
OpenID(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET, OPENID_CONFIG_URL),
|
||||
auth_backend,
|
||||
USER_AUTH_SECRET,
|
||||
associate_by_email=True,
|
||||
|
||||
@@ -179,6 +179,7 @@ def handle_simplified_chat_message(
|
||||
chunks_below=0,
|
||||
full_doc=chat_message_req.full_doc,
|
||||
structured_response_format=chat_message_req.structured_response_format,
|
||||
use_agentic_search=chat_message_req.use_agentic_search,
|
||||
)
|
||||
|
||||
packets = stream_chat_message_objects(
|
||||
@@ -301,6 +302,7 @@ def handle_send_message_simple_with_history(
|
||||
chunks_below=0,
|
||||
full_doc=req.full_doc,
|
||||
structured_response_format=req.structured_response_format,
|
||||
use_agentic_search=req.use_agentic_search,
|
||||
)
|
||||
|
||||
packets = stream_chat_message_objects(
|
||||
|
||||
@@ -57,6 +57,9 @@ class BasicCreateChatMessageRequest(ChunkContext):
|
||||
# https://platform.openai.com/docs/guides/structured-outputs/introduction
|
||||
structured_response_format: dict | None = None
|
||||
|
||||
# If True, uses agentic search instead of basic search
|
||||
use_agentic_search: bool = False
|
||||
|
||||
|
||||
class BasicCreateChatMessageWithHistoryRequest(ChunkContext):
|
||||
# Last element is the new query. All previous elements are historical context
|
||||
@@ -71,6 +74,8 @@ class BasicCreateChatMessageWithHistoryRequest(ChunkContext):
|
||||
# only works if using an OpenAI model. See the following for more details:
|
||||
# https://platform.openai.com/docs/guides/structured-outputs/introduction
|
||||
structured_response_format: dict | None = None
|
||||
# If True, uses agentic search instead of basic search
|
||||
use_agentic_search: bool = False
|
||||
|
||||
|
||||
class SimpleDoc(BaseModel):
|
||||
@@ -123,6 +128,9 @@ class OneShotQARequest(ChunkContext):
|
||||
# If True, skips generative an AI response to the search query
|
||||
skip_gen_ai_answer_generation: bool = False
|
||||
|
||||
# If True, uses pro search instead of basic search
|
||||
use_agentic_search: bool = False
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_persona_fields(self) -> "OneShotQARequest":
|
||||
if self.persona_override_config is None and self.persona_id is None:
|
||||
|
||||
@@ -196,6 +196,8 @@ def get_answer_stream(
|
||||
retrieval_details=query_request.retrieval_options,
|
||||
rerank_settings=query_request.rerank_settings,
|
||||
db_session=db_session,
|
||||
use_agentic_search=query_request.use_agentic_search,
|
||||
skip_gen_ai_answer_generation=query_request.skip_gen_ai_answer_generation,
|
||||
)
|
||||
|
||||
packets = stream_chat_message_objects(
|
||||
|
||||
@@ -111,7 +111,6 @@ async def login_as_anonymous_user(
|
||||
token = generate_anonymous_user_jwt_token(tenant_id)
|
||||
|
||||
response = Response()
|
||||
response.delete_cookie("fastapiusersauth")
|
||||
response.set_cookie(
|
||||
key=ANONYMOUS_USER_COOKIE_NAME,
|
||||
value=token,
|
||||
|
||||
@@ -58,7 +58,6 @@ class UserGroup(BaseModel):
|
||||
credential=CredentialSnapshot.from_credential_db_model(
|
||||
cc_pair_relationship.cc_pair.credential
|
||||
),
|
||||
access_type=cc_pair_relationship.cc_pair.access_type,
|
||||
)
|
||||
for cc_pair_relationship in user_group_model.cc_pair_relationships
|
||||
if cc_pair_relationship.is_current
|
||||
|
||||
71
backend/onyx/agents/agent_search/basic/graph_builder.py
Normal file
71
backend/onyx/agents/agent_search/basic/graph_builder.py
Normal file
@@ -0,0 +1,71 @@
|
||||
from langgraph.graph import END
|
||||
from langgraph.graph import START
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
from onyx.agents.agent_search.basic.nodes.basic_use_tool_response import (
|
||||
basic_use_tool_response,
|
||||
)
|
||||
from onyx.agents.agent_search.basic.nodes.llm_tool_choice import llm_tool_choice
|
||||
from onyx.agents.agent_search.basic.nodes.tool_call import tool_call
|
||||
from onyx.agents.agent_search.basic.states import BasicInput
|
||||
from onyx.agents.agent_search.basic.states import BasicOutput
|
||||
from onyx.agents.agent_search.basic.states import BasicState
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def basic_graph_builder() -> StateGraph:
|
||||
graph = StateGraph(
|
||||
state_schema=BasicState,
|
||||
input=BasicInput,
|
||||
output=BasicOutput,
|
||||
)
|
||||
|
||||
### Add nodes ###
|
||||
|
||||
graph.add_node(
|
||||
node="llm_tool_choice",
|
||||
action=llm_tool_choice,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
node="tool_call",
|
||||
action=tool_call,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
node="basic_use_tool_response",
|
||||
action=basic_use_tool_response,
|
||||
)
|
||||
|
||||
### Add edges ###
|
||||
|
||||
graph.add_edge(start_key=START, end_key="llm_tool_choice")
|
||||
|
||||
graph.add_conditional_edges("llm_tool_choice", should_continue, ["tool_call", END])
|
||||
|
||||
graph.add_edge(
|
||||
start_key="tool_call",
|
||||
end_key="basic_use_tool_response",
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key="basic_use_tool_response",
|
||||
end_key=END,
|
||||
)
|
||||
|
||||
return graph
|
||||
|
||||
|
||||
def should_continue(state: BasicState) -> str:
|
||||
return (
|
||||
# If there are no tool calls, basic graph already streamed the answer
|
||||
END
|
||||
if state["tool_choice"] is None
|
||||
else "tool_call"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pass
|
||||
@@ -0,0 +1,63 @@
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.basic.states import BasicOutput
|
||||
from onyx.agents.agent_search.basic.states import BasicState
|
||||
from onyx.agents.agent_search.basic.utils import process_llm_stream
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.chat.models import LlmDoc
|
||||
from onyx.tools.tool_implementations.search.search_tool import (
|
||||
SEARCH_DOC_CONTENT_ID,
|
||||
)
|
||||
from onyx.tools.tool_implementations.search_like_tool_utils import (
|
||||
FINAL_CONTEXT_DOCUMENTS_ID,
|
||||
)
|
||||
|
||||
|
||||
def basic_use_tool_response(state: BasicState, config: RunnableConfig) -> BasicOutput:
|
||||
agent_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
structured_response_format = agent_config.structured_response_format
|
||||
llm = agent_config.primary_llm
|
||||
tool_choice = state["tool_choice"]
|
||||
if tool_choice is None:
|
||||
raise ValueError("Tool choice is None")
|
||||
tool = tool_choice["tool"]
|
||||
prompt_builder = agent_config.prompt_builder
|
||||
tool_call_summary = state["tool_call_summary"]
|
||||
tool_call_responses = state["tool_call_responses"]
|
||||
state["tool_call_final_result"]
|
||||
new_prompt_builder = tool.build_next_prompt(
|
||||
prompt_builder=prompt_builder,
|
||||
tool_call_summary=tool_call_summary,
|
||||
tool_responses=tool_call_responses,
|
||||
using_tool_calling_llm=agent_config.using_tool_calling_llm,
|
||||
)
|
||||
|
||||
final_search_results = []
|
||||
initial_search_results = []
|
||||
for yield_item in tool_call_responses:
|
||||
if yield_item.id == FINAL_CONTEXT_DOCUMENTS_ID:
|
||||
final_search_results = cast(list[LlmDoc], yield_item.response)
|
||||
elif yield_item.id == SEARCH_DOC_CONTENT_ID:
|
||||
search_contexts = yield_item.response.contexts
|
||||
for doc in search_contexts:
|
||||
if doc.document_id not in initial_search_results:
|
||||
initial_search_results.append(doc)
|
||||
|
||||
initial_search_results = cast(list[LlmDoc], initial_search_results)
|
||||
|
||||
stream = llm.stream(
|
||||
prompt=new_prompt_builder.build(),
|
||||
structured_response_format=structured_response_format,
|
||||
)
|
||||
|
||||
# For now, we don't do multiple tool calls, so we ignore the tool_message
|
||||
process_llm_stream(
|
||||
stream,
|
||||
True,
|
||||
final_search_results=final_search_results,
|
||||
displayed_search_results=initial_search_results,
|
||||
)
|
||||
|
||||
return BasicOutput()
|
||||
134
backend/onyx/agents/agent_search/basic/nodes/llm_tool_choice.py
Normal file
134
backend/onyx/agents/agent_search/basic/nodes/llm_tool_choice.py
Normal file
@@ -0,0 +1,134 @@
|
||||
from typing import cast
|
||||
from uuid import uuid4
|
||||
|
||||
from langchain_core.messages import ToolCall
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.basic.states import BasicState
|
||||
from onyx.agents.agent_search.basic.states import ToolChoice
|
||||
from onyx.agents.agent_search.basic.states import ToolChoiceUpdate
|
||||
from onyx.agents.agent_search.basic.utils import process_llm_stream
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.chat.tool_handling.tool_response_handler import get_tool_by_name
|
||||
from onyx.chat.tool_handling.tool_response_handler import (
|
||||
get_tool_call_for_non_tool_calling_llm_impl,
|
||||
)
|
||||
from onyx.tools.tool import Tool
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
# TODO: break this out into an implementation function
|
||||
# and a function that handles extracting the necessary fields
|
||||
# from the state and config
|
||||
# TODO: fan-out to multiple tool call nodes? Make this configurable?
|
||||
def llm_tool_choice(state: BasicState, config: RunnableConfig) -> ToolChoiceUpdate:
|
||||
"""
|
||||
This node is responsible for calling the LLM to choose a tool. If no tool is chosen,
|
||||
The node MAY emit an answer, depending on whether state["should_stream_answer"] is set.
|
||||
"""
|
||||
should_stream_answer = state["should_stream_answer"]
|
||||
|
||||
agent_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
using_tool_calling_llm = agent_config.using_tool_calling_llm
|
||||
prompt_builder = agent_config.prompt_builder
|
||||
llm = agent_config.primary_llm
|
||||
skip_gen_ai_answer_generation = agent_config.skip_gen_ai_answer_generation
|
||||
|
||||
structured_response_format = agent_config.structured_response_format
|
||||
tools = agent_config.tools or []
|
||||
force_use_tool = agent_config.force_use_tool
|
||||
|
||||
tool, tool_args = None, None
|
||||
if force_use_tool.force_use and force_use_tool.args is not None:
|
||||
tool_name, tool_args = (
|
||||
force_use_tool.tool_name,
|
||||
force_use_tool.args,
|
||||
)
|
||||
tool = get_tool_by_name(tools, tool_name)
|
||||
|
||||
# special pre-logic for non-tool calling LLM case
|
||||
elif not using_tool_calling_llm and tools:
|
||||
chosen_tool_and_args = get_tool_call_for_non_tool_calling_llm_impl(
|
||||
force_use_tool=force_use_tool,
|
||||
tools=tools,
|
||||
prompt_builder=prompt_builder,
|
||||
llm=llm,
|
||||
)
|
||||
if chosen_tool_and_args:
|
||||
tool, tool_args = chosen_tool_and_args
|
||||
|
||||
# If we have a tool and tool args, we are redy to request a tool call.
|
||||
# This only happens if the tool call was forced or we are using a non-tool calling LLM.
|
||||
if tool and tool_args:
|
||||
return ToolChoiceUpdate(
|
||||
tool_choice=ToolChoice(
|
||||
tool=tool,
|
||||
tool_args=tool_args,
|
||||
id=str(uuid4()),
|
||||
),
|
||||
)
|
||||
|
||||
# if we're skipping gen ai answer generation, we should only
|
||||
# continue if we're forcing a tool call (which will be emitted by
|
||||
# the tool calling llm in the stream() below)
|
||||
if skip_gen_ai_answer_generation and not force_use_tool.force_use:
|
||||
return ToolChoiceUpdate(
|
||||
tool_choice=None,
|
||||
)
|
||||
|
||||
# At this point, we are either using a tool calling LLM or we are skipping the tool call.
|
||||
# DEBUG: good breakpoint
|
||||
stream = llm.stream(
|
||||
# For tool calling LLMs, we want to insert the task prompt as part of this flow, this is because the LLM
|
||||
# may choose to not call any tools and just generate the answer, in which case the task prompt is needed.
|
||||
prompt=prompt_builder.build(),
|
||||
tools=[tool.tool_definition() for tool in tools] or None,
|
||||
tool_choice=("required" if tools and force_use_tool.force_use else None),
|
||||
structured_response_format=structured_response_format,
|
||||
)
|
||||
|
||||
tool_message = process_llm_stream(stream, should_stream_answer)
|
||||
|
||||
# If no tool calls are emitted by the LLM, we should not choose a tool
|
||||
if len(tool_message.tool_calls) == 0:
|
||||
return ToolChoiceUpdate(
|
||||
tool_choice=None,
|
||||
)
|
||||
|
||||
# TODO: here we could handle parallel tool calls. Right now
|
||||
# we just pick the first one that matches.
|
||||
selected_tool: Tool | None = None
|
||||
selected_tool_call_request: ToolCall | None = None
|
||||
for tool_call_request in tool_message.tool_calls:
|
||||
known_tools_by_name = [
|
||||
tool for tool in tools if tool.name == tool_call_request["name"]
|
||||
]
|
||||
|
||||
if known_tools_by_name:
|
||||
selected_tool = known_tools_by_name[0]
|
||||
selected_tool_call_request = tool_call_request
|
||||
break
|
||||
|
||||
logger.error(
|
||||
"Tool call requested with unknown name field. \n"
|
||||
f"tools: {tools}"
|
||||
f"tool_call_request: {tool_call_request}"
|
||||
)
|
||||
|
||||
if not selected_tool or not selected_tool_call_request:
|
||||
raise ValueError(
|
||||
f"Tool call attempted with tool {selected_tool}, request {selected_tool_call_request}"
|
||||
)
|
||||
|
||||
logger.info(f"Selected tool: {selected_tool.name}")
|
||||
logger.debug(f"Selected tool call request: {selected_tool_call_request}")
|
||||
|
||||
return ToolChoiceUpdate(
|
||||
tool_choice=ToolChoice(
|
||||
tool=selected_tool,
|
||||
tool_args=selected_tool_call_request["args"],
|
||||
id=selected_tool_call_request["id"],
|
||||
),
|
||||
)
|
||||
69
backend/onyx/agents/agent_search/basic/nodes/tool_call.py
Normal file
69
backend/onyx/agents/agent_search/basic/nodes/tool_call.py
Normal file
@@ -0,0 +1,69 @@
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.callbacks.manager import dispatch_custom_event
|
||||
from langchain_core.messages import AIMessageChunk
|
||||
from langchain_core.messages.tool import ToolCall
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.basic.states import BasicState
|
||||
from onyx.agents.agent_search.basic.states import ToolCallUpdate
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.chat.models import AnswerPacket
|
||||
from onyx.tools.message import build_tool_message
|
||||
from onyx.tools.message import ToolCallSummary
|
||||
from onyx.tools.tool_runner import ToolRunner
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def emit_packet(packet: AnswerPacket) -> None:
|
||||
dispatch_custom_event("basic_response", packet)
|
||||
|
||||
|
||||
# TODO: handle is_cancelled
|
||||
def tool_call(state: BasicState, config: RunnableConfig) -> ToolCallUpdate:
|
||||
"""Calls the tool specified in the state and updates the state with the result"""
|
||||
# TODO: implement
|
||||
|
||||
cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
# Unnecessary now, node should only be called if there is a tool call
|
||||
# if not self.tool_call_chunk or not self.tool_call_chunk.tool_calls:
|
||||
# return
|
||||
|
||||
tool_choice = state["tool_choice"]
|
||||
if tool_choice is None:
|
||||
raise ValueError("Cannot invoke tool call node without a tool choice")
|
||||
|
||||
tool = tool_choice["tool"]
|
||||
tool_args = tool_choice["tool_args"]
|
||||
tool_id = tool_choice["id"]
|
||||
tool_runner = ToolRunner(tool, tool_args)
|
||||
tool_kickoff = tool_runner.kickoff()
|
||||
|
||||
# TODO: custom events for yields
|
||||
emit_packet(tool_kickoff)
|
||||
|
||||
tool_responses = []
|
||||
for response in tool_runner.tool_responses():
|
||||
tool_responses.append(response)
|
||||
emit_packet(response)
|
||||
|
||||
tool_final_result = tool_runner.tool_final_result()
|
||||
emit_packet(tool_final_result)
|
||||
|
||||
tool_call = ToolCall(name=tool.name, args=tool_args, id=tool_id)
|
||||
tool_call_summary = ToolCallSummary(
|
||||
tool_call_request=AIMessageChunk(content="", tool_calls=[tool_call]),
|
||||
tool_call_result=build_tool_message(
|
||||
tool_call, tool_runner.tool_message_content()
|
||||
),
|
||||
)
|
||||
|
||||
return ToolCallUpdate(
|
||||
tool_call_summary=tool_call_summary,
|
||||
tool_call_kickoff=tool_kickoff,
|
||||
tool_call_responses=tool_responses,
|
||||
tool_call_final_result=tool_final_result,
|
||||
)
|
||||
55
backend/onyx/agents/agent_search/basic/states.py
Normal file
55
backend/onyx/agents/agent_search/basic/states.py
Normal file
@@ -0,0 +1,55 @@
|
||||
from typing import TypedDict
|
||||
|
||||
from onyx.tools.message import ToolCallSummary
|
||||
from onyx.tools.models import ToolCallFinalResult
|
||||
from onyx.tools.models import ToolCallKickoff
|
||||
from onyx.tools.models import ToolResponse
|
||||
from onyx.tools.tool import Tool
|
||||
|
||||
# States contain values that change over the course of graph execution,
|
||||
# Config is for values that are set at the start and never change.
|
||||
# If you are using a value from the config and realize it needs to change,
|
||||
# you should add it to the state and use/update the version in the state.
|
||||
|
||||
## Graph Input State
|
||||
|
||||
|
||||
class BasicInput(TypedDict):
|
||||
should_stream_answer: bool
|
||||
|
||||
|
||||
## Graph Output State
|
||||
|
||||
|
||||
class BasicOutput(TypedDict):
|
||||
pass
|
||||
|
||||
|
||||
## Update States
|
||||
class ToolCallUpdate(TypedDict):
|
||||
tool_call_summary: ToolCallSummary
|
||||
tool_call_kickoff: ToolCallKickoff
|
||||
tool_call_responses: list[ToolResponse]
|
||||
tool_call_final_result: ToolCallFinalResult
|
||||
|
||||
|
||||
class ToolChoice(TypedDict):
|
||||
tool: Tool
|
||||
tool_args: dict
|
||||
id: str | None
|
||||
|
||||
|
||||
class ToolChoiceUpdate(TypedDict):
|
||||
tool_choice: ToolChoice | None
|
||||
|
||||
|
||||
## Graph State
|
||||
|
||||
|
||||
class BasicState(
|
||||
BasicInput,
|
||||
ToolCallUpdate,
|
||||
ToolChoiceUpdate,
|
||||
BasicOutput,
|
||||
):
|
||||
pass
|
||||
67
backend/onyx/agents/agent_search/basic/utils.py
Normal file
67
backend/onyx/agents/agent_search/basic/utils.py
Normal file
@@ -0,0 +1,67 @@
|
||||
from collections.abc import Iterator
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.callbacks.manager import dispatch_custom_event
|
||||
from langchain_core.messages import AIMessageChunk
|
||||
from langchain_core.messages import BaseMessage
|
||||
|
||||
from onyx.chat.models import LlmDoc
|
||||
from onyx.chat.stream_processing.answer_response_handler import AnswerResponseHandler
|
||||
from onyx.chat.stream_processing.answer_response_handler import CitationResponseHandler
|
||||
from onyx.chat.stream_processing.answer_response_handler import (
|
||||
PassThroughAnswerResponseHandler,
|
||||
)
|
||||
from onyx.chat.stream_processing.utils import map_document_id_order
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# TODO: handle citations here; below is what was previously passed in
|
||||
# see basic_use_tool_response.py for where these variables come from
|
||||
# answer_handler = CitationResponseHandler(
|
||||
# context_docs=final_search_results,
|
||||
# final_doc_id_to_rank_map=map_document_id_order(final_search_results),
|
||||
# display_doc_id_to_rank_map=map_document_id_order(displayed_search_results),
|
||||
# )
|
||||
|
||||
|
||||
def process_llm_stream(
|
||||
stream: Iterator[BaseMessage],
|
||||
should_stream_answer: bool,
|
||||
final_search_results: list[LlmDoc] | None = None,
|
||||
displayed_search_results: list[LlmDoc] | None = None,
|
||||
) -> AIMessageChunk:
|
||||
tool_call_chunk = AIMessageChunk(content="")
|
||||
# for response in response_handler_manager.handle_llm_response(stream):
|
||||
|
||||
if final_search_results and displayed_search_results:
|
||||
answer_handler: AnswerResponseHandler = CitationResponseHandler(
|
||||
context_docs=final_search_results,
|
||||
final_doc_id_to_rank_map=map_document_id_order(final_search_results),
|
||||
display_doc_id_to_rank_map=map_document_id_order(displayed_search_results),
|
||||
)
|
||||
else:
|
||||
answer_handler = PassThroughAnswerResponseHandler()
|
||||
|
||||
# This stream will be the llm answer if no tool is chosen. When a tool is chosen,
|
||||
# the stream will contain AIMessageChunks with tool call information.
|
||||
for response in stream:
|
||||
answer_piece = response.content
|
||||
if not isinstance(answer_piece, str):
|
||||
# TODO: handle non-string content
|
||||
logger.warning(f"Received non-string content: {type(answer_piece)}")
|
||||
answer_piece = str(answer_piece)
|
||||
|
||||
if isinstance(response, AIMessageChunk) and (
|
||||
response.tool_call_chunks or response.tool_calls
|
||||
):
|
||||
tool_call_chunk += response # type: ignore
|
||||
elif should_stream_answer:
|
||||
# TODO: handle emitting of CitationInfo
|
||||
for response_part in answer_handler.handle_response_part(response, []):
|
||||
dispatch_custom_event(
|
||||
"basic_response",
|
||||
response_part,
|
||||
)
|
||||
|
||||
return cast(AIMessageChunk, tool_call_chunk)
|
||||
21
backend/onyx/agents/agent_search/core_state.py
Normal file
21
backend/onyx/agents/agent_search/core_state.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from operator import add
|
||||
from typing import Annotated
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class CoreState(BaseModel):
|
||||
"""
|
||||
This is the core state that is shared across all subgraphs.
|
||||
"""
|
||||
|
||||
base_question: str = ""
|
||||
log_messages: Annotated[list[str], add] = []
|
||||
|
||||
|
||||
class SubgraphCoreState(BaseModel):
|
||||
"""
|
||||
This is the core state that is shared across all subgraphs.
|
||||
"""
|
||||
|
||||
log_messages: Annotated[list[str], add]
|
||||
66
backend/onyx/agents/agent_search/db_operations.py
Normal file
66
backend/onyx/agents/agent_search/db_operations.py
Normal file
@@ -0,0 +1,66 @@
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.models import AgentSubQuery
|
||||
from onyx.db.models import AgentSubQuestion
|
||||
|
||||
|
||||
def create_sub_question(
|
||||
db_session: Session,
|
||||
chat_session_id: UUID,
|
||||
primary_message_id: int,
|
||||
sub_question: str,
|
||||
sub_answer: str,
|
||||
) -> AgentSubQuestion:
|
||||
"""Create a new sub-question record in the database."""
|
||||
sub_q = AgentSubQuestion(
|
||||
chat_session_id=chat_session_id,
|
||||
primary_question_id=primary_message_id,
|
||||
sub_question=sub_question,
|
||||
sub_answer=sub_answer,
|
||||
)
|
||||
db_session.add(sub_q)
|
||||
db_session.flush()
|
||||
return sub_q
|
||||
|
||||
|
||||
def create_sub_query(
|
||||
db_session: Session,
|
||||
chat_session_id: UUID,
|
||||
parent_question_id: int,
|
||||
sub_query: str,
|
||||
) -> AgentSubQuery:
|
||||
"""Create a new sub-query record in the database."""
|
||||
sub_q = AgentSubQuery(
|
||||
chat_session_id=chat_session_id,
|
||||
parent_question_id=parent_question_id,
|
||||
sub_query=sub_query,
|
||||
)
|
||||
db_session.add(sub_q)
|
||||
db_session.flush()
|
||||
return sub_q
|
||||
|
||||
|
||||
def get_sub_questions_for_message(
|
||||
db_session: Session,
|
||||
primary_message_id: int,
|
||||
) -> list[AgentSubQuestion]:
|
||||
"""Get all sub-questions for a given primary message."""
|
||||
return (
|
||||
db_session.query(AgentSubQuestion)
|
||||
.filter(AgentSubQuestion.primary_question_id == primary_message_id)
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
def get_sub_queries_for_question(
|
||||
db_session: Session,
|
||||
sub_question_id: int,
|
||||
) -> list[AgentSubQuery]:
|
||||
"""Get all sub-queries for a given sub-question."""
|
||||
return (
|
||||
db_session.query(AgentSubQuery)
|
||||
.filter(AgentSubQuery.parent_question_id == sub_question_id)
|
||||
.all()
|
||||
)
|
||||
@@ -0,0 +1,29 @@
|
||||
from collections.abc import Hashable
|
||||
from datetime import datetime
|
||||
|
||||
from langgraph.types import Send
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.states import (
|
||||
AnswerQuestionInput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
|
||||
ExpandedRetrievalInput,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def send_to_expanded_retrieval(state: AnswerQuestionInput) -> Send | Hashable:
|
||||
logger.debug("sending to expanded retrieval via edge")
|
||||
now_start = datetime.now()
|
||||
|
||||
return Send(
|
||||
"initial_sub_question_expanded_retrieval",
|
||||
ExpandedRetrievalInput(
|
||||
question=state.question,
|
||||
base_search=False,
|
||||
sub_question_id=state.question_id,
|
||||
log_messages=[f"{now_start} -- Sending to expanded retrieval"],
|
||||
),
|
||||
)
|
||||
@@ -0,0 +1,126 @@
|
||||
from langgraph.graph import END
|
||||
from langgraph.graph import START
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.edges import (
|
||||
send_to_expanded_retrieval,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.nodes.answer_check import (
|
||||
answer_check,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.nodes.answer_generation import (
|
||||
answer_generation,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.nodes.format_answer import (
|
||||
format_answer,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.nodes.ingest_retrieval import (
|
||||
ingest_retrieval,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.states import (
|
||||
AnswerQuestionInput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.states import (
|
||||
AnswerQuestionOutput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.states import (
|
||||
AnswerQuestionState,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.graph_builder import (
|
||||
expanded_retrieval_graph_builder,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def answer_query_graph_builder() -> StateGraph:
|
||||
graph = StateGraph(
|
||||
state_schema=AnswerQuestionState,
|
||||
input=AnswerQuestionInput,
|
||||
output=AnswerQuestionOutput,
|
||||
)
|
||||
|
||||
### Add nodes ###
|
||||
|
||||
expanded_retrieval = expanded_retrieval_graph_builder().compile()
|
||||
graph.add_node(
|
||||
node="initial_sub_question_expanded_retrieval",
|
||||
action=expanded_retrieval,
|
||||
)
|
||||
graph.add_node(
|
||||
node="answer_check",
|
||||
action=answer_check,
|
||||
)
|
||||
graph.add_node(
|
||||
node="answer_generation",
|
||||
action=answer_generation,
|
||||
)
|
||||
graph.add_node(
|
||||
node="format_answer",
|
||||
action=format_answer,
|
||||
)
|
||||
graph.add_node(
|
||||
node="ingest_retrieval",
|
||||
action=ingest_retrieval,
|
||||
)
|
||||
|
||||
### Add edges ###
|
||||
|
||||
graph.add_conditional_edges(
|
||||
source=START,
|
||||
path=send_to_expanded_retrieval,
|
||||
path_map=["initial_sub_question_expanded_retrieval"],
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="initial_sub_question_expanded_retrieval",
|
||||
end_key="ingest_retrieval",
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="ingest_retrieval",
|
||||
end_key="answer_generation",
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="answer_generation",
|
||||
end_key="answer_check",
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="answer_check",
|
||||
end_key="format_answer",
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="format_answer",
|
||||
end_key=END,
|
||||
)
|
||||
|
||||
return graph
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.llm.factory import get_default_llms
|
||||
from onyx.context.search.models import SearchRequest
|
||||
|
||||
graph = answer_query_graph_builder()
|
||||
compiled_graph = graph.compile()
|
||||
primary_llm, fast_llm = get_default_llms()
|
||||
search_request = SearchRequest(
|
||||
query="what can you do with onyx or danswer?",
|
||||
)
|
||||
with get_session_context_manager() as db_session:
|
||||
agent_search_config, search_tool = get_test_config(
|
||||
db_session, primary_llm, fast_llm, search_request
|
||||
)
|
||||
inputs = AnswerQuestionInput(
|
||||
question="what can you do with onyx?",
|
||||
question_id="0_0",
|
||||
log_messages=[],
|
||||
)
|
||||
for thing in compiled_graph.stream(
|
||||
input=inputs,
|
||||
config={"configurable": {"config": agent_search_config}},
|
||||
# debug=True,
|
||||
# subgraphs=True,
|
||||
):
|
||||
logger.debug(thing)
|
||||
@@ -0,0 +1,8 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
### Models ###
|
||||
|
||||
|
||||
class AnswerRetrievalStats(BaseModel):
|
||||
answer_retrieval_stats: dict[str, float | int]
|
||||
@@ -0,0 +1,59 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import merge_message_runs
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.states import (
|
||||
AnswerQuestionState,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.states import (
|
||||
QACheckUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import SUB_CHECK_NO
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import SUB_CHECK_PROMPT
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import UNKNOWN_ANSWER
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
|
||||
|
||||
|
||||
def answer_check(state: AnswerQuestionState, config: RunnableConfig) -> QACheckUpdate:
|
||||
now_start = datetime.now()
|
||||
|
||||
level, question_num = parse_question_id(state.question_id)
|
||||
if state.answer == UNKNOWN_ANSWER:
|
||||
now_end = datetime.now()
|
||||
return QACheckUpdate(
|
||||
answer_quality=SUB_CHECK_NO,
|
||||
log_messages=[
|
||||
f"{now_end} -- Answer check SQ-{level}-{question_num} - unknown answer, Time taken: {now_end - now_start}"
|
||||
],
|
||||
)
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=SUB_CHECK_PROMPT.format(
|
||||
question=state.question,
|
||||
base_answer=state.answer,
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
agent_searchch_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
fast_llm = agent_searchch_config.fast_llm
|
||||
response = list(
|
||||
fast_llm.stream(
|
||||
prompt=msg,
|
||||
)
|
||||
)
|
||||
|
||||
quality_str = merge_message_runs(response, chunk_separator="")[0].content
|
||||
|
||||
now_end = datetime.now()
|
||||
return QACheckUpdate(
|
||||
answer_quality=quality_str,
|
||||
log_messages=[
|
||||
f"""{now_end} -- Answer check SQ-{level}-{question_num} - Answer quality: {quality_str},
|
||||
Time taken: {now_end - now_start}"""
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,116 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.callbacks.manager import dispatch_custom_event
|
||||
from langchain_core.messages import merge_message_runs
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.states import (
|
||||
AnswerQuestionState,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.states import (
|
||||
QAGenerationUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
build_sub_question_answer_prompt,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import (
|
||||
ASSISTANT_SYSTEM_PROMPT_DEFAULT,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import (
|
||||
ASSISTANT_SYSTEM_PROMPT_PERSONA,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import NO_RECOVERED_DOCS
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import get_persona_prompt
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
|
||||
from onyx.chat.models import AgentAnswerPiece
|
||||
from onyx.chat.models import StreamStopInfo
|
||||
from onyx.chat.models import StreamStopReason
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def answer_generation(
|
||||
state: AnswerQuestionState, config: RunnableConfig
|
||||
) -> QAGenerationUpdate:
|
||||
now_start = datetime.now()
|
||||
logger.debug(f"--------{now_start}--------START ANSWER GENERATION---")
|
||||
|
||||
agent_search_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
question = state.question
|
||||
docs = state.documents
|
||||
level, question_nr = parse_question_id(state.question_id)
|
||||
context_docs = state.context_documents
|
||||
persona_prompt = get_persona_prompt(agent_search_config.search_request.persona)
|
||||
|
||||
if len(context_docs) == 0:
|
||||
answer_str = NO_RECOVERED_DOCS
|
||||
dispatch_custom_event(
|
||||
"sub_answers",
|
||||
AgentAnswerPiece(
|
||||
answer_piece=answer_str,
|
||||
level=level,
|
||||
level_question_nr=question_nr,
|
||||
answer_type="agent_sub_answer",
|
||||
),
|
||||
)
|
||||
else:
|
||||
if len(persona_prompt) > 0:
|
||||
persona_specification = ASSISTANT_SYSTEM_PROMPT_DEFAULT
|
||||
else:
|
||||
persona_specification = ASSISTANT_SYSTEM_PROMPT_PERSONA.format(
|
||||
persona_prompt=persona_prompt
|
||||
)
|
||||
|
||||
logger.debug(f"Number of verified retrieval docs: {len(docs)}")
|
||||
|
||||
fast_llm = agent_search_config.fast_llm
|
||||
msg = build_sub_question_answer_prompt(
|
||||
question=question,
|
||||
original_question=agent_search_config.search_request.query,
|
||||
docs=docs,
|
||||
persona_specification=persona_specification,
|
||||
config=fast_llm.config,
|
||||
)
|
||||
|
||||
response: list[str | list[str | dict[str, Any]]] = []
|
||||
for message in fast_llm.stream(
|
||||
prompt=msg,
|
||||
):
|
||||
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
|
||||
content = message.content
|
||||
if not isinstance(content, str):
|
||||
raise ValueError(
|
||||
f"Expected content to be a string, but got {type(content)}"
|
||||
)
|
||||
dispatch_custom_event(
|
||||
"sub_answers",
|
||||
AgentAnswerPiece(
|
||||
answer_piece=content,
|
||||
level=level,
|
||||
level_question_nr=question_nr,
|
||||
answer_type="agent_sub_answer",
|
||||
),
|
||||
)
|
||||
response.append(content)
|
||||
|
||||
answer_str = merge_message_runs(response, chunk_separator="")[0].content
|
||||
|
||||
stop_event = StreamStopInfo(
|
||||
stop_reason=StreamStopReason.FINISHED,
|
||||
stream_type="sub_answer",
|
||||
level=level,
|
||||
level_question_nr=question_nr,
|
||||
)
|
||||
dispatch_custom_event("stream_finished", stop_event)
|
||||
|
||||
now_end = datetime.now()
|
||||
return QAGenerationUpdate(
|
||||
answer=answer_str,
|
||||
log_messages=[
|
||||
f"{now_end} -- Answer generation SQ-{level} - Q{question_nr} - Time taken: {now_end - now_start}"
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,28 @@
|
||||
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.states import (
|
||||
AnswerQuestionOutput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.states import (
|
||||
AnswerQuestionState,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import (
|
||||
QuestionAnswerResults,
|
||||
)
|
||||
|
||||
|
||||
def format_answer(state: AnswerQuestionState) -> AnswerQuestionOutput:
|
||||
return AnswerQuestionOutput(
|
||||
answer_results=[
|
||||
QuestionAnswerResults(
|
||||
question=state.question,
|
||||
question_id=state.question_id,
|
||||
quality=state.answer_quality
|
||||
if hasattr(state, "answer_quality")
|
||||
else "No",
|
||||
answer=state.answer,
|
||||
expanded_retrieval_results=state.expanded_retrieval_results,
|
||||
documents=state.documents,
|
||||
context_documents=state.context_documents,
|
||||
sub_question_retrieval_stats=state.sub_question_retrieval_stats,
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,22 @@
|
||||
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.states import (
|
||||
RetrievalIngestionUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
|
||||
ExpandedRetrievalOutput,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkStats
|
||||
|
||||
|
||||
def ingest_retrieval(state: ExpandedRetrievalOutput) -> RetrievalIngestionUpdate:
|
||||
sub_question_retrieval_stats = (
|
||||
state.expanded_retrieval_result.sub_question_retrieval_stats
|
||||
)
|
||||
if sub_question_retrieval_stats is None:
|
||||
sub_question_retrieval_stats = [AgentChunkStats()]
|
||||
|
||||
return RetrievalIngestionUpdate(
|
||||
expanded_retrieval_results=state.expanded_retrieval_result.expanded_queries_results,
|
||||
documents=state.expanded_retrieval_result.all_documents,
|
||||
context_documents=state.expanded_retrieval_result.context_documents,
|
||||
sub_question_retrieval_stats=sub_question_retrieval_stats,
|
||||
)
|
||||
@@ -0,0 +1,71 @@
|
||||
from operator import add
|
||||
from typing import Annotated
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.agents.agent_search.core_state import SubgraphCoreState
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkStats
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import QueryResult
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import (
|
||||
QuestionAnswerResults,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.operators import (
|
||||
dedup_inference_sections,
|
||||
)
|
||||
from onyx.context.search.models import InferenceSection
|
||||
|
||||
|
||||
## Update States
|
||||
class QACheckUpdate(BaseModel):
|
||||
answer_quality: str = ""
|
||||
log_messages: list[str] = []
|
||||
|
||||
|
||||
class QAGenerationUpdate(BaseModel):
|
||||
answer: str = ""
|
||||
log_messages: list[str] = []
|
||||
# answer_stat: AnswerStats
|
||||
|
||||
|
||||
class RetrievalIngestionUpdate(BaseModel):
|
||||
expanded_retrieval_results: list[QueryResult] = []
|
||||
documents: Annotated[list[InferenceSection], dedup_inference_sections] = []
|
||||
context_documents: Annotated[list[InferenceSection], dedup_inference_sections] = []
|
||||
sub_question_retrieval_stats: AgentChunkStats = AgentChunkStats()
|
||||
|
||||
|
||||
## Graph Input State
|
||||
|
||||
|
||||
class AnswerQuestionInput(SubgraphCoreState):
|
||||
question: str = ""
|
||||
question_id: str = (
|
||||
"" # 0_0 is original question, everything else is <level>_<question_num>.
|
||||
)
|
||||
# level 0 is original question and first decomposition, level 1 is follow up, etc
|
||||
# question_num is a unique number per original question per level.
|
||||
|
||||
|
||||
## Graph State
|
||||
|
||||
|
||||
class AnswerQuestionState(
|
||||
AnswerQuestionInput,
|
||||
QAGenerationUpdate,
|
||||
QACheckUpdate,
|
||||
RetrievalIngestionUpdate,
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
## Graph Output State
|
||||
|
||||
|
||||
class AnswerQuestionOutput(BaseModel):
|
||||
"""
|
||||
This is a list of results even though each call of this subgraph only returns one result.
|
||||
This is because if we parallelize the answer query subgraph, there will be multiple
|
||||
results in a list so the add operator is used to add them together.
|
||||
"""
|
||||
|
||||
answer_results: Annotated[list[QuestionAnswerResults], add] = []
|
||||
@@ -0,0 +1,28 @@
|
||||
from collections.abc import Hashable
|
||||
from datetime import datetime
|
||||
|
||||
from langgraph.types import Send
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.states import (
|
||||
AnswerQuestionInput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
|
||||
ExpandedRetrievalInput,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def send_to_expanded_refined_retrieval(state: AnswerQuestionInput) -> Send | Hashable:
|
||||
logger.debug("sending to expanded retrieval for follow up question via edge")
|
||||
datetime.now()
|
||||
return Send(
|
||||
"refined_sub_question_expanded_retrieval",
|
||||
ExpandedRetrievalInput(
|
||||
question=state.question,
|
||||
sub_question_id=state.question_id,
|
||||
base_search=False,
|
||||
log_messages=[f"{datetime.now()} -- Sending to expanded retrieval"],
|
||||
),
|
||||
)
|
||||
@@ -0,0 +1,123 @@
|
||||
from langgraph.graph import END
|
||||
from langgraph.graph import START
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.nodes.answer_check import (
|
||||
answer_check,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.nodes.answer_generation import (
|
||||
answer_generation,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.nodes.format_answer import (
|
||||
format_answer,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.nodes.ingest_retrieval import (
|
||||
ingest_retrieval,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.states import (
|
||||
AnswerQuestionInput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.states import (
|
||||
AnswerQuestionOutput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.states import (
|
||||
AnswerQuestionState,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.answer_refinement_sub_question.edges import (
|
||||
send_to_expanded_refined_retrieval,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.graph_builder import (
|
||||
expanded_retrieval_graph_builder,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def answer_refined_query_graph_builder() -> StateGraph:
|
||||
graph = StateGraph(
|
||||
state_schema=AnswerQuestionState,
|
||||
input=AnswerQuestionInput,
|
||||
output=AnswerQuestionOutput,
|
||||
)
|
||||
|
||||
### Add nodes ###
|
||||
|
||||
expanded_retrieval = expanded_retrieval_graph_builder().compile()
|
||||
graph.add_node(
|
||||
node="refined_sub_question_expanded_retrieval",
|
||||
action=expanded_retrieval,
|
||||
)
|
||||
graph.add_node(
|
||||
node="refined_sub_answer_check",
|
||||
action=answer_check,
|
||||
)
|
||||
graph.add_node(
|
||||
node="refined_sub_answer_generation",
|
||||
action=answer_generation,
|
||||
)
|
||||
graph.add_node(
|
||||
node="format_refined_sub_answer",
|
||||
action=format_answer,
|
||||
)
|
||||
graph.add_node(
|
||||
node="ingest_refined_retrieval",
|
||||
action=ingest_retrieval,
|
||||
)
|
||||
|
||||
### Add edges ###
|
||||
|
||||
graph.add_conditional_edges(
|
||||
source=START,
|
||||
path=send_to_expanded_refined_retrieval,
|
||||
path_map=["refined_sub_question_expanded_retrieval"],
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="refined_sub_question_expanded_retrieval",
|
||||
end_key="ingest_refined_retrieval",
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="ingest_refined_retrieval",
|
||||
end_key="refined_sub_answer_generation",
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="refined_sub_answer_generation",
|
||||
end_key="refined_sub_answer_check",
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="refined_sub_answer_check",
|
||||
end_key="format_refined_sub_answer",
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="format_refined_sub_answer",
|
||||
end_key=END,
|
||||
)
|
||||
|
||||
return graph
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.llm.factory import get_default_llms
|
||||
from onyx.context.search.models import SearchRequest
|
||||
|
||||
graph = answer_refined_query_graph_builder()
|
||||
compiled_graph = graph.compile()
|
||||
primary_llm, fast_llm = get_default_llms()
|
||||
search_request = SearchRequest(
|
||||
query="what can you do with onyx or danswer?",
|
||||
)
|
||||
with get_session_context_manager() as db_session:
|
||||
inputs = AnswerQuestionInput(
|
||||
question="what can you do with onyx?",
|
||||
question_id="0_0",
|
||||
log_messages=[],
|
||||
)
|
||||
for thing in compiled_graph.stream(
|
||||
input=inputs,
|
||||
# debug=True,
|
||||
# subgraphs=True,
|
||||
):
|
||||
logger.debug(thing)
|
||||
# output = compiled_graph.invoke(inputs)
|
||||
# logger.debug(output)
|
||||
@@ -0,0 +1,19 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkStats
|
||||
from onyx.context.search.models import InferenceSection
|
||||
|
||||
### Models ###
|
||||
|
||||
|
||||
class AnswerRetrievalStats(BaseModel):
|
||||
answer_retrieval_stats: dict[str, float | int]
|
||||
|
||||
|
||||
class QuestionAnswerResults(BaseModel):
|
||||
question: str
|
||||
answer: str
|
||||
quality: str
|
||||
# expanded_retrieval_results: list[QueryResult]
|
||||
documents: list[InferenceSection]
|
||||
sub_question_retrieval_stats: AgentChunkStats
|
||||
@@ -0,0 +1,76 @@
|
||||
from langgraph.graph import END
|
||||
from langgraph.graph import START
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.base_raw_search.nodes.format_raw_search_results import (
|
||||
format_raw_search_results,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.base_raw_search.nodes.generate_raw_search_data import (
|
||||
generate_raw_search_data,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.base_raw_search.states import (
|
||||
BaseRawSearchInput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.base_raw_search.states import (
|
||||
BaseRawSearchOutput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.base_raw_search.states import (
|
||||
BaseRawSearchState,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.graph_builder import (
|
||||
expanded_retrieval_graph_builder,
|
||||
)
|
||||
|
||||
|
||||
def base_raw_search_graph_builder() -> StateGraph:
|
||||
graph = StateGraph(
|
||||
state_schema=BaseRawSearchState,
|
||||
input=BaseRawSearchInput,
|
||||
output=BaseRawSearchOutput,
|
||||
)
|
||||
|
||||
### Add nodes ###
|
||||
|
||||
graph.add_node(
|
||||
node="generate_raw_search_data",
|
||||
action=generate_raw_search_data,
|
||||
)
|
||||
|
||||
expanded_retrieval = expanded_retrieval_graph_builder().compile()
|
||||
graph.add_node(
|
||||
node="expanded_retrieval_base_search",
|
||||
action=expanded_retrieval,
|
||||
)
|
||||
graph.add_node(
|
||||
node="format_raw_search_results",
|
||||
action=format_raw_search_results,
|
||||
)
|
||||
|
||||
### Add edges ###
|
||||
|
||||
graph.add_edge(start_key=START, end_key="generate_raw_search_data")
|
||||
|
||||
graph.add_edge(
|
||||
start_key="generate_raw_search_data",
|
||||
end_key="expanded_retrieval_base_search",
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="expanded_retrieval_base_search",
|
||||
end_key="format_raw_search_results",
|
||||
)
|
||||
|
||||
# graph.add_edge(
|
||||
# start_key="expanded_retrieval_base_search",
|
||||
# end_key=END,
|
||||
# )
|
||||
|
||||
graph.add_edge(
|
||||
start_key="format_raw_search_results",
|
||||
end_key=END,
|
||||
)
|
||||
|
||||
return graph
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pass
|
||||
@@ -0,0 +1,20 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkStats
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import QueryResult
|
||||
from onyx.context.search.models import InferenceSection
|
||||
|
||||
### Models ###
|
||||
|
||||
|
||||
class AnswerRetrievalStats(BaseModel):
|
||||
answer_retrieval_stats: dict[str, float | int]
|
||||
|
||||
|
||||
class QuestionAnswerResults(BaseModel):
|
||||
question: str
|
||||
answer: str
|
||||
quality: str
|
||||
expanded_retrieval_results: list[QueryResult]
|
||||
documents: list[InferenceSection]
|
||||
sub_question_retrieval_stats: list[AgentChunkStats]
|
||||
@@ -0,0 +1,18 @@
|
||||
from onyx.agents.agent_search.deep_search_a.base_raw_search.states import (
|
||||
BaseRawSearchOutput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
|
||||
ExpandedRetrievalOutput,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def format_raw_search_results(state: ExpandedRetrievalOutput) -> BaseRawSearchOutput:
|
||||
logger.debug("format_raw_search_results")
|
||||
return BaseRawSearchOutput(
|
||||
base_expanded_retrieval_result=state.expanded_retrieval_result,
|
||||
# base_retrieval_results=[state.expanded_retrieval_result],
|
||||
# base_search_documents=[],
|
||||
)
|
||||
@@ -0,0 +1,25 @@
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.core_state import CoreState
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
|
||||
ExpandedRetrievalInput,
|
||||
)
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def generate_raw_search_data(
|
||||
state: CoreState, config: RunnableConfig
|
||||
) -> ExpandedRetrievalInput:
|
||||
logger.debug("generate_raw_search_data")
|
||||
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
return ExpandedRetrievalInput(
|
||||
question=agent_a_config.search_request.query,
|
||||
base_search=True,
|
||||
sub_question_id=None, # This graph is always and only used for the original question
|
||||
log_messages=[],
|
||||
)
|
||||
@@ -0,0 +1,43 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.models import (
|
||||
ExpandedRetrievalResult,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
|
||||
ExpandedRetrievalInput,
|
||||
)
|
||||
|
||||
|
||||
## Update States
|
||||
|
||||
|
||||
## Graph Input State
|
||||
|
||||
|
||||
class BaseRawSearchInput(ExpandedRetrievalInput):
|
||||
pass
|
||||
|
||||
|
||||
## Graph Output State
|
||||
|
||||
|
||||
class BaseRawSearchOutput(BaseModel):
|
||||
"""
|
||||
This is a list of results even though each call of this subgraph only returns one result.
|
||||
This is because if we parallelize the answer query subgraph, there will be multiple
|
||||
results in a list so the add operator is used to add them together.
|
||||
"""
|
||||
|
||||
# base_search_documents: Annotated[list[InferenceSection], dedup_inference_sections]
|
||||
# base_retrieval_results: Annotated[list[ExpandedRetrievalResult], add]
|
||||
base_expanded_retrieval_result: ExpandedRetrievalResult = ExpandedRetrievalResult()
|
||||
|
||||
|
||||
## Graph State
|
||||
|
||||
|
||||
class BaseRawSearchState(
|
||||
BaseRawSearchInput,
|
||||
BaseRawSearchOutput,
|
||||
):
|
||||
pass
|
||||
@@ -0,0 +1,37 @@
|
||||
from collections.abc import Hashable
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
from langgraph.types import Send
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
|
||||
ExpandedRetrievalState,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
|
||||
RetrievalInput,
|
||||
)
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
|
||||
|
||||
def parallel_retrieval_edge(
|
||||
state: ExpandedRetrievalState, config: RunnableConfig
|
||||
) -> list[Send | Hashable]:
|
||||
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
question = state.question if state.question else agent_a_config.search_request.query
|
||||
|
||||
query_expansions = (
|
||||
state.expanded_queries if state.expanded_queries else [] + [question]
|
||||
)
|
||||
return [
|
||||
Send(
|
||||
"doc_retrieval",
|
||||
RetrievalInput(
|
||||
query_to_retrieve=query,
|
||||
question=question,
|
||||
base_search=False,
|
||||
sub_question_id=state.sub_question_id,
|
||||
log_messages=[],
|
||||
),
|
||||
)
|
||||
for query in query_expansions
|
||||
]
|
||||
@@ -0,0 +1,147 @@
|
||||
from langgraph.graph import END
|
||||
from langgraph.graph import START
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.edges import (
|
||||
parallel_retrieval_edge,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.nodes.doc_reranking import (
|
||||
doc_reranking,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.nodes.doc_retrieval import (
|
||||
doc_retrieval,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.nodes.doc_verification import (
|
||||
doc_verification,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.nodes.dummy import (
|
||||
dummy,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.nodes.expand_queries import (
|
||||
expand_queries,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.nodes.format_results import (
|
||||
format_results,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.nodes.verification_kickoff import (
|
||||
verification_kickoff,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
|
||||
ExpandedRetrievalInput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
|
||||
ExpandedRetrievalOutput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
|
||||
ExpandedRetrievalState,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def expanded_retrieval_graph_builder() -> StateGraph:
|
||||
graph = StateGraph(
|
||||
state_schema=ExpandedRetrievalState,
|
||||
input=ExpandedRetrievalInput,
|
||||
output=ExpandedRetrievalOutput,
|
||||
)
|
||||
|
||||
### Add nodes ###
|
||||
|
||||
graph.add_node(
|
||||
node="expand_queries",
|
||||
action=expand_queries,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
node="dummy",
|
||||
action=dummy,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
node="doc_retrieval",
|
||||
action=doc_retrieval,
|
||||
)
|
||||
graph.add_node(
|
||||
node="verification_kickoff",
|
||||
action=verification_kickoff,
|
||||
)
|
||||
graph.add_node(
|
||||
node="doc_verification",
|
||||
action=doc_verification,
|
||||
)
|
||||
graph.add_node(
|
||||
node="doc_reranking",
|
||||
action=doc_reranking,
|
||||
)
|
||||
graph.add_node(
|
||||
node="format_results",
|
||||
action=format_results,
|
||||
)
|
||||
|
||||
### Add edges ###
|
||||
graph.add_edge(
|
||||
start_key=START,
|
||||
end_key="expand_queries",
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="expand_queries",
|
||||
end_key="dummy",
|
||||
)
|
||||
|
||||
graph.add_conditional_edges(
|
||||
source="dummy",
|
||||
path=parallel_retrieval_edge,
|
||||
path_map=["doc_retrieval"],
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="doc_retrieval",
|
||||
end_key="verification_kickoff",
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="doc_verification",
|
||||
end_key="doc_reranking",
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="doc_reranking",
|
||||
end_key="format_results",
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="format_results",
|
||||
end_key=END,
|
||||
)
|
||||
|
||||
return graph
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.llm.factory import get_default_llms
|
||||
from onyx.context.search.models import SearchRequest
|
||||
|
||||
graph = expanded_retrieval_graph_builder()
|
||||
compiled_graph = graph.compile()
|
||||
primary_llm, fast_llm = get_default_llms()
|
||||
search_request = SearchRequest(
|
||||
query="what can you do with onyx or danswer?",
|
||||
)
|
||||
|
||||
with get_session_context_manager() as db_session:
|
||||
agent_a_config, search_tool = get_test_config(
|
||||
db_session, primary_llm, fast_llm, search_request
|
||||
)
|
||||
inputs = ExpandedRetrievalInput(
|
||||
question="what can you do with onyx?",
|
||||
base_search=False,
|
||||
sub_question_id=None,
|
||||
log_messages=[],
|
||||
)
|
||||
for thing in compiled_graph.stream(
|
||||
input=inputs,
|
||||
config={"configurable": {"config": agent_a_config}},
|
||||
# debug=True,
|
||||
subgraphs=True,
|
||||
):
|
||||
logger.debug(thing)
|
||||
@@ -0,0 +1,12 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkStats
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import QueryResult
|
||||
from onyx.context.search.models import InferenceSection
|
||||
|
||||
|
||||
class ExpandedRetrievalResult(BaseModel):
|
||||
expanded_queries_results: list[QueryResult] = []
|
||||
all_documents: list[InferenceSection] = []
|
||||
context_documents: list[InferenceSection] = []
|
||||
sub_question_retrieval_stats: AgentChunkStats = AgentChunkStats()
|
||||
@@ -0,0 +1,74 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.operations import logger
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
|
||||
DocRerankingUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
|
||||
ExpandedRetrievalState,
|
||||
)
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.calculations import get_fit_scores
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import RetrievalFitStats
|
||||
from onyx.configs.dev_configs import AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS
|
||||
from onyx.configs.dev_configs import AGENT_RERANKING_STATS
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.context.search.models import SearchRequest
|
||||
from onyx.context.search.pipeline import retrieval_preprocessing
|
||||
from onyx.context.search.postprocessing.postprocessing import rerank_sections
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
|
||||
|
||||
def doc_reranking(
|
||||
state: ExpandedRetrievalState, config: RunnableConfig
|
||||
) -> DocRerankingUpdate:
|
||||
now_start = datetime.now()
|
||||
verified_documents = state.verified_documents
|
||||
|
||||
# Rerank post retrieval and verification. First, create a search query
|
||||
# then create the list of reranked sections
|
||||
|
||||
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
question = state.question if state.question else agent_a_config.search_request.query
|
||||
with get_session_context_manager() as db_session:
|
||||
_search_query = retrieval_preprocessing(
|
||||
search_request=SearchRequest(query=question),
|
||||
user=agent_a_config.search_tool.user, # bit of a hack
|
||||
llm=agent_a_config.fast_llm,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# skip section filtering
|
||||
|
||||
if (
|
||||
_search_query.rerank_settings
|
||||
and _search_query.rerank_settings.rerank_model_name
|
||||
and _search_query.rerank_settings.num_rerank > 0
|
||||
):
|
||||
reranked_documents = rerank_sections(
|
||||
_search_query,
|
||||
verified_documents,
|
||||
)
|
||||
else:
|
||||
logger.warning("No reranking settings found, using unranked documents")
|
||||
reranked_documents = verified_documents
|
||||
|
||||
if AGENT_RERANKING_STATS:
|
||||
fit_scores = get_fit_scores(verified_documents, reranked_documents)
|
||||
else:
|
||||
fit_scores = RetrievalFitStats(fit_score_lift=0, rerank_effect=0, fit_scores={})
|
||||
|
||||
# TODO: stream deduped docs here, or decide to use search tool ranking/verification
|
||||
now_end = datetime.now()
|
||||
return DocRerankingUpdate(
|
||||
reranked_documents=[
|
||||
doc for doc in reranked_documents if type(doc) == InferenceSection
|
||||
][:AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS],
|
||||
sub_question_retrieval_stats=fit_scores,
|
||||
log_messages=[
|
||||
f"{now_end} -- Expanded Retrieval - Reranking - Time taken: {now_end - now_start}"
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,103 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.operations import logger
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
|
||||
DocRetrievalUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
|
||||
RetrievalInput,
|
||||
)
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.calculations import get_fit_scores
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import QueryResult
|
||||
from onyx.configs.dev_configs import AGENT_MAX_QUERY_RETRIEVAL_RESULTS
|
||||
from onyx.configs.dev_configs import AGENT_RETRIEVAL_STATS
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.tools.models import SearchQueryInfo
|
||||
from onyx.tools.tool_implementations.search.search_tool import (
|
||||
SEARCH_RESPONSE_SUMMARY_ID,
|
||||
)
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary
|
||||
|
||||
|
||||
def doc_retrieval(state: RetrievalInput, config: RunnableConfig) -> DocRetrievalUpdate:
|
||||
"""
|
||||
Retrieve documents
|
||||
|
||||
Args:
|
||||
state (RetrievalInput): Primary state + the query to retrieve
|
||||
config (RunnableConfig): Configuration containing ProSearchConfig
|
||||
|
||||
Updates:
|
||||
expanded_retrieval_results: list[ExpandedRetrievalResult]
|
||||
retrieved_documents: list[InferenceSection]
|
||||
"""
|
||||
now_start = datetime.now()
|
||||
query_to_retrieve = state.query_to_retrieve
|
||||
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
search_tool = agent_a_config.search_tool
|
||||
|
||||
retrieved_docs: list[InferenceSection] = []
|
||||
if not query_to_retrieve.strip():
|
||||
logger.warning("Empty query, skipping retrieval")
|
||||
now_end = datetime.now()
|
||||
return DocRetrievalUpdate(
|
||||
expanded_retrieval_results=[],
|
||||
retrieved_documents=[],
|
||||
log_messages=[
|
||||
f"{now_end} -- Expanded Retrieval - Retrieval - Empty Query - Time taken: {now_end - now_start}"
|
||||
],
|
||||
)
|
||||
|
||||
query_info = None
|
||||
# new db session to avoid concurrency issues
|
||||
with get_session_context_manager() as db_session:
|
||||
for tool_response in search_tool.run(
|
||||
query=query_to_retrieve,
|
||||
force_no_rerank=True,
|
||||
alternate_db_session=db_session,
|
||||
):
|
||||
# get retrieved docs to send to the rest of the graph
|
||||
if tool_response.id == SEARCH_RESPONSE_SUMMARY_ID:
|
||||
response = cast(SearchResponseSummary, tool_response.response)
|
||||
retrieved_docs = response.top_sections
|
||||
query_info = SearchQueryInfo(
|
||||
predicted_search=response.predicted_search,
|
||||
final_filters=response.final_filters,
|
||||
recency_bias_multiplier=response.recency_bias_multiplier,
|
||||
)
|
||||
break
|
||||
|
||||
retrieved_docs = retrieved_docs[:AGENT_MAX_QUERY_RETRIEVAL_RESULTS]
|
||||
pre_rerank_docs = retrieved_docs
|
||||
if search_tool.search_pipeline is not None:
|
||||
pre_rerank_docs = (
|
||||
search_tool.search_pipeline._retrieved_sections or retrieved_docs
|
||||
)
|
||||
|
||||
if AGENT_RETRIEVAL_STATS:
|
||||
fit_scores = get_fit_scores(
|
||||
pre_rerank_docs,
|
||||
retrieved_docs,
|
||||
)
|
||||
else:
|
||||
fit_scores = None
|
||||
|
||||
expanded_retrieval_result = QueryResult(
|
||||
query=query_to_retrieve,
|
||||
search_results=retrieved_docs,
|
||||
stats=fit_scores,
|
||||
query_info=query_info,
|
||||
)
|
||||
now_end = datetime.now()
|
||||
return DocRetrievalUpdate(
|
||||
expanded_retrieval_results=[expanded_retrieval_result],
|
||||
retrieved_documents=retrieved_docs,
|
||||
log_messages=[
|
||||
f"{now_end} -- Expanded Retrieval - Retrieval - Time taken: {now_end - now_start}"
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,60 @@
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
|
||||
DocVerificationInput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
|
||||
DocVerificationUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
trim_prompt_piece,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import VERIFIER_PROMPT
|
||||
|
||||
|
||||
def doc_verification(
|
||||
state: DocVerificationInput, config: RunnableConfig
|
||||
) -> DocVerificationUpdate:
|
||||
"""
|
||||
Check whether the document is relevant for the original user question
|
||||
|
||||
Args:
|
||||
state (DocVerificationInput): The current state
|
||||
config (RunnableConfig): Configuration containing ProSearchConfig
|
||||
|
||||
Updates:
|
||||
verified_documents: list[InferenceSection]
|
||||
"""
|
||||
|
||||
question = state.question
|
||||
doc_to_verify = state.doc_to_verify
|
||||
document_content = doc_to_verify.combined_content
|
||||
|
||||
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
fast_llm = agent_a_config.fast_llm
|
||||
|
||||
document_content = trim_prompt_piece(
|
||||
fast_llm.config, document_content, VERIFIER_PROMPT + question
|
||||
)
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=VERIFIER_PROMPT.format(
|
||||
question=question, document_content=document_content
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
response = fast_llm.invoke(msg)
|
||||
|
||||
verified_documents = []
|
||||
if isinstance(response.content, str) and "yes" in response.content.lower():
|
||||
verified_documents.append(doc_to_verify)
|
||||
|
||||
return DocVerificationUpdate(
|
||||
verified_documents=verified_documents,
|
||||
)
|
||||
@@ -0,0 +1,16 @@
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
|
||||
ExpandedRetrievalState,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
|
||||
QueryExpansionUpdate,
|
||||
)
|
||||
|
||||
|
||||
def dummy(
|
||||
state: ExpandedRetrievalState, config: RunnableConfig
|
||||
) -> QueryExpansionUpdate:
|
||||
return QueryExpansionUpdate(
|
||||
expanded_queries=state.expanded_queries,
|
||||
)
|
||||
@@ -0,0 +1,68 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import merge_message_runs
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.operations import (
|
||||
dispatch_subquery,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
|
||||
ExpandedRetrievalInput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
|
||||
QueryExpansionUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import (
|
||||
REWRITE_PROMPT_MULTI_ORIGINAL,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import dispatch_separated
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
|
||||
|
||||
|
||||
def expand_queries(
|
||||
state: ExpandedRetrievalInput, config: RunnableConfig
|
||||
) -> QueryExpansionUpdate:
|
||||
# Sometimes we want to expand the original question, sometimes we want to expand a sub-question.
|
||||
# When we are running this node on the original question, no question is explictly passed in.
|
||||
# Instead, we use the original question from the search request.
|
||||
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
now_start = datetime.now()
|
||||
question = (
|
||||
state.question
|
||||
if hasattr(state, "question")
|
||||
else agent_a_config.search_request.query
|
||||
)
|
||||
llm = agent_a_config.fast_llm
|
||||
chat_session_id = agent_a_config.chat_session_id
|
||||
sub_question_id = state.sub_question_id
|
||||
if sub_question_id is None:
|
||||
level, question_nr = 0, 0
|
||||
else:
|
||||
level, question_nr = parse_question_id(sub_question_id)
|
||||
|
||||
if chat_session_id is None:
|
||||
raise ValueError("chat_session_id must be provided for agent search")
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=REWRITE_PROMPT_MULTI_ORIGINAL.format(question=question),
|
||||
)
|
||||
]
|
||||
|
||||
llm_response_list = dispatch_separated(
|
||||
llm.stream(prompt=msg), dispatch_subquery(level, question_nr)
|
||||
)
|
||||
|
||||
llm_response = merge_message_runs(llm_response_list, chunk_separator="")[0].content
|
||||
|
||||
rewritten_queries = llm_response.split("\n")
|
||||
now_end = datetime.now()
|
||||
return QueryExpansionUpdate(
|
||||
expanded_queries=rewritten_queries,
|
||||
log_messages=[
|
||||
f"{now_end} -- Expanded Retrieval - Query Expansion - Time taken: {now_end - now_start}"
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,82 @@
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.callbacks.manager import dispatch_custom_event
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.models import (
|
||||
ExpandedRetrievalResult,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.operations import (
|
||||
calculate_sub_question_retrieval_stats,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
|
||||
ExpandedRetrievalState,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
|
||||
ExpandedRetrievalUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkStats
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
|
||||
from onyx.chat.models import ExtendedToolResponse
|
||||
from onyx.tools.tool_implementations.search.search_tool import yield_search_responses
|
||||
|
||||
|
||||
def format_results(
|
||||
state: ExpandedRetrievalState, config: RunnableConfig
|
||||
) -> ExpandedRetrievalUpdate:
|
||||
level, question_nr = parse_question_id(state.sub_question_id or "0_0")
|
||||
query_infos = [
|
||||
result.query_info
|
||||
for result in state.expanded_retrieval_results
|
||||
if result.query_info is not None
|
||||
]
|
||||
if len(query_infos) == 0:
|
||||
raise ValueError("No query info found")
|
||||
|
||||
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
# main question docs will be sent later after aggregation and deduping with sub-question docs
|
||||
|
||||
stream_documents = state.reranked_documents
|
||||
|
||||
if not (level == 0 and question_nr == 0):
|
||||
if len(stream_documents) == 0:
|
||||
# The sub-question is used as the last query. If no verified documents are found, stream
|
||||
# the top 3 for that one. We may want to revisit this.
|
||||
stream_documents = state.expanded_retrieval_results[-1].search_results[:3]
|
||||
|
||||
for tool_response in yield_search_responses(
|
||||
query=state.question,
|
||||
reranked_sections=state.retrieved_documents, # TODO: rename params. (sections pre-merging here.)
|
||||
final_context_sections=stream_documents,
|
||||
search_query_info=query_infos[0], # TODO: handle differing query infos?
|
||||
get_section_relevance=lambda: None, # TODO: add relevance
|
||||
search_tool=agent_a_config.search_tool,
|
||||
):
|
||||
dispatch_custom_event(
|
||||
"tool_response",
|
||||
ExtendedToolResponse(
|
||||
id=tool_response.id,
|
||||
response=tool_response.response,
|
||||
level=level,
|
||||
level_question_nr=question_nr,
|
||||
),
|
||||
)
|
||||
sub_question_retrieval_stats = calculate_sub_question_retrieval_stats(
|
||||
verified_documents=state.verified_documents,
|
||||
expanded_retrieval_results=state.expanded_retrieval_results,
|
||||
)
|
||||
|
||||
if sub_question_retrieval_stats is None:
|
||||
sub_question_retrieval_stats = AgentChunkStats()
|
||||
# else:
|
||||
# sub_question_retrieval_stats = [sub_question_retrieval_stats]
|
||||
|
||||
return ExpandedRetrievalUpdate(
|
||||
expanded_retrieval_result=ExpandedRetrievalResult(
|
||||
expanded_queries_results=state.expanded_retrieval_results,
|
||||
all_documents=stream_documents,
|
||||
context_documents=state.reranked_documents,
|
||||
sub_question_retrieval_stats=sub_question_retrieval_stats,
|
||||
),
|
||||
)
|
||||
@@ -0,0 +1,44 @@
|
||||
from typing import cast
|
||||
from typing import Literal
|
||||
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
from langgraph.types import Command
|
||||
from langgraph.types import Send
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
|
||||
DocVerificationInput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.states import (
|
||||
ExpandedRetrievalState,
|
||||
)
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
|
||||
|
||||
def verification_kickoff(
|
||||
state: ExpandedRetrievalState,
|
||||
config: RunnableConfig,
|
||||
) -> Command[Literal["doc_verification"]]:
|
||||
documents = state.retrieved_documents
|
||||
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
verification_question = (
|
||||
state.question
|
||||
if hasattr(state, "question")
|
||||
else agent_a_config.search_request.query
|
||||
)
|
||||
sub_question_id = state.sub_question_id
|
||||
return Command(
|
||||
update={},
|
||||
goto=[
|
||||
Send(
|
||||
node="doc_verification",
|
||||
arg=DocVerificationInput(
|
||||
doc_to_verify=doc,
|
||||
question=verification_question,
|
||||
base_search=False,
|
||||
sub_question_id=sub_question_id,
|
||||
log_messages=[],
|
||||
),
|
||||
)
|
||||
for doc in documents
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,97 @@
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
|
||||
import numpy as np
|
||||
from langchain_core.callbacks.manager import dispatch_custom_event
|
||||
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkStats
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import QueryResult
|
||||
from onyx.chat.models import SubQueryPiece
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def dispatch_subquery(level: int, question_nr: int) -> Callable[[str, int], None]:
|
||||
def helper(token: str, num: int) -> None:
|
||||
dispatch_custom_event(
|
||||
"subqueries",
|
||||
SubQueryPiece(
|
||||
sub_query=token,
|
||||
level=level,
|
||||
level_question_nr=question_nr,
|
||||
query_id=num,
|
||||
),
|
||||
)
|
||||
|
||||
return helper
|
||||
|
||||
|
||||
def calculate_sub_question_retrieval_stats(
|
||||
verified_documents: list[InferenceSection],
|
||||
expanded_retrieval_results: list[QueryResult],
|
||||
) -> AgentChunkStats:
|
||||
chunk_scores: dict[str, dict[str, list[int | float]]] = defaultdict(
|
||||
lambda: defaultdict(list)
|
||||
)
|
||||
|
||||
for expanded_retrieval_result in expanded_retrieval_results:
|
||||
for doc in expanded_retrieval_result.search_results:
|
||||
doc_chunk_id = f"{doc.center_chunk.document_id}_{doc.center_chunk.chunk_id}"
|
||||
if doc.center_chunk.score is not None:
|
||||
chunk_scores[doc_chunk_id]["score"].append(doc.center_chunk.score)
|
||||
|
||||
verified_doc_chunk_ids = [
|
||||
f"{verified_document.center_chunk.document_id}_{verified_document.center_chunk.chunk_id}"
|
||||
for verified_document in verified_documents
|
||||
]
|
||||
dismissed_doc_chunk_ids = []
|
||||
|
||||
raw_chunk_stats_counts: dict[str, int] = defaultdict(int)
|
||||
raw_chunk_stats_scores: dict[str, float] = defaultdict(float)
|
||||
for doc_chunk_id, chunk_data in chunk_scores.items():
|
||||
if doc_chunk_id in verified_doc_chunk_ids:
|
||||
raw_chunk_stats_counts["verified_count"] += 1
|
||||
|
||||
valid_chunk_scores = [
|
||||
score for score in chunk_data["score"] if score is not None
|
||||
]
|
||||
raw_chunk_stats_scores["verified_scores"] += float(
|
||||
np.mean(valid_chunk_scores)
|
||||
)
|
||||
else:
|
||||
raw_chunk_stats_counts["rejected_count"] += 1
|
||||
valid_chunk_scores = [
|
||||
score for score in chunk_data["score"] if score is not None
|
||||
]
|
||||
raw_chunk_stats_scores["rejected_scores"] += float(
|
||||
np.mean(valid_chunk_scores)
|
||||
)
|
||||
dismissed_doc_chunk_ids.append(doc_chunk_id)
|
||||
|
||||
if raw_chunk_stats_counts["verified_count"] == 0:
|
||||
verified_avg_scores = 0.0
|
||||
else:
|
||||
verified_avg_scores = raw_chunk_stats_scores["verified_scores"] / float(
|
||||
raw_chunk_stats_counts["verified_count"]
|
||||
)
|
||||
|
||||
rejected_scores = raw_chunk_stats_scores.get("rejected_scores", None)
|
||||
if rejected_scores is not None:
|
||||
rejected_avg_scores = rejected_scores / float(
|
||||
raw_chunk_stats_counts["rejected_count"]
|
||||
)
|
||||
else:
|
||||
rejected_avg_scores = None
|
||||
|
||||
chunk_stats = AgentChunkStats(
|
||||
verified_count=raw_chunk_stats_counts["verified_count"],
|
||||
verified_avg_scores=verified_avg_scores,
|
||||
rejected_count=raw_chunk_stats_counts["rejected_count"],
|
||||
rejected_avg_scores=rejected_avg_scores,
|
||||
verified_doc_chunk_ids=verified_doc_chunk_ids,
|
||||
dismissed_doc_chunk_ids=dismissed_doc_chunk_ids,
|
||||
)
|
||||
|
||||
return chunk_stats
|
||||
@@ -0,0 +1,91 @@
|
||||
from operator import add
|
||||
from typing import Annotated
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.agents.agent_search.core_state import SubgraphCoreState
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.models import (
|
||||
ExpandedRetrievalResult,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import QueryResult
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import RetrievalFitStats
|
||||
from onyx.agents.agent_search.shared_graph_utils.operators import (
|
||||
dedup_inference_sections,
|
||||
)
|
||||
from onyx.context.search.models import InferenceSection
|
||||
|
||||
|
||||
### States ###
|
||||
|
||||
## Graph Input State
|
||||
|
||||
|
||||
class ExpandedRetrievalInput(SubgraphCoreState):
|
||||
question: str = ""
|
||||
base_search: bool = False
|
||||
sub_question_id: str | None = None
|
||||
|
||||
|
||||
## Update/Return States
|
||||
|
||||
|
||||
class QueryExpansionUpdate(BaseModel):
|
||||
expanded_queries: list[str] = ["aaa", "bbb"]
|
||||
log_messages: list[str] = []
|
||||
|
||||
|
||||
class DocVerificationUpdate(BaseModel):
|
||||
verified_documents: Annotated[list[InferenceSection], dedup_inference_sections] = []
|
||||
|
||||
|
||||
class DocRetrievalUpdate(BaseModel):
|
||||
expanded_retrieval_results: Annotated[list[QueryResult], add] = []
|
||||
retrieved_documents: Annotated[
|
||||
list[InferenceSection], dedup_inference_sections
|
||||
] = []
|
||||
log_messages: list[str] = []
|
||||
|
||||
|
||||
class DocRerankingUpdate(BaseModel):
|
||||
reranked_documents: Annotated[list[InferenceSection], dedup_inference_sections] = []
|
||||
sub_question_retrieval_stats: RetrievalFitStats | None = None
|
||||
log_messages: list[str] = []
|
||||
|
||||
|
||||
class ExpandedRetrievalUpdate(BaseModel):
|
||||
expanded_retrieval_result: ExpandedRetrievalResult
|
||||
|
||||
|
||||
## Graph Output State
|
||||
|
||||
|
||||
class ExpandedRetrievalOutput(BaseModel):
|
||||
expanded_retrieval_result: ExpandedRetrievalResult = ExpandedRetrievalResult()
|
||||
base_expanded_retrieval_result: ExpandedRetrievalResult = ExpandedRetrievalResult()
|
||||
log_messages: list[str] = []
|
||||
|
||||
|
||||
## Graph State
|
||||
|
||||
|
||||
class ExpandedRetrievalState(
|
||||
# This includes the core state
|
||||
ExpandedRetrievalInput,
|
||||
QueryExpansionUpdate,
|
||||
DocRetrievalUpdate,
|
||||
DocVerificationUpdate,
|
||||
DocRerankingUpdate,
|
||||
ExpandedRetrievalOutput,
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
## Conditional Input States
|
||||
|
||||
|
||||
class DocVerificationInput(ExpandedRetrievalInput):
|
||||
doc_to_verify: InferenceSection
|
||||
|
||||
|
||||
class RetrievalInput(ExpandedRetrievalInput):
|
||||
query_to_retrieve: str = ""
|
||||
100
backend/onyx/agents/agent_search/deep_search_a/main/edges.py
Normal file
100
backend/onyx/agents/agent_search/deep_search_a/main/edges.py
Normal file
@@ -0,0 +1,100 @@
|
||||
from collections.abc import Hashable
|
||||
from datetime import datetime
|
||||
from typing import Literal
|
||||
|
||||
from langgraph.types import Send
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.states import (
|
||||
AnswerQuestionInput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.states import (
|
||||
AnswerQuestionOutput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import MainState
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import (
|
||||
RequireRefinedAnswerUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import make_question_id
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def parallelize_initial_sub_question_answering(
|
||||
state: MainState,
|
||||
) -> list[Send | Hashable]:
|
||||
now_start = datetime.now()
|
||||
if len(state.initial_decomp_questions) > 0:
|
||||
# sub_question_record_ids = [subq_record.id for subq_record in state["sub_question_records"]]
|
||||
# if len(state["sub_question_records"]) == 0:
|
||||
# if state["config"].use_persistence:
|
||||
# raise ValueError("No sub-questions found for initial decompozed questions")
|
||||
# else:
|
||||
# # in this case, we are doing retrieval on the original question.
|
||||
# # to make all the logic consistent, we create a new sub-question
|
||||
# # with the same content as the original question
|
||||
# sub_question_record_ids = [1] * len(state["initial_decomp_questions"])
|
||||
|
||||
return [
|
||||
Send(
|
||||
"answer_query_subgraph",
|
||||
AnswerQuestionInput(
|
||||
question=question,
|
||||
question_id=make_question_id(0, question_nr + 1),
|
||||
log_messages=[
|
||||
f"{now_start} -- Main Edge - Parallelize Initial Sub-question Answering"
|
||||
],
|
||||
),
|
||||
)
|
||||
for question_nr, question in enumerate(state.initial_decomp_questions)
|
||||
]
|
||||
|
||||
else:
|
||||
return [
|
||||
Send(
|
||||
"ingest_answers",
|
||||
AnswerQuestionOutput(
|
||||
answer_results=[],
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
# Define the function that determines whether to continue or not
|
||||
def continue_to_refined_answer_or_end(
|
||||
state: RequireRefinedAnswerUpdate,
|
||||
) -> Literal["refined_sub_question_creation", "logging_node"]:
|
||||
if state.require_refined_answer:
|
||||
return "refined_sub_question_creation"
|
||||
else:
|
||||
return "logging_node"
|
||||
|
||||
|
||||
def parallelize_refined_sub_question_answering(
|
||||
state: MainState,
|
||||
) -> list[Send | Hashable]:
|
||||
now_start = datetime.now()
|
||||
if len(state.refined_sub_questions) > 0:
|
||||
return [
|
||||
Send(
|
||||
"answer_refined_question",
|
||||
AnswerQuestionInput(
|
||||
question=question_data.sub_question,
|
||||
question_id=make_question_id(1, question_nr),
|
||||
log_messages=[
|
||||
f"{now_start} -- Main Edge - Parallelize Refined Sub-question Answering"
|
||||
],
|
||||
),
|
||||
)
|
||||
for question_nr, question_data in state.refined_sub_questions.items()
|
||||
]
|
||||
|
||||
else:
|
||||
return [
|
||||
Send(
|
||||
"ingest_refined_sub_answers",
|
||||
AnswerQuestionOutput(
|
||||
answer_results=[],
|
||||
),
|
||||
)
|
||||
]
|
||||
@@ -0,0 +1,375 @@
|
||||
from langgraph.graph import END
|
||||
from langgraph.graph import START
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.graph_builder import (
|
||||
answer_query_graph_builder,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.answer_refinement_sub_question.graph_builder import (
|
||||
answer_refined_query_graph_builder,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.base_raw_search.graph_builder import (
|
||||
base_raw_search_graph_builder,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.edges import (
|
||||
continue_to_refined_answer_or_end,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.edges import (
|
||||
parallelize_initial_sub_question_answering,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.edges import (
|
||||
parallelize_refined_sub_question_answering,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.nodes.agent_logging import (
|
||||
agent_logging,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.nodes.agent_path_decision import (
|
||||
agent_path_decision,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.nodes.agent_path_routing import (
|
||||
agent_path_routing,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.nodes.agent_search_start import (
|
||||
agent_search_start,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.nodes.answer_comparison import (
|
||||
answer_comparison,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.nodes.direct_llm_handling import (
|
||||
direct_llm_handling,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.nodes.entity_term_extraction_llm import (
|
||||
entity_term_extraction_llm,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.nodes.generate_initial_answer import (
|
||||
generate_initial_answer,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.nodes.generate_refined_answer import (
|
||||
generate_refined_answer,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.nodes.ingest_initial_base_retrieval import (
|
||||
ingest_initial_base_retrieval,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.nodes.ingest_initial_sub_question_answers import (
|
||||
ingest_initial_sub_question_answers,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.nodes.ingest_refined_answers import (
|
||||
ingest_refined_answers,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.nodes.initial_answer_quality_check import (
|
||||
initial_answer_quality_check,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.nodes.initial_sub_question_creation import (
|
||||
initial_sub_question_creation,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.nodes.refined_answer_decision import (
|
||||
refined_answer_decision,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.nodes.refined_sub_question_creation import (
|
||||
refined_sub_question_creation,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.nodes.retrieval_consolidation import (
|
||||
retrieval_consolidation,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import MainInput
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import MainState
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
test_mode = False
|
||||
|
||||
|
||||
def main_graph_builder(test_mode: bool = False) -> StateGraph:
|
||||
graph = StateGraph(
|
||||
state_schema=MainState,
|
||||
input=MainInput,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
node="agent_path_decision",
|
||||
action=agent_path_decision,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
node="agent_path_routing",
|
||||
action=agent_path_routing,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
node="LLM",
|
||||
action=direct_llm_handling,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
node="agent_search_start",
|
||||
action=agent_search_start,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
node="initial_sub_question_creation",
|
||||
action=initial_sub_question_creation,
|
||||
)
|
||||
answer_query_subgraph = answer_query_graph_builder().compile()
|
||||
graph.add_node(
|
||||
node="answer_query_subgraph",
|
||||
action=answer_query_subgraph,
|
||||
)
|
||||
|
||||
base_raw_search_subgraph = base_raw_search_graph_builder().compile()
|
||||
graph.add_node(
|
||||
node="base_raw_search_subgraph",
|
||||
action=base_raw_search_subgraph,
|
||||
)
|
||||
|
||||
# refined_answer_subgraph = refined_answers_graph_builder().compile()
|
||||
# graph.add_node(
|
||||
# node="refined_answer_subgraph",
|
||||
# action=refined_answer_subgraph,
|
||||
# )
|
||||
|
||||
graph.add_node(
|
||||
node="refined_sub_question_creation",
|
||||
action=refined_sub_question_creation,
|
||||
)
|
||||
|
||||
answer_refined_question = answer_refined_query_graph_builder().compile()
|
||||
graph.add_node(
|
||||
node="answer_refined_question",
|
||||
action=answer_refined_question,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
node="ingest_refined_answers",
|
||||
action=ingest_refined_answers,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
node="generate_refined_answer",
|
||||
action=generate_refined_answer,
|
||||
)
|
||||
|
||||
# graph.add_node(
|
||||
# node="check_refined_answer",
|
||||
# action=check_refined_answer,
|
||||
# )
|
||||
|
||||
graph.add_node(
|
||||
node="ingest_initial_retrieval",
|
||||
action=ingest_initial_base_retrieval,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
node="retrieval_consolidation",
|
||||
action=retrieval_consolidation,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
node="ingest_initial_sub_question_answers",
|
||||
action=ingest_initial_sub_question_answers,
|
||||
)
|
||||
graph.add_node(
|
||||
node="generate_initial_answer",
|
||||
action=generate_initial_answer,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
node="initial_answer_quality_check",
|
||||
action=initial_answer_quality_check,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
node="entity_term_extraction_llm",
|
||||
action=entity_term_extraction_llm,
|
||||
)
|
||||
graph.add_node(
|
||||
node="refined_answer_decision",
|
||||
action=refined_answer_decision,
|
||||
)
|
||||
graph.add_node(
|
||||
node="answer_comparison",
|
||||
action=answer_comparison,
|
||||
)
|
||||
graph.add_node(
|
||||
node="logging_node",
|
||||
action=agent_logging,
|
||||
)
|
||||
# if test_mode:
|
||||
# graph.add_node(
|
||||
# node="generate_initial_base_answer",
|
||||
# action=generate_initial_base_answer,
|
||||
# )
|
||||
|
||||
### Add edges ###
|
||||
|
||||
# raph.add_edge(start_key=START, end_key="base_raw_search_subgraph")
|
||||
|
||||
graph.add_edge(
|
||||
start_key=START,
|
||||
end_key="agent_path_decision",
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key="agent_path_decision",
|
||||
end_key="agent_path_routing",
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key="agent_search_start",
|
||||
end_key="base_raw_search_subgraph",
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key="agent_search_start",
|
||||
end_key="initial_sub_question_creation",
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key="base_raw_search_subgraph",
|
||||
end_key="ingest_initial_retrieval",
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key=["ingest_initial_retrieval", "ingest_initial_sub_question_answers"],
|
||||
end_key="retrieval_consolidation",
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key="retrieval_consolidation",
|
||||
end_key="entity_term_extraction_llm",
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key="retrieval_consolidation",
|
||||
end_key="generate_initial_answer",
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key="LLM",
|
||||
end_key=END,
|
||||
)
|
||||
|
||||
# graph.add_edge(
|
||||
# start_key=START,
|
||||
# end_key="initial_sub_question_creation",
|
||||
# )
|
||||
|
||||
graph.add_conditional_edges(
|
||||
source="initial_sub_question_creation",
|
||||
path=parallelize_initial_sub_question_answering,
|
||||
path_map=["answer_query_subgraph"],
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="answer_query_subgraph",
|
||||
end_key="ingest_initial_sub_question_answers",
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key="retrieval_consolidation",
|
||||
end_key="generate_initial_answer",
|
||||
)
|
||||
|
||||
# graph.add_edge(
|
||||
# start_key="generate_initial_answer",
|
||||
# end_key="entity_term_extraction_llm",
|
||||
# )
|
||||
|
||||
graph.add_edge(
|
||||
start_key="generate_initial_answer",
|
||||
end_key="initial_answer_quality_check",
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key=["initial_answer_quality_check", "entity_term_extraction_llm"],
|
||||
end_key="refined_answer_decision",
|
||||
)
|
||||
|
||||
graph.add_conditional_edges(
|
||||
source="refined_answer_decision",
|
||||
path=continue_to_refined_answer_or_end,
|
||||
path_map=["refined_sub_question_creation", "logging_node"],
|
||||
)
|
||||
|
||||
graph.add_conditional_edges(
|
||||
source="refined_sub_question_creation", # DONE
|
||||
path=parallelize_refined_sub_question_answering,
|
||||
path_map=["answer_refined_question"],
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="answer_refined_question", # HERE
|
||||
end_key="ingest_refined_answers",
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key="ingest_refined_answers",
|
||||
end_key="generate_refined_answer",
|
||||
)
|
||||
|
||||
# graph.add_conditional_edges(
|
||||
# source="refined_answer_decision",
|
||||
# path=continue_to_refined_answer_or_end,
|
||||
# path_map=["refined_answer_subgraph", END],
|
||||
# )
|
||||
|
||||
# graph.add_edge(
|
||||
# start_key="refined_answer_subgraph",
|
||||
# end_key="generate_refined_answer",
|
||||
# )
|
||||
|
||||
graph.add_edge(
|
||||
start_key="generate_refined_answer",
|
||||
end_key="answer_comparison",
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="answer_comparison",
|
||||
end_key="logging_node",
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key="logging_node",
|
||||
end_key=END,
|
||||
)
|
||||
|
||||
# graph.add_edge(
|
||||
# start_key="generate_refined_answer",
|
||||
# end_key="check_refined_answer",
|
||||
# )
|
||||
|
||||
# graph.add_edge(
|
||||
# start_key="check_refined_answer",
|
||||
# end_key=END,
|
||||
# )
|
||||
|
||||
return graph
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pass
|
||||
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.llm.factory import get_default_llms
|
||||
from onyx.context.search.models import SearchRequest
|
||||
|
||||
graph = main_graph_builder()
|
||||
compiled_graph = graph.compile()
|
||||
primary_llm, fast_llm = get_default_llms()
|
||||
|
||||
with get_session_context_manager() as db_session:
|
||||
search_request = SearchRequest(query="Who created Excel?")
|
||||
agent_a_config, search_tool = get_test_config(
|
||||
db_session, primary_llm, fast_llm, search_request
|
||||
)
|
||||
|
||||
inputs = MainInput(
|
||||
base_question=agent_a_config.search_request.query, log_messages=[]
|
||||
)
|
||||
|
||||
for thing in compiled_graph.stream(
|
||||
input=inputs,
|
||||
config={"configurable": {"config": agent_a_config}},
|
||||
# stream_mode="debug",
|
||||
# debug=True,
|
||||
subgraphs=True,
|
||||
):
|
||||
logger.debug(thing)
|
||||
@@ -0,0 +1,36 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class FollowUpSubQuestion(BaseModel):
|
||||
sub_question: str
|
||||
sub_question_id: str
|
||||
verified: bool
|
||||
answered: bool
|
||||
answer: str
|
||||
|
||||
|
||||
class AgentTimings(BaseModel):
|
||||
base_duration__s: float | None
|
||||
refined_duration__s: float | None
|
||||
full_duration__s: float | None
|
||||
|
||||
|
||||
class AgentBaseMetrics(BaseModel):
|
||||
num_verified_documents_total: int | None
|
||||
num_verified_documents_core: int | None
|
||||
verified_avg_score_core: float | None
|
||||
num_verified_documents_base: int | float | None
|
||||
verified_avg_score_base: float | None = None
|
||||
base_doc_boost_factor: float | None = None
|
||||
support_boost_factor: float | None = None
|
||||
duration__s: float | None = None
|
||||
|
||||
|
||||
class AgentRefinedMetrics(BaseModel):
|
||||
refined_doc_boost_factor: float | None = None
|
||||
refined_question_boost_factor: float | None = None
|
||||
duration__s: float | None = None
|
||||
|
||||
|
||||
class AgentAdditionalMetrics(BaseModel):
|
||||
pass
|
||||
@@ -0,0 +1,113 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.main.models import AgentAdditionalMetrics
|
||||
from onyx.agents.agent_search.deep_search_a.main.models import AgentTimings
|
||||
from onyx.agents.agent_search.deep_search_a.main.operations import logger
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import MainOutput
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import MainState
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import CombinedAgentMetrics
|
||||
from onyx.db.chat import log_agent_metrics
|
||||
from onyx.db.chat import log_agent_sub_question_results
|
||||
|
||||
|
||||
def agent_logging(state: MainState, config: RunnableConfig) -> MainOutput:
|
||||
now_start = datetime.now()
|
||||
|
||||
logger.debug(f"--------{now_start}--------LOGGING NODE---")
|
||||
|
||||
agent_start_time = state.agent_start_time
|
||||
agent_base_end_time = state.agent_base_end_time
|
||||
agent_refined_start_time = state.agent_refined_start_time or None
|
||||
agent_refined_end_time = state.agent_refined_end_time or None
|
||||
agent_end_time = agent_refined_end_time or agent_base_end_time
|
||||
|
||||
agent_base_duration = None
|
||||
if agent_base_end_time:
|
||||
agent_base_duration = (agent_base_end_time - agent_start_time).total_seconds()
|
||||
|
||||
agent_refined_duration = None
|
||||
if agent_refined_start_time and agent_refined_end_time:
|
||||
agent_refined_duration = (
|
||||
agent_refined_end_time - agent_refined_start_time
|
||||
).total_seconds()
|
||||
|
||||
agent_full_duration = None
|
||||
if agent_end_time:
|
||||
agent_full_duration = (agent_end_time - agent_start_time).total_seconds()
|
||||
|
||||
agent_type = "refined" if agent_refined_duration else "base"
|
||||
|
||||
agent_base_metrics = state.agent_base_metrics
|
||||
agent_refined_metrics = state.agent_refined_metrics
|
||||
|
||||
combined_agent_metrics = CombinedAgentMetrics(
|
||||
timings=AgentTimings(
|
||||
base_duration__s=agent_base_duration,
|
||||
refined_duration__s=agent_refined_duration,
|
||||
full_duration__s=agent_full_duration,
|
||||
),
|
||||
base_metrics=agent_base_metrics,
|
||||
refined_metrics=agent_refined_metrics,
|
||||
additional_metrics=AgentAdditionalMetrics(),
|
||||
)
|
||||
|
||||
persona_id = None
|
||||
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
if agent_a_config.search_request.persona:
|
||||
persona_id = agent_a_config.search_request.persona.id
|
||||
|
||||
user_id = None
|
||||
user = agent_a_config.search_tool.user
|
||||
if user:
|
||||
user_id = user.id
|
||||
|
||||
# log the agent metrics
|
||||
if agent_a_config.db_session is not None:
|
||||
log_agent_metrics(
|
||||
db_session=agent_a_config.db_session,
|
||||
user_id=user_id,
|
||||
persona_id=persona_id,
|
||||
agent_type=agent_type,
|
||||
start_time=agent_start_time,
|
||||
agent_metrics=combined_agent_metrics,
|
||||
)
|
||||
|
||||
if agent_a_config.use_persistence:
|
||||
# Persist the sub-answer in the database
|
||||
db_session = agent_a_config.db_session
|
||||
chat_session_id = agent_a_config.chat_session_id
|
||||
primary_message_id = agent_a_config.message_id
|
||||
sub_question_answer_results = state.decomp_answer_results
|
||||
|
||||
log_agent_sub_question_results(
|
||||
db_session=db_session,
|
||||
chat_session_id=chat_session_id,
|
||||
primary_message_id=primary_message_id,
|
||||
sub_question_answer_results=sub_question_answer_results,
|
||||
)
|
||||
|
||||
# if chat_session_id is not None and primary_message_id is not None and sub_question_id is not None:
|
||||
# create_sub_answer(
|
||||
# db_session=db_session,
|
||||
# chat_session_id=chat_session_id,
|
||||
# primary_message_id=primary_message_id,
|
||||
# sub_question_id=sub_question_id,
|
||||
# answer=answer_str,
|
||||
# # )
|
||||
# pass
|
||||
|
||||
now_end = datetime.now()
|
||||
main_output = MainOutput(
|
||||
log_messages=[
|
||||
f"{now_end} -- Main - Logging, Time taken: {now_end - now_start}"
|
||||
],
|
||||
)
|
||||
|
||||
logger.debug(f"--------{now_end}--{now_end - now_start}--------LOGGING NODE END---")
|
||||
logger.debug(f"--------{now_end}--{now_end - now_start}--------LOGGING NODE END---")
|
||||
|
||||
return main_output
|
||||
@@ -0,0 +1,99 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.main.operations import logger
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import MainState
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import RoutingDecision
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
build_history_prompt,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import AGENT_DECISION_PROMPT
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import (
|
||||
AGENT_DECISION_PROMPT_AFTER_SEARCH,
|
||||
)
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.llm.utils import check_number_of_tokens
|
||||
from onyx.tools.tool_implementations.search.search_tool import (
|
||||
SEARCH_RESPONSE_SUMMARY_ID,
|
||||
)
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary
|
||||
|
||||
|
||||
def agent_path_decision(state: MainState, config: RunnableConfig) -> RoutingDecision:
|
||||
now_start = datetime.now()
|
||||
|
||||
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
question = agent_a_config.search_request.query
|
||||
perform_initial_search_path_decision = (
|
||||
agent_a_config.perform_initial_search_path_decision
|
||||
)
|
||||
|
||||
history = build_history_prompt(agent_a_config.prompt_builder)
|
||||
|
||||
logger.debug(f"--------{now_start}--------DECIDING TO SEARCH OR GO TO LLM---")
|
||||
|
||||
if perform_initial_search_path_decision:
|
||||
search_tool = agent_a_config.search_tool
|
||||
retrieved_docs: list[InferenceSection] = []
|
||||
|
||||
# new db session to avoid concurrency issues
|
||||
with get_session_context_manager() as db_session:
|
||||
for tool_response in search_tool.run(
|
||||
query=question,
|
||||
force_no_rerank=True,
|
||||
alternate_db_session=db_session,
|
||||
):
|
||||
# get retrieved docs to send to the rest of the graph
|
||||
if tool_response.id == SEARCH_RESPONSE_SUMMARY_ID:
|
||||
response = cast(SearchResponseSummary, tool_response.response)
|
||||
retrieved_docs = response.top_sections
|
||||
break
|
||||
|
||||
sample_doc_str = "\n\n".join(
|
||||
[doc.combined_content for _, doc in enumerate(retrieved_docs[:3])]
|
||||
)
|
||||
|
||||
agent_decision_prompt = AGENT_DECISION_PROMPT_AFTER_SEARCH.format(
|
||||
question=question, sample_doc_str=sample_doc_str, history=history
|
||||
)
|
||||
|
||||
else:
|
||||
sample_doc_str = ""
|
||||
agent_decision_prompt = AGENT_DECISION_PROMPT.format(
|
||||
question=question, history=history
|
||||
)
|
||||
|
||||
msg = [HumanMessage(content=agent_decision_prompt)]
|
||||
|
||||
# Get the rewritten queries in a defined format
|
||||
model = agent_a_config.fast_llm
|
||||
|
||||
# no need to stream this
|
||||
resp = model.invoke(msg)
|
||||
|
||||
if isinstance(resp.content, str) and "research" in resp.content.lower():
|
||||
routing = "agent_search"
|
||||
else:
|
||||
routing = "LLM"
|
||||
|
||||
now_end = datetime.now()
|
||||
|
||||
logger.debug(
|
||||
f"--------{now_end}--{now_end - now_start}--------DECIDING TO SEARCH OR GO TO LLM END---"
|
||||
)
|
||||
|
||||
check_number_of_tokens(agent_decision_prompt)
|
||||
|
||||
return RoutingDecision(
|
||||
# Decide which route to take
|
||||
routing=routing,
|
||||
sample_doc_str=sample_doc_str,
|
||||
log_messages=[
|
||||
f"{now_end} -- Path decision: {routing}, Time taken: {now_end - now_start}"
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,31 @@
|
||||
from datetime import datetime
|
||||
from typing import Literal
|
||||
|
||||
from langgraph.types import Command
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import MainState
|
||||
|
||||
|
||||
def agent_path_routing(
|
||||
state: MainState,
|
||||
) -> Command[Literal["agent_search_start", "LLM"]]:
|
||||
now_start = datetime.now()
|
||||
routing = state.routing if hasattr(state, "routing") else "agent_search"
|
||||
|
||||
if routing == "agent_search":
|
||||
agent_path = "agent_search_start"
|
||||
else:
|
||||
agent_path = "LLM"
|
||||
|
||||
now_end = datetime.now()
|
||||
|
||||
return Command(
|
||||
# state update
|
||||
update={
|
||||
"log_messages": [
|
||||
f"{now_end} -- Main - Path routing: {agent_path}, Time taken: {now_end - now_start}"
|
||||
]
|
||||
},
|
||||
# control flow
|
||||
goto=agent_path,
|
||||
)
|
||||
@@ -0,0 +1,8 @@
|
||||
from datetime import datetime
|
||||
|
||||
from onyx.agents.agent_search.core_state import CoreState
|
||||
|
||||
|
||||
def agent_search_start(state: CoreState) -> CoreState:
|
||||
now_end = datetime.now()
|
||||
return CoreState(log_messages=[f"{now_end} -- Main - Agent search start"])
|
||||
@@ -0,0 +1,60 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.callbacks.manager import dispatch_custom_event
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.main.operations import logger
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import AnswerComparison
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import MainState
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import ANSWER_COMPARISON_PROMPT
|
||||
from onyx.chat.models import RefinedAnswerImprovement
|
||||
|
||||
|
||||
def answer_comparison(state: MainState, config: RunnableConfig) -> AnswerComparison:
|
||||
now_start = datetime.now()
|
||||
|
||||
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
question = agent_a_config.search_request.query
|
||||
initial_answer = state.initial_answer
|
||||
refined_answer = state.refined_answer
|
||||
|
||||
logger.debug(f"--------{now_start}--------ANSWER COMPARISON STARTED--")
|
||||
|
||||
answer_comparison_prompt = ANSWER_COMPARISON_PROMPT.format(
|
||||
question=question, initial_answer=initial_answer, refined_answer=refined_answer
|
||||
)
|
||||
|
||||
msg = [HumanMessage(content=answer_comparison_prompt)]
|
||||
|
||||
# Get the rewritten queries in a defined format
|
||||
model = agent_a_config.fast_llm
|
||||
|
||||
# no need to stream this
|
||||
resp = model.invoke(msg)
|
||||
|
||||
refined_answer_improvement = (
|
||||
isinstance(resp.content, str) and "yes" in resp.content.lower()
|
||||
)
|
||||
|
||||
dispatch_custom_event(
|
||||
"refined_answer_improvement",
|
||||
RefinedAnswerImprovement(
|
||||
refined_answer_improvement=refined_answer_improvement,
|
||||
),
|
||||
)
|
||||
|
||||
now_end = datetime.now()
|
||||
|
||||
logger.debug(
|
||||
f"--------{now_end}--{now_end - now_start}--------ANSWER COMPARISON COMPLETED---"
|
||||
)
|
||||
|
||||
return AnswerComparison(
|
||||
refined_answer_improvement=refined_answer_improvement,
|
||||
log_messages=[
|
||||
f"{now_start} -- Answer comparison: {refined_answer_improvement}, Time taken: {now_end - now_start}"
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,89 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.callbacks.manager import dispatch_custom_event
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import merge_content
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.main.operations import logger
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import InitialAnswerUpdate
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import MainState
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import (
|
||||
ASSISTANT_SYSTEM_PROMPT_DEFAULT,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import (
|
||||
ASSISTANT_SYSTEM_PROMPT_PERSONA,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import DIRECT_LLM_PROMPT
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import get_persona_prompt
|
||||
from onyx.chat.models import AgentAnswerPiece
|
||||
|
||||
|
||||
def direct_llm_handling(
|
||||
state: MainState, config: RunnableConfig
|
||||
) -> InitialAnswerUpdate:
|
||||
now_start = datetime.now()
|
||||
|
||||
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
question = agent_a_config.search_request.query
|
||||
persona_prompt = get_persona_prompt(agent_a_config.search_request.persona)
|
||||
|
||||
if len(persona_prompt) == 0:
|
||||
persona_specification = ASSISTANT_SYSTEM_PROMPT_DEFAULT
|
||||
else:
|
||||
persona_specification = ASSISTANT_SYSTEM_PROMPT_PERSONA.format(
|
||||
persona_prompt=persona_prompt
|
||||
)
|
||||
|
||||
logger.debug(f"--------{now_start}--------LLM HANDLING START---")
|
||||
|
||||
model = agent_a_config.fast_llm
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=DIRECT_LLM_PROMPT.format(
|
||||
persona_specification=persona_specification, question=question
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
streamed_tokens: list[str | list[str | dict[str, Any]]] = [""]
|
||||
|
||||
for message in model.stream(msg):
|
||||
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
|
||||
content = message.content
|
||||
if not isinstance(content, str):
|
||||
raise ValueError(
|
||||
f"Expected content to be a string, but got {type(content)}"
|
||||
)
|
||||
dispatch_custom_event(
|
||||
"initial_agent_answer",
|
||||
AgentAnswerPiece(
|
||||
answer_piece=content,
|
||||
level=0,
|
||||
level_question_nr=0,
|
||||
answer_type="agent_level_answer",
|
||||
),
|
||||
)
|
||||
streamed_tokens.append(content)
|
||||
|
||||
response = merge_content(*streamed_tokens)
|
||||
answer = cast(str, response)
|
||||
|
||||
now_end = datetime.now()
|
||||
|
||||
logger.debug(f"--------{now_end}--{now_end - now_start}--------LLM HANDLING END---")
|
||||
|
||||
return InitialAnswerUpdate(
|
||||
initial_answer=answer,
|
||||
initial_agent_stats=None,
|
||||
generated_sub_questions=[],
|
||||
agent_base_end_time=now_end,
|
||||
agent_base_metrics=None,
|
||||
log_messages=[
|
||||
f"{now_end} -- Main - LLM handling: {answer}, Time taken: {now_end - now_start}"
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,137 @@
|
||||
import json
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import merge_message_runs
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.main.operations import logger
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import (
|
||||
EntityTermExtractionUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import MainState
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
trim_prompt_piece,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import Entity
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import (
|
||||
EntityRelationshipTermExtraction,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import Relationship
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import Term
|
||||
from onyx.agents.agent_search.shared_graph_utils.operators import (
|
||||
dedup_inference_sections,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import ENTITY_TERM_PROMPT
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import format_docs
|
||||
|
||||
|
||||
def entity_term_extraction_llm(
|
||||
state: MainState, config: RunnableConfig
|
||||
) -> EntityTermExtractionUpdate:
|
||||
now_start = datetime.now()
|
||||
|
||||
logger.debug(f"--------{now_start}--------GENERATE ENTITIES & TERMS---")
|
||||
|
||||
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
if not agent_a_config.allow_refinement:
|
||||
now_end = datetime.now()
|
||||
return EntityTermExtractionUpdate(
|
||||
entity_retlation_term_extractions=EntityRelationshipTermExtraction(
|
||||
entities=[],
|
||||
relationships=[],
|
||||
terms=[],
|
||||
),
|
||||
log_messages=[
|
||||
f"{now_end} -- Main - ETR Extraction, Time taken: {now_end - now_start}"
|
||||
],
|
||||
)
|
||||
|
||||
# first four lines duplicates from generate_initial_answer
|
||||
question = agent_a_config.search_request.query
|
||||
sub_question_docs = state.documents
|
||||
all_original_question_documents = state.all_original_question_documents
|
||||
relevant_docs = dedup_inference_sections(
|
||||
sub_question_docs, all_original_question_documents
|
||||
)
|
||||
|
||||
# start with the entity/term/extraction
|
||||
doc_context = format_docs(relevant_docs)
|
||||
|
||||
doc_context = trim_prompt_piece(
|
||||
agent_a_config.fast_llm.config, doc_context, ENTITY_TERM_PROMPT + question
|
||||
)
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=ENTITY_TERM_PROMPT.format(question=question, context=doc_context),
|
||||
)
|
||||
]
|
||||
fast_llm = agent_a_config.fast_llm
|
||||
# Grader
|
||||
llm_response_list = list(
|
||||
fast_llm.stream(
|
||||
prompt=msg,
|
||||
)
|
||||
)
|
||||
llm_response = merge_message_runs(llm_response_list, chunk_separator="")[0].content
|
||||
|
||||
cleaned_response = re.sub(r"```json\n|\n```", "", llm_response)
|
||||
parsed_response = json.loads(cleaned_response)
|
||||
|
||||
entities = []
|
||||
relationships = []
|
||||
terms = []
|
||||
for entity in parsed_response.get("retrieved_entities_relationships", {}).get(
|
||||
"entities", {}
|
||||
):
|
||||
entity_name = entity.get("entity_name", "")
|
||||
entity_type = entity.get("entity_type", "")
|
||||
entities.append(Entity(entity_name=entity_name, entity_type=entity_type))
|
||||
|
||||
for relationship in parsed_response.get("retrieved_entities_relationships", {}).get(
|
||||
"relationships", {}
|
||||
):
|
||||
relationship_name = relationship.get("relationship_name", "")
|
||||
relationship_type = relationship.get("relationship_type", "")
|
||||
relationship_entities = relationship.get("relationship_entities", [])
|
||||
relationships.append(
|
||||
Relationship(
|
||||
relationship_name=relationship_name,
|
||||
relationship_type=relationship_type,
|
||||
relationship_entities=relationship_entities,
|
||||
)
|
||||
)
|
||||
|
||||
for term in parsed_response.get("retrieved_entities_relationships", {}).get(
|
||||
"terms", {}
|
||||
):
|
||||
term_name = term.get("term_name", "")
|
||||
term_type = term.get("term_type", "")
|
||||
term_similar_to = term.get("term_similar_to", [])
|
||||
terms.append(
|
||||
Term(
|
||||
term_name=term_name,
|
||||
term_type=term_type,
|
||||
term_similar_to=term_similar_to,
|
||||
)
|
||||
)
|
||||
|
||||
now_end = datetime.now()
|
||||
|
||||
logger.debug(
|
||||
f"--------{now_end}--{now_end - now_start}--------ENTITY TERM EXTRACTION END---"
|
||||
)
|
||||
|
||||
return EntityTermExtractionUpdate(
|
||||
entity_retlation_term_extractions=EntityRelationshipTermExtraction(
|
||||
entities=entities,
|
||||
relationships=relationships,
|
||||
terms=terms,
|
||||
),
|
||||
log_messages=[
|
||||
f"{now_end} -- Main - ETR Extraction, Time taken: {now_end - now_start}"
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,264 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.callbacks.manager import dispatch_custom_event
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import merge_content
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.main.models import AgentBaseMetrics
|
||||
from onyx.agents.agent_search.deep_search_a.main.operations import (
|
||||
calculate_initial_agent_stats,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.operations import get_query_info
|
||||
from onyx.agents.agent_search.deep_search_a.main.operations import logger
|
||||
from onyx.agents.agent_search.deep_search_a.main.operations import (
|
||||
remove_document_citations,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import InitialAnswerUpdate
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import MainState
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
build_history_prompt,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
trim_prompt_piece,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import InitialAgentResultStats
|
||||
from onyx.agents.agent_search.shared_graph_utils.operators import (
|
||||
dedup_inference_sections,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import (
|
||||
ASSISTANT_SYSTEM_PROMPT_DEFAULT,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import (
|
||||
ASSISTANT_SYSTEM_PROMPT_PERSONA,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import INITIAL_RAG_PROMPT
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import (
|
||||
INITIAL_RAG_PROMPT_NO_SUB_QUESTIONS,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import (
|
||||
SUB_QUESTION_ANSWER_TEMPLATE,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import UNKNOWN_ANSWER
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import format_docs
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import get_persona_prompt
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import get_today_prompt
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
|
||||
from onyx.chat.models import AgentAnswerPiece
|
||||
from onyx.chat.models import ExtendedToolResponse
|
||||
from onyx.tools.tool_implementations.search.search_tool import yield_search_responses
|
||||
|
||||
|
||||
def generate_initial_answer(
|
||||
state: MainState, config: RunnableConfig
|
||||
) -> InitialAnswerUpdate:
|
||||
now_start = datetime.now()
|
||||
|
||||
logger.debug(f"--------{now_start}--------GENERATE INITIAL---")
|
||||
|
||||
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
question = agent_a_config.search_request.query
|
||||
persona_prompt = get_persona_prompt(agent_a_config.search_request.persona)
|
||||
|
||||
history = build_history_prompt(agent_a_config.prompt_builder)
|
||||
date_str = get_today_prompt()
|
||||
|
||||
sub_question_docs = state.context_documents
|
||||
all_original_question_documents = state.all_original_question_documents
|
||||
|
||||
relevant_docs = dedup_inference_sections(
|
||||
sub_question_docs, all_original_question_documents
|
||||
)
|
||||
decomp_questions = []
|
||||
|
||||
if len(relevant_docs) == 0:
|
||||
dispatch_custom_event(
|
||||
"initial_agent_answer",
|
||||
AgentAnswerPiece(
|
||||
answer_piece=UNKNOWN_ANSWER,
|
||||
level=0,
|
||||
level_question_nr=0,
|
||||
answer_type="agent_level_answer",
|
||||
),
|
||||
)
|
||||
|
||||
answer = UNKNOWN_ANSWER
|
||||
initial_agent_stats = InitialAgentResultStats(
|
||||
sub_questions={},
|
||||
original_question={},
|
||||
agent_effectiveness={},
|
||||
)
|
||||
|
||||
else:
|
||||
# Use the query info from the base document retrieval
|
||||
query_info = get_query_info(state.original_question_retrieval_results)
|
||||
|
||||
for tool_response in yield_search_responses(
|
||||
query=question,
|
||||
reranked_sections=relevant_docs,
|
||||
final_context_sections=relevant_docs,
|
||||
search_query_info=query_info,
|
||||
get_section_relevance=lambda: None, # TODO: add relevance
|
||||
search_tool=agent_a_config.search_tool,
|
||||
):
|
||||
dispatch_custom_event(
|
||||
"tool_response",
|
||||
ExtendedToolResponse(
|
||||
id=tool_response.id,
|
||||
response=tool_response.response,
|
||||
level=0,
|
||||
level_question_nr=0, # 0, 0 is the base question
|
||||
),
|
||||
)
|
||||
|
||||
net_new_original_question_docs = []
|
||||
for all_original_question_doc in all_original_question_documents:
|
||||
if all_original_question_doc not in sub_question_docs:
|
||||
net_new_original_question_docs.append(all_original_question_doc)
|
||||
|
||||
decomp_answer_results = state.decomp_answer_results
|
||||
|
||||
good_qa_list: list[str] = []
|
||||
|
||||
sub_question_nr = 1
|
||||
|
||||
for decomp_answer_result in decomp_answer_results:
|
||||
decomp_questions.append(decomp_answer_result.question)
|
||||
_, question_nr = parse_question_id(decomp_answer_result.question_id)
|
||||
if (
|
||||
decomp_answer_result.quality.lower().startswith("yes")
|
||||
and len(decomp_answer_result.answer) > 0
|
||||
and decomp_answer_result.answer != UNKNOWN_ANSWER
|
||||
):
|
||||
good_qa_list.append(
|
||||
SUB_QUESTION_ANSWER_TEMPLATE.format(
|
||||
sub_question=decomp_answer_result.question,
|
||||
sub_answer=decomp_answer_result.answer,
|
||||
sub_question_nr=sub_question_nr,
|
||||
)
|
||||
)
|
||||
sub_question_nr += 1
|
||||
|
||||
if len(good_qa_list) > 0:
|
||||
sub_question_answer_str = "\n\n------\n\n".join(good_qa_list)
|
||||
else:
|
||||
sub_question_answer_str = ""
|
||||
|
||||
# Determine which persona-specification prompt to use
|
||||
|
||||
if len(persona_prompt) == 0:
|
||||
persona_specification = ASSISTANT_SYSTEM_PROMPT_DEFAULT
|
||||
else:
|
||||
persona_specification = ASSISTANT_SYSTEM_PROMPT_PERSONA.format(
|
||||
persona_prompt=persona_prompt
|
||||
)
|
||||
|
||||
# Determine which base prompt to use given the sub-question information
|
||||
if len(good_qa_list) > 0:
|
||||
base_prompt = INITIAL_RAG_PROMPT
|
||||
else:
|
||||
base_prompt = INITIAL_RAG_PROMPT_NO_SUB_QUESTIONS
|
||||
|
||||
model = agent_a_config.fast_llm
|
||||
|
||||
doc_context = format_docs(relevant_docs)
|
||||
doc_context = trim_prompt_piece(
|
||||
model.config,
|
||||
doc_context,
|
||||
base_prompt
|
||||
+ sub_question_answer_str
|
||||
+ persona_specification
|
||||
+ history
|
||||
+ date_str,
|
||||
)
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=base_prompt.format(
|
||||
question=question,
|
||||
answered_sub_questions=remove_document_citations(
|
||||
sub_question_answer_str
|
||||
),
|
||||
relevant_docs=format_docs(relevant_docs),
|
||||
persona_specification=persona_specification,
|
||||
history=history,
|
||||
date_prompt=date_str,
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
streamed_tokens: list[str | list[str | dict[str, Any]]] = [""]
|
||||
for message in model.stream(msg):
|
||||
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
|
||||
content = message.content
|
||||
if not isinstance(content, str):
|
||||
raise ValueError(
|
||||
f"Expected content to be a string, but got {type(content)}"
|
||||
)
|
||||
dispatch_custom_event(
|
||||
"initial_agent_answer",
|
||||
AgentAnswerPiece(
|
||||
answer_piece=content,
|
||||
level=0,
|
||||
level_question_nr=0,
|
||||
answer_type="agent_level_answer",
|
||||
),
|
||||
)
|
||||
streamed_tokens.append(content)
|
||||
|
||||
response = merge_content(*streamed_tokens)
|
||||
answer = cast(str, response)
|
||||
|
||||
initial_agent_stats = calculate_initial_agent_stats(
|
||||
state.decomp_answer_results, state.original_question_retrieval_stats
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"\n\nYYYYY--Sub-Questions:\n\n{sub_question_answer_str}\n\nStats:\n\n"
|
||||
)
|
||||
|
||||
if initial_agent_stats:
|
||||
logger.debug(initial_agent_stats.original_question)
|
||||
logger.debug(initial_agent_stats.sub_questions)
|
||||
logger.debug(initial_agent_stats.agent_effectiveness)
|
||||
|
||||
now_end = datetime.now()
|
||||
|
||||
logger.debug(
|
||||
f"--------{now_end}--{now_end - now_start}--------INITIAL AGENT ANSWER END---\n\n"
|
||||
)
|
||||
|
||||
agent_base_end_time = datetime.now()
|
||||
|
||||
agent_base_metrics = AgentBaseMetrics(
|
||||
num_verified_documents_total=len(relevant_docs),
|
||||
num_verified_documents_core=state.original_question_retrieval_stats.verified_count,
|
||||
verified_avg_score_core=state.original_question_retrieval_stats.verified_avg_scores,
|
||||
num_verified_documents_base=initial_agent_stats.sub_questions.get(
|
||||
"num_verified_documents", None
|
||||
),
|
||||
verified_avg_score_base=initial_agent_stats.sub_questions.get(
|
||||
"verified_avg_score", None
|
||||
),
|
||||
base_doc_boost_factor=initial_agent_stats.agent_effectiveness.get(
|
||||
"utilized_chunk_ratio", None
|
||||
),
|
||||
support_boost_factor=initial_agent_stats.agent_effectiveness.get(
|
||||
"support_ratio", None
|
||||
),
|
||||
duration__s=(agent_base_end_time - state.agent_start_time).total_seconds(),
|
||||
)
|
||||
|
||||
return InitialAnswerUpdate(
|
||||
initial_answer=answer,
|
||||
initial_agent_stats=initial_agent_stats,
|
||||
generated_sub_questions=decomp_questions,
|
||||
agent_base_end_time=agent_base_end_time,
|
||||
agent_base_metrics=agent_base_metrics,
|
||||
log_messages=[
|
||||
f"{now_end} -- Main - Initial Answer generation, Time taken: {now_end - now_start}"
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,56 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.main.operations import logger
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import InitialAnswerBASEUpdate
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import MainState
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
trim_prompt_piece,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import INITIAL_RAG_BASE_PROMPT
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import format_docs
|
||||
|
||||
|
||||
def generate_initial_base_search_only_answer(
|
||||
state: MainState,
|
||||
config: RunnableConfig,
|
||||
) -> InitialAnswerBASEUpdate:
|
||||
now_start = datetime.now()
|
||||
|
||||
logger.debug(f"--------{now_start}--------GENERATE INITIAL BASE ANSWER---")
|
||||
|
||||
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
question = agent_a_config.search_request.query
|
||||
original_question_docs = state.all_original_question_documents
|
||||
|
||||
model = agent_a_config.fast_llm
|
||||
|
||||
doc_context = format_docs(original_question_docs)
|
||||
doc_context = trim_prompt_piece(
|
||||
model.config, doc_context, INITIAL_RAG_BASE_PROMPT + question
|
||||
)
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=INITIAL_RAG_BASE_PROMPT.format(
|
||||
question=question,
|
||||
context=doc_context,
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
# Grader
|
||||
response = model.invoke(msg)
|
||||
answer = response.pretty_repr()
|
||||
|
||||
now_end = datetime.now()
|
||||
|
||||
logger.debug(
|
||||
f"--------{now_end}--{now_end - now_start}--------INITIAL BASE ANSWER END---\n\n"
|
||||
)
|
||||
|
||||
return InitialAnswerBASEUpdate(initial_base_answer=answer)
|
||||
@@ -0,0 +1,326 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.callbacks.manager import dispatch_custom_event
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import merge_content
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.main.models import AgentRefinedMetrics
|
||||
from onyx.agents.agent_search.deep_search_a.main.operations import get_query_info
|
||||
from onyx.agents.agent_search.deep_search_a.main.operations import logger
|
||||
from onyx.agents.agent_search.deep_search_a.main.operations import (
|
||||
remove_document_citations,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import MainState
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import RefinedAnswerUpdate
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
build_history_prompt,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
trim_prompt_piece,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import RefinedAgentStats
|
||||
from onyx.agents.agent_search.shared_graph_utils.operators import (
|
||||
dedup_inference_sections,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import (
|
||||
ASSISTANT_SYSTEM_PROMPT_DEFAULT,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import (
|
||||
ASSISTANT_SYSTEM_PROMPT_PERSONA,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import REVISED_RAG_PROMPT
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import (
|
||||
REVISED_RAG_PROMPT_NO_SUB_QUESTIONS,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import (
|
||||
SUB_QUESTION_ANSWER_TEMPLATE,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import UNKNOWN_ANSWER
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import format_docs
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import get_persona_prompt
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import get_today_prompt
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
|
||||
from onyx.chat.models import AgentAnswerPiece
|
||||
from onyx.chat.models import ExtendedToolResponse
|
||||
from onyx.tools.tool_implementations.search.search_tool import yield_search_responses
|
||||
|
||||
|
||||
def generate_refined_answer(
|
||||
state: MainState, config: RunnableConfig
|
||||
) -> RefinedAnswerUpdate:
|
||||
now_start = datetime.now()
|
||||
|
||||
logger.debug(f"--------{now_start}--------GENERATE REFINED ANSWER---")
|
||||
|
||||
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
question = agent_a_config.search_request.query
|
||||
persona_prompt = get_persona_prompt(agent_a_config.search_request.persona)
|
||||
|
||||
history = build_history_prompt(agent_a_config.prompt_builder)
|
||||
date_str = get_today_prompt()
|
||||
initial_documents = state.documents
|
||||
revised_documents = state.refined_documents
|
||||
|
||||
combined_documents = dedup_inference_sections(initial_documents, revised_documents)
|
||||
|
||||
query_info = get_query_info(state.original_question_retrieval_results)
|
||||
# stream refined answer docs
|
||||
for tool_response in yield_search_responses(
|
||||
query=question,
|
||||
reranked_sections=combined_documents,
|
||||
final_context_sections=combined_documents,
|
||||
search_query_info=query_info,
|
||||
get_section_relevance=lambda: None, # TODO: add relevance
|
||||
search_tool=agent_a_config.search_tool,
|
||||
):
|
||||
dispatch_custom_event(
|
||||
"tool_response",
|
||||
ExtendedToolResponse(
|
||||
id=tool_response.id,
|
||||
response=tool_response.response,
|
||||
level=1,
|
||||
level_question_nr=0, # 0, 0 is the base question
|
||||
),
|
||||
)
|
||||
|
||||
if len(initial_documents) > 0:
|
||||
revision_doc_effectiveness = len(combined_documents) / len(initial_documents)
|
||||
elif len(revised_documents) == 0:
|
||||
revision_doc_effectiveness = 0.0
|
||||
else:
|
||||
revision_doc_effectiveness = 10.0
|
||||
|
||||
decomp_answer_results = state.decomp_answer_results
|
||||
# revised_answer_results = state.refined_decomp_answer_results
|
||||
|
||||
good_qa_list: list[str] = []
|
||||
decomp_questions = []
|
||||
|
||||
initial_good_sub_questions: list[str] = []
|
||||
new_revised_good_sub_questions: list[str] = []
|
||||
|
||||
sub_question_nr = 1
|
||||
|
||||
for decomp_answer_result in decomp_answer_results:
|
||||
question_level, question_nr = parse_question_id(
|
||||
decomp_answer_result.question_id
|
||||
)
|
||||
|
||||
decomp_questions.append(decomp_answer_result.question)
|
||||
if (
|
||||
decomp_answer_result.quality.lower().startswith("yes")
|
||||
and len(decomp_answer_result.answer) > 0
|
||||
and decomp_answer_result.answer != UNKNOWN_ANSWER
|
||||
):
|
||||
good_qa_list.append(
|
||||
SUB_QUESTION_ANSWER_TEMPLATE.format(
|
||||
sub_question=decomp_answer_result.question,
|
||||
sub_answer=decomp_answer_result.answer,
|
||||
sub_question_nr=sub_question_nr,
|
||||
)
|
||||
)
|
||||
if question_level == 0:
|
||||
initial_good_sub_questions.append(decomp_answer_result.question)
|
||||
else:
|
||||
new_revised_good_sub_questions.append(decomp_answer_result.question)
|
||||
|
||||
sub_question_nr += 1
|
||||
|
||||
initial_good_sub_questions = list(set(initial_good_sub_questions))
|
||||
new_revised_good_sub_questions = list(set(new_revised_good_sub_questions))
|
||||
total_good_sub_questions = list(
|
||||
set(initial_good_sub_questions + new_revised_good_sub_questions)
|
||||
)
|
||||
if len(initial_good_sub_questions) > 0:
|
||||
revision_question_efficiency: float = len(total_good_sub_questions) / len(
|
||||
initial_good_sub_questions
|
||||
)
|
||||
elif len(new_revised_good_sub_questions) > 0:
|
||||
revision_question_efficiency = 10.0
|
||||
else:
|
||||
revision_question_efficiency = 1.0
|
||||
|
||||
sub_question_answer_str = "\n\n------\n\n".join(list(set(good_qa_list)))
|
||||
|
||||
# original answer
|
||||
|
||||
initial_answer = state.initial_answer
|
||||
|
||||
# Determine which persona-specification prompt to use
|
||||
|
||||
if len(persona_prompt) == 0:
|
||||
persona_specification = ASSISTANT_SYSTEM_PROMPT_DEFAULT
|
||||
else:
|
||||
persona_specification = ASSISTANT_SYSTEM_PROMPT_PERSONA.format(
|
||||
persona_prompt=persona_prompt
|
||||
)
|
||||
|
||||
# Determine which base prompt to use given the sub-question information
|
||||
if len(good_qa_list) > 0:
|
||||
base_prompt = REVISED_RAG_PROMPT
|
||||
else:
|
||||
base_prompt = REVISED_RAG_PROMPT_NO_SUB_QUESTIONS
|
||||
|
||||
model = agent_a_config.fast_llm
|
||||
relevant_docs = format_docs(combined_documents)
|
||||
relevant_docs = trim_prompt_piece(
|
||||
model.config,
|
||||
relevant_docs,
|
||||
base_prompt
|
||||
+ question
|
||||
+ sub_question_answer_str
|
||||
+ relevant_docs
|
||||
+ initial_answer
|
||||
+ persona_specification
|
||||
+ history,
|
||||
)
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=base_prompt.format(
|
||||
question=question,
|
||||
history=history,
|
||||
answered_sub_questions=remove_document_citations(
|
||||
sub_question_answer_str
|
||||
),
|
||||
relevant_docs=relevant_docs,
|
||||
initial_answer=remove_document_citations(initial_answer),
|
||||
persona_specification=persona_specification,
|
||||
date_prompt=date_str,
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
# Grader
|
||||
|
||||
streamed_tokens: list[str | list[str | dict[str, Any]]] = [""]
|
||||
for message in model.stream(msg):
|
||||
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
|
||||
content = message.content
|
||||
if not isinstance(content, str):
|
||||
raise ValueError(
|
||||
f"Expected content to be a string, but got {type(content)}"
|
||||
)
|
||||
dispatch_custom_event(
|
||||
"refined_agent_answer",
|
||||
AgentAnswerPiece(
|
||||
answer_piece=content,
|
||||
level=1,
|
||||
level_question_nr=0,
|
||||
answer_type="agent_level_answer",
|
||||
),
|
||||
)
|
||||
streamed_tokens.append(content)
|
||||
|
||||
response = merge_content(*streamed_tokens)
|
||||
answer = cast(str, response)
|
||||
|
||||
# refined_agent_stats = _calculate_refined_agent_stats(
|
||||
# state.decomp_answer_results, state.original_question_retrieval_stats
|
||||
# )
|
||||
|
||||
initial_good_sub_questions_str = "\n".join(list(set(initial_good_sub_questions)))
|
||||
new_revised_good_sub_questions_str = "\n".join(
|
||||
list(set(new_revised_good_sub_questions))
|
||||
)
|
||||
|
||||
refined_agent_stats = RefinedAgentStats(
|
||||
revision_doc_efficiency=revision_doc_effectiveness,
|
||||
revision_question_efficiency=revision_question_efficiency,
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"\n\n---INITIAL ANSWER START---\n\n Answer:\n Agent: {initial_answer}"
|
||||
)
|
||||
logger.debug("-" * 10)
|
||||
logger.debug(f"\n\n---REVISED AGENT ANSWER START---\n\n Answer:\n Agent: {answer}")
|
||||
|
||||
logger.debug("-" * 100)
|
||||
logger.debug(f"\n\nINITAL Sub-Questions\n\n{initial_good_sub_questions_str}\n\n")
|
||||
logger.debug("-" * 10)
|
||||
logger.debug(
|
||||
f"\n\nNEW REVISED Sub-Questions\n\n{new_revised_good_sub_questions_str}\n\n"
|
||||
)
|
||||
|
||||
logger.debug("-" * 100)
|
||||
|
||||
logger.debug(
|
||||
f"\n\nINITAL & REVISED Sub-Questions & Answers:\n\n{sub_question_answer_str}\n\nStas:\n\n"
|
||||
)
|
||||
|
||||
logger.debug("-" * 100)
|
||||
|
||||
if state.initial_agent_stats:
|
||||
initial_doc_boost_factor = state.initial_agent_stats.agent_effectiveness.get(
|
||||
"utilized_chunk_ratio", "--"
|
||||
)
|
||||
initial_support_boost_factor = (
|
||||
state.initial_agent_stats.agent_effectiveness.get("support_ratio", "--")
|
||||
)
|
||||
num_initial_verified_docs = state.initial_agent_stats.original_question.get(
|
||||
"num_verified_documents", "--"
|
||||
)
|
||||
initial_verified_docs_avg_score = (
|
||||
state.initial_agent_stats.original_question.get("verified_avg_score", "--")
|
||||
)
|
||||
initial_sub_questions_verified_docs = (
|
||||
state.initial_agent_stats.sub_questions.get("num_verified_documents", "--")
|
||||
)
|
||||
|
||||
logger.debug("INITIAL AGENT STATS")
|
||||
logger.debug(f"Document Boost Factor: {initial_doc_boost_factor}")
|
||||
logger.debug(f"Support Boost Factor: {initial_support_boost_factor}")
|
||||
logger.debug(f"Originally Verified Docs: {num_initial_verified_docs}")
|
||||
logger.debug(
|
||||
f"Originally Verified Docs Avg Score: {initial_verified_docs_avg_score}"
|
||||
)
|
||||
logger.debug(
|
||||
f"Sub-Questions Verified Docs: {initial_sub_questions_verified_docs}"
|
||||
)
|
||||
if refined_agent_stats:
|
||||
logger.debug("-" * 10)
|
||||
logger.debug("REFINED AGENT STATS")
|
||||
logger.debug(
|
||||
f"Revision Doc Factor: {refined_agent_stats.revision_doc_efficiency}"
|
||||
)
|
||||
logger.debug(
|
||||
f"Revision Question Factor: {refined_agent_stats.revision_question_efficiency}"
|
||||
)
|
||||
|
||||
now_end = datetime.now()
|
||||
|
||||
logger.debug(
|
||||
f"--------{now_end}--{now_end - now_start}--------INITIAL AGENT ANSWER END---\n\n"
|
||||
)
|
||||
|
||||
agent_refined_end_time = datetime.now()
|
||||
if state.agent_refined_start_time:
|
||||
agent_refined_duration = (
|
||||
agent_refined_end_time - state.agent_refined_start_time
|
||||
).total_seconds()
|
||||
else:
|
||||
agent_refined_duration = None
|
||||
|
||||
agent_refined_metrics = AgentRefinedMetrics(
|
||||
refined_doc_boost_factor=refined_agent_stats.revision_doc_efficiency,
|
||||
refined_question_boost_factor=refined_agent_stats.revision_question_efficiency,
|
||||
duration__s=agent_refined_duration,
|
||||
)
|
||||
|
||||
now_end = datetime.now()
|
||||
|
||||
logger.debug(
|
||||
f"--------{now_end}--{now_end - now_start}--------REFINED ANSWER UPDATE END---"
|
||||
)
|
||||
|
||||
return RefinedAnswerUpdate(
|
||||
refined_answer=answer,
|
||||
refined_answer_quality=True, # TODO: replace this with the actual check value
|
||||
refined_agent_stats=refined_agent_stats,
|
||||
agent_refined_end_time=agent_refined_end_time,
|
||||
agent_refined_metrics=agent_refined_metrics,
|
||||
)
|
||||
@@ -0,0 +1,39 @@
|
||||
from datetime import datetime
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.base_raw_search.states import (
|
||||
BaseRawSearchOutput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.operations import logger
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import ExpandedRetrievalUpdate
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkStats
|
||||
|
||||
|
||||
def ingest_initial_base_retrieval(
|
||||
state: BaseRawSearchOutput,
|
||||
) -> ExpandedRetrievalUpdate:
|
||||
now_start = datetime.now()
|
||||
|
||||
logger.debug(f"--------{now_start}--------INGEST INITIAL RETRIEVAL---")
|
||||
|
||||
sub_question_retrieval_stats = (
|
||||
state.base_expanded_retrieval_result.sub_question_retrieval_stats
|
||||
)
|
||||
if sub_question_retrieval_stats is None:
|
||||
sub_question_retrieval_stats = AgentChunkStats()
|
||||
else:
|
||||
sub_question_retrieval_stats = sub_question_retrieval_stats
|
||||
|
||||
now_end = datetime.now()
|
||||
|
||||
logger.debug(
|
||||
f"--------{now_end}--{now_end - now_start}--------INGEST INITIAL RETRIEVAL END---"
|
||||
)
|
||||
|
||||
return ExpandedRetrievalUpdate(
|
||||
original_question_retrieval_results=state.base_expanded_retrieval_result.expanded_queries_results,
|
||||
all_original_question_documents=state.base_expanded_retrieval_result.context_documents,
|
||||
original_question_retrieval_stats=sub_question_retrieval_stats,
|
||||
log_messages=[
|
||||
f"{now_end} -- Main - Ingestion base retrieval, Time taken: {now_end - now_start}"
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,41 @@
|
||||
from datetime import datetime
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.states import (
|
||||
AnswerQuestionOutput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.operations import logger
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import DecompAnswersUpdate
|
||||
from onyx.agents.agent_search.shared_graph_utils.operators import (
|
||||
dedup_inference_sections,
|
||||
)
|
||||
|
||||
|
||||
def ingest_initial_sub_question_answers(
|
||||
state: AnswerQuestionOutput,
|
||||
) -> DecompAnswersUpdate:
|
||||
now_start = datetime.now()
|
||||
|
||||
logger.debug(f"--------{now_start}--------INGEST ANSWERS---")
|
||||
documents = []
|
||||
context_documents = []
|
||||
answer_results = state.answer_results if hasattr(state, "answer_results") else []
|
||||
for answer_result in answer_results:
|
||||
documents.extend(answer_result.documents)
|
||||
context_documents.extend(answer_result.context_documents)
|
||||
|
||||
now_end = datetime.now()
|
||||
|
||||
logger.debug(
|
||||
f"--------{now_end}--{now_end - now_start}--------INGEST ANSWERS END---"
|
||||
)
|
||||
|
||||
return DecompAnswersUpdate(
|
||||
# Deduping is done by the documents operator for the main graph
|
||||
# so we might not need to dedup here
|
||||
documents=dedup_inference_sections(documents, []),
|
||||
context_documents=dedup_inference_sections(context_documents, []),
|
||||
decomp_answer_results=answer_results,
|
||||
log_messages=[
|
||||
f"{now_end} -- Main - Ingest initial processed sub questions, Time taken: {now_end - now_start}"
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,39 @@
|
||||
from datetime import datetime
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.states import (
|
||||
AnswerQuestionOutput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.operations import logger
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import DecompAnswersUpdate
|
||||
from onyx.agents.agent_search.shared_graph_utils.operators import (
|
||||
dedup_inference_sections,
|
||||
)
|
||||
|
||||
|
||||
def ingest_refined_answers(
|
||||
state: AnswerQuestionOutput,
|
||||
) -> DecompAnswersUpdate:
|
||||
now_start = datetime.now()
|
||||
|
||||
logger.debug(f"--------{now_start}--------INGEST FOLLOW UP ANSWERS---")
|
||||
|
||||
documents = []
|
||||
answer_results = state.answer_results if hasattr(state, "answer_results") else []
|
||||
for answer_result in answer_results:
|
||||
documents.extend(answer_result.documents)
|
||||
|
||||
now_end = datetime.now()
|
||||
|
||||
logger.debug(
|
||||
f"--------{now_end}--{now_end - now_start}--------INGEST FOLLOW UP ANSWERS END---"
|
||||
)
|
||||
|
||||
return DecompAnswersUpdate(
|
||||
# Deduping is done by the documents operator for the main graph
|
||||
# so we might not need to dedup here
|
||||
documents=dedup_inference_sections(documents, []),
|
||||
decomp_answer_results=answer_results,
|
||||
log_messages=[
|
||||
f"{now_end} -- Main - Ingest refined answers, Time taken: {now_end - now_start}"
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,40 @@
|
||||
from datetime import datetime
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.main.operations import logger
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import (
|
||||
InitialAnswerQualityUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import MainState
|
||||
|
||||
|
||||
def initial_answer_quality_check(state: MainState) -> InitialAnswerQualityUpdate:
|
||||
"""
|
||||
Check whether the final output satisfies the original user question
|
||||
|
||||
Args:
|
||||
state (messages): The current state
|
||||
|
||||
Returns:
|
||||
InitialAnswerQualityUpdate
|
||||
"""
|
||||
|
||||
now_start = datetime.now()
|
||||
|
||||
logger.debug(
|
||||
f"--------{now_start}--------Checking for base answer validity - for not set True/False manually"
|
||||
)
|
||||
|
||||
verdict = True
|
||||
|
||||
now_end = datetime.now()
|
||||
|
||||
logger.debug(
|
||||
f"--------{now_end}--{now_end - now_start}--------INITIAL ANSWER QUALITY CHECK END---"
|
||||
)
|
||||
|
||||
return InitialAnswerQualityUpdate(
|
||||
initial_answer_quality=verdict,
|
||||
log_messages=[
|
||||
f"{now_end} -- Main - Initial answer quality check, Time taken: {now_end - now_start}"
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,150 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.callbacks.manager import dispatch_custom_event
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import merge_content
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.main.models import AgentRefinedMetrics
|
||||
from onyx.agents.agent_search.deep_search_a.main.operations import dispatch_subquestion
|
||||
from onyx.agents.agent_search.deep_search_a.main.operations import logger
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import BaseDecompUpdate
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import MainState
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
build_history_prompt,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import (
|
||||
INITIAL_DECOMPOSITION_PROMPT_QUESTIONS,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import (
|
||||
INITIAL_DECOMPOSITION_PROMPT_QUESTIONS_AFTER_SEARCH,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import dispatch_separated
|
||||
from onyx.chat.models import StreamStopInfo
|
||||
from onyx.chat.models import StreamStopReason
|
||||
from onyx.chat.models import SubQuestionPiece
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.tools.tool_implementations.search.search_tool import (
|
||||
SEARCH_RESPONSE_SUMMARY_ID,
|
||||
)
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary
|
||||
|
||||
|
||||
def initial_sub_question_creation(
|
||||
state: MainState, config: RunnableConfig
|
||||
) -> BaseDecompUpdate:
|
||||
now_start = datetime.now()
|
||||
|
||||
logger.debug(f"--------{now_start}--------BASE DECOMP START---")
|
||||
|
||||
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
question = agent_a_config.search_request.query
|
||||
chat_session_id = agent_a_config.chat_session_id
|
||||
primary_message_id = agent_a_config.message_id
|
||||
perform_initial_search_decomposition = (
|
||||
agent_a_config.perform_initial_search_decomposition
|
||||
)
|
||||
perform_initial_search_path_decision = (
|
||||
agent_a_config.perform_initial_search_path_decision
|
||||
)
|
||||
history = build_history_prompt(agent_a_config.prompt_builder)
|
||||
|
||||
# Use the initial search results to inform the decomposition
|
||||
sample_doc_str = state.sample_doc_str if hasattr(state, "sample_doc_str") else ""
|
||||
|
||||
if not chat_session_id or not primary_message_id:
|
||||
raise ValueError(
|
||||
"chat_session_id and message_id must be provided for agent search"
|
||||
)
|
||||
agent_start_time = datetime.now()
|
||||
|
||||
# Initial search to inform decomposition. Just get top 3 fits
|
||||
|
||||
if perform_initial_search_decomposition:
|
||||
if not perform_initial_search_path_decision:
|
||||
search_tool = agent_a_config.search_tool
|
||||
retrieved_docs: list[InferenceSection] = []
|
||||
|
||||
# new db session to avoid concurrency issues
|
||||
with get_session_context_manager() as db_session:
|
||||
for tool_response in search_tool.run(
|
||||
query=question,
|
||||
force_no_rerank=True,
|
||||
alternate_db_session=db_session,
|
||||
):
|
||||
# get retrieved docs to send to the rest of the graph
|
||||
if tool_response.id == SEARCH_RESPONSE_SUMMARY_ID:
|
||||
response = cast(SearchResponseSummary, tool_response.response)
|
||||
retrieved_docs = response.top_sections
|
||||
break
|
||||
|
||||
sample_doc_str = "\n\n".join(
|
||||
[doc.combined_content for _, doc in enumerate(retrieved_docs[:3])]
|
||||
)
|
||||
|
||||
decomposition_prompt = (
|
||||
INITIAL_DECOMPOSITION_PROMPT_QUESTIONS_AFTER_SEARCH.format(
|
||||
question=question, sample_doc_str=sample_doc_str, history=history
|
||||
)
|
||||
)
|
||||
|
||||
else:
|
||||
decomposition_prompt = INITIAL_DECOMPOSITION_PROMPT_QUESTIONS.format(
|
||||
question=question, history=history
|
||||
)
|
||||
|
||||
# Start decomposition
|
||||
|
||||
msg = [HumanMessage(content=decomposition_prompt)]
|
||||
|
||||
# Get the rewritten queries in a defined format
|
||||
model = agent_a_config.fast_llm
|
||||
|
||||
# Send the initial question as a subquestion with number 0
|
||||
dispatch_custom_event(
|
||||
"decomp_qs",
|
||||
SubQuestionPiece(
|
||||
sub_question=question,
|
||||
level=0,
|
||||
level_question_nr=0,
|
||||
),
|
||||
)
|
||||
# dispatches custom events for subquestion tokens, adding in subquestion ids.
|
||||
streamed_tokens = dispatch_separated(model.stream(msg), dispatch_subquestion(0))
|
||||
|
||||
stop_event = StreamStopInfo(
|
||||
stop_reason=StreamStopReason.FINISHED,
|
||||
stream_type="sub_questions",
|
||||
level=0,
|
||||
)
|
||||
dispatch_custom_event("stream_finished", stop_event)
|
||||
|
||||
deomposition_response = merge_content(*streamed_tokens)
|
||||
|
||||
# this call should only return strings. Commenting out for efficiency
|
||||
# assert [type(tok) == str for tok in streamed_tokens]
|
||||
|
||||
# use no-op cast() instead of str() which runs code
|
||||
# list_of_subquestions = clean_and_parse_list_string(cast(str, response))
|
||||
list_of_subqs = cast(str, deomposition_response).split("\n")
|
||||
|
||||
decomp_list: list[str] = [sq.strip() for sq in list_of_subqs if sq.strip() != ""]
|
||||
|
||||
now_end = datetime.now()
|
||||
|
||||
logger.debug(f"--------{now_end}--{now_end - now_start}--------BASE DECOMP END---")
|
||||
|
||||
return BaseDecompUpdate(
|
||||
initial_decomp_questions=decomp_list,
|
||||
agent_start_time=agent_start_time,
|
||||
agent_refined_start_time=None,
|
||||
agent_refined_end_time=None,
|
||||
agent_refined_metrics=AgentRefinedMetrics(
|
||||
refined_doc_boost_factor=None,
|
||||
refined_question_boost_factor=None,
|
||||
duration__s=None,
|
||||
),
|
||||
)
|
||||
@@ -0,0 +1,47 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.main.operations import logger
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import MainState
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import (
|
||||
RequireRefinedAnswerUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
|
||||
|
||||
def refined_answer_decision(
|
||||
state: MainState, config: RunnableConfig
|
||||
) -> RequireRefinedAnswerUpdate:
|
||||
now_start = datetime.now()
|
||||
|
||||
logger.debug(f"--------{now_start}--------REFINED ANSWER DECISION---")
|
||||
|
||||
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
if "?" in agent_a_config.search_request.query:
|
||||
decision = False
|
||||
else:
|
||||
decision = True
|
||||
|
||||
decision = True
|
||||
|
||||
now_end = datetime.now()
|
||||
|
||||
logger.debug(
|
||||
f"--------{now_end}--{now_end - now_start}--------REFINED ANSWER DECISION END---"
|
||||
)
|
||||
log_messages = [
|
||||
f"{now_end} -- Main - Refined answer decision: {decision}, Time taken: {now_end - now_start}"
|
||||
]
|
||||
if agent_a_config.allow_refinement:
|
||||
return RequireRefinedAnswerUpdate(
|
||||
require_refined_answer=decision,
|
||||
log_messages=log_messages,
|
||||
)
|
||||
|
||||
else:
|
||||
return RequireRefinedAnswerUpdate(
|
||||
require_refined_answer=False,
|
||||
log_messages=log_messages,
|
||||
)
|
||||
@@ -0,0 +1,116 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.callbacks.manager import dispatch_custom_event
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import merge_content
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.main.models import FollowUpSubQuestion
|
||||
from onyx.agents.agent_search.deep_search_a.main.operations import dispatch_subquestion
|
||||
from onyx.agents.agent_search.deep_search_a.main.operations import logger
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import (
|
||||
FollowUpSubQuestionsUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import MainState
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
build_history_prompt,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import DEEP_DECOMPOSE_PROMPT
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import dispatch_separated
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
format_entity_term_extraction,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import make_question_id
|
||||
from onyx.tools.models import ToolCallKickoff
|
||||
|
||||
|
||||
def refined_sub_question_creation(
|
||||
state: MainState, config: RunnableConfig
|
||||
) -> FollowUpSubQuestionsUpdate:
|
||||
""" """
|
||||
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
dispatch_custom_event(
|
||||
"start_refined_answer_creation",
|
||||
ToolCallKickoff(
|
||||
tool_name="agent_search_1",
|
||||
tool_args={
|
||||
"query": agent_a_config.search_request.query,
|
||||
"answer": state.initial_answer,
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
now_start = datetime.now()
|
||||
|
||||
logger.debug(f"--------{now_start}--------FOLLOW UP DECOMPOSE---")
|
||||
|
||||
agent_refined_start_time = datetime.now()
|
||||
|
||||
question = agent_a_config.search_request.query
|
||||
base_answer = state.initial_answer
|
||||
history = build_history_prompt(agent_a_config.prompt_builder)
|
||||
# get the entity term extraction dict and properly format it
|
||||
entity_retlation_term_extractions = state.entity_retlation_term_extractions
|
||||
|
||||
entity_term_extraction_str = format_entity_term_extraction(
|
||||
entity_retlation_term_extractions
|
||||
)
|
||||
|
||||
initial_question_answers = state.decomp_answer_results
|
||||
|
||||
addressed_question_list = [
|
||||
x.question for x in initial_question_answers if "yes" in x.quality.lower()
|
||||
]
|
||||
|
||||
failed_question_list = [
|
||||
x.question for x in initial_question_answers if "no" in x.quality.lower()
|
||||
]
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=DEEP_DECOMPOSE_PROMPT.format(
|
||||
question=question,
|
||||
history=history,
|
||||
entity_term_extraction_str=entity_term_extraction_str,
|
||||
base_answer=base_answer,
|
||||
answered_sub_questions="\n - ".join(addressed_question_list),
|
||||
failed_sub_questions="\n - ".join(failed_question_list),
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
# Grader
|
||||
model = agent_a_config.fast_llm
|
||||
|
||||
streamed_tokens = dispatch_separated(model.stream(msg), dispatch_subquestion(1))
|
||||
response = merge_content(*streamed_tokens)
|
||||
|
||||
if isinstance(response, str):
|
||||
parsed_response = [q for q in response.split("\n") if q.strip() != ""]
|
||||
else:
|
||||
raise ValueError("LLM response is not a string")
|
||||
|
||||
refined_sub_question_dict = {}
|
||||
for sub_question_nr, sub_question in enumerate(parsed_response):
|
||||
refined_sub_question = FollowUpSubQuestion(
|
||||
sub_question=sub_question,
|
||||
sub_question_id=make_question_id(1, sub_question_nr + 1),
|
||||
verified=False,
|
||||
answered=False,
|
||||
answer="",
|
||||
)
|
||||
|
||||
refined_sub_question_dict[sub_question_nr + 1] = refined_sub_question
|
||||
|
||||
now_end = datetime.now()
|
||||
|
||||
logger.debug(
|
||||
f"--------{now_end}--{now_end - now_start}--------FOLLOW UP DECOMPOSE END---"
|
||||
)
|
||||
|
||||
return FollowUpSubQuestionsUpdate(
|
||||
refined_sub_questions=refined_sub_question_dict,
|
||||
agent_refined_start_time=agent_refined_start_time,
|
||||
)
|
||||
@@ -0,0 +1,12 @@
|
||||
from datetime import datetime
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import LoggerUpdate
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import MainState
|
||||
|
||||
|
||||
def retrieval_consolidation(
|
||||
state: MainState,
|
||||
) -> LoggerUpdate:
|
||||
now_start = datetime.now()
|
||||
|
||||
return LoggerUpdate(log_messages=[f"{now_start} -- Retrieval consolidation"])
|
||||
@@ -0,0 +1,145 @@
|
||||
import re
|
||||
from collections.abc import Callable
|
||||
|
||||
from langchain_core.callbacks.manager import dispatch_custom_event
|
||||
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkStats
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import InitialAgentResultStats
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import QueryResult
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import (
|
||||
QuestionAnswerResults,
|
||||
)
|
||||
from onyx.chat.models import SubQuestionPiece
|
||||
from onyx.tools.models import SearchQueryInfo
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def remove_document_citations(text: str) -> str:
|
||||
"""
|
||||
Removes citation expressions of format '[[D1]]()' from text.
|
||||
The number after D can vary.
|
||||
|
||||
Args:
|
||||
text: Input text containing citations
|
||||
|
||||
Returns:
|
||||
Text with citations removed
|
||||
"""
|
||||
# Pattern explanation:
|
||||
# \[\[D\d+\]\]\(\) matches:
|
||||
# \[\[ - literal [[ characters
|
||||
# D - literal D character
|
||||
# \d+ - one or more digits
|
||||
# \]\] - literal ]] characters
|
||||
# \(\) - literal () characters
|
||||
return re.sub(r"\[\[(?:D|Q)\d+\]\]\(\)", "", text)
|
||||
|
||||
|
||||
def dispatch_subquestion(level: int) -> Callable[[str, int], None]:
|
||||
def _helper(sub_question_part: str, num: int) -> None:
|
||||
dispatch_custom_event(
|
||||
"decomp_qs",
|
||||
SubQuestionPiece(
|
||||
sub_question=sub_question_part,
|
||||
level=level,
|
||||
level_question_nr=num,
|
||||
),
|
||||
)
|
||||
|
||||
return _helper
|
||||
|
||||
|
||||
def calculate_initial_agent_stats(
|
||||
decomp_answer_results: list[QuestionAnswerResults],
|
||||
original_question_stats: AgentChunkStats,
|
||||
) -> InitialAgentResultStats:
|
||||
initial_agent_result_stats: InitialAgentResultStats = InitialAgentResultStats(
|
||||
sub_questions={},
|
||||
original_question={},
|
||||
agent_effectiveness={},
|
||||
)
|
||||
|
||||
orig_verified = original_question_stats.verified_count
|
||||
orig_support_score = original_question_stats.verified_avg_scores
|
||||
|
||||
verified_document_chunk_ids = []
|
||||
support_scores = 0.0
|
||||
|
||||
for decomp_answer_result in decomp_answer_results:
|
||||
verified_document_chunk_ids += (
|
||||
decomp_answer_result.sub_question_retrieval_stats.verified_doc_chunk_ids
|
||||
)
|
||||
if (
|
||||
decomp_answer_result.sub_question_retrieval_stats.verified_avg_scores
|
||||
is not None
|
||||
):
|
||||
support_scores += (
|
||||
decomp_answer_result.sub_question_retrieval_stats.verified_avg_scores
|
||||
)
|
||||
|
||||
verified_document_chunk_ids = list(set(verified_document_chunk_ids))
|
||||
|
||||
# Calculate sub-question stats
|
||||
if (
|
||||
verified_document_chunk_ids
|
||||
and len(verified_document_chunk_ids) > 0
|
||||
and support_scores is not None
|
||||
):
|
||||
sub_question_stats: dict[str, float | int | None] = {
|
||||
"num_verified_documents": len(verified_document_chunk_ids),
|
||||
"verified_avg_score": float(support_scores / len(decomp_answer_results)),
|
||||
}
|
||||
else:
|
||||
sub_question_stats = {"num_verified_documents": 0, "verified_avg_score": None}
|
||||
|
||||
initial_agent_result_stats.sub_questions.update(sub_question_stats)
|
||||
|
||||
# Get original question stats
|
||||
initial_agent_result_stats.original_question.update(
|
||||
{
|
||||
"num_verified_documents": original_question_stats.verified_count,
|
||||
"verified_avg_score": original_question_stats.verified_avg_scores,
|
||||
}
|
||||
)
|
||||
|
||||
# Calculate chunk utilization ratio
|
||||
sub_verified = initial_agent_result_stats.sub_questions["num_verified_documents"]
|
||||
|
||||
chunk_ratio: float | None = None
|
||||
if sub_verified is not None and orig_verified is not None and orig_verified > 0:
|
||||
chunk_ratio = (float(sub_verified) / orig_verified) if sub_verified > 0 else 0.0
|
||||
elif sub_verified is not None and sub_verified > 0:
|
||||
chunk_ratio = 10.0
|
||||
|
||||
initial_agent_result_stats.agent_effectiveness["utilized_chunk_ratio"] = chunk_ratio
|
||||
|
||||
if (
|
||||
orig_support_score is None
|
||||
or orig_support_score == 0.0
|
||||
and initial_agent_result_stats.sub_questions["verified_avg_score"] is None
|
||||
):
|
||||
initial_agent_result_stats.agent_effectiveness["support_ratio"] = None
|
||||
elif orig_support_score is None or orig_support_score == 0.0:
|
||||
initial_agent_result_stats.agent_effectiveness["support_ratio"] = 10
|
||||
elif initial_agent_result_stats.sub_questions["verified_avg_score"] is None:
|
||||
initial_agent_result_stats.agent_effectiveness["support_ratio"] = 0
|
||||
else:
|
||||
initial_agent_result_stats.agent_effectiveness["support_ratio"] = (
|
||||
initial_agent_result_stats.sub_questions["verified_avg_score"]
|
||||
/ orig_support_score
|
||||
)
|
||||
|
||||
return initial_agent_result_stats
|
||||
|
||||
|
||||
def get_query_info(results: list[QueryResult]) -> SearchQueryInfo:
|
||||
# Use the query info from the base document retrieval
|
||||
# TODO: see if this is the right way to do this
|
||||
query_infos = [
|
||||
result.query_info for result in results if result.query_info is not None
|
||||
]
|
||||
if len(query_infos) == 0:
|
||||
raise ValueError("No query info found")
|
||||
return query_infos[0]
|
||||
160
backend/onyx/agents/agent_search/deep_search_a/main/states.py
Normal file
160
backend/onyx/agents/agent_search/deep_search_a/main/states.py
Normal file
@@ -0,0 +1,160 @@
|
||||
from datetime import datetime
|
||||
from operator import add
|
||||
from typing import Annotated
|
||||
from typing import TypedDict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.agents.agent_search.core_state import CoreState
|
||||
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.models import (
|
||||
ExpandedRetrievalResult,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.models import AgentBaseMetrics
|
||||
from onyx.agents.agent_search.deep_search_a.main.models import AgentRefinedMetrics
|
||||
from onyx.agents.agent_search.deep_search_a.main.models import FollowUpSubQuestion
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkStats
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import (
|
||||
EntityRelationshipTermExtraction,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import InitialAgentResultStats
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import QueryResult
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import (
|
||||
QuestionAnswerResults,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import RefinedAgentStats
|
||||
from onyx.agents.agent_search.shared_graph_utils.operators import (
|
||||
dedup_inference_sections,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.operators import (
|
||||
dedup_question_answer_results,
|
||||
)
|
||||
from onyx.context.search.models import InferenceSection
|
||||
|
||||
### States ###
|
||||
|
||||
## Update States
|
||||
|
||||
|
||||
class LoggerUpdate(BaseModel):
|
||||
log_messages: Annotated[list[str], add] = []
|
||||
|
||||
|
||||
class RefinedAgentStartStats(BaseModel):
|
||||
agent_refined_start_time: datetime | None = None
|
||||
|
||||
|
||||
class RefinedAgentEndStats(BaseModel):
|
||||
agent_refined_end_time: datetime | None = None
|
||||
agent_refined_metrics: AgentRefinedMetrics = AgentRefinedMetrics()
|
||||
|
||||
|
||||
class BaseDecompUpdate(RefinedAgentStartStats, RefinedAgentEndStats):
|
||||
agent_start_time: datetime = datetime.now()
|
||||
initial_decomp_questions: list[str] = []
|
||||
|
||||
|
||||
class AnswerComparison(LoggerUpdate):
|
||||
refined_answer_improvement: bool = False
|
||||
|
||||
|
||||
class RoutingDecision(LoggerUpdate):
|
||||
routing: str = ""
|
||||
sample_doc_str: str = ""
|
||||
|
||||
|
||||
class InitialAnswerBASEUpdate(BaseModel):
|
||||
initial_base_answer: str = ""
|
||||
|
||||
|
||||
class InitialAnswerUpdate(LoggerUpdate):
|
||||
initial_answer: str = ""
|
||||
initial_agent_stats: InitialAgentResultStats | None = None
|
||||
generated_sub_questions: list[str] = []
|
||||
agent_base_end_time: datetime | None = None
|
||||
agent_base_metrics: AgentBaseMetrics | None = None
|
||||
|
||||
|
||||
class RefinedAnswerUpdate(RefinedAgentEndStats):
|
||||
refined_answer: str = ""
|
||||
refined_agent_stats: RefinedAgentStats | None = None
|
||||
refined_answer_quality: bool = False
|
||||
|
||||
|
||||
class InitialAnswerQualityUpdate(LoggerUpdate):
|
||||
initial_answer_quality: bool = False
|
||||
|
||||
|
||||
class RequireRefinedAnswerUpdate(LoggerUpdate):
|
||||
require_refined_answer: bool = True
|
||||
|
||||
|
||||
class DecompAnswersUpdate(LoggerUpdate):
|
||||
documents: Annotated[list[InferenceSection], dedup_inference_sections] = []
|
||||
context_documents: Annotated[list[InferenceSection], dedup_inference_sections] = []
|
||||
decomp_answer_results: Annotated[
|
||||
list[QuestionAnswerResults], dedup_question_answer_results
|
||||
] = []
|
||||
|
||||
|
||||
class FollowUpDecompAnswersUpdate(LoggerUpdate):
|
||||
refined_documents: Annotated[list[InferenceSection], dedup_inference_sections] = []
|
||||
refined_decomp_answer_results: Annotated[list[QuestionAnswerResults], add] = []
|
||||
|
||||
|
||||
class ExpandedRetrievalUpdate(LoggerUpdate):
|
||||
all_original_question_documents: Annotated[
|
||||
list[InferenceSection], dedup_inference_sections
|
||||
]
|
||||
original_question_retrieval_results: list[QueryResult] = []
|
||||
original_question_retrieval_stats: AgentChunkStats = AgentChunkStats()
|
||||
|
||||
|
||||
class EntityTermExtractionUpdate(LoggerUpdate):
|
||||
entity_retlation_term_extractions: EntityRelationshipTermExtraction = (
|
||||
EntityRelationshipTermExtraction()
|
||||
)
|
||||
|
||||
|
||||
class FollowUpSubQuestionsUpdate(RefinedAgentStartStats):
|
||||
refined_sub_questions: dict[int, FollowUpSubQuestion] = {}
|
||||
|
||||
|
||||
## Graph Input State
|
||||
## Graph Input State
|
||||
|
||||
|
||||
class MainInput(CoreState):
|
||||
pass
|
||||
|
||||
|
||||
## Graph State
|
||||
|
||||
|
||||
class MainState(
|
||||
# This includes the core state
|
||||
MainInput,
|
||||
BaseDecompUpdate,
|
||||
InitialAnswerUpdate,
|
||||
InitialAnswerBASEUpdate,
|
||||
DecompAnswersUpdate,
|
||||
ExpandedRetrievalUpdate,
|
||||
EntityTermExtractionUpdate,
|
||||
InitialAnswerQualityUpdate,
|
||||
RequireRefinedAnswerUpdate,
|
||||
FollowUpSubQuestionsUpdate,
|
||||
FollowUpDecompAnswersUpdate,
|
||||
RefinedAnswerUpdate,
|
||||
RefinedAgentStartStats,
|
||||
RefinedAgentEndStats,
|
||||
RoutingDecision,
|
||||
AnswerComparison,
|
||||
):
|
||||
# expanded_retrieval_result: Annotated[list[ExpandedRetrievalResult], add]
|
||||
base_raw_search_result: Annotated[list[ExpandedRetrievalResult], add]
|
||||
|
||||
|
||||
## Graph Output State - presently not used
|
||||
|
||||
|
||||
class MainOutput(TypedDict):
|
||||
log_messages: list[str]
|
||||
86
backend/onyx/agents/agent_search/models.py
Normal file
86
backend/onyx/agents/agent_search/models.py
Normal file
@@ -0,0 +1,86 @@
|
||||
from dataclasses import dataclass
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import model_validator
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.context.search.models import SearchRequest
|
||||
from onyx.file_store.utils import InMemoryChatFile
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.tools.force import ForceUseTool
|
||||
from onyx.tools.tool import Tool
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentSearchConfig:
|
||||
"""
|
||||
Configuration for the Agent Search feature.
|
||||
"""
|
||||
|
||||
# The search request that was used to generate the Pro Search
|
||||
search_request: SearchRequest
|
||||
|
||||
primary_llm: LLM
|
||||
fast_llm: LLM
|
||||
search_tool: SearchTool
|
||||
|
||||
# Whether to force use of a tool, or to
|
||||
# force tool args IF the tool is used
|
||||
force_use_tool: ForceUseTool
|
||||
|
||||
# contains message history for the current chat session
|
||||
# has the following (at most one is non-None)
|
||||
# message_history: list[PreviousMessage] | None = None
|
||||
# single_message_history: str | None = None
|
||||
prompt_builder: AnswerPromptBuilder
|
||||
|
||||
use_agentic_search: bool = False
|
||||
|
||||
# For persisting agent search data
|
||||
chat_session_id: UUID | None = None
|
||||
|
||||
# The message ID of the user message that triggered the Pro Search
|
||||
message_id: int | None = None
|
||||
|
||||
# Whether to persistence data for the Pro Search (turned off for testing)
|
||||
use_persistence: bool = True
|
||||
|
||||
# The database session for the Pro Search
|
||||
db_session: Session | None = None
|
||||
|
||||
# Whether to perform initial search to inform decomposition
|
||||
perform_initial_search_path_decision: bool = True
|
||||
|
||||
# Whether to perform initial search to inform decomposition
|
||||
perform_initial_search_decomposition: bool = True
|
||||
|
||||
# Whether to allow creation of refinement questions (and entity extraction, etc.)
|
||||
allow_refinement: bool = True
|
||||
|
||||
# Tools available for use
|
||||
tools: list[Tool] | None = None
|
||||
|
||||
using_tool_calling_llm: bool = False
|
||||
|
||||
files: list[InMemoryChatFile] | None = None
|
||||
|
||||
structured_response_format: dict | None = None
|
||||
|
||||
skip_gen_ai_answer_generation: bool = False
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_db_session(self) -> "AgentSearchConfig":
|
||||
if self.use_persistence and self.db_session is None:
|
||||
raise ValueError(
|
||||
"db_session must be provided for pro search when using persistence"
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
class AgentDocumentCitations(BaseModel):
|
||||
document_id: str
|
||||
document_title: str
|
||||
link: str
|
||||
272
backend/onyx/agents/agent_search/run_graph.py
Normal file
272
backend/onyx/agents/agent_search/run_graph.py
Normal file
@@ -0,0 +1,272 @@
|
||||
import asyncio
|
||||
from collections.abc import AsyncIterable
|
||||
from collections.abc import Iterable
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.runnables.schema import StreamEvent
|
||||
from langgraph.graph.state import CompiledStateGraph
|
||||
|
||||
from onyx.agents.agent_search.basic.graph_builder import basic_graph_builder
|
||||
from onyx.agents.agent_search.basic.states import BasicInput
|
||||
from onyx.agents.agent_search.deep_search_a.main.graph_builder import (
|
||||
main_graph_builder as main_graph_builder_a,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import MainInput as MainInput_a
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
|
||||
from onyx.chat.models import AgentAnswerPiece
|
||||
from onyx.chat.models import AnswerPacket
|
||||
from onyx.chat.models import AnswerStream
|
||||
from onyx.chat.models import ExtendedToolResponse
|
||||
from onyx.chat.models import RefinedAnswerImprovement
|
||||
from onyx.chat.models import StreamStopInfo
|
||||
from onyx.chat.models import SubQueryPiece
|
||||
from onyx.chat.models import SubQuestionPiece
|
||||
from onyx.chat.models import ToolResponse
|
||||
from onyx.configs.dev_configs import GRAPH_NAME
|
||||
from onyx.context.search.models import SearchRequest
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.tools.tool_runner import ToolCallKickoff
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_COMPILED_GRAPH: CompiledStateGraph | None = None
|
||||
|
||||
|
||||
def _set_combined_token_value(
|
||||
combined_token: str, parsed_object: AgentAnswerPiece
|
||||
) -> AgentAnswerPiece:
|
||||
parsed_object.answer_piece = combined_token
|
||||
|
||||
return parsed_object
|
||||
|
||||
|
||||
def _parse_agent_event(
|
||||
event: StreamEvent,
|
||||
) -> AnswerPacket | None:
|
||||
"""
|
||||
Parse the event into a typed object.
|
||||
Return None if we are not interested in the event.
|
||||
"""
|
||||
|
||||
event_type = event["event"]
|
||||
|
||||
# We always just yield the event data, but this piece is useful for two development reasons:
|
||||
# 1. It's a list of the names of every place we dispatch a custom event
|
||||
# 2. We maintain the intended types yielded by each event
|
||||
if event_type == "on_custom_event":
|
||||
# TODO: different AnswerStream types for different events
|
||||
if event["name"] == "decomp_qs":
|
||||
return cast(SubQuestionPiece, event["data"])
|
||||
elif event["name"] == "subqueries":
|
||||
return cast(SubQueryPiece, event["data"])
|
||||
elif event["name"] == "sub_answers":
|
||||
return cast(AgentAnswerPiece, event["data"])
|
||||
elif event["name"] == "stream_finished":
|
||||
return cast(StreamStopInfo, event["data"])
|
||||
elif event["name"] == "initial_agent_answer":
|
||||
return cast(AgentAnswerPiece, event["data"])
|
||||
elif event["name"] == "refined_agent_answer":
|
||||
return cast(AgentAnswerPiece, event["data"])
|
||||
elif event["name"] == "start_refined_answer_creation":
|
||||
return cast(ToolCallKickoff, event["data"])
|
||||
elif event["name"] == "tool_response":
|
||||
return cast(ToolResponse, event["data"])
|
||||
elif event["name"] == "basic_response":
|
||||
return cast(AnswerPacket, event["data"])
|
||||
elif event["name"] == "refined_answer_improvement":
|
||||
return cast(RefinedAnswerImprovement, event["data"])
|
||||
return None
|
||||
|
||||
|
||||
# https://stackoverflow.com/questions/60226557/how-to-forcefully-close-an-async-generator
|
||||
# https://stackoverflow.com/questions/40897428/please-explain-task-was-destroyed-but-it-is-pending-after-cancelling-tasks
|
||||
task_references: set[asyncio.Task[StreamEvent]] = set()
|
||||
|
||||
|
||||
def _manage_async_event_streaming(
|
||||
compiled_graph: CompiledStateGraph,
|
||||
config: AgentSearchConfig | None,
|
||||
graph_input: MainInput_a | BasicInput,
|
||||
) -> Iterable[StreamEvent]:
|
||||
async def _run_async_event_stream() -> AsyncIterable[StreamEvent]:
|
||||
message_id = config.message_id if config else None
|
||||
async for event in compiled_graph.astream_events(
|
||||
input=graph_input,
|
||||
config={"metadata": {"config": config, "thread_id": str(message_id)}},
|
||||
# debug=True,
|
||||
# indicating v2 here deserves further scrutiny
|
||||
version="v2",
|
||||
):
|
||||
yield event
|
||||
|
||||
# This might be able to be simplified
|
||||
def _yield_async_to_sync() -> Iterable[StreamEvent]:
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
# Get the async generator
|
||||
async_gen = _run_async_event_stream()
|
||||
# Convert to AsyncIterator
|
||||
async_iter = async_gen.__aiter__()
|
||||
while True:
|
||||
try:
|
||||
# Create a coroutine by calling anext with the async iterator
|
||||
next_coro = anext(async_iter)
|
||||
task = asyncio.ensure_future(next_coro, loop=loop)
|
||||
task_references.add(task)
|
||||
# Run the coroutine to get the next event
|
||||
event = loop.run_until_complete(task)
|
||||
yield event
|
||||
except (StopAsyncIteration, GeneratorExit):
|
||||
break
|
||||
finally:
|
||||
for task in task_references.pop():
|
||||
task.cancel()
|
||||
loop.close()
|
||||
|
||||
return _yield_async_to_sync()
|
||||
|
||||
|
||||
def run_graph(
|
||||
compiled_graph: CompiledStateGraph,
|
||||
config: AgentSearchConfig,
|
||||
input: BasicInput | MainInput_a,
|
||||
) -> AnswerStream:
|
||||
# TODO: add these to the environment
|
||||
config.perform_initial_search_path_decision = False
|
||||
config.perform_initial_search_decomposition = True
|
||||
config.allow_refinement = True
|
||||
|
||||
for event in _manage_async_event_streaming(
|
||||
compiled_graph=compiled_graph, config=config, graph_input=input
|
||||
):
|
||||
if not (parsed_object := _parse_agent_event(event)):
|
||||
continue
|
||||
|
||||
yield parsed_object
|
||||
|
||||
|
||||
# TODO: call this once on startup, TBD where and if it should be gated based
|
||||
# on dev mode or not
|
||||
def load_compiled_graph(graph_name: str) -> CompiledStateGraph:
|
||||
main_graph_builder = (
|
||||
main_graph_builder_a if graph_name == "a" else main_graph_builder_a
|
||||
)
|
||||
global _COMPILED_GRAPH
|
||||
if _COMPILED_GRAPH is None:
|
||||
graph = main_graph_builder()
|
||||
_COMPILED_GRAPH = graph.compile()
|
||||
return _COMPILED_GRAPH
|
||||
|
||||
|
||||
def run_main_graph(
|
||||
config: AgentSearchConfig,
|
||||
graph_name: str = "a",
|
||||
) -> AnswerStream:
|
||||
compiled_graph = load_compiled_graph(graph_name)
|
||||
if graph_name == "a":
|
||||
input = MainInput_a(base_question=config.search_request.query, log_messages=[])
|
||||
else:
|
||||
input = MainInput_a(base_question=config.search_request.query, log_messages=[])
|
||||
|
||||
# Agent search is not a Tool per se, but this is helpful for the frontend
|
||||
yield ToolCallKickoff(
|
||||
tool_name="agent_search_0",
|
||||
tool_args={"query": config.search_request.query},
|
||||
)
|
||||
yield from run_graph(compiled_graph, config, input)
|
||||
|
||||
|
||||
# TODO: unify input types, especially prosearchconfig
|
||||
def run_basic_graph(
|
||||
config: AgentSearchConfig,
|
||||
) -> AnswerStream:
|
||||
graph = basic_graph_builder()
|
||||
compiled_graph = graph.compile()
|
||||
# TODO: unify basic input
|
||||
input = BasicInput(
|
||||
should_stream_answer=True,
|
||||
)
|
||||
return run_graph(compiled_graph, config, input)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from onyx.llm.factory import get_default_llms
|
||||
|
||||
now_start = datetime.now()
|
||||
logger.debug(f"Start at {now_start}")
|
||||
|
||||
if GRAPH_NAME == "a":
|
||||
graph = main_graph_builder_a()
|
||||
else:
|
||||
graph = main_graph_builder_a()
|
||||
compiled_graph = graph.compile()
|
||||
now_end = datetime.now()
|
||||
logger.debug(f"Graph compiled in {now_end - now_start} seconds")
|
||||
primary_llm, fast_llm = get_default_llms()
|
||||
search_request = SearchRequest(
|
||||
# query="what can you do with gitlab?",
|
||||
# query="What are the guiding principles behind the development of cockroachDB",
|
||||
# query="What are the temperatures in Munich, Hawaii, and New York?",
|
||||
# query="When was Washington born?",
|
||||
query="What is Onyx?",
|
||||
)
|
||||
# Joachim custom persona
|
||||
|
||||
with get_session_context_manager() as db_session:
|
||||
config, search_tool = get_test_config(
|
||||
db_session, primary_llm, fast_llm, search_request
|
||||
)
|
||||
# search_request.persona = get_persona_by_id(1, None, db_session)
|
||||
config.use_persistence = True
|
||||
config.perform_initial_search_path_decision = False
|
||||
config.perform_initial_search_decomposition = True
|
||||
if GRAPH_NAME == "a":
|
||||
input = MainInput_a(
|
||||
base_question=config.search_request.query, log_messages=[]
|
||||
)
|
||||
else:
|
||||
input = MainInput_a(
|
||||
base_question=config.search_request.query, log_messages=[]
|
||||
)
|
||||
# with open("output.txt", "w") as f:
|
||||
tool_responses: list = []
|
||||
for output in run_graph(compiled_graph, config, input):
|
||||
# pass
|
||||
|
||||
if isinstance(output, ToolCallKickoff):
|
||||
pass
|
||||
elif isinstance(output, ExtendedToolResponse):
|
||||
tool_responses.append(output.response)
|
||||
logger.info(
|
||||
f" ---- ET {output.level} - {output.level_question_nr} | "
|
||||
)
|
||||
elif isinstance(output, SubQueryPiece):
|
||||
logger.info(
|
||||
f"Sq {output.level} - {output.level_question_nr} - {output.sub_query} | "
|
||||
)
|
||||
elif isinstance(output, SubQuestionPiece):
|
||||
logger.info(
|
||||
f"SQ {output.level} - {output.level_question_nr} - {output.sub_question} | "
|
||||
)
|
||||
elif (
|
||||
isinstance(output, AgentAnswerPiece)
|
||||
and output.answer_type == "agent_sub_answer"
|
||||
):
|
||||
logger.info(
|
||||
f" ---- SA {output.level} - {output.level_question_nr} {output.answer_piece} | "
|
||||
)
|
||||
elif (
|
||||
isinstance(output, AgentAnswerPiece)
|
||||
and output.answer_type == "agent_level_answer"
|
||||
):
|
||||
logger.info(
|
||||
f" ---------- FA {output.level} - {output.level_question_nr} {output.answer_piece} | "
|
||||
)
|
||||
elif isinstance(output, RefinedAnswerImprovement):
|
||||
logger.info(f" ---------- RE {output.refined_answer_improvement} | ")
|
||||
|
||||
# for tool_response in tool_responses:
|
||||
# logger.debug(tool_response)
|
||||
@@ -0,0 +1,100 @@
|
||||
from langchain.schema import AIMessage
|
||||
from langchain.schema import HumanMessage
|
||||
from langchain.schema import SystemMessage
|
||||
from langchain_core.messages.tool import ToolMessage
|
||||
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import BASE_RAG_PROMPT_v2
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import HISTORY_PROMPT
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import get_today_prompt
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.llm.interfaces import LLMConfig
|
||||
from onyx.llm.utils import get_max_input_tokens
|
||||
from onyx.natural_language_processing.utils import get_tokenizer
|
||||
from onyx.natural_language_processing.utils import tokenizer_trim_content
|
||||
|
||||
|
||||
def build_sub_question_answer_prompt(
|
||||
question: str,
|
||||
original_question: str,
|
||||
docs: list[InferenceSection],
|
||||
persona_specification: str,
|
||||
config: LLMConfig,
|
||||
) -> list[SystemMessage | HumanMessage | AIMessage | ToolMessage]:
|
||||
system_message = SystemMessage(
|
||||
content=persona_specification,
|
||||
)
|
||||
|
||||
date_str = get_today_prompt()
|
||||
|
||||
docs_format_list = [
|
||||
f"""Document Number: [D{doc_nr + 1}]\n
|
||||
Content: {doc.combined_content}\n\n"""
|
||||
for doc_nr, doc in enumerate(docs)
|
||||
]
|
||||
|
||||
docs_str = "\n\n".join(docs_format_list)
|
||||
|
||||
docs_str = trim_prompt_piece(
|
||||
config, docs_str, BASE_RAG_PROMPT_v2 + question + original_question + date_str
|
||||
)
|
||||
human_message = HumanMessage(
|
||||
content=BASE_RAG_PROMPT_v2.format(
|
||||
question=question,
|
||||
original_question=original_question,
|
||||
context=docs_str,
|
||||
date_prompt=date_str,
|
||||
)
|
||||
)
|
||||
|
||||
return [system_message, human_message]
|
||||
|
||||
|
||||
def trim_prompt_piece(config: LLMConfig, prompt_piece: str, reserved_str: str) -> str:
|
||||
# TODO: this truncating might add latency. We could do a rougher + faster check
|
||||
# first to determine whether truncation is needed
|
||||
|
||||
# TODO: maybe save the tokenizer and max input tokens if this is getting called multiple times?
|
||||
llm_tokenizer = get_tokenizer(
|
||||
provider_type=config.model_provider,
|
||||
model_name=config.model_name,
|
||||
)
|
||||
|
||||
max_tokens = get_max_input_tokens(
|
||||
model_provider=config.model_provider,
|
||||
model_name=config.model_name,
|
||||
)
|
||||
|
||||
# slightly conservative trimming
|
||||
return tokenizer_trim_content(
|
||||
content=prompt_piece,
|
||||
desired_length=max_tokens - len(llm_tokenizer.encode(reserved_str)),
|
||||
tokenizer=llm_tokenizer,
|
||||
)
|
||||
|
||||
|
||||
def build_history_prompt(prompt_builder: AnswerPromptBuilder | None) -> str:
|
||||
if prompt_builder is None:
|
||||
return ""
|
||||
|
||||
if prompt_builder.single_message_history is not None:
|
||||
history = prompt_builder.single_message_history
|
||||
else:
|
||||
history_components = []
|
||||
previous_message_type = None
|
||||
for message in prompt_builder.raw_message_history:
|
||||
if "user" in message.message_type:
|
||||
history_components.append(f"User: {message.message}\n")
|
||||
previous_message_type = "user"
|
||||
elif "assistant" in message.message_type:
|
||||
# only use the last agent answer for the history
|
||||
if previous_message_type != "assistant":
|
||||
history_components.append(f"You/Agent: {message.message}\n")
|
||||
else:
|
||||
history_components = history_components[:-1]
|
||||
history_components.append(f"You/Agent: {message.message}\n")
|
||||
previous_message_type = "assistant"
|
||||
else:
|
||||
continue
|
||||
history = "\n".join(history_components)
|
||||
return HISTORY_PROMPT.format(history=history) if history else ""
|
||||
@@ -0,0 +1,98 @@
|
||||
import numpy as np
|
||||
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import RetrievalFitScoreMetrics
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import RetrievalFitStats
|
||||
from onyx.chat.models import SectionRelevancePiece
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def unique_chunk_id(doc: InferenceSection) -> str:
|
||||
return f"{doc.center_chunk.document_id}_{doc.center_chunk.chunk_id}"
|
||||
|
||||
|
||||
def calculate_rank_shift(list1: list, list2: list, top_n: int = 20) -> float:
|
||||
shift = 0
|
||||
for rank_first, doc_id in enumerate(list1[:top_n], 1):
|
||||
try:
|
||||
rank_second = list2.index(doc_id) + 1
|
||||
except ValueError:
|
||||
rank_second = len(list2) # Document not found in second list
|
||||
|
||||
shift += np.abs(rank_first - rank_second) / np.log(1 + rank_first * rank_second)
|
||||
|
||||
return shift / top_n
|
||||
|
||||
|
||||
def get_fit_scores(
|
||||
pre_reranked_results: list[InferenceSection],
|
||||
post_reranked_results: list[InferenceSection] | list[SectionRelevancePiece],
|
||||
) -> RetrievalFitStats | None:
|
||||
"""
|
||||
Calculate retrieval metrics for search purposes
|
||||
"""
|
||||
|
||||
if len(pre_reranked_results) == 0 or len(post_reranked_results) == 0:
|
||||
return None
|
||||
|
||||
ranked_sections = {
|
||||
"initial": pre_reranked_results,
|
||||
"reranked": post_reranked_results,
|
||||
}
|
||||
|
||||
fit_eval: RetrievalFitStats = RetrievalFitStats(
|
||||
fit_score_lift=0,
|
||||
rerank_effect=0,
|
||||
fit_scores={
|
||||
"initial": RetrievalFitScoreMetrics(scores={}, chunk_ids=[]),
|
||||
"reranked": RetrievalFitScoreMetrics(scores={}, chunk_ids=[]),
|
||||
},
|
||||
)
|
||||
|
||||
for rank_type, docs in ranked_sections.items():
|
||||
logger.debug(f"rank_type: {rank_type}")
|
||||
|
||||
for i in [1, 5, 10]:
|
||||
fit_eval.fit_scores[rank_type].scores[str(i)] = (
|
||||
sum(
|
||||
[
|
||||
float(doc.center_chunk.score)
|
||||
for doc in docs[:i]
|
||||
if type(doc) == InferenceSection
|
||||
and doc.center_chunk.score is not None
|
||||
]
|
||||
)
|
||||
/ i
|
||||
)
|
||||
|
||||
fit_eval.fit_scores[rank_type].scores["fit_score"] = (
|
||||
1
|
||||
/ 3
|
||||
* (
|
||||
fit_eval.fit_scores[rank_type].scores["1"]
|
||||
+ fit_eval.fit_scores[rank_type].scores["5"]
|
||||
+ fit_eval.fit_scores[rank_type].scores["10"]
|
||||
)
|
||||
)
|
||||
|
||||
fit_eval.fit_scores[rank_type].scores["fit_score"] = fit_eval.fit_scores[
|
||||
rank_type
|
||||
].scores["1"]
|
||||
|
||||
fit_eval.fit_scores[rank_type].chunk_ids = [
|
||||
unique_chunk_id(doc) for doc in docs if type(doc) == InferenceSection
|
||||
]
|
||||
|
||||
fit_eval.fit_score_lift = (
|
||||
fit_eval.fit_scores["reranked"].scores["fit_score"]
|
||||
/ fit_eval.fit_scores["initial"].scores["fit_score"]
|
||||
)
|
||||
|
||||
fit_eval.rerank_effect = calculate_rank_shift(
|
||||
fit_eval.fit_scores["initial"].chunk_ids,
|
||||
fit_eval.fit_scores["reranked"].chunk_ids,
|
||||
)
|
||||
|
||||
return fit_eval
|
||||
113
backend/onyx/agents/agent_search/shared_graph_utils/models.py
Normal file
113
backend/onyx/agents/agent_search/shared_graph_utils/models.py
Normal file
@@ -0,0 +1,113 @@
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.main.models import AgentAdditionalMetrics
|
||||
from onyx.agents.agent_search.deep_search_a.main.models import AgentBaseMetrics
|
||||
from onyx.agents.agent_search.deep_search_a.main.models import AgentRefinedMetrics
|
||||
from onyx.agents.agent_search.deep_search_a.main.models import AgentTimings
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.tools.models import SearchQueryInfo
|
||||
|
||||
|
||||
# Pydantic models for structured outputs
|
||||
class RewrittenQueries(BaseModel):
|
||||
rewritten_queries: list[str]
|
||||
|
||||
|
||||
class BinaryDecision(BaseModel):
|
||||
decision: Literal["yes", "no"]
|
||||
|
||||
|
||||
class BinaryDecisionWithReasoning(BaseModel):
|
||||
reasoning: str
|
||||
decision: Literal["yes", "no"]
|
||||
|
||||
|
||||
class RetrievalFitScoreMetrics(BaseModel):
|
||||
scores: dict[str, float]
|
||||
chunk_ids: list[str]
|
||||
|
||||
|
||||
class RetrievalFitStats(BaseModel):
|
||||
fit_score_lift: float
|
||||
rerank_effect: float
|
||||
fit_scores: dict[str, RetrievalFitScoreMetrics]
|
||||
|
||||
|
||||
class AgentChunkScores(BaseModel):
|
||||
scores: dict[str, dict[str, list[int | float]]]
|
||||
|
||||
|
||||
class AgentChunkStats(BaseModel):
|
||||
verified_count: int | None = None
|
||||
verified_avg_scores: float | None = None
|
||||
rejected_count: int | None = None
|
||||
rejected_avg_scores: float | None = None
|
||||
verified_doc_chunk_ids: list[str] = []
|
||||
dismissed_doc_chunk_ids: list[str] = []
|
||||
|
||||
|
||||
class InitialAgentResultStats(BaseModel):
|
||||
sub_questions: dict[str, float | int | None]
|
||||
original_question: dict[str, float | int | None]
|
||||
agent_effectiveness: dict[str, float | int | None]
|
||||
|
||||
|
||||
class RefinedAgentStats(BaseModel):
|
||||
revision_doc_efficiency: float | None
|
||||
revision_question_efficiency: float | None
|
||||
|
||||
|
||||
class Term(BaseModel):
|
||||
term_name: str = ""
|
||||
term_type: str = ""
|
||||
term_similar_to: list[str] = []
|
||||
|
||||
|
||||
### Models ###
|
||||
|
||||
|
||||
class Entity(BaseModel):
|
||||
entity_name: str = ""
|
||||
entity_type: str = ""
|
||||
|
||||
|
||||
class Relationship(BaseModel):
|
||||
relationship_name: str = ""
|
||||
relationship_type: str = ""
|
||||
relationship_entities: list[str] = []
|
||||
|
||||
|
||||
class EntityRelationshipTermExtraction(BaseModel):
|
||||
entities: list[Entity] = []
|
||||
relationships: list[Relationship] = []
|
||||
terms: list[Term] = []
|
||||
|
||||
|
||||
### Models ###
|
||||
|
||||
|
||||
class QueryResult(BaseModel):
|
||||
query: str
|
||||
search_results: list[InferenceSection]
|
||||
stats: RetrievalFitStats | None
|
||||
query_info: SearchQueryInfo | None
|
||||
|
||||
|
||||
class QuestionAnswerResults(BaseModel):
|
||||
question: str
|
||||
question_id: str
|
||||
answer: str
|
||||
quality: str
|
||||
expanded_retrieval_results: list[QueryResult]
|
||||
documents: list[InferenceSection]
|
||||
context_documents: list[InferenceSection]
|
||||
sub_question_retrieval_stats: AgentChunkStats
|
||||
|
||||
|
||||
class CombinedAgentMetrics(BaseModel):
|
||||
timings: AgentTimings
|
||||
base_metrics: AgentBaseMetrics | None
|
||||
refined_metrics: AgentRefinedMetrics
|
||||
additional_metrics: AgentAdditionalMetrics
|
||||
@@ -0,0 +1,31 @@
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import (
|
||||
QuestionAnswerResults,
|
||||
)
|
||||
from onyx.chat.prune_and_merge import _merge_sections
|
||||
from onyx.context.search.models import InferenceSection
|
||||
|
||||
|
||||
def dedup_inference_sections(
|
||||
list1: list[InferenceSection], list2: list[InferenceSection]
|
||||
) -> list[InferenceSection]:
|
||||
deduped = _merge_sections(list1 + list2)
|
||||
return deduped
|
||||
|
||||
|
||||
def dedup_question_answer_results(
|
||||
question_answer_results_1: list[QuestionAnswerResults],
|
||||
question_answer_results_2: list[QuestionAnswerResults],
|
||||
) -> list[QuestionAnswerResults]:
|
||||
deduped_question_answer_results: list[
|
||||
QuestionAnswerResults
|
||||
] = question_answer_results_1
|
||||
utilized_question_ids: set[str] = set(
|
||||
[x.question_id for x in question_answer_results_1]
|
||||
)
|
||||
|
||||
for question_answer_result in question_answer_results_2:
|
||||
if question_answer_result.question_id not in utilized_question_ids:
|
||||
deduped_question_answer_results.append(question_answer_result)
|
||||
utilized_question_ids.add(question_answer_result.question_id)
|
||||
|
||||
return deduped_question_answer_results
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user