mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-17 15:55:45 +00:00
Compare commits
279 Commits
l
...
random_doc
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
477f8eeb68 | ||
|
|
737e37170d | ||
|
|
c58a7ef819 | ||
|
|
bd08e6d787 | ||
|
|
47e6192b99 | ||
|
|
29f5f4edfa | ||
|
|
b469a7eff4 | ||
|
|
d1e9760b92 | ||
|
|
7153cb09f1 | ||
|
|
78153e5012 | ||
|
|
b1ee1efecb | ||
|
|
526932a7f6 | ||
|
|
6889152d81 | ||
|
|
4affc259a6 | ||
|
|
0ec065f1fb | ||
|
|
8eb4320f76 | ||
|
|
1c12ab31f9 | ||
|
|
49fd76b336 | ||
|
|
5854b39dd4 | ||
|
|
c0271a948a | ||
|
|
aff4ee5ebf | ||
|
|
675d2f3539 | ||
|
|
2974b57ef4 | ||
|
|
679bdd5e04 | ||
|
|
e6cb47fcb8 | ||
|
|
a514818e13 | ||
|
|
89021cde90 | ||
|
|
32ecc282a2 | ||
|
|
59b1d4673f | ||
|
|
ec0c655c8d | ||
|
|
42a0f45a96 | ||
|
|
125e5eaab1 | ||
|
|
f2dab9ba89 | ||
|
|
02a068a68b | ||
|
|
91f0650071 | ||
|
|
b97819189b | ||
|
|
b928201397 | ||
|
|
b500c914b0 | ||
|
|
4b0d22fae3 | ||
|
|
b46c09ac6c | ||
|
|
3ce8923086 | ||
|
|
7ac6d3ed50 | ||
|
|
3cd057d7a2 | ||
|
|
4834ee6223 | ||
|
|
cb85be41b1 | ||
|
|
eb227c0acc | ||
|
|
25a57e2292 | ||
|
|
3f3b04a4ee | ||
|
|
3f6de7968a | ||
|
|
024207e2d9 | ||
|
|
8f7db9212c | ||
|
|
b1e9e03aa4 | ||
|
|
87a53d6d80 | ||
|
|
59c65a4192 | ||
|
|
c984c6c7f2 | ||
|
|
9a3ce504bc | ||
|
|
16265d27f5 | ||
|
|
570fe43efb | ||
|
|
506a9f1b94 | ||
|
|
a067b32467 | ||
|
|
9b6e51b4fe | ||
|
|
e23dd0a3fa | ||
|
|
71304e4228 | ||
|
|
2adeaaeded | ||
|
|
a96728ff4d | ||
|
|
eaffdee0dc | ||
|
|
feaa3b653f | ||
|
|
9438f9df05 | ||
|
|
b90e0834a5 | ||
|
|
29440f5482 | ||
|
|
5a95a5c9fd | ||
|
|
118e8afbef | ||
|
|
8342168658 | ||
|
|
d5661baf98 | ||
|
|
95fcc0019c | ||
|
|
0ccd83e809 | ||
|
|
732861a940 | ||
|
|
d53dd1e356 | ||
|
|
1a2760edee | ||
|
|
23ae4547ca | ||
|
|
385b344a43 | ||
|
|
a340529de3 | ||
|
|
4a0b2a6c09 | ||
|
|
756a1cbf8f | ||
|
|
8af4f1da8e | ||
|
|
4b82440915 | ||
|
|
bb6d55783e | ||
|
|
2b8cd63b34 | ||
|
|
b0c3098693 | ||
|
|
2517aa39b2 | ||
|
|
ceaaa05af0 | ||
|
|
3b13380051 | ||
|
|
ef6e6f9556 | ||
|
|
0a6808c4c1 | ||
|
|
6442c56d82 | ||
|
|
e191e514b9 | ||
|
|
f33a2ffb01 | ||
|
|
0578c31522 | ||
|
|
8cbdc6d8fe | ||
|
|
60fb06da4e | ||
|
|
55ed6e2294 | ||
|
|
42780d5f97 | ||
|
|
f050d281fd | ||
|
|
3ca4d532b4 | ||
|
|
e3e855c526 | ||
|
|
23bf50b90a | ||
|
|
c43c2320e7 | ||
|
|
01e6e9a2ba | ||
|
|
bd3b1943c4 | ||
|
|
1dbf561db0 | ||
|
|
a43a6627eb | ||
|
|
5bff8bc8ce | ||
|
|
7879ba6a77 | ||
|
|
a63b341913 | ||
|
|
c062097b2a | ||
|
|
48e42af8e7 | ||
|
|
6c7f8eaefb | ||
|
|
3d99ad7bc4 | ||
|
|
8fea571f6e | ||
|
|
d70bbcc2ce | ||
|
|
73769c6cae | ||
|
|
7e98936c58 | ||
|
|
4e17fc06ff | ||
|
|
ff4df6f3bf | ||
|
|
91b929d466 | ||
|
|
6bef5ca7a4 | ||
|
|
4817fa0bd1 | ||
|
|
da4a086398 | ||
|
|
69e8c5f0fc | ||
|
|
12d1186888 | ||
|
|
325892a21c | ||
|
|
18d92559b5 | ||
|
|
f2aeeb7b3c | ||
|
|
110c9f7e1b | ||
|
|
1a22af4f27 | ||
|
|
efa32a8c04 | ||
|
|
9bad12968f | ||
|
|
f1d96343a9 | ||
|
|
0496ec3bb8 | ||
|
|
568f927b9b | ||
|
|
f842e15d64 | ||
|
|
3a07093663 | ||
|
|
1fe966d0f7 | ||
|
|
812172f1bd | ||
|
|
9e9bd440f4 | ||
|
|
7487b15522 | ||
|
|
de5ce8a613 | ||
|
|
8c9577aa95 | ||
|
|
4baf3dc484 | ||
|
|
50ef5115e7 | ||
|
|
a2247363af | ||
|
|
a0af8ee91c | ||
|
|
25f6543443 | ||
|
|
d52a0b96ac | ||
|
|
f14b282f0f | ||
|
|
7d494cd65e | ||
|
|
139374966f | ||
|
|
bf06710215 | ||
|
|
d4e0d0db05 | ||
|
|
f96a3ee29a | ||
|
|
3bf6b77319 | ||
|
|
3b3b0c8a87 | ||
|
|
aa8cb44a33 | ||
|
|
fc60fd0322 | ||
|
|
46402a97c7 | ||
|
|
5bf6a47948 | ||
|
|
2d8486bac4 | ||
|
|
eea6f2749a | ||
|
|
5e9b2e41ae | ||
|
|
2bbe20edc3 | ||
|
|
db2004542e | ||
|
|
ddbfc65ad0 | ||
|
|
982040c792 | ||
|
|
4b0a4a2741 | ||
|
|
28ba01b361 | ||
|
|
d32d1c6079 | ||
|
|
dd494d2daa | ||
|
|
eb6dbf49a1 | ||
|
|
e5fa411092 | ||
|
|
1ced8924b3 | ||
|
|
3c3900fac6 | ||
|
|
3b298e19bc | ||
|
|
71eafe04a8 | ||
|
|
80d248e02d | ||
|
|
2032fb10da | ||
|
|
ca1f176c61 | ||
|
|
3ced9bc28b | ||
|
|
deea9c8c3c | ||
|
|
4e47c81ed8 | ||
|
|
00cee71c18 | ||
|
|
470c4d15dd | ||
|
|
50bacc03b3 | ||
|
|
dd260140b2 | ||
|
|
8aa82be12a | ||
|
|
b7f9e431a5 | ||
|
|
b9bd2ea4e2 | ||
|
|
e4c93bed8b | ||
|
|
4fd6e36c2f | ||
|
|
715359c120 | ||
|
|
6f018d75ee | ||
|
|
fd947aadea | ||
|
|
e061ba2b93 | ||
|
|
87bccc13cc | ||
|
|
3a950721b9 | ||
|
|
569639eb90 | ||
|
|
68cb1f3409 | ||
|
|
11da0d9889 | ||
|
|
6a7e2a8036 | ||
|
|
035f83c464 | ||
|
|
3c34ddcc4f | ||
|
|
bbee2865e9 | ||
|
|
a82cac5361 | ||
|
|
83e5cb2d2f | ||
|
|
a5d2f0d9ac | ||
|
|
d3cf18160e | ||
|
|
618e4addd8 | ||
|
|
69f16cc972 | ||
|
|
2676d40065 | ||
|
|
b64545c7c7 | ||
|
|
7bc8554e01 | ||
|
|
5232aeacad | ||
|
|
261150e81a | ||
|
|
3e0d24a3f6 | ||
|
|
ffe8ac168f | ||
|
|
17b280e59e | ||
|
|
5edba4a7f3 | ||
|
|
d842fed37e | ||
|
|
14981162fd | ||
|
|
288daa4e90 | ||
|
|
30e8fb12e4 | ||
|
|
d8578bc1cb | ||
|
|
5e21dc6cb3 | ||
|
|
39b3a503b4 | ||
|
|
a70d472b5c | ||
|
|
0ed2886ad0 | ||
|
|
6b31e2f622 | ||
|
|
aabf8a99bc | ||
|
|
7ccfe85ee5 | ||
|
|
95701db1bd | ||
|
|
24105254ac | ||
|
|
4fe99d05fd | ||
|
|
d35f93b233 | ||
|
|
766b0f35df | ||
|
|
a0470a96eb | ||
|
|
b82123563b | ||
|
|
787e25cd78 | ||
|
|
c6375f8abf | ||
|
|
58e5deba01 | ||
|
|
028e877342 | ||
|
|
47bff2b6a9 | ||
|
|
1502bcea12 | ||
|
|
2701f83634 | ||
|
|
601037abb5 | ||
|
|
7e9b12403a | ||
|
|
d903e5912a | ||
|
|
d2aea63573 | ||
|
|
57b4639709 | ||
|
|
1308b6cbe8 | ||
|
|
98abd7d3fa | ||
|
|
e4180cefba | ||
|
|
f67b5356fa | ||
|
|
9bdb581220 | ||
|
|
42d6d935ae | ||
|
|
8d62b992ef | ||
|
|
2ad86aa9a6 | ||
|
|
74a472ece7 | ||
|
|
b2ce848b53 | ||
|
|
519ec20d05 | ||
|
|
3b1e26d0d4 | ||
|
|
118d2b52e6 | ||
|
|
e625884702 | ||
|
|
fa78f50fe3 | ||
|
|
05ab94945b | ||
|
|
7a64a25ff4 | ||
|
|
7f10494bbe | ||
|
|
f2d4024783 | ||
|
|
70795a4047 | ||
|
|
d8a17a7238 | ||
|
|
cbf98c0128 |
2
.github/workflows/pr-chromatic-tests.yml
vendored
2
.github/workflows/pr-chromatic-tests.yml
vendored
@@ -8,6 +8,8 @@ on: push
|
||||
env:
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
|
||||
GEN_AI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
MOCK_LLM_RESPONSE: true
|
||||
|
||||
jobs:
|
||||
playwright-tests:
|
||||
|
||||
22
.github/workflows/pr-helm-chart-testing.yml
vendored
22
.github/workflows/pr-helm-chart-testing.yml
vendored
@@ -21,10 +21,10 @@ jobs:
|
||||
- name: Set up Helm
|
||||
uses: azure/setup-helm@v4.2.0
|
||||
with:
|
||||
version: v3.14.4
|
||||
version: v3.17.0
|
||||
|
||||
- name: Set up chart-testing
|
||||
uses: helm/chart-testing-action@v2.6.1
|
||||
uses: helm/chart-testing-action@v2.7.0
|
||||
|
||||
# even though we specify chart-dirs in ct.yaml, it isn't used by ct for the list-changed command...
|
||||
- name: Run chart-testing (list-changed)
|
||||
@@ -37,22 +37,6 @@ jobs:
|
||||
echo "changed=true" >> "$GITHUB_OUTPUT"
|
||||
fi
|
||||
|
||||
# rkuo: I don't think we need python?
|
||||
# - name: Set up Python
|
||||
# uses: actions/setup-python@v5
|
||||
# with:
|
||||
# python-version: '3.11'
|
||||
# cache: 'pip'
|
||||
# cache-dependency-path: |
|
||||
# backend/requirements/default.txt
|
||||
# backend/requirements/dev.txt
|
||||
# backend/requirements/model_server.txt
|
||||
# - run: |
|
||||
# python -m pip install --upgrade pip
|
||||
# pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
|
||||
# pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
|
||||
# pip install --retries 5 --timeout 30 -r backend/requirements/model_server.txt
|
||||
|
||||
# lint all charts if any changes were detected
|
||||
- name: Run chart-testing (lint)
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
@@ -62,7 +46,7 @@ jobs:
|
||||
|
||||
- name: Create kind cluster
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
uses: helm/kind-action@v1.10.0
|
||||
uses: helm/kind-action@v1.12.0
|
||||
|
||||
- name: Run chart-testing (install)
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
|
||||
@@ -39,6 +39,12 @@ env:
|
||||
AIRTABLE_TEST_TABLE_ID: ${{ secrets.AIRTABLE_TEST_TABLE_ID }}
|
||||
AIRTABLE_TEST_TABLE_NAME: ${{ secrets.AIRTABLE_TEST_TABLE_NAME }}
|
||||
AIRTABLE_ACCESS_TOKEN: ${{ secrets.AIRTABLE_ACCESS_TOKEN }}
|
||||
# Sharepoint
|
||||
SHAREPOINT_CLIENT_ID: ${{ secrets.SHAREPOINT_CLIENT_ID }}
|
||||
SHAREPOINT_CLIENT_SECRET: ${{ secrets.SHAREPOINT_CLIENT_SECRET }}
|
||||
SHAREPOINT_CLIENT_DIRECTORY_ID: ${{ secrets.SHAREPOINT_CLIENT_DIRECTORY_ID }}
|
||||
SHAREPOINT_SITE: ${{ secrets.SHAREPOINT_SITE }}
|
||||
|
||||
jobs:
|
||||
connectors-check:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
|
||||
4
.gitignore
vendored
4
.gitignore
vendored
@@ -7,4 +7,6 @@
|
||||
.vscode/
|
||||
*.sw?
|
||||
/backend/tests/regression/answer_quality/search_test_config.yaml
|
||||
/web/test-results/
|
||||
/web/test-results/
|
||||
backend/onyx/agent_search/main/test_data.json
|
||||
backend/tests/regression/answer_quality/test_data.json
|
||||
|
||||
6
.vscode/env_template.txt
vendored
6
.vscode/env_template.txt
vendored
@@ -52,3 +52,9 @@ BING_API_KEY=<REPLACE THIS>
|
||||
# Enable the full set of Danswer Enterprise Edition features
|
||||
# NOTE: DO NOT ENABLE THIS UNLESS YOU HAVE A PAID ENTERPRISE LICENSE (or if you are using this for local testing/development)
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=False
|
||||
|
||||
# Agent Search configs # TODO: Remove give proper namings
|
||||
AGENT_RETRIEVAL_STATS=False # Note: This setting will incur substantial re-ranking effort
|
||||
AGENT_RERANKING_STATS=True
|
||||
AGENT_MAX_QUERY_RETRIEVAL_RESULTS=20
|
||||
AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS=20
|
||||
|
||||
@@ -124,7 +124,7 @@ There are two editions of Onyx:
|
||||
To try the Onyx Enterprise Edition:
|
||||
|
||||
1. Checkout our [Cloud product](https://cloud.onyx.app/signup).
|
||||
2. For self-hosting, contact us at [founders@onyx.app](mailto:founders@onyx.app) or book a call with us on our [Cal](https://cal.com/team/danswer/founders).
|
||||
2. For self-hosting, contact us at [founders@onyx.app](mailto:founders@onyx.app) or book a call with us on our [Cal](https://cal.com/team/onyx/founders).
|
||||
|
||||
## 💡 Contributing
|
||||
|
||||
|
||||
@@ -0,0 +1,36 @@
|
||||
"""add chat session specific temperature override
|
||||
|
||||
Revision ID: 2f80c6a2550f
|
||||
Revises: 33ea50e88f24
|
||||
Create Date: 2025-01-31 10:30:27.289646
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "2f80c6a2550f"
|
||||
down_revision = "33ea50e88f24"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"chat_session", sa.Column("temperature_override", sa.Float(), nullable=True)
|
||||
)
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column(
|
||||
"temperature_override_enabled",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.false(),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("chat_session", "temperature_override")
|
||||
op.drop_column("user", "temperature_override_enabled")
|
||||
@@ -0,0 +1,80 @@
|
||||
"""foreign key input prompts
|
||||
|
||||
Revision ID: 33ea50e88f24
|
||||
Revises: a6df6b88ef81
|
||||
Create Date: 2025-01-29 10:54:22.141765
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "33ea50e88f24"
|
||||
down_revision = "a6df6b88ef81"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Safely drop constraints if exists
|
||||
op.execute(
|
||||
"""
|
||||
ALTER TABLE inputprompt__user
|
||||
DROP CONSTRAINT IF EXISTS inputprompt__user_input_prompt_id_fkey
|
||||
"""
|
||||
)
|
||||
op.execute(
|
||||
"""
|
||||
ALTER TABLE inputprompt__user
|
||||
DROP CONSTRAINT IF EXISTS inputprompt__user_user_id_fkey
|
||||
"""
|
||||
)
|
||||
|
||||
# Recreate with ON DELETE CASCADE
|
||||
op.create_foreign_key(
|
||||
"inputprompt__user_input_prompt_id_fkey",
|
||||
"inputprompt__user",
|
||||
"inputprompt",
|
||||
["input_prompt_id"],
|
||||
["id"],
|
||||
ondelete="CASCADE",
|
||||
)
|
||||
|
||||
op.create_foreign_key(
|
||||
"inputprompt__user_user_id_fkey",
|
||||
"inputprompt__user",
|
||||
"user",
|
||||
["user_id"],
|
||||
["id"],
|
||||
ondelete="CASCADE",
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop the new FKs with ondelete
|
||||
op.drop_constraint(
|
||||
"inputprompt__user_input_prompt_id_fkey",
|
||||
"inputprompt__user",
|
||||
type_="foreignkey",
|
||||
)
|
||||
op.drop_constraint(
|
||||
"inputprompt__user_user_id_fkey",
|
||||
"inputprompt__user",
|
||||
type_="foreignkey",
|
||||
)
|
||||
|
||||
# Recreate them without cascading
|
||||
op.create_foreign_key(
|
||||
"inputprompt__user_input_prompt_id_fkey",
|
||||
"inputprompt__user",
|
||||
"inputprompt",
|
||||
["input_prompt_id"],
|
||||
["id"],
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"inputprompt__user_user_id_fkey",
|
||||
"inputprompt__user",
|
||||
"user",
|
||||
["user_id"],
|
||||
["id"],
|
||||
)
|
||||
@@ -0,0 +1,37 @@
|
||||
"""lowercase_user_emails
|
||||
|
||||
Revision ID: 4d58345da04a
|
||||
Revises: f1ca58b2f2ec
|
||||
Create Date: 2025-01-29 07:48:46.784041
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
from sqlalchemy.sql import text
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "4d58345da04a"
|
||||
down_revision = "f1ca58b2f2ec"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Get database connection
|
||||
connection = op.get_bind()
|
||||
|
||||
# Update all user emails to lowercase
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
UPDATE "user"
|
||||
SET email = LOWER(email)
|
||||
WHERE email != LOWER(email)
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Cannot restore original case of emails
|
||||
pass
|
||||
107
backend/alembic/versions/98a5008d8711_agent_tracking.py
Normal file
107
backend/alembic/versions/98a5008d8711_agent_tracking.py
Normal file
@@ -0,0 +1,107 @@
|
||||
"""agent_tracking
|
||||
|
||||
Revision ID: 98a5008d8711
|
||||
Revises: 2f80c6a2550f
|
||||
Create Date: 2025-01-29 17:00:00.000001
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "98a5008d8711"
|
||||
down_revision = "2f80c6a2550f"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"agent__search_metrics",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=True),
|
||||
sa.Column("persona_id", sa.Integer(), nullable=True),
|
||||
sa.Column("agent_type", sa.String(), nullable=False),
|
||||
sa.Column("start_time", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("base_duration_s", sa.Float(), nullable=False),
|
||||
sa.Column("full_duration_s", sa.Float(), nullable=False),
|
||||
sa.Column("base_metrics", postgresql.JSONB(), nullable=True),
|
||||
sa.Column("refined_metrics", postgresql.JSONB(), nullable=True),
|
||||
sa.Column("all_metrics", postgresql.JSONB(), nullable=True),
|
||||
sa.ForeignKeyConstraint(
|
||||
["persona_id"],
|
||||
["persona.id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
|
||||
# Create sub_question table
|
||||
op.create_table(
|
||||
"agent__sub_question",
|
||||
sa.Column("id", sa.Integer, primary_key=True),
|
||||
sa.Column("primary_question_id", sa.Integer, sa.ForeignKey("chat_message.id")),
|
||||
sa.Column(
|
||||
"chat_session_id", UUID(as_uuid=True), sa.ForeignKey("chat_session.id")
|
||||
),
|
||||
sa.Column("sub_question", sa.Text),
|
||||
sa.Column(
|
||||
"time_created", sa.DateTime(timezone=True), server_default=sa.func.now()
|
||||
),
|
||||
sa.Column("sub_answer", sa.Text),
|
||||
sa.Column("sub_question_doc_results", postgresql.JSONB(), nullable=True),
|
||||
sa.Column("level", sa.Integer(), nullable=False),
|
||||
sa.Column("level_question_num", sa.Integer(), nullable=False),
|
||||
)
|
||||
|
||||
# Create sub_query table
|
||||
op.create_table(
|
||||
"agent__sub_query",
|
||||
sa.Column("id", sa.Integer, primary_key=True),
|
||||
sa.Column(
|
||||
"parent_question_id", sa.Integer, sa.ForeignKey("agent__sub_question.id")
|
||||
),
|
||||
sa.Column(
|
||||
"chat_session_id", UUID(as_uuid=True), sa.ForeignKey("chat_session.id")
|
||||
),
|
||||
sa.Column("sub_query", sa.Text),
|
||||
sa.Column(
|
||||
"time_created", sa.DateTime(timezone=True), server_default=sa.func.now()
|
||||
),
|
||||
)
|
||||
|
||||
# Create sub_query__search_doc association table
|
||||
op.create_table(
|
||||
"agent__sub_query__search_doc",
|
||||
sa.Column(
|
||||
"sub_query_id",
|
||||
sa.Integer,
|
||||
sa.ForeignKey("agent__sub_query.id"),
|
||||
primary_key=True,
|
||||
),
|
||||
sa.Column(
|
||||
"search_doc_id",
|
||||
sa.Integer,
|
||||
sa.ForeignKey("search_doc.id"),
|
||||
primary_key=True,
|
||||
),
|
||||
)
|
||||
|
||||
op.add_column(
|
||||
"chat_message",
|
||||
sa.Column(
|
||||
"refined_answer_improvement",
|
||||
sa.Boolean(),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("chat_message", "refined_answer_improvement")
|
||||
op.drop_table("agent__sub_query__search_doc")
|
||||
op.drop_table("agent__sub_query")
|
||||
op.drop_table("agent__sub_question")
|
||||
op.drop_table("agent__search_metrics")
|
||||
@@ -0,0 +1,29 @@
|
||||
"""remove recent assistants
|
||||
|
||||
Revision ID: a6df6b88ef81
|
||||
Revises: 4d58345da04a
|
||||
Create Date: 2025-01-29 10:25:52.790407
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "a6df6b88ef81"
|
||||
down_revision = "4d58345da04a"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.drop_column("user", "recent_assistants")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column(
|
||||
"recent_assistants", postgresql.JSONB(), server_default="[]", nullable=False
|
||||
),
|
||||
)
|
||||
@@ -0,0 +1,76 @@
|
||||
"""add default slack channel config
|
||||
|
||||
Revision ID: eaa3b5593925
|
||||
Revises: 98a5008d8711
|
||||
Create Date: 2025-02-03 18:07:56.552526
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "eaa3b5593925"
|
||||
down_revision = "98a5008d8711"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add is_default column
|
||||
op.add_column(
|
||||
"slack_channel_config",
|
||||
sa.Column("is_default", sa.Boolean(), nullable=False, server_default="false"),
|
||||
)
|
||||
|
||||
op.create_index(
|
||||
"ix_slack_channel_config_slack_bot_id_default",
|
||||
"slack_channel_config",
|
||||
["slack_bot_id", "is_default"],
|
||||
unique=True,
|
||||
postgresql_where=sa.text("is_default IS TRUE"),
|
||||
)
|
||||
|
||||
# Create default channel configs for existing slack bots without one
|
||||
conn = op.get_bind()
|
||||
slack_bots = conn.execute(sa.text("SELECT id FROM slack_bot")).fetchall()
|
||||
|
||||
for slack_bot in slack_bots:
|
||||
slack_bot_id = slack_bot[0]
|
||||
existing_default = conn.execute(
|
||||
sa.text(
|
||||
"SELECT id FROM slack_channel_config WHERE slack_bot_id = :bot_id AND is_default = TRUE"
|
||||
),
|
||||
{"bot_id": slack_bot_id},
|
||||
).fetchone()
|
||||
|
||||
if not existing_default:
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO slack_channel_config (
|
||||
slack_bot_id, persona_id, channel_config, enable_auto_filters, is_default
|
||||
) VALUES (
|
||||
:bot_id, NULL,
|
||||
'{"channel_name": null, "respond_member_group_list": [], "answer_filters": [], "follow_up_tags": []}',
|
||||
FALSE, TRUE
|
||||
)
|
||||
"""
|
||||
),
|
||||
{"bot_id": slack_bot_id},
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Delete default slack channel configs
|
||||
conn = op.get_bind()
|
||||
conn.execute(sa.text("DELETE FROM slack_channel_config WHERE is_default = TRUE"))
|
||||
|
||||
# Remove index
|
||||
op.drop_index(
|
||||
"ix_slack_channel_config_slack_bot_id_default",
|
||||
table_name="slack_channel_config",
|
||||
)
|
||||
|
||||
# Remove is_default column
|
||||
op.drop_column("slack_channel_config", "is_default")
|
||||
@@ -32,6 +32,7 @@ def perform_ttl_management_task(
|
||||
|
||||
@celery_app.task(
|
||||
name="check_ttl_management_task",
|
||||
ignore_result=True,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
)
|
||||
def check_ttl_management_task(*, tenant_id: str | None) -> None:
|
||||
@@ -56,6 +57,7 @@ def check_ttl_management_task(*, tenant_id: str | None) -> None:
|
||||
|
||||
@celery_app.task(
|
||||
name="autogenerate_usage_report_task",
|
||||
ignore_result=True,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
)
|
||||
def autogenerate_usage_report_task(*, tenant_id: str | None) -> None:
|
||||
|
||||
@@ -13,6 +13,7 @@ from onyx.connectors.confluence.onyx_confluence import OnyxConfluence
|
||||
from onyx.connectors.confluence.utils import get_user_email_from_username__server
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -257,6 +258,7 @@ def _fetch_all_page_restrictions(
|
||||
slim_docs: list[SlimDocument],
|
||||
space_permissions_by_space_key: dict[str, ExternalAccess],
|
||||
is_cloud: bool,
|
||||
callback: IndexingHeartbeatInterface | None,
|
||||
) -> list[DocExternalAccess]:
|
||||
"""
|
||||
For all pages, if a page has restrictions, then use those restrictions.
|
||||
@@ -265,6 +267,12 @@ def _fetch_all_page_restrictions(
|
||||
document_restrictions: list[DocExternalAccess] = []
|
||||
|
||||
for slim_doc in slim_docs:
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
raise RuntimeError("confluence_doc_sync: Stop signal detected")
|
||||
|
||||
callback.progress("confluence_doc_sync:fetch_all_page_restrictions", 1)
|
||||
|
||||
if slim_doc.perm_sync_data is None:
|
||||
raise ValueError(
|
||||
f"No permission sync data found for document {slim_doc.id}"
|
||||
@@ -334,7 +342,7 @@ def _fetch_all_page_restrictions(
|
||||
|
||||
|
||||
def confluence_doc_sync(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
|
||||
) -> list[DocExternalAccess]:
|
||||
"""
|
||||
Adds the external permissions to the documents in postgres
|
||||
@@ -359,6 +367,12 @@ def confluence_doc_sync(
|
||||
logger.debug("Fetching all slim documents from confluence")
|
||||
for doc_batch in confluence_connector.retrieve_all_slim_documents():
|
||||
logger.debug(f"Got {len(doc_batch)} slim documents from confluence")
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
raise RuntimeError("confluence_doc_sync: Stop signal detected")
|
||||
|
||||
callback.progress("confluence_doc_sync", 1)
|
||||
|
||||
slim_docs.extend(doc_batch)
|
||||
|
||||
logger.debug("Fetching all page restrictions for space")
|
||||
@@ -367,4 +381,5 @@ def confluence_doc_sync(
|
||||
slim_docs=slim_docs,
|
||||
space_permissions_by_space_key=space_permissions_by_space_key,
|
||||
is_cloud=is_cloud,
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
@@ -14,6 +14,8 @@ def _build_group_member_email_map(
|
||||
) -> dict[str, set[str]]:
|
||||
group_member_emails: dict[str, set[str]] = {}
|
||||
for user_result in confluence_client.paginated_cql_user_retrieval():
|
||||
logger.debug(f"Processing groups for user: {user_result}")
|
||||
|
||||
user = user_result.get("user", {})
|
||||
if not user:
|
||||
logger.warning(f"user result missing user field: {user_result}")
|
||||
@@ -33,10 +35,17 @@ def _build_group_member_email_map(
|
||||
logger.warning(f"user result missing email field: {user_result}")
|
||||
continue
|
||||
|
||||
all_users_groups: set[str] = set()
|
||||
for group in confluence_client.paginated_groups_by_user_retrieval(user):
|
||||
# group name uniqueness is enforced by Confluence, so we can use it as a group ID
|
||||
group_id = group["name"]
|
||||
group_member_emails.setdefault(group_id, set()).add(email)
|
||||
all_users_groups.add(group_id)
|
||||
|
||||
if not group_member_emails:
|
||||
logger.warning(f"No groups found for user with email: {email}")
|
||||
else:
|
||||
logger.debug(f"Found groups {all_users_groups} for user with email {email}")
|
||||
|
||||
return group_member_emails
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ from onyx.access.models import ExternalAccess
|
||||
from onyx.connectors.gmail.connector import GmailConnector
|
||||
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -28,7 +29,7 @@ def _get_slim_doc_generator(
|
||||
|
||||
|
||||
def gmail_doc_sync(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
|
||||
) -> list[DocExternalAccess]:
|
||||
"""
|
||||
Adds the external permissions to the documents in postgres
|
||||
@@ -44,6 +45,12 @@ def gmail_doc_sync(
|
||||
document_external_access: list[DocExternalAccess] = []
|
||||
for slim_doc_batch in slim_doc_generator:
|
||||
for slim_doc in slim_doc_batch:
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
raise RuntimeError("gmail_doc_sync: Stop signal detected")
|
||||
|
||||
callback.progress("gmail_doc_sync", 1)
|
||||
|
||||
if slim_doc.perm_sync_data is None:
|
||||
logger.warning(f"No permissions found for document {slim_doc.id}")
|
||||
continue
|
||||
|
||||
@@ -10,6 +10,7 @@ from onyx.connectors.google_utils.resources import get_drive_service
|
||||
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -42,24 +43,22 @@ def _fetch_permissions_for_permission_ids(
|
||||
if not permission_info or not doc_id:
|
||||
return []
|
||||
|
||||
# Check cache first for all permission IDs
|
||||
permissions = [
|
||||
_PERMISSION_ID_PERMISSION_MAP[pid]
|
||||
for pid in permission_ids
|
||||
if pid in _PERMISSION_ID_PERMISSION_MAP
|
||||
]
|
||||
|
||||
# If we found all permissions in cache, return them
|
||||
if len(permissions) == len(permission_ids):
|
||||
return permissions
|
||||
|
||||
owner_email = permission_info.get("owner_email")
|
||||
|
||||
drive_service = get_drive_service(
|
||||
creds=google_drive_connector.creds,
|
||||
user_email=(owner_email or google_drive_connector.primary_admin_email),
|
||||
)
|
||||
|
||||
# Otherwise, fetch all permissions and update cache
|
||||
fetched_permissions = execute_paginated_retrieval(
|
||||
retrieval_function=drive_service.permissions().list,
|
||||
list_key="permissions",
|
||||
@@ -69,7 +68,6 @@ def _fetch_permissions_for_permission_ids(
|
||||
)
|
||||
|
||||
permissions_for_doc_id = []
|
||||
# Update cache and return all permissions
|
||||
for permission in fetched_permissions:
|
||||
permissions_for_doc_id.append(permission)
|
||||
_PERMISSION_ID_PERMISSION_MAP[permission["id"]] = permission
|
||||
@@ -131,7 +129,7 @@ def _get_permissions_from_slim_doc(
|
||||
|
||||
|
||||
def gdrive_doc_sync(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
|
||||
) -> list[DocExternalAccess]:
|
||||
"""
|
||||
Adds the external permissions to the documents in postgres
|
||||
@@ -149,6 +147,12 @@ def gdrive_doc_sync(
|
||||
document_external_accesses = []
|
||||
for slim_doc_batch in slim_doc_generator:
|
||||
for slim_doc in slim_doc_batch:
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
raise RuntimeError("gdrive_doc_sync: Stop signal detected")
|
||||
|
||||
callback.progress("gdrive_doc_sync", 1)
|
||||
|
||||
ext_access = _get_permissions_from_slim_doc(
|
||||
google_drive_connector=google_drive_connector,
|
||||
slim_doc=slim_doc,
|
||||
|
||||
@@ -7,6 +7,7 @@ from onyx.connectors.slack.connector import get_channels
|
||||
from onyx.connectors.slack.connector import make_paginated_slack_api_call_w_retries
|
||||
from onyx.connectors.slack.connector import SlackPollConnector
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
@@ -14,7 +15,7 @@ logger = setup_logger()
|
||||
|
||||
|
||||
def _get_slack_document_ids_and_channels(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
|
||||
) -> dict[str, list[str]]:
|
||||
slack_connector = SlackPollConnector(**cc_pair.connector.connector_specific_config)
|
||||
slack_connector.load_credentials(cc_pair.credential.credential_json)
|
||||
@@ -24,6 +25,14 @@ def _get_slack_document_ids_and_channels(
|
||||
channel_doc_map: dict[str, list[str]] = {}
|
||||
for doc_metadata_batch in slim_doc_generator:
|
||||
for doc_metadata in doc_metadata_batch:
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
raise RuntimeError(
|
||||
"_get_slack_document_ids_and_channels: Stop signal detected"
|
||||
)
|
||||
|
||||
callback.progress("_get_slack_document_ids_and_channels", 1)
|
||||
|
||||
if doc_metadata.perm_sync_data is None:
|
||||
continue
|
||||
channel_id = doc_metadata.perm_sync_data["channel_id"]
|
||||
@@ -114,7 +123,7 @@ def _fetch_channel_permissions(
|
||||
|
||||
|
||||
def slack_doc_sync(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
|
||||
) -> list[DocExternalAccess]:
|
||||
"""
|
||||
Adds the external permissions to the documents in postgres
|
||||
@@ -127,7 +136,7 @@ def slack_doc_sync(
|
||||
)
|
||||
user_id_to_email_map = fetch_user_id_to_email_map(slack_client)
|
||||
channel_doc_map = _get_slack_document_ids_and_channels(
|
||||
cc_pair=cc_pair,
|
||||
cc_pair=cc_pair, callback=callback
|
||||
)
|
||||
workspace_permissions = _fetch_workspace_permissions(
|
||||
user_id_to_email_map=user_id_to_email_map,
|
||||
|
||||
@@ -15,11 +15,13 @@ from ee.onyx.external_permissions.slack.doc_sync import slack_doc_sync
|
||||
from onyx.access.models import DocExternalAccess
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
|
||||
# Defining the input/output types for the sync functions
|
||||
DocSyncFuncType = Callable[
|
||||
[
|
||||
ConnectorCredentialPair,
|
||||
IndexingHeartbeatInterface | None,
|
||||
],
|
||||
list[DocExternalAccess],
|
||||
]
|
||||
|
||||
@@ -80,7 +80,7 @@ def oneoff_standard_answers(
|
||||
def _handle_standard_answers(
|
||||
message_info: SlackMessageInfo,
|
||||
receiver_ids: list[str] | None,
|
||||
slack_channel_config: SlackChannelConfig | None,
|
||||
slack_channel_config: SlackChannelConfig,
|
||||
prompt: Prompt | None,
|
||||
logger: OnyxLoggingAdapter,
|
||||
client: WebClient,
|
||||
@@ -94,13 +94,10 @@ def _handle_standard_answers(
|
||||
Returns True if standard answers are found to match the user's message and therefore,
|
||||
we still need to respond to the users.
|
||||
"""
|
||||
# if no channel config, then no standard answers are configured
|
||||
if not slack_channel_config:
|
||||
return False
|
||||
|
||||
slack_thread_id = message_info.thread_to_respond
|
||||
configured_standard_answer_categories = (
|
||||
slack_channel_config.standard_answer_categories if slack_channel_config else []
|
||||
slack_channel_config.standard_answer_categories
|
||||
)
|
||||
configured_standard_answers = set(
|
||||
[
|
||||
|
||||
@@ -10,6 +10,7 @@ from fastapi import Response
|
||||
from ee.onyx.auth.users import decode_anonymous_user_jwt_token
|
||||
from ee.onyx.configs.app_configs import ANONYMOUS_USER_COOKIE_NAME
|
||||
from onyx.auth.api_key import extract_tenant_from_api_key_header
|
||||
from onyx.configs.constants import TENANT_ID_COOKIE_NAME
|
||||
from onyx.db.engine import is_valid_schema_name
|
||||
from onyx.redis.redis_pool import retrieve_auth_token_data_from_redis
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
@@ -43,6 +44,7 @@ async def _get_tenant_id_from_request(
|
||||
Attempt to extract tenant_id from:
|
||||
1) The API key header
|
||||
2) The Redis-based token (stored in Cookie: fastapiusersauth)
|
||||
3) Reset token cookie
|
||||
Fallback: POSTGRES_DEFAULT_SCHEMA
|
||||
"""
|
||||
# Check for API key
|
||||
@@ -90,3 +92,12 @@ async def _get_tenant_id_from_request(
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in _get_tenant_id_from_request: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
finally:
|
||||
# As a final step, check for explicit tenant_id cookie
|
||||
tenant_id_cookie = request.cookies.get(TENANT_ID_COOKIE_NAME)
|
||||
if tenant_id_cookie and is_valid_schema_name(tenant_id_cookie):
|
||||
return tenant_id_cookie
|
||||
|
||||
# If we've reached this point, return the default schema
|
||||
return POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
@@ -286,6 +286,7 @@ def prepare_authorization_request(
|
||||
oauth_state = (
|
||||
base64.urlsafe_b64encode(oauth_uuid.bytes).rstrip(b"=").decode("utf-8")
|
||||
)
|
||||
session: str
|
||||
|
||||
if connector == DocumentSource.SLACK:
|
||||
oauth_url = SlackOAuth.generate_oauth_url(oauth_state)
|
||||
@@ -554,6 +555,7 @@ def handle_google_drive_oauth_callback(
|
||||
)
|
||||
|
||||
session_json = session_json_bytes.decode("utf-8")
|
||||
session: GoogleDriveOAuth.OAuthSession
|
||||
try:
|
||||
session = GoogleDriveOAuth.parse_session(session_json)
|
||||
|
||||
|
||||
@@ -179,6 +179,7 @@ def handle_simplified_chat_message(
|
||||
chunks_below=0,
|
||||
full_doc=chat_message_req.full_doc,
|
||||
structured_response_format=chat_message_req.structured_response_format,
|
||||
use_agentic_search=chat_message_req.use_agentic_search,
|
||||
)
|
||||
|
||||
packets = stream_chat_message_objects(
|
||||
@@ -301,6 +302,7 @@ def handle_send_message_simple_with_history(
|
||||
chunks_below=0,
|
||||
full_doc=req.full_doc,
|
||||
structured_response_format=req.structured_response_format,
|
||||
use_agentic_search=req.use_agentic_search,
|
||||
)
|
||||
|
||||
packets = stream_chat_message_objects(
|
||||
|
||||
@@ -57,6 +57,9 @@ class BasicCreateChatMessageRequest(ChunkContext):
|
||||
# https://platform.openai.com/docs/guides/structured-outputs/introduction
|
||||
structured_response_format: dict | None = None
|
||||
|
||||
# If True, uses agentic search instead of basic search
|
||||
use_agentic_search: bool = False
|
||||
|
||||
|
||||
class BasicCreateChatMessageWithHistoryRequest(ChunkContext):
|
||||
# Last element is the new query. All previous elements are historical context
|
||||
@@ -71,6 +74,8 @@ class BasicCreateChatMessageWithHistoryRequest(ChunkContext):
|
||||
# only works if using an OpenAI model. See the following for more details:
|
||||
# https://platform.openai.com/docs/guides/structured-outputs/introduction
|
||||
structured_response_format: dict | None = None
|
||||
# If True, uses agentic search instead of basic search
|
||||
use_agentic_search: bool = False
|
||||
|
||||
|
||||
class SimpleDoc(BaseModel):
|
||||
@@ -120,9 +125,12 @@ class OneShotQARequest(ChunkContext):
|
||||
# will also disable Thread-based Rewording if specified
|
||||
query_override: str | None = None
|
||||
|
||||
# If True, skips generative an AI response to the search query
|
||||
# If True, skips generating an AI response to the search query
|
||||
skip_gen_ai_answer_generation: bool = False
|
||||
|
||||
# If True, uses agentic search instead of basic search
|
||||
use_agentic_search: bool = False
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_persona_fields(self) -> "OneShotQARequest":
|
||||
if self.persona_override_config is None and self.persona_id is None:
|
||||
|
||||
@@ -196,6 +196,8 @@ def get_answer_stream(
|
||||
retrieval_details=query_request.retrieval_options,
|
||||
rerank_settings=query_request.rerank_settings,
|
||||
db_session=db_session,
|
||||
use_agentic_search=query_request.use_agentic_search,
|
||||
skip_gen_ai_answer_generation=query_request.skip_gen_ai_answer_generation,
|
||||
)
|
||||
|
||||
packets = stream_chat_message_objects(
|
||||
|
||||
@@ -34,6 +34,7 @@ from onyx.auth.users import get_redis_strategy
|
||||
from onyx.auth.users import optional_user
|
||||
from onyx.auth.users import User
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.configs.constants import FASTAPI_USERS_AUTH_COOKIE_NAME
|
||||
from onyx.db.auth import get_user_count
|
||||
from onyx.db.engine import get_current_tenant_id
|
||||
from onyx.db.engine import get_session
|
||||
@@ -111,6 +112,7 @@ async def login_as_anonymous_user(
|
||||
token = generate_anonymous_user_jwt_token(tenant_id)
|
||||
|
||||
response = Response()
|
||||
response.delete_cookie(FASTAPI_USERS_AUTH_COOKIE_NAME)
|
||||
response.set_cookie(
|
||||
key=ANONYMOUS_USER_COOKIE_NAME,
|
||||
value=token,
|
||||
|
||||
@@ -58,6 +58,7 @@ class UserGroup(BaseModel):
|
||||
credential=CredentialSnapshot.from_credential_db_model(
|
||||
cc_pair_relationship.cc_pair.credential
|
||||
),
|
||||
access_type=cc_pair_relationship.cc_pair.access_type,
|
||||
)
|
||||
for cc_pair_relationship in user_group_model.cc_pair_relationships
|
||||
if cc_pair_relationship.is_current
|
||||
|
||||
97
backend/onyx/agents/agent_search/basic/graph_builder.py
Normal file
97
backend/onyx/agents/agent_search/basic/graph_builder.py
Normal file
@@ -0,0 +1,97 @@
|
||||
from langgraph.graph import END
|
||||
from langgraph.graph import START
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
from onyx.agents.agent_search.basic.states import BasicInput
|
||||
from onyx.agents.agent_search.basic.states import BasicOutput
|
||||
from onyx.agents.agent_search.basic.states import BasicState
|
||||
from onyx.agents.agent_search.orchestration.nodes.basic_use_tool_response import (
|
||||
basic_use_tool_response,
|
||||
)
|
||||
from onyx.agents.agent_search.orchestration.nodes.llm_tool_choice import llm_tool_choice
|
||||
from onyx.agents.agent_search.orchestration.nodes.prepare_tool_input import (
|
||||
prepare_tool_input,
|
||||
)
|
||||
from onyx.agents.agent_search.orchestration.nodes.tool_call import tool_call
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def basic_graph_builder() -> StateGraph:
|
||||
graph = StateGraph(
|
||||
state_schema=BasicState,
|
||||
input=BasicInput,
|
||||
output=BasicOutput,
|
||||
)
|
||||
|
||||
### Add nodes ###
|
||||
|
||||
graph.add_node(
|
||||
node="prepare_tool_input",
|
||||
action=prepare_tool_input,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
node="llm_tool_choice",
|
||||
action=llm_tool_choice,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
node="tool_call",
|
||||
action=tool_call,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
node="basic_use_tool_response",
|
||||
action=basic_use_tool_response,
|
||||
)
|
||||
|
||||
### Add edges ###
|
||||
|
||||
graph.add_edge(start_key=START, end_key="prepare_tool_input")
|
||||
|
||||
graph.add_edge(start_key="prepare_tool_input", end_key="llm_tool_choice")
|
||||
|
||||
graph.add_conditional_edges("llm_tool_choice", should_continue, ["tool_call", END])
|
||||
|
||||
graph.add_edge(
|
||||
start_key="tool_call",
|
||||
end_key="basic_use_tool_response",
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key="basic_use_tool_response",
|
||||
end_key=END,
|
||||
)
|
||||
|
||||
return graph
|
||||
|
||||
|
||||
def should_continue(state: BasicState) -> str:
|
||||
return (
|
||||
# If there are no tool calls, basic graph already streamed the answer
|
||||
END
|
||||
if state.tool_choice is None
|
||||
else "tool_call"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.context.search.models import SearchRequest
|
||||
from onyx.llm.factory import get_default_llms
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
|
||||
|
||||
graph = basic_graph_builder()
|
||||
compiled_graph = graph.compile()
|
||||
input = BasicInput(_unused=True)
|
||||
primary_llm, fast_llm = get_default_llms()
|
||||
with get_session_context_manager() as db_session:
|
||||
config, _ = get_test_config(
|
||||
db_session=db_session,
|
||||
primary_llm=primary_llm,
|
||||
fast_llm=fast_llm,
|
||||
search_request=SearchRequest(query="How does onyx use FastAPI?"),
|
||||
)
|
||||
compiled_graph.invoke(input, config={"metadata": {"config": config}})
|
||||
35
backend/onyx/agents/agent_search/basic/states.py
Normal file
35
backend/onyx/agents/agent_search/basic/states.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from typing import TypedDict
|
||||
|
||||
from langchain_core.messages import AIMessageChunk
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.agents.agent_search.orchestration.states import ToolCallUpdate
|
||||
from onyx.agents.agent_search.orchestration.states import ToolChoiceInput
|
||||
from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
|
||||
|
||||
# States contain values that change over the course of graph execution,
|
||||
# Config is for values that are set at the start and never change.
|
||||
# If you are using a value from the config and realize it needs to change,
|
||||
# you should add it to the state and use/update the version in the state.
|
||||
|
||||
|
||||
## Graph Input State
|
||||
class BasicInput(BaseModel):
|
||||
# Langgraph needs a nonempty input, but we pass in all static
|
||||
# data through a RunnableConfig.
|
||||
_unused: bool = True
|
||||
|
||||
|
||||
## Graph Output State
|
||||
class BasicOutput(TypedDict):
|
||||
tool_call_chunk: AIMessageChunk
|
||||
|
||||
|
||||
## Graph State
|
||||
class BasicState(
|
||||
BasicInput,
|
||||
ToolChoiceInput,
|
||||
ToolCallUpdate,
|
||||
ToolChoiceUpdate,
|
||||
):
|
||||
pass
|
||||
64
backend/onyx/agents/agent_search/basic/utils.py
Normal file
64
backend/onyx/agents/agent_search/basic/utils.py
Normal file
@@ -0,0 +1,64 @@
|
||||
from collections.abc import Iterator
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import AIMessageChunk
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.chat.models import LlmDoc
|
||||
from onyx.chat.models import OnyxContext
|
||||
from onyx.chat.stream_processing.answer_response_handler import AnswerResponseHandler
|
||||
from onyx.chat.stream_processing.answer_response_handler import CitationResponseHandler
|
||||
from onyx.chat.stream_processing.answer_response_handler import (
|
||||
PassThroughAnswerResponseHandler,
|
||||
)
|
||||
from onyx.chat.stream_processing.utils import map_document_id_order
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def process_llm_stream(
|
||||
messages: Iterator[BaseMessage],
|
||||
should_stream_answer: bool,
|
||||
writer: StreamWriter,
|
||||
final_search_results: list[LlmDoc] | None = None,
|
||||
displayed_search_results: list[OnyxContext] | list[LlmDoc] | None = None,
|
||||
) -> AIMessageChunk:
|
||||
tool_call_chunk = AIMessageChunk(content="")
|
||||
|
||||
if final_search_results and displayed_search_results:
|
||||
answer_handler: AnswerResponseHandler = CitationResponseHandler(
|
||||
context_docs=final_search_results,
|
||||
final_doc_id_to_rank_map=map_document_id_order(final_search_results),
|
||||
display_doc_id_to_rank_map=map_document_id_order(displayed_search_results),
|
||||
)
|
||||
else:
|
||||
answer_handler = PassThroughAnswerResponseHandler()
|
||||
|
||||
full_answer = ""
|
||||
# This stream will be the llm answer if no tool is chosen. When a tool is chosen,
|
||||
# the stream will contain AIMessageChunks with tool call information.
|
||||
for message in messages:
|
||||
answer_piece = message.content
|
||||
if not isinstance(answer_piece, str):
|
||||
# this is only used for logging, so fine to
|
||||
# just add the string representation
|
||||
answer_piece = str(answer_piece)
|
||||
full_answer += answer_piece
|
||||
|
||||
if isinstance(message, AIMessageChunk) and (
|
||||
message.tool_call_chunks or message.tool_calls
|
||||
):
|
||||
tool_call_chunk += message # type: ignore
|
||||
elif should_stream_answer:
|
||||
for response_part in answer_handler.handle_response_part(message, []):
|
||||
write_custom_event(
|
||||
"basic_response",
|
||||
response_part,
|
||||
writer,
|
||||
)
|
||||
|
||||
logger.debug(f"Full answer: {full_answer}")
|
||||
return cast(AIMessageChunk, tool_call_chunk)
|
||||
21
backend/onyx/agents/agent_search/core_state.py
Normal file
21
backend/onyx/agents/agent_search/core_state.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from operator import add
|
||||
from typing import Annotated
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class CoreState(BaseModel):
|
||||
"""
|
||||
This is the core state that is shared across all subgraphs.
|
||||
"""
|
||||
|
||||
base_question: str = ""
|
||||
log_messages: Annotated[list[str], add] = []
|
||||
|
||||
|
||||
class SubgraphCoreState(BaseModel):
|
||||
"""
|
||||
This is the core state that is shared across all subgraphs.
|
||||
"""
|
||||
|
||||
log_messages: Annotated[list[str], add]
|
||||
@@ -0,0 +1,31 @@
|
||||
from collections.abc import Hashable
|
||||
from datetime import datetime
|
||||
|
||||
from langgraph.types import Send
|
||||
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
|
||||
SubQuestionAnsweringInput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
|
||||
ExpandedRetrievalInput,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def send_to_expanded_retrieval(state: SubQuestionAnsweringInput) -> Send | Hashable:
|
||||
"""
|
||||
LangGraph edge to send a sub-question to the expanded retrieval.
|
||||
"""
|
||||
edge_start_time = datetime.now()
|
||||
|
||||
return Send(
|
||||
"initial_sub_question_expanded_retrieval",
|
||||
ExpandedRetrievalInput(
|
||||
question=state.question,
|
||||
base_search=False,
|
||||
sub_question_id=state.question_id,
|
||||
log_messages=[f"{edge_start_time} -- Sending to expanded retrieval"],
|
||||
),
|
||||
)
|
||||
@@ -0,0 +1,137 @@
|
||||
from langgraph.graph import END
|
||||
from langgraph.graph import START
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.edges import (
|
||||
send_to_expanded_retrieval,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.nodes.check_sub_answer import (
|
||||
check_sub_answer,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.nodes.format_sub_answer import (
|
||||
format_sub_answer,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.nodes.generate_sub_answer import (
|
||||
generate_sub_answer,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.nodes.ingest_retrieved_documents import (
|
||||
ingest_retrieved_documents,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
|
||||
AnswerQuestionOutput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
|
||||
AnswerQuestionState,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
|
||||
SubQuestionAnsweringInput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.graph_builder import (
|
||||
expanded_retrieval_graph_builder,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def answer_query_graph_builder() -> StateGraph:
|
||||
"""
|
||||
LangGraph sub-graph builder for the initial individual sub-answer generation.
|
||||
"""
|
||||
graph = StateGraph(
|
||||
state_schema=AnswerQuestionState,
|
||||
input=SubQuestionAnsweringInput,
|
||||
output=AnswerQuestionOutput,
|
||||
)
|
||||
|
||||
### Add nodes ###
|
||||
|
||||
# The sub-graph that executes the expanded retrieval process for a sub-question
|
||||
expanded_retrieval = expanded_retrieval_graph_builder().compile()
|
||||
graph.add_node(
|
||||
node="initial_sub_question_expanded_retrieval",
|
||||
action=expanded_retrieval,
|
||||
)
|
||||
|
||||
# The node that ingests the retrieved documents and puts them into the proper
|
||||
# state keys.
|
||||
graph.add_node(
|
||||
node="ingest_retrieval",
|
||||
action=ingest_retrieved_documents,
|
||||
)
|
||||
|
||||
# The node that generates the sub-answer
|
||||
graph.add_node(
|
||||
node="generate_sub_answer",
|
||||
action=generate_sub_answer,
|
||||
)
|
||||
|
||||
# The node that checks the sub-answer
|
||||
graph.add_node(
|
||||
node="answer_check",
|
||||
action=check_sub_answer,
|
||||
)
|
||||
|
||||
# The node that formats the sub-answer for the following initial answer generation
|
||||
graph.add_node(
|
||||
node="format_answer",
|
||||
action=format_sub_answer,
|
||||
)
|
||||
|
||||
### Add edges ###
|
||||
|
||||
graph.add_conditional_edges(
|
||||
source=START,
|
||||
path=send_to_expanded_retrieval,
|
||||
path_map=["initial_sub_question_expanded_retrieval"],
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="initial_sub_question_expanded_retrieval",
|
||||
end_key="ingest_retrieval",
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="ingest_retrieval",
|
||||
end_key="generate_sub_answer",
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="generate_sub_answer",
|
||||
end_key="answer_check",
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="answer_check",
|
||||
end_key="format_answer",
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="format_answer",
|
||||
end_key=END,
|
||||
)
|
||||
|
||||
return graph
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.llm.factory import get_default_llms
|
||||
from onyx.context.search.models import SearchRequest
|
||||
|
||||
graph = answer_query_graph_builder()
|
||||
compiled_graph = graph.compile()
|
||||
primary_llm, fast_llm = get_default_llms()
|
||||
search_request = SearchRequest(
|
||||
query="what can you do with onyx or danswer?",
|
||||
)
|
||||
with get_session_context_manager() as db_session:
|
||||
graph_config, search_tool = get_test_config(
|
||||
db_session, primary_llm, fast_llm, search_request
|
||||
)
|
||||
inputs = SubQuestionAnsweringInput(
|
||||
question="what can you do with onyx?",
|
||||
question_id="0_0",
|
||||
log_messages=[],
|
||||
)
|
||||
for thing in compiled_graph.stream(
|
||||
input=inputs,
|
||||
config={"configurable": {"config": graph_config}},
|
||||
):
|
||||
logger.debug(thing)
|
||||
@@ -0,0 +1,75 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import merge_message_runs
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
|
||||
AnswerQuestionState,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
|
||||
SubQuestionAnswerCheckUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
|
||||
from onyx.prompts.agent_search import SUB_ANSWER_CHECK_PROMPT
|
||||
from onyx.prompts.agent_search import UNKNOWN_ANSWER
|
||||
|
||||
|
||||
def check_sub_answer(
|
||||
state: AnswerQuestionState, config: RunnableConfig
|
||||
) -> SubQuestionAnswerCheckUpdate:
|
||||
"""
|
||||
LangGraph node to check the quality of the sub-answer. The answer
|
||||
is represented as a boolean value.
|
||||
"""
|
||||
node_start_time = datetime.now()
|
||||
|
||||
level, question_num = parse_question_id(state.question_id)
|
||||
if state.answer == UNKNOWN_ANSWER:
|
||||
return SubQuestionAnswerCheckUpdate(
|
||||
answer_quality=False,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="initial - generate individual sub answer",
|
||||
node_name="check sub answer",
|
||||
node_start_time=node_start_time,
|
||||
result="unknown answer",
|
||||
)
|
||||
],
|
||||
)
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=SUB_ANSWER_CHECK_PROMPT.format(
|
||||
question=state.question,
|
||||
base_answer=state.answer,
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
fast_llm = graph_config.tooling.fast_llm
|
||||
response = list(
|
||||
fast_llm.stream(
|
||||
prompt=msg,
|
||||
)
|
||||
)
|
||||
|
||||
quality_str: str = merge_message_runs(response, chunk_separator="")[0].content
|
||||
answer_quality = "yes" in quality_str.lower()
|
||||
|
||||
return SubQuestionAnswerCheckUpdate(
|
||||
answer_quality=answer_quality,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="initial - generate individual sub answer",
|
||||
node_name="check sub answer",
|
||||
node_start_time=node_start_time,
|
||||
result=f"Answer quality: {quality_str}",
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,30 @@
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
|
||||
AnswerQuestionOutput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
|
||||
AnswerQuestionState,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import (
|
||||
SubQuestionAnswerResults,
|
||||
)
|
||||
|
||||
|
||||
def format_sub_answer(state: AnswerQuestionState) -> AnswerQuestionOutput:
|
||||
"""
|
||||
LangGraph node to generate the sub-answer format.
|
||||
"""
|
||||
return AnswerQuestionOutput(
|
||||
answer_results=[
|
||||
SubQuestionAnswerResults(
|
||||
question=state.question,
|
||||
question_id=state.question_id,
|
||||
verified_high_quality=state.answer_quality,
|
||||
answer=state.answer,
|
||||
sub_query_retrieval_results=state.expanded_retrieval_results,
|
||||
verified_reranked_documents=state.verified_reranked_documents,
|
||||
context_documents=state.context_documents,
|
||||
cited_documents=state.cited_documents,
|
||||
sub_question_retrieval_stats=state.sub_question_retrieval_stats,
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,137 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import merge_message_runs
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
|
||||
AnswerQuestionState,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
|
||||
SubQuestionAnswerGenerationUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
build_sub_question_answer_prompt,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import get_answer_citation_ids
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_persona_agent_prompt_expressions,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.chat.models import AgentAnswerPiece
|
||||
from onyx.chat.models import StreamStopInfo
|
||||
from onyx.chat.models import StreamStopReason
|
||||
from onyx.chat.models import StreamType
|
||||
from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS
|
||||
from onyx.prompts.agent_search import NO_RECOVERED_DOCS
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def generate_sub_answer(
|
||||
state: AnswerQuestionState,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
) -> SubQuestionAnswerGenerationUpdate:
|
||||
"""
|
||||
LangGraph node to generate a sub-answer.
|
||||
"""
|
||||
node_start_time = datetime.now()
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
question = state.question
|
||||
state.verified_reranked_documents
|
||||
level, question_num = parse_question_id(state.question_id)
|
||||
context_docs = state.context_documents[:AGENT_MAX_ANSWER_CONTEXT_DOCS]
|
||||
persona_contextualized_prompt = get_persona_agent_prompt_expressions(
|
||||
graph_config.inputs.search_request.persona
|
||||
).contextualized_prompt
|
||||
|
||||
if len(context_docs) == 0:
|
||||
answer_str = NO_RECOVERED_DOCS
|
||||
write_custom_event(
|
||||
"sub_answers",
|
||||
AgentAnswerPiece(
|
||||
answer_piece=answer_str,
|
||||
level=level,
|
||||
level_question_num=question_num,
|
||||
answer_type="agent_sub_answer",
|
||||
),
|
||||
writer,
|
||||
)
|
||||
else:
|
||||
fast_llm = graph_config.tooling.fast_llm
|
||||
msg = build_sub_question_answer_prompt(
|
||||
question=question,
|
||||
original_question=graph_config.inputs.search_request.query,
|
||||
docs=context_docs,
|
||||
persona_specification=persona_contextualized_prompt,
|
||||
config=fast_llm.config,
|
||||
)
|
||||
|
||||
response: list[str | list[str | dict[str, Any]]] = []
|
||||
dispatch_timings: list[float] = []
|
||||
for message in fast_llm.stream(
|
||||
prompt=msg,
|
||||
):
|
||||
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
|
||||
content = message.content
|
||||
if not isinstance(content, str):
|
||||
raise ValueError(
|
||||
f"Expected content to be a string, but got {type(content)}"
|
||||
)
|
||||
start_stream_token = datetime.now()
|
||||
write_custom_event(
|
||||
"sub_answers",
|
||||
AgentAnswerPiece(
|
||||
answer_piece=content,
|
||||
level=level,
|
||||
level_question_num=question_num,
|
||||
answer_type="agent_sub_answer",
|
||||
),
|
||||
writer,
|
||||
)
|
||||
end_stream_token = datetime.now()
|
||||
dispatch_timings.append(
|
||||
(end_stream_token - start_stream_token).microseconds
|
||||
)
|
||||
response.append(content)
|
||||
|
||||
answer_str = merge_message_runs(response, chunk_separator="")[0].content
|
||||
logger.debug(
|
||||
f"Average dispatch time: {sum(dispatch_timings) / len(dispatch_timings)}"
|
||||
)
|
||||
|
||||
answer_citation_ids = get_answer_citation_ids(answer_str)
|
||||
cited_documents = [
|
||||
context_docs[id] for id in answer_citation_ids if id < len(context_docs)
|
||||
]
|
||||
|
||||
stop_event = StreamStopInfo(
|
||||
stop_reason=StreamStopReason.FINISHED,
|
||||
stream_type=StreamType.SUB_ANSWER,
|
||||
level=level,
|
||||
level_question_num=question_num,
|
||||
)
|
||||
write_custom_event("stream_finished", stop_event, writer)
|
||||
|
||||
return SubQuestionAnswerGenerationUpdate(
|
||||
answer=answer_str,
|
||||
cited_documents=cited_documents,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="initial - generate individual sub answer",
|
||||
node_name="generate sub answer",
|
||||
node_start_time=node_start_time,
|
||||
result="",
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,25 @@
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
|
||||
SubQuestionRetrievalIngestionUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
|
||||
ExpandedRetrievalOutput,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkRetrievalStats
|
||||
|
||||
|
||||
def ingest_retrieved_documents(
|
||||
state: ExpandedRetrievalOutput,
|
||||
) -> SubQuestionRetrievalIngestionUpdate:
|
||||
"""
|
||||
LangGraph node to ingest the retrieved documents to format it for the sub-answer.
|
||||
"""
|
||||
sub_question_retrieval_stats = state.expanded_retrieval_result.retrieval_stats
|
||||
if sub_question_retrieval_stats is None:
|
||||
sub_question_retrieval_stats = [AgentChunkRetrievalStats()]
|
||||
|
||||
return SubQuestionRetrievalIngestionUpdate(
|
||||
expanded_retrieval_results=state.expanded_retrieval_result.expanded_query_results,
|
||||
verified_reranked_documents=state.expanded_retrieval_result.verified_reranked_documents,
|
||||
context_documents=state.expanded_retrieval_result.context_documents,
|
||||
sub_question_retrieval_stats=sub_question_retrieval_stats,
|
||||
)
|
||||
@@ -0,0 +1,75 @@
|
||||
from operator import add
|
||||
from typing import Annotated
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.agents.agent_search.core_state import SubgraphCoreState
|
||||
from onyx.agents.agent_search.deep_search.main.states import LoggerUpdate
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkRetrievalStats
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import QueryRetrievalResult
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import (
|
||||
SubQuestionAnswerResults,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.operators import (
|
||||
dedup_inference_sections,
|
||||
)
|
||||
from onyx.context.search.models import InferenceSection
|
||||
|
||||
|
||||
## Update States
|
||||
class SubQuestionAnswerCheckUpdate(LoggerUpdate, BaseModel):
|
||||
answer_quality: bool = False
|
||||
log_messages: list[str] = []
|
||||
|
||||
|
||||
class SubQuestionAnswerGenerationUpdate(LoggerUpdate, BaseModel):
|
||||
answer: str = ""
|
||||
log_messages: list[str] = []
|
||||
cited_documents: Annotated[list[InferenceSection], dedup_inference_sections] = []
|
||||
# answer_stat: AnswerStats
|
||||
|
||||
|
||||
class SubQuestionRetrievalIngestionUpdate(LoggerUpdate, BaseModel):
|
||||
expanded_retrieval_results: list[QueryRetrievalResult] = []
|
||||
verified_reranked_documents: Annotated[
|
||||
list[InferenceSection], dedup_inference_sections
|
||||
] = []
|
||||
context_documents: Annotated[list[InferenceSection], dedup_inference_sections] = []
|
||||
sub_question_retrieval_stats: AgentChunkRetrievalStats = AgentChunkRetrievalStats()
|
||||
|
||||
|
||||
## Graph Input State
|
||||
|
||||
|
||||
class SubQuestionAnsweringInput(SubgraphCoreState):
|
||||
question: str = ""
|
||||
question_id: str = (
|
||||
"" # 0_0 is original question, everything else is <level>_<question_num>.
|
||||
)
|
||||
# level 0 is original question and first decomposition, level 1 is follow up, etc
|
||||
# question_num is a unique number per original question per level.
|
||||
|
||||
|
||||
## Graph State
|
||||
|
||||
|
||||
class AnswerQuestionState(
|
||||
SubQuestionAnsweringInput,
|
||||
SubQuestionAnswerGenerationUpdate,
|
||||
SubQuestionAnswerCheckUpdate,
|
||||
SubQuestionRetrievalIngestionUpdate,
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
## Graph Output State
|
||||
|
||||
|
||||
class AnswerQuestionOutput(LoggerUpdate, BaseModel):
|
||||
"""
|
||||
This is a list of results even though each call of this subgraph only returns one result.
|
||||
This is because if we parallelize the answer query subgraph, there will be multiple
|
||||
results in a list so the add operator is used to add them together.
|
||||
"""
|
||||
|
||||
answer_results: Annotated[list[SubQuestionAnswerResults], add] = []
|
||||
@@ -0,0 +1,50 @@
|
||||
from collections.abc import Hashable
|
||||
from datetime import datetime
|
||||
|
||||
from langgraph.types import Send
|
||||
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
|
||||
AnswerQuestionOutput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
|
||||
SubQuestionAnsweringInput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_initial_answer.states import (
|
||||
SubQuestionRetrievalState,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import make_question_id
|
||||
|
||||
|
||||
def parallelize_initial_sub_question_answering(
|
||||
state: SubQuestionRetrievalState,
|
||||
) -> list[Send | Hashable]:
|
||||
"""
|
||||
LangGraph edge to parallelize the initial sub-question answering. If there are no sub-questions,
|
||||
we send empty answers to the initial answer generation, and that answer would be generated
|
||||
solely based on the documents retrieved for the original question.
|
||||
"""
|
||||
edge_start_time = datetime.now()
|
||||
if len(state.initial_sub_questions) > 0:
|
||||
return [
|
||||
Send(
|
||||
"answer_query_subgraph",
|
||||
SubQuestionAnsweringInput(
|
||||
question=question,
|
||||
question_id=make_question_id(0, question_num + 1),
|
||||
log_messages=[
|
||||
f"{edge_start_time} -- Main Edge - Parallelize Initial Sub-question Answering"
|
||||
],
|
||||
),
|
||||
)
|
||||
for question_num, question in enumerate(state.initial_sub_questions)
|
||||
]
|
||||
|
||||
else:
|
||||
return [
|
||||
Send(
|
||||
"ingest_answers",
|
||||
AnswerQuestionOutput(
|
||||
answer_results=[],
|
||||
),
|
||||
)
|
||||
]
|
||||
@@ -0,0 +1,96 @@
|
||||
from langgraph.graph import END
|
||||
from langgraph.graph import START
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_initial_answer.nodes.generate_initial_answer import (
|
||||
generate_initial_answer,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_initial_answer.nodes.validate_initial_answer import (
|
||||
validate_initial_answer,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_initial_answer.states import (
|
||||
SubQuestionRetrievalInput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_initial_answer.states import (
|
||||
SubQuestionRetrievalState,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_sub_answers.graph_builder import (
|
||||
generate_sub_answers_graph_builder,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.initial.retrieve_orig_question_docs.graph_builder import (
|
||||
retrieve_orig_question_docs_graph_builder,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def generate_initial_answer_graph_builder(test_mode: bool = False) -> StateGraph:
|
||||
"""
|
||||
LangGraph graph builder for the initial answer generation.
|
||||
"""
|
||||
graph = StateGraph(
|
||||
state_schema=SubQuestionRetrievalState,
|
||||
input=SubQuestionRetrievalInput,
|
||||
)
|
||||
|
||||
# The sub-graph that generates the initial sub-answers
|
||||
generate_sub_answers = generate_sub_answers_graph_builder().compile()
|
||||
graph.add_node(
|
||||
node="generate_sub_answers_subgraph",
|
||||
action=generate_sub_answers,
|
||||
)
|
||||
|
||||
# The sub-graph that retrieves the original question documents. This is run
|
||||
# in parallel with the sub-answer generation process
|
||||
retrieve_orig_question_docs = retrieve_orig_question_docs_graph_builder().compile()
|
||||
graph.add_node(
|
||||
node="retrieve_orig_question_docs_subgraph_wrapper",
|
||||
action=retrieve_orig_question_docs,
|
||||
)
|
||||
|
||||
# Node that generates the initial answer using the results of the previous
|
||||
# two sub-graphs
|
||||
graph.add_node(
|
||||
node="generate_initial_answer",
|
||||
action=generate_initial_answer,
|
||||
)
|
||||
|
||||
# Node that validates the initial answer
|
||||
graph.add_node(
|
||||
node="validate_initial_answer",
|
||||
action=validate_initial_answer,
|
||||
)
|
||||
|
||||
### Add edges ###
|
||||
|
||||
graph.add_edge(
|
||||
start_key=START,
|
||||
end_key="retrieve_orig_question_docs_subgraph_wrapper",
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key=START,
|
||||
end_key="generate_sub_answers_subgraph",
|
||||
)
|
||||
|
||||
# Wait for both, the original question docs and the sub-answers to be generated before proceeding
|
||||
graph.add_edge(
|
||||
start_key=[
|
||||
"retrieve_orig_question_docs_subgraph_wrapper",
|
||||
"generate_sub_answers_subgraph",
|
||||
],
|
||||
end_key="generate_initial_answer",
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key="generate_initial_answer",
|
||||
end_key="validate_initial_answer",
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key="validate_initial_answer",
|
||||
end_key=END,
|
||||
)
|
||||
|
||||
return graph
|
||||
@@ -0,0 +1,313 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import merge_content
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_initial_answer.states import (
|
||||
SubQuestionRetrievalState,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.main.models import AgentBaseMetrics
|
||||
from onyx.agents.agent_search.deep_search.main.operations import (
|
||||
calculate_initial_agent_stats,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.main.operations import get_query_info
|
||||
from onyx.agents.agent_search.deep_search.main.operations import logger
|
||||
from onyx.agents.agent_search.deep_search.main.states import (
|
||||
InitialAnswerUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
get_prompt_enrichment_components,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
trim_prompt_piece,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import InitialAgentResultStats
|
||||
from onyx.agents.agent_search.shared_graph_utils.operators import (
|
||||
dedup_inference_sections,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
dispatch_main_answer_stop_info,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import format_docs
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import relevance_from_docs
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import remove_document_citations
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.chat.models import AgentAnswerPiece
|
||||
from onyx.chat.models import ExtendedToolResponse
|
||||
from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS
|
||||
from onyx.configs.agent_configs import AGENT_MIN_ORIG_QUESTION_DOCS
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.prompts.agent_search import (
|
||||
INITIAL_ANSWER_PROMPT_W_SUB_QUESTIONS,
|
||||
)
|
||||
from onyx.prompts.agent_search import (
|
||||
INITIAL_ANSWER_PROMPT_WO_SUB_QUESTIONS,
|
||||
)
|
||||
from onyx.prompts.agent_search import (
|
||||
SUB_QUESTION_ANSWER_TEMPLATE,
|
||||
)
|
||||
from onyx.prompts.agent_search import UNKNOWN_ANSWER
|
||||
from onyx.tools.tool_implementations.search.search_tool import yield_search_responses
|
||||
|
||||
|
||||
def generate_initial_answer(
|
||||
state: SubQuestionRetrievalState,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
) -> InitialAnswerUpdate:
|
||||
"""
|
||||
LangGraph node to generate the initial answer, using the initial sub-questions/sub-answers and the
|
||||
documents retrieved for the original question.
|
||||
"""
|
||||
node_start_time = datetime.now()
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
question = graph_config.inputs.search_request.query
|
||||
prompt_enrichment_components = get_prompt_enrichment_components(graph_config)
|
||||
|
||||
sub_questions_cited_documents = state.cited_documents
|
||||
orig_question_retrieval_documents = state.orig_question_retrieved_documents
|
||||
|
||||
consolidated_context_docs: list[InferenceSection] = sub_questions_cited_documents
|
||||
counter = 0
|
||||
for original_doc_number, original_doc in enumerate(
|
||||
orig_question_retrieval_documents
|
||||
):
|
||||
if original_doc_number not in sub_questions_cited_documents:
|
||||
if (
|
||||
counter <= AGENT_MIN_ORIG_QUESTION_DOCS
|
||||
or len(consolidated_context_docs) < AGENT_MAX_ANSWER_CONTEXT_DOCS
|
||||
):
|
||||
consolidated_context_docs.append(original_doc)
|
||||
counter += 1
|
||||
|
||||
# sort docs by their scores - though the scores refer to different questions
|
||||
relevant_docs = dedup_inference_sections(
|
||||
consolidated_context_docs, consolidated_context_docs
|
||||
)
|
||||
|
||||
sub_questions: list[str] = []
|
||||
streamed_documents = (
|
||||
relevant_docs
|
||||
if len(relevant_docs) > 0
|
||||
else state.orig_question_retrieved_documents[:15]
|
||||
)
|
||||
|
||||
# Use the query info from the base document retrieval
|
||||
query_info = get_query_info(state.orig_question_sub_query_retrieval_results)
|
||||
|
||||
assert (
|
||||
graph_config.tooling.search_tool
|
||||
), "search_tool must be provided for agentic search"
|
||||
|
||||
relevance_list = relevance_from_docs(relevant_docs)
|
||||
for tool_response in yield_search_responses(
|
||||
query=question,
|
||||
reranked_sections=streamed_documents,
|
||||
final_context_sections=streamed_documents,
|
||||
search_query_info=query_info,
|
||||
get_section_relevance=lambda: relevance_list,
|
||||
search_tool=graph_config.tooling.search_tool,
|
||||
):
|
||||
write_custom_event(
|
||||
"tool_response",
|
||||
ExtendedToolResponse(
|
||||
id=tool_response.id,
|
||||
response=tool_response.response,
|
||||
level=0,
|
||||
level_question_num=0, # 0, 0 is the base question
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
if len(relevant_docs) == 0:
|
||||
write_custom_event(
|
||||
"initial_agent_answer",
|
||||
AgentAnswerPiece(
|
||||
answer_piece=UNKNOWN_ANSWER,
|
||||
level=0,
|
||||
level_question_num=0,
|
||||
answer_type="agent_level_answer",
|
||||
),
|
||||
writer,
|
||||
)
|
||||
dispatch_main_answer_stop_info(0, writer)
|
||||
|
||||
answer = UNKNOWN_ANSWER
|
||||
initial_agent_stats = InitialAgentResultStats(
|
||||
sub_questions={},
|
||||
original_question={},
|
||||
agent_effectiveness={},
|
||||
)
|
||||
|
||||
else:
|
||||
sub_question_answer_results = state.sub_question_results
|
||||
|
||||
# Collect the sub-questions and sub-answers and construct an appropriate
|
||||
# prompt string.
|
||||
# Consider replacing by a function.
|
||||
answered_sub_questions: list[str] = []
|
||||
all_sub_questions: list[str] = [] # Separate list for tracking all questions
|
||||
|
||||
for idx, sub_question_answer_result in enumerate(
|
||||
sub_question_answer_results, start=1
|
||||
):
|
||||
all_sub_questions.append(sub_question_answer_result.question)
|
||||
|
||||
is_valid_answer = (
|
||||
sub_question_answer_result.verified_high_quality
|
||||
and sub_question_answer_result.answer
|
||||
and sub_question_answer_result.answer != UNKNOWN_ANSWER
|
||||
)
|
||||
|
||||
if is_valid_answer:
|
||||
answered_sub_questions.append(
|
||||
SUB_QUESTION_ANSWER_TEMPLATE.format(
|
||||
sub_question=sub_question_answer_result.question,
|
||||
sub_answer=sub_question_answer_result.answer,
|
||||
sub_question_num=idx,
|
||||
)
|
||||
)
|
||||
|
||||
sub_question_answer_str = (
|
||||
"\n\n------\n\n".join(answered_sub_questions)
|
||||
if answered_sub_questions
|
||||
else ""
|
||||
)
|
||||
|
||||
# Use the appropriate prompt based on whether there are sub-questions.
|
||||
base_prompt = (
|
||||
INITIAL_ANSWER_PROMPT_W_SUB_QUESTIONS
|
||||
if answered_sub_questions
|
||||
else INITIAL_ANSWER_PROMPT_WO_SUB_QUESTIONS
|
||||
)
|
||||
|
||||
sub_questions = all_sub_questions # Replace the original assignment
|
||||
|
||||
model = graph_config.tooling.fast_llm
|
||||
|
||||
doc_context = format_docs(relevant_docs)
|
||||
doc_context = trim_prompt_piece(
|
||||
config=model.config,
|
||||
prompt_piece=doc_context,
|
||||
reserved_str=(
|
||||
base_prompt
|
||||
+ sub_question_answer_str
|
||||
+ prompt_enrichment_components.persona_prompts.contextualized_prompt
|
||||
+ prompt_enrichment_components.history
|
||||
+ prompt_enrichment_components.date_str
|
||||
),
|
||||
)
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=base_prompt.format(
|
||||
question=question,
|
||||
answered_sub_questions=remove_document_citations(
|
||||
sub_question_answer_str
|
||||
),
|
||||
relevant_docs=doc_context,
|
||||
persona_specification=prompt_enrichment_components.persona_prompts.contextualized_prompt,
|
||||
history=prompt_enrichment_components.history,
|
||||
date_prompt=prompt_enrichment_components.date_str,
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
streamed_tokens: list[str | list[str | dict[str, Any]]] = [""]
|
||||
dispatch_timings: list[float] = []
|
||||
for message in model.stream(msg):
|
||||
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
|
||||
content = message.content
|
||||
if not isinstance(content, str):
|
||||
raise ValueError(
|
||||
f"Expected content to be a string, but got {type(content)}"
|
||||
)
|
||||
start_stream_token = datetime.now()
|
||||
|
||||
write_custom_event(
|
||||
"initial_agent_answer",
|
||||
AgentAnswerPiece(
|
||||
answer_piece=content,
|
||||
level=0,
|
||||
level_question_num=0,
|
||||
answer_type="agent_level_answer",
|
||||
),
|
||||
writer,
|
||||
)
|
||||
end_stream_token = datetime.now()
|
||||
dispatch_timings.append(
|
||||
(end_stream_token - start_stream_token).microseconds
|
||||
)
|
||||
streamed_tokens.append(content)
|
||||
|
||||
logger.debug(
|
||||
f"Average dispatch time for initial answer: {sum(dispatch_timings) / len(dispatch_timings)}"
|
||||
)
|
||||
|
||||
dispatch_main_answer_stop_info(0, writer)
|
||||
response = merge_content(*streamed_tokens)
|
||||
answer = cast(str, response)
|
||||
|
||||
initial_agent_stats = calculate_initial_agent_stats(
|
||||
state.sub_question_results, state.orig_question_retrieval_stats
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"\n\nYYYYY--Sub-Questions:\n\n{sub_question_answer_str}\n\nStats:\n\n"
|
||||
)
|
||||
|
||||
if initial_agent_stats:
|
||||
logger.debug(initial_agent_stats.original_question)
|
||||
logger.debug(initial_agent_stats.sub_questions)
|
||||
logger.debug(initial_agent_stats.agent_effectiveness)
|
||||
|
||||
agent_base_end_time = datetime.now()
|
||||
|
||||
if agent_base_end_time and state.agent_start_time:
|
||||
duration_s = (agent_base_end_time - state.agent_start_time).total_seconds()
|
||||
else:
|
||||
duration_s = None
|
||||
|
||||
agent_base_metrics = AgentBaseMetrics(
|
||||
num_verified_documents_total=len(relevant_docs),
|
||||
num_verified_documents_core=state.orig_question_retrieval_stats.verified_count,
|
||||
verified_avg_score_core=state.orig_question_retrieval_stats.verified_avg_scores,
|
||||
num_verified_documents_base=initial_agent_stats.sub_questions.get(
|
||||
"num_verified_documents"
|
||||
),
|
||||
verified_avg_score_base=initial_agent_stats.sub_questions.get(
|
||||
"verified_avg_score"
|
||||
),
|
||||
base_doc_boost_factor=initial_agent_stats.agent_effectiveness.get(
|
||||
"utilized_chunk_ratio"
|
||||
),
|
||||
support_boost_factor=initial_agent_stats.agent_effectiveness.get(
|
||||
"support_ratio"
|
||||
),
|
||||
duration_s=duration_s,
|
||||
)
|
||||
|
||||
return InitialAnswerUpdate(
|
||||
initial_answer=answer,
|
||||
initial_agent_stats=initial_agent_stats,
|
||||
generated_sub_questions=sub_questions,
|
||||
agent_base_end_time=agent_base_end_time,
|
||||
agent_base_metrics=agent_base_metrics,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="initial - generate initial answer",
|
||||
node_name="generate initial answer",
|
||||
node_start_time=node_start_time,
|
||||
result="",
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,40 @@
|
||||
from datetime import datetime
|
||||
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_initial_answer.states import (
|
||||
SubQuestionRetrievalState,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.main.operations import logger
|
||||
from onyx.agents.agent_search.deep_search.main.states import (
|
||||
InitialAnswerQualityUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
|
||||
|
||||
def validate_initial_answer(
|
||||
state: SubQuestionRetrievalState,
|
||||
) -> InitialAnswerQualityUpdate:
|
||||
"""
|
||||
Check whether the initial answer sufficiently addresses the original user question.
|
||||
"""
|
||||
|
||||
node_start_time = datetime.now()
|
||||
|
||||
logger.debug(
|
||||
f"--------{node_start_time}--------Checking for base answer validity - for not set True/False manually"
|
||||
)
|
||||
|
||||
verdict = True
|
||||
|
||||
return InitialAnswerQualityUpdate(
|
||||
initial_answer_quality_eval=verdict,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="initial - generate initial answer",
|
||||
node_name="validate initial answer",
|
||||
node_start_time=node_start_time,
|
||||
result="",
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,51 @@
|
||||
from operator import add
|
||||
from typing import Annotated
|
||||
from typing import TypedDict
|
||||
|
||||
from onyx.agents.agent_search.core_state import CoreState
|
||||
from onyx.agents.agent_search.deep_search.main.states import (
|
||||
ExploratorySearchUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.main.states import (
|
||||
InitialAnswerQualityUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.main.states import (
|
||||
InitialAnswerUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.main.states import (
|
||||
InitialQuestionDecompositionUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.main.states import (
|
||||
OrigQuestionRetrievalUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.main.states import (
|
||||
SubQuestionResultsUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.models import (
|
||||
QuestionRetrievalResult,
|
||||
)
|
||||
from onyx.context.search.models import InferenceSection
|
||||
|
||||
|
||||
### States ###
|
||||
class SubQuestionRetrievalInput(CoreState):
|
||||
exploratory_search_results: list[InferenceSection]
|
||||
|
||||
|
||||
## Graph State
|
||||
class SubQuestionRetrievalState(
|
||||
# This includes the core state
|
||||
SubQuestionRetrievalInput,
|
||||
InitialQuestionDecompositionUpdate,
|
||||
InitialAnswerUpdate,
|
||||
SubQuestionResultsUpdate,
|
||||
OrigQuestionRetrievalUpdate,
|
||||
InitialAnswerQualityUpdate,
|
||||
ExploratorySearchUpdate,
|
||||
):
|
||||
base_raw_search_result: Annotated[list[QuestionRetrievalResult], add]
|
||||
|
||||
|
||||
## Graph Output State
|
||||
class SubQuestionRetrievalOutput(TypedDict):
|
||||
log_messages: list[str]
|
||||
@@ -0,0 +1,48 @@
|
||||
from collections.abc import Hashable
|
||||
from datetime import datetime
|
||||
|
||||
from langgraph.types import Send
|
||||
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
|
||||
AnswerQuestionOutput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
|
||||
SubQuestionAnsweringInput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_initial_answer.states import (
|
||||
SubQuestionRetrievalState,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import make_question_id
|
||||
|
||||
|
||||
def parallelize_initial_sub_question_answering(
|
||||
state: SubQuestionRetrievalState,
|
||||
) -> list[Send | Hashable]:
|
||||
"""
|
||||
LangGraph edge to parallelize the initial sub-question answering.
|
||||
"""
|
||||
edge_start_time = datetime.now()
|
||||
if len(state.initial_sub_questions) > 0:
|
||||
return [
|
||||
Send(
|
||||
"answer_sub_question_subgraphs",
|
||||
SubQuestionAnsweringInput(
|
||||
question=question,
|
||||
question_id=make_question_id(0, question_num + 1),
|
||||
log_messages=[
|
||||
f"{edge_start_time} -- Main Edge - Parallelize Initial Sub-question Answering"
|
||||
],
|
||||
),
|
||||
)
|
||||
for question_num, question in enumerate(state.initial_sub_questions)
|
||||
]
|
||||
|
||||
else:
|
||||
return [
|
||||
Send(
|
||||
"ingest_answers",
|
||||
AnswerQuestionOutput(
|
||||
answer_results=[],
|
||||
),
|
||||
)
|
||||
]
|
||||
@@ -0,0 +1,81 @@
|
||||
from langgraph.graph import END
|
||||
from langgraph.graph import START
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.graph_builder import (
|
||||
answer_query_graph_builder,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_sub_answers.edges import (
|
||||
parallelize_initial_sub_question_answering,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_sub_answers.nodes.decompose_orig_question import (
|
||||
decompose_orig_question,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_sub_answers.nodes.format_initial_sub_answers import (
|
||||
format_initial_sub_answers,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_sub_answers.states import (
|
||||
SubQuestionAnsweringInput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_sub_answers.states import (
|
||||
SubQuestionAnsweringState,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
test_mode = False
|
||||
|
||||
|
||||
def generate_sub_answers_graph_builder() -> StateGraph:
|
||||
"""
|
||||
LangGraph graph builder for the initial sub-answer generation process.
|
||||
It generates the initial sub-questions and produces the answers.
|
||||
"""
|
||||
|
||||
graph = StateGraph(
|
||||
state_schema=SubQuestionAnsweringState,
|
||||
input=SubQuestionAnsweringInput,
|
||||
)
|
||||
|
||||
# Decompose the original question into sub-questions
|
||||
graph.add_node(
|
||||
node="decompose_orig_question",
|
||||
action=decompose_orig_question,
|
||||
)
|
||||
|
||||
# The sub-graph that executes the initial sub-question answering for
|
||||
# each of the sub-questions.
|
||||
answer_sub_question_subgraphs = answer_query_graph_builder().compile()
|
||||
graph.add_node(
|
||||
node="answer_sub_question_subgraphs",
|
||||
action=answer_sub_question_subgraphs,
|
||||
)
|
||||
|
||||
# Node that collects and formats the initial sub-question answers
|
||||
graph.add_node(
|
||||
node="format_initial_sub_question_answers",
|
||||
action=format_initial_sub_answers,
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key=START,
|
||||
end_key="decompose_orig_question",
|
||||
)
|
||||
|
||||
graph.add_conditional_edges(
|
||||
source="decompose_orig_question",
|
||||
path=parallelize_initial_sub_question_answering,
|
||||
path_map=["answer_sub_question_subgraphs"],
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key=["answer_sub_question_subgraphs"],
|
||||
end_key="format_initial_sub_question_answers",
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key="format_initial_sub_question_answers",
|
||||
end_key=END,
|
||||
)
|
||||
|
||||
return graph
|
||||
@@ -0,0 +1,153 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import merge_content
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_initial_answer.states import (
|
||||
SubQuestionRetrievalState,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.main.models import (
|
||||
AgentRefinedMetrics,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.main.operations import (
|
||||
dispatch_subquestion,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.main.states import (
|
||||
InitialQuestionDecompositionUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
build_history_prompt,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import dispatch_separated
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.chat.models import StreamStopInfo
|
||||
from onyx.chat.models import StreamStopReason
|
||||
from onyx.chat.models import StreamType
|
||||
from onyx.chat.models import SubQuestionPiece
|
||||
from onyx.configs.agent_configs import AGENT_NUM_DOCS_FOR_DECOMPOSITION
|
||||
from onyx.prompts.agent_search import (
|
||||
INITIAL_DECOMPOSITION_PROMPT_QUESTIONS_AFTER_SEARCH,
|
||||
)
|
||||
from onyx.prompts.agent_search import (
|
||||
INITIAL_QUESTION_DECOMPOSITION_PROMPT,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def decompose_orig_question(
|
||||
state: SubQuestionRetrievalState,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
) -> InitialQuestionDecompositionUpdate:
|
||||
"""
|
||||
LangGraph node to decompose the original question into sub-questions.
|
||||
"""
|
||||
node_start_time = datetime.now()
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
question = graph_config.inputs.search_request.query
|
||||
perform_initial_search_decomposition = (
|
||||
graph_config.behavior.perform_initial_search_decomposition
|
||||
)
|
||||
# Get the rewritten queries in a defined format
|
||||
model = graph_config.tooling.fast_llm
|
||||
|
||||
history = build_history_prompt(graph_config, question)
|
||||
|
||||
# Use the initial search results to inform the decomposition
|
||||
agent_start_time = datetime.now()
|
||||
|
||||
# Initial search to inform decomposition. Just get top 3 fits
|
||||
|
||||
if perform_initial_search_decomposition:
|
||||
# Due to unfortunate state representation in LangGraph, we need here to double check that the retrieval has
|
||||
# happened prior to this point, allowing silent failure here since it is not critical for decomposition in
|
||||
# all queries.
|
||||
if not state.exploratory_search_results:
|
||||
logger.error("Initial search for decomposition failed")
|
||||
|
||||
sample_doc_str = "\n\n".join(
|
||||
[
|
||||
doc.combined_content
|
||||
for doc in state.exploratory_search_results[
|
||||
:AGENT_NUM_DOCS_FOR_DECOMPOSITION
|
||||
]
|
||||
]
|
||||
)
|
||||
|
||||
decomposition_prompt = (
|
||||
INITIAL_DECOMPOSITION_PROMPT_QUESTIONS_AFTER_SEARCH.format(
|
||||
question=question, sample_doc_str=sample_doc_str, history=history
|
||||
)
|
||||
)
|
||||
|
||||
else:
|
||||
decomposition_prompt = INITIAL_QUESTION_DECOMPOSITION_PROMPT.format(
|
||||
question=question, history=history
|
||||
)
|
||||
|
||||
# Start decomposition
|
||||
|
||||
msg = [HumanMessage(content=decomposition_prompt)]
|
||||
|
||||
# Send the initial question as a subquestion with number 0
|
||||
write_custom_event(
|
||||
"decomp_qs",
|
||||
SubQuestionPiece(
|
||||
sub_question=question,
|
||||
level=0,
|
||||
level_question_num=0,
|
||||
),
|
||||
writer,
|
||||
)
|
||||
# dispatches custom events for subquestion tokens, adding in subquestion ids.
|
||||
streamed_tokens = dispatch_separated(
|
||||
model.stream(msg), dispatch_subquestion(0, writer)
|
||||
)
|
||||
|
||||
stop_event = StreamStopInfo(
|
||||
stop_reason=StreamStopReason.FINISHED,
|
||||
stream_type=StreamType.SUB_QUESTIONS,
|
||||
level=0,
|
||||
)
|
||||
write_custom_event("stream_finished", stop_event, writer)
|
||||
|
||||
deomposition_response = merge_content(*streamed_tokens)
|
||||
|
||||
# this call should only return strings. Commenting out for efficiency
|
||||
# assert [type(tok) == str for tok in streamed_tokens]
|
||||
|
||||
# use no-op cast() instead of str() which runs code
|
||||
# list_of_subquestions = clean_and_parse_list_string(cast(str, response))
|
||||
list_of_subqs = cast(str, deomposition_response).split("\n")
|
||||
|
||||
decomp_list: list[str] = [sq.strip() for sq in list_of_subqs if sq.strip() != ""]
|
||||
|
||||
return InitialQuestionDecompositionUpdate(
|
||||
initial_sub_questions=decomp_list,
|
||||
agent_start_time=agent_start_time,
|
||||
agent_refined_start_time=None,
|
||||
agent_refined_end_time=None,
|
||||
agent_refined_metrics=AgentRefinedMetrics(
|
||||
refined_doc_boost_factor=None,
|
||||
refined_question_boost_factor=None,
|
||||
duration_s=None,
|
||||
),
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="initial - generate sub answers",
|
||||
node_name="decompose original question",
|
||||
node_start_time=node_start_time,
|
||||
result=f"decomposed original question into {len(decomp_list)} subquestions",
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,50 @@
|
||||
from datetime import datetime
|
||||
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
|
||||
AnswerQuestionOutput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.main.states import (
|
||||
SubQuestionResultsUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.operators import (
|
||||
dedup_inference_sections,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
|
||||
|
||||
def format_initial_sub_answers(
|
||||
state: AnswerQuestionOutput,
|
||||
) -> SubQuestionResultsUpdate:
|
||||
"""
|
||||
LangGraph node to format the answers to the initial sub-questions, including
|
||||
deduping verified documents and context documents.
|
||||
"""
|
||||
node_start_time = datetime.now()
|
||||
|
||||
documents = []
|
||||
context_documents = []
|
||||
cited_documents = []
|
||||
answer_results = state.answer_results
|
||||
for answer_result in answer_results:
|
||||
documents.extend(answer_result.verified_reranked_documents)
|
||||
context_documents.extend(answer_result.context_documents)
|
||||
cited_documents.extend(answer_result.cited_documents)
|
||||
|
||||
return SubQuestionResultsUpdate(
|
||||
# Deduping is done by the documents operator for the main graph
|
||||
# so we might not need to dedup here
|
||||
verified_reranked_documents=dedup_inference_sections(documents, []),
|
||||
context_documents=dedup_inference_sections(context_documents, []),
|
||||
cited_documents=dedup_inference_sections(cited_documents, []),
|
||||
sub_question_results=answer_results,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="initial - generate sub answers",
|
||||
node_name="format initial sub answers",
|
||||
node_start_time=node_start_time,
|
||||
result="",
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,34 @@
|
||||
from typing import TypedDict
|
||||
|
||||
from onyx.agents.agent_search.core_state import CoreState
|
||||
from onyx.agents.agent_search.deep_search.main.states import (
|
||||
InitialAnswerUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.main.states import (
|
||||
InitialQuestionDecompositionUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.main.states import (
|
||||
SubQuestionResultsUpdate,
|
||||
)
|
||||
from onyx.context.search.models import InferenceSection
|
||||
|
||||
|
||||
### States ###
|
||||
class SubQuestionAnsweringInput(CoreState):
|
||||
exploratory_search_results: list[InferenceSection]
|
||||
|
||||
|
||||
## Graph State
|
||||
class SubQuestionAnsweringState(
|
||||
# This includes the core state
|
||||
SubQuestionAnsweringInput,
|
||||
InitialQuestionDecompositionUpdate,
|
||||
InitialAnswerUpdate,
|
||||
SubQuestionResultsUpdate,
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
## Graph Output State
|
||||
class SubQuestionAnsweringOutput(TypedDict):
|
||||
log_messages: list[str]
|
||||
@@ -0,0 +1,81 @@
|
||||
from langgraph.graph import END
|
||||
from langgraph.graph import START
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
from onyx.agents.agent_search.deep_search.initial.retrieve_orig_question_docs.nodes.format_orig_question_search_input import (
|
||||
format_orig_question_search_input,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.initial.retrieve_orig_question_docs.nodes.format_orig_question_search_output import (
|
||||
format_orig_question_search_output,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.initial.retrieve_orig_question_docs.states import (
|
||||
BaseRawSearchInput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.initial.retrieve_orig_question_docs.states import (
|
||||
BaseRawSearchOutput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.initial.retrieve_orig_question_docs.states import (
|
||||
BaseRawSearchState,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.graph_builder import (
|
||||
expanded_retrieval_graph_builder,
|
||||
)
|
||||
|
||||
|
||||
def retrieve_orig_question_docs_graph_builder() -> StateGraph:
|
||||
"""
|
||||
LangGraph graph builder for the retrieval of documents
|
||||
that are relevant to the original question. This is
|
||||
largely a wrapper around the expanded retrieval process to
|
||||
ensure parallelism with the sub-question answer process.
|
||||
"""
|
||||
graph = StateGraph(
|
||||
state_schema=BaseRawSearchState,
|
||||
input=BaseRawSearchInput,
|
||||
output=BaseRawSearchOutput,
|
||||
)
|
||||
|
||||
### Add nodes ###
|
||||
|
||||
# Format the original question search output
|
||||
graph.add_node(
|
||||
node="format_orig_question_search_output",
|
||||
action=format_orig_question_search_output,
|
||||
)
|
||||
|
||||
# The sub-graph that executes the expanded retrieval process
|
||||
expanded_retrieval = expanded_retrieval_graph_builder().compile()
|
||||
graph.add_node(
|
||||
node="retrieve_orig_question_docs_subgraph",
|
||||
action=expanded_retrieval,
|
||||
)
|
||||
|
||||
# Format the original question search input
|
||||
graph.add_node(
|
||||
node="format_orig_question_search_input",
|
||||
action=format_orig_question_search_input,
|
||||
)
|
||||
|
||||
### Add edges ###
|
||||
|
||||
graph.add_edge(start_key=START, end_key="format_orig_question_search_input")
|
||||
|
||||
graph.add_edge(
|
||||
start_key="format_orig_question_search_input",
|
||||
end_key="retrieve_orig_question_docs_subgraph",
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="retrieve_orig_question_docs_subgraph",
|
||||
end_key="format_orig_question_search_output",
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key="format_orig_question_search_output",
|
||||
end_key=END,
|
||||
)
|
||||
|
||||
return graph
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pass
|
||||
@@ -0,0 +1,28 @@
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.core_state import CoreState
|
||||
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
|
||||
ExpandedRetrievalInput,
|
||||
)
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def format_orig_question_search_input(
|
||||
state: CoreState, config: RunnableConfig
|
||||
) -> ExpandedRetrievalInput:
|
||||
"""
|
||||
LangGraph node to format the search input for the original question.
|
||||
"""
|
||||
logger.debug("generate_raw_search_data")
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
return ExpandedRetrievalInput(
|
||||
question=graph_config.inputs.search_request.query,
|
||||
base_search=True,
|
||||
sub_question_id=None, # This graph is always and only used for the original question
|
||||
log_messages=[],
|
||||
)
|
||||
@@ -0,0 +1,30 @@
|
||||
from onyx.agents.agent_search.deep_search.main.states import OrigQuestionRetrievalUpdate
|
||||
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
|
||||
ExpandedRetrievalOutput,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkRetrievalStats
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def format_orig_question_search_output(
|
||||
state: ExpandedRetrievalOutput,
|
||||
) -> OrigQuestionRetrievalUpdate:
|
||||
"""
|
||||
LangGraph node to format the search result for the original question into the
|
||||
proper format.
|
||||
"""
|
||||
sub_question_retrieval_stats = state.expanded_retrieval_result.retrieval_stats
|
||||
if sub_question_retrieval_stats is None:
|
||||
sub_question_retrieval_stats = AgentChunkRetrievalStats()
|
||||
else:
|
||||
sub_question_retrieval_stats = sub_question_retrieval_stats
|
||||
|
||||
return OrigQuestionRetrievalUpdate(
|
||||
orig_question_verified_reranked_documents=state.expanded_retrieval_result.verified_reranked_documents,
|
||||
orig_question_sub_query_retrieval_results=state.expanded_retrieval_result.expanded_query_results,
|
||||
orig_question_retrieved_documents=state.retrieved_documents,
|
||||
orig_question_retrieval_stats=sub_question_retrieval_stats,
|
||||
log_messages=[],
|
||||
)
|
||||
@@ -0,0 +1,29 @@
|
||||
from onyx.agents.agent_search.deep_search.main.states import (
|
||||
OrigQuestionRetrievalUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
|
||||
ExpandedRetrievalInput,
|
||||
)
|
||||
|
||||
|
||||
## Graph Input State
|
||||
class BaseRawSearchInput(ExpandedRetrievalInput):
|
||||
pass
|
||||
|
||||
|
||||
## Graph Output State
|
||||
class BaseRawSearchOutput(OrigQuestionRetrievalUpdate):
|
||||
"""
|
||||
This is a list of results even though each call of this subgraph only returns one result.
|
||||
This is because if we parallelize the answer query subgraph, there will be multiple
|
||||
results in a list so the add operator is used to add them together.
|
||||
"""
|
||||
|
||||
# base_expanded_retrieval_result: QuestionRetrievalResult = QuestionRetrievalResult()
|
||||
|
||||
|
||||
## Graph State
|
||||
class BaseRawSearchState(
|
||||
BaseRawSearchInput, BaseRawSearchOutput, OrigQuestionRetrievalUpdate
|
||||
):
|
||||
pass
|
||||
113
backend/onyx/agents/agent_search/deep_search/main/edges.py
Normal file
113
backend/onyx/agents/agent_search/deep_search/main/edges.py
Normal file
@@ -0,0 +1,113 @@
|
||||
from collections.abc import Hashable
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
from typing import Literal
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import Send
|
||||
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
|
||||
AnswerQuestionOutput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
|
||||
SubQuestionAnsweringInput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.main.states import MainState
|
||||
from onyx.agents.agent_search.deep_search.main.states import (
|
||||
RequireRefinemenEvalUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import make_question_id
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def route_initial_tool_choice(
|
||||
state: MainState, config: RunnableConfig
|
||||
) -> Literal["tool_call", "start_agent_search", "logging_node"]:
|
||||
"""
|
||||
LangGraph edge to route to agent search.
|
||||
"""
|
||||
agent_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
if state.tool_choice is not None:
|
||||
if (
|
||||
agent_config.behavior.use_agentic_search
|
||||
and agent_config.tooling.search_tool is not None
|
||||
and state.tool_choice.tool.name == agent_config.tooling.search_tool.name
|
||||
):
|
||||
return "start_agent_search"
|
||||
else:
|
||||
return "tool_call"
|
||||
else:
|
||||
return "logging_node"
|
||||
|
||||
|
||||
def parallelize_initial_sub_question_answering(
|
||||
state: MainState,
|
||||
) -> list[Send | Hashable]:
|
||||
edge_start_time = datetime.now()
|
||||
if len(state.initial_sub_questions) > 0:
|
||||
return [
|
||||
Send(
|
||||
"answer_query_subgraph",
|
||||
SubQuestionAnsweringInput(
|
||||
question=question,
|
||||
question_id=make_question_id(0, question_num + 1),
|
||||
log_messages=[
|
||||
f"{edge_start_time} -- Main Edge - Parallelize Initial Sub-question Answering"
|
||||
],
|
||||
),
|
||||
)
|
||||
for question_num, question in enumerate(state.initial_sub_questions)
|
||||
]
|
||||
|
||||
else:
|
||||
return [
|
||||
Send(
|
||||
"ingest_answers",
|
||||
AnswerQuestionOutput(
|
||||
answer_results=[],
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
# Define the function that determines whether to continue or not
|
||||
def continue_to_refined_answer_or_end(
|
||||
state: RequireRefinemenEvalUpdate,
|
||||
) -> Literal["create_refined_sub_questions", "logging_node"]:
|
||||
if state.require_refined_answer_eval:
|
||||
return "create_refined_sub_questions"
|
||||
else:
|
||||
return "logging_node"
|
||||
|
||||
|
||||
def parallelize_refined_sub_question_answering(
|
||||
state: MainState,
|
||||
) -> list[Send | Hashable]:
|
||||
edge_start_time = datetime.now()
|
||||
if len(state.refined_sub_questions) > 0:
|
||||
return [
|
||||
Send(
|
||||
"answer_refined_question_subgraphs",
|
||||
SubQuestionAnsweringInput(
|
||||
question=question_data.sub_question,
|
||||
question_id=make_question_id(1, question_num),
|
||||
log_messages=[
|
||||
f"{edge_start_time} -- Main Edge - Parallelize Refined Sub-question Answering"
|
||||
],
|
||||
),
|
||||
)
|
||||
for question_num, question_data in state.refined_sub_questions.items()
|
||||
]
|
||||
|
||||
else:
|
||||
return [
|
||||
Send(
|
||||
"ingest_refined_sub_answers",
|
||||
AnswerQuestionOutput(
|
||||
answer_results=[],
|
||||
),
|
||||
)
|
||||
]
|
||||
@@ -0,0 +1,265 @@
|
||||
from langgraph.graph import END
|
||||
from langgraph.graph import START
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_initial_answer.graph_builder import (
|
||||
generate_initial_answer_graph_builder,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.main.edges import (
|
||||
continue_to_refined_answer_or_end,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.main.edges import (
|
||||
parallelize_refined_sub_question_answering,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.main.edges import (
|
||||
route_initial_tool_choice,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.main.nodes.compare_answers import (
|
||||
compare_answers,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.main.nodes.create_refined_sub_questions import (
|
||||
create_refined_sub_questions,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.main.nodes.decide_refinement_need import (
|
||||
decide_refinement_need,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.main.nodes.extract_entities_terms import (
|
||||
extract_entities_terms,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.main.nodes.generate_refined_answer import (
|
||||
generate_refined_answer,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.main.nodes.ingest_refined_sub_answers import (
|
||||
ingest_refined_sub_answers,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.main.nodes.persist_agent_results import (
|
||||
persist_agent_results,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.main.nodes.start_agent_search import (
|
||||
start_agent_search,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.main.states import MainInput
|
||||
from onyx.agents.agent_search.deep_search.main.states import MainState
|
||||
from onyx.agents.agent_search.deep_search.refinement.consolidate_sub_answers.graph_builder import (
|
||||
answer_refined_query_graph_builder,
|
||||
)
|
||||
from onyx.agents.agent_search.orchestration.nodes.basic_use_tool_response import (
|
||||
basic_use_tool_response,
|
||||
)
|
||||
from onyx.agents.agent_search.orchestration.nodes.llm_tool_choice import llm_tool_choice
|
||||
from onyx.agents.agent_search.orchestration.nodes.prepare_tool_input import (
|
||||
prepare_tool_input,
|
||||
)
|
||||
from onyx.agents.agent_search.orchestration.nodes.tool_call import tool_call
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
test_mode = False
|
||||
|
||||
|
||||
def main_graph_builder(test_mode: bool = False) -> StateGraph:
|
||||
"""
|
||||
LangGraph graph builder for the main agent search process.
|
||||
"""
|
||||
graph = StateGraph(
|
||||
state_schema=MainState,
|
||||
input=MainInput,
|
||||
)
|
||||
|
||||
# Prepare the tool input
|
||||
graph.add_node(
|
||||
node="prepare_tool_input",
|
||||
action=prepare_tool_input,
|
||||
)
|
||||
|
||||
# Choose the initial tool
|
||||
graph.add_node(
|
||||
node="initial_tool_choice",
|
||||
action=llm_tool_choice,
|
||||
)
|
||||
|
||||
# Call the tool, if required
|
||||
graph.add_node(
|
||||
node="tool_call",
|
||||
action=tool_call,
|
||||
)
|
||||
|
||||
# Use the tool response
|
||||
graph.add_node(
|
||||
node="basic_use_tool_response",
|
||||
action=basic_use_tool_response,
|
||||
)
|
||||
|
||||
# Start the agent search process
|
||||
graph.add_node(
|
||||
node="start_agent_search",
|
||||
action=start_agent_search,
|
||||
)
|
||||
|
||||
# The sub-graph for the initial answer generation
|
||||
generate_initial_answer_subgraph = generate_initial_answer_graph_builder().compile()
|
||||
graph.add_node(
|
||||
node="generate_initial_answer_subgraph",
|
||||
action=generate_initial_answer_subgraph,
|
||||
)
|
||||
|
||||
# Create the refined sub-questions
|
||||
graph.add_node(
|
||||
node="create_refined_sub_questions",
|
||||
action=create_refined_sub_questions,
|
||||
)
|
||||
|
||||
# Subgraph for the refined sub-answer generation
|
||||
answer_refined_question = answer_refined_query_graph_builder().compile()
|
||||
graph.add_node(
|
||||
node="answer_refined_question_subgraphs",
|
||||
action=answer_refined_question,
|
||||
)
|
||||
|
||||
# Ingest the refined sub-answers
|
||||
graph.add_node(
|
||||
node="ingest_refined_sub_answers",
|
||||
action=ingest_refined_sub_answers,
|
||||
)
|
||||
|
||||
# Node to generate the refined answer
|
||||
graph.add_node(
|
||||
node="generate_refined_answer",
|
||||
action=generate_refined_answer,
|
||||
)
|
||||
|
||||
# Early node to extract the entities and terms from the initial answer,
|
||||
# This information is used to inform the creation the refined sub-questions
|
||||
graph.add_node(
|
||||
node="extract_entity_term",
|
||||
action=extract_entities_terms,
|
||||
)
|
||||
|
||||
# Decide if the answer needs to be refined (currently always true)
|
||||
graph.add_node(
|
||||
node="decide_refinement_need",
|
||||
action=decide_refinement_need,
|
||||
)
|
||||
|
||||
# Compare the initial and refined answers, and determine whether
|
||||
# the refined answer is sufficiently better
|
||||
graph.add_node(
|
||||
node="compare_answers",
|
||||
action=compare_answers,
|
||||
)
|
||||
|
||||
# Log the results. This will log the stats as well as the answers, sub-questions, and sub-answers
|
||||
graph.add_node(
|
||||
node="logging_node",
|
||||
action=persist_agent_results,
|
||||
)
|
||||
|
||||
### Add edges ###
|
||||
|
||||
graph.add_edge(start_key=START, end_key="prepare_tool_input")
|
||||
|
||||
graph.add_edge(
|
||||
start_key="prepare_tool_input",
|
||||
end_key="initial_tool_choice",
|
||||
)
|
||||
|
||||
graph.add_conditional_edges(
|
||||
"initial_tool_choice",
|
||||
route_initial_tool_choice,
|
||||
["tool_call", "start_agent_search", "logging_node"],
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key="tool_call",
|
||||
end_key="basic_use_tool_response",
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="basic_use_tool_response",
|
||||
end_key="logging_node",
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key="start_agent_search",
|
||||
end_key="generate_initial_answer_subgraph",
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key="start_agent_search",
|
||||
end_key="extract_entity_term",
|
||||
)
|
||||
|
||||
# Wait for the initial answer generation and the entity/term extraction to be complete
|
||||
# before deciding if a refinement is needed.
|
||||
graph.add_edge(
|
||||
start_key=["generate_initial_answer_subgraph", "extract_entity_term"],
|
||||
end_key="decide_refinement_need",
|
||||
)
|
||||
|
||||
graph.add_conditional_edges(
|
||||
source="decide_refinement_need",
|
||||
path=continue_to_refined_answer_or_end,
|
||||
path_map=["create_refined_sub_questions", "logging_node"],
|
||||
)
|
||||
|
||||
graph.add_conditional_edges(
|
||||
source="create_refined_sub_questions",
|
||||
path=parallelize_refined_sub_question_answering,
|
||||
path_map=["answer_refined_question_subgraphs"],
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="answer_refined_question_subgraphs",
|
||||
end_key="ingest_refined_sub_answers",
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key="ingest_refined_sub_answers",
|
||||
end_key="generate_refined_answer",
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key="generate_refined_answer",
|
||||
end_key="compare_answers",
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="compare_answers",
|
||||
end_key="logging_node",
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key="logging_node",
|
||||
end_key=END,
|
||||
)
|
||||
|
||||
return graph
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pass
|
||||
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.llm.factory import get_default_llms
|
||||
from onyx.context.search.models import SearchRequest
|
||||
|
||||
graph = main_graph_builder()
|
||||
compiled_graph = graph.compile()
|
||||
primary_llm, fast_llm = get_default_llms()
|
||||
|
||||
with get_session_context_manager() as db_session:
|
||||
search_request = SearchRequest(query="Who created Excel?")
|
||||
graph_config = get_test_config(
|
||||
db_session, primary_llm, fast_llm, search_request
|
||||
)
|
||||
|
||||
inputs = MainInput(
|
||||
base_question=graph_config.inputs.search_request.query, log_messages=[]
|
||||
)
|
||||
|
||||
for thing in compiled_graph.stream(
|
||||
input=inputs,
|
||||
config={"configurable": {"config": graph_config}},
|
||||
stream_mode="custom",
|
||||
subgraphs=True,
|
||||
):
|
||||
logger.debug(thing)
|
||||
36
backend/onyx/agents/agent_search/deep_search/main/models.py
Normal file
36
backend/onyx/agents/agent_search/deep_search/main/models.py
Normal file
@@ -0,0 +1,36 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class RefinementSubQuestion(BaseModel):
|
||||
sub_question: str
|
||||
sub_question_id: str
|
||||
verified: bool
|
||||
answered: bool
|
||||
answer: str
|
||||
|
||||
|
||||
class AgentTimings(BaseModel):
|
||||
base_duration_s: float | None
|
||||
refined_duration_s: float | None
|
||||
full_duration_s: float | None
|
||||
|
||||
|
||||
class AgentBaseMetrics(BaseModel):
|
||||
num_verified_documents_total: int | None
|
||||
num_verified_documents_core: int | None
|
||||
verified_avg_score_core: float | None
|
||||
num_verified_documents_base: int | float | None
|
||||
verified_avg_score_base: float | None = None
|
||||
base_doc_boost_factor: float | None = None
|
||||
support_boost_factor: float | None = None
|
||||
duration_s: float | None = None
|
||||
|
||||
|
||||
class AgentRefinedMetrics(BaseModel):
|
||||
refined_doc_boost_factor: float | None = None
|
||||
refined_question_boost_factor: float | None = None
|
||||
duration_s: float | None = None
|
||||
|
||||
|
||||
class AgentAdditionalMetrics(BaseModel):
|
||||
pass
|
||||
@@ -0,0 +1,71 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.deep_search.main.states import (
|
||||
InitialRefinedAnswerComparisonUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.main.states import MainState
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.chat.models import RefinedAnswerImprovement
|
||||
from onyx.prompts.agent_search import (
|
||||
INITIAL_REFINED_ANSWER_COMPARISON_PROMPT,
|
||||
)
|
||||
|
||||
|
||||
def compare_answers(
|
||||
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> InitialRefinedAnswerComparisonUpdate:
|
||||
"""
|
||||
LangGraph node to compare the initial answer and the refined answer and determine if the
|
||||
refined answer is sufficiently better than the initial answer.
|
||||
"""
|
||||
node_start_time = datetime.now()
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
question = graph_config.inputs.search_request.query
|
||||
initial_answer = state.initial_answer
|
||||
refined_answer = state.refined_answer
|
||||
|
||||
compare_answers_prompt = INITIAL_REFINED_ANSWER_COMPARISON_PROMPT.format(
|
||||
question=question, initial_answer=initial_answer, refined_answer=refined_answer
|
||||
)
|
||||
|
||||
msg = [HumanMessage(content=compare_answers_prompt)]
|
||||
|
||||
# Get the rewritten queries in a defined format
|
||||
model = graph_config.tooling.fast_llm
|
||||
|
||||
# no need to stream this
|
||||
resp = model.invoke(msg)
|
||||
|
||||
refined_answer_improvement = (
|
||||
isinstance(resp.content, str) and "yes" in resp.content.lower()
|
||||
)
|
||||
|
||||
write_custom_event(
|
||||
"refined_answer_improvement",
|
||||
RefinedAnswerImprovement(
|
||||
refined_answer_improvement=refined_answer_improvement,
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
return InitialRefinedAnswerComparisonUpdate(
|
||||
refined_answer_improvement_eval=refined_answer_improvement,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="main",
|
||||
node_name="compare answers",
|
||||
node_start_time=node_start_time,
|
||||
result=f"Answer comparison: {refined_answer_improvement}",
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,131 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import merge_content
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.deep_search.main.models import (
|
||||
RefinementSubQuestion,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.main.operations import (
|
||||
dispatch_subquestion,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.main.states import MainState
|
||||
from onyx.agents.agent_search.deep_search.main.states import (
|
||||
RefinedQuestionDecompositionUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
build_history_prompt,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import dispatch_separated
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
format_entity_term_extraction,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import make_question_id
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.prompts.agent_search import (
|
||||
REFINEMENT_QUESTION_DECOMPOSITION_PROMPT,
|
||||
)
|
||||
from onyx.tools.models import ToolCallKickoff
|
||||
|
||||
|
||||
def create_refined_sub_questions(
|
||||
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> RefinedQuestionDecompositionUpdate:
|
||||
"""
|
||||
LangGraph node to create refined sub-questions based on the initial answer, the history,
|
||||
the entity term extraction results found earlier, and the sub-questions that were answered and failed.
|
||||
"""
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
write_custom_event(
|
||||
"start_refined_answer_creation",
|
||||
ToolCallKickoff(
|
||||
tool_name="agent_search_1",
|
||||
tool_args={
|
||||
"query": graph_config.inputs.search_request.query,
|
||||
"answer": state.initial_answer,
|
||||
},
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
node_start_time = datetime.now()
|
||||
|
||||
agent_refined_start_time = datetime.now()
|
||||
|
||||
question = graph_config.inputs.search_request.query
|
||||
base_answer = state.initial_answer
|
||||
history = build_history_prompt(graph_config, question)
|
||||
# get the entity term extraction dict and properly format it
|
||||
entity_retlation_term_extractions = state.entity_relation_term_extractions
|
||||
|
||||
entity_term_extraction_str = format_entity_term_extraction(
|
||||
entity_retlation_term_extractions
|
||||
)
|
||||
|
||||
initial_question_answers = state.sub_question_results
|
||||
|
||||
addressed_question_list = [
|
||||
x.question for x in initial_question_answers if x.verified_high_quality
|
||||
]
|
||||
|
||||
failed_question_list = [
|
||||
x.question for x in initial_question_answers if not x.verified_high_quality
|
||||
]
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=REFINEMENT_QUESTION_DECOMPOSITION_PROMPT.format(
|
||||
question=question,
|
||||
history=history,
|
||||
entity_term_extraction_str=entity_term_extraction_str,
|
||||
base_answer=base_answer,
|
||||
answered_sub_questions="\n - ".join(addressed_question_list),
|
||||
failed_sub_questions="\n - ".join(failed_question_list),
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
# Grader
|
||||
model = graph_config.tooling.fast_llm
|
||||
|
||||
streamed_tokens = dispatch_separated(
|
||||
model.stream(msg), dispatch_subquestion(1, writer)
|
||||
)
|
||||
response = merge_content(*streamed_tokens)
|
||||
|
||||
if isinstance(response, str):
|
||||
parsed_response = [q for q in response.split("\n") if q.strip() != ""]
|
||||
else:
|
||||
raise ValueError("LLM response is not a string")
|
||||
|
||||
refined_sub_question_dict = {}
|
||||
for sub_question_num, sub_question in enumerate(parsed_response):
|
||||
refined_sub_question = RefinementSubQuestion(
|
||||
sub_question=sub_question,
|
||||
sub_question_id=make_question_id(1, sub_question_num + 1),
|
||||
verified=False,
|
||||
answered=False,
|
||||
answer="",
|
||||
)
|
||||
|
||||
refined_sub_question_dict[sub_question_num + 1] = refined_sub_question
|
||||
|
||||
return RefinedQuestionDecompositionUpdate(
|
||||
refined_sub_questions=refined_sub_question_dict,
|
||||
agent_refined_start_time=agent_refined_start_time,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="main",
|
||||
node_name="create refined sub questions",
|
||||
node_start_time=node_start_time,
|
||||
result=f"Created {len(refined_sub_question_dict)} refined sub questions",
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,47 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.deep_search.main.states import MainState
|
||||
from onyx.agents.agent_search.deep_search.main.states import (
|
||||
RequireRefinemenEvalUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
|
||||
|
||||
def decide_refinement_need(
|
||||
state: MainState, config: RunnableConfig
|
||||
) -> RequireRefinemenEvalUpdate:
|
||||
"""
|
||||
LangGraph node to decide if refinement is needed based on the initial answer and the question.
|
||||
At present, we always refine.
|
||||
"""
|
||||
node_start_time = datetime.now()
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
|
||||
decision = True # TODO: just for current testing purposes
|
||||
|
||||
log_messages = [
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="main",
|
||||
node_name="decide refinement need",
|
||||
node_start_time=node_start_time,
|
||||
result=f"Refinement decision: {decision}",
|
||||
)
|
||||
]
|
||||
|
||||
if graph_config.behavior.allow_refinement:
|
||||
return RequireRefinemenEvalUpdate(
|
||||
require_refined_answer_eval=decision,
|
||||
log_messages=log_messages,
|
||||
)
|
||||
else:
|
||||
return RequireRefinemenEvalUpdate(
|
||||
require_refined_answer_eval=False,
|
||||
log_messages=log_messages,
|
||||
)
|
||||
@@ -0,0 +1,116 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.deep_search.main.operations import logger
|
||||
from onyx.agents.agent_search.deep_search.main.states import (
|
||||
EntityTermExtractionUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.main.states import MainState
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
trim_prompt_piece,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import EntityExtractionResult
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import (
|
||||
EntityRelationshipTermExtraction,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import format_docs
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.configs.constants import NUM_EXPLORATORY_DOCS
|
||||
from onyx.prompts.agent_search import ENTITY_TERM_EXTRACTION_PROMPT
|
||||
from onyx.prompts.agent_search import ENTITY_TERM_EXTRACTION_PROMPT_JSON_EXAMPLE
|
||||
|
||||
|
||||
def extract_entities_terms(
|
||||
state: MainState, config: RunnableConfig
|
||||
) -> EntityTermExtractionUpdate:
|
||||
"""
|
||||
LangGraph node to extract entities, relationships, and terms from the initial search results.
|
||||
This data is used to inform particularly the sub-questions that are created for the refined answer.
|
||||
"""
|
||||
node_start_time = datetime.now()
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
if not graph_config.behavior.allow_refinement:
|
||||
return EntityTermExtractionUpdate(
|
||||
entity_relation_term_extractions=EntityRelationshipTermExtraction(
|
||||
entities=[],
|
||||
relationships=[],
|
||||
terms=[],
|
||||
),
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="main",
|
||||
node_name="extract entities terms",
|
||||
node_start_time=node_start_time,
|
||||
result="Refinement is not allowed",
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
# first four lines duplicates from generate_initial_answer
|
||||
question = graph_config.inputs.search_request.query
|
||||
initial_search_docs = state.exploratory_search_results[:NUM_EXPLORATORY_DOCS]
|
||||
|
||||
# start with the entity/term/extraction
|
||||
doc_context = format_docs(initial_search_docs)
|
||||
|
||||
# Calculation here is only approximate
|
||||
doc_context = trim_prompt_piece(
|
||||
graph_config.tooling.fast_llm.config,
|
||||
doc_context,
|
||||
ENTITY_TERM_EXTRACTION_PROMPT
|
||||
+ question
|
||||
+ ENTITY_TERM_EXTRACTION_PROMPT_JSON_EXAMPLE,
|
||||
)
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=ENTITY_TERM_EXTRACTION_PROMPT.format(
|
||||
question=question, context=doc_context
|
||||
)
|
||||
+ ENTITY_TERM_EXTRACTION_PROMPT_JSON_EXAMPLE,
|
||||
)
|
||||
]
|
||||
fast_llm = graph_config.tooling.fast_llm
|
||||
# Grader
|
||||
llm_response = fast_llm.invoke(
|
||||
prompt=msg,
|
||||
)
|
||||
|
||||
cleaned_response = (
|
||||
str(llm_response.content).replace("```json\n", "").replace("\n```", "")
|
||||
)
|
||||
first_bracket = cleaned_response.find("{")
|
||||
last_bracket = cleaned_response.rfind("}")
|
||||
cleaned_response = cleaned_response[first_bracket : last_bracket + 1]
|
||||
|
||||
try:
|
||||
entity_extraction_result = EntityExtractionResult.model_validate_json(
|
||||
cleaned_response
|
||||
)
|
||||
except ValueError:
|
||||
logger.error("Failed to parse LLM response as JSON in Entity-Term Extraction")
|
||||
entity_extraction_result = EntityExtractionResult(
|
||||
retrieved_entities_relationships=EntityRelationshipTermExtraction(
|
||||
entities=[],
|
||||
relationships=[],
|
||||
terms=[],
|
||||
),
|
||||
)
|
||||
|
||||
return EntityTermExtractionUpdate(
|
||||
entity_relation_term_extractions=entity_extraction_result.retrieved_entities_relationships,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="main",
|
||||
node_name="extract entities terms",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,339 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import merge_content
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.deep_search.main.models import (
|
||||
AgentRefinedMetrics,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.main.operations import get_query_info
|
||||
from onyx.agents.agent_search.deep_search.main.operations import logger
|
||||
from onyx.agents.agent_search.deep_search.main.states import MainState
|
||||
from onyx.agents.agent_search.deep_search.main.states import (
|
||||
RefinedAnswerUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
get_prompt_enrichment_components,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
trim_prompt_piece,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import InferenceSection
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import RefinedAgentStats
|
||||
from onyx.agents.agent_search.shared_graph_utils.operators import (
|
||||
dedup_inference_sections,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
dispatch_main_answer_stop_info,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import format_docs
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import relevance_from_docs
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
remove_document_citations,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.chat.models import AgentAnswerPiece
|
||||
from onyx.chat.models import ExtendedToolResponse
|
||||
from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS
|
||||
from onyx.configs.agent_configs import AGENT_MIN_ORIG_QUESTION_DOCS
|
||||
from onyx.prompts.agent_search import (
|
||||
REFINED_ANSWER_PROMPT_W_SUB_QUESTIONS,
|
||||
)
|
||||
from onyx.prompts.agent_search import (
|
||||
REFINED_ANSWER_PROMPT_WO_SUB_QUESTIONS,
|
||||
)
|
||||
from onyx.prompts.agent_search import (
|
||||
SUB_QUESTION_ANSWER_TEMPLATE_REFINED,
|
||||
)
|
||||
from onyx.prompts.agent_search import UNKNOWN_ANSWER
|
||||
from onyx.tools.tool_implementations.search.search_tool import yield_search_responses
|
||||
|
||||
|
||||
def generate_refined_answer(
|
||||
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> RefinedAnswerUpdate:
|
||||
"""
|
||||
LangGraph node to generate the refined answer.
|
||||
"""
|
||||
|
||||
node_start_time = datetime.now()
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
question = graph_config.inputs.search_request.query
|
||||
prompt_enrichment_components = get_prompt_enrichment_components(graph_config)
|
||||
|
||||
persona_contextualized_prompt = (
|
||||
prompt_enrichment_components.persona_prompts.contextualized_prompt
|
||||
)
|
||||
|
||||
verified_reranked_documents = state.verified_reranked_documents
|
||||
sub_questions_cited_documents = state.cited_documents
|
||||
original_question_verified_documents = (
|
||||
state.orig_question_verified_reranked_documents
|
||||
)
|
||||
original_question_retrieved_documents = state.orig_question_retrieved_documents
|
||||
|
||||
consolidated_context_docs: list[InferenceSection] = sub_questions_cited_documents
|
||||
|
||||
counter = 0
|
||||
for original_doc_number, original_doc in enumerate(
|
||||
original_question_verified_documents
|
||||
):
|
||||
if original_doc_number not in sub_questions_cited_documents:
|
||||
if (
|
||||
counter <= AGENT_MIN_ORIG_QUESTION_DOCS
|
||||
or len(consolidated_context_docs)
|
||||
< 1.5
|
||||
* AGENT_MAX_ANSWER_CONTEXT_DOCS # allow for larger context in refinement
|
||||
):
|
||||
consolidated_context_docs.append(original_doc)
|
||||
counter += 1
|
||||
|
||||
# sort docs by their scores - though the scores refer to different questions
|
||||
relevant_docs = dedup_inference_sections(
|
||||
consolidated_context_docs, consolidated_context_docs
|
||||
)
|
||||
|
||||
streaming_docs = (
|
||||
relevant_docs
|
||||
if len(relevant_docs) > 0
|
||||
else original_question_retrieved_documents[:15]
|
||||
)
|
||||
|
||||
query_info = get_query_info(state.orig_question_sub_query_retrieval_results)
|
||||
assert (
|
||||
graph_config.tooling.search_tool
|
||||
), "search_tool must be provided for agentic search"
|
||||
# stream refined answer docs, or original question docs if no relevant docs are found
|
||||
relevance_list = relevance_from_docs(relevant_docs)
|
||||
for tool_response in yield_search_responses(
|
||||
query=question,
|
||||
reranked_sections=streaming_docs,
|
||||
final_context_sections=streaming_docs,
|
||||
search_query_info=query_info,
|
||||
get_section_relevance=lambda: relevance_list,
|
||||
search_tool=graph_config.tooling.search_tool,
|
||||
):
|
||||
write_custom_event(
|
||||
"tool_response",
|
||||
ExtendedToolResponse(
|
||||
id=tool_response.id,
|
||||
response=tool_response.response,
|
||||
level=1,
|
||||
level_question_num=0, # 0, 0 is the base question
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
if len(verified_reranked_documents) > 0:
|
||||
refined_doc_effectiveness = len(relevant_docs) / len(
|
||||
verified_reranked_documents
|
||||
)
|
||||
else:
|
||||
refined_doc_effectiveness = 10.0
|
||||
|
||||
sub_question_answer_results = state.sub_question_results
|
||||
|
||||
answered_sub_question_answer_list: list[str] = []
|
||||
sub_questions: list[str] = []
|
||||
initial_answered_sub_questions: set[str] = set()
|
||||
refined_answered_sub_questions: set[str] = set()
|
||||
|
||||
for i, result in enumerate(sub_question_answer_results, 1):
|
||||
question_level, _ = parse_question_id(result.question_id)
|
||||
sub_questions.append(result.question)
|
||||
|
||||
if (
|
||||
result.verified_high_quality
|
||||
and result.answer
|
||||
and result.answer != UNKNOWN_ANSWER
|
||||
):
|
||||
sub_question_type = "initial" if question_level == 0 else "refined"
|
||||
question_set = (
|
||||
initial_answered_sub_questions
|
||||
if question_level == 0
|
||||
else refined_answered_sub_questions
|
||||
)
|
||||
question_set.add(result.question)
|
||||
|
||||
answered_sub_question_answer_list.append(
|
||||
SUB_QUESTION_ANSWER_TEMPLATE_REFINED.format(
|
||||
sub_question=result.question,
|
||||
sub_answer=result.answer,
|
||||
sub_question_num=i,
|
||||
sub_question_type=sub_question_type,
|
||||
)
|
||||
)
|
||||
|
||||
# Calculate efficiency
|
||||
total_answered_questions = (
|
||||
initial_answered_sub_questions | refined_answered_sub_questions
|
||||
)
|
||||
revision_question_efficiency = (
|
||||
len(total_answered_questions) / len(initial_answered_sub_questions)
|
||||
if initial_answered_sub_questions
|
||||
else 10.0
|
||||
if refined_answered_sub_questions
|
||||
else 1.0
|
||||
)
|
||||
|
||||
sub_question_answer_str = "\n\n------\n\n".join(
|
||||
set(answered_sub_question_answer_list)
|
||||
)
|
||||
initial_answer = state.initial_answer or ""
|
||||
|
||||
# Choose appropriate prompt template
|
||||
base_prompt = (
|
||||
REFINED_ANSWER_PROMPT_W_SUB_QUESTIONS
|
||||
if answered_sub_question_answer_list
|
||||
else REFINED_ANSWER_PROMPT_WO_SUB_QUESTIONS
|
||||
)
|
||||
|
||||
model = graph_config.tooling.fast_llm
|
||||
relevant_docs_str = format_docs(relevant_docs)
|
||||
relevant_docs_str = trim_prompt_piece(
|
||||
model.config,
|
||||
relevant_docs_str,
|
||||
base_prompt
|
||||
+ question
|
||||
+ sub_question_answer_str
|
||||
+ initial_answer
|
||||
+ persona_contextualized_prompt
|
||||
+ prompt_enrichment_components.history,
|
||||
)
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=base_prompt.format(
|
||||
question=question,
|
||||
history=prompt_enrichment_components.history,
|
||||
answered_sub_questions=remove_document_citations(
|
||||
sub_question_answer_str
|
||||
),
|
||||
relevant_docs=relevant_docs_str,
|
||||
initial_answer=remove_document_citations(initial_answer)
|
||||
if initial_answer
|
||||
else None,
|
||||
persona_specification=persona_contextualized_prompt,
|
||||
date_prompt=prompt_enrichment_components.date_str,
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
streamed_tokens: list[str | list[str | dict[str, Any]]] = [""]
|
||||
dispatch_timings: list[float] = []
|
||||
for message in model.stream(msg):
|
||||
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
|
||||
content = message.content
|
||||
if not isinstance(content, str):
|
||||
raise ValueError(
|
||||
f"Expected content to be a string, but got {type(content)}"
|
||||
)
|
||||
|
||||
start_stream_token = datetime.now()
|
||||
write_custom_event(
|
||||
"refined_agent_answer",
|
||||
AgentAnswerPiece(
|
||||
answer_piece=content,
|
||||
level=1,
|
||||
level_question_num=0,
|
||||
answer_type="agent_level_answer",
|
||||
),
|
||||
writer,
|
||||
)
|
||||
end_stream_token = datetime.now()
|
||||
dispatch_timings.append((end_stream_token - start_stream_token).microseconds)
|
||||
streamed_tokens.append(content)
|
||||
|
||||
logger.debug(
|
||||
f"Average dispatch time for refined answer: {sum(dispatch_timings) / len(dispatch_timings)}"
|
||||
)
|
||||
dispatch_main_answer_stop_info(1, writer)
|
||||
response = merge_content(*streamed_tokens)
|
||||
answer = cast(str, response)
|
||||
|
||||
refined_agent_stats = RefinedAgentStats(
|
||||
revision_doc_efficiency=refined_doc_effectiveness,
|
||||
revision_question_efficiency=revision_question_efficiency,
|
||||
)
|
||||
|
||||
logger.debug(f"\n\n---INITIAL ANSWER ---\n\n Answer:\n Agent: {initial_answer}")
|
||||
logger.debug("-" * 10)
|
||||
logger.debug(f"\n\n---REVISED AGENT ANSWER ---\n\n Answer:\n Agent: {answer}")
|
||||
|
||||
logger.debug("-" * 100)
|
||||
|
||||
if state.initial_agent_stats:
|
||||
initial_doc_boost_factor = state.initial_agent_stats.agent_effectiveness.get(
|
||||
"utilized_chunk_ratio", "--"
|
||||
)
|
||||
initial_support_boost_factor = (
|
||||
state.initial_agent_stats.agent_effectiveness.get("support_ratio", "--")
|
||||
)
|
||||
num_initial_verified_docs = state.initial_agent_stats.original_question.get(
|
||||
"num_verified_documents", "--"
|
||||
)
|
||||
initial_verified_docs_avg_score = (
|
||||
state.initial_agent_stats.original_question.get("verified_avg_score", "--")
|
||||
)
|
||||
initial_sub_questions_verified_docs = (
|
||||
state.initial_agent_stats.sub_questions.get("num_verified_documents", "--")
|
||||
)
|
||||
|
||||
logger.debug("INITIAL AGENT STATS")
|
||||
logger.debug(f"Document Boost Factor: {initial_doc_boost_factor}")
|
||||
logger.debug(f"Support Boost Factor: {initial_support_boost_factor}")
|
||||
logger.debug(f"Originally Verified Docs: {num_initial_verified_docs}")
|
||||
logger.debug(
|
||||
f"Originally Verified Docs Avg Score: {initial_verified_docs_avg_score}"
|
||||
)
|
||||
logger.debug(
|
||||
f"Sub-Questions Verified Docs: {initial_sub_questions_verified_docs}"
|
||||
)
|
||||
if refined_agent_stats:
|
||||
logger.debug("-" * 10)
|
||||
logger.debug("REFINED AGENT STATS")
|
||||
logger.debug(
|
||||
f"Revision Doc Factor: {refined_agent_stats.revision_doc_efficiency}"
|
||||
)
|
||||
logger.debug(
|
||||
f"Revision Question Factor: {refined_agent_stats.revision_question_efficiency}"
|
||||
)
|
||||
|
||||
agent_refined_end_time = datetime.now()
|
||||
if state.agent_refined_start_time:
|
||||
agent_refined_duration = (
|
||||
agent_refined_end_time - state.agent_refined_start_time
|
||||
).total_seconds()
|
||||
else:
|
||||
agent_refined_duration = None
|
||||
|
||||
agent_refined_metrics = AgentRefinedMetrics(
|
||||
refined_doc_boost_factor=refined_agent_stats.revision_doc_efficiency,
|
||||
refined_question_boost_factor=refined_agent_stats.revision_question_efficiency,
|
||||
duration_s=agent_refined_duration,
|
||||
)
|
||||
|
||||
return RefinedAnswerUpdate(
|
||||
refined_answer=answer,
|
||||
refined_answer_quality=True, # TODO: replace this with the actual check value
|
||||
refined_agent_stats=refined_agent_stats,
|
||||
agent_refined_end_time=agent_refined_end_time,
|
||||
agent_refined_metrics=agent_refined_metrics,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="main",
|
||||
node_name="generate refined answer",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,42 @@
|
||||
from datetime import datetime
|
||||
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
|
||||
AnswerQuestionOutput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.main.states import (
|
||||
SubQuestionResultsUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.operators import (
|
||||
dedup_inference_sections,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
|
||||
|
||||
def ingest_refined_sub_answers(
|
||||
state: AnswerQuestionOutput,
|
||||
) -> SubQuestionResultsUpdate:
|
||||
"""
|
||||
LangGraph node to ingest and format the refined sub-answers and retrieved documents.
|
||||
"""
|
||||
node_start_time = datetime.now()
|
||||
|
||||
documents = []
|
||||
answer_results = state.answer_results
|
||||
for answer_result in answer_results:
|
||||
documents.extend(answer_result.verified_reranked_documents)
|
||||
|
||||
return SubQuestionResultsUpdate(
|
||||
# Deduping is done by the documents operator for the main graph
|
||||
# so we might not need to dedup here
|
||||
verified_reranked_documents=dedup_inference_sections(documents, []),
|
||||
sub_question_results=answer_results,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="main",
|
||||
node_name="ingest refined answers",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,129 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.deep_search.main.models import (
|
||||
AgentAdditionalMetrics,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.main.models import AgentTimings
|
||||
from onyx.agents.agent_search.deep_search.main.operations import logger
|
||||
from onyx.agents.agent_search.deep_search.main.states import MainOutput
|
||||
from onyx.agents.agent_search.deep_search.main.states import MainState
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import CombinedAgentMetrics
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.db.chat import log_agent_metrics
|
||||
from onyx.db.chat import log_agent_sub_question_results
|
||||
|
||||
|
||||
def persist_agent_results(state: MainState, config: RunnableConfig) -> MainOutput:
|
||||
"""
|
||||
LangGraph node to persist the agent results, including agent logging data.
|
||||
"""
|
||||
node_start_time = datetime.now()
|
||||
|
||||
agent_start_time = state.agent_start_time
|
||||
agent_base_end_time = state.agent_base_end_time
|
||||
agent_refined_start_time = state.agent_refined_start_time
|
||||
agent_refined_end_time = state.agent_refined_end_time
|
||||
agent_end_time = agent_refined_end_time or agent_base_end_time
|
||||
|
||||
agent_base_duration = None
|
||||
if agent_base_end_time and agent_start_time:
|
||||
agent_base_duration = (agent_base_end_time - agent_start_time).total_seconds()
|
||||
|
||||
agent_refined_duration = None
|
||||
if agent_refined_start_time and agent_refined_end_time:
|
||||
agent_refined_duration = (
|
||||
agent_refined_end_time - agent_refined_start_time
|
||||
).total_seconds()
|
||||
|
||||
agent_full_duration = None
|
||||
if agent_end_time and agent_start_time:
|
||||
agent_full_duration = (agent_end_time - agent_start_time).total_seconds()
|
||||
|
||||
agent_type = "refined" if agent_refined_duration else "base"
|
||||
|
||||
agent_base_metrics = state.agent_base_metrics
|
||||
agent_refined_metrics = state.agent_refined_metrics
|
||||
|
||||
combined_agent_metrics = CombinedAgentMetrics(
|
||||
timings=AgentTimings(
|
||||
base_duration_s=agent_base_duration,
|
||||
refined_duration_s=agent_refined_duration,
|
||||
full_duration_s=agent_full_duration,
|
||||
),
|
||||
base_metrics=agent_base_metrics,
|
||||
refined_metrics=agent_refined_metrics,
|
||||
additional_metrics=AgentAdditionalMetrics(),
|
||||
)
|
||||
|
||||
persona_id = None
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
if graph_config.inputs.search_request.persona:
|
||||
persona_id = graph_config.inputs.search_request.persona.id
|
||||
|
||||
user_id = None
|
||||
assert (
|
||||
graph_config.tooling.search_tool
|
||||
), "search_tool must be provided for agentic search"
|
||||
user = graph_config.tooling.search_tool.user
|
||||
if user:
|
||||
user_id = user.id
|
||||
|
||||
# log the agent metrics
|
||||
if graph_config.persistence:
|
||||
if agent_base_duration is not None:
|
||||
log_agent_metrics(
|
||||
db_session=graph_config.persistence.db_session,
|
||||
user_id=user_id,
|
||||
persona_id=persona_id,
|
||||
agent_type=agent_type,
|
||||
start_time=agent_start_time,
|
||||
agent_metrics=combined_agent_metrics,
|
||||
)
|
||||
|
||||
# Persist the sub-answer in the database
|
||||
db_session = graph_config.persistence.db_session
|
||||
chat_session_id = graph_config.persistence.chat_session_id
|
||||
primary_message_id = graph_config.persistence.message_id
|
||||
sub_question_answer_results = state.sub_question_results
|
||||
|
||||
log_agent_sub_question_results(
|
||||
db_session=db_session,
|
||||
chat_session_id=chat_session_id,
|
||||
primary_message_id=primary_message_id,
|
||||
sub_question_answer_results=sub_question_answer_results,
|
||||
)
|
||||
|
||||
main_output = MainOutput(
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="main",
|
||||
node_name="persist agent results",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
for log_message in state.log_messages:
|
||||
logger.debug(log_message)
|
||||
|
||||
if state.agent_base_metrics:
|
||||
logger.debug(f"Initial loop: {state.agent_base_metrics.duration_s}")
|
||||
if state.agent_refined_metrics:
|
||||
logger.debug(f"Refined loop: {state.agent_refined_metrics.duration_s}")
|
||||
if (
|
||||
state.agent_base_metrics
|
||||
and state.agent_refined_metrics
|
||||
and state.agent_base_metrics.duration_s
|
||||
and state.agent_refined_metrics.duration_s
|
||||
):
|
||||
logger.debug(
|
||||
f"Total time: {float(state.agent_base_metrics.duration_s) + float(state.agent_refined_metrics.duration_s)}"
|
||||
)
|
||||
|
||||
return main_output
|
||||
@@ -0,0 +1,52 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.deep_search.main.states import (
|
||||
ExploratorySearchUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.main.states import MainState
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
build_history_prompt,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import retrieve_search_docs
|
||||
from onyx.configs.agent_configs import AGENT_EXPLORATORY_SEARCH_RESULTS
|
||||
from onyx.context.search.models import InferenceSection
|
||||
|
||||
|
||||
def start_agent_search(
|
||||
state: MainState, config: RunnableConfig
|
||||
) -> ExploratorySearchUpdate:
|
||||
"""
|
||||
LangGraph node to start the agentic search process.
|
||||
"""
|
||||
node_start_time = datetime.now()
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
question = graph_config.inputs.search_request.query
|
||||
|
||||
history = build_history_prompt(graph_config, question)
|
||||
|
||||
# Initial search to inform decomposition. Just get top 3 fits
|
||||
search_tool = graph_config.tooling.search_tool
|
||||
assert search_tool, "search_tool must be provided for agentic search"
|
||||
retrieved_docs: list[InferenceSection] = retrieve_search_docs(search_tool, question)
|
||||
|
||||
exploratory_search_results = retrieved_docs[:AGENT_EXPLORATORY_SEARCH_RESULTS]
|
||||
|
||||
return ExploratorySearchUpdate(
|
||||
exploratory_search_results=exploratory_search_results,
|
||||
previous_history_summary=history,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="main",
|
||||
node_name="start agent search",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
132
backend/onyx/agents/agent_search/deep_search/main/operations.py
Normal file
132
backend/onyx/agents/agent_search/deep_search/main/operations.py
Normal file
@@ -0,0 +1,132 @@
|
||||
from collections.abc import Callable
|
||||
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkRetrievalStats
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import InitialAgentResultStats
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import QueryRetrievalResult
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import (
|
||||
SubQuestionAnswerResults,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.chat.models import SubQuestionPiece
|
||||
from onyx.context.search.models import IndexFilters
|
||||
from onyx.tools.models import SearchQueryInfo
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def dispatch_subquestion(
|
||||
level: int, writer: StreamWriter
|
||||
) -> Callable[[str, int], None]:
|
||||
def _helper(sub_question_part: str, sep_num: int) -> None:
|
||||
write_custom_event(
|
||||
"decomp_qs",
|
||||
SubQuestionPiece(
|
||||
sub_question=sub_question_part,
|
||||
level=level,
|
||||
level_question_num=sep_num,
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
return _helper
|
||||
|
||||
|
||||
def calculate_initial_agent_stats(
|
||||
decomp_answer_results: list[SubQuestionAnswerResults],
|
||||
original_question_stats: AgentChunkRetrievalStats,
|
||||
) -> InitialAgentResultStats:
|
||||
initial_agent_result_stats: InitialAgentResultStats = InitialAgentResultStats(
|
||||
sub_questions={},
|
||||
original_question={},
|
||||
agent_effectiveness={},
|
||||
)
|
||||
|
||||
orig_verified = original_question_stats.verified_count
|
||||
orig_support_score = original_question_stats.verified_avg_scores
|
||||
|
||||
verified_document_chunk_ids = []
|
||||
support_scores = 0.0
|
||||
|
||||
for decomp_answer_result in decomp_answer_results:
|
||||
verified_document_chunk_ids += (
|
||||
decomp_answer_result.sub_question_retrieval_stats.verified_doc_chunk_ids
|
||||
)
|
||||
if (
|
||||
decomp_answer_result.sub_question_retrieval_stats.verified_avg_scores
|
||||
is not None
|
||||
):
|
||||
support_scores += (
|
||||
decomp_answer_result.sub_question_retrieval_stats.verified_avg_scores
|
||||
)
|
||||
|
||||
verified_document_chunk_ids = list(set(verified_document_chunk_ids))
|
||||
|
||||
# Calculate sub-question stats
|
||||
if (
|
||||
verified_document_chunk_ids
|
||||
and len(verified_document_chunk_ids) > 0
|
||||
and support_scores is not None
|
||||
):
|
||||
sub_question_stats: dict[str, float | int | None] = {
|
||||
"num_verified_documents": len(verified_document_chunk_ids),
|
||||
"verified_avg_score": float(support_scores / len(decomp_answer_results)),
|
||||
}
|
||||
else:
|
||||
sub_question_stats = {"num_verified_documents": 0, "verified_avg_score": None}
|
||||
|
||||
initial_agent_result_stats.sub_questions.update(sub_question_stats)
|
||||
|
||||
# Get original question stats
|
||||
initial_agent_result_stats.original_question.update(
|
||||
{
|
||||
"num_verified_documents": original_question_stats.verified_count,
|
||||
"verified_avg_score": original_question_stats.verified_avg_scores,
|
||||
}
|
||||
)
|
||||
|
||||
# Calculate chunk utilization ratio
|
||||
sub_verified = initial_agent_result_stats.sub_questions["num_verified_documents"]
|
||||
|
||||
chunk_ratio: float | None = None
|
||||
if sub_verified is not None and orig_verified is not None and orig_verified > 0:
|
||||
chunk_ratio = (float(sub_verified) / orig_verified) if sub_verified > 0 else 0.0
|
||||
elif sub_verified is not None and sub_verified > 0:
|
||||
chunk_ratio = 10.0
|
||||
|
||||
initial_agent_result_stats.agent_effectiveness["utilized_chunk_ratio"] = chunk_ratio
|
||||
|
||||
if (
|
||||
orig_support_score is None
|
||||
or orig_support_score == 0.0
|
||||
and initial_agent_result_stats.sub_questions["verified_avg_score"] is None
|
||||
):
|
||||
initial_agent_result_stats.agent_effectiveness["support_ratio"] = None
|
||||
elif orig_support_score is None or orig_support_score == 0.0:
|
||||
initial_agent_result_stats.agent_effectiveness["support_ratio"] = 10
|
||||
elif initial_agent_result_stats.sub_questions["verified_avg_score"] is None:
|
||||
initial_agent_result_stats.agent_effectiveness["support_ratio"] = 0
|
||||
else:
|
||||
initial_agent_result_stats.agent_effectiveness["support_ratio"] = (
|
||||
initial_agent_result_stats.sub_questions["verified_avg_score"]
|
||||
/ orig_support_score
|
||||
)
|
||||
|
||||
return initial_agent_result_stats
|
||||
|
||||
|
||||
def get_query_info(results: list[QueryRetrievalResult]) -> SearchQueryInfo:
|
||||
# Use the query info from the base document retrieval
|
||||
# this is used for some fields that are the same across the searches done
|
||||
query_info = None
|
||||
for result in results:
|
||||
if result.query_info is not None:
|
||||
query_info = result.query_info
|
||||
break
|
||||
return query_info or SearchQueryInfo(
|
||||
predicted_search=None,
|
||||
final_filters=IndexFilters(access_control_list=None),
|
||||
recency_bias_multiplier=1.0,
|
||||
)
|
||||
172
backend/onyx/agents/agent_search/deep_search/main/states.py
Normal file
172
backend/onyx/agents/agent_search/deep_search/main/states.py
Normal file
@@ -0,0 +1,172 @@
|
||||
from datetime import datetime
|
||||
from operator import add
|
||||
from typing import Annotated
|
||||
from typing import TypedDict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.agents.agent_search.core_state import CoreState
|
||||
from onyx.agents.agent_search.deep_search.main.models import AgentBaseMetrics
|
||||
from onyx.agents.agent_search.deep_search.main.models import (
|
||||
AgentRefinedMetrics,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.main.models import (
|
||||
RefinementSubQuestion,
|
||||
)
|
||||
from onyx.agents.agent_search.orchestration.states import ToolCallUpdate
|
||||
from onyx.agents.agent_search.orchestration.states import ToolChoiceInput
|
||||
from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkRetrievalStats
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import (
|
||||
EntityRelationshipTermExtraction,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import InitialAgentResultStats
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import QueryRetrievalResult
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import RefinedAgentStats
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import (
|
||||
SubQuestionAnswerResults,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.operators import (
|
||||
dedup_inference_sections,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.operators import (
|
||||
dedup_question_answer_results,
|
||||
)
|
||||
from onyx.context.search.models import InferenceSection
|
||||
|
||||
|
||||
### States ###
|
||||
class LoggerUpdate(BaseModel):
|
||||
log_messages: Annotated[list[str], add] = []
|
||||
|
||||
|
||||
class RefinedAgentStartStats(BaseModel):
|
||||
agent_refined_start_time: datetime | None = None
|
||||
|
||||
|
||||
class RefinedAgentEndStats(BaseModel):
|
||||
agent_refined_end_time: datetime | None = None
|
||||
agent_refined_metrics: AgentRefinedMetrics = AgentRefinedMetrics()
|
||||
|
||||
|
||||
class InitialQuestionDecompositionUpdate(
|
||||
RefinedAgentStartStats, RefinedAgentEndStats, LoggerUpdate
|
||||
):
|
||||
agent_start_time: datetime | None = None
|
||||
previous_history: str | None = None
|
||||
initial_sub_questions: list[str] = []
|
||||
|
||||
|
||||
class ExploratorySearchUpdate(LoggerUpdate):
|
||||
exploratory_search_results: list[InferenceSection] = []
|
||||
previous_history_summary: str | None = None
|
||||
|
||||
|
||||
class InitialRefinedAnswerComparisonUpdate(LoggerUpdate):
|
||||
"""
|
||||
Evaluation of whether the refined answer is better than the initial answer
|
||||
"""
|
||||
|
||||
refined_answer_improvement_eval: bool = False
|
||||
|
||||
|
||||
class InitialAnswerUpdate(LoggerUpdate):
|
||||
"""
|
||||
Initial answer information
|
||||
"""
|
||||
|
||||
initial_answer: str | None = None
|
||||
initial_agent_stats: InitialAgentResultStats | None = None
|
||||
generated_sub_questions: list[str] = []
|
||||
agent_base_end_time: datetime | None = None
|
||||
agent_base_metrics: AgentBaseMetrics | None = None
|
||||
|
||||
|
||||
class RefinedAnswerUpdate(RefinedAgentEndStats, LoggerUpdate):
|
||||
"""
|
||||
Refined answer information
|
||||
"""
|
||||
|
||||
refined_answer: str | None = None
|
||||
refined_agent_stats: RefinedAgentStats | None = None
|
||||
refined_answer_quality: bool = False
|
||||
|
||||
|
||||
class InitialAnswerQualityUpdate(LoggerUpdate):
|
||||
"""
|
||||
Initial answer quality evaluation
|
||||
"""
|
||||
|
||||
initial_answer_quality_eval: bool = False
|
||||
|
||||
|
||||
class RequireRefinemenEvalUpdate(LoggerUpdate):
|
||||
require_refined_answer_eval: bool = True
|
||||
|
||||
|
||||
class SubQuestionResultsUpdate(LoggerUpdate):
|
||||
verified_reranked_documents: Annotated[
|
||||
list[InferenceSection], dedup_inference_sections
|
||||
] = []
|
||||
context_documents: Annotated[list[InferenceSection], dedup_inference_sections] = []
|
||||
cited_documents: Annotated[
|
||||
list[InferenceSection], dedup_inference_sections
|
||||
] = [] # cited docs from sub-answers are used for answer context
|
||||
sub_question_results: Annotated[
|
||||
list[SubQuestionAnswerResults], dedup_question_answer_results
|
||||
] = []
|
||||
|
||||
|
||||
class OrigQuestionRetrievalUpdate(LoggerUpdate):
|
||||
orig_question_retrieved_documents: Annotated[
|
||||
list[InferenceSection], dedup_inference_sections
|
||||
]
|
||||
orig_question_verified_reranked_documents: Annotated[
|
||||
list[InferenceSection], dedup_inference_sections
|
||||
]
|
||||
orig_question_sub_query_retrieval_results: list[QueryRetrievalResult] = []
|
||||
orig_question_retrieval_stats: AgentChunkRetrievalStats = AgentChunkRetrievalStats()
|
||||
|
||||
|
||||
class EntityTermExtractionUpdate(LoggerUpdate):
|
||||
entity_relation_term_extractions: EntityRelationshipTermExtraction = (
|
||||
EntityRelationshipTermExtraction()
|
||||
)
|
||||
|
||||
|
||||
class RefinedQuestionDecompositionUpdate(RefinedAgentStartStats, LoggerUpdate):
|
||||
refined_sub_questions: dict[int, RefinementSubQuestion] = {}
|
||||
|
||||
|
||||
## Graph Input State
|
||||
class MainInput(CoreState):
|
||||
pass
|
||||
|
||||
|
||||
## Graph State
|
||||
class MainState(
|
||||
# This includes the core state
|
||||
MainInput,
|
||||
ToolChoiceInput,
|
||||
ToolCallUpdate,
|
||||
ToolChoiceUpdate,
|
||||
InitialQuestionDecompositionUpdate,
|
||||
InitialAnswerUpdate,
|
||||
SubQuestionResultsUpdate,
|
||||
OrigQuestionRetrievalUpdate,
|
||||
EntityTermExtractionUpdate,
|
||||
InitialAnswerQualityUpdate,
|
||||
RequireRefinemenEvalUpdate,
|
||||
RefinedQuestionDecompositionUpdate,
|
||||
RefinedAnswerUpdate,
|
||||
RefinedAgentStartStats,
|
||||
RefinedAgentEndStats,
|
||||
InitialRefinedAnswerComparisonUpdate,
|
||||
ExploratorySearchUpdate,
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
## Graph Output State - presently not used
|
||||
class MainOutput(TypedDict):
|
||||
log_messages: list[str]
|
||||
@@ -0,0 +1,33 @@
|
||||
from collections.abc import Hashable
|
||||
from datetime import datetime
|
||||
|
||||
from langgraph.types import Send
|
||||
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
|
||||
SubQuestionAnsweringInput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
|
||||
ExpandedRetrievalInput,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def send_to_expanded_refined_retrieval(
|
||||
state: SubQuestionAnsweringInput,
|
||||
) -> Send | Hashable:
|
||||
"""
|
||||
LangGraph edge to sends a refined sub-question extended retrieval.
|
||||
"""
|
||||
logger.debug("sending to expanded retrieval for follow up question via edge")
|
||||
datetime.now()
|
||||
return Send(
|
||||
"refined_sub_question_expanded_retrieval",
|
||||
ExpandedRetrievalInput(
|
||||
question=state.question,
|
||||
sub_question_id=state.question_id,
|
||||
base_search=False,
|
||||
log_messages=[f"{datetime.now()} -- Sending to expanded retrieval"],
|
||||
),
|
||||
)
|
||||
@@ -0,0 +1,132 @@
|
||||
from langgraph.graph import END
|
||||
from langgraph.graph import START
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.nodes.check_sub_answer import (
|
||||
check_sub_answer,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.nodes.format_sub_answer import (
|
||||
format_sub_answer,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.nodes.generate_sub_answer import (
|
||||
generate_sub_answer,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.nodes.ingest_retrieved_documents import (
|
||||
ingest_retrieved_documents,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
|
||||
AnswerQuestionOutput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
|
||||
AnswerQuestionState,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
|
||||
SubQuestionAnsweringInput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.refinement.consolidate_sub_answers.edges import (
|
||||
send_to_expanded_refined_retrieval,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.graph_builder import (
|
||||
expanded_retrieval_graph_builder,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def answer_refined_query_graph_builder() -> StateGraph:
|
||||
"""
|
||||
LangGraph graph builder for the refined sub-answer generation process.
|
||||
"""
|
||||
graph = StateGraph(
|
||||
state_schema=AnswerQuestionState,
|
||||
input=SubQuestionAnsweringInput,
|
||||
output=AnswerQuestionOutput,
|
||||
)
|
||||
|
||||
### Add nodes ###
|
||||
|
||||
# Subgraph for the expanded retrieval process
|
||||
expanded_retrieval = expanded_retrieval_graph_builder().compile()
|
||||
graph.add_node(
|
||||
node="refined_sub_question_expanded_retrieval",
|
||||
action=expanded_retrieval,
|
||||
)
|
||||
|
||||
# Ingest the retrieved documents
|
||||
graph.add_node(
|
||||
node="ingest_refined_retrieval",
|
||||
action=ingest_retrieved_documents,
|
||||
)
|
||||
|
||||
# Generate the refined sub-answer
|
||||
graph.add_node(
|
||||
node="generate_refined_sub_answer",
|
||||
action=generate_sub_answer,
|
||||
)
|
||||
|
||||
# Check if the refined sub-answer is correct
|
||||
graph.add_node(
|
||||
node="refined_sub_answer_check",
|
||||
action=check_sub_answer,
|
||||
)
|
||||
|
||||
# Format the refined sub-answer
|
||||
graph.add_node(
|
||||
node="format_refined_sub_answer",
|
||||
action=format_sub_answer,
|
||||
)
|
||||
|
||||
### Add edges ###
|
||||
|
||||
graph.add_conditional_edges(
|
||||
source=START,
|
||||
path=send_to_expanded_refined_retrieval,
|
||||
path_map=["refined_sub_question_expanded_retrieval"],
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="refined_sub_question_expanded_retrieval",
|
||||
end_key="ingest_refined_retrieval",
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="ingest_refined_retrieval",
|
||||
end_key="generate_refined_sub_answer",
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="generate_refined_sub_answer",
|
||||
end_key="refined_sub_answer_check",
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="refined_sub_answer_check",
|
||||
end_key="format_refined_sub_answer",
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="format_refined_sub_answer",
|
||||
end_key=END,
|
||||
)
|
||||
|
||||
return graph
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.llm.factory import get_default_llms
|
||||
from onyx.context.search.models import SearchRequest
|
||||
|
||||
graph = answer_refined_query_graph_builder()
|
||||
compiled_graph = graph.compile()
|
||||
primary_llm, fast_llm = get_default_llms()
|
||||
search_request = SearchRequest(
|
||||
query="what can you do with onyx or danswer?",
|
||||
)
|
||||
with get_session_context_manager() as db_session:
|
||||
inputs = SubQuestionAnsweringInput(
|
||||
question="what can you do with onyx?",
|
||||
question_id="0_0",
|
||||
log_messages=[],
|
||||
)
|
||||
for thing in compiled_graph.stream(
|
||||
input=inputs,
|
||||
stream_mode="custom",
|
||||
):
|
||||
logger.debug(thing)
|
||||
@@ -0,0 +1,42 @@
|
||||
from collections.abc import Hashable
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
from langgraph.types import Send
|
||||
|
||||
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
|
||||
ExpandedRetrievalState,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
|
||||
RetrievalInput,
|
||||
)
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
|
||||
|
||||
def parallel_retrieval_edge(
|
||||
state: ExpandedRetrievalState, config: RunnableConfig
|
||||
) -> list[Send | Hashable]:
|
||||
"""
|
||||
LangGraph edge to parallelize the retrieval process for each of the
|
||||
generated sub-queries and the original question.
|
||||
"""
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
question = (
|
||||
state.question if state.question else graph_config.inputs.search_request.query
|
||||
)
|
||||
|
||||
query_expansions = state.expanded_queries + [question]
|
||||
|
||||
return [
|
||||
Send(
|
||||
"retrieve_documents",
|
||||
RetrievalInput(
|
||||
query_to_retrieve=query,
|
||||
question=question,
|
||||
base_search=False,
|
||||
sub_question_id=state.sub_question_id,
|
||||
log_messages=[],
|
||||
),
|
||||
)
|
||||
for query in query_expansions
|
||||
]
|
||||
@@ -0,0 +1,161 @@
|
||||
from langgraph.graph import END
|
||||
from langgraph.graph import START
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.edges import (
|
||||
parallel_retrieval_edge,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.nodes.expand_queries import (
|
||||
expand_queries,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.nodes.format_queries import (
|
||||
format_queries,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.nodes.format_results import (
|
||||
format_results,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.nodes.kickoff_verification import (
|
||||
kickoff_verification,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.nodes.rerank_documents import (
|
||||
rerank_documents,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.nodes.retrieve_documents import (
|
||||
retrieve_documents,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.nodes.verify_documents import (
|
||||
verify_documents,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
|
||||
ExpandedRetrievalInput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
|
||||
ExpandedRetrievalOutput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
|
||||
ExpandedRetrievalState,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def expanded_retrieval_graph_builder() -> StateGraph:
|
||||
"""
|
||||
LangGraph graph builder for the expanded retrieval process.
|
||||
"""
|
||||
graph = StateGraph(
|
||||
state_schema=ExpandedRetrievalState,
|
||||
input=ExpandedRetrievalInput,
|
||||
output=ExpandedRetrievalOutput,
|
||||
)
|
||||
|
||||
### Add nodes ###
|
||||
|
||||
# Convert the question into multiple sub-queries
|
||||
graph.add_node(
|
||||
node="expand_queries",
|
||||
action=expand_queries,
|
||||
)
|
||||
|
||||
# Format the sub-queries into a list of strings
|
||||
graph.add_node(
|
||||
node="format_queries",
|
||||
action=format_queries,
|
||||
)
|
||||
|
||||
# Retrieve the documents for each sub-query
|
||||
graph.add_node(
|
||||
node="retrieve_documents",
|
||||
action=retrieve_documents,
|
||||
)
|
||||
|
||||
# Start verification process that the documents are relevant to the question (not the query)
|
||||
graph.add_node(
|
||||
node="kickoff_verification",
|
||||
action=kickoff_verification,
|
||||
)
|
||||
|
||||
# Verify that a given document is relevant to the question (not the query)
|
||||
graph.add_node(
|
||||
node="verify_documents",
|
||||
action=verify_documents,
|
||||
)
|
||||
|
||||
# Rerank the documents that have been verified
|
||||
graph.add_node(
|
||||
node="rerank_documents",
|
||||
action=rerank_documents,
|
||||
)
|
||||
|
||||
# Format the results into a list of strings
|
||||
graph.add_node(
|
||||
node="format_results",
|
||||
action=format_results,
|
||||
)
|
||||
|
||||
### Add edges ###
|
||||
graph.add_edge(
|
||||
start_key=START,
|
||||
end_key="expand_queries",
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="expand_queries",
|
||||
end_key="format_queries",
|
||||
)
|
||||
|
||||
graph.add_conditional_edges(
|
||||
source="format_queries",
|
||||
path=parallel_retrieval_edge,
|
||||
path_map=["retrieve_documents"],
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="retrieve_documents",
|
||||
end_key="kickoff_verification",
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="verify_documents",
|
||||
end_key="rerank_documents",
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="rerank_documents",
|
||||
end_key="format_results",
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="format_results",
|
||||
end_key=END,
|
||||
)
|
||||
|
||||
return graph
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.llm.factory import get_default_llms
|
||||
from onyx.context.search.models import SearchRequest
|
||||
|
||||
graph = expanded_retrieval_graph_builder()
|
||||
compiled_graph = graph.compile()
|
||||
primary_llm, fast_llm = get_default_llms()
|
||||
search_request = SearchRequest(
|
||||
query="what can you do with onyx or danswer?",
|
||||
)
|
||||
|
||||
with get_session_context_manager() as db_session:
|
||||
graph_config, search_tool = get_test_config(
|
||||
db_session, primary_llm, fast_llm, search_request
|
||||
)
|
||||
inputs = ExpandedRetrievalInput(
|
||||
question="what can you do with onyx?",
|
||||
base_search=False,
|
||||
sub_question_id=None,
|
||||
log_messages=[],
|
||||
)
|
||||
for thing in compiled_graph.stream(
|
||||
input=inputs,
|
||||
config={"configurable": {"config": graph_config}},
|
||||
stream_mode="custom",
|
||||
subgraphs=True,
|
||||
):
|
||||
logger.debug(thing)
|
||||
@@ -0,0 +1,13 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkRetrievalStats
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import QueryRetrievalResult
|
||||
from onyx.context.search.models import InferenceSection
|
||||
|
||||
|
||||
class QuestionRetrievalResult(BaseModel):
|
||||
expanded_query_results: list[QueryRetrievalResult] = []
|
||||
retrieved_documents: list[InferenceSection] = []
|
||||
verified_reranked_documents: list[InferenceSection] = []
|
||||
context_documents: list[InferenceSection] = []
|
||||
retrieval_stats: AgentChunkRetrievalStats = AgentChunkRetrievalStats()
|
||||
@@ -0,0 +1,75 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import merge_message_runs
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.operations import (
|
||||
dispatch_subquery,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
|
||||
ExpandedRetrievalInput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
|
||||
QueryExpansionUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import dispatch_separated
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
|
||||
from onyx.prompts.agent_search import (
|
||||
QUERY_REWRITING_PROMPT,
|
||||
)
|
||||
|
||||
|
||||
def expand_queries(
|
||||
state: ExpandedRetrievalInput,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
) -> QueryExpansionUpdate:
|
||||
"""
|
||||
LangGraph node to expand a question into multiple search queries.
|
||||
"""
|
||||
# Sometimes we want to expand the original question, sometimes we want to expand a sub-question.
|
||||
# When we are running this node on the original question, no question is explictly passed in.
|
||||
# Instead, we use the original question from the search request.
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
node_start_time = datetime.now()
|
||||
question = state.question
|
||||
|
||||
llm = graph_config.tooling.fast_llm
|
||||
sub_question_id = state.sub_question_id
|
||||
if sub_question_id is None:
|
||||
level, question_num = 0, 0
|
||||
else:
|
||||
level, question_num = parse_question_id(sub_question_id)
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=QUERY_REWRITING_PROMPT.format(question=question),
|
||||
)
|
||||
]
|
||||
|
||||
llm_response_list = dispatch_separated(
|
||||
llm.stream(prompt=msg), dispatch_subquery(level, question_num, writer)
|
||||
)
|
||||
|
||||
llm_response = merge_message_runs(llm_response_list, chunk_separator="")[0].content
|
||||
|
||||
rewritten_queries = llm_response.split("\n")
|
||||
|
||||
return QueryExpansionUpdate(
|
||||
expanded_queries=rewritten_queries,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="shared - expanded retrieval",
|
||||
node_name="expand queries",
|
||||
node_start_time=node_start_time,
|
||||
result=f"Number of expanded queries: {len(rewritten_queries)}",
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,19 @@
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
|
||||
ExpandedRetrievalState,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
|
||||
QueryExpansionUpdate,
|
||||
)
|
||||
|
||||
|
||||
def format_queries(
|
||||
state: ExpandedRetrievalState, config: RunnableConfig
|
||||
) -> QueryExpansionUpdate:
|
||||
"""
|
||||
LangGraph node to format the expanded queries into a list of strings.
|
||||
"""
|
||||
return QueryExpansionUpdate(
|
||||
expanded_queries=state.expanded_queries,
|
||||
)
|
||||
@@ -0,0 +1,91 @@
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.deep_search.main.operations import get_query_info
|
||||
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.models import (
|
||||
QuestionRetrievalResult,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.operations import (
|
||||
calculate_sub_question_retrieval_stats,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
|
||||
ExpandedRetrievalState,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
|
||||
ExpandedRetrievalUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkRetrievalStats
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import relevance_from_docs
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.chat.models import ExtendedToolResponse
|
||||
from onyx.tools.tool_implementations.search.search_tool import yield_search_responses
|
||||
|
||||
|
||||
def format_results(
|
||||
state: ExpandedRetrievalState,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
) -> ExpandedRetrievalUpdate:
|
||||
"""
|
||||
LangGraph node that constructs the proper expanded retrieval format.
|
||||
"""
|
||||
level, question_num = parse_question_id(state.sub_question_id or "0_0")
|
||||
query_info = get_query_info(state.query_retrieval_results)
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
|
||||
# Main question docs will be sent later after aggregation and deduping with sub-question docs
|
||||
reranked_documents = state.reranked_documents
|
||||
|
||||
if not (level == 0 and question_num == 0):
|
||||
if len(reranked_documents) == 0:
|
||||
# The sub-question is used as the last query. If no verified documents are found, stream
|
||||
# the top 3 for that one. We may want to revisit this.
|
||||
reranked_documents = state.query_retrieval_results[-1].retrieved_documents[
|
||||
:3
|
||||
]
|
||||
|
||||
assert (
|
||||
graph_config.tooling.search_tool
|
||||
), "search_tool must be provided for agentic search"
|
||||
|
||||
relevance_list = relevance_from_docs(reranked_documents)
|
||||
for tool_response in yield_search_responses(
|
||||
query=state.question,
|
||||
reranked_sections=state.retrieved_documents,
|
||||
final_context_sections=reranked_documents,
|
||||
search_query_info=query_info,
|
||||
get_section_relevance=lambda: relevance_list,
|
||||
search_tool=graph_config.tooling.search_tool,
|
||||
):
|
||||
write_custom_event(
|
||||
"tool_response",
|
||||
ExtendedToolResponse(
|
||||
id=tool_response.id,
|
||||
response=tool_response.response,
|
||||
level=level,
|
||||
level_question_num=question_num,
|
||||
),
|
||||
writer,
|
||||
)
|
||||
sub_question_retrieval_stats = calculate_sub_question_retrieval_stats(
|
||||
verified_documents=state.verified_documents,
|
||||
expanded_retrieval_results=state.query_retrieval_results,
|
||||
)
|
||||
|
||||
if sub_question_retrieval_stats is None:
|
||||
sub_question_retrieval_stats = AgentChunkRetrievalStats()
|
||||
|
||||
return ExpandedRetrievalUpdate(
|
||||
expanded_retrieval_result=QuestionRetrievalResult(
|
||||
expanded_query_results=state.query_retrieval_results,
|
||||
retrieved_documents=state.retrieved_documents,
|
||||
verified_reranked_documents=reranked_documents,
|
||||
context_documents=state.reranked_documents,
|
||||
retrieval_stats=sub_question_retrieval_stats,
|
||||
),
|
||||
)
|
||||
@@ -0,0 +1,44 @@
|
||||
from typing import Literal
|
||||
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
from langgraph.types import Command
|
||||
from langgraph.types import Send
|
||||
|
||||
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
|
||||
DocVerificationInput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
|
||||
ExpandedRetrievalState,
|
||||
)
|
||||
|
||||
|
||||
def kickoff_verification(
|
||||
state: ExpandedRetrievalState,
|
||||
config: RunnableConfig,
|
||||
) -> Command[Literal["verify_documents"]]:
|
||||
"""
|
||||
LangGraph node (Command node!) that kicks off the verification process for the retrieved documents.
|
||||
Note that this is a Command node and does the routing as well. (At present, no state updates
|
||||
are done here, so this could be replaced with an edge. But we may choose to make state
|
||||
updates later.)
|
||||
"""
|
||||
retrieved_documents = state.retrieved_documents
|
||||
verification_question = state.question
|
||||
|
||||
sub_question_id = state.sub_question_id
|
||||
return Command(
|
||||
update={},
|
||||
goto=[
|
||||
Send(
|
||||
node="verify_documents",
|
||||
arg=DocVerificationInput(
|
||||
retrieved_document_to_verify=document,
|
||||
question=verification_question,
|
||||
base_search=False,
|
||||
sub_question_id=sub_question_id,
|
||||
log_messages=[],
|
||||
),
|
||||
)
|
||||
for document in retrieved_documents
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,105 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.operations import (
|
||||
logger,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
|
||||
DocRerankingUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
|
||||
ExpandedRetrievalState,
|
||||
)
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.calculations import get_fit_scores
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import RetrievalFitStats
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.configs.agent_configs import AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS
|
||||
from onyx.configs.agent_configs import AGENT_RERANKING_STATS
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.context.search.models import SearchRequest
|
||||
from onyx.context.search.pipeline import retrieval_preprocessing
|
||||
from onyx.context.search.postprocessing.postprocessing import rerank_sections
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
|
||||
|
||||
def rerank_documents(
|
||||
state: ExpandedRetrievalState, config: RunnableConfig
|
||||
) -> DocRerankingUpdate:
|
||||
"""
|
||||
LangGraph node to rerank the retrieved and verified documents. A part of the
|
||||
pre-existing pipeline is used here.
|
||||
"""
|
||||
node_start_time = datetime.now()
|
||||
verified_documents = state.verified_documents
|
||||
|
||||
# Rerank post retrieval and verification. First, create a search query
|
||||
# then create the list of reranked sections
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
question = (
|
||||
state.question if state.question else graph_config.inputs.search_request.query
|
||||
)
|
||||
assert (
|
||||
graph_config.tooling.search_tool
|
||||
), "search_tool must be provided for agentic search"
|
||||
with get_session_context_manager() as db_session:
|
||||
# we ignore some of the user specified fields since this search is
|
||||
# internal to agentic search, but we still want to pass through
|
||||
# persona (for stuff like document sets) and rerank settings
|
||||
# (to not make an unnecessary db call).
|
||||
search_request = SearchRequest(
|
||||
query=question,
|
||||
persona=graph_config.inputs.search_request.persona,
|
||||
rerank_settings=graph_config.inputs.search_request.rerank_settings,
|
||||
)
|
||||
_search_query = retrieval_preprocessing(
|
||||
search_request=search_request,
|
||||
user=graph_config.tooling.search_tool.user, # bit of a hack
|
||||
llm=graph_config.tooling.fast_llm,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# skip section filtering
|
||||
|
||||
if (
|
||||
_search_query.rerank_settings
|
||||
and _search_query.rerank_settings.rerank_model_name
|
||||
and _search_query.rerank_settings.num_rerank > 0
|
||||
and len(verified_documents) > 0
|
||||
):
|
||||
if len(verified_documents) > 1:
|
||||
reranked_documents = rerank_sections(
|
||||
_search_query,
|
||||
verified_documents,
|
||||
)
|
||||
else:
|
||||
num = "No" if len(verified_documents) == 0 else "One"
|
||||
logger.warning(f"{num} verified document(s) found, skipping reranking")
|
||||
reranked_documents = verified_documents
|
||||
else:
|
||||
logger.warning("No reranking settings found, using unranked documents")
|
||||
reranked_documents = verified_documents
|
||||
|
||||
if AGENT_RERANKING_STATS:
|
||||
fit_scores = get_fit_scores(verified_documents, reranked_documents)
|
||||
else:
|
||||
fit_scores = RetrievalFitStats(fit_score_lift=0, rerank_effect=0, fit_scores={})
|
||||
|
||||
return DocRerankingUpdate(
|
||||
reranked_documents=[
|
||||
doc for doc in reranked_documents if type(doc) == InferenceSection
|
||||
][:AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS],
|
||||
sub_question_retrieval_stats=fit_scores,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="shared - expanded retrieval",
|
||||
node_name="rerank documents",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,113 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.operations import (
|
||||
logger,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
|
||||
DocRetrievalUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
|
||||
RetrievalInput,
|
||||
)
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.calculations import get_fit_scores
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import QueryRetrievalResult
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.configs.agent_configs import AGENT_MAX_QUERY_RETRIEVAL_RESULTS
|
||||
from onyx.configs.agent_configs import AGENT_RETRIEVAL_STATS
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.tools.models import SearchQueryInfo
|
||||
from onyx.tools.tool_implementations.search.search_tool import (
|
||||
SEARCH_RESPONSE_SUMMARY_ID,
|
||||
)
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary
|
||||
|
||||
|
||||
def retrieve_documents(
|
||||
state: RetrievalInput, config: RunnableConfig
|
||||
) -> DocRetrievalUpdate:
|
||||
"""
|
||||
LangGraph node to retrieve documents from the search tool.
|
||||
"""
|
||||
node_start_time = datetime.now()
|
||||
query_to_retrieve = state.query_to_retrieve
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
search_tool = graph_config.tooling.search_tool
|
||||
|
||||
retrieved_docs: list[InferenceSection] = []
|
||||
if not query_to_retrieve.strip():
|
||||
logger.warning("Empty query, skipping retrieval")
|
||||
|
||||
return DocRetrievalUpdate(
|
||||
query_retrieval_results=[],
|
||||
retrieved_documents=[],
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="shared - expanded retrieval",
|
||||
node_name="retrieve documents",
|
||||
node_start_time=node_start_time,
|
||||
result="Empty query, skipping retrieval",
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
query_info = None
|
||||
if search_tool is None:
|
||||
raise ValueError("search_tool must be provided for agentic search")
|
||||
|
||||
callback_container: list[list[InferenceSection]] = []
|
||||
|
||||
# new db session to avoid concurrency issues
|
||||
with get_session_context_manager() as db_session:
|
||||
for tool_response in search_tool.run(
|
||||
query=query_to_retrieve,
|
||||
force_no_rerank=True,
|
||||
alternate_db_session=db_session,
|
||||
retrieved_sections_callback=callback_container.append,
|
||||
):
|
||||
# get retrieved docs to send to the rest of the graph
|
||||
if tool_response.id == SEARCH_RESPONSE_SUMMARY_ID:
|
||||
response = cast(SearchResponseSummary, tool_response.response)
|
||||
retrieved_docs = response.top_sections
|
||||
query_info = SearchQueryInfo(
|
||||
predicted_search=response.predicted_search,
|
||||
final_filters=response.final_filters,
|
||||
recency_bias_multiplier=response.recency_bias_multiplier,
|
||||
)
|
||||
break
|
||||
|
||||
retrieved_docs = retrieved_docs[:AGENT_MAX_QUERY_RETRIEVAL_RESULTS]
|
||||
|
||||
if AGENT_RETRIEVAL_STATS:
|
||||
pre_rerank_docs = callback_container[0]
|
||||
fit_scores = get_fit_scores(
|
||||
pre_rerank_docs,
|
||||
retrieved_docs,
|
||||
)
|
||||
else:
|
||||
fit_scores = None
|
||||
|
||||
expanded_retrieval_result = QueryRetrievalResult(
|
||||
query=query_to_retrieve,
|
||||
retrieved_documents=retrieved_docs,
|
||||
stats=fit_scores,
|
||||
query_info=query_info,
|
||||
)
|
||||
|
||||
return DocRetrievalUpdate(
|
||||
query_retrieval_results=[expanded_retrieval_result],
|
||||
retrieved_documents=retrieved_docs,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="shared - expanded retrieval",
|
||||
node_name="retrieve documents",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,62 @@
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
|
||||
DocVerificationInput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
|
||||
DocVerificationUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
trim_prompt_piece,
|
||||
)
|
||||
from onyx.prompts.agent_search import (
|
||||
DOCUMENT_VERIFICATION_PROMPT,
|
||||
)
|
||||
|
||||
|
||||
def verify_documents(
|
||||
state: DocVerificationInput, config: RunnableConfig
|
||||
) -> DocVerificationUpdate:
|
||||
"""
|
||||
LangGraph node to check whether the document is relevant for the original user question
|
||||
|
||||
Args:
|
||||
state (DocVerificationInput): The current state
|
||||
config (RunnableConfig): Configuration containing ProSearchConfig
|
||||
|
||||
Updates:
|
||||
verified_documents: list[InferenceSection]
|
||||
"""
|
||||
|
||||
question = state.question
|
||||
retrieved_document_to_verify = state.retrieved_document_to_verify
|
||||
document_content = retrieved_document_to_verify.combined_content
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
fast_llm = graph_config.tooling.fast_llm
|
||||
|
||||
document_content = trim_prompt_piece(
|
||||
fast_llm.config, document_content, DOCUMENT_VERIFICATION_PROMPT + question
|
||||
)
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=DOCUMENT_VERIFICATION_PROMPT.format(
|
||||
question=question, document_content=document_content
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
response = fast_llm.invoke(msg)
|
||||
|
||||
verified_documents = []
|
||||
if isinstance(response.content, str) and "yes" in response.content.lower():
|
||||
verified_documents.append(retrieved_document_to_verify)
|
||||
|
||||
return DocVerificationUpdate(
|
||||
verified_documents=verified_documents,
|
||||
)
|
||||
@@ -0,0 +1,93 @@
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
|
||||
import numpy as np
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkRetrievalStats
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import QueryRetrievalResult
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.chat.models import SubQueryPiece
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def dispatch_subquery(
|
||||
level: int, question_num: int, writer: StreamWriter
|
||||
) -> Callable[[str, int], None]:
|
||||
def helper(token: str, num: int) -> None:
|
||||
write_custom_event(
|
||||
"subqueries",
|
||||
SubQueryPiece(
|
||||
sub_query=token,
|
||||
level=level,
|
||||
level_question_num=question_num,
|
||||
query_id=num,
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
return helper
|
||||
|
||||
|
||||
def calculate_sub_question_retrieval_stats(
|
||||
verified_documents: list[InferenceSection],
|
||||
expanded_retrieval_results: list[QueryRetrievalResult],
|
||||
) -> AgentChunkRetrievalStats:
|
||||
chunk_scores: dict[str, dict[str, list[int | float]]] = defaultdict(
|
||||
lambda: defaultdict(list)
|
||||
)
|
||||
|
||||
for expanded_retrieval_result in expanded_retrieval_results:
|
||||
for doc in expanded_retrieval_result.retrieved_documents:
|
||||
doc_chunk_id = f"{doc.center_chunk.document_id}_{doc.center_chunk.chunk_id}"
|
||||
if doc.center_chunk.score is not None:
|
||||
chunk_scores[doc_chunk_id]["score"].append(doc.center_chunk.score)
|
||||
|
||||
verified_doc_chunk_ids = [
|
||||
f"{verified_document.center_chunk.document_id}_{verified_document.center_chunk.chunk_id}"
|
||||
for verified_document in verified_documents
|
||||
]
|
||||
dismissed_doc_chunk_ids = []
|
||||
|
||||
raw_chunk_stats_counts: dict[str, int] = defaultdict(int)
|
||||
raw_chunk_stats_scores: dict[str, float] = defaultdict(float)
|
||||
for doc_chunk_id, chunk_data in chunk_scores.items():
|
||||
valid_chunk_scores = [
|
||||
score for score in chunk_data["score"] if score is not None
|
||||
]
|
||||
key = "verified" if doc_chunk_id in verified_doc_chunk_ids else "rejected"
|
||||
raw_chunk_stats_counts[f"{key}_count"] += 1
|
||||
|
||||
raw_chunk_stats_scores[f"{key}_scores"] += float(np.mean(valid_chunk_scores))
|
||||
|
||||
if key == "rejected":
|
||||
dismissed_doc_chunk_ids.append(doc_chunk_id)
|
||||
|
||||
if raw_chunk_stats_counts["verified_count"] == 0:
|
||||
verified_avg_scores = 0.0
|
||||
else:
|
||||
verified_avg_scores = raw_chunk_stats_scores["verified_scores"] / float(
|
||||
raw_chunk_stats_counts["verified_count"]
|
||||
)
|
||||
|
||||
rejected_scores = raw_chunk_stats_scores.get("rejected_scores")
|
||||
if rejected_scores is not None:
|
||||
rejected_avg_scores = rejected_scores / float(
|
||||
raw_chunk_stats_counts["rejected_count"]
|
||||
)
|
||||
else:
|
||||
rejected_avg_scores = None
|
||||
|
||||
chunk_stats = AgentChunkRetrievalStats(
|
||||
verified_count=raw_chunk_stats_counts["verified_count"],
|
||||
verified_avg_scores=verified_avg_scores,
|
||||
rejected_count=raw_chunk_stats_counts["rejected_count"],
|
||||
rejected_avg_scores=rejected_avg_scores,
|
||||
verified_doc_chunk_ids=verified_doc_chunk_ids,
|
||||
dismissed_doc_chunk_ids=dismissed_doc_chunk_ids,
|
||||
)
|
||||
|
||||
return chunk_stats
|
||||
@@ -0,0 +1,91 @@
|
||||
from operator import add
|
||||
from typing import Annotated
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.agents.agent_search.core_state import SubgraphCoreState
|
||||
from onyx.agents.agent_search.deep_search.main.states import LoggerUpdate
|
||||
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.models import (
|
||||
QuestionRetrievalResult,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import QueryRetrievalResult
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import RetrievalFitStats
|
||||
from onyx.agents.agent_search.shared_graph_utils.operators import (
|
||||
dedup_inference_sections,
|
||||
)
|
||||
from onyx.context.search.models import InferenceSection
|
||||
|
||||
### States ###
|
||||
|
||||
## Graph Input State
|
||||
|
||||
|
||||
class ExpandedRetrievalInput(SubgraphCoreState):
|
||||
question: str = ""
|
||||
base_search: bool = False
|
||||
sub_question_id: str | None = None
|
||||
|
||||
|
||||
## Update/Return States
|
||||
|
||||
|
||||
class QueryExpansionUpdate(LoggerUpdate, BaseModel):
|
||||
expanded_queries: list[str] = []
|
||||
log_messages: list[str] = []
|
||||
|
||||
|
||||
class DocVerificationUpdate(BaseModel):
|
||||
verified_documents: Annotated[list[InferenceSection], dedup_inference_sections] = []
|
||||
|
||||
|
||||
class DocRetrievalUpdate(LoggerUpdate, BaseModel):
|
||||
query_retrieval_results: Annotated[list[QueryRetrievalResult], add] = []
|
||||
retrieved_documents: Annotated[
|
||||
list[InferenceSection], dedup_inference_sections
|
||||
] = []
|
||||
|
||||
|
||||
class DocRerankingUpdate(LoggerUpdate, BaseModel):
|
||||
reranked_documents: Annotated[list[InferenceSection], dedup_inference_sections] = []
|
||||
sub_question_retrieval_stats: RetrievalFitStats | None = None
|
||||
|
||||
|
||||
class ExpandedRetrievalUpdate(LoggerUpdate, BaseModel):
|
||||
expanded_retrieval_result: QuestionRetrievalResult
|
||||
|
||||
|
||||
## Graph Output State
|
||||
|
||||
|
||||
class ExpandedRetrievalOutput(LoggerUpdate, BaseModel):
|
||||
expanded_retrieval_result: QuestionRetrievalResult = QuestionRetrievalResult()
|
||||
base_expanded_retrieval_result: QuestionRetrievalResult = QuestionRetrievalResult()
|
||||
retrieved_documents: Annotated[
|
||||
list[InferenceSection], dedup_inference_sections
|
||||
] = []
|
||||
|
||||
|
||||
## Graph State
|
||||
|
||||
|
||||
class ExpandedRetrievalState(
|
||||
# This includes the core state
|
||||
ExpandedRetrievalInput,
|
||||
QueryExpansionUpdate,
|
||||
DocRetrievalUpdate,
|
||||
DocVerificationUpdate,
|
||||
DocRerankingUpdate,
|
||||
ExpandedRetrievalOutput,
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
## Conditional Input States
|
||||
|
||||
|
||||
class DocVerificationInput(ExpandedRetrievalInput):
|
||||
retrieved_document_to_verify: InferenceSection
|
||||
|
||||
|
||||
class RetrievalInput(ExpandedRetrievalInput):
|
||||
query_to_retrieve: str = ""
|
||||
90
backend/onyx/agents/agent_search/models.py
Normal file
90
backend/onyx/agents/agent_search/models.py
Normal file
@@ -0,0 +1,90 @@
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import model_validator
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.context.search.models import SearchRequest
|
||||
from onyx.file_store.utils import InMemoryChatFile
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.tools.force import ForceUseTool
|
||||
from onyx.tools.tool import Tool
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
|
||||
|
||||
class GraphInputs(BaseModel):
|
||||
"""Input data required for the graph execution"""
|
||||
|
||||
search_request: SearchRequest
|
||||
prompt_builder: AnswerPromptBuilder
|
||||
files: list[InMemoryChatFile] | None = None
|
||||
structured_response_format: dict | None = None
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
class GraphTooling(BaseModel):
|
||||
"""Tools and LLMs available to the graph"""
|
||||
|
||||
primary_llm: LLM
|
||||
fast_llm: LLM
|
||||
search_tool: SearchTool | None = None
|
||||
tools: list[Tool]
|
||||
# Whether to force use of a tool, or to
|
||||
# force tool args IF the tool is used
|
||||
force_use_tool: ForceUseTool
|
||||
using_tool_calling_llm: bool = False
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
class GraphPersistence(BaseModel):
|
||||
"""Configuration for data persistence"""
|
||||
|
||||
chat_session_id: UUID
|
||||
# The message ID of the to-be-created first agent message
|
||||
# in response to the user message that triggered the Pro Search
|
||||
message_id: int
|
||||
|
||||
# The database session the user and initial agent
|
||||
# message were flushed to; only needed for agentic search
|
||||
db_session: Session
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
class GraphSearchConfig(BaseModel):
|
||||
"""Configuration controlling search behavior"""
|
||||
|
||||
use_agentic_search: bool = False
|
||||
# Whether to perform initial search to inform decomposition
|
||||
perform_initial_search_decomposition: bool = True
|
||||
|
||||
# Whether to allow creation of refinement questions (and entity extraction, etc.)
|
||||
allow_refinement: bool = True
|
||||
skip_gen_ai_answer_generation: bool = False
|
||||
|
||||
|
||||
class GraphConfig(BaseModel):
|
||||
"""
|
||||
Main container for data needed for Langgraph execution
|
||||
"""
|
||||
|
||||
inputs: GraphInputs
|
||||
tooling: GraphTooling
|
||||
behavior: GraphSearchConfig
|
||||
# Only needed for agentic search
|
||||
persistence: GraphPersistence
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_search_tool(self) -> "GraphConfig":
|
||||
if self.behavior.use_agentic_search and self.tooling.search_tool is None:
|
||||
raise ValueError("search_tool must be provided for agentic search")
|
||||
return self
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
@@ -0,0 +1,77 @@
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import AIMessageChunk
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.basic.states import BasicOutput
|
||||
from onyx.agents.agent_search.basic.states import BasicState
|
||||
from onyx.agents.agent_search.basic.utils import process_llm_stream
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.chat.models import LlmDoc
|
||||
from onyx.chat.models import OnyxContexts
|
||||
from onyx.tools.tool_implementations.search.search_tool import (
|
||||
SEARCH_DOC_CONTENT_ID,
|
||||
)
|
||||
from onyx.tools.tool_implementations.search_like_tool_utils import (
|
||||
FINAL_CONTEXT_DOCUMENTS_ID,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def basic_use_tool_response(
|
||||
state: BasicState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> BasicOutput:
|
||||
agent_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
structured_response_format = agent_config.inputs.structured_response_format
|
||||
llm = agent_config.tooling.primary_llm
|
||||
tool_choice = state.tool_choice
|
||||
if tool_choice is None:
|
||||
raise ValueError("Tool choice is None")
|
||||
tool = tool_choice.tool
|
||||
prompt_builder = agent_config.inputs.prompt_builder
|
||||
if state.tool_call_output is None:
|
||||
raise ValueError("Tool call output is None")
|
||||
tool_call_output = state.tool_call_output
|
||||
tool_call_summary = tool_call_output.tool_call_summary
|
||||
tool_call_responses = tool_call_output.tool_call_responses
|
||||
|
||||
new_prompt_builder = tool.build_next_prompt(
|
||||
prompt_builder=prompt_builder,
|
||||
tool_call_summary=tool_call_summary,
|
||||
tool_responses=tool_call_responses,
|
||||
using_tool_calling_llm=agent_config.tooling.using_tool_calling_llm,
|
||||
)
|
||||
|
||||
final_search_results = []
|
||||
initial_search_results = []
|
||||
for yield_item in tool_call_responses:
|
||||
if yield_item.id == FINAL_CONTEXT_DOCUMENTS_ID:
|
||||
final_search_results = cast(list[LlmDoc], yield_item.response)
|
||||
elif yield_item.id == SEARCH_DOC_CONTENT_ID:
|
||||
search_contexts = cast(OnyxContexts, yield_item.response).contexts
|
||||
for doc in search_contexts:
|
||||
if doc.document_id not in initial_search_results:
|
||||
initial_search_results.append(doc)
|
||||
|
||||
new_tool_call_chunk = AIMessageChunk(content="")
|
||||
if not agent_config.behavior.skip_gen_ai_answer_generation:
|
||||
stream = llm.stream(
|
||||
prompt=new_prompt_builder.build(),
|
||||
structured_response_format=structured_response_format,
|
||||
)
|
||||
|
||||
# For now, we don't do multiple tool calls, so we ignore the tool_message
|
||||
new_tool_call_chunk = process_llm_stream(
|
||||
stream,
|
||||
True,
|
||||
writer,
|
||||
final_search_results=final_search_results,
|
||||
# when the search tool is called with specific doc ids, initial search
|
||||
# results are not output. But, we still want i.e. citations to be processed.
|
||||
displayed_search_results=initial_search_results or final_search_results,
|
||||
)
|
||||
|
||||
return BasicOutput(tool_call_chunk=new_tool_call_chunk)
|
||||
@@ -0,0 +1,154 @@
|
||||
from typing import cast
|
||||
from uuid import uuid4
|
||||
|
||||
from langchain_core.messages import ToolCall
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.basic.utils import process_llm_stream
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.orchestration.states import ToolChoice
|
||||
from onyx.agents.agent_search.orchestration.states import ToolChoiceState
|
||||
from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.chat.tool_handling.tool_response_handler import get_tool_by_name
|
||||
from onyx.chat.tool_handling.tool_response_handler import (
|
||||
get_tool_call_for_non_tool_calling_llm_impl,
|
||||
)
|
||||
from onyx.tools.tool import Tool
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
# TODO: break this out into an implementation function
|
||||
# and a function that handles extracting the necessary fields
|
||||
# from the state and config
|
||||
# TODO: fan-out to multiple tool call nodes? Make this configurable?
|
||||
def llm_tool_choice(
|
||||
state: ToolChoiceState,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
) -> ToolChoiceUpdate:
|
||||
"""
|
||||
This node is responsible for calling the LLM to choose a tool. If no tool is chosen,
|
||||
The node MAY emit an answer, depending on whether state["should_stream_answer"] is set.
|
||||
"""
|
||||
should_stream_answer = state.should_stream_answer
|
||||
|
||||
agent_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
using_tool_calling_llm = agent_config.tooling.using_tool_calling_llm
|
||||
prompt_builder = state.prompt_snapshot or agent_config.inputs.prompt_builder
|
||||
|
||||
llm = agent_config.tooling.primary_llm
|
||||
skip_gen_ai_answer_generation = agent_config.behavior.skip_gen_ai_answer_generation
|
||||
|
||||
structured_response_format = agent_config.inputs.structured_response_format
|
||||
tools = [
|
||||
tool for tool in (agent_config.tooling.tools or []) if tool.name in state.tools
|
||||
]
|
||||
force_use_tool = agent_config.tooling.force_use_tool
|
||||
|
||||
tool, tool_args = None, None
|
||||
if force_use_tool.force_use and force_use_tool.args is not None:
|
||||
tool_name, tool_args = (
|
||||
force_use_tool.tool_name,
|
||||
force_use_tool.args,
|
||||
)
|
||||
tool = get_tool_by_name(tools, tool_name)
|
||||
|
||||
# special pre-logic for non-tool calling LLM case
|
||||
elif not using_tool_calling_llm and tools:
|
||||
chosen_tool_and_args = get_tool_call_for_non_tool_calling_llm_impl(
|
||||
force_use_tool=force_use_tool,
|
||||
tools=tools,
|
||||
prompt_builder=prompt_builder,
|
||||
llm=llm,
|
||||
)
|
||||
if chosen_tool_and_args:
|
||||
tool, tool_args = chosen_tool_and_args
|
||||
|
||||
# If we have a tool and tool args, we are ready to request a tool call.
|
||||
# This only happens if the tool call was forced or we are using a non-tool calling LLM.
|
||||
if tool and tool_args:
|
||||
return ToolChoiceUpdate(
|
||||
tool_choice=ToolChoice(
|
||||
tool=tool,
|
||||
tool_args=tool_args,
|
||||
id=str(uuid4()),
|
||||
),
|
||||
)
|
||||
|
||||
# if we're skipping gen ai answer generation, we should only
|
||||
# continue if we're forcing a tool call (which will be emitted by
|
||||
# the tool calling llm in the stream() below)
|
||||
if skip_gen_ai_answer_generation and not force_use_tool.force_use:
|
||||
return ToolChoiceUpdate(
|
||||
tool_choice=None,
|
||||
)
|
||||
|
||||
built_prompt = (
|
||||
prompt_builder.build()
|
||||
if isinstance(prompt_builder, AnswerPromptBuilder)
|
||||
else prompt_builder.built_prompt
|
||||
)
|
||||
# At this point, we are either using a tool calling LLM or we are skipping the tool call.
|
||||
# DEBUG: good breakpoint
|
||||
stream = llm.stream(
|
||||
# For tool calling LLMs, we want to insert the task prompt as part of this flow, this is because the LLM
|
||||
# may choose to not call any tools and just generate the answer, in which case the task prompt is needed.
|
||||
prompt=built_prompt,
|
||||
tools=[tool.tool_definition() for tool in tools] or None,
|
||||
tool_choice=("required" if tools and force_use_tool.force_use else None),
|
||||
structured_response_format=structured_response_format,
|
||||
)
|
||||
|
||||
tool_message = process_llm_stream(
|
||||
stream,
|
||||
should_stream_answer
|
||||
and not agent_config.behavior.skip_gen_ai_answer_generation,
|
||||
writer,
|
||||
)
|
||||
|
||||
# If no tool calls are emitted by the LLM, we should not choose a tool
|
||||
if len(tool_message.tool_calls) == 0:
|
||||
logger.debug("No tool calls emitted by LLM")
|
||||
return ToolChoiceUpdate(
|
||||
tool_choice=None,
|
||||
)
|
||||
|
||||
# TODO: here we could handle parallel tool calls. Right now
|
||||
# we just pick the first one that matches.
|
||||
selected_tool: Tool | None = None
|
||||
selected_tool_call_request: ToolCall | None = None
|
||||
for tool_call_request in tool_message.tool_calls:
|
||||
known_tools_by_name = [
|
||||
tool for tool in tools if tool.name == tool_call_request["name"]
|
||||
]
|
||||
|
||||
if known_tools_by_name:
|
||||
selected_tool = known_tools_by_name[0]
|
||||
selected_tool_call_request = tool_call_request
|
||||
break
|
||||
|
||||
logger.error(
|
||||
"Tool call requested with unknown name field. \n"
|
||||
f"tools: {tools}"
|
||||
f"tool_call_request: {tool_call_request}"
|
||||
)
|
||||
|
||||
if not selected_tool or not selected_tool_call_request:
|
||||
raise ValueError(
|
||||
f"Tool call attempted with tool {selected_tool}, request {selected_tool_call_request}"
|
||||
)
|
||||
|
||||
logger.debug(f"Selected tool: {selected_tool.name}")
|
||||
logger.debug(f"Selected tool call request: {selected_tool_call_request}")
|
||||
|
||||
return ToolChoiceUpdate(
|
||||
tool_choice=ToolChoice(
|
||||
tool=selected_tool,
|
||||
tool_args=selected_tool_call_request["args"],
|
||||
id=selected_tool_call_request["id"],
|
||||
),
|
||||
)
|
||||
@@ -0,0 +1,17 @@
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.orchestration.states import ToolChoiceInput
|
||||
|
||||
|
||||
def prepare_tool_input(state: Any, config: RunnableConfig) -> ToolChoiceInput:
|
||||
agent_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
return ToolChoiceInput(
|
||||
# NOTE: this node is used at the top level of the agent, so we always stream
|
||||
should_stream_answer=True,
|
||||
prompt_snapshot=None, # uses default prompt builder
|
||||
tools=[tool.name for tool in (agent_config.tooling.tools or [])],
|
||||
)
|
||||
@@ -0,0 +1,79 @@
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import AIMessageChunk
|
||||
from langchain_core.messages.tool import ToolCall
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.orchestration.states import ToolCallOutput
|
||||
from onyx.agents.agent_search.orchestration.states import ToolCallUpdate
|
||||
from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.chat.models import AnswerPacket
|
||||
from onyx.tools.message import build_tool_message
|
||||
from onyx.tools.message import ToolCallSummary
|
||||
from onyx.tools.tool_runner import ToolRunner
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class ToolCallException(Exception):
|
||||
"""Exception raised for errors during tool calls."""
|
||||
|
||||
|
||||
def emit_packet(packet: AnswerPacket, writer: StreamWriter) -> None:
|
||||
write_custom_event("basic_response", packet, writer)
|
||||
|
||||
|
||||
def tool_call(
|
||||
state: ToolChoiceUpdate,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
) -> ToolCallUpdate:
|
||||
"""Calls the tool specified in the state and updates the state with the result"""
|
||||
|
||||
cast(GraphConfig, config["metadata"]["config"])
|
||||
|
||||
tool_choice = state.tool_choice
|
||||
if tool_choice is None:
|
||||
raise ValueError("Cannot invoke tool call node without a tool choice")
|
||||
|
||||
tool = tool_choice.tool
|
||||
tool_args = tool_choice.tool_args
|
||||
tool_id = tool_choice.id
|
||||
tool_runner = ToolRunner(tool, tool_args)
|
||||
tool_kickoff = tool_runner.kickoff()
|
||||
|
||||
emit_packet(tool_kickoff, writer)
|
||||
|
||||
try:
|
||||
tool_responses = []
|
||||
for response in tool_runner.tool_responses():
|
||||
tool_responses.append(response)
|
||||
emit_packet(response, writer)
|
||||
|
||||
tool_final_result = tool_runner.tool_final_result()
|
||||
emit_packet(tool_final_result, writer)
|
||||
except Exception as e:
|
||||
raise ToolCallException(
|
||||
f"Error during tool call for {tool.display_name}: {e}"
|
||||
) from e
|
||||
|
||||
tool_call = ToolCall(name=tool.name, args=tool_args, id=tool_id)
|
||||
tool_call_summary = ToolCallSummary(
|
||||
tool_call_request=AIMessageChunk(content="", tool_calls=[tool_call]),
|
||||
tool_call_result=build_tool_message(
|
||||
tool_call, tool_runner.tool_message_content()
|
||||
),
|
||||
)
|
||||
|
||||
tool_call_output = ToolCallOutput(
|
||||
tool_call_summary=tool_call_summary,
|
||||
tool_call_kickoff=tool_kickoff,
|
||||
tool_call_responses=tool_responses,
|
||||
tool_call_final_result=tool_final_result,
|
||||
)
|
||||
return ToolCallUpdate(tool_call_output=tool_call_output)
|
||||
48
backend/onyx/agents/agent_search/orchestration/states.py
Normal file
48
backend/onyx/agents/agent_search/orchestration/states.py
Normal file
@@ -0,0 +1,48 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import PromptSnapshot
|
||||
from onyx.tools.message import ToolCallSummary
|
||||
from onyx.tools.models import ToolCallFinalResult
|
||||
from onyx.tools.models import ToolCallKickoff
|
||||
from onyx.tools.models import ToolResponse
|
||||
from onyx.tools.tool import Tool
|
||||
|
||||
|
||||
# TODO: adapt the tool choice/tool call to allow for parallel tool calls by
|
||||
# creating a subgraph that can be invoked in parallel via Send/Command APIs
|
||||
class ToolChoiceInput(BaseModel):
|
||||
should_stream_answer: bool = True
|
||||
# default to the prompt builder from the config, but
|
||||
# allow overrides for arbitrary tool calls
|
||||
prompt_snapshot: PromptSnapshot | None = None
|
||||
|
||||
# names of tools to use for tool calling. Filters the tools available in the config
|
||||
tools: list[str] = []
|
||||
|
||||
|
||||
class ToolCallOutput(BaseModel):
|
||||
tool_call_summary: ToolCallSummary
|
||||
tool_call_kickoff: ToolCallKickoff
|
||||
tool_call_responses: list[ToolResponse]
|
||||
tool_call_final_result: ToolCallFinalResult
|
||||
|
||||
|
||||
class ToolCallUpdate(BaseModel):
|
||||
tool_call_output: ToolCallOutput | None = None
|
||||
|
||||
|
||||
class ToolChoice(BaseModel):
|
||||
tool: Tool
|
||||
tool_args: dict
|
||||
id: str | None
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
class ToolChoiceUpdate(BaseModel):
|
||||
tool_choice: ToolChoice | None = None
|
||||
|
||||
|
||||
class ToolChoiceState(ToolChoiceUpdate, ToolChoiceInput):
|
||||
pass
|
||||
213
backend/onyx/agents/agent_search/run_graph.py
Normal file
213
backend/onyx/agents/agent_search/run_graph.py
Normal file
@@ -0,0 +1,213 @@
|
||||
from collections.abc import Iterable
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.runnables.schema import CustomStreamEvent
|
||||
from langchain_core.runnables.schema import StreamEvent
|
||||
from langgraph.graph.state import CompiledStateGraph
|
||||
|
||||
from onyx.agents.agent_search.basic.graph_builder import basic_graph_builder
|
||||
from onyx.agents.agent_search.basic.states import BasicInput
|
||||
from onyx.agents.agent_search.deep_search.main.graph_builder import (
|
||||
main_graph_builder as main_graph_builder_a,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.main.states import (
|
||||
MainInput as MainInput_a,
|
||||
)
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
|
||||
from onyx.chat.models import AgentAnswerPiece
|
||||
from onyx.chat.models import AnswerPacket
|
||||
from onyx.chat.models import AnswerStream
|
||||
from onyx.chat.models import ExtendedToolResponse
|
||||
from onyx.chat.models import RefinedAnswerImprovement
|
||||
from onyx.chat.models import StreamStopInfo
|
||||
from onyx.chat.models import SubQueryPiece
|
||||
from onyx.chat.models import SubQuestionPiece
|
||||
from onyx.chat.models import ToolResponse
|
||||
from onyx.configs.agent_configs import ALLOW_REFINEMENT
|
||||
from onyx.configs.agent_configs import INITIAL_SEARCH_DECOMPOSITION_ENABLED
|
||||
from onyx.context.search.models import SearchRequest
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.llm.factory import get_default_llms
|
||||
from onyx.tools.tool_runner import ToolCallKickoff
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_COMPILED_GRAPH: CompiledStateGraph | None = None
|
||||
|
||||
|
||||
def _parse_agent_event(
|
||||
event: StreamEvent,
|
||||
) -> AnswerPacket | None:
|
||||
"""
|
||||
Parse the event into a typed object.
|
||||
Return None if we are not interested in the event.
|
||||
"""
|
||||
|
||||
event_type = event["event"]
|
||||
|
||||
# We always just yield the event data, but this piece is useful for two development reasons:
|
||||
# 1. It's a list of the names of every place we dispatch a custom event
|
||||
# 2. We maintain the intended types yielded by each event
|
||||
if event_type == "on_custom_event":
|
||||
if event["name"] == "decomp_qs":
|
||||
return cast(SubQuestionPiece, event["data"])
|
||||
elif event["name"] == "subqueries":
|
||||
return cast(SubQueryPiece, event["data"])
|
||||
elif event["name"] == "sub_answers":
|
||||
return cast(AgentAnswerPiece, event["data"])
|
||||
elif event["name"] == "stream_finished":
|
||||
return cast(StreamStopInfo, event["data"])
|
||||
elif event["name"] == "initial_agent_answer":
|
||||
return cast(AgentAnswerPiece, event["data"])
|
||||
elif event["name"] == "refined_agent_answer":
|
||||
return cast(AgentAnswerPiece, event["data"])
|
||||
elif event["name"] == "start_refined_answer_creation":
|
||||
return cast(ToolCallKickoff, event["data"])
|
||||
elif event["name"] == "tool_response":
|
||||
return cast(ToolResponse, event["data"])
|
||||
elif event["name"] == "basic_response":
|
||||
return cast(AnswerPacket, event["data"])
|
||||
elif event["name"] == "refined_answer_improvement":
|
||||
return cast(RefinedAnswerImprovement, event["data"])
|
||||
return None
|
||||
|
||||
|
||||
def manage_sync_streaming(
|
||||
compiled_graph: CompiledStateGraph,
|
||||
config: GraphConfig,
|
||||
graph_input: BasicInput | MainInput_a,
|
||||
) -> Iterable[StreamEvent]:
|
||||
message_id = config.persistence.message_id if config.persistence else None
|
||||
for event in compiled_graph.stream(
|
||||
stream_mode="custom",
|
||||
input=graph_input,
|
||||
config={"metadata": {"config": config, "thread_id": str(message_id)}},
|
||||
):
|
||||
yield cast(CustomStreamEvent, event)
|
||||
|
||||
|
||||
def run_graph(
|
||||
compiled_graph: CompiledStateGraph,
|
||||
config: GraphConfig,
|
||||
input: BasicInput | MainInput_a,
|
||||
) -> AnswerStream:
|
||||
config.behavior.perform_initial_search_decomposition = (
|
||||
INITIAL_SEARCH_DECOMPOSITION_ENABLED
|
||||
)
|
||||
config.behavior.allow_refinement = ALLOW_REFINEMENT
|
||||
|
||||
for event in manage_sync_streaming(
|
||||
compiled_graph=compiled_graph, config=config, graph_input=input
|
||||
):
|
||||
if not (parsed_object := _parse_agent_event(event)):
|
||||
continue
|
||||
|
||||
yield parsed_object
|
||||
|
||||
|
||||
# It doesn't actually take very long to load the graph, but we'd rather
|
||||
# not compile it again on every request.
|
||||
def load_compiled_graph() -> CompiledStateGraph:
|
||||
global _COMPILED_GRAPH
|
||||
if _COMPILED_GRAPH is None:
|
||||
graph = main_graph_builder_a()
|
||||
_COMPILED_GRAPH = graph.compile()
|
||||
return _COMPILED_GRAPH
|
||||
|
||||
|
||||
def run_main_graph(
|
||||
config: GraphConfig,
|
||||
) -> AnswerStream:
|
||||
compiled_graph = load_compiled_graph()
|
||||
|
||||
input = MainInput_a(
|
||||
base_question=config.inputs.search_request.query, log_messages=[]
|
||||
)
|
||||
|
||||
# Agent search is not a Tool per se, but this is helpful for the frontend
|
||||
yield ToolCallKickoff(
|
||||
tool_name="agent_search_0",
|
||||
tool_args={"query": config.inputs.search_request.query},
|
||||
)
|
||||
yield from run_graph(compiled_graph, config, input)
|
||||
|
||||
|
||||
def run_basic_graph(
|
||||
config: GraphConfig,
|
||||
) -> AnswerStream:
|
||||
graph = basic_graph_builder()
|
||||
compiled_graph = graph.compile()
|
||||
input = BasicInput()
|
||||
return run_graph(compiled_graph, config, input)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
for _ in range(1):
|
||||
query_start_time = datetime.now()
|
||||
logger.debug(f"Start at {query_start_time}")
|
||||
graph = main_graph_builder_a()
|
||||
compiled_graph = graph.compile()
|
||||
query_end_time = datetime.now()
|
||||
logger.debug(f"Graph compiled in {query_end_time - query_start_time} seconds")
|
||||
primary_llm, fast_llm = get_default_llms()
|
||||
search_request = SearchRequest(
|
||||
# query="what can you do with gitlab?",
|
||||
# query="What are the guiding principles behind the development of cockroachDB",
|
||||
# query="What are the temperatures in Munich, Hawaii, and New York?",
|
||||
# query="When was Washington born?",
|
||||
# query="What is Onyx?",
|
||||
# query="What is the difference between astronomy and astrology?",
|
||||
query="Do a search to tell me what is the difference between astronomy and astrology?",
|
||||
)
|
||||
|
||||
with get_session_context_manager() as db_session:
|
||||
config = get_test_config(db_session, primary_llm, fast_llm, search_request)
|
||||
assert (
|
||||
config.persistence is not None
|
||||
), "set a chat session id to run this test"
|
||||
|
||||
# search_request.persona = get_persona_by_id(1, None, db_session)
|
||||
# config.perform_initial_search_path_decision = False
|
||||
config.behavior.perform_initial_search_decomposition = True
|
||||
input = MainInput_a(
|
||||
base_question=config.inputs.search_request.query, log_messages=[]
|
||||
)
|
||||
|
||||
tool_responses: list = []
|
||||
for output in run_graph(compiled_graph, config, input):
|
||||
if isinstance(output, ToolCallKickoff):
|
||||
pass
|
||||
elif isinstance(output, ExtendedToolResponse):
|
||||
tool_responses.append(output.response)
|
||||
logger.info(
|
||||
f" ---- ET {output.level} - {output.level_question_num} | "
|
||||
)
|
||||
elif isinstance(output, SubQueryPiece):
|
||||
logger.info(
|
||||
f"Sq {output.level} - {output.level_question_num} - {output.sub_query} | "
|
||||
)
|
||||
elif isinstance(output, SubQuestionPiece):
|
||||
logger.info(
|
||||
f"SQ {output.level} - {output.level_question_num} - {output.sub_question} | "
|
||||
)
|
||||
elif (
|
||||
isinstance(output, AgentAnswerPiece)
|
||||
and output.answer_type == "agent_sub_answer"
|
||||
):
|
||||
logger.info(
|
||||
f" ---- SA {output.level} - {output.level_question_num} {output.answer_piece} | "
|
||||
)
|
||||
elif (
|
||||
isinstance(output, AgentAnswerPiece)
|
||||
and output.answer_type == "agent_level_answer"
|
||||
):
|
||||
logger.info(
|
||||
f" ---------- FA {output.level} - {output.level_question_num} {output.answer_piece} | "
|
||||
)
|
||||
elif isinstance(output, RefinedAnswerImprovement):
|
||||
logger.info(
|
||||
f" ---------- RE {output.refined_answer_improvement} | "
|
||||
)
|
||||
@@ -0,0 +1,152 @@
|
||||
from langchain.schema import AIMessage
|
||||
from langchain.schema import HumanMessage
|
||||
from langchain.schema import SystemMessage
|
||||
from langchain_core.messages.tool import ToolMessage
|
||||
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import (
|
||||
AgentPromptEnrichmentComponents,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_persona_agent_prompt_expressions,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import remove_document_citations
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import summarize_history
|
||||
from onyx.configs.agent_configs import AGENT_MAX_STATIC_HISTORY_WORD_LENGTH
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.llm.interfaces import LLMConfig
|
||||
from onyx.llm.utils import get_max_input_tokens
|
||||
from onyx.natural_language_processing.utils import get_tokenizer
|
||||
from onyx.natural_language_processing.utils import tokenizer_trim_content
|
||||
from onyx.prompts.agent_search import HISTORY_FRAMING_PROMPT
|
||||
from onyx.prompts.agent_search import SUB_QUESTION_RAG_PROMPT
|
||||
from onyx.prompts.prompt_utils import build_date_time_string
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def build_sub_question_answer_prompt(
|
||||
question: str,
|
||||
original_question: str,
|
||||
docs: list[InferenceSection],
|
||||
persona_specification: str,
|
||||
config: LLMConfig,
|
||||
) -> list[SystemMessage | HumanMessage | AIMessage | ToolMessage]:
|
||||
system_message = SystemMessage(
|
||||
content=persona_specification,
|
||||
)
|
||||
|
||||
date_str = build_date_time_string()
|
||||
|
||||
# TODO: This should include document metadata and title
|
||||
docs_format_list = [
|
||||
f"Document Number: [D{doc_num + 1}]\nContent: {doc.combined_content}\n\n"
|
||||
for doc_num, doc in enumerate(docs)
|
||||
]
|
||||
|
||||
docs_str = "\n\n".join(docs_format_list)
|
||||
|
||||
docs_str = trim_prompt_piece(
|
||||
config,
|
||||
docs_str,
|
||||
SUB_QUESTION_RAG_PROMPT + question + original_question + date_str,
|
||||
)
|
||||
human_message = HumanMessage(
|
||||
content=SUB_QUESTION_RAG_PROMPT.format(
|
||||
question=question,
|
||||
original_question=original_question,
|
||||
context=docs_str,
|
||||
date_prompt=date_str,
|
||||
)
|
||||
)
|
||||
|
||||
return [system_message, human_message]
|
||||
|
||||
|
||||
def trim_prompt_piece(config: LLMConfig, prompt_piece: str, reserved_str: str) -> str:
|
||||
# TODO: save the max input tokens in LLMConfig
|
||||
max_tokens = get_max_input_tokens(
|
||||
model_provider=config.model_provider,
|
||||
model_name=config.model_name,
|
||||
)
|
||||
|
||||
# no need to trim if a conservative estimate of one token
|
||||
# per character is already less than the max tokens
|
||||
if len(prompt_piece) + len(reserved_str) < max_tokens:
|
||||
return prompt_piece
|
||||
|
||||
llm_tokenizer = get_tokenizer(
|
||||
provider_type=config.model_provider,
|
||||
model_name=config.model_name,
|
||||
)
|
||||
|
||||
# slightly conservative trimming
|
||||
return tokenizer_trim_content(
|
||||
content=prompt_piece,
|
||||
desired_length=max_tokens - len(llm_tokenizer.encode(reserved_str)),
|
||||
tokenizer=llm_tokenizer,
|
||||
)
|
||||
|
||||
|
||||
def build_history_prompt(config: GraphConfig, question: str) -> str:
|
||||
prompt_builder = config.inputs.prompt_builder
|
||||
persona_base = get_persona_agent_prompt_expressions(
|
||||
config.inputs.search_request.persona
|
||||
).base_prompt
|
||||
|
||||
if prompt_builder is None:
|
||||
return ""
|
||||
|
||||
if prompt_builder.single_message_history is not None:
|
||||
history = prompt_builder.single_message_history
|
||||
else:
|
||||
history_components = []
|
||||
previous_message_type = None
|
||||
for message in prompt_builder.raw_message_history:
|
||||
if message.message_type == MessageType.USER:
|
||||
history_components.append(f"User: {message.message}\n")
|
||||
previous_message_type = MessageType.USER
|
||||
elif message.message_type == MessageType.ASSISTANT:
|
||||
# Previously there could be multiple assistant messages in a row
|
||||
# Now this is handled at the message history construction
|
||||
assert previous_message_type is not MessageType.ASSISTANT
|
||||
history_components.append(f"You/Agent: {message.message}\n")
|
||||
previous_message_type = MessageType.ASSISTANT
|
||||
else:
|
||||
# Other message types are not included here, currently there should be no other message types
|
||||
logger.error(
|
||||
f"Unhandled message type: {message.message_type} with message: {message.message}"
|
||||
)
|
||||
continue
|
||||
|
||||
history = "\n".join(history_components)
|
||||
history = remove_document_citations(history)
|
||||
if len(history.split()) > AGENT_MAX_STATIC_HISTORY_WORD_LENGTH:
|
||||
history = summarize_history(
|
||||
history=history,
|
||||
question=question,
|
||||
persona_specification=persona_base,
|
||||
llm=config.tooling.fast_llm,
|
||||
)
|
||||
|
||||
return HISTORY_FRAMING_PROMPT.format(history=history) if history else ""
|
||||
|
||||
|
||||
def get_prompt_enrichment_components(
|
||||
config: GraphConfig,
|
||||
) -> AgentPromptEnrichmentComponents:
|
||||
persona_prompts = get_persona_agent_prompt_expressions(
|
||||
config.inputs.search_request.persona
|
||||
)
|
||||
|
||||
history = build_history_prompt(config, config.inputs.search_request.query)
|
||||
|
||||
date_str = build_date_time_string()
|
||||
|
||||
return AgentPromptEnrichmentComponents(
|
||||
persona_prompts=persona_prompts,
|
||||
history=history,
|
||||
date_str=date_str,
|
||||
)
|
||||
@@ -0,0 +1,98 @@
|
||||
import numpy as np
|
||||
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import RetrievalFitScoreMetrics
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import RetrievalFitStats
|
||||
from onyx.chat.models import SectionRelevancePiece
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def unique_chunk_id(doc: InferenceSection) -> str:
|
||||
return f"{doc.center_chunk.document_id}_{doc.center_chunk.chunk_id}"
|
||||
|
||||
|
||||
def calculate_rank_shift(list1: list, list2: list, top_n: int = 20) -> float:
|
||||
shift = 0
|
||||
for rank_first, doc_id in enumerate(list1[:top_n], 1):
|
||||
try:
|
||||
rank_second = list2.index(doc_id) + 1
|
||||
except ValueError:
|
||||
rank_second = len(list2) # Document not found in second list
|
||||
|
||||
shift += np.abs(rank_first - rank_second) / np.log(1 + rank_first * rank_second)
|
||||
|
||||
return shift / top_n
|
||||
|
||||
|
||||
def get_fit_scores(
|
||||
pre_reranked_results: list[InferenceSection],
|
||||
post_reranked_results: list[InferenceSection] | list[SectionRelevancePiece],
|
||||
) -> RetrievalFitStats | None:
|
||||
"""
|
||||
Calculate retrieval metrics for search purposes
|
||||
"""
|
||||
|
||||
if len(pre_reranked_results) == 0 or len(post_reranked_results) == 0:
|
||||
return None
|
||||
|
||||
ranked_sections = {
|
||||
"initial": pre_reranked_results,
|
||||
"reranked": post_reranked_results,
|
||||
}
|
||||
|
||||
fit_eval: RetrievalFitStats = RetrievalFitStats(
|
||||
fit_score_lift=0,
|
||||
rerank_effect=0,
|
||||
fit_scores={
|
||||
"initial": RetrievalFitScoreMetrics(scores={}, chunk_ids=[]),
|
||||
"reranked": RetrievalFitScoreMetrics(scores={}, chunk_ids=[]),
|
||||
},
|
||||
)
|
||||
|
||||
for rank_type, docs in ranked_sections.items():
|
||||
logger.debug(f"rank_type: {rank_type}")
|
||||
|
||||
for i in [1, 5, 10]:
|
||||
fit_eval.fit_scores[rank_type].scores[str(i)] = (
|
||||
sum(
|
||||
[
|
||||
float(doc.center_chunk.score)
|
||||
for doc in docs[:i]
|
||||
if type(doc) == InferenceSection
|
||||
and doc.center_chunk.score is not None
|
||||
]
|
||||
)
|
||||
/ i
|
||||
)
|
||||
|
||||
fit_eval.fit_scores[rank_type].scores["fit_score"] = (
|
||||
1
|
||||
/ 3
|
||||
* (
|
||||
fit_eval.fit_scores[rank_type].scores["1"]
|
||||
+ fit_eval.fit_scores[rank_type].scores["5"]
|
||||
+ fit_eval.fit_scores[rank_type].scores["10"]
|
||||
)
|
||||
)
|
||||
|
||||
fit_eval.fit_scores[rank_type].scores["fit_score"] = fit_eval.fit_scores[
|
||||
rank_type
|
||||
].scores["1"]
|
||||
|
||||
fit_eval.fit_scores[rank_type].chunk_ids = [
|
||||
unique_chunk_id(doc) for doc in docs if type(doc) == InferenceSection
|
||||
]
|
||||
|
||||
fit_eval.fit_score_lift = (
|
||||
fit_eval.fit_scores["reranked"].scores["fit_score"]
|
||||
/ fit_eval.fit_scores["initial"].scores["fit_score"]
|
||||
)
|
||||
|
||||
fit_eval.rerank_effect = calculate_rank_shift(
|
||||
fit_eval.fit_scores["initial"].chunk_ids,
|
||||
fit_eval.fit_scores["reranked"].chunk_ids,
|
||||
)
|
||||
|
||||
return fit_eval
|
||||
128
backend/onyx/agents/agent_search/shared_graph_utils/models.py
Normal file
128
backend/onyx/agents/agent_search/shared_graph_utils/models.py
Normal file
@@ -0,0 +1,128 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.agents.agent_search.deep_search.main.models import (
|
||||
AgentAdditionalMetrics,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.main.models import AgentBaseMetrics
|
||||
from onyx.agents.agent_search.deep_search.main.models import (
|
||||
AgentRefinedMetrics,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.main.models import AgentTimings
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.tools.models import SearchQueryInfo
|
||||
|
||||
|
||||
# Pydantic models for structured outputs
|
||||
# class RewrittenQueries(BaseModel):
|
||||
# rewritten_queries: list[str]
|
||||
|
||||
|
||||
# class BinaryDecision(BaseModel):
|
||||
# decision: Literal["yes", "no"]
|
||||
|
||||
|
||||
# class BinaryDecisionWithReasoning(BaseModel):
|
||||
# reasoning: str
|
||||
# decision: Literal["yes", "no"]
|
||||
|
||||
|
||||
class RetrievalFitScoreMetrics(BaseModel):
|
||||
scores: dict[str, float]
|
||||
chunk_ids: list[str]
|
||||
|
||||
|
||||
class RetrievalFitStats(BaseModel):
|
||||
fit_score_lift: float
|
||||
rerank_effect: float
|
||||
fit_scores: dict[str, RetrievalFitScoreMetrics]
|
||||
|
||||
|
||||
# class AgentChunkScores(BaseModel):
|
||||
# scores: dict[str, dict[str, list[int | float]]]
|
||||
|
||||
|
||||
class AgentChunkRetrievalStats(BaseModel):
|
||||
verified_count: int | None = None
|
||||
verified_avg_scores: float | None = None
|
||||
rejected_count: int | None = None
|
||||
rejected_avg_scores: float | None = None
|
||||
verified_doc_chunk_ids: list[str] = []
|
||||
dismissed_doc_chunk_ids: list[str] = []
|
||||
|
||||
|
||||
class InitialAgentResultStats(BaseModel):
|
||||
sub_questions: dict[str, float | int | None]
|
||||
original_question: dict[str, float | int | None]
|
||||
agent_effectiveness: dict[str, float | int | None]
|
||||
|
||||
|
||||
class RefinedAgentStats(BaseModel):
|
||||
revision_doc_efficiency: float | None
|
||||
revision_question_efficiency: float | None
|
||||
|
||||
|
||||
class Term(BaseModel):
|
||||
term_name: str = ""
|
||||
term_type: str = ""
|
||||
term_similar_to: list[str] = []
|
||||
|
||||
|
||||
### Models ###
|
||||
|
||||
|
||||
class Entity(BaseModel):
|
||||
entity_name: str = ""
|
||||
entity_type: str = ""
|
||||
|
||||
|
||||
class Relationship(BaseModel):
|
||||
relationship_name: str = ""
|
||||
relationship_type: str = ""
|
||||
relationship_entities: list[str] = []
|
||||
|
||||
|
||||
class EntityRelationshipTermExtraction(BaseModel):
|
||||
entities: list[Entity] = []
|
||||
relationships: list[Relationship] = []
|
||||
terms: list[Term] = []
|
||||
|
||||
|
||||
class EntityExtractionResult(BaseModel):
|
||||
retrieved_entities_relationships: EntityRelationshipTermExtraction
|
||||
|
||||
|
||||
class QueryRetrievalResult(BaseModel):
|
||||
query: str
|
||||
retrieved_documents: list[InferenceSection]
|
||||
stats: RetrievalFitStats | None
|
||||
query_info: SearchQueryInfo | None
|
||||
|
||||
|
||||
class SubQuestionAnswerResults(BaseModel):
|
||||
question: str
|
||||
question_id: str
|
||||
answer: str
|
||||
verified_high_quality: bool
|
||||
sub_query_retrieval_results: list[QueryRetrievalResult]
|
||||
verified_reranked_documents: list[InferenceSection]
|
||||
context_documents: list[InferenceSection]
|
||||
cited_documents: list[InferenceSection]
|
||||
sub_question_retrieval_stats: AgentChunkRetrievalStats
|
||||
|
||||
|
||||
class CombinedAgentMetrics(BaseModel):
|
||||
timings: AgentTimings
|
||||
base_metrics: AgentBaseMetrics | None
|
||||
refined_metrics: AgentRefinedMetrics
|
||||
additional_metrics: AgentAdditionalMetrics
|
||||
|
||||
|
||||
class PersonaPromptExpressions(BaseModel):
|
||||
contextualized_prompt: str
|
||||
base_prompt: str | None
|
||||
|
||||
|
||||
class AgentPromptEnrichmentComponents(BaseModel):
|
||||
persona_prompts: PersonaPromptExpressions
|
||||
history: str
|
||||
date_str: str
|
||||
@@ -0,0 +1,31 @@
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import (
|
||||
SubQuestionAnswerResults,
|
||||
)
|
||||
from onyx.chat.prune_and_merge import _merge_sections
|
||||
from onyx.context.search.models import InferenceSection
|
||||
|
||||
|
||||
def dedup_inference_sections(
|
||||
list1: list[InferenceSection], list2: list[InferenceSection]
|
||||
) -> list[InferenceSection]:
|
||||
deduped = _merge_sections(list1 + list2)
|
||||
return deduped
|
||||
|
||||
|
||||
def dedup_question_answer_results(
|
||||
question_answer_results_1: list[SubQuestionAnswerResults],
|
||||
question_answer_results_2: list[SubQuestionAnswerResults],
|
||||
) -> list[SubQuestionAnswerResults]:
|
||||
deduped_question_answer_results: list[
|
||||
SubQuestionAnswerResults
|
||||
] = question_answer_results_1
|
||||
utilized_question_ids: set[str] = set(
|
||||
[x.question_id for x in question_answer_results_1]
|
||||
)
|
||||
|
||||
for question_answer_result in question_answer_results_2:
|
||||
if question_answer_result.question_id not in utilized_question_ids:
|
||||
deduped_question_answer_results.append(question_answer_result)
|
||||
utilized_question_ids.add(question_answer_result.question_id)
|
||||
|
||||
return deduped_question_answer_results
|
||||
433
backend/onyx/agents/agent_search/shared_graph_utils/utils.py
Normal file
433
backend/onyx/agents/agent_search/shared_graph_utils/utils.py
Normal file
@@ -0,0 +1,433 @@
|
||||
import os
|
||||
import re
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import Literal
|
||||
from typing import TypedDict
|
||||
from uuid import UUID
|
||||
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langgraph.types import StreamWriter
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.models import GraphInputs
|
||||
from onyx.agents.agent_search.models import GraphPersistence
|
||||
from onyx.agents.agent_search.models import GraphSearchConfig
|
||||
from onyx.agents.agent_search.models import GraphTooling
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import (
|
||||
EntityRelationshipTermExtraction,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import PersonaPromptExpressions
|
||||
from onyx.chat.models import AnswerPacket
|
||||
from onyx.chat.models import AnswerStyleConfig
|
||||
from onyx.chat.models import CitationConfig
|
||||
from onyx.chat.models import DocumentPruningConfig
|
||||
from onyx.chat.models import PromptConfig
|
||||
from onyx.chat.models import SectionRelevancePiece
|
||||
from onyx.chat.models import StreamStopInfo
|
||||
from onyx.chat.models import StreamStopReason
|
||||
from onyx.chat.models import StreamType
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
|
||||
from onyx.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
|
||||
from onyx.configs.constants import DEFAULT_PERSONA_ID
|
||||
from onyx.configs.constants import DISPATCH_SEP_CHAR
|
||||
from onyx.configs.constants import FORMAT_DOCS_SEPARATOR
|
||||
from onyx.context.search.enums import LLMEvaluationType
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.context.search.models import RetrievalDetails
|
||||
from onyx.context.search.models import SearchRequest
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.db.persona import get_persona_by_id
|
||||
from onyx.db.persona import Persona
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.prompts.agent_search import (
|
||||
ASSISTANT_SYSTEM_PROMPT_DEFAULT,
|
||||
)
|
||||
from onyx.prompts.agent_search import (
|
||||
ASSISTANT_SYSTEM_PROMPT_PERSONA,
|
||||
)
|
||||
from onyx.prompts.agent_search import (
|
||||
HISTORY_CONTEXT_SUMMARY_PROMPT,
|
||||
)
|
||||
from onyx.prompts.prompt_utils import handle_onyx_date_awareness
|
||||
from onyx.tools.force import ForceUseTool
|
||||
from onyx.tools.tool_constructor import SearchToolConfig
|
||||
from onyx.tools.tool_implementations.search.search_tool import (
|
||||
SEARCH_RESPONSE_SUMMARY_ID,
|
||||
)
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from onyx.tools.utils import explicit_tool_calling_supported
|
||||
|
||||
BaseMessage_Content = str | list[str | dict[str, Any]]
|
||||
|
||||
|
||||
# Post-processing
|
||||
def format_docs(docs: Sequence[InferenceSection]) -> str:
|
||||
formatted_doc_list = []
|
||||
|
||||
for doc_num, doc in enumerate(docs):
|
||||
title: str | None = doc.center_chunk.title
|
||||
metadata: dict[str, str | list[str]] | None = (
|
||||
doc.center_chunk.metadata if doc.center_chunk.metadata else None
|
||||
)
|
||||
|
||||
doc_str = f"**Document: D{doc_num + 1}**"
|
||||
if title:
|
||||
doc_str += f"\nTitle: {title}"
|
||||
if metadata:
|
||||
metadata_str = ""
|
||||
for key, value in metadata.items():
|
||||
if isinstance(value, str):
|
||||
metadata_str += f" - {key}: {value}"
|
||||
elif isinstance(value, list):
|
||||
metadata_str += f" - {key}: {', '.join(value)}"
|
||||
doc_str += f"\nMetadata: {metadata_str}"
|
||||
doc_str += f"\nContent:\n{doc.combined_content}"
|
||||
|
||||
formatted_doc_list.append(doc_str)
|
||||
|
||||
return FORMAT_DOCS_SEPARATOR.join(formatted_doc_list)
|
||||
|
||||
|
||||
def format_entity_term_extraction(
|
||||
entity_term_extraction_dict: EntityRelationshipTermExtraction,
|
||||
) -> str:
|
||||
entities = entity_term_extraction_dict.entities
|
||||
terms = entity_term_extraction_dict.terms
|
||||
relationships = entity_term_extraction_dict.relationships
|
||||
|
||||
entity_strs = ["\nEntities:\n"]
|
||||
for entity in entities:
|
||||
entity_str = f"{entity.entity_name} ({entity.entity_type})"
|
||||
entity_strs.append(entity_str)
|
||||
|
||||
entity_str = "\n - ".join(entity_strs)
|
||||
|
||||
relationship_strs = ["\n\nRelationships:\n"]
|
||||
for relationship in relationships:
|
||||
relationship_name = relationship.relationship_name
|
||||
relationship_type = relationship.relationship_type
|
||||
relationship_entities = relationship.relationship_entities
|
||||
relationship_str = (
|
||||
f"""{relationship_name} ({relationship_type}): {relationship_entities}"""
|
||||
)
|
||||
relationship_strs.append(relationship_str)
|
||||
|
||||
relationship_str = "\n - ".join(relationship_strs)
|
||||
|
||||
term_strs = ["\n\nTerms:\n"]
|
||||
for term in terms:
|
||||
term_str = f"{term.term_name} ({term.term_type}): similar to {', '.join(term.term_similar_to)}"
|
||||
term_strs.append(term_str)
|
||||
|
||||
term_str = "\n - ".join(term_strs)
|
||||
|
||||
return "\n".join(entity_strs + relationship_strs + term_strs)
|
||||
|
||||
|
||||
def get_test_config(
|
||||
db_session: Session,
|
||||
primary_llm: LLM,
|
||||
fast_llm: LLM,
|
||||
search_request: SearchRequest,
|
||||
use_agentic_search: bool = True,
|
||||
) -> GraphConfig:
|
||||
persona = get_persona_by_id(DEFAULT_PERSONA_ID, None, db_session)
|
||||
document_pruning_config = DocumentPruningConfig(
|
||||
max_chunks=int(
|
||||
persona.num_chunks
|
||||
if persona.num_chunks is not None
|
||||
else MAX_CHUNKS_FED_TO_CHAT
|
||||
),
|
||||
max_window_percentage=CHAT_TARGET_CHUNK_PERCENTAGE,
|
||||
)
|
||||
|
||||
answer_style_config = AnswerStyleConfig(
|
||||
citation_config=CitationConfig(
|
||||
# The docs retrieved by this flow are already relevance-filtered
|
||||
all_docs_useful=True
|
||||
),
|
||||
document_pruning_config=document_pruning_config,
|
||||
structured_response_format=None,
|
||||
)
|
||||
|
||||
search_tool_config = SearchToolConfig(
|
||||
answer_style_config=answer_style_config,
|
||||
document_pruning_config=document_pruning_config,
|
||||
retrieval_options=RetrievalDetails(), # may want to set dedupe_docs=True
|
||||
rerank_settings=None, # Can use this to change reranking model
|
||||
selected_sections=None,
|
||||
latest_query_files=None,
|
||||
bypass_acl=False,
|
||||
)
|
||||
|
||||
prompt_config = PromptConfig.from_model(persona.prompts[0])
|
||||
|
||||
search_tool = SearchTool(
|
||||
db_session=db_session,
|
||||
user=None,
|
||||
persona=persona,
|
||||
retrieval_options=search_tool_config.retrieval_options,
|
||||
prompt_config=prompt_config,
|
||||
llm=primary_llm,
|
||||
fast_llm=fast_llm,
|
||||
pruning_config=search_tool_config.document_pruning_config,
|
||||
answer_style_config=search_tool_config.answer_style_config,
|
||||
selected_sections=search_tool_config.selected_sections,
|
||||
chunks_above=search_tool_config.chunks_above,
|
||||
chunks_below=search_tool_config.chunks_below,
|
||||
full_doc=search_tool_config.full_doc,
|
||||
evaluation_type=(
|
||||
LLMEvaluationType.BASIC
|
||||
if persona.llm_relevance_filter
|
||||
else LLMEvaluationType.SKIP
|
||||
),
|
||||
rerank_settings=search_tool_config.rerank_settings,
|
||||
bypass_acl=search_tool_config.bypass_acl,
|
||||
)
|
||||
|
||||
graph_inputs = GraphInputs(
|
||||
search_request=search_request,
|
||||
prompt_builder=AnswerPromptBuilder(
|
||||
user_message=HumanMessage(content=search_request.query),
|
||||
message_history=[],
|
||||
llm_config=primary_llm.config,
|
||||
raw_user_query=search_request.query,
|
||||
raw_user_uploaded_files=[],
|
||||
),
|
||||
structured_response_format=answer_style_config.structured_response_format,
|
||||
)
|
||||
|
||||
using_tool_calling_llm = explicit_tool_calling_supported(
|
||||
primary_llm.config.model_provider, primary_llm.config.model_name
|
||||
)
|
||||
graph_tooling = GraphTooling(
|
||||
primary_llm=primary_llm,
|
||||
fast_llm=fast_llm,
|
||||
search_tool=search_tool,
|
||||
tools=[search_tool],
|
||||
force_use_tool=ForceUseTool(force_use=False, tool_name=""),
|
||||
using_tool_calling_llm=using_tool_calling_llm,
|
||||
)
|
||||
|
||||
chat_session_id = os.environ.get("ONYX_AS_CHAT_SESSION_ID")
|
||||
assert (
|
||||
chat_session_id is not None
|
||||
), "ONYX_AS_CHAT_SESSION_ID must be set for backend tests"
|
||||
graph_persistence = GraphPersistence(
|
||||
db_session=db_session,
|
||||
chat_session_id=UUID(chat_session_id),
|
||||
message_id=1,
|
||||
)
|
||||
|
||||
search_behavior_config = GraphSearchConfig(
|
||||
use_agentic_search=use_agentic_search,
|
||||
skip_gen_ai_answer_generation=False,
|
||||
allow_refinement=True,
|
||||
)
|
||||
graph_config = GraphConfig(
|
||||
inputs=graph_inputs,
|
||||
tooling=graph_tooling,
|
||||
persistence=graph_persistence,
|
||||
behavior=search_behavior_config,
|
||||
)
|
||||
|
||||
return graph_config
|
||||
|
||||
|
||||
def get_persona_agent_prompt_expressions(
|
||||
persona: Persona | None,
|
||||
) -> PersonaPromptExpressions:
|
||||
if persona is None or len(persona.prompts) == 0:
|
||||
# TODO base_prompt should be None, but no time to properly fix
|
||||
return PersonaPromptExpressions(
|
||||
contextualized_prompt=ASSISTANT_SYSTEM_PROMPT_DEFAULT, base_prompt=""
|
||||
)
|
||||
|
||||
# Only a 1:1 mapping between personas and prompts currently
|
||||
prompt = persona.prompts[0]
|
||||
prompt_config = PromptConfig.from_model(prompt)
|
||||
datetime_aware_system_prompt = handle_onyx_date_awareness(
|
||||
prompt_str=prompt_config.system_prompt,
|
||||
prompt_config=prompt_config,
|
||||
add_additional_info_if_no_tag=prompt.datetime_aware,
|
||||
)
|
||||
|
||||
return PersonaPromptExpressions(
|
||||
contextualized_prompt=ASSISTANT_SYSTEM_PROMPT_PERSONA.format(
|
||||
persona_prompt=datetime_aware_system_prompt
|
||||
),
|
||||
base_prompt=datetime_aware_system_prompt,
|
||||
)
|
||||
|
||||
|
||||
def make_question_id(level: int, question_num: int) -> str:
|
||||
return f"{level}_{question_num}"
|
||||
|
||||
|
||||
def parse_question_id(question_id: str) -> tuple[int, int]:
|
||||
level, question_num = question_id.split("_")
|
||||
return int(level), int(question_num)
|
||||
|
||||
|
||||
def _dispatch_nonempty(
|
||||
content: str, dispatch_event: Callable[[str, int], None], sep_num: int
|
||||
) -> None:
|
||||
"""
|
||||
Dispatch a content string if it is not empty using the given callback.
|
||||
This function is used in the context of dispatching some arbitrary number
|
||||
of similar objects which are separated by a separator during the LLM stream.
|
||||
The callback expects a sep_num denoting which object is being dispatched; these
|
||||
numbers go from 1 to however many strings the LLM decides to stream.
|
||||
"""
|
||||
if content != "":
|
||||
dispatch_event(content, sep_num)
|
||||
|
||||
|
||||
def dispatch_separated(
|
||||
tokens: Iterator[BaseMessage],
|
||||
dispatch_event: Callable[[str, int], None],
|
||||
sep: str = DISPATCH_SEP_CHAR,
|
||||
) -> list[BaseMessage_Content]:
|
||||
num = 1
|
||||
streamed_tokens: list[BaseMessage_Content] = []
|
||||
for token in tokens:
|
||||
content = cast(str, token.content)
|
||||
if sep in content:
|
||||
sub_question_parts = content.split(sep)
|
||||
_dispatch_nonempty(sub_question_parts[0], dispatch_event, num)
|
||||
num += 1
|
||||
_dispatch_nonempty(
|
||||
"".join(sub_question_parts[1:]).strip(), dispatch_event, num
|
||||
)
|
||||
else:
|
||||
_dispatch_nonempty(content, dispatch_event, num)
|
||||
streamed_tokens.append(content)
|
||||
|
||||
return streamed_tokens
|
||||
|
||||
|
||||
def dispatch_main_answer_stop_info(level: int, writer: StreamWriter) -> None:
|
||||
stop_event = StreamStopInfo(
|
||||
stop_reason=StreamStopReason.FINISHED,
|
||||
stream_type=StreamType.MAIN_ANSWER,
|
||||
level=level,
|
||||
)
|
||||
write_custom_event("stream_finished", stop_event, writer)
|
||||
|
||||
|
||||
def retrieve_search_docs(
|
||||
search_tool: SearchTool, question: str
|
||||
) -> list[InferenceSection]:
|
||||
retrieved_docs: list[InferenceSection] = []
|
||||
|
||||
# new db session to avoid concurrency issues
|
||||
with get_session_context_manager() as db_session:
|
||||
for tool_response in search_tool.run(
|
||||
query=question,
|
||||
force_no_rerank=True,
|
||||
alternate_db_session=db_session,
|
||||
):
|
||||
# get retrieved docs to send to the rest of the graph
|
||||
if tool_response.id == SEARCH_RESPONSE_SUMMARY_ID:
|
||||
response = cast(SearchResponseSummary, tool_response.response)
|
||||
retrieved_docs = response.top_sections
|
||||
break
|
||||
|
||||
return retrieved_docs
|
||||
|
||||
|
||||
def get_answer_citation_ids(answer_str: str) -> list[int]:
|
||||
"""
|
||||
Extract citation numbers of format [D<number>] from the answer string.
|
||||
"""
|
||||
citation_ids = re.findall(r"\[D(\d+)\]", answer_str)
|
||||
return list(set([(int(id) - 1) for id in citation_ids]))
|
||||
|
||||
|
||||
def summarize_history(
|
||||
history: str, question: str, persona_specification: str | None, llm: LLM
|
||||
) -> str:
|
||||
history_context_prompt = remove_document_citations(
|
||||
HISTORY_CONTEXT_SUMMARY_PROMPT.format(
|
||||
persona_specification=persona_specification,
|
||||
question=question,
|
||||
history=history,
|
||||
)
|
||||
)
|
||||
|
||||
history_response = llm.invoke(history_context_prompt)
|
||||
assert isinstance(history_response.content, str)
|
||||
return history_response.content
|
||||
|
||||
|
||||
# taken from langchain_core.runnables.schema
|
||||
# we don't use the one from their library because
|
||||
# it includes ids they generate
|
||||
class CustomStreamEvent(TypedDict):
|
||||
# Overwrite the event field to be more specific.
|
||||
event: Literal["on_custom_event"] # type: ignore[misc]
|
||||
"""The event type."""
|
||||
name: str
|
||||
"""User defined name for the event."""
|
||||
data: Any
|
||||
"""The data associated with the event. Free form and can be anything."""
|
||||
|
||||
|
||||
def write_custom_event(
|
||||
name: str, event: AnswerPacket, stream_writer: StreamWriter
|
||||
) -> None:
|
||||
stream_writer(CustomStreamEvent(event="on_custom_event", name=name, data=event))
|
||||
|
||||
|
||||
def relevance_from_docs(
|
||||
relevant_docs: list[InferenceSection],
|
||||
) -> list[SectionRelevancePiece]:
|
||||
return [
|
||||
SectionRelevancePiece(
|
||||
relevant=True,
|
||||
content=doc.center_chunk.content,
|
||||
document_id=doc.center_chunk.document_id,
|
||||
chunk_id=doc.center_chunk.chunk_id,
|
||||
)
|
||||
for doc in relevant_docs
|
||||
]
|
||||
|
||||
|
||||
def get_langgraph_node_log_string(
|
||||
graph_component: str,
|
||||
node_name: str,
|
||||
node_start_time: datetime,
|
||||
result: str | None = None,
|
||||
) -> str:
|
||||
duration = datetime.now() - node_start_time
|
||||
results_str = "" if result is None else f" -- Result: {result}"
|
||||
return f"{node_start_time} -- {graph_component} - {node_name} -- Time taken: {duration}{results_str}"
|
||||
|
||||
|
||||
def remove_document_citations(text: str) -> str:
|
||||
"""
|
||||
Removes citation expressions of format '[[D1]]()' from text.
|
||||
The number after D can vary.
|
||||
|
||||
Args:
|
||||
text: Input text containing citations
|
||||
|
||||
Returns:
|
||||
Text with citations removed
|
||||
"""
|
||||
# Pattern explanation:
|
||||
# \[(?:D|Q)?\d+\] matches:
|
||||
# \[ - literal [ character
|
||||
# (?:D|Q)? - optional D or Q character
|
||||
# \d+ - one or more digits
|
||||
# \] - literal ] character
|
||||
return re.sub(r"\[(?:D|Q)?\d+\]", "", text)
|
||||
@@ -10,6 +10,7 @@ from onyx.configs.app_configs import SMTP_PORT
|
||||
from onyx.configs.app_configs import SMTP_SERVER
|
||||
from onyx.configs.app_configs import SMTP_USER
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.configs.constants import TENANT_ID_COOKIE_NAME
|
||||
from onyx.db.models import User
|
||||
|
||||
|
||||
@@ -65,9 +66,13 @@ def send_forgot_password_email(
|
||||
user_email: str,
|
||||
token: str,
|
||||
mail_from: str = EMAIL_FROM,
|
||||
tenant_id: str | None = None,
|
||||
) -> None:
|
||||
subject = "Onyx Forgot Password"
|
||||
link = f"{WEB_DOMAIN}/auth/reset-password?token={token}"
|
||||
if tenant_id:
|
||||
link += f"&{TENANT_ID_COOKIE_NAME}={tenant_id}"
|
||||
# Keep search param same name as cookie for simplicity
|
||||
body = f"Click the following link to reset your password: {link}"
|
||||
send_email(user_email, subject, body, mail_from)
|
||||
|
||||
|
||||
@@ -42,6 +42,10 @@ class UserCreate(schemas.BaseUserCreate):
|
||||
tenant_id: str | None = None
|
||||
|
||||
|
||||
class UserUpdateWithRole(schemas.BaseUserUpdate):
|
||||
role: UserRole
|
||||
|
||||
|
||||
class UserUpdate(schemas.BaseUserUpdate):
|
||||
"""
|
||||
Role updates are not allowed through the user update endpoint for security reasons
|
||||
|
||||
@@ -57,7 +57,7 @@ from onyx.auth.invited_users import get_invited_users
|
||||
from onyx.auth.schemas import AuthBackend
|
||||
from onyx.auth.schemas import UserCreate
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.auth.schemas import UserUpdate
|
||||
from onyx.auth.schemas import UserUpdateWithRole
|
||||
from onyx.configs.app_configs import AUTH_BACKEND
|
||||
from onyx.configs.app_configs import AUTH_COOKIE_EXPIRE_TIME_SECONDS
|
||||
from onyx.configs.app_configs import AUTH_TYPE
|
||||
@@ -73,6 +73,7 @@ from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.configs.constants import AuthType
|
||||
from onyx.configs.constants import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
|
||||
from onyx.configs.constants import DANSWER_API_KEY_PREFIX
|
||||
from onyx.configs.constants import FASTAPI_USERS_AUTH_COOKIE_NAME
|
||||
from onyx.configs.constants import MilestoneRecordType
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.configs.constants import PASSWORD_SPECIAL_CHARS
|
||||
@@ -216,9 +217,26 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
reset_password_token_secret = USER_AUTH_SECRET
|
||||
verification_token_secret = USER_AUTH_SECRET
|
||||
verification_token_lifetime_seconds = AUTH_COOKIE_EXPIRE_TIME_SECONDS
|
||||
|
||||
user_db: SQLAlchemyUserDatabase[User, uuid.UUID]
|
||||
|
||||
async def get_by_email(self, user_email: str) -> User:
|
||||
tenant_id = fetch_ee_implementation_or_noop(
|
||||
"onyx.server.tenants.user_mapping", "get_tenant_id_for_email", None
|
||||
)(user_email)
|
||||
async with get_async_session_with_tenant(tenant_id) as db_session:
|
||||
if MULTI_TENANT:
|
||||
tenant_user_db = SQLAlchemyUserAdminDB[User, uuid.UUID](
|
||||
db_session, User, OAuthAccount
|
||||
)
|
||||
user = await tenant_user_db.get_by_email(user_email)
|
||||
else:
|
||||
user = await self.user_db.get_by_email(user_email)
|
||||
|
||||
if not user:
|
||||
raise exceptions.UserNotExists()
|
||||
|
||||
return user
|
||||
|
||||
async def create(
|
||||
self,
|
||||
user_create: schemas.UC | UserCreate,
|
||||
@@ -246,10 +264,10 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
referral_source=referral_source,
|
||||
request=request,
|
||||
)
|
||||
user: User
|
||||
|
||||
async with get_async_session_with_tenant(tenant_id) as db_session:
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
|
||||
verify_email_is_invited(user_create.email)
|
||||
verify_email_domain(user_create.email)
|
||||
if MULTI_TENANT:
|
||||
@@ -268,16 +286,16 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
user_create.role = UserRole.ADMIN
|
||||
else:
|
||||
user_create.role = UserRole.BASIC
|
||||
|
||||
try:
|
||||
user = await super().create(user_create, safe=safe, request=request) # type: ignore
|
||||
except exceptions.UserAlreadyExists:
|
||||
user = await self.get_by_email(user_create.email)
|
||||
# Handle case where user has used product outside of web and is now creating an account through web
|
||||
if not user.role.is_web_login() and user_create.role.is_web_login():
|
||||
user_update = UserUpdate(
|
||||
user_update = UserUpdateWithRole(
|
||||
password=user_create.password,
|
||||
is_verified=user_create.is_verified,
|
||||
role=user_create.role,
|
||||
)
|
||||
user = await self.update(user_update, user)
|
||||
else:
|
||||
@@ -285,7 +303,6 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
|
||||
finally:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
return user
|
||||
|
||||
async def validate_password(self, password: str, _: schemas.UC | models.UP) -> None:
|
||||
@@ -372,6 +389,8 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
"refresh_token": refresh_token,
|
||||
}
|
||||
|
||||
user: User
|
||||
|
||||
try:
|
||||
# Attempt to get user by OAuth account
|
||||
user = await self.get_by_oauth_account(oauth_name, account_id)
|
||||
@@ -504,9 +523,15 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
)
|
||||
raise HTTPException(
|
||||
status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
"Your admin has not enbaled this feature.",
|
||||
"Your admin has not enabled this feature.",
|
||||
)
|
||||
send_forgot_password_email(user.email, token)
|
||||
tenant_id = await fetch_ee_implementation_or_noop(
|
||||
"onyx.server.tenants.provisioning",
|
||||
"get_or_provision_tenant",
|
||||
async_return_default_schema,
|
||||
)(email=user.email)
|
||||
|
||||
send_forgot_password_email(user.email, token, tenant_id=tenant_id)
|
||||
|
||||
async def on_after_request_verify(
|
||||
self, user: User, token: str, request: Optional[Request] = None
|
||||
@@ -580,6 +605,7 @@ async def get_user_manager(
|
||||
cookie_transport = CookieTransport(
|
||||
cookie_max_age=SESSION_EXPIRE_TIME_SECONDS,
|
||||
cookie_secure=WEB_DOMAIN.startswith("https"),
|
||||
cookie_name=FASTAPI_USERS_AUTH_COOKIE_NAME,
|
||||
)
|
||||
|
||||
|
||||
@@ -1047,6 +1073,8 @@ async def api_key_dep(
|
||||
if AUTH_TYPE == AuthType.DISABLED:
|
||||
return None
|
||||
|
||||
user: User | None = None
|
||||
|
||||
hashed_api_key = get_hashed_api_key_from_request(request)
|
||||
if not hashed_api_key:
|
||||
raise HTTPException(status_code=401, detail="Missing API key")
|
||||
|
||||
@@ -24,6 +24,7 @@ from onyx.configs.constants import ONYX_CLOUD_CELERY_TASK_PREFIX
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.db.engine import get_sqlalchemy_engine
|
||||
from onyx.document_index.vespa.shared_utils.utils import wait_for_vespa_with_timeout
|
||||
from onyx.httpx.httpx_pool import HttpxPool
|
||||
from onyx.redis.redis_connector import RedisConnector
|
||||
from onyx.redis.redis_connector_credential_pair import RedisConnectorCredentialPair
|
||||
from onyx.redis.redis_connector_delete import RedisConnectorDelete
|
||||
@@ -197,7 +198,8 @@ def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None:
|
||||
|
||||
def wait_for_redis(sender: Any, **kwargs: Any) -> None:
|
||||
"""Waits for redis to become ready subject to a hardcoded timeout.
|
||||
Will raise WorkerShutdown to kill the celery worker if the timeout is reached."""
|
||||
Will raise WorkerShutdown to kill the celery worker if the timeout
|
||||
is reached."""
|
||||
|
||||
r = get_redis_client(tenant_id=None)
|
||||
|
||||
@@ -316,6 +318,8 @@ def on_worker_ready(sender: Any, **kwargs: Any) -> None:
|
||||
|
||||
|
||||
def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
|
||||
HttpxPool.close_all()
|
||||
|
||||
if not celery_is_worker_primary(sender):
|
||||
return
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from datetime import timedelta
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from celery import Celery
|
||||
from celery import signals
|
||||
@@ -8,7 +7,6 @@ from celery.beat import PersistentScheduler # type: ignore
|
||||
from celery.signals import beat_init
|
||||
|
||||
import onyx.background.celery.apps.app_base as app_base
|
||||
from onyx.configs.constants import ONYX_CLOUD_CELERY_TASK_PREFIX
|
||||
from onyx.configs.constants import POSTGRES_CELERY_BEAT_APP_NAME
|
||||
from onyx.db.engine import get_all_tenant_ids
|
||||
from onyx.db.engine import SqlEngine
|
||||
@@ -132,21 +130,25 @@ class DynamicTenantScheduler(PersistentScheduler):
|
||||
# get current schedule and extract current tenants
|
||||
current_schedule = self.schedule.items()
|
||||
|
||||
current_tenants = set()
|
||||
for task_name, _ in current_schedule:
|
||||
task_name = cast(str, task_name)
|
||||
if task_name.startswith(ONYX_CLOUD_CELERY_TASK_PREFIX):
|
||||
continue
|
||||
# there are no more per tenant beat tasks, so comment this out
|
||||
# NOTE: we may not actualy need this scheduler any more and should
|
||||
# test reverting to a regular beat schedule implementation
|
||||
|
||||
if "_" in task_name:
|
||||
# example: "check-for-condition-tenant_12345678-abcd-efgh-ijkl-12345678"
|
||||
# -> "12345678-abcd-efgh-ijkl-12345678"
|
||||
current_tenants.add(task_name.split("_")[-1])
|
||||
logger.info(f"Found {len(current_tenants)} existing items in schedule")
|
||||
# current_tenants = set()
|
||||
# for task_name, _ in current_schedule:
|
||||
# task_name = cast(str, task_name)
|
||||
# if task_name.startswith(ONYX_CLOUD_CELERY_TASK_PREFIX):
|
||||
# continue
|
||||
|
||||
for tenant_id in tenant_ids:
|
||||
if tenant_id not in current_tenants:
|
||||
logger.info(f"Processing new tenant: {tenant_id}")
|
||||
# if "_" in task_name:
|
||||
# # example: "check-for-condition-tenant_12345678-abcd-efgh-ijkl-12345678"
|
||||
# # -> "12345678-abcd-efgh-ijkl-12345678"
|
||||
# current_tenants.add(task_name.split("_")[-1])
|
||||
# logger.info(f"Found {len(current_tenants)} existing items in schedule")
|
||||
|
||||
# for tenant_id in tenant_ids:
|
||||
# if tenant_id not in current_tenants:
|
||||
# logger.info(f"Processing new tenant: {tenant_id}")
|
||||
|
||||
new_schedule = self._generate_schedule(tenant_ids)
|
||||
|
||||
|
||||
@@ -10,6 +10,10 @@ from celery.signals import worker_ready
|
||||
from celery.signals import worker_shutdown
|
||||
|
||||
import onyx.background.celery.apps.app_base as app_base
|
||||
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
|
||||
from onyx.configs.app_configs import MANAGED_VESPA
|
||||
from onyx.configs.app_configs import VESPA_CLOUD_CERT_PATH
|
||||
from onyx.configs.app_configs import VESPA_CLOUD_KEY_PATH
|
||||
from onyx.configs.constants import POSTGRES_CELERY_WORKER_LIGHT_APP_NAME
|
||||
from onyx.db.engine import SqlEngine
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -54,12 +58,23 @@ def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None:
|
||||
|
||||
@worker_init.connect
|
||||
def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
EXTRA_CONCURRENCY = 8 # small extra fudge factor for connection limits
|
||||
|
||||
logger.info("worker_init signal received.")
|
||||
|
||||
logger.info(f"Concurrency: {sender.concurrency}") # type: ignore
|
||||
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_LIGHT_APP_NAME)
|
||||
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=8) # type: ignore
|
||||
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=EXTRA_CONCURRENCY) # type: ignore
|
||||
|
||||
if MANAGED_VESPA:
|
||||
httpx_init_vespa_pool(
|
||||
sender.concurrency + EXTRA_CONCURRENCY, # type: ignore
|
||||
ssl_cert=VESPA_CLOUD_CERT_PATH,
|
||||
ssl_key=VESPA_CLOUD_KEY_PATH,
|
||||
)
|
||||
else:
|
||||
httpx_init_vespa_pool(sender.concurrency + EXTRA_CONCURRENCY) # type: ignore
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
|
||||
@@ -21,13 +21,16 @@ from onyx.background.celery.tasks.indexing.utils import (
|
||||
get_unfenced_index_attempt_ids,
|
||||
)
|
||||
from onyx.configs.constants import CELERY_PRIMARY_WORKER_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import OnyxRedisConstants
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.configs.constants import POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME
|
||||
from onyx.db.engine import get_session_with_default_tenant
|
||||
from onyx.db.engine import SqlEngine
|
||||
from onyx.db.index_attempt import get_index_attempt
|
||||
from onyx.db.index_attempt import mark_attempt_canceled
|
||||
from onyx.redis.redis_connector_credential_pair import RedisConnectorCredentialPair
|
||||
from onyx.redis.redis_connector_credential_pair import (
|
||||
RedisGlobalConnectorCredentialPair,
|
||||
)
|
||||
from onyx.redis.redis_connector_delete import RedisConnectorDelete
|
||||
from onyx.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
|
||||
from onyx.redis.redis_connector_ext_group_sync import RedisConnectorExternalGroupSync
|
||||
@@ -141,23 +144,16 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
r.delete(OnyxRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK)
|
||||
r.delete(OnyxRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
|
||||
|
||||
r.delete(RedisConnectorCredentialPair.get_taskset_key())
|
||||
r.delete(RedisConnectorCredentialPair.get_fence_key())
|
||||
r.delete(OnyxRedisConstants.ACTIVE_FENCES)
|
||||
|
||||
RedisGlobalConnectorCredentialPair.reset_all(r)
|
||||
RedisDocumentSet.reset_all(r)
|
||||
|
||||
RedisUserGroup.reset_all(r)
|
||||
|
||||
RedisConnectorDelete.reset_all(r)
|
||||
|
||||
RedisConnectorPrune.reset_all(r)
|
||||
|
||||
RedisConnectorIndex.reset_all(r)
|
||||
|
||||
RedisConnectorStop.reset_all(r)
|
||||
|
||||
RedisConnectorPermissionSync.reset_all(r)
|
||||
|
||||
RedisConnectorExternalGroupSync.reset_all(r)
|
||||
|
||||
# mark orphaned index attempts as failed
|
||||
|
||||
@@ -91,6 +91,28 @@ def celery_find_task(task_id: str, queue: str, r: Redis) -> int:
|
||||
return False
|
||||
|
||||
|
||||
def celery_get_queued_task_ids(queue: str, r: Redis) -> set[str]:
|
||||
"""This is a redis specific way to build a list of tasks in a queue.
|
||||
|
||||
This helps us read the queue once and then efficiently look for missing tasks
|
||||
in the queue.
|
||||
"""
|
||||
|
||||
task_set: set[str] = set()
|
||||
|
||||
for priority in range(len(OnyxCeleryPriority)):
|
||||
queue_name = f"{queue}{CELERY_SEPARATOR}{priority}" if priority > 0 else queue
|
||||
|
||||
tasks = cast(list[bytes], r.lrange(queue_name, 0, -1))
|
||||
for task in tasks:
|
||||
task_dict: dict[str, Any] = json.loads(task.decode("utf-8"))
|
||||
task_id = task_dict.get("headers", {}).get("id")
|
||||
if task_id:
|
||||
task_set.add(task_id)
|
||||
|
||||
return task_set
|
||||
|
||||
|
||||
def celery_inspect_get_workers(name_filter: str | None, app: Celery) -> list[str]:
|
||||
"""Returns a list of current workers containing name_filter, or all workers if
|
||||
name_filter is None.
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
import httpx
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.app_configs import MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE
|
||||
from onyx.configs.app_configs import VESPA_REQUEST_TIMEOUT
|
||||
from onyx.connectors.cross_connector_utils.rate_limit_wrapper import (
|
||||
rate_limit_builder,
|
||||
)
|
||||
@@ -17,6 +20,7 @@ from onyx.db.connector_credential_pair import get_connector_credential_pair
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.enums import TaskStatus
|
||||
from onyx.db.models import TaskQueueState
|
||||
from onyx.httpx.httpx_pool import HttpxPool
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.redis.redis_connector import RedisConnector
|
||||
from onyx.server.documents.models import DeletionAttemptSnapshot
|
||||
@@ -154,3 +158,25 @@ def celery_is_worker_primary(worker: Any) -> bool:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def httpx_init_vespa_pool(
|
||||
max_keepalive_connections: int,
|
||||
timeout: int = VESPA_REQUEST_TIMEOUT,
|
||||
ssl_cert: str | None = None,
|
||||
ssl_key: str | None = None,
|
||||
) -> None:
|
||||
httpx_cert = None
|
||||
httpx_verify = False
|
||||
if ssl_cert and ssl_key:
|
||||
httpx_cert = cast(tuple[str, str], (ssl_cert, ssl_key))
|
||||
httpx_verify = True
|
||||
|
||||
HttpxPool.init_client(
|
||||
name="vespa",
|
||||
cert=httpx_cert,
|
||||
verify=httpx_verify,
|
||||
timeout=timeout,
|
||||
http2=False,
|
||||
limits=httpx.Limits(max_keepalive_connections=max_keepalive_connections),
|
||||
)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user