1
0
forked from github/onyx

Compare commits

...

64 Commits

Author SHA1 Message Date
Justin Tahara
83073f3ded fix(infra): Add Playwright Directory (#5313) 2025-09-01 19:35:48 -07:00
Wenxi
439a27a775 scroll forms on invalid submit (#5310) 2025-09-01 15:33:52 -07:00
Justin Tahara
91773a4789 fix(jira): Upgrade the Jira Python Version (#5309) 2025-09-01 15:33:03 -07:00
Chris Weaver
185beca648 Small center bar improvements (#5306) 2025-09-01 13:32:56 -07:00
Justin Tahara
2dc564c8df feat(infra): Add IAM support for Redis (#5267)
* feat: JIRA support for custom JQL filter (#5164)

* jira jql support

* jira jql fixes

* Address comment

---------

Co-authored-by: sktbcpraha <131408565+sktbcpraha@users.noreply.github.com>
2025-09-01 10:52:28 -07:00
Chris Weaver
b259f53972 Remove console-log (#5304) 2025-09-01 10:18:39 -07:00
Chris Weaver
f8beb08e2f Fix web build (#5303) 2025-09-01 10:18:06 -07:00
Evan Lohn
83c88c7cf6 feat: mcp client1 (#5271)
* working mcp implementation v1

* attempt openapi fix

* fastmcp
2025-09-01 09:52:35 -07:00
Chris Weaver
2372dd40e0 fix: small formatting fixes (#5300)
* SMall formatting fixes

* Fix mypy

* reorder imports
2025-08-31 23:19:22 -07:00
Chris Weaver
5cb6bafe81 Cleanup on ChatPage/MessagesDisply (#5299) 2025-08-31 21:29:17 -07:00
Mohamed Mathari
a0309b31c7 feat: Add Outline Connector (#5284)
* Outline

* fixConnector

* fixTest

* The date filtering is implemented correctly as client-side filtering, which is the only way to achieve it with the Outline API since it doesn't support date parameters natively.

* Update web/src/lib/connectors/connectors.tsx

Co-authored-by: cubic-dev-ai[bot] <191113872+cubic-dev-ai[bot]@users.noreply.github.com>

* no connector config for outline

* Update backend/onyx/connectors/outline/client.py

Co-authored-by: cubic-dev-ai[bot] <191113872+cubic-dev-ai[bot]@users.noreply.github.com>

* Fix all PR review issues: document ID prefixes, error handling, test assertions, and null guards

* Update backend/onyx/connectors/outline/client.py

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>

* The test no longer depends on external network connectivity to httpbin.org

* I've enhanced the OutlineApiClient.post() method in backend/onyx/connectors/outline/client.py to properly handle network-level exceptions that could crash the connector during synchronization:

* Polling mechanism

* Removed flag-based approach

* commentOnClasses

* commentOnClasses

* commentOnClasses

* responseStatus

* startBound

* Changed the method signature to match the interface

* ConnectorMissingCredentials

* Time Out shared config

* Missing Credential message

---------

Co-authored-by: cubic-dev-ai[bot] <191113872+cubic-dev-ai[bot]@users.noreply.github.com>
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
2025-08-31 20:56:10 -07:00
Chris Weaver
0fd268dba7 fix: message render performance (#5297)
* Separate out message display into it's own component

* Memoize AIMessage

* Cleanup

* Remove log

* Address greptile/cubic comments
2025-08-31 19:49:53 -07:00
Wenxi
f345da7487 fix: make radios and checkboxes actually clickable (#5298)
* dont nest labels, use htmlfor, fix slackbot form bug

* fix playwright tests for improved labels
2025-08-31 19:16:25 -07:00
Chris Weaver
f2dacf03f1 fix: Chat Page performance improvements (#5295)
* CC performance improvements r2

* More misc chat performance improvements

* Remove unused import

* Remove silly useMemo

* Fix small shift

* Address greptile + cubic + subash comments

* Fix build

* Improve document sidebar

* Remove console.log

* Remove more logs

* Fix build
2025-08-31 14:29:03 -07:00
Wenxi
e0fef50cf0 fix: don't skip ccpairs if embedding swap in progress (#5189)
* don't skip ccpairs if embedding swap in progress

* refactor check_for_indexing to properly handle search setting swaps

* mypy

* mypy

* comment debugging log

* nits and more efficient active index attempt check
2025-08-29 17:17:36 -07:00
Chris Weaver
6ba3eeefa5 feat: center bar + tool force + tool disable (#5272)
* Exploration

* Adding user-specific assistant preferences

* Small fixes

* Improvements

* Reset forced tools upon switch

* Add starter messages

* Improve starter messages

* Add misisng file

* cleanup renaming imports

* Address greptile/cubic comments

* Fix build

* Add back assistant info

* Fix image icon

* rebase fix

* Color corrections

* Small tweak

* More color correction

* Remove animation for now

* fix test

* Fix coloring + allow only one forced tool
2025-08-29 17:17:09 -07:00
Richard Guan
aa158abaa9 . (#5286) 2025-08-29 17:07:50 -07:00
Wenxi
255c2af1d6 feat: reorganize connectors pages (#5186)
* Add popular connectors sections and cleanup connectors page

* Add other connectors env var

* other connectors env var to vscode env template

* update playwright tests

* sort by popuarlity

* recategorize and sort by popularity
2025-08-29 16:59:00 -07:00
Chris Weaver
9ece3b0310 fix: improve index attempts API (#5287)
* Improve index attempts API

* Fix import
2025-08-29 16:15:58 -07:00
joachim-danswer
9e3aca03a7 fix: various dr issues and comments (#5280)
* replacement of "message_delta" etc as Enums + removal

* prompt changes

* cubic fixes where appropriate

* schema fixes + citation symbols

* various fixes

* fix for kg context in new search

* cw comments

* updates
2025-08-29 15:08:23 -07:00
Wenxi
dbd5d4d8f1 fix: allow jira api v3 (#5285)
* allow jira api v3

* don't rely on api version for parsing issues and separate cloud and dc versions
2025-08-29 14:02:01 -07:00
Chris Weaver
cdb97c3ce4 fix: test_soft_delete_chat_session (#5283)
* Fix test_soft_delete_chat_session

* Fix flakiness
2025-08-29 09:01:55 -07:00
Chris Weaver
f30ced31a9 fix: IT (#5276)
* Fix IT

* test

* Fix test

* test

* fix

* Fix test
2025-08-28 20:42:14 -07:00
Wenxi
6cc6c43234 fix: explain why limit=None is appropriate for discord (#5278)
* explain why limit=None is appropriate for discord

* linting
2025-08-28 14:17:46 -07:00
Wenxi
224d934cf4 fix: ruff complaint about type comparison (#5279)
* ruff complaint about type comparison

* ruff complaint type comparison
2025-08-28 14:17:30 -07:00
Nigel Brown
8ecdc61ad3 fix: Explicitly add limit to the function calls (#5273)
* Explicitly add limit to the function calls
This means we miss fewer messages. The default limit is 100.

Signed-off-by: nigel brown <nigel@stacklok.com>

* Update backend/onyx/connectors/discord/connector.py

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>

---------

Signed-off-by: nigel brown <nigel@stacklok.com>
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
2025-08-28 13:35:02 -07:00
Chris Weaver
08161db7ea fix: playwright tests (#5259)
* Fix playwright tests

* Address comment

* Fix
2025-08-27 23:23:55 -07:00
Richard Guan
b139764631 feat: Fast internet search (#5238)
* squash: combine all DR commits into one

Co-authored-by: Joachim Rahmfeld <joachim@onyx.app>
Co-authored-by: Rei Meguro <rmeguro@umich.edu>

* Fixes

* show KG in Assistant only if available

* KG only usable for KG Beta (for now)

* base file upload

* improvements

* raise error if uploaded context is too long

* More improvements

* Fix citations

* jank implementation of internet search with deep research that can kind of work

* early implementation for google api support

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

---------

Co-authored-by: Weves <chrisweaver101@gmail.com>
Co-authored-by: Joachim Rahmfeld <joachim@onyx.app>
Co-authored-by: Rei Meguro <rmeguro@umich.edu>
Co-authored-by: joachim-danswer <joachim@danswer.ai>
2025-08-27 20:03:02 -07:00
joachim-danswer
2b23dbde8d fix: small DR/Thoughtful mode fixes (#5269)
* fix budget calculation

* Internal custom tool fix + Okta special casing

* nits

* CW comments
2025-08-26 22:33:54 -07:00
Wenxi
2dec009d63 feat: add api/versions to onyx (#5268)
* add api/versions to onyx

* add test and rename onyx

* cubic nit

Co-authored-by: cubic-dev-ai[bot] <191113872+cubic-dev-ai[bot]@users.noreply.github.com>

* move api version constants and add explanatory comment

---------

Co-authored-by: cubic-dev-ai[bot] <191113872+cubic-dev-ai[bot]@users.noreply.github.com>
2025-08-26 18:14:54 -07:00
Chris Weaver
91eadae353 Fix logger startup (#5263) 2025-08-26 17:33:25 -07:00
Wenxi
8bff616e27 fix: clarify jql instructions and embed links (#5264)
* clarify jql instructions and embed links

* typo

* lint

* fix unit test
2025-08-26 17:27:07 -07:00
sktbcpraha
2c049e170f feat: JIRA support for custom JQL filter (#5164)
* jira jql support

* jira jql fixes
2025-08-26 12:44:39 -07:00
Oht8wooWi8yait9n
23e6d7ef3c Update gemini model names. (#5262)
Co-authored-by: Aaron Sells <aaron.b.sells@nasa.gov>
2025-08-26 12:33:02 -07:00
Chris Weaver
ed81e75edd fix: add jira auto-sync option in UI (#5260)
* Add jira auto-sync option in UI

* Fix build
2025-08-26 11:21:04 -07:00
Wenxi
de22fc3a58 remove dead code (#5261) 2025-08-26 11:14:12 -07:00
Cameron
009b7f60f1 Update date format used for fetching from Bookstack (#5221) 2025-08-26 09:49:38 -07:00
Chris Weaver
9d997e20df feat: frontend refactor + DR (#5225)
* squash: combine all DR commits into one

Co-authored-by: Joachim Rahmfeld <joachim@onyx.app>
Co-authored-by: Rei Meguro <rmeguro@umich.edu>

* Fixes

* show KG in Assistant only if available

* KG only usable for KG Beta (for now)

* base file upload

* raise error if uploaded context is too long

* improvements

* More improvements

* Fix citations

* better decision making

* improved decision-making in Orchestrator

* generic_internal tools

* Small tweak

* tool use improvements

* add on

* More image gen stuff

* fixes

* Small color improvements

* Markdown utils

* fixed end conditions (incl early exit for image generation)

* remove agent search + image fixes

* Okta tool support for reload

* Some cleanup

* Stream back search tool results as they come

* tool forcing

* fixed no-Tool-Assistant

* Support anthropic tool calling

* Support anthropic models better

* More stuff

* prompt fixes and search step numbers

* Fix hook ordering issue

* internal search fix

* Improve citation look

* Small UI improvements

* Improvements

* Improve dot

* Small chat fixes

* Small UI tweaks

* Small improvements

* Remove un-used code

* Fix

* Remove test_answer.py for now

* Fix

* improvements

* Add foreign keys

* early forcing

* Fix tests

* Fix tests

---------

Co-authored-by: Joachim Rahmfeld <joachim@onyx.app>
Co-authored-by: Rei Meguro <rmeguro@umich.edu>
Co-authored-by: joachim-danswer <joachim@danswer.ai>
2025-08-26 00:26:14 -07:00
Denizhan Dakılır
e6423c4541 Handle disabled auth in connector indexing status endpoint (#5256) 2025-08-25 16:42:46 -07:00
Wenxi
cb969ad06a add require_email_verification to values.yaml (#5249) 2025-08-25 22:02:49 +00:00
Sam Waddell
c4076d16b6 fix: update all log paths to reflect change related to non-root user (#5244) 2025-08-25 14:11:18 -07:00
Evan Lohn
04a607a718 ensure multi-tenant contextvar is passed (#5240) 2025-08-25 13:35:50 -07:00
Evan Lohn
c1e1aa9dfd fix: downloads are never larger than 20mb (#5247)
* fix: downloads are never larger than 20mb

* JT comments

* import to fix integration tests
2025-08-25 18:10:14 +00:00
Chris Weaver
1ed7abae6e Small improvement (#5250) 2025-08-25 08:07:36 +05:30
SubashMohan
cf4855822b Perf/indexing status page (#5142)
* indexing status optimization first draft

* refactor: update pagination logic and enhance UI for indexing status table

* add index attempt pruning job and display federated connectors in index status page

* update celery worker command to include index_attempt_cleanup queue

* refactor: enhance indexing status table and remove deprecated components

* mypy fix

* address review comments

* fix pagination reset issue

* add TODO for optimizing connector materialization and performance in future deployments

* enhance connector indexing status retrieval by adding 'get_all_connectors' option and updating pagination logic

* refactor: transition to paginated connector indexing status retrieval and update related components

* fix: initialize latest_index_attempt_docs_indexed to 0 in CCPairIndexingStatusTable component

* feat: add mock connector file support for indexing status retrieval and update indexing_statuses type to Sequence

* mypy fix

* refactor: rename indexing status endpoint to simplify API and update related components
2025-08-24 17:43:47 -07:00
Justin Tahara
e242b1319c fix(infra): Fixed RDS IAM Issue (#5245) 2025-08-22 18:13:12 -07:00
Justin Tahara
eba4b6620e feat(infra): AWS IAM Terraform (#5228)
* feat(infra): AWS IAM Terraform

* Fixing dependency issue

* Fixing more weird logic

* Final cleanup

* one change

* oops
2025-08-22 16:39:16 -07:00
Justin Tahara
3534515e11 feat(infra): Utilize AWS RDS IAM Auth (#5226)
* feat(infra): Utilize AWS RDS IAM Auth

* Update spacing

* Bump helm version
2025-08-21 17:35:53 -07:00
Justin Tahara
5602ff8666 fix: use only celery-shared for security context (#5236) (#5239)
* fix: use only celery-shared for security context

* fix: bump helm chart version 0.2.8

Co-authored-by: Sam Waddell <shwaddell28@gmail.com>
2025-08-21 17:25:06 -07:00
Sam Waddell
2fc70781b4 fix: use only celery-shared for security context (#5236)
* fix: use only celery-shared for security context

* fix: bump helm chart version 0.2.8
2025-08-21 14:15:07 -07:00
Justin Tahara
f76b4dec4c feat(infra): Ignoring local Terraform files (#5227)
* feat(infra): Ignoring local Terraform files

* Addressing some comments
2025-08-21 09:43:18 -07:00
Jessica Singh
a5a516fa8a refactor(model): move api-based embeddings/reranking calls out of model server (#5216)
* move api-based embeddings/reranking calls to api server out of model server, added/modified unit tests

* ran pre-commit

* fix mypy errors

* mypy and precommit

* move utils to right place and add requirements

* precommit check

* removed extra constants, changed error msg

* Update backend/onyx/utils/search_nlp_models_utils.py

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>

* greptile

* addressed comments

* added code enforcement to throw error

---------

Co-authored-by: Jessica Singh <jessicasingh@Mac.attlocal.net>
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
2025-08-20 21:50:21 +00:00
Sam Waddell
811a198134 docs: add non-root user info (#5224) 2025-08-20 13:50:10 -07:00
Sam Waddell
5867ab1d7d feat: add non-root user to backend and model-server images (#5134)
* feat: add non-root user to backend and model-server image

* feat: update values to support security context for index, inference, and celery_shared

* feat: add security context support for index and inference

* feat: add celery_shared security context support to celery worker templates

* fix: cache management strategy

* fix: update deployment files for volume mount

* fix: address comments

* fix: bump helm chart version for new security context template changes

* fix: bump helm chart version for new security context template changes

* feat: move useradd earlier in build for reduced image size

---------

Co-authored-by: Phil Critchfield <phil.critchfield@liatrio.com>
2025-08-20 13:49:50 -07:00
Jose Bañez
dd6653eb1f fix(connector): #5178 Add error handling and logging for empty answer text in Loopio Connector (#5179)
* fix(connector): #5178 Add error handling and logging for empty answer text in LoopioConnector

* fix(connector): onyx-dot-app#5178:  Improve handling of empty answer text in LoopioConnector

---------

Co-authored-by: Jose Bañez <jose@4gclinical.com>
2025-08-20 09:14:08 -07:00
Richard Guan
db457ef432 fix(admin): [DAN-2202] Remove users from invited users after accept (#5214)
* .

* .

* .

* .

* .

* .

* .

---------

Co-authored-by: Richard Guan <richardguan@Richards-MacBook-Pro.local>
Co-authored-by: Richard Guan <richardguan@Mac.attlocal.net>
2025-08-20 03:55:02 +00:00
Richard Guan
de7fe939b2 . (#5212)
Co-authored-by: Richard Guan <richardguan@Richards-MBP.lan>
2025-08-20 02:36:44 +00:00
Chris Weaver
38114d9542 fix: PDF file upload (#5218)
* Fix / improve file upload

* Address cubic comment
2025-08-19 15:16:08 -07:00
Justin Tahara
32f20f2e2e feat(infra): Add WAF implementation (#5213) (#5217)
* feat(infra): Add WAF implementation

* Addressing greptile comments

* Additional removal of unnecessary code
2025-08-19 13:01:40 -07:00
Justin Tahara
3dd27099f7 feat(infra): Add WAF implementation (#5213)
* feat(infra): Add WAF implementation

* Addressing greptile comments

* Additional removal of unnecessary code
2025-08-18 17:45:50 -07:00
Cameron
91c4d43a80 Move @types packages to devDependencies (#5210) 2025-08-18 14:34:09 -07:00
SubashMohan
a63ba1bb03 fix: sharepoint group not found error and url with apostrophe (#5208)
* fix: handle ClientRequestException in SharePoint permission utils and connector

* feat: enhance SharePoint permission utilities with logging and URL handling

* greptile typo fix

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>

* enhance group sync handling for public groups

---------

Co-authored-by: Wenxi <wenxi@onyx.app>
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
2025-08-18 17:12:59 +00:00
Evan Lohn
7b6189e74c corrected routing (#5202) 2025-08-18 16:07:28 +00:00
Evan Lohn
ba423e5773 fix: model server concurrency (#5206)
* fix: model server race cond

* fix async

* different approach
2025-08-18 16:07:16 +00:00
508 changed files with 32483 additions and 20231 deletions

View File

@@ -70,7 +70,9 @@ jobs:
-i /local/openapi.json \
-g python \
-o /local/onyx_openapi_client \
--package-name onyx_openapi_client
--package-name onyx_openapi_client \
--skip-validate-spec \
--openapi-normalizer "SIMPLIFY_ONEOF_ANYOF=true,SET_OAS3_NULLABLE=true"
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3

View File

@@ -67,7 +67,9 @@ jobs:
-i /local/openapi.json \
-g python \
-o /local/onyx_openapi_client \
--package-name onyx_openapi_client
--package-name onyx_openapi_client \
--skip-validate-spec \
--openapi-normalizer "SIMPLIFY_ONEOF_ANYOF=true,SET_OAS3_NULLABLE=true"
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3

View File

@@ -48,6 +48,8 @@ jobs:
-g python \
-o /local/onyx_openapi_client \
--package-name onyx_openapi_client \
--skip-validate-spec \
--openapi-normalizer "SIMPLIFY_ONEOF_ANYOF=true,SET_OAS3_NULLABLE=true"
- name: Run MyPy
run: |

12
.gitignore vendored
View File

@@ -17,12 +17,24 @@ backend/tests/regression/answer_quality/test_data.json
backend/tests/regression/search_quality/eval-*
backend/tests/regression/search_quality/search_eval_config.yaml
backend/tests/regression/search_quality/*.json
*.log
# secret files
.env
jira_test_env
settings.json
# others
/deployment/data/nginx/app.conf
*.sw?
/backend/tests/regression/answer_quality/search_test_config.yaml
# Local .terraform directories
**/.terraform/*
# Local .tfstate files
*.tfstate
*.tfstate.*
# Local .terraform.lock.hcl file
.terraform.lock.hcl

View File

@@ -23,6 +23,9 @@ DISABLE_LLM_DOC_RELEVANCE=False
# Useful if you want to toggle auth on/off (google_oauth/OIDC specifically)
OAUTH_CLIENT_ID=<REPLACE THIS>
OAUTH_CLIENT_SECRET=<REPLACE THIS>
OPENID_CONFIG_URL=<REPLACE THIS>
SAML_CONF_DIR=/<ABSOLUTE PATH TO ONYX>/onyx/backend/ee/onyx/configs/saml_config
# Generally not useful for dev, we don't generally want to set up an SMTP server for dev
REQUIRE_EMAIL_VERIFICATION=False
@@ -46,7 +49,6 @@ PYTHONUNBUFFERED=1
# Internet Search
BING_API_KEY=<REPLACE THIS>
EXA_API_KEY=<REPLACE THIS>
@@ -65,3 +67,6 @@ S3_ENDPOINT_URL=http://localhost:9004
S3_FILE_STORE_BUCKET_NAME=onyx-file-store-bucket
S3_AWS_ACCESS_KEY_ID=minioadmin
S3_AWS_SECRET_ACCESS_KEY=minioadmin
# Show extra/uncommon connectors
SHOW_EXTRA_CONNECTORS=True

View File

@@ -31,14 +31,16 @@
],
"presentation": {
"group": "1"
}
},
"stopAll": true
},
{
"name": "Web / Model / API",
"configurations": ["Web Server", "Model Server", "API Server"],
"presentation": {
"group": "1"
}
},
"stopAll": true
},
{
"name": "Celery (all)",
@@ -53,7 +55,8 @@
],
"presentation": {
"group": "1"
}
},
"stopAll": true
}
],
"configurations": [
@@ -189,7 +192,7 @@
"--loglevel=INFO",
"--hostname=light@%n",
"-Q",
"vespa_metadata_sync,connector_deletion,doc_permissions_upsert"
"vespa_metadata_sync,connector_deletion,doc_permissions_upsert,index_attempt_cleanup"
],
"presentation": {
"group": "2"

View File

@@ -103,10 +103,10 @@ If using PowerShell, the command slightly differs:
Install the required python dependencies:
```bash
pip install -r onyx/backend/requirements/default.txt
pip install -r onyx/backend/requirements/dev.txt
pip install -r onyx/backend/requirements/ee.txt
pip install -r onyx/backend/requirements/model_server.txt
pip install -r backend/requirements/default.txt
pip install -r backend/requirements/dev.txt
pip install -r backend/requirements/ee.txt
pip install -r backend/requirements/model_server.txt
```
Install Playwright for Python (headless browser required by the Web Connector)

View File

@@ -12,7 +12,8 @@ ARG ONYX_VERSION=0.0.0-dev
# DO_NOT_TRACK is used to disable telemetry for Unstructured
ENV ONYX_VERSION=${ONYX_VERSION} \
DANSWER_RUNNING_IN_DOCKER="true" \
DO_NOT_TRACK="true"
DO_NOT_TRACK="true" \
PLAYWRIGHT_BROWSERS_PATH="/app/.cache/ms-playwright"
RUN echo "ONYX_VERSION: ${ONYX_VERSION}"
@@ -116,6 +117,14 @@ COPY ./assets /app/assets
ENV PYTHONPATH=/app
# Create non-root user for security best practices
RUN groupadd -g 1001 onyx && \
useradd -u 1001 -g onyx -m -s /bin/bash onyx && \
chown -R onyx:onyx /app && \
mkdir -p /var/log/onyx && \
chmod 755 /var/log/onyx && \
chown onyx:onyx /var/log/onyx
# Default command which does nothing
# This container is used by api server and background which specify their own CMD
CMD ["tail", "-f", "/dev/null"]

View File

@@ -9,11 +9,20 @@ visit https://github.com/onyx-dot-app/onyx."
# Default ONYX_VERSION, typically overriden during builds by GitHub Actions.
ARG ONYX_VERSION=0.0.0-dev
ENV ONYX_VERSION=${ONYX_VERSION} \
DANSWER_RUNNING_IN_DOCKER="true"
DANSWER_RUNNING_IN_DOCKER="true" \
HF_HOME=/app/.cache/huggingface
RUN echo "ONYX_VERSION: ${ONYX_VERSION}"
# Create non-root user for security best practices
RUN mkdir -p /app && \
groupadd -g 1001 onyx && \
useradd -u 1001 -g onyx -m -s /bin/bash onyx && \
chown -R onyx:onyx /app && \
mkdir -p /var/log/onyx && \
chmod 755 /var/log/onyx && \
chown onyx:onyx /var/log/onyx
COPY ./requirements/model_server.txt /tmp/requirements.txt
RUN pip install --no-cache-dir --upgrade \
--retries 5 \
@@ -38,9 +47,11 @@ snapshot_download('mixedbread-ai/mxbai-rerank-xsmall-v1'); \
from sentence_transformers import SentenceTransformer; \
SentenceTransformer(model_name_or_path='nomic-ai/nomic-embed-text-v1', trust_remote_code=True);"
# In case the user has volumes mounted to /root/.cache/huggingface that they've downloaded while
# running Onyx, don't overwrite it with the built in cache folder
RUN mv /root/.cache/huggingface /root/.cache/temp_huggingface
# In case the user has volumes mounted to /app/.cache/huggingface that they've downloaded while
# running Onyx, move the current contents of the cache folder to a temporary location to ensure
# it's preserved in order to combine with the user's cache contents
RUN mv /app/.cache/huggingface /app/.cache/temp_huggingface && \
chown -R onyx:onyx /app
WORKDIR /app

View File

@@ -0,0 +1,115 @@
"""add research agent database tables and chat message research fields
Revision ID: 5ae8240accb3
Revises: b558f51620b4
Create Date: 2025-08-06 14:29:24.691388
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "5ae8240accb3"
down_revision = "b558f51620b4"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Add research_type and research_plan columns to chat_message table
op.add_column(
"chat_message",
sa.Column("research_type", sa.String(), nullable=True),
)
op.add_column(
"chat_message",
sa.Column("research_plan", postgresql.JSONB(), nullable=True),
)
# Create research_agent_iteration table
op.create_table(
"research_agent_iteration",
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
sa.Column(
"primary_question_id",
sa.Integer(),
sa.ForeignKey("chat_message.id", ondelete="CASCADE"),
nullable=False,
),
sa.Column("iteration_nr", sa.Integer(), nullable=False),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.func.now(),
nullable=False,
),
sa.Column("purpose", sa.String(), nullable=True),
sa.Column("reasoning", sa.String(), nullable=True),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint(
"primary_question_id",
"iteration_nr",
name="_research_agent_iteration_unique_constraint",
),
)
# Create research_agent_iteration_sub_step table
op.create_table(
"research_agent_iteration_sub_step",
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
sa.Column(
"primary_question_id",
sa.Integer(),
sa.ForeignKey("chat_message.id", ondelete="CASCADE"),
nullable=False,
),
sa.Column(
"parent_question_id",
sa.Integer(),
sa.ForeignKey("research_agent_iteration_sub_step.id", ondelete="CASCADE"),
nullable=True,
),
sa.Column("iteration_nr", sa.Integer(), nullable=False),
sa.Column("iteration_sub_step_nr", sa.Integer(), nullable=False),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.func.now(),
nullable=False,
),
sa.Column("sub_step_instructions", sa.String(), nullable=True),
sa.Column(
"sub_step_tool_id",
sa.Integer(),
sa.ForeignKey("tool.id"),
nullable=True,
),
sa.Column("reasoning", sa.String(), nullable=True),
sa.Column("sub_answer", sa.String(), nullable=True),
sa.Column("cited_doc_results", postgresql.JSONB(), nullable=True),
sa.Column("claims", postgresql.JSONB(), nullable=True),
sa.Column("generated_images", postgresql.JSONB(), nullable=True),
sa.Column("additional_data", postgresql.JSONB(), nullable=True),
sa.PrimaryKeyConstraint("id"),
sa.ForeignKeyConstraint(
["primary_question_id", "iteration_nr"],
[
"research_agent_iteration.primary_question_id",
"research_agent_iteration.iteration_nr",
],
ondelete="CASCADE",
),
)
def downgrade() -> None:
# Drop tables in reverse order
op.drop_table("research_agent_iteration_sub_step")
op.drop_table("research_agent_iteration")
# Remove columns from chat_message table
op.drop_column("chat_message", "research_plan")
op.drop_column("chat_message", "research_type")

View File

@@ -0,0 +1,249 @@
"""add_mcp_server_and_connection_config_models
Revision ID: 7ed603b64d5a
Revises: b329d00a9ea6
Create Date: 2025-07-28 17:35:59.900680
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
from onyx.db.enums import MCPAuthenticationType
# revision identifiers, used by Alembic.
revision = "7ed603b64d5a"
down_revision = "b329d00a9ea6"
branch_labels = None
depends_on = None
def upgrade() -> None:
"""Create tables and columns for MCP Server support"""
# 1. MCP Server main table (no FK constraints yet to avoid circular refs)
op.create_table(
"mcp_server",
sa.Column("id", sa.Integer(), primary_key=True),
sa.Column("owner", sa.String(), nullable=False),
sa.Column("name", sa.String(), nullable=False),
sa.Column("description", sa.String(), nullable=True),
sa.Column("server_url", sa.String(), nullable=False),
sa.Column(
"auth_type",
sa.Enum(
MCPAuthenticationType,
name="mcp_authentication_type",
native_enum=False,
),
nullable=False,
),
sa.Column("admin_connection_config_id", sa.Integer(), nullable=True),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"), # type: ignore
nullable=False,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"), # type: ignore
nullable=False,
),
)
# 2. MCP Connection Config table (can reference mcp_server now that it exists)
op.create_table(
"mcp_connection_config",
sa.Column("id", sa.Integer(), primary_key=True),
sa.Column("mcp_server_id", sa.Integer(), nullable=True),
sa.Column("user_email", sa.String(), nullable=False, default=""),
sa.Column("config", sa.LargeBinary(), nullable=False),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"), # type: ignore
nullable=False,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"), # type: ignore
nullable=False,
),
sa.ForeignKeyConstraint(
["mcp_server_id"], ["mcp_server.id"], ondelete="CASCADE"
),
)
# Helpful indexes
op.create_index(
"ix_mcp_connection_config_server_user",
"mcp_connection_config",
["mcp_server_id", "user_email"],
)
op.create_index(
"ix_mcp_connection_config_user_email",
"mcp_connection_config",
["user_email"],
)
# 3. Add the back-references from mcp_server to connection configs
op.create_foreign_key(
"mcp_server_admin_config_fk",
"mcp_server",
"mcp_connection_config",
["admin_connection_config_id"],
["id"],
ondelete="SET NULL",
)
# 4. Association / access-control tables
op.create_table(
"mcp_server__user",
sa.Column("mcp_server_id", sa.Integer(), primary_key=True),
sa.Column("user_id", sa.UUID(), primary_key=True),
sa.ForeignKeyConstraint(
["mcp_server_id"], ["mcp_server.id"], ondelete="CASCADE"
),
sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"),
)
op.create_table(
"mcp_server__user_group",
sa.Column("mcp_server_id", sa.Integer(), primary_key=True),
sa.Column("user_group_id", sa.Integer(), primary_key=True),
sa.ForeignKeyConstraint(
["mcp_server_id"], ["mcp_server.id"], ondelete="CASCADE"
),
sa.ForeignKeyConstraint(["user_group_id"], ["user_group.id"]),
)
# 5. Update existing `tool` table allow tools to belong to an MCP server
op.add_column(
"tool",
sa.Column("mcp_server_id", sa.Integer(), nullable=True),
)
# Add column for MCP tool input schema
op.add_column(
"tool",
sa.Column("mcp_input_schema", postgresql.JSONB(), nullable=True),
)
op.create_foreign_key(
"tool_mcp_server_fk",
"tool",
"mcp_server",
["mcp_server_id"],
["id"],
ondelete="CASCADE",
)
# 6. Update persona__tool foreign keys to cascade delete
# This ensures that when a tool is deleted (including via MCP server deletion),
# the corresponding persona__tool rows are also deleted
op.drop_constraint(
"persona__tool_tool_id_fkey", "persona__tool", type_="foreignkey"
)
op.drop_constraint(
"persona__tool_persona_id_fkey", "persona__tool", type_="foreignkey"
)
op.create_foreign_key(
"persona__tool_persona_id_fkey",
"persona__tool",
"persona",
["persona_id"],
["id"],
ondelete="CASCADE",
)
op.create_foreign_key(
"persona__tool_tool_id_fkey",
"persona__tool",
"tool",
["tool_id"],
["id"],
ondelete="CASCADE",
)
# 7. Update research_agent_iteration_sub_step foreign key to SET NULL on delete
# This ensures that when a tool is deleted, the sub_step_tool_id is set to NULL
# instead of causing a foreign key constraint violation
op.drop_constraint(
"research_agent_iteration_sub_step_sub_step_tool_id_fkey",
"research_agent_iteration_sub_step",
type_="foreignkey",
)
op.create_foreign_key(
"research_agent_iteration_sub_step_sub_step_tool_id_fkey",
"research_agent_iteration_sub_step",
"tool",
["sub_step_tool_id"],
["id"],
ondelete="SET NULL",
)
def downgrade() -> None:
"""Drop all MCP-related tables / columns"""
# # # 1. Drop FK & columns from tool
# op.drop_constraint("tool_mcp_server_fk", "tool", type_="foreignkey")
op.execute("DELETE FROM tool WHERE mcp_server_id IS NOT NULL")
op.drop_constraint(
"research_agent_iteration_sub_step_sub_step_tool_id_fkey",
"research_agent_iteration_sub_step",
type_="foreignkey",
)
op.create_foreign_key(
"research_agent_iteration_sub_step_sub_step_tool_id_fkey",
"research_agent_iteration_sub_step",
"tool",
["sub_step_tool_id"],
["id"],
)
# Restore original persona__tool foreign keys (without CASCADE)
op.drop_constraint(
"persona__tool_persona_id_fkey", "persona__tool", type_="foreignkey"
)
op.drop_constraint(
"persona__tool_tool_id_fkey", "persona__tool", type_="foreignkey"
)
op.create_foreign_key(
"persona__tool_persona_id_fkey",
"persona__tool",
"persona",
["persona_id"],
["id"],
)
op.create_foreign_key(
"persona__tool_tool_id_fkey",
"persona__tool",
"tool",
["tool_id"],
["id"],
)
op.drop_column("tool", "mcp_input_schema")
op.drop_column("tool", "mcp_server_id")
# 2. Drop association tables
op.drop_table("mcp_server__user_group")
op.drop_table("mcp_server__user")
# 3. Drop FK from mcp_server to connection configs
op.drop_constraint("mcp_server_admin_config_fk", "mcp_server", type_="foreignkey")
# 4. Drop connection config indexes & table
op.drop_index(
"ix_mcp_connection_config_user_email", table_name="mcp_connection_config"
)
op.drop_index(
"ix_mcp_connection_config_server_user", table_name="mcp_connection_config"
)
op.drop_table("mcp_connection_config")
# 5. Finally drop mcp_server table
op.drop_table("mcp_server")

View File

@@ -0,0 +1,38 @@
"""Adding assistant-specific user preferences
Revision ID: b329d00a9ea6
Revises: f9b8c7d6e5a4
Create Date: 2025-08-26 23:14:44.592985
"""
from alembic import op
import fastapi_users_db_sqlalchemy
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "b329d00a9ea6"
down_revision = "f9b8c7d6e5a4"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.create_table(
"assistant__user_specific_config",
sa.Column("assistant_id", sa.Integer(), nullable=False),
sa.Column(
"user_id",
fastapi_users_db_sqlalchemy.generics.GUID(),
nullable=False,
),
sa.Column("disabled_tool_ids", postgresql.ARRAY(sa.Integer()), nullable=False),
sa.ForeignKeyConstraint(["assistant_id"], ["persona.id"], ondelete="CASCADE"),
sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("assistant_id", "user_id"),
)
def downgrade() -> None:
op.drop_table("assistant__user_specific_config")

View File

@@ -0,0 +1,147 @@
"""migrate_agent_sub_questions_to_research_iterations
Revision ID: bd7c3bf8beba
Revises: f8a9b2c3d4e5
Create Date: 2025-08-18 11:33:27.098287
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "bd7c3bf8beba"
down_revision = "f8a9b2c3d4e5"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Get connection to execute raw SQL
connection = op.get_bind()
# First, insert data into research_agent_iteration table
# This creates one iteration record per primary_question_id using the earliest time_created
connection.execute(
sa.text(
"""
INSERT INTO research_agent_iteration (primary_question_id, created_at, iteration_nr, purpose, reasoning)
SELECT
primary_question_id,
MIN(time_created) as created_at,
1 as iteration_nr,
'Generating and researching subquestions' as purpose,
'(No previous reasoning)' as reasoning
FROM agent__sub_question
JOIN chat_message on agent__sub_question.primary_question_id = chat_message.id
WHERE primary_question_id IS NOT NULL
AND chat_message.is_agentic = true
GROUP BY primary_question_id
ON CONFLICT DO NOTHING;
"""
)
)
# Then, insert data into research_agent_iteration_sub_step table
# This migrates each sub-question as a sub-step
connection.execute(
sa.text(
"""
INSERT INTO research_agent_iteration_sub_step (
primary_question_id,
iteration_nr,
iteration_sub_step_nr,
created_at,
sub_step_instructions,
sub_step_tool_id,
sub_answer,
cited_doc_results
)
SELECT
primary_question_id,
1 as iteration_nr,
level_question_num as iteration_sub_step_nr,
time_created as created_at,
sub_question as sub_step_instructions,
1 as sub_step_tool_id,
sub_answer,
sub_question_doc_results as cited_doc_results
FROM agent__sub_question
JOIN chat_message on agent__sub_question.primary_question_id = chat_message.id
WHERE chat_message.is_agentic = true
AND primary_question_id IS NOT NULL
ON CONFLICT DO NOTHING;
"""
)
)
# Update chat_message records: set legacy agentic type and answer purpose for existing agentic messages
connection.execute(
sa.text(
"""
UPDATE chat_message
SET research_answer_purpose = 'ANSWER'
WHERE is_agentic = true
AND research_type IS NULL and
message_type = 'ASSISTANT';
"""
)
)
connection.execute(
sa.text(
"""
UPDATE chat_message
SET research_type = 'LEGACY_AGENTIC'
WHERE is_agentic = true
AND research_type IS NULL;
"""
)
)
def downgrade() -> None:
# Get connection to execute raw SQL
connection = op.get_bind()
# Note: This downgrade removes all research agent iteration data
# There's no way to perfectly restore the original agent__sub_question data
# if it was deleted after this migration
# Delete all research_agent_iteration_sub_step records that were migrated
connection.execute(
sa.text(
"""
DELETE FROM research_agent_iteration_sub_step
USING chat_message
WHERE research_agent_iteration_sub_step.primary_question_id = chat_message.id
AND chat_message.research_type = 'LEGACY_AGENTIC';
"""
)
)
# Delete all research_agent_iteration records that were migrated
connection.execute(
sa.text(
"""
DELETE FROM research_agent_iteration
USING chat_message
WHERE research_agent_iteration.primary_question_id = chat_message.id
AND chat_message.research_type = 'LEGACY_AGENTIC';
"""
)
)
# Revert chat_message updates: clear research fields for legacy agentic messages
connection.execute(
sa.text(
"""
UPDATE chat_message
SET research_type = NULL,
research_answer_purpose = NULL
WHERE is_agentic = true
AND research_type = 'LEGACY_AGENTIC'
AND message_type = 'ASSISTANT';
"""
)
)

View File

@@ -0,0 +1,30 @@
"""add research_answer_purpose to chat_message
Revision ID: f8a9b2c3d4e5
Revises: 5ae8240accb3
Create Date: 2025-01-27 12:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "f8a9b2c3d4e5"
down_revision = "5ae8240accb3"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Add research_answer_purpose column to chat_message table
op.add_column(
"chat_message",
sa.Column("research_answer_purpose", sa.String(), nullable=True),
)
def downgrade() -> None:
# Remove research_answer_purpose column from chat_message table
op.drop_column("chat_message", "research_answer_purpose")

View File

@@ -0,0 +1,69 @@
"""remove foreign key constraints from research_agent_iteration_sub_step
Revision ID: f9b8c7d6e5a4
Revises: bd7c3bf8beba
Create Date: 2025-01-27 12:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "f9b8c7d6e5a4"
down_revision = "bd7c3bf8beba"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Drop the existing foreign key constraint for parent_question_id
op.drop_constraint(
"research_agent_iteration_sub_step_parent_question_id_fkey",
"research_agent_iteration_sub_step",
type_="foreignkey",
)
# Drop the parent_question_id column entirely
op.drop_column("research_agent_iteration_sub_step", "parent_question_id")
# Drop the foreign key constraint for primary_question_id to chat_message.id
# (keep the column as it's needed for the composite foreign key)
op.drop_constraint(
"research_agent_iteration_sub_step_primary_question_id_fkey",
"research_agent_iteration_sub_step",
type_="foreignkey",
)
def downgrade() -> None:
# Restore the foreign key constraint for primary_question_id to chat_message.id
op.create_foreign_key(
"research_agent_iteration_sub_step_primary_question_id_fkey",
"research_agent_iteration_sub_step",
"chat_message",
["primary_question_id"],
["id"],
ondelete="CASCADE",
)
# Add back the parent_question_id column
op.add_column(
"research_agent_iteration_sub_step",
sa.Column(
"parent_question_id",
sa.Integer(),
nullable=True,
),
)
# Restore the foreign key constraint pointing to research_agent_iteration.id
op.create_foreign_key(
"research_agent_iteration_sub_step_parent_question_id_fkey",
"research_agent_iteration_sub_step",
"research_agent_iteration",
["parent_question_id"],
["id"],
ondelete="CASCADE",
)

View File

@@ -1,17 +1,17 @@
from ee.onyx.server.query_and_chat.models import OneShotQAResponse
from onyx.chat.models import AllCitations
from onyx.chat.models import AnswerStream
from onyx.chat.models import LLMRelevanceFilterResponse
from onyx.chat.models import OnyxAnswerPiece
from onyx.chat.models import QADocsResponse
from onyx.chat.models import StreamingError
from onyx.chat.process_message import ChatPacketStream
from onyx.server.query_and_chat.models import ChatMessageDetail
from onyx.utils.timing import log_function_time
@log_function_time()
def gather_stream_for_answer_api(
packets: ChatPacketStream,
packets: AnswerStream,
) -> OneShotQAResponse:
response = OneShotQAResponse()

View File

@@ -2,18 +2,14 @@ from collections.abc import Callable
from collections.abc import Generator
from typing import Optional
from typing import Protocol
from typing import TYPE_CHECKING
from ee.onyx.db.external_perm import ExternalUserGroup # noqa
from onyx.access.models import DocExternalAccess # noqa
from onyx.context.search.models import InferenceChunk
from onyx.db.models import ConnectorCredentialPair # noqa
from onyx.db.utils import DocumentRow
from onyx.db.utils import SortOrder
# Avoid circular imports
if TYPE_CHECKING:
from ee.onyx.db.external_perm import ExternalUserGroup # noqa
from onyx.access.models import DocExternalAccess # noqa
from onyx.db.models import ConnectorCredentialPair # noqa
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface # noqa
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface # noqa
class FetchAllDocumentsFunction(Protocol):
@@ -52,20 +48,20 @@ class FetchAllDocumentsIdsFunction(Protocol):
# Defining the input/output types for the sync functions
DocSyncFuncType = Callable[
[
"ConnectorCredentialPair",
ConnectorCredentialPair,
FetchAllDocumentsFunction,
FetchAllDocumentsIdsFunction,
Optional["IndexingHeartbeatInterface"],
Optional[IndexingHeartbeatInterface],
],
Generator["DocExternalAccess", None, None],
Generator[DocExternalAccess, None, None],
]
GroupSyncFuncType = Callable[
[
str, # tenant_id
"ConnectorCredentialPair", # cc_pair
ConnectorCredentialPair, # cc_pair
],
Generator["ExternalUserGroup", None, None],
Generator[ExternalUserGroup, None, None],
]
# list of chunks to be censored and the user email. returns censored chunks

View File

@@ -1,9 +1,12 @@
import re
from collections import deque
from typing import Any
from urllib.parse import unquote
from urllib.parse import urlparse
from office365.graph_client import GraphClient # type: ignore[import-untyped]
from office365.onedrive.driveitems.driveItem import DriveItem # type: ignore[import-untyped]
from office365.runtime.client_request import ClientRequestException # type: ignore
from office365.sharepoint.client_context import ClientContext # type: ignore[import-untyped]
from office365.sharepoint.permissions.securable_object import RoleAssignmentCollection # type: ignore[import-untyped]
from pydantic import BaseModel
@@ -231,6 +234,7 @@ def _get_sharepoint_groups(
nonlocal groups, user_emails
for user in users:
logger.debug(f"User: {user.to_json()}")
if user.principal_type == USER_PRINCIPAL_TYPE and hasattr(
user, "user_principal_name"
):
@@ -285,7 +289,7 @@ def _get_azuread_groups(
for member in members:
member_data = member.to_json()
logger.debug(f"Member: {member_data}")
# Check for user-specific attributes
user_principal_name = member_data.get("userPrincipalName")
mail = member_data.get("mail")
@@ -366,6 +370,7 @@ def _get_groups_and_members_recursively(
client_context: ClientContext,
graph_client: GraphClient,
groups: set[SharepointGroup],
is_group_sync: bool = False,
) -> GroupsResult:
"""
Get all groups and their members recursively.
@@ -373,6 +378,7 @@ def _get_groups_and_members_recursively(
group_queue: deque[SharepointGroup] = deque(groups)
visited_groups: set[str] = set()
visited_group_name_to_emails: dict[str, set[str]] = {}
found_public_group = False
while group_queue:
group = group_queue.popleft()
if group.login_name in visited_groups:
@@ -390,19 +396,35 @@ def _get_groups_and_members_recursively(
if group_info:
group_queue.extend(group_info)
if group.principal_type == AZURE_AD_GROUP_PRINCIPAL_TYPE:
# if the site is public, we have default groups assigned to it, so we return early
if _is_public_login_name(group.login_name):
return GroupsResult(groups_to_emails={}, found_public_group=True)
group_info, user_emails = _get_azuread_groups(
graph_client, group.login_name
)
visited_group_name_to_emails[group.name].update(user_emails)
if group_info:
group_queue.extend(group_info)
try:
# if the site is public, we have default groups assigned to it, so we return early
if _is_public_login_name(group.login_name):
found_public_group = True
if not is_group_sync:
return GroupsResult(
groups_to_emails={}, found_public_group=True
)
else:
# we don't want to sync public groups, so we skip them
continue
group_info, user_emails = _get_azuread_groups(
graph_client, group.login_name
)
visited_group_name_to_emails[group.name].update(user_emails)
if group_info:
group_queue.extend(group_info)
except ClientRequestException as e:
# If the group is not found, we skip it. There is a chance that group is still referenced
# in sharepoint but it is removed from Azure AD. There is no actual documentation on this, but based on
# our testing we have seen this happen.
if e.response is not None and e.response.status_code == 404:
logger.warning(f"Group {group.login_name} not found")
continue
raise e
return GroupsResult(
groups_to_emails=visited_group_name_to_emails, found_public_group=False
groups_to_emails=visited_group_name_to_emails,
found_public_group=found_public_group,
)
@@ -427,6 +449,7 @@ def get_external_access_from_sharepoint(
) -> None:
nonlocal user_emails, groups
for assignment in role_assignments:
logger.debug(f"Assignment: {assignment.to_json()}")
if assignment.role_definition_bindings:
is_limited_access = True
for role_definition_binding in assignment.role_definition_bindings:
@@ -503,12 +526,19 @@ def get_external_access_from_sharepoint(
)
elif site_page:
site_url = site_page.get("webUrl")
site_pages = client_context.web.lists.get_by_title("Site Pages")
client_context.load(site_pages)
client_context.execute_query()
site_pages.items.get_by_url(site_url).role_assignments.expand(
["Member", "RoleDefinitionBindings"]
).get_all(page_loaded=add_user_and_group_to_sets).execute_query()
# Prefer server-relative URL to avoid OData filters that break on apostrophes
server_relative_url = unquote(urlparse(site_url).path)
file_obj = client_context.web.get_file_by_server_relative_url(
server_relative_url
)
item = file_obj.listItemAllFields
sleep_and_retry(
item.role_assignments.expand(["Member", "RoleDefinitionBindings"]).get_all(
page_loaded=add_user_and_group_to_sets,
),
"get_external_access_from_sharepoint",
)
else:
raise RuntimeError("No drive item or site page provided")
@@ -595,13 +625,9 @@ def get_sharepoint_external_groups(
"get_sharepoint_external_groups",
)
groups_and_members: GroupsResult = _get_groups_and_members_recursively(
client_context, graph_client, groups
client_context, graph_client, groups, is_group_sync=True
)
# We don't have any direct way to check if the site is public, so we check if any public group is present
if groups_and_members.found_public_group:
return []
# get all Azure AD groups because if any group is assigned to the drive item, we don't want to miss them
# We can't assign sharepoint groups to drive items or drives, so we don't need to get all sharepoint groups
azure_ad_groups = sleep_and_retry(

View File

@@ -1,43 +1,22 @@
import re
from typing import cast
from uuid import UUID
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from sqlalchemy.orm import Session
from ee.onyx.server.query_and_chat.models import AgentAnswer
from ee.onyx.server.query_and_chat.models import AgentSubQuery
from ee.onyx.server.query_and_chat.models import AgentSubQuestion
from ee.onyx.server.query_and_chat.models import BasicCreateChatMessageRequest
from ee.onyx.server.query_and_chat.models import (
BasicCreateChatMessageWithHistoryRequest,
)
from ee.onyx.server.query_and_chat.models import ChatBasicResponse
from onyx.auth.users import current_user
from onyx.chat.chat_utils import combine_message_thread
from onyx.chat.chat_utils import create_chat_chain
from onyx.chat.models import AgentAnswerPiece
from onyx.chat.models import AllCitations
from onyx.chat.models import ExtendedToolResponse
from onyx.chat.models import FinalUsedContextDocsResponse
from onyx.chat.models import LlmDoc
from onyx.chat.models import LLMRelevanceFilterResponse
from onyx.chat.models import OnyxAnswerPiece
from onyx.chat.models import QADocsResponse
from onyx.chat.models import RefinedAnswerImprovement
from onyx.chat.models import StreamingError
from onyx.chat.models import SubQueryPiece
from onyx.chat.models import SubQuestionIdentifier
from onyx.chat.models import SubQuestionPiece
from onyx.chat.process_message import ChatPacketStream
from onyx.chat.models import ChatBasicResponse
from onyx.chat.process_message import gather_stream
from onyx.chat.process_message import stream_chat_message_objects
from onyx.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
from onyx.configs.constants import MessageType
from onyx.context.search.models import OptionalSearchSetting
from onyx.context.search.models import RetrievalDetails
from onyx.context.search.models import SavedSearchDoc
from onyx.db.chat import create_chat_session
from onyx.db.chat import create_new_chat_message
from onyx.db.chat import get_or_create_root_message
@@ -46,7 +25,6 @@ from onyx.db.models import User
from onyx.llm.factory import get_llms_for_persona
from onyx.natural_language_processing.utils import get_tokenizer
from onyx.secondary_llm_flows.query_expansion import thread_based_query_rephrase
from onyx.server.query_and_chat.models import ChatMessageDetail
from onyx.server.query_and_chat.models import CreateChatMessageRequest
from onyx.utils.logger import setup_logger
@@ -55,180 +33,6 @@ logger = setup_logger()
router = APIRouter(prefix="/chat")
def _get_final_context_doc_indices(
final_context_docs: list[LlmDoc] | None,
top_docs: list[SavedSearchDoc] | None,
) -> list[int] | None:
"""
this function returns a list of indices of the simple search docs
that were actually fed to the LLM.
"""
if final_context_docs is None or top_docs is None:
return None
final_context_doc_ids = {doc.document_id for doc in final_context_docs}
return [
i for i, doc in enumerate(top_docs) if doc.document_id in final_context_doc_ids
]
def _convert_packet_stream_to_response(
packets: ChatPacketStream,
chat_session_id: UUID,
) -> ChatBasicResponse:
response = ChatBasicResponse()
final_context_docs: list[LlmDoc] = []
answer = ""
# accumulate stream data with these dicts
agent_sub_questions: dict[tuple[int, int], AgentSubQuestion] = {}
agent_answers: dict[tuple[int, int], AgentAnswer] = {}
agent_sub_queries: dict[tuple[int, int, int], AgentSubQuery] = {}
for packet in packets:
if isinstance(packet, OnyxAnswerPiece) and packet.answer_piece:
answer += packet.answer_piece
elif isinstance(packet, QADocsResponse):
response.top_documents = packet.top_documents
# This is a no-op if agent_sub_questions hasn't already been filled
if packet.level is not None and packet.level_question_num is not None:
id = (packet.level, packet.level_question_num)
if id in agent_sub_questions:
agent_sub_questions[id].document_ids = [
saved_search_doc.document_id
for saved_search_doc in packet.top_documents
]
elif isinstance(packet, StreamingError):
response.error_msg = packet.error
elif isinstance(packet, ChatMessageDetail):
response.message_id = packet.message_id
elif isinstance(packet, LLMRelevanceFilterResponse):
response.llm_selected_doc_indices = packet.llm_selected_doc_indices
# TODO: deprecate `llm_chunks_indices`
response.llm_chunks_indices = packet.llm_selected_doc_indices
elif isinstance(packet, FinalUsedContextDocsResponse):
final_context_docs = packet.final_context_docs
elif isinstance(packet, AllCitations):
response.cited_documents = {
citation.citation_num: citation.document_id
for citation in packet.citations
}
# agentic packets
elif isinstance(packet, SubQuestionPiece):
if packet.level is not None and packet.level_question_num is not None:
id = (packet.level, packet.level_question_num)
if agent_sub_questions.get(id) is None:
agent_sub_questions[id] = AgentSubQuestion(
level=packet.level,
level_question_num=packet.level_question_num,
sub_question=packet.sub_question,
document_ids=[],
)
else:
agent_sub_questions[id].sub_question += packet.sub_question
elif isinstance(packet, AgentAnswerPiece):
if packet.level is not None and packet.level_question_num is not None:
id = (packet.level, packet.level_question_num)
if agent_answers.get(id) is None:
agent_answers[id] = AgentAnswer(
level=packet.level,
level_question_num=packet.level_question_num,
answer=packet.answer_piece,
answer_type=packet.answer_type,
)
else:
agent_answers[id].answer += packet.answer_piece
elif isinstance(packet, SubQueryPiece):
if packet.level is not None and packet.level_question_num is not None:
sub_query_id = (
packet.level,
packet.level_question_num,
packet.query_id,
)
if agent_sub_queries.get(sub_query_id) is None:
agent_sub_queries[sub_query_id] = AgentSubQuery(
level=packet.level,
level_question_num=packet.level_question_num,
sub_query=packet.sub_query,
query_id=packet.query_id,
)
else:
agent_sub_queries[sub_query_id].sub_query += packet.sub_query
elif isinstance(packet, ExtendedToolResponse):
# we shouldn't get this ... it gets intercepted and translated to QADocsResponse
logger.warning(
"_convert_packet_stream_to_response: Unexpected chat packet type ExtendedToolResponse!"
)
elif isinstance(packet, RefinedAnswerImprovement):
response.agent_refined_answer_improvement = (
packet.refined_answer_improvement
)
else:
logger.warning(
f"_convert_packet_stream_to_response - Unrecognized chat packet: type={type(packet)}"
)
response.final_context_doc_indices = _get_final_context_doc_indices(
final_context_docs, response.top_documents
)
# organize / sort agent metadata for output
if len(agent_sub_questions) > 0:
response.agent_sub_questions = cast(
dict[int, list[AgentSubQuestion]],
SubQuestionIdentifier.make_dict_by_level(agent_sub_questions),
)
if len(agent_answers) > 0:
# return the agent_level_answer from the first level or the last one depending
# on agent_refined_answer_improvement
response.agent_answers = cast(
dict[int, list[AgentAnswer]],
SubQuestionIdentifier.make_dict_by_level(agent_answers),
)
if response.agent_answers:
selected_answer_level = (
0
if not response.agent_refined_answer_improvement
else len(response.agent_answers) - 1
)
level_answers = response.agent_answers[selected_answer_level]
for level_answer in level_answers:
if level_answer.answer_type != "agent_level_answer":
continue
answer = level_answer.answer
break
if len(agent_sub_queries) > 0:
# subqueries are often emitted with trailing whitespace ... clean it up here
# perhaps fix at the source?
for v in agent_sub_queries.values():
v.sub_query = v.sub_query.strip()
response.agent_sub_queries = (
AgentSubQuery.make_dict_by_level_and_question_index(agent_sub_queries)
)
response.answer = answer
if answer:
response.answer_citationless = remove_answer_citations(answer)
response.chat_session_id = chat_session_id
return response
def remove_answer_citations(answer: str) -> str:
pattern = r"\s*\[\[\d+\]\]\(http[s]?://[^\s]+\)"
return re.sub(pattern, "", answer)
@router.post("/send-message-simple-api")
def handle_simplified_chat_message(
chat_message_req: BasicCreateChatMessageRequest,
@@ -310,7 +114,7 @@ def handle_simplified_chat_message(
enforce_chat_session_id_for_search_docs=False,
)
return _convert_packet_stream_to_response(packets, chat_session_id)
return gather_stream(packets)
@router.post("/send-message-simple-with-history")
@@ -430,4 +234,4 @@ def handle_send_message_simple_with_history(
enforce_chat_session_id_for_search_docs=False,
)
return _convert_packet_stream_to_response(packets, chat_session.id)
return gather_stream(packets)

View File

@@ -6,10 +6,8 @@ from pydantic import BaseModel
from pydantic import Field
from pydantic import model_validator
from onyx.chat.models import CitationInfo
from onyx.chat.models import PersonaOverrideConfig
from onyx.chat.models import QADocsResponse
from onyx.chat.models import SubQuestionIdentifier
from onyx.chat.models import ThreadMessage
from onyx.configs.constants import DocumentSource
from onyx.context.search.enums import LLMEvaluationType
@@ -17,8 +15,9 @@ from onyx.context.search.enums import SearchType
from onyx.context.search.models import ChunkContext
from onyx.context.search.models import RerankingDetails
from onyx.context.search.models import RetrievalDetails
from onyx.context.search.models import SavedSearchDoc
from onyx.server.manage.models import StandardAnswer
from onyx.server.query_and_chat.streaming_models import CitationInfo
from onyx.server.query_and_chat.streaming_models import SubQuestionIdentifier
class StandardAnswerRequest(BaseModel):
@@ -156,33 +155,6 @@ class AgentSubQuery(SubQuestionIdentifier):
return sorted_dict
class ChatBasicResponse(BaseModel):
# This is built piece by piece, any of these can be None as the flow could break
answer: str | None = None
answer_citationless: str | None = None
top_documents: list[SavedSearchDoc] | None = None
error_msg: str | None = None
message_id: int | None = None
llm_selected_doc_indices: list[int] | None = None
final_context_doc_indices: list[int] | None = None
# this is a map of the citation number to the document id
cited_documents: dict[int, str] | None = None
# FOR BACKWARDS COMPATIBILITY
llm_chunks_indices: list[int] | None = None
# agentic fields
agent_sub_questions: dict[int, list[AgentSubQuestion]] | None = None
agent_answers: dict[int, list[AgentAnswer]] | None = None
agent_sub_queries: dict[int, dict[int, list[AgentSubQuery]]] | None = None
agent_refined_answer_improvement: bool | None = None
# Chat session ID for tracking conversation continuity
chat_session_id: UUID | None = None
class OneShotQARequest(ChunkContext):
# Supports simplier APIs that don't deal with chat histories or message edits
# Easier APIs to work with for developers
@@ -193,7 +165,6 @@ class OneShotQARequest(ChunkContext):
prompt_id: int | None = None
retrieval_options: RetrievalDetails = Field(default_factory=RetrievalDetails)
rerank_settings: RerankingDetails | None = None
return_contexts: bool = False
# allows the caller to specify the exact search query they want to use
# can be used if the message sent to the LLM / query should not be the same

View File

@@ -20,8 +20,8 @@ from ee.onyx.server.query_and_chat.models import StandardAnswerResponse
from onyx.auth.users import current_user
from onyx.chat.chat_utils import combine_message_thread
from onyx.chat.chat_utils import prepare_chat_message_request
from onyx.chat.models import AnswerStream
from onyx.chat.models import PersonaOverrideConfig
from onyx.chat.process_message import ChatPacketStream
from onyx.chat.process_message import stream_chat_message_objects
from onyx.configs.onyxbot_configs import MAX_THREAD_CONTEXT_PERCENTAGE
from onyx.context.search.models import SavedSearchDocWithContent
@@ -140,7 +140,7 @@ def get_answer_stream(
query_request: OneShotQARequest,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> ChatPacketStream:
) -> AnswerStream:
query = query_request.messages[0].message
logger.notice(f"Received query for Answer API: {query}")
@@ -205,7 +205,6 @@ def get_answer_stream(
new_msg_req=request,
user=user,
db_session=db_session,
include_contexts=query_request.return_contexts,
)
return packets

View File

@@ -1,34 +1,5 @@
from shared_configs.enums import EmbeddingProvider
from shared_configs.enums import EmbedTextType
MODEL_WARM_UP_STRING = "hi " * 512
INFORMATION_CONTENT_MODEL_WARM_UP_STRING = "hi " * 16
DEFAULT_OPENAI_MODEL = "text-embedding-3-small"
DEFAULT_COHERE_MODEL = "embed-english-light-v3.0"
DEFAULT_VOYAGE_MODEL = "voyage-large-2-instruct"
DEFAULT_VERTEX_MODEL = "text-embedding-005"
class EmbeddingModelTextType:
PROVIDER_TEXT_TYPE_MAP = {
EmbeddingProvider.COHERE: {
EmbedTextType.QUERY: "search_query",
EmbedTextType.PASSAGE: "search_document",
},
EmbeddingProvider.VOYAGE: {
EmbedTextType.QUERY: "query",
EmbedTextType.PASSAGE: "document",
},
EmbeddingProvider.GOOGLE: {
EmbedTextType.QUERY: "RETRIEVAL_QUERY",
EmbedTextType.PASSAGE: "RETRIEVAL_DOCUMENT",
},
}
@staticmethod
def get_type(provider: EmbeddingProvider, text_type: EmbedTextType) -> str:
return EmbeddingModelTextType.PROVIDER_TEXT_TYPE_MAP[provider][text_type]
class GPUStatus:

View File

@@ -1,55 +1,30 @@
import asyncio
import json
import time
from types import TracebackType
from typing import cast
from typing import Any
from typing import Optional
import aioboto3 # type: ignore
import httpx
import openai
import vertexai # type: ignore
import voyageai # type: ignore
from cohere import AsyncClient as CohereAsyncClient
from fastapi import APIRouter
from fastapi import HTTPException
from fastapi import Request
from google.oauth2 import service_account # type: ignore
from litellm import aembedding
from litellm.exceptions import RateLimitError
from retry import retry
from sentence_transformers import CrossEncoder # type: ignore
from sentence_transformers import SentenceTransformer # type: ignore
from vertexai.language_models import TextEmbeddingInput # type: ignore
from vertexai.language_models import TextEmbeddingModel # type: ignore
from model_server.constants import DEFAULT_COHERE_MODEL
from model_server.constants import DEFAULT_OPENAI_MODEL
from model_server.constants import DEFAULT_VERTEX_MODEL
from model_server.constants import DEFAULT_VOYAGE_MODEL
from model_server.constants import EmbeddingModelTextType
from model_server.constants import EmbeddingProvider
from model_server.utils import pass_aws_key
from model_server.utils import simple_log_function_time
from onyx.utils.logger import setup_logger
from shared_configs.configs import API_BASED_EMBEDDING_TIMEOUT
from shared_configs.configs import INDEXING_ONLY
from shared_configs.configs import OPENAI_EMBEDDING_TIMEOUT
from shared_configs.configs import VERTEXAI_EMBEDDING_LOCAL_BATCH_SIZE
from shared_configs.enums import EmbedTextType
from shared_configs.enums import RerankerProvider
from shared_configs.model_server_models import Embedding
from shared_configs.model_server_models import EmbedRequest
from shared_configs.model_server_models import EmbedResponse
from shared_configs.model_server_models import RerankRequest
from shared_configs.model_server_models import RerankResponse
from shared_configs.utils import batch_list
logger = setup_logger()
router = APIRouter(prefix="/encoder")
_GLOBAL_MODELS_DICT: dict[str, "SentenceTransformer"] = {}
_RERANK_MODEL: Optional["CrossEncoder"] = None
@@ -57,315 +32,6 @@ _RERANK_MODEL: Optional["CrossEncoder"] = None
_RETRY_DELAY = 10 if INDEXING_ONLY else 0.1
_RETRY_TRIES = 10 if INDEXING_ONLY else 2
# OpenAI only allows 2048 embeddings to be computed at once
_OPENAI_MAX_INPUT_LEN = 2048
# Cohere allows up to 96 embeddings in a single embedding calling
_COHERE_MAX_INPUT_LEN = 96
# Authentication error string constants
_AUTH_ERROR_401 = "401"
_AUTH_ERROR_UNAUTHORIZED = "unauthorized"
_AUTH_ERROR_INVALID_API_KEY = "invalid api key"
_AUTH_ERROR_PERMISSION = "permission"
def is_authentication_error(error: Exception) -> bool:
"""Check if an exception is related to authentication issues.
Args:
error: The exception to check
Returns:
bool: True if the error appears to be authentication-related
"""
error_str = str(error).lower()
return (
_AUTH_ERROR_401 in error_str
or _AUTH_ERROR_UNAUTHORIZED in error_str
or _AUTH_ERROR_INVALID_API_KEY in error_str
or _AUTH_ERROR_PERMISSION in error_str
)
def format_embedding_error(
error: Exception,
service_name: str,
model: str | None,
provider: EmbeddingProvider,
sanitized_api_key: str | None = None,
status_code: int | None = None,
) -> str:
"""
Format a standardized error string for embedding errors.
"""
detail = f"Status {status_code}" if status_code else f"{type(error)}"
return (
f"{'HTTP error' if status_code else 'Exception'} embedding text with {service_name} - {detail}: "
f"Model: {model} "
f"Provider: {provider} "
f"API Key: {sanitized_api_key} "
f"Exception: {error}"
)
# Custom exception for authentication errors
class AuthenticationError(Exception):
"""Raised when authentication fails with a provider."""
def __init__(self, provider: str, message: str = "API key is invalid or expired"):
self.provider = provider
self.message = message
super().__init__(f"{provider} authentication failed: {message}")
class CloudEmbedding:
def __init__(
self,
api_key: str,
provider: EmbeddingProvider,
api_url: str | None = None,
api_version: str | None = None,
timeout: int = API_BASED_EMBEDDING_TIMEOUT,
) -> None:
self.provider = provider
self.api_key = api_key
self.api_url = api_url
self.api_version = api_version
self.timeout = timeout
self.http_client = httpx.AsyncClient(timeout=timeout)
self._closed = False
self.sanitized_api_key = api_key[:4] + "********" + api_key[-4:]
async def _embed_openai(
self, texts: list[str], model: str | None, reduced_dimension: int | None
) -> list[Embedding]:
if not model:
model = DEFAULT_OPENAI_MODEL
# Use the OpenAI specific timeout for this one
client = openai.AsyncOpenAI(
api_key=self.api_key, timeout=OPENAI_EMBEDDING_TIMEOUT
)
final_embeddings: list[Embedding] = []
for text_batch in batch_list(texts, _OPENAI_MAX_INPUT_LEN):
response = await client.embeddings.create(
input=text_batch,
model=model,
dimensions=reduced_dimension or openai.NOT_GIVEN,
)
final_embeddings.extend(
[embedding.embedding for embedding in response.data]
)
return final_embeddings
async def _embed_cohere(
self, texts: list[str], model: str | None, embedding_type: str
) -> list[Embedding]:
if not model:
model = DEFAULT_COHERE_MODEL
client = CohereAsyncClient(api_key=self.api_key)
final_embeddings: list[Embedding] = []
for text_batch in batch_list(texts, _COHERE_MAX_INPUT_LEN):
# Does not use the same tokenizer as the Onyx API server but it's approximately the same
# empirically it's only off by a very few tokens so it's not a big deal
response = await client.embed(
texts=text_batch,
model=model,
input_type=embedding_type,
truncate="END",
)
final_embeddings.extend(cast(list[Embedding], response.embeddings))
return final_embeddings
async def _embed_voyage(
self, texts: list[str], model: str | None, embedding_type: str
) -> list[Embedding]:
if not model:
model = DEFAULT_VOYAGE_MODEL
client = voyageai.AsyncClient(
api_key=self.api_key, timeout=API_BASED_EMBEDDING_TIMEOUT
)
response = await client.embed(
texts=texts,
model=model,
input_type=embedding_type,
truncation=True,
)
return response.embeddings
async def _embed_azure(
self, texts: list[str], model: str | None
) -> list[Embedding]:
response = await aembedding(
model=model,
input=texts,
timeout=API_BASED_EMBEDDING_TIMEOUT,
api_key=self.api_key,
api_base=self.api_url,
api_version=self.api_version,
)
embeddings = [embedding["embedding"] for embedding in response.data]
return embeddings
async def _embed_vertex(
self, texts: list[str], model: str | None, embedding_type: str
) -> list[Embedding]:
if not model:
model = DEFAULT_VERTEX_MODEL
credentials = service_account.Credentials.from_service_account_info(
json.loads(self.api_key)
)
project_id = json.loads(self.api_key)["project_id"]
vertexai.init(project=project_id, credentials=credentials)
client = TextEmbeddingModel.from_pretrained(model)
inputs = [TextEmbeddingInput(text, embedding_type) for text in texts]
# Split into batches of 25 texts
max_texts_per_batch = VERTEXAI_EMBEDDING_LOCAL_BATCH_SIZE
batches = [
inputs[i : i + max_texts_per_batch]
for i in range(0, len(inputs), max_texts_per_batch)
]
# Dispatch all embedding calls asynchronously at once
tasks = [
client.get_embeddings_async(batch, auto_truncate=True) for batch in batches
]
# Wait for all tasks to complete in parallel
results = await asyncio.gather(*tasks)
return [embedding.values for batch in results for embedding in batch]
async def _embed_litellm_proxy(
self, texts: list[str], model_name: str | None
) -> list[Embedding]:
if not model_name:
raise ValueError("Model name is required for LiteLLM proxy embedding.")
if not self.api_url:
raise ValueError("API URL is required for LiteLLM proxy embedding.")
headers = (
{} if not self.api_key else {"Authorization": f"Bearer {self.api_key}"}
)
response = await self.http_client.post(
self.api_url,
json={
"model": model_name,
"input": texts,
},
headers=headers,
)
response.raise_for_status()
result = response.json()
return [embedding["embedding"] for embedding in result["data"]]
@retry(tries=_RETRY_TRIES, delay=_RETRY_DELAY)
async def embed(
self,
*,
texts: list[str],
text_type: EmbedTextType,
model_name: str | None = None,
deployment_name: str | None = None,
reduced_dimension: int | None = None,
) -> list[Embedding]:
try:
if self.provider == EmbeddingProvider.OPENAI:
return await self._embed_openai(texts, model_name, reduced_dimension)
elif self.provider == EmbeddingProvider.AZURE:
return await self._embed_azure(texts, f"azure/{deployment_name}")
elif self.provider == EmbeddingProvider.LITELLM:
return await self._embed_litellm_proxy(texts, model_name)
embedding_type = EmbeddingModelTextType.get_type(self.provider, text_type)
if self.provider == EmbeddingProvider.COHERE:
return await self._embed_cohere(texts, model_name, embedding_type)
elif self.provider == EmbeddingProvider.VOYAGE:
return await self._embed_voyage(texts, model_name, embedding_type)
elif self.provider == EmbeddingProvider.GOOGLE:
return await self._embed_vertex(texts, model_name, embedding_type)
else:
raise ValueError(f"Unsupported provider: {self.provider}")
except openai.AuthenticationError:
raise AuthenticationError(provider="OpenAI")
except httpx.HTTPStatusError as e:
if e.response.status_code == 401:
raise AuthenticationError(provider=str(self.provider))
error_string = format_embedding_error(
e,
str(self.provider),
model_name or deployment_name,
self.provider,
sanitized_api_key=self.sanitized_api_key,
status_code=e.response.status_code,
)
logger.error(error_string)
logger.debug(f"Exception texts: {texts}")
raise RuntimeError(error_string)
except Exception as e:
if is_authentication_error(e):
raise AuthenticationError(provider=str(self.provider))
error_string = format_embedding_error(
e,
str(self.provider),
model_name or deployment_name,
self.provider,
sanitized_api_key=self.sanitized_api_key,
)
logger.error(error_string)
logger.debug(f"Exception texts: {texts}")
raise RuntimeError(error_string)
@staticmethod
def create(
api_key: str,
provider: EmbeddingProvider,
api_url: str | None = None,
api_version: str | None = None,
) -> "CloudEmbedding":
logger.debug(f"Creating Embedding instance for provider: {provider}")
return CloudEmbedding(api_key, provider, api_url, api_version)
async def aclose(self) -> None:
"""Explicitly close the client."""
if not self._closed:
await self.http_client.aclose()
self._closed = True
async def __aenter__(self) -> "CloudEmbedding":
return self
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
await self.aclose()
def __del__(self) -> None:
"""Finalizer to warn about unclosed clients."""
if not self._closed:
logger.warning(
"CloudEmbedding was not properly closed. Use 'async with' or call aclose()"
)
def get_embedding_model(
model_name: str,
@@ -404,20 +70,34 @@ def get_local_reranking_model(
return _RERANK_MODEL
ENCODING_RETRIES = 3
ENCODING_RETRY_DELAY = 0.1
def _concurrent_embedding(
texts: list[str], model: "SentenceTransformer", normalize_embeddings: bool
) -> Any:
"""Synchronous wrapper for concurrent_embedding to use with run_in_executor."""
for _ in range(ENCODING_RETRIES):
try:
return model.encode(texts, normalize_embeddings=normalize_embeddings)
except RuntimeError as e:
# There is a concurrency bug in the SentenceTransformer library that causes
# the model to fail to encode texts. It's pretty rare and we want to allow
# concurrent embedding, hence we retry (the specific error is
# "RuntimeError: Already borrowed" and occurs in the transformers library)
logger.error(f"Error encoding texts, retrying: {e}")
time.sleep(ENCODING_RETRY_DELAY)
return model.encode(texts, normalize_embeddings=normalize_embeddings)
@simple_log_function_time()
async def embed_text(
texts: list[str],
text_type: EmbedTextType,
model_name: str | None,
deployment_name: str | None,
max_context_length: int,
normalize_embeddings: bool,
api_key: str | None,
provider_type: EmbeddingProvider | None,
prefix: str | None,
api_url: str | None,
api_version: str | None,
reduced_dimension: int | None,
gpu_type: str = "UNKNOWN",
) -> list[Embedding]:
if not all(texts):
@@ -434,52 +114,10 @@ async def embed_text(
for text in texts:
total_chars += len(text)
if provider_type is not None:
logger.info(
f"Embedding {len(texts)} texts with {total_chars} total characters with provider: {provider_type}"
)
# Only local models should call this function now
# API providers should go directly to API server
if api_key is None:
logger.error("API key not provided for cloud model")
raise RuntimeError("API key not provided for cloud model")
if prefix:
logger.warning("Prefix provided for cloud model, which is not supported")
raise ValueError(
"Prefix string is not valid for cloud models. "
"Cloud models take an explicit text type instead."
)
async with CloudEmbedding(
api_key=api_key,
provider=provider_type,
api_url=api_url,
api_version=api_version,
) as cloud_model:
embeddings = await cloud_model.embed(
texts=texts,
model_name=model_name,
deployment_name=deployment_name,
text_type=text_type,
reduced_dimension=reduced_dimension,
)
if any(embedding is None for embedding in embeddings):
error_message = "Embeddings contain None values\n"
error_message += "Corresponding texts:\n"
error_message += "\n".join(texts)
logger.error(error_message)
raise ValueError(error_message)
elapsed = time.monotonic() - start
logger.info(
f"event=embedding_provider "
f"texts={len(texts)} "
f"chars={total_chars} "
f"provider={provider_type} "
f"elapsed={elapsed:.2f}"
)
elif model_name is not None:
if model_name is not None:
logger.info(
f"Embedding {len(texts)} texts with {total_chars} total characters with local model: {model_name}"
)
@@ -492,8 +130,8 @@ async def embed_text(
# Run CPU-bound embedding in a thread pool
embeddings_vectors = await asyncio.get_event_loop().run_in_executor(
None,
lambda: local_model.encode(
prefixed_texts, normalize_embeddings=normalize_embeddings
lambda: _concurrent_embedding(
prefixed_texts, local_model, normalize_embeddings
),
)
embeddings = [
@@ -515,10 +153,8 @@ async def embed_text(
f"elapsed={elapsed:.2f}"
)
else:
logger.error("Neither model name nor provider specified for embedding")
raise ValueError(
"Either model name or provider must be provided to run embeddings."
)
logger.error("Model name not specified for embedding")
raise ValueError("Model name must be provided to run embeddings.")
return embeddings
@@ -533,77 +169,6 @@ async def local_rerank(query: str, docs: list[str], model_name: str) -> list[flo
)
async def cohere_rerank_api(
query: str, docs: list[str], model_name: str, api_key: str
) -> list[float]:
cohere_client = CohereAsyncClient(api_key=api_key)
response = await cohere_client.rerank(query=query, documents=docs, model=model_name)
results = response.results
sorted_results = sorted(results, key=lambda item: item.index)
return [result.relevance_score for result in sorted_results]
async def cohere_rerank_aws(
query: str,
docs: list[str],
model_name: str,
region_name: str,
aws_access_key_id: str,
aws_secret_access_key: str,
) -> list[float]:
session = aioboto3.Session(
aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key
)
async with session.client(
"bedrock-runtime", region_name=region_name
) as bedrock_client:
body = json.dumps(
{
"query": query,
"documents": docs,
"api_version": 2,
}
)
# Invoke the Bedrock model asynchronously
response = await bedrock_client.invoke_model(
modelId=model_name,
accept="application/json",
contentType="application/json",
body=body,
)
# Read the response asynchronously
response_body = json.loads(await response["body"].read())
# Extract and sort the results
results = response_body.get("results", [])
sorted_results = sorted(results, key=lambda item: item["index"])
return [result["relevance_score"] for result in sorted_results]
async def litellm_rerank(
query: str, docs: list[str], api_url: str, model_name: str, api_key: str | None
) -> list[float]:
headers = {} if not api_key else {"Authorization": f"Bearer {api_key}"}
async with httpx.AsyncClient() as client:
response = await client.post(
api_url,
json={
"model": model_name,
"query": query,
"documents": docs,
},
headers=headers,
)
response.raise_for_status()
result = response.json()
return [
item["relevance_score"]
for item in sorted(result["results"], key=lambda x: x["index"])
]
@router.post("/bi-encoder-embed")
async def route_bi_encoder_embed(
request: Request,
@@ -615,6 +180,13 @@ async def route_bi_encoder_embed(
async def process_embed_request(
embed_request: EmbedRequest, gpu_type: str = "UNKNOWN"
) -> EmbedResponse:
# Only local models should use this endpoint - API providers should make direct API calls
if embed_request.provider_type is not None:
raise ValueError(
f"Model server embedding endpoint should only be used for local models. "
f"API provider '{embed_request.provider_type}' should make direct API calls instead."
)
if not embed_request.texts:
raise HTTPException(status_code=400, detail="No texts to be embedded")
@@ -632,26 +204,12 @@ async def process_embed_request(
embeddings = await embed_text(
texts=embed_request.texts,
model_name=embed_request.model_name,
deployment_name=embed_request.deployment_name,
max_context_length=embed_request.max_context_length,
normalize_embeddings=embed_request.normalize_embeddings,
api_key=embed_request.api_key,
provider_type=embed_request.provider_type,
text_type=embed_request.text_type,
api_url=embed_request.api_url,
api_version=embed_request.api_version,
reduced_dimension=embed_request.reduced_dimension,
prefix=prefix,
gpu_type=gpu_type,
)
return EmbedResponse(embeddings=embeddings)
except AuthenticationError as e:
# Handle authentication errors consistently
logger.error(f"Authentication error: {e.provider}")
raise HTTPException(
status_code=401,
detail=f"Authentication failed: {e.message}",
)
except RateLimitError as e:
raise HTTPException(
status_code=429,
@@ -669,6 +227,13 @@ async def process_embed_request(
@router.post("/cross-encoder-scores")
async def process_rerank_request(rerank_request: RerankRequest) -> RerankResponse:
"""Cross encoders can be purely black box from the app perspective"""
# Only local models should use this endpoint - API providers should make direct API calls
if rerank_request.provider_type is not None:
raise ValueError(
f"Model server reranking endpoint should only be used for local models. "
f"API provider '{rerank_request.provider_type}' should make direct API calls instead."
)
if INDEXING_ONLY:
raise RuntimeError("Indexing model server should not call intent endpoint")
@@ -680,55 +245,13 @@ async def process_rerank_request(rerank_request: RerankRequest) -> RerankRespons
raise ValueError("Empty documents cannot be reranked.")
try:
if rerank_request.provider_type is None:
sim_scores = await local_rerank(
query=rerank_request.query,
docs=rerank_request.documents,
model_name=rerank_request.model_name,
)
return RerankResponse(scores=sim_scores)
elif rerank_request.provider_type == RerankerProvider.LITELLM:
if rerank_request.api_url is None:
raise ValueError("API URL is required for LiteLLM reranking.")
sim_scores = await litellm_rerank(
query=rerank_request.query,
docs=rerank_request.documents,
api_url=rerank_request.api_url,
model_name=rerank_request.model_name,
api_key=rerank_request.api_key,
)
return RerankResponse(scores=sim_scores)
elif rerank_request.provider_type == RerankerProvider.COHERE:
if rerank_request.api_key is None:
raise RuntimeError("Cohere Rerank Requires an API Key")
sim_scores = await cohere_rerank_api(
query=rerank_request.query,
docs=rerank_request.documents,
model_name=rerank_request.model_name,
api_key=rerank_request.api_key,
)
return RerankResponse(scores=sim_scores)
elif rerank_request.provider_type == RerankerProvider.BEDROCK:
if rerank_request.api_key is None:
raise RuntimeError("Bedrock Rerank Requires an API Key")
aws_access_key_id, aws_secret_access_key, aws_region = pass_aws_key(
rerank_request.api_key
)
sim_scores = await cohere_rerank_aws(
query=rerank_request.query,
docs=rerank_request.documents,
model_name=rerank_request.model_name,
region_name=aws_region,
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
)
return RerankResponse(scores=sim_scores)
else:
raise ValueError(f"Unsupported provider: {rerank_request.provider_type}")
# At this point, provider_type is None, so handle local reranking
sim_scores = await local_rerank(
query=rerank_request.query,
docs=rerank_request.documents,
model_name=rerank_request.model_name,
)
return RerankResponse(scores=sim_scores)
except Exception as e:
logger.exception(f"Error during reranking process:\n{str(e)}")

View File

@@ -34,8 +34,8 @@ from shared_configs.configs import SENTRY_DSN
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"
HF_CACHE_PATH = Path(os.path.expanduser("~")) / ".cache/huggingface"
TEMP_HF_CACHE_PATH = Path(os.path.expanduser("~")) / ".cache/temp_huggingface"
HF_CACHE_PATH = Path(".cache/huggingface")
TEMP_HF_CACHE_PATH = Path(".cache/temp_huggingface")
transformer_logging.set_verbosity_error()

View File

@@ -70,32 +70,3 @@ def get_gpu_type() -> str:
return GPUStatus.MAC_MPS
return GPUStatus.NONE
def pass_aws_key(api_key: str) -> tuple[str, str, str]:
"""Parse AWS API key string into components.
Args:
api_key: String in format 'aws_ACCESSKEY_SECRETKEY_REGION'
Returns:
Tuple of (access_key, secret_key, region)
Raises:
ValueError: If key format is invalid
"""
if not api_key.startswith("aws"):
raise ValueError("API key must start with 'aws' prefix")
parts = api_key.split("_")
if len(parts) != 4:
raise ValueError(
f"API key must be in format 'aws_ACCESSKEY_SECRETKEY_REGION', got {len(parts) - 1} parts"
"this is an onyx specific format for formatting the aws secrets for bedrock"
)
try:
_, aws_access_key_id, aws_secret_access_key, aws_region = parts
return aws_access_key_id, aws_secret_access_key, aws_region
except Exception as e:
raise ValueError(f"Failed to parse AWS key components: {str(e)}")

View File

@@ -1,97 +0,0 @@
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
from onyx.agents.agent_search.basic.states import BasicInput
from onyx.agents.agent_search.basic.states import BasicOutput
from onyx.agents.agent_search.basic.states import BasicState
from onyx.agents.agent_search.orchestration.nodes.call_tool import call_tool
from onyx.agents.agent_search.orchestration.nodes.choose_tool import choose_tool
from onyx.agents.agent_search.orchestration.nodes.prepare_tool_input import (
prepare_tool_input,
)
from onyx.agents.agent_search.orchestration.nodes.use_tool_response import (
basic_use_tool_response,
)
from onyx.utils.logger import setup_logger
logger = setup_logger()
def basic_graph_builder() -> StateGraph:
graph = StateGraph(
state_schema=BasicState,
input=BasicInput,
output=BasicOutput,
)
### Add nodes ###
graph.add_node(
node="prepare_tool_input",
action=prepare_tool_input,
)
graph.add_node(
node="choose_tool",
action=choose_tool,
)
graph.add_node(
node="call_tool",
action=call_tool,
)
graph.add_node(
node="basic_use_tool_response",
action=basic_use_tool_response,
)
### Add edges ###
graph.add_edge(start_key=START, end_key="prepare_tool_input")
graph.add_edge(start_key="prepare_tool_input", end_key="choose_tool")
graph.add_conditional_edges("choose_tool", should_continue, ["call_tool", END])
graph.add_edge(
start_key="call_tool",
end_key="basic_use_tool_response",
)
graph.add_edge(
start_key="basic_use_tool_response",
end_key=END,
)
return graph
def should_continue(state: BasicState) -> str:
return (
# If there are no tool calls, basic graph already streamed the answer
END
if state.tool_choice is None
else "call_tool"
)
if __name__ == "__main__":
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.context.search.models import SearchRequest
from onyx.llm.factory import get_default_llms
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
graph = basic_graph_builder()
compiled_graph = graph.compile()
input = BasicInput(unused=True)
primary_llm, fast_llm = get_default_llms()
with get_session_with_current_tenant() as db_session:
config, _ = get_test_config(
db_session=db_session,
primary_llm=primary_llm,
fast_llm=fast_llm,
search_request=SearchRequest(query="How does onyx use FastAPI?"),
)
compiled_graph.invoke(input, config={"metadata": {"config": config}})

View File

@@ -1,35 +0,0 @@
from typing import TypedDict
from langchain_core.messages import AIMessageChunk
from pydantic import BaseModel
from onyx.agents.agent_search.orchestration.states import ToolCallUpdate
from onyx.agents.agent_search.orchestration.states import ToolChoiceInput
from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
# States contain values that change over the course of graph execution,
# Config is for values that are set at the start and never change.
# If you are using a value from the config and realize it needs to change,
# you should add it to the state and use/update the version in the state.
## Graph Input State
class BasicInput(BaseModel):
# Langgraph needs a nonempty input, but we pass in all static
# data through a RunnableConfig.
unused: bool = True
## Graph Output State
class BasicOutput(TypedDict):
tool_call_chunk: AIMessageChunk
## Graph State
class BasicState(
BasicInput,
ToolChoiceInput,
ToolCallUpdate,
ToolChoiceUpdate,
):
pass

View File

@@ -1,64 +0,0 @@
from collections.abc import Iterator
from typing import cast
from langchain_core.messages import AIMessageChunk
from langchain_core.messages import BaseMessage
from langgraph.types import StreamWriter
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import LlmDoc
from onyx.chat.stream_processing.answer_response_handler import AnswerResponseHandler
from onyx.chat.stream_processing.answer_response_handler import CitationResponseHandler
from onyx.chat.stream_processing.answer_response_handler import (
PassThroughAnswerResponseHandler,
)
from onyx.chat.stream_processing.utils import map_document_id_order
from onyx.utils.logger import setup_logger
logger = setup_logger()
def process_llm_stream(
messages: Iterator[BaseMessage],
should_stream_answer: bool,
writer: StreamWriter,
final_search_results: list[LlmDoc] | None = None,
displayed_search_results: list[LlmDoc] | None = None,
) -> AIMessageChunk:
tool_call_chunk = AIMessageChunk(content="")
if final_search_results and displayed_search_results:
answer_handler: AnswerResponseHandler = CitationResponseHandler(
context_docs=final_search_results,
final_doc_id_to_rank_map=map_document_id_order(final_search_results),
display_doc_id_to_rank_map=map_document_id_order(displayed_search_results),
)
else:
answer_handler = PassThroughAnswerResponseHandler()
full_answer = ""
# This stream will be the llm answer if no tool is chosen. When a tool is chosen,
# the stream will contain AIMessageChunks with tool call information.
for message in messages:
answer_piece = message.content
if not isinstance(answer_piece, str):
# this is only used for logging, so fine to
# just add the string representation
answer_piece = str(answer_piece)
full_answer += answer_piece
if isinstance(message, AIMessageChunk) and (
message.tool_call_chunks or message.tool_calls
):
tool_call_chunk += message # type: ignore
elif should_stream_answer:
for response_part in answer_handler.handle_response_part(message, []):
write_custom_event(
"basic_response",
response_part,
writer,
)
logger.debug(f"Full answer: {full_answer}")
return cast(AIMessageChunk, tool_call_chunk)

View File

@@ -10,6 +10,7 @@ class CoreState(BaseModel):
"""
log_messages: Annotated[list[str], add] = []
current_step_nr: int = 1
class SubgraphCoreState(BaseModel):

View File

@@ -14,8 +14,6 @@ from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
trim_prompt_piece,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import AgentAnswerPiece
from onyx.configs.constants import DocumentSource
from onyx.prompts.agents.dc_prompts import DC_OBJECT_NO_BASE_DATA_EXTRACTION_PROMPT
from onyx.prompts.agents.dc_prompts import DC_OBJECT_SEPARATOR
@@ -139,17 +137,6 @@ def search_objects(
except Exception as e:
raise ValueError(f"Error in search_objects: {e}")
write_custom_event(
"initial_agent_answer",
AgentAnswerPiece(
answer_piece=" Researching the individual objects for each source type... ",
level=0,
level_question_num=0,
answer_type="agent_level_answer",
),
writer,
)
return SearchSourcesObjectsUpdate(
analysis_objects=object_list,
analysis_sources=document_sources,

View File

@@ -9,8 +9,6 @@ from onyx.agents.agent_search.dc_search_analysis.states import MainState
from onyx.agents.agent_search.dc_search_analysis.states import (
ObjectResearchInformationUpdate,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import AgentAnswerPiece
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -23,17 +21,6 @@ def structure_research_by_object(
LangGraph node to start the agentic search process.
"""
write_custom_event(
"initial_agent_answer",
AgentAnswerPiece(
answer_piece=" consolidating the information across source types for each object...",
level=0,
level_question_num=0,
answer_type="agent_level_answer",
),
writer,
)
object_source_research_results = state.object_source_research_results
object_research_information_results: List[Dict[str, str]] = []

View File

@@ -12,8 +12,6 @@ from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
trim_prompt_piece,
)
from onyx.agents.agent_search.shared_graph_utils.llm import stream_llm_answer
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import AgentAnswerPiece
from onyx.prompts.agents.dc_prompts import DC_FORMATTING_NO_BASE_DATA_PROMPT
from onyx.prompts.agents.dc_prompts import DC_FORMATTING_WITH_BASE_DATA_PROMPT
from onyx.utils.logger import setup_logger
@@ -33,17 +31,6 @@ def consolidate_research(
search_tool = graph_config.tooling.search_tool
write_custom_event(
"initial_agent_answer",
AgentAnswerPiece(
answer_piece=" generating the answer\n\n\n",
level=0,
level_question_num=0,
answer_type="agent_level_answer",
),
writer,
)
if search_tool is None or graph_config.inputs.persona is None:
raise ValueError("Search tool and persona must be provided for DivCon search")

View File

@@ -1,31 +0,0 @@
from collections.abc import Hashable
from datetime import datetime
from langgraph.types import Send
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
SubQuestionAnsweringInput,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
ExpandedRetrievalInput,
)
from onyx.utils.logger import setup_logger
logger = setup_logger()
def send_to_expanded_retrieval(state: SubQuestionAnsweringInput) -> Send | Hashable:
"""
LangGraph edge to send a sub-question to the expanded retrieval.
"""
edge_start_time = datetime.now()
return Send(
"initial_sub_question_expanded_retrieval",
ExpandedRetrievalInput(
question=state.question,
base_search=False,
sub_question_id=state.question_id,
log_messages=[f"{edge_start_time} -- Sending to expanded retrieval"],
),
)

View File

@@ -1,137 +0,0 @@
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.edges import (
send_to_expanded_retrieval,
)
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.nodes.check_sub_answer import (
check_sub_answer,
)
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.nodes.format_sub_answer import (
format_sub_answer,
)
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.nodes.generate_sub_answer import (
generate_sub_answer,
)
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.nodes.ingest_retrieved_documents import (
ingest_retrieved_documents,
)
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
AnswerQuestionOutput,
)
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
AnswerQuestionState,
)
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
SubQuestionAnsweringInput,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.graph_builder import (
expanded_retrieval_graph_builder,
)
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
from onyx.utils.logger import setup_logger
logger = setup_logger()
def answer_query_graph_builder() -> StateGraph:
"""
LangGraph sub-graph builder for the initial individual sub-answer generation.
"""
graph = StateGraph(
state_schema=AnswerQuestionState,
input=SubQuestionAnsweringInput,
output=AnswerQuestionOutput,
)
### Add nodes ###
# The sub-graph that executes the expanded retrieval process for a sub-question
expanded_retrieval = expanded_retrieval_graph_builder().compile()
graph.add_node(
node="initial_sub_question_expanded_retrieval",
action=expanded_retrieval,
)
# The node that ingests the retrieved documents and puts them into the proper
# state keys.
graph.add_node(
node="ingest_retrieval",
action=ingest_retrieved_documents,
)
# The node that generates the sub-answer
graph.add_node(
node="generate_sub_answer",
action=generate_sub_answer,
)
# The node that checks the sub-answer
graph.add_node(
node="answer_check",
action=check_sub_answer,
)
# The node that formats the sub-answer for the following initial answer generation
graph.add_node(
node="format_answer",
action=format_sub_answer,
)
### Add edges ###
graph.add_conditional_edges(
source=START,
path=send_to_expanded_retrieval,
path_map=["initial_sub_question_expanded_retrieval"],
)
graph.add_edge(
start_key="initial_sub_question_expanded_retrieval",
end_key="ingest_retrieval",
)
graph.add_edge(
start_key="ingest_retrieval",
end_key="generate_sub_answer",
)
graph.add_edge(
start_key="generate_sub_answer",
end_key="answer_check",
)
graph.add_edge(
start_key="answer_check",
end_key="format_answer",
)
graph.add_edge(
start_key="format_answer",
end_key=END,
)
return graph
if __name__ == "__main__":
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.llm.factory import get_default_llms
from onyx.context.search.models import SearchRequest
graph = answer_query_graph_builder()
compiled_graph = graph.compile()
primary_llm, fast_llm = get_default_llms()
search_request = SearchRequest(
query="what can you do with onyx or danswer?",
)
with get_session_with_current_tenant() as db_session:
graph_config, search_tool = get_test_config(
db_session, primary_llm, fast_llm, search_request
)
inputs = SubQuestionAnsweringInput(
question="what can you do with onyx?",
question_id="0_0",
log_messages=[],
)
for thing in compiled_graph.stream(
input=inputs,
config={"configurable": {"config": graph_config}},
):
logger.debug(thing)

View File

@@ -1,136 +0,0 @@
from datetime import datetime
from typing import cast
from langchain_core.messages import BaseMessage
from langchain_core.messages import HumanMessage
from langchain_core.runnables.config import RunnableConfig
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
AnswerQuestionState,
)
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
SubQuestionAnswerCheckUpdate,
)
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
binary_string_test,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_LLM_RATELIMIT_MESSAGE,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_LLM_TIMEOUT_MESSAGE,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_POSITIVE_VALUE_STR,
)
from onyx.agents.agent_search.shared_graph_utils.constants import AgentLLMErrorType
from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLog
from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_VALIDATION
from onyx.configs.agent_configs import AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_CHECK
from onyx.configs.agent_configs import AGENT_TIMEOUT_LLM_SUBANSWER_CHECK
from onyx.llm.chat_llm import LLMRateLimitError
from onyx.llm.chat_llm import LLMTimeoutError
from onyx.prompts.agent_search import SUB_ANSWER_CHECK_PROMPT
from onyx.prompts.agent_search import UNKNOWN_ANSWER
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
from onyx.utils.timing import log_function_time
logger = setup_logger()
_llm_node_error_strings = LLMNodeErrorStrings(
timeout="LLM Timeout Error. The sub-answer will be treated as 'relevant'",
rate_limit="LLM Rate Limit Error. The sub-answer will be treated as 'relevant'",
general_error="General LLM Error. The sub-answer will be treated as 'relevant'",
)
@log_function_time(print_only=True)
def check_sub_answer(
state: AnswerQuestionState, config: RunnableConfig
) -> SubQuestionAnswerCheckUpdate:
"""
LangGraph node to check the quality of the sub-answer. The answer
is represented as a boolean value.
"""
node_start_time = datetime.now()
level, question_num = parse_question_id(state.question_id)
if state.answer == UNKNOWN_ANSWER:
return SubQuestionAnswerCheckUpdate(
answer_quality=False,
log_messages=[
get_langgraph_node_log_string(
graph_component="initial - generate individual sub answer",
node_name="check sub answer",
node_start_time=node_start_time,
result="unknown answer",
)
],
)
msg = [
HumanMessage(
content=SUB_ANSWER_CHECK_PROMPT.format(
question=state.question,
base_answer=state.answer,
)
)
]
graph_config = cast(GraphConfig, config["metadata"]["config"])
fast_llm = graph_config.tooling.fast_llm
agent_error: AgentErrorLog | None = None
response: BaseMessage | None = None
try:
response = run_with_timeout(
AGENT_TIMEOUT_LLM_SUBANSWER_CHECK,
fast_llm.invoke,
prompt=msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_CHECK,
max_tokens=AGENT_MAX_TOKENS_VALIDATION,
)
quality_str: str = cast(str, response.content)
answer_quality = binary_string_test(
text=quality_str, positive_value=AGENT_POSITIVE_VALUE_STR
)
log_result = f"Answer quality: {quality_str}"
except (LLMTimeoutError, TimeoutError):
agent_error = AgentErrorLog(
error_type=AgentLLMErrorType.TIMEOUT,
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
error_result=_llm_node_error_strings.timeout,
)
answer_quality = True
log_result = agent_error.error_result
logger.error("LLM Timeout Error - check sub answer")
except LLMRateLimitError:
agent_error = AgentErrorLog(
error_type=AgentLLMErrorType.RATE_LIMIT,
error_message=AGENT_LLM_RATELIMIT_MESSAGE,
error_result=_llm_node_error_strings.rate_limit,
)
answer_quality = True
log_result = agent_error.error_result
logger.error("LLM Rate Limit Error - check sub answer")
return SubQuestionAnswerCheckUpdate(
answer_quality=answer_quality,
log_messages=[
get_langgraph_node_log_string(
graph_component="initial - generate individual sub answer",
node_name="check sub answer",
node_start_time=node_start_time,
result=log_result,
)
],
)

View File

@@ -1,30 +0,0 @@
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
AnswerQuestionOutput,
)
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
AnswerQuestionState,
)
from onyx.agents.agent_search.shared_graph_utils.models import (
SubQuestionAnswerResults,
)
def format_sub_answer(state: AnswerQuestionState) -> AnswerQuestionOutput:
"""
LangGraph node to generate the sub-answer format.
"""
return AnswerQuestionOutput(
answer_results=[
SubQuestionAnswerResults(
question=state.question,
question_id=state.question_id,
verified_high_quality=state.answer_quality,
answer=state.answer,
sub_query_retrieval_results=state.expanded_retrieval_results,
verified_reranked_documents=state.verified_reranked_documents,
context_documents=state.context_documents,
cited_documents=state.cited_documents,
sub_question_retrieval_stats=state.sub_question_retrieval_stats,
)
],
)

View File

@@ -1,185 +0,0 @@
from datetime import datetime
from typing import cast
from langchain_core.messages import merge_message_runs
from langchain_core.runnables.config import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
AnswerQuestionState,
)
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
SubQuestionAnswerGenerationUpdate,
)
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
build_sub_question_answer_prompt,
)
from onyx.agents.agent_search.shared_graph_utils.calculations import (
dedup_sort_inference_section_list,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_LLM_RATELIMIT_MESSAGE,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_LLM_TIMEOUT_MESSAGE,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AgentLLMErrorType,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
LLM_ANSWER_ERROR_MESSAGE,
)
from onyx.agents.agent_search.shared_graph_utils.llm import stream_llm_answer
from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLog
from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings
from onyx.agents.agent_search.shared_graph_utils.utils import get_answer_citation_ids
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_persona_agent_prompt_expressions,
)
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import AgentAnswerPiece
from onyx.chat.models import StreamStopInfo
from onyx.chat.models import StreamStopReason
from onyx.chat.models import StreamType
from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_SUBANSWER_GENERATION
from onyx.configs.agent_configs import AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION
from onyx.configs.agent_configs import AGENT_TIMEOUT_LLM_SUBANSWER_GENERATION
from onyx.llm.chat_llm import LLMRateLimitError
from onyx.llm.chat_llm import LLMTimeoutError
from onyx.prompts.agent_search import NO_RECOVERED_DOCS
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
from onyx.utils.timing import log_function_time
logger = setup_logger()
_llm_node_error_strings = LLMNodeErrorStrings(
timeout="LLM Timeout Error. A sub-answer could not be constructed and the sub-question will be ignored.",
rate_limit="LLM Rate Limit Error. A sub-answer could not be constructed and the sub-question will be ignored.",
general_error="General LLM Error. A sub-answer could not be constructed and the sub-question will be ignored.",
)
@log_function_time(print_only=True)
def generate_sub_answer(
state: AnswerQuestionState,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> SubQuestionAnswerGenerationUpdate:
"""
LangGraph node to generate a sub-answer.
"""
node_start_time = datetime.now()
graph_config = cast(GraphConfig, config["metadata"]["config"])
question = state.question
state.verified_reranked_documents
level, question_num = parse_question_id(state.question_id)
context_docs = state.context_documents[:AGENT_MAX_ANSWER_CONTEXT_DOCS]
context_docs = dedup_sort_inference_section_list(context_docs)
persona_contextualized_prompt = get_persona_agent_prompt_expressions(
graph_config.inputs.persona
).contextualized_prompt
if len(context_docs) == 0:
answer_str = NO_RECOVERED_DOCS
cited_documents: list = []
log_results = "No documents retrieved"
write_custom_event(
"sub_answers",
AgentAnswerPiece(
answer_piece=answer_str,
level=level,
level_question_num=question_num,
answer_type="agent_sub_answer",
),
writer,
)
else:
fast_llm = graph_config.tooling.fast_llm
msg = build_sub_question_answer_prompt(
question=question,
original_question=graph_config.inputs.prompt_builder.raw_user_query,
docs=context_docs,
persona_specification=persona_contextualized_prompt,
config=fast_llm.config,
)
agent_error: AgentErrorLog | None = None
response: list[str] = []
try:
response, _ = run_with_timeout(
AGENT_TIMEOUT_LLM_SUBANSWER_GENERATION,
lambda: stream_llm_answer(
llm=fast_llm,
prompt=msg,
event_name="sub_answers",
writer=writer,
agent_answer_level=level,
agent_answer_question_num=question_num,
agent_answer_type="agent_sub_answer",
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION,
max_tokens=AGENT_MAX_TOKENS_SUBANSWER_GENERATION,
),
)
except (LLMTimeoutError, TimeoutError):
agent_error = AgentErrorLog(
error_type=AgentLLMErrorType.TIMEOUT,
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
error_result=_llm_node_error_strings.timeout,
)
logger.error("LLM Timeout Error - generate sub answer")
except LLMRateLimitError:
agent_error = AgentErrorLog(
error_type=AgentLLMErrorType.RATE_LIMIT,
error_message=AGENT_LLM_RATELIMIT_MESSAGE,
error_result=_llm_node_error_strings.rate_limit,
)
logger.error("LLM Rate Limit Error - generate sub answer")
if agent_error:
answer_str = LLM_ANSWER_ERROR_MESSAGE
cited_documents = []
log_results = (
agent_error.error_result
or "Sub-answer generation failed due to LLM error"
)
else:
answer_str = merge_message_runs(response, chunk_separator="")[0].content
answer_citation_ids = get_answer_citation_ids(answer_str)
cited_documents = [
context_docs[id] for id in answer_citation_ids if id < len(context_docs)
]
log_results = None
stop_event = StreamStopInfo(
stop_reason=StreamStopReason.FINISHED,
stream_type=StreamType.SUB_ANSWER,
level=level,
level_question_num=question_num,
)
write_custom_event("stream_finished", stop_event, writer)
return SubQuestionAnswerGenerationUpdate(
answer=answer_str,
cited_documents=cited_documents,
log_messages=[
get_langgraph_node_log_string(
graph_component="initial - generate individual sub answer",
node_name="generate sub answer",
node_start_time=node_start_time,
result=log_results or "",
)
],
)

View File

@@ -1,25 +0,0 @@
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
SubQuestionRetrievalIngestionUpdate,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
ExpandedRetrievalOutput,
)
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkRetrievalStats
def ingest_retrieved_documents(
state: ExpandedRetrievalOutput,
) -> SubQuestionRetrievalIngestionUpdate:
"""
LangGraph node to ingest the retrieved documents to format it for the sub-answer.
"""
sub_question_retrieval_stats = state.expanded_retrieval_result.retrieval_stats
if sub_question_retrieval_stats is None:
sub_question_retrieval_stats = [AgentChunkRetrievalStats()]
return SubQuestionRetrievalIngestionUpdate(
expanded_retrieval_results=state.expanded_retrieval_result.expanded_query_results,
verified_reranked_documents=state.expanded_retrieval_result.verified_reranked_documents,
context_documents=state.expanded_retrieval_result.context_documents,
sub_question_retrieval_stats=sub_question_retrieval_stats,
)

View File

@@ -1,73 +0,0 @@
from operator import add
from typing import Annotated
from pydantic import BaseModel
from onyx.agents.agent_search.core_state import SubgraphCoreState
from onyx.agents.agent_search.deep_search.main.states import LoggerUpdate
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkRetrievalStats
from onyx.agents.agent_search.shared_graph_utils.models import QueryRetrievalResult
from onyx.agents.agent_search.shared_graph_utils.models import (
SubQuestionAnswerResults,
)
from onyx.agents.agent_search.shared_graph_utils.operators import (
dedup_inference_sections,
)
from onyx.context.search.models import InferenceSection
## Update States
class SubQuestionAnswerCheckUpdate(LoggerUpdate, BaseModel):
answer_quality: bool = False
log_messages: list[str] = []
class SubQuestionAnswerGenerationUpdate(LoggerUpdate, BaseModel):
answer: str = ""
log_messages: list[str] = []
cited_documents: Annotated[list[InferenceSection], dedup_inference_sections] = []
# answer_stat: AnswerStats
class SubQuestionRetrievalIngestionUpdate(LoggerUpdate, BaseModel):
expanded_retrieval_results: list[QueryRetrievalResult] = []
verified_reranked_documents: Annotated[
list[InferenceSection], dedup_inference_sections
] = []
context_documents: Annotated[list[InferenceSection], dedup_inference_sections] = []
sub_question_retrieval_stats: AgentChunkRetrievalStats = AgentChunkRetrievalStats()
## Graph Input State
class SubQuestionAnsweringInput(SubgraphCoreState):
question: str
question_id: str
# level 0 is original question and first decomposition, level 1 is follow up, etc
# question_num is a unique number per original question per level.
## Graph State
class AnswerQuestionState(
SubQuestionAnsweringInput,
SubQuestionAnswerGenerationUpdate,
SubQuestionAnswerCheckUpdate,
SubQuestionRetrievalIngestionUpdate,
):
pass
## Graph Output State
class AnswerQuestionOutput(LoggerUpdate, BaseModel):
"""
This is a list of results even though each call of this subgraph only returns one result.
This is because if we parallelize the answer query subgraph, there will be multiple
results in a list so the add operator is used to add them together.
"""
answer_results: Annotated[list[SubQuestionAnswerResults], add] = []

View File

@@ -1,50 +0,0 @@
from collections.abc import Hashable
from datetime import datetime
from langgraph.types import Send
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
AnswerQuestionOutput,
)
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
SubQuestionAnsweringInput,
)
from onyx.agents.agent_search.deep_search.initial.generate_initial_answer.states import (
SubQuestionRetrievalState,
)
from onyx.agents.agent_search.shared_graph_utils.utils import make_question_id
def parallelize_initial_sub_question_answering(
state: SubQuestionRetrievalState,
) -> list[Send | Hashable]:
"""
LangGraph edge to parallelize the initial sub-question answering. If there are no sub-questions,
we send empty answers to the initial answer generation, and that answer would be generated
solely based on the documents retrieved for the original question.
"""
edge_start_time = datetime.now()
if len(state.initial_sub_questions) > 0:
return [
Send(
"answer_query_subgraph",
SubQuestionAnsweringInput(
question=question,
question_id=make_question_id(0, question_num + 1),
log_messages=[
f"{edge_start_time} -- Main Edge - Parallelize Initial Sub-question Answering"
],
),
)
for question_num, question in enumerate(state.initial_sub_questions)
]
else:
return [
Send(
"ingest_answers",
AnswerQuestionOutput(
answer_results=[],
),
)
]

View File

@@ -1,96 +0,0 @@
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
from onyx.agents.agent_search.deep_search.initial.generate_initial_answer.nodes.generate_initial_answer import (
generate_initial_answer,
)
from onyx.agents.agent_search.deep_search.initial.generate_initial_answer.nodes.validate_initial_answer import (
validate_initial_answer,
)
from onyx.agents.agent_search.deep_search.initial.generate_initial_answer.states import (
SubQuestionRetrievalInput,
)
from onyx.agents.agent_search.deep_search.initial.generate_initial_answer.states import (
SubQuestionRetrievalState,
)
from onyx.agents.agent_search.deep_search.initial.generate_sub_answers.graph_builder import (
generate_sub_answers_graph_builder,
)
from onyx.agents.agent_search.deep_search.initial.retrieve_orig_question_docs.graph_builder import (
retrieve_orig_question_docs_graph_builder,
)
from onyx.utils.logger import setup_logger
logger = setup_logger()
def generate_initial_answer_graph_builder(test_mode: bool = False) -> StateGraph:
"""
LangGraph graph builder for the initial answer generation.
"""
graph = StateGraph(
state_schema=SubQuestionRetrievalState,
input=SubQuestionRetrievalInput,
)
# The sub-graph that generates the initial sub-answers
generate_sub_answers = generate_sub_answers_graph_builder().compile()
graph.add_node(
node="generate_sub_answers_subgraph",
action=generate_sub_answers,
)
# The sub-graph that retrieves the original question documents. This is run
# in parallel with the sub-answer generation process
retrieve_orig_question_docs = retrieve_orig_question_docs_graph_builder().compile()
graph.add_node(
node="retrieve_orig_question_docs_subgraph_wrapper",
action=retrieve_orig_question_docs,
)
# Node that generates the initial answer using the results of the previous
# two sub-graphs
graph.add_node(
node="generate_initial_answer",
action=generate_initial_answer,
)
# Node that validates the initial answer
graph.add_node(
node="validate_initial_answer",
action=validate_initial_answer,
)
### Add edges ###
graph.add_edge(
start_key=START,
end_key="retrieve_orig_question_docs_subgraph_wrapper",
)
graph.add_edge(
start_key=START,
end_key="generate_sub_answers_subgraph",
)
# Wait for both, the original question docs and the sub-answers to be generated before proceeding
graph.add_edge(
start_key=[
"retrieve_orig_question_docs_subgraph_wrapper",
"generate_sub_answers_subgraph",
],
end_key="generate_initial_answer",
)
graph.add_edge(
start_key="generate_initial_answer",
end_key="validate_initial_answer",
)
graph.add_edge(
start_key="validate_initial_answer",
end_key=END,
)
return graph

View File

@@ -1,405 +0,0 @@
from datetime import datetime
from typing import cast
from langchain_core.messages import HumanMessage
from langchain_core.messages import merge_content
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.deep_search.initial.generate_initial_answer.states import (
SubQuestionRetrievalState,
)
from onyx.agents.agent_search.deep_search.main.models import AgentBaseMetrics
from onyx.agents.agent_search.deep_search.main.operations import (
calculate_initial_agent_stats,
)
from onyx.agents.agent_search.deep_search.main.operations import get_query_info
from onyx.agents.agent_search.deep_search.main.operations import logger
from onyx.agents.agent_search.deep_search.main.states import (
InitialAnswerUpdate,
)
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
get_prompt_enrichment_components,
)
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
trim_prompt_piece,
)
from onyx.agents.agent_search.shared_graph_utils.calculations import (
get_answer_generation_documents,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_LLM_RATELIMIT_MESSAGE,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_LLM_TIMEOUT_MESSAGE,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AgentLLMErrorType,
)
from onyx.agents.agent_search.shared_graph_utils.llm import stream_llm_answer
from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLog
from onyx.agents.agent_search.shared_graph_utils.models import InitialAgentResultStats
from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings
from onyx.agents.agent_search.shared_graph_utils.operators import (
dedup_inference_section_list,
)
from onyx.agents.agent_search.shared_graph_utils.utils import _should_restrict_tokens
from onyx.agents.agent_search.shared_graph_utils.utils import (
dispatch_main_answer_stop_info,
)
from onyx.agents.agent_search.shared_graph_utils.utils import format_docs
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_deduplicated_structured_subquestion_documents,
)
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import relevance_from_docs
from onyx.agents.agent_search.shared_graph_utils.utils import remove_document_citations
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import AgentAnswerPiece
from onyx.chat.models import ExtendedToolResponse
from onyx.chat.models import StreamingError
from onyx.configs.agent_configs import AGENT_ANSWER_GENERATION_BY_FAST_LLM
from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS
from onyx.configs.agent_configs import AGENT_MAX_STREAMED_DOCS_FOR_INITIAL_ANSWER
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_ANSWER_GENERATION
from onyx.configs.agent_configs import AGENT_MIN_ORIG_QUESTION_DOCS
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION,
)
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_LLM_INITIAL_ANSWER_GENERATION,
)
from onyx.llm.chat_llm import LLMRateLimitError
from onyx.llm.chat_llm import LLMTimeoutError
from onyx.prompts.agent_search import INITIAL_ANSWER_PROMPT_W_SUB_QUESTIONS
from onyx.prompts.agent_search import (
INITIAL_ANSWER_PROMPT_WO_SUB_QUESTIONS,
)
from onyx.prompts.agent_search import (
SUB_QUESTION_ANSWER_TEMPLATE,
)
from onyx.prompts.agent_search import UNKNOWN_ANSWER
from onyx.tools.tool_implementations.search.search_tool import yield_search_responses
from onyx.utils.threadpool_concurrency import run_with_timeout
from onyx.utils.timing import log_function_time
_llm_node_error_strings = LLMNodeErrorStrings(
timeout="LLM Timeout Error. The initial answer could not be generated.",
rate_limit="LLM Rate Limit Error. The initial answer could not be generated.",
general_error="General LLM Error. The initial answer could not be generated.",
)
@log_function_time(print_only=True)
def generate_initial_answer(
state: SubQuestionRetrievalState,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> InitialAnswerUpdate:
"""
LangGraph node to generate the initial answer, using the initial sub-questions/sub-answers and the
documents retrieved for the original question.
"""
node_start_time = datetime.now()
graph_config = cast(GraphConfig, config["metadata"]["config"])
question = graph_config.inputs.prompt_builder.raw_user_query
prompt_enrichment_components = get_prompt_enrichment_components(graph_config)
# get all documents cited in sub-questions
structured_subquestion_docs = get_deduplicated_structured_subquestion_documents(
state.sub_question_results
)
orig_question_retrieval_documents = state.orig_question_retrieved_documents
consolidated_context_docs = structured_subquestion_docs.cited_documents
counter = 0
for original_doc in orig_question_retrieval_documents:
if original_doc in structured_subquestion_docs.cited_documents:
continue
if (
counter <= AGENT_MIN_ORIG_QUESTION_DOCS
or len(consolidated_context_docs) < AGENT_MAX_ANSWER_CONTEXT_DOCS
):
consolidated_context_docs.append(original_doc)
counter += 1
# sort docs by their scores - though the scores refer to different questions
relevant_docs = dedup_inference_section_list(consolidated_context_docs)
sub_questions: list[str] = []
# Create the list of documents to stream out. Start with the
# ones that wil be in the context (or, if len == 0, use docs
# that were retrieved for the original question)
answer_generation_documents = get_answer_generation_documents(
relevant_docs=relevant_docs,
context_documents=structured_subquestion_docs.context_documents,
original_question_docs=orig_question_retrieval_documents,
max_docs=AGENT_MAX_STREAMED_DOCS_FOR_INITIAL_ANSWER,
)
# Use the query info from the base document retrieval
query_info = get_query_info(state.orig_question_sub_query_retrieval_results)
assert (
graph_config.tooling.search_tool
), "search_tool must be provided for agentic search"
relevance_list = relevance_from_docs(
answer_generation_documents.streaming_documents
)
for tool_response in yield_search_responses(
query=question,
get_retrieved_sections=lambda: answer_generation_documents.context_documents,
get_final_context_sections=lambda: answer_generation_documents.context_documents,
search_query_info=query_info,
get_section_relevance=lambda: relevance_list,
search_tool=graph_config.tooling.search_tool,
):
write_custom_event(
"tool_response",
ExtendedToolResponse(
id=tool_response.id,
response=tool_response.response,
level=0,
level_question_num=0, # 0, 0 is the base question
),
writer,
)
if len(answer_generation_documents.context_documents) == 0:
write_custom_event(
"initial_agent_answer",
AgentAnswerPiece(
answer_piece=UNKNOWN_ANSWER,
level=0,
level_question_num=0,
answer_type="agent_level_answer",
),
writer,
)
dispatch_main_answer_stop_info(0, writer)
answer = UNKNOWN_ANSWER
initial_agent_stats = InitialAgentResultStats(
sub_questions={},
original_question={},
agent_effectiveness={},
)
else:
sub_question_answer_results = state.sub_question_results
# Collect the sub-questions and sub-answers and construct an appropriate
# prompt string.
# Consider replacing by a function.
answered_sub_questions: list[str] = []
all_sub_questions: list[str] = [] # Separate list for tracking all questions
for idx, sub_question_answer_result in enumerate(
sub_question_answer_results, start=1
):
all_sub_questions.append(sub_question_answer_result.question)
is_valid_answer = (
sub_question_answer_result.verified_high_quality
and sub_question_answer_result.answer
and sub_question_answer_result.answer != UNKNOWN_ANSWER
)
if is_valid_answer:
answered_sub_questions.append(
SUB_QUESTION_ANSWER_TEMPLATE.format(
sub_question=sub_question_answer_result.question,
sub_answer=sub_question_answer_result.answer,
sub_question_num=idx,
)
)
sub_question_answer_str = (
"\n\n------\n\n".join(answered_sub_questions)
if answered_sub_questions
else ""
)
# Use the appropriate prompt based on whether there are sub-questions.
base_prompt = (
INITIAL_ANSWER_PROMPT_W_SUB_QUESTIONS
if answered_sub_questions
else INITIAL_ANSWER_PROMPT_WO_SUB_QUESTIONS
)
sub_questions = all_sub_questions # Replace the original assignment
model = (
graph_config.tooling.fast_llm
if AGENT_ANSWER_GENERATION_BY_FAST_LLM
else graph_config.tooling.primary_llm
)
doc_context = format_docs(answer_generation_documents.context_documents)
doc_context = trim_prompt_piece(
config=model.config,
prompt_piece=doc_context,
reserved_str=(
base_prompt
+ sub_question_answer_str
+ prompt_enrichment_components.persona_prompts.contextualized_prompt
+ prompt_enrichment_components.history
+ prompt_enrichment_components.date_str
),
)
msg = [
HumanMessage(
content=base_prompt.format(
question=question,
answered_sub_questions=remove_document_citations(
sub_question_answer_str
),
relevant_docs=doc_context,
persona_specification=prompt_enrichment_components.persona_prompts.contextualized_prompt,
history=prompt_enrichment_components.history,
date_prompt=prompt_enrichment_components.date_str,
)
)
]
streamed_tokens: list[str] = [""]
dispatch_timings: list[float] = []
agent_error: AgentErrorLog | None = None
try:
streamed_tokens, dispatch_timings = run_with_timeout(
AGENT_TIMEOUT_LLM_INITIAL_ANSWER_GENERATION,
lambda: stream_llm_answer(
llm=model,
prompt=msg,
event_name="initial_agent_answer",
writer=writer,
agent_answer_level=0,
agent_answer_question_num=0,
agent_answer_type="agent_level_answer",
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION,
max_tokens=(
AGENT_MAX_TOKENS_ANSWER_GENERATION
if _should_restrict_tokens(model.config)
else None
),
),
)
except (LLMTimeoutError, TimeoutError):
agent_error = AgentErrorLog(
error_type=AgentLLMErrorType.TIMEOUT,
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
error_result=_llm_node_error_strings.timeout,
)
logger.error("LLM Timeout Error - generate initial answer")
except LLMRateLimitError:
agent_error = AgentErrorLog(
error_type=AgentLLMErrorType.RATE_LIMIT,
error_message=AGENT_LLM_RATELIMIT_MESSAGE,
error_result=_llm_node_error_strings.rate_limit,
)
logger.error("LLM Rate Limit Error - generate initial answer")
if agent_error:
write_custom_event(
"initial_agent_answer",
StreamingError(
error=AGENT_LLM_TIMEOUT_MESSAGE,
),
writer,
)
return InitialAnswerUpdate(
initial_answer=None,
answer_error=AgentErrorLog(
error_message=agent_error.error_message or "An LLM error occurred",
error_type=agent_error.error_type,
error_result=agent_error.error_result,
),
initial_agent_stats=None,
generated_sub_questions=sub_questions,
agent_base_end_time=None,
agent_base_metrics=None,
log_messages=[
get_langgraph_node_log_string(
graph_component="initial - generate initial answer",
node_name="generate initial answer",
node_start_time=node_start_time,
result=agent_error.error_result or "An LLM error occurred",
)
],
)
logger.debug(
f"Average dispatch time for initial answer: {sum(dispatch_timings) / len(dispatch_timings)}"
)
dispatch_main_answer_stop_info(0, writer)
response = merge_content(*streamed_tokens)
answer = cast(str, response)
initial_agent_stats = calculate_initial_agent_stats(
state.sub_question_results, state.orig_question_retrieval_stats
)
logger.debug(
f"\n\nYYYYY--Sub-Questions:\n\n{sub_question_answer_str}\n\nStats:\n\n"
)
if initial_agent_stats:
logger.debug(initial_agent_stats.original_question)
logger.debug(initial_agent_stats.sub_questions)
logger.debug(initial_agent_stats.agent_effectiveness)
agent_base_end_time = datetime.now()
if agent_base_end_time and state.agent_start_time:
duration_s = (agent_base_end_time - state.agent_start_time).total_seconds()
else:
duration_s = None
agent_base_metrics = AgentBaseMetrics(
num_verified_documents_total=len(relevant_docs),
num_verified_documents_core=state.orig_question_retrieval_stats.verified_count,
verified_avg_score_core=state.orig_question_retrieval_stats.verified_avg_scores,
num_verified_documents_base=initial_agent_stats.sub_questions.get(
"num_verified_documents"
),
verified_avg_score_base=initial_agent_stats.sub_questions.get(
"verified_avg_score"
),
base_doc_boost_factor=initial_agent_stats.agent_effectiveness.get(
"utilized_chunk_ratio"
),
support_boost_factor=initial_agent_stats.agent_effectiveness.get(
"support_ratio"
),
duration_s=duration_s,
)
return InitialAnswerUpdate(
initial_answer=answer,
initial_agent_stats=initial_agent_stats,
generated_sub_questions=sub_questions,
agent_base_end_time=agent_base_end_time,
agent_base_metrics=agent_base_metrics,
log_messages=[
get_langgraph_node_log_string(
graph_component="initial - generate initial answer",
node_name="generate initial answer",
node_start_time=node_start_time,
result="",
)
],
)

View File

@@ -1,42 +0,0 @@
from datetime import datetime
from onyx.agents.agent_search.deep_search.initial.generate_initial_answer.states import (
SubQuestionRetrievalState,
)
from onyx.agents.agent_search.deep_search.main.operations import logger
from onyx.agents.agent_search.deep_search.main.states import (
InitialAnswerQualityUpdate,
)
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.utils.timing import log_function_time
@log_function_time(print_only=True)
def validate_initial_answer(
state: SubQuestionRetrievalState,
) -> InitialAnswerQualityUpdate:
"""
Check whether the initial answer sufficiently addresses the original user question.
"""
node_start_time = datetime.now()
logger.debug(
f"--------{node_start_time}--------Checking for base answer validity - for not set True/False manually"
)
verdict = True # not actually required as already streamed out. Refinement will do similar
return InitialAnswerQualityUpdate(
initial_answer_quality_eval=verdict,
log_messages=[
get_langgraph_node_log_string(
graph_component="initial - generate initial answer",
node_name="validate initial answer",
node_start_time=node_start_time,
result="",
)
],
)

View File

@@ -1,51 +0,0 @@
from operator import add
from typing import Annotated
from typing import TypedDict
from onyx.agents.agent_search.core_state import CoreState
from onyx.agents.agent_search.deep_search.main.states import (
ExploratorySearchUpdate,
)
from onyx.agents.agent_search.deep_search.main.states import (
InitialAnswerQualityUpdate,
)
from onyx.agents.agent_search.deep_search.main.states import (
InitialAnswerUpdate,
)
from onyx.agents.agent_search.deep_search.main.states import (
InitialQuestionDecompositionUpdate,
)
from onyx.agents.agent_search.deep_search.main.states import (
OrigQuestionRetrievalUpdate,
)
from onyx.agents.agent_search.deep_search.main.states import (
SubQuestionResultsUpdate,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.models import (
QuestionRetrievalResult,
)
from onyx.context.search.models import InferenceSection
### States ###
class SubQuestionRetrievalInput(CoreState):
exploratory_search_results: list[InferenceSection]
## Graph State
class SubQuestionRetrievalState(
# This includes the core state
SubQuestionRetrievalInput,
InitialQuestionDecompositionUpdate,
InitialAnswerUpdate,
SubQuestionResultsUpdate,
OrigQuestionRetrievalUpdate,
InitialAnswerQualityUpdate,
ExploratorySearchUpdate,
):
base_raw_search_result: Annotated[list[QuestionRetrievalResult], add]
## Graph Output State
class SubQuestionRetrievalOutput(TypedDict):
log_messages: list[str]

View File

@@ -1,48 +0,0 @@
from collections.abc import Hashable
from datetime import datetime
from langgraph.types import Send
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
AnswerQuestionOutput,
)
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
SubQuestionAnsweringInput,
)
from onyx.agents.agent_search.deep_search.initial.generate_initial_answer.states import (
SubQuestionRetrievalState,
)
from onyx.agents.agent_search.shared_graph_utils.utils import make_question_id
def parallelize_initial_sub_question_answering(
state: SubQuestionRetrievalState,
) -> list[Send | Hashable]:
"""
LangGraph edge to parallelize the initial sub-question answering.
"""
edge_start_time = datetime.now()
if len(state.initial_sub_questions) > 0:
return [
Send(
"answer_sub_question_subgraphs",
SubQuestionAnsweringInput(
question=question,
question_id=make_question_id(0, question_num + 1),
log_messages=[
f"{edge_start_time} -- Main Edge - Parallelize Initial Sub-question Answering"
],
),
)
for question_num, question in enumerate(state.initial_sub_questions)
]
else:
return [
Send(
"ingest_answers",
AnswerQuestionOutput(
answer_results=[],
),
)
]

View File

@@ -1,81 +0,0 @@
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.graph_builder import (
answer_query_graph_builder,
)
from onyx.agents.agent_search.deep_search.initial.generate_sub_answers.edges import (
parallelize_initial_sub_question_answering,
)
from onyx.agents.agent_search.deep_search.initial.generate_sub_answers.nodes.decompose_orig_question import (
decompose_orig_question,
)
from onyx.agents.agent_search.deep_search.initial.generate_sub_answers.nodes.format_initial_sub_answers import (
format_initial_sub_answers,
)
from onyx.agents.agent_search.deep_search.initial.generate_sub_answers.states import (
SubQuestionAnsweringInput,
)
from onyx.agents.agent_search.deep_search.initial.generate_sub_answers.states import (
SubQuestionAnsweringState,
)
from onyx.utils.logger import setup_logger
logger = setup_logger()
test_mode = False
def generate_sub_answers_graph_builder() -> StateGraph:
"""
LangGraph graph builder for the initial sub-answer generation process.
It generates the initial sub-questions and produces the answers.
"""
graph = StateGraph(
state_schema=SubQuestionAnsweringState,
input=SubQuestionAnsweringInput,
)
# Decompose the original question into sub-questions
graph.add_node(
node="decompose_orig_question",
action=decompose_orig_question,
)
# The sub-graph that executes the initial sub-question answering for
# each of the sub-questions.
answer_sub_question_subgraphs = answer_query_graph_builder().compile()
graph.add_node(
node="answer_sub_question_subgraphs",
action=answer_sub_question_subgraphs,
)
# Node that collects and formats the initial sub-question answers
graph.add_node(
node="format_initial_sub_question_answers",
action=format_initial_sub_answers,
)
graph.add_edge(
start_key=START,
end_key="decompose_orig_question",
)
graph.add_conditional_edges(
source="decompose_orig_question",
path=parallelize_initial_sub_question_answering,
path_map=["answer_sub_question_subgraphs"],
)
graph.add_edge(
start_key=["answer_sub_question_subgraphs"],
end_key="format_initial_sub_question_answers",
)
graph.add_edge(
start_key="format_initial_sub_question_answers",
end_key=END,
)
return graph

View File

@@ -1,190 +0,0 @@
from datetime import datetime
from typing import cast
from langchain_core.messages import HumanMessage
from langchain_core.messages import merge_content
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.deep_search.initial.generate_initial_answer.states import (
SubQuestionRetrievalState,
)
from onyx.agents.agent_search.deep_search.main.models import (
AgentRefinedMetrics,
)
from onyx.agents.agent_search.deep_search.main.operations import dispatch_subquestion
from onyx.agents.agent_search.deep_search.main.operations import (
dispatch_subquestion_sep,
)
from onyx.agents.agent_search.deep_search.main.states import (
InitialQuestionDecompositionUpdate,
)
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
build_history_prompt,
)
from onyx.agents.agent_search.shared_graph_utils.models import BaseMessage_Content
from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings
from onyx.agents.agent_search.shared_graph_utils.utils import dispatch_separated
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import StreamStopInfo
from onyx.chat.models import StreamStopReason
from onyx.chat.models import StreamType
from onyx.chat.models import SubQuestionPiece
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_SUBQUESTION_GENERATION
from onyx.configs.agent_configs import AGENT_NUM_DOCS_FOR_DECOMPOSITION
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_CONNECT_LLM_SUBQUESTION_GENERATION,
)
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_LLM_SUBQUESTION_GENERATION,
)
from onyx.llm.chat_llm import LLMRateLimitError
from onyx.llm.chat_llm import LLMTimeoutError
from onyx.prompts.agent_search import (
INITIAL_DECOMPOSITION_PROMPT_QUESTIONS_AFTER_SEARCH_ASSUMING_REFINEMENT,
)
from onyx.prompts.agent_search import (
INITIAL_QUESTION_DECOMPOSITION_PROMPT_ASSUMING_REFINEMENT,
)
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
from onyx.utils.timing import log_function_time
logger = setup_logger()
_llm_node_error_strings = LLMNodeErrorStrings(
timeout="LLM Timeout Error. Sub-questions could not be generated.",
rate_limit="LLM Rate Limit Error. Sub-questions could not be generated.",
general_error="General LLM Error. Sub-questions could not be generated.",
)
@log_function_time(print_only=True)
def decompose_orig_question(
state: SubQuestionRetrievalState,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> InitialQuestionDecompositionUpdate:
"""
LangGraph node to decompose the original question into sub-questions.
"""
node_start_time = datetime.now()
graph_config = cast(GraphConfig, config["metadata"]["config"])
question = graph_config.inputs.prompt_builder.raw_user_query
perform_initial_search_decomposition = (
graph_config.behavior.perform_initial_search_decomposition
)
# Get the rewritten queries in a defined format
model = graph_config.tooling.fast_llm
history = build_history_prompt(graph_config, question)
# Use the initial search results to inform the decomposition
agent_start_time = datetime.now()
# Initial search to inform decomposition. Just get top 3 fits
if perform_initial_search_decomposition:
# Due to unfortunate state representation in LangGraph, we need here to double check that the retrieval has
# happened prior to this point, allowing silent failure here since it is not critical for decomposition in
# all queries.
if not state.exploratory_search_results:
logger.error("Initial search for decomposition failed")
sample_doc_str = "\n\n".join(
[
doc.combined_content
for doc in state.exploratory_search_results[
:AGENT_NUM_DOCS_FOR_DECOMPOSITION
]
]
)
decomposition_prompt = INITIAL_DECOMPOSITION_PROMPT_QUESTIONS_AFTER_SEARCH_ASSUMING_REFINEMENT.format(
question=question, sample_doc_str=sample_doc_str, history=history
)
else:
decomposition_prompt = (
INITIAL_QUESTION_DECOMPOSITION_PROMPT_ASSUMING_REFINEMENT.format(
question=question, history=history
)
)
# Start decomposition
msg = [HumanMessage(content=decomposition_prompt)]
# Send the initial question as a subquestion with number 0
write_custom_event(
"decomp_qs",
SubQuestionPiece(
sub_question=question,
level=0,
level_question_num=0,
),
writer,
)
# dispatches custom events for subquestion tokens, adding in subquestion ids.
streamed_tokens: list[BaseMessage_Content] = []
try:
streamed_tokens = run_with_timeout(
AGENT_TIMEOUT_LLM_SUBQUESTION_GENERATION,
dispatch_separated,
model.stream(
msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_SUBQUESTION_GENERATION,
max_tokens=AGENT_MAX_TOKENS_SUBQUESTION_GENERATION,
),
dispatch_subquestion(0, writer),
sep_callback=dispatch_subquestion_sep(0, writer),
)
decomposition_response = merge_content(*streamed_tokens)
list_of_subqs = cast(str, decomposition_response).split("\n")
initial_sub_questions = [sq.strip() for sq in list_of_subqs if sq.strip() != ""]
log_result = f"decomposed original question into {len(initial_sub_questions)} subquestions"
stop_event = StreamStopInfo(
stop_reason=StreamStopReason.FINISHED,
stream_type=StreamType.SUB_QUESTIONS,
level=0,
)
write_custom_event("stream_finished", stop_event, writer)
except (LLMTimeoutError, TimeoutError) as e:
logger.error("LLM Timeout Error - decompose orig question")
raise e # fail loudly on this critical step
except LLMRateLimitError as e:
logger.error("LLM Rate Limit Error - decompose orig question")
raise e
return InitialQuestionDecompositionUpdate(
initial_sub_questions=initial_sub_questions,
agent_start_time=agent_start_time,
agent_refined_start_time=None,
agent_refined_end_time=None,
agent_refined_metrics=AgentRefinedMetrics(
refined_doc_boost_factor=None,
refined_question_boost_factor=None,
duration_s=None,
),
log_messages=[
get_langgraph_node_log_string(
graph_component="initial - generate sub answers",
node_name="decompose original question",
node_start_time=node_start_time,
result=log_result,
)
],
)

View File

@@ -1,50 +0,0 @@
from datetime import datetime
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
AnswerQuestionOutput,
)
from onyx.agents.agent_search.deep_search.main.states import (
SubQuestionResultsUpdate,
)
from onyx.agents.agent_search.shared_graph_utils.operators import (
dedup_inference_sections,
)
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
def format_initial_sub_answers(
state: AnswerQuestionOutput,
) -> SubQuestionResultsUpdate:
"""
LangGraph node to format the answers to the initial sub-questions, including
deduping verified documents and context documents.
"""
node_start_time = datetime.now()
documents = []
context_documents = []
cited_documents = []
answer_results = state.answer_results
for answer_result in answer_results:
documents.extend(answer_result.verified_reranked_documents)
context_documents.extend(answer_result.context_documents)
cited_documents.extend(answer_result.cited_documents)
return SubQuestionResultsUpdate(
# Deduping is done by the documents operator for the main graph
# so we might not need to dedup here
verified_reranked_documents=dedup_inference_sections(documents, []),
context_documents=dedup_inference_sections(context_documents, []),
cited_documents=dedup_inference_sections(cited_documents, []),
sub_question_results=answer_results,
log_messages=[
get_langgraph_node_log_string(
graph_component="initial - generate sub answers",
node_name="format initial sub answers",
node_start_time=node_start_time,
result="",
)
],
)

View File

@@ -1,34 +0,0 @@
from typing import TypedDict
from onyx.agents.agent_search.core_state import CoreState
from onyx.agents.agent_search.deep_search.main.states import (
InitialAnswerUpdate,
)
from onyx.agents.agent_search.deep_search.main.states import (
InitialQuestionDecompositionUpdate,
)
from onyx.agents.agent_search.deep_search.main.states import (
SubQuestionResultsUpdate,
)
from onyx.context.search.models import InferenceSection
### States ###
class SubQuestionAnsweringInput(CoreState):
exploratory_search_results: list[InferenceSection]
## Graph State
class SubQuestionAnsweringState(
# This includes the core state
SubQuestionAnsweringInput,
InitialQuestionDecompositionUpdate,
InitialAnswerUpdate,
SubQuestionResultsUpdate,
):
pass
## Graph Output State
class SubQuestionAnsweringOutput(TypedDict):
log_messages: list[str]

View File

@@ -1,81 +0,0 @@
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
from onyx.agents.agent_search.deep_search.initial.retrieve_orig_question_docs.nodes.format_orig_question_search_input import (
format_orig_question_search_input,
)
from onyx.agents.agent_search.deep_search.initial.retrieve_orig_question_docs.nodes.format_orig_question_search_output import (
format_orig_question_search_output,
)
from onyx.agents.agent_search.deep_search.initial.retrieve_orig_question_docs.states import (
BaseRawSearchInput,
)
from onyx.agents.agent_search.deep_search.initial.retrieve_orig_question_docs.states import (
BaseRawSearchOutput,
)
from onyx.agents.agent_search.deep_search.initial.retrieve_orig_question_docs.states import (
BaseRawSearchState,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.graph_builder import (
expanded_retrieval_graph_builder,
)
def retrieve_orig_question_docs_graph_builder() -> StateGraph:
"""
LangGraph graph builder for the retrieval of documents
that are relevant to the original question. This is
largely a wrapper around the expanded retrieval process to
ensure parallelism with the sub-question answer process.
"""
graph = StateGraph(
state_schema=BaseRawSearchState,
input=BaseRawSearchInput,
output=BaseRawSearchOutput,
)
### Add nodes ###
# Format the original question search output
graph.add_node(
node="format_orig_question_search_output",
action=format_orig_question_search_output,
)
# The sub-graph that executes the expanded retrieval process
expanded_retrieval = expanded_retrieval_graph_builder().compile()
graph.add_node(
node="retrieve_orig_question_docs_subgraph",
action=expanded_retrieval,
)
# Format the original question search input
graph.add_node(
node="format_orig_question_search_input",
action=format_orig_question_search_input,
)
### Add edges ###
graph.add_edge(start_key=START, end_key="format_orig_question_search_input")
graph.add_edge(
start_key="format_orig_question_search_input",
end_key="retrieve_orig_question_docs_subgraph",
)
graph.add_edge(
start_key="retrieve_orig_question_docs_subgraph",
end_key="format_orig_question_search_output",
)
graph.add_edge(
start_key="format_orig_question_search_output",
end_key=END,
)
return graph
if __name__ == "__main__":
pass

View File

@@ -1,28 +0,0 @@
from typing import cast
from langchain_core.runnables.config import RunnableConfig
from onyx.agents.agent_search.core_state import CoreState
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
ExpandedRetrievalInput,
)
from onyx.agents.agent_search.models import GraphConfig
from onyx.utils.logger import setup_logger
logger = setup_logger()
def format_orig_question_search_input(
state: CoreState, config: RunnableConfig
) -> ExpandedRetrievalInput:
"""
LangGraph node to format the search input for the original question.
"""
logger.debug("generate_raw_search_data")
graph_config = cast(GraphConfig, config["metadata"]["config"])
return ExpandedRetrievalInput(
question=graph_config.inputs.prompt_builder.raw_user_query,
base_search=True,
sub_question_id=None, # This graph is always and only used for the original question
log_messages=[],
)

View File

@@ -1,30 +0,0 @@
from onyx.agents.agent_search.deep_search.main.states import OrigQuestionRetrievalUpdate
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
ExpandedRetrievalOutput,
)
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkRetrievalStats
from onyx.utils.logger import setup_logger
logger = setup_logger()
def format_orig_question_search_output(
state: ExpandedRetrievalOutput,
) -> OrigQuestionRetrievalUpdate:
"""
LangGraph node to format the search result for the original question into the
proper format.
"""
sub_question_retrieval_stats = state.expanded_retrieval_result.retrieval_stats
if sub_question_retrieval_stats is None:
sub_question_retrieval_stats = AgentChunkRetrievalStats()
else:
sub_question_retrieval_stats = sub_question_retrieval_stats
return OrigQuestionRetrievalUpdate(
orig_question_verified_reranked_documents=state.expanded_retrieval_result.verified_reranked_documents,
orig_question_sub_query_retrieval_results=state.expanded_retrieval_result.expanded_query_results,
orig_question_retrieved_documents=state.retrieved_documents,
orig_question_retrieval_stats=sub_question_retrieval_stats,
log_messages=[],
)

View File

@@ -1,29 +0,0 @@
from onyx.agents.agent_search.deep_search.main.states import (
OrigQuestionRetrievalUpdate,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
ExpandedRetrievalInput,
)
## Graph Input State
class BaseRawSearchInput(ExpandedRetrievalInput):
pass
## Graph Output State
class BaseRawSearchOutput(OrigQuestionRetrievalUpdate):
"""
This is a list of results even though each call of this subgraph only returns one result.
This is because if we parallelize the answer query subgraph, there will be multiple
results in a list so the add operator is used to add them together.
"""
# base_expanded_retrieval_result: QuestionRetrievalResult = QuestionRetrievalResult()
## Graph State
class BaseRawSearchState(
BaseRawSearchInput, BaseRawSearchOutput, OrigQuestionRetrievalUpdate
):
pass

View File

@@ -1,113 +0,0 @@
from collections.abc import Hashable
from datetime import datetime
from typing import cast
from typing import Literal
from langchain_core.runnables import RunnableConfig
from langgraph.types import Send
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
AnswerQuestionOutput,
)
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
SubQuestionAnsweringInput,
)
from onyx.agents.agent_search.deep_search.main.states import MainState
from onyx.agents.agent_search.deep_search.main.states import (
RequireRefinemenEvalUpdate,
)
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.utils import make_question_id
from onyx.utils.logger import setup_logger
logger = setup_logger()
def route_initial_tool_choice(
state: MainState, config: RunnableConfig
) -> Literal["call_tool", "start_agent_search", "logging_node"]:
"""
LangGraph edge to route to agent search.
"""
agent_config = cast(GraphConfig, config["metadata"]["config"])
if state.tool_choice is None:
return "logging_node"
if (
agent_config.behavior.use_agentic_search
and agent_config.tooling.search_tool is not None
and state.tool_choice.tool.name == agent_config.tooling.search_tool.name
):
return "start_agent_search"
else:
return "call_tool"
def parallelize_initial_sub_question_answering(
state: MainState,
) -> list[Send | Hashable]:
edge_start_time = datetime.now()
if len(state.initial_sub_questions) > 0:
return [
Send(
"answer_query_subgraph",
SubQuestionAnsweringInput(
question=question,
question_id=make_question_id(0, question_num + 1),
log_messages=[
f"{edge_start_time} -- Main Edge - Parallelize Initial Sub-question Answering"
],
),
)
for question_num, question in enumerate(state.initial_sub_questions)
]
else:
return [
Send(
"ingest_answers",
AnswerQuestionOutput(
answer_results=[],
),
)
]
# Define the function that determines whether to continue or not
def continue_to_refined_answer_or_end(
state: RequireRefinemenEvalUpdate,
) -> Literal["create_refined_sub_questions", "logging_node"]:
if state.require_refined_answer_eval:
return "create_refined_sub_questions"
else:
return "logging_node"
def parallelize_refined_sub_question_answering(
state: MainState,
) -> list[Send | Hashable]:
edge_start_time = datetime.now()
if len(state.refined_sub_questions) > 0:
return [
Send(
"answer_refined_question_subgraphs",
SubQuestionAnsweringInput(
question=question_data.sub_question,
question_id=make_question_id(1, question_num),
log_messages=[
f"{edge_start_time} -- Main Edge - Parallelize Refined Sub-question Answering"
],
),
)
for question_num, question_data in state.refined_sub_questions.items()
]
else:
return [
Send(
"ingest_refined_sub_answers",
AnswerQuestionOutput(
answer_results=[],
),
)
]

View File

@@ -1,263 +0,0 @@
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
from onyx.agents.agent_search.deep_search.initial.generate_initial_answer.graph_builder import (
generate_initial_answer_graph_builder,
)
from onyx.agents.agent_search.deep_search.main.edges import (
continue_to_refined_answer_or_end,
)
from onyx.agents.agent_search.deep_search.main.edges import (
parallelize_refined_sub_question_answering,
)
from onyx.agents.agent_search.deep_search.main.edges import (
route_initial_tool_choice,
)
from onyx.agents.agent_search.deep_search.main.nodes.compare_answers import (
compare_answers,
)
from onyx.agents.agent_search.deep_search.main.nodes.create_refined_sub_questions import (
create_refined_sub_questions,
)
from onyx.agents.agent_search.deep_search.main.nodes.decide_refinement_need import (
decide_refinement_need,
)
from onyx.agents.agent_search.deep_search.main.nodes.extract_entities_terms import (
extract_entities_terms,
)
from onyx.agents.agent_search.deep_search.main.nodes.generate_validate_refined_answer import (
generate_validate_refined_answer,
)
from onyx.agents.agent_search.deep_search.main.nodes.ingest_refined_sub_answers import (
ingest_refined_sub_answers,
)
from onyx.agents.agent_search.deep_search.main.nodes.persist_agent_results import (
persist_agent_results,
)
from onyx.agents.agent_search.deep_search.main.nodes.start_agent_search import (
start_agent_search,
)
from onyx.agents.agent_search.deep_search.main.states import MainInput
from onyx.agents.agent_search.deep_search.main.states import MainState
from onyx.agents.agent_search.deep_search.refinement.consolidate_sub_answers.graph_builder import (
answer_refined_query_graph_builder,
)
from onyx.agents.agent_search.orchestration.nodes.call_tool import call_tool
from onyx.agents.agent_search.orchestration.nodes.choose_tool import choose_tool
from onyx.agents.agent_search.orchestration.nodes.prepare_tool_input import (
prepare_tool_input,
)
from onyx.agents.agent_search.orchestration.nodes.use_tool_response import (
basic_use_tool_response,
)
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
from onyx.utils.logger import setup_logger
logger = setup_logger()
test_mode = False
def agent_search_graph_builder() -> StateGraph:
"""
LangGraph graph builder for the main agent search process.
"""
graph = StateGraph(
state_schema=MainState,
input=MainInput,
)
# Prepare the tool input
graph.add_node(
node="prepare_tool_input",
action=prepare_tool_input,
)
# Choose the initial tool
graph.add_node(
node="choose_tool",
action=choose_tool,
)
# Call the tool, if required
graph.add_node(
node="call_tool",
action=call_tool,
)
# Use the tool response
graph.add_node(
node="basic_use_tool_response",
action=basic_use_tool_response,
)
# Start the agent search process
graph.add_node(
node="start_agent_search",
action=start_agent_search,
)
# The sub-graph for the initial answer generation
generate_initial_answer_subgraph = generate_initial_answer_graph_builder().compile()
graph.add_node(
node="generate_initial_answer_subgraph",
action=generate_initial_answer_subgraph,
)
# Create the refined sub-questions
graph.add_node(
node="create_refined_sub_questions",
action=create_refined_sub_questions,
)
# Subgraph for the refined sub-answer generation
answer_refined_question = answer_refined_query_graph_builder().compile()
graph.add_node(
node="answer_refined_question_subgraphs",
action=answer_refined_question,
)
# Ingest the refined sub-answers
graph.add_node(
node="ingest_refined_sub_answers",
action=ingest_refined_sub_answers,
)
# Node to generate the refined answer
graph.add_node(
node="generate_validate_refined_answer",
action=generate_validate_refined_answer,
)
# Early node to extract the entities and terms from the initial answer,
# This information is used to inform the creation the refined sub-questions
graph.add_node(
node="extract_entity_term",
action=extract_entities_terms,
)
# Decide if the answer needs to be refined (currently always true)
graph.add_node(
node="decide_refinement_need",
action=decide_refinement_need,
)
# Compare the initial and refined answers, and determine whether
# the refined answer is sufficiently better
graph.add_node(
node="compare_answers",
action=compare_answers,
)
# Log the results. This will log the stats as well as the answers, sub-questions, and sub-answers
graph.add_node(
node="logging_node",
action=persist_agent_results,
)
### Add edges ###
graph.add_edge(start_key=START, end_key="prepare_tool_input")
graph.add_edge(
start_key="prepare_tool_input",
end_key="choose_tool",
)
graph.add_conditional_edges(
"choose_tool",
route_initial_tool_choice,
["call_tool", "start_agent_search", "logging_node"],
)
graph.add_edge(
start_key="call_tool",
end_key="basic_use_tool_response",
)
graph.add_edge(
start_key="basic_use_tool_response",
end_key="logging_node",
)
graph.add_edge(
start_key="start_agent_search",
end_key="generate_initial_answer_subgraph",
)
graph.add_edge(
start_key="start_agent_search",
end_key="extract_entity_term",
)
# Wait for the initial answer generation and the entity/term extraction to be complete
# before deciding if a refinement is needed.
graph.add_edge(
start_key=["generate_initial_answer_subgraph", "extract_entity_term"],
end_key="decide_refinement_need",
)
graph.add_conditional_edges(
source="decide_refinement_need",
path=continue_to_refined_answer_or_end,
path_map=["create_refined_sub_questions", "logging_node"],
)
graph.add_conditional_edges(
source="create_refined_sub_questions",
path=parallelize_refined_sub_question_answering,
path_map=["answer_refined_question_subgraphs"],
)
graph.add_edge(
start_key="answer_refined_question_subgraphs",
end_key="ingest_refined_sub_answers",
)
graph.add_edge(
start_key="ingest_refined_sub_answers",
end_key="generate_validate_refined_answer",
)
graph.add_edge(
start_key="generate_validate_refined_answer",
end_key="compare_answers",
)
graph.add_edge(
start_key="compare_answers",
end_key="logging_node",
)
graph.add_edge(
start_key="logging_node",
end_key=END,
)
return graph
if __name__ == "__main__":
pass
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.llm.factory import get_default_llms
from onyx.context.search.models import SearchRequest
graph = agent_search_graph_builder()
compiled_graph = graph.compile()
primary_llm, fast_llm = get_default_llms()
with get_session_with_current_tenant() as db_session:
search_request = SearchRequest(query="Who created Excel?")
graph_config = get_test_config(
db_session, primary_llm, fast_llm, search_request
)
inputs = MainInput(log_messages=[])
for thing in compiled_graph.stream(
input=inputs,
config={"configurable": {"config": graph_config}},
stream_mode="custom",
subgraphs=True,
):
logger.debug(thing)

View File

@@ -1,168 +0,0 @@
from datetime import datetime
from typing import cast
from langchain_core.messages import BaseMessage
from langchain_core.messages import HumanMessage
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.deep_search.main.states import (
InitialRefinedAnswerComparisonUpdate,
)
from onyx.agents.agent_search.deep_search.main.states import MainState
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
binary_string_test,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_LLM_RATELIMIT_MESSAGE,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_LLM_TIMEOUT_MESSAGE,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_POSITIVE_VALUE_STR,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AgentLLMErrorType,
)
from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLog
from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import RefinedAnswerImprovement
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_VALIDATION
from onyx.configs.agent_configs import AGENT_TIMEOUT_CONNECT_LLM_COMPARE_ANSWERS
from onyx.configs.agent_configs import AGENT_TIMEOUT_LLM_COMPARE_ANSWERS
from onyx.llm.chat_llm import LLMRateLimitError
from onyx.llm.chat_llm import LLMTimeoutError
from onyx.prompts.agent_search import (
INITIAL_REFINED_ANSWER_COMPARISON_PROMPT,
)
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
from onyx.utils.timing import log_function_time
logger = setup_logger()
_llm_node_error_strings = LLMNodeErrorStrings(
timeout="The LLM timed out, and the answers could not be compared.",
rate_limit="The LLM encountered a rate limit, and the answers could not be compared.",
general_error="The LLM encountered an error, and the answers could not be compared.",
)
_ANSWER_QUALITY_NOT_SUFFICIENT_MESSAGE = (
"Answer quality is not sufficient, so stay with the initial answer."
)
@log_function_time(print_only=True)
def compare_answers(
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> InitialRefinedAnswerComparisonUpdate:
"""
LangGraph node to compare the initial answer and the refined answer and determine if the
refined answer is sufficiently better than the initial answer.
"""
node_start_time = datetime.now()
graph_config = cast(GraphConfig, config["metadata"]["config"])
question = graph_config.inputs.prompt_builder.raw_user_query
initial_answer = state.initial_answer
refined_answer = state.refined_answer
# if answer quality is not sufficient, then stay with the initial answer
if not state.refined_answer_quality:
write_custom_event(
"refined_answer_improvement",
RefinedAnswerImprovement(
refined_answer_improvement=False,
),
writer,
)
return InitialRefinedAnswerComparisonUpdate(
refined_answer_improvement_eval=False,
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="compare answers",
node_start_time=node_start_time,
result=_ANSWER_QUALITY_NOT_SUFFICIENT_MESSAGE,
)
],
)
compare_answers_prompt = INITIAL_REFINED_ANSWER_COMPARISON_PROMPT.format(
question=question, initial_answer=initial_answer, refined_answer=refined_answer
)
msg = [HumanMessage(content=compare_answers_prompt)]
agent_error: AgentErrorLog | None = None
# Get the rewritten queries in a defined format
model = graph_config.tooling.fast_llm
resp: BaseMessage | None = None
refined_answer_improvement: bool | None = None
# no need to stream this
try:
resp = run_with_timeout(
AGENT_TIMEOUT_LLM_COMPARE_ANSWERS,
model.invoke,
prompt=msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_COMPARE_ANSWERS,
max_tokens=AGENT_MAX_TOKENS_VALIDATION,
)
except (LLMTimeoutError, TimeoutError):
agent_error = AgentErrorLog(
error_type=AgentLLMErrorType.TIMEOUT,
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
error_result=_llm_node_error_strings.timeout,
)
logger.error("LLM Timeout Error - compare answers")
# continue as True in this support step
except LLMRateLimitError:
agent_error = AgentErrorLog(
error_type=AgentLLMErrorType.RATE_LIMIT,
error_message=AGENT_LLM_RATELIMIT_MESSAGE,
error_result=_llm_node_error_strings.rate_limit,
)
logger.error("LLM Rate Limit Error - compare answers")
# continue as True in this support step
if agent_error or resp is None:
refined_answer_improvement = True
if agent_error:
log_result = agent_error.error_result
else:
log_result = "An answer could not be generated."
else:
refined_answer_improvement = binary_string_test(
text=cast(str, resp.content),
positive_value=AGENT_POSITIVE_VALUE_STR,
)
log_result = f"Answer comparison: {refined_answer_improvement}"
write_custom_event(
"refined_answer_improvement",
RefinedAnswerImprovement(
refined_answer_improvement=refined_answer_improvement,
),
writer,
)
return InitialRefinedAnswerComparisonUpdate(
refined_answer_improvement_eval=refined_answer_improvement,
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="compare answers",
node_start_time=node_start_time,
result=log_result,
)
],
)

View File

@@ -1,213 +0,0 @@
from datetime import datetime
from typing import cast
from langchain_core.messages import HumanMessage
from langchain_core.messages import merge_content
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.deep_search.main.models import (
RefinementSubQuestion,
)
from onyx.agents.agent_search.deep_search.main.operations import dispatch_subquestion
from onyx.agents.agent_search.deep_search.main.operations import (
dispatch_subquestion_sep,
)
from onyx.agents.agent_search.deep_search.main.states import MainState
from onyx.agents.agent_search.deep_search.main.states import (
RefinedQuestionDecompositionUpdate,
)
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
build_history_prompt,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_LLM_RATELIMIT_MESSAGE,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_LLM_TIMEOUT_MESSAGE,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AgentLLMErrorType,
)
from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLog
from onyx.agents.agent_search.shared_graph_utils.models import BaseMessage_Content
from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings
from onyx.agents.agent_search.shared_graph_utils.utils import dispatch_separated
from onyx.agents.agent_search.shared_graph_utils.utils import (
format_entity_term_extraction,
)
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import make_question_id
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import StreamingError
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_SUBQUESTION_GENERATION
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_CONNECT_LLM_REFINED_SUBQUESTION_GENERATION,
)
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_LLM_REFINED_SUBQUESTION_GENERATION,
)
from onyx.llm.chat_llm import LLMRateLimitError
from onyx.llm.chat_llm import LLMTimeoutError
from onyx.prompts.agent_search import (
REFINEMENT_QUESTION_DECOMPOSITION_PROMPT_W_INITIAL_SUBQUESTION_ANSWERS,
)
from onyx.tools.models import ToolCallKickoff
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
from onyx.utils.timing import log_function_time
logger = setup_logger()
_ANSWERED_SUBQUESTIONS_DIVIDER = "\n\n---\n\n"
_llm_node_error_strings = LLMNodeErrorStrings(
timeout="The LLM timed out. The sub-questions could not be generated.",
rate_limit="The LLM encountered a rate limit. The sub-questions could not be generated.",
general_error="The LLM encountered an error. The sub-questions could not be generated.",
)
@log_function_time(print_only=True)
def create_refined_sub_questions(
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> RefinedQuestionDecompositionUpdate:
"""
LangGraph node to create refined sub-questions based on the initial answer, the history,
the entity term extraction results found earlier, and the sub-questions that were answered and failed.
"""
graph_config = cast(GraphConfig, config["metadata"]["config"])
write_custom_event(
"start_refined_answer_creation",
ToolCallKickoff(
tool_name="agent_search_1",
tool_args={
"query": graph_config.inputs.prompt_builder.raw_user_query,
"answer": state.initial_answer,
},
),
writer,
)
node_start_time = datetime.now()
agent_refined_start_time = datetime.now()
question = graph_config.inputs.prompt_builder.raw_user_query
base_answer = state.initial_answer
history = build_history_prompt(graph_config, question)
# get the entity term extraction dict and properly format it
entity_retlation_term_extractions = state.entity_relation_term_extractions
entity_term_extraction_str = format_entity_term_extraction(
entity_retlation_term_extractions
)
initial_question_answers = state.sub_question_results
addressed_subquestions_with_answers = [
f"Subquestion: {x.question}\nSubanswer:\n{x.answer}"
for x in initial_question_answers
if x.verified_high_quality and x.answer
]
failed_question_list = [
x.question for x in initial_question_answers if not x.verified_high_quality
]
msg = [
HumanMessage(
content=REFINEMENT_QUESTION_DECOMPOSITION_PROMPT_W_INITIAL_SUBQUESTION_ANSWERS.format(
question=question,
history=history,
entity_term_extraction_str=entity_term_extraction_str,
base_answer=base_answer,
answered_subquestions_with_answers=_ANSWERED_SUBQUESTIONS_DIVIDER.join(
addressed_subquestions_with_answers
),
failed_sub_questions="\n - ".join(failed_question_list),
),
)
]
# Grader
model = graph_config.tooling.fast_llm
agent_error: AgentErrorLog | None = None
streamed_tokens: list[BaseMessage_Content] = []
try:
streamed_tokens = run_with_timeout(
AGENT_TIMEOUT_LLM_REFINED_SUBQUESTION_GENERATION,
dispatch_separated,
model.stream(
msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_REFINED_SUBQUESTION_GENERATION,
max_tokens=AGENT_MAX_TOKENS_SUBQUESTION_GENERATION,
),
dispatch_subquestion(1, writer),
sep_callback=dispatch_subquestion_sep(1, writer),
)
except (LLMTimeoutError, TimeoutError):
agent_error = AgentErrorLog(
error_type=AgentLLMErrorType.TIMEOUT,
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
error_result=_llm_node_error_strings.timeout,
)
logger.error("LLM Timeout Error - create refined sub questions")
except LLMRateLimitError:
agent_error = AgentErrorLog(
error_type=AgentLLMErrorType.RATE_LIMIT,
error_message=AGENT_LLM_RATELIMIT_MESSAGE,
error_result=_llm_node_error_strings.rate_limit,
)
logger.error("LLM Rate Limit Error - create refined sub questions")
if agent_error:
refined_sub_question_dict: dict[int, RefinementSubQuestion] = {}
log_result = agent_error.error_result
write_custom_event(
"refined_sub_question_creation_error",
StreamingError(
error="Your LLM was not able to create refined sub questions in time and timed out. Please try again.",
),
writer,
)
else:
response = merge_content(*streamed_tokens)
if isinstance(response, str):
parsed_response = [q for q in response.split("\n") if q.strip() != ""]
else:
raise ValueError("LLM response is not a string")
refined_sub_question_dict = {}
for sub_question_num, sub_question in enumerate(parsed_response):
refined_sub_question = RefinementSubQuestion(
sub_question=sub_question,
sub_question_id=make_question_id(1, sub_question_num + 1),
verified=False,
answered=False,
answer="",
)
refined_sub_question_dict[sub_question_num + 1] = refined_sub_question
log_result = f"Created {len(refined_sub_question_dict)} refined sub questions"
return RefinedQuestionDecompositionUpdate(
refined_sub_questions=refined_sub_question_dict,
agent_refined_start_time=agent_refined_start_time,
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="create refined sub questions",
node_start_time=node_start_time,
result=log_result,
)
],
)

View File

@@ -1,56 +0,0 @@
from datetime import datetime
from typing import cast
from langchain_core.runnables import RunnableConfig
from onyx.agents.agent_search.deep_search.main.states import MainState
from onyx.agents.agent_search.deep_search.main.states import (
RequireRefinemenEvalUpdate,
)
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.utils.timing import log_function_time
@log_function_time(print_only=True)
def decide_refinement_need(
state: MainState, config: RunnableConfig
) -> RequireRefinemenEvalUpdate:
"""
LangGraph node to decide if refinement is needed based on the initial answer and the question.
At present, we always refine.
"""
node_start_time = datetime.now()
graph_config = cast(GraphConfig, config["metadata"]["config"])
decision = graph_config.behavior.allow_refinement
if state.answer_error:
return RequireRefinemenEvalUpdate(
require_refined_answer_eval=False,
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="decide refinement need",
node_start_time=node_start_time,
result="Timeout Error",
)
],
)
log_messages = [
get_langgraph_node_log_string(
graph_component="main",
node_name="decide refinement need",
node_start_time=node_start_time,
result=f"Refinement decision: {decision}",
)
]
return RequireRefinemenEvalUpdate(
require_refined_answer_eval=graph_config.behavior.allow_refinement and decision,
log_messages=log_messages,
)

View File

@@ -1,142 +0,0 @@
from datetime import datetime
from typing import cast
from langchain_core.messages import HumanMessage
from langchain_core.runnables import RunnableConfig
from onyx.agents.agent_search.deep_search.main.operations import logger
from onyx.agents.agent_search.deep_search.main.states import (
EntityTermExtractionUpdate,
)
from onyx.agents.agent_search.deep_search.main.states import MainState
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
trim_prompt_piece,
)
from onyx.agents.agent_search.shared_graph_utils.models import EntityExtractionResult
from onyx.agents.agent_search.shared_graph_utils.models import (
EntityRelationshipTermExtraction,
)
from onyx.agents.agent_search.shared_graph_utils.utils import format_docs
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_ENTITY_TERM_EXTRACTION
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_CONNECT_LLM_ENTITY_TERM_EXTRACTION,
)
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_LLM_ENTITY_TERM_EXTRACTION,
)
from onyx.configs.constants import NUM_EXPLORATORY_DOCS
from onyx.llm.chat_llm import LLMRateLimitError
from onyx.llm.chat_llm import LLMTimeoutError
from onyx.prompts.agent_search import ENTITY_TERM_EXTRACTION_PROMPT
from onyx.prompts.agent_search import ENTITY_TERM_EXTRACTION_PROMPT_JSON_EXAMPLE
from onyx.utils.threadpool_concurrency import run_with_timeout
from onyx.utils.timing import log_function_time
@log_function_time(print_only=True)
def extract_entities_terms(
state: MainState, config: RunnableConfig
) -> EntityTermExtractionUpdate:
"""
LangGraph node to extract entities, relationships, and terms from the initial search results.
This data is used to inform particularly the sub-questions that are created for the refined answer.
"""
node_start_time = datetime.now()
graph_config = cast(GraphConfig, config["metadata"]["config"])
if not graph_config.behavior.allow_refinement:
return EntityTermExtractionUpdate(
entity_relation_term_extractions=EntityRelationshipTermExtraction(
entities=[],
relationships=[],
terms=[],
),
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="extract entities terms",
node_start_time=node_start_time,
result="Refinement is not allowed",
)
],
)
# first four lines duplicates from generate_initial_answer
question = graph_config.inputs.prompt_builder.raw_user_query
initial_search_docs = state.exploratory_search_results[:NUM_EXPLORATORY_DOCS]
# start with the entity/term/extraction
doc_context = format_docs(initial_search_docs)
# Calculation here is only approximate
doc_context = trim_prompt_piece(
config=graph_config.tooling.fast_llm.config,
prompt_piece=doc_context,
reserved_str=ENTITY_TERM_EXTRACTION_PROMPT
+ question
+ ENTITY_TERM_EXTRACTION_PROMPT_JSON_EXAMPLE,
)
msg = [
HumanMessage(
content=ENTITY_TERM_EXTRACTION_PROMPT.format(
question=question, context=doc_context
)
+ ENTITY_TERM_EXTRACTION_PROMPT_JSON_EXAMPLE,
)
]
fast_llm = graph_config.tooling.fast_llm
# Grader
try:
llm_response = run_with_timeout(
AGENT_TIMEOUT_LLM_ENTITY_TERM_EXTRACTION,
fast_llm.invoke,
prompt=msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_ENTITY_TERM_EXTRACTION,
max_tokens=AGENT_MAX_TOKENS_ENTITY_TERM_EXTRACTION,
)
cleaned_response = (
str(llm_response.content).replace("```json\n", "").replace("\n```", "")
)
first_bracket = cleaned_response.find("{")
last_bracket = cleaned_response.rfind("}")
cleaned_response = cleaned_response[first_bracket : last_bracket + 1]
try:
entity_extraction_result = EntityExtractionResult.model_validate_json(
cleaned_response
)
except ValueError:
logger.error(
"Failed to parse LLM response as JSON in Entity-Term Extraction"
)
entity_extraction_result = EntityExtractionResult(
retrieved_entities_relationships=EntityRelationshipTermExtraction(),
)
except (LLMTimeoutError, TimeoutError):
logger.error("LLM Timeout Error - extract entities terms")
entity_extraction_result = EntityExtractionResult(
retrieved_entities_relationships=EntityRelationshipTermExtraction(),
)
except LLMRateLimitError:
logger.error("LLM Rate Limit Error - extract entities terms")
entity_extraction_result = EntityExtractionResult(
retrieved_entities_relationships=EntityRelationshipTermExtraction(),
)
return EntityTermExtractionUpdate(
entity_relation_term_extractions=entity_extraction_result.retrieved_entities_relationships,
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="extract entities terms",
node_start_time=node_start_time,
)
],
)

View File

@@ -1,445 +0,0 @@
from datetime import datetime
from typing import cast
from langchain_core.messages import HumanMessage
from langchain_core.messages import merge_content
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.deep_search.main.models import (
AgentRefinedMetrics,
)
from onyx.agents.agent_search.deep_search.main.operations import get_query_info
from onyx.agents.agent_search.deep_search.main.states import MainState
from onyx.agents.agent_search.deep_search.main.states import (
RefinedAnswerUpdate,
)
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
binary_string_test_after_answer_separator,
)
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
get_prompt_enrichment_components,
)
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
trim_prompt_piece,
)
from onyx.agents.agent_search.shared_graph_utils.calculations import (
get_answer_generation_documents,
)
from onyx.agents.agent_search.shared_graph_utils.constants import AGENT_ANSWER_SEPARATOR
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_LLM_RATELIMIT_MESSAGE,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_LLM_TIMEOUT_MESSAGE,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_POSITIVE_VALUE_STR,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AgentLLMErrorType,
)
from onyx.agents.agent_search.shared_graph_utils.llm import stream_llm_answer
from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLog
from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings
from onyx.agents.agent_search.shared_graph_utils.models import RefinedAgentStats
from onyx.agents.agent_search.shared_graph_utils.operators import (
dedup_inference_section_list,
)
from onyx.agents.agent_search.shared_graph_utils.utils import _should_restrict_tokens
from onyx.agents.agent_search.shared_graph_utils.utils import (
dispatch_main_answer_stop_info,
)
from onyx.agents.agent_search.shared_graph_utils.utils import format_docs
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_deduplicated_structured_subquestion_documents,
)
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
from onyx.agents.agent_search.shared_graph_utils.utils import relevance_from_docs
from onyx.agents.agent_search.shared_graph_utils.utils import (
remove_document_citations,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import ExtendedToolResponse
from onyx.chat.models import StreamingError
from onyx.configs.agent_configs import AGENT_ANSWER_GENERATION_BY_FAST_LLM
from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS
from onyx.configs.agent_configs import AGENT_MAX_STREAMED_DOCS_FOR_REFINED_ANSWER
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_ANSWER_GENERATION
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_VALIDATION
from onyx.configs.agent_configs import AGENT_MIN_ORIG_QUESTION_DOCS
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION,
)
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_VALIDATION,
)
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_LLM_REFINED_ANSWER_GENERATION,
)
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_LLM_REFINED_ANSWER_VALIDATION,
)
from onyx.llm.chat_llm import LLMRateLimitError
from onyx.llm.chat_llm import LLMTimeoutError
from onyx.prompts.agent_search import (
REFINED_ANSWER_PROMPT_W_SUB_QUESTIONS,
)
from onyx.prompts.agent_search import (
REFINED_ANSWER_PROMPT_WO_SUB_QUESTIONS,
)
from onyx.prompts.agent_search import (
REFINED_ANSWER_VALIDATION_PROMPT,
)
from onyx.prompts.agent_search import (
SUB_QUESTION_ANSWER_TEMPLATE_REFINED,
)
from onyx.prompts.agent_search import UNKNOWN_ANSWER
from onyx.tools.tool_implementations.search.search_tool import yield_search_responses
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
from onyx.utils.timing import log_function_time
logger = setup_logger()
_llm_node_error_strings = LLMNodeErrorStrings(
timeout="The LLM timed out. The refined answer could not be generated.",
rate_limit="The LLM encountered a rate limit. The refined answer could not be generated.",
general_error="The LLM encountered an error. The refined answer could not be generated.",
)
@log_function_time(print_only=True)
def generate_validate_refined_answer(
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> RefinedAnswerUpdate:
"""
LangGraph node to generate the refined answer and validate it.
"""
node_start_time = datetime.now()
graph_config = cast(GraphConfig, config["metadata"]["config"])
question = graph_config.inputs.prompt_builder.raw_user_query
prompt_enrichment_components = get_prompt_enrichment_components(graph_config)
persona_contextualized_prompt = (
prompt_enrichment_components.persona_prompts.contextualized_prompt
)
verified_reranked_documents = state.verified_reranked_documents
# get all documents cited in sub-questions
structured_subquestion_docs = get_deduplicated_structured_subquestion_documents(
state.sub_question_results
)
original_question_verified_documents = (
state.orig_question_verified_reranked_documents
)
original_question_retrieved_documents = state.orig_question_retrieved_documents
consolidated_context_docs = structured_subquestion_docs.cited_documents
counter = 0
for original_doc in original_question_verified_documents:
if original_doc not in structured_subquestion_docs.cited_documents:
if (
counter <= AGENT_MIN_ORIG_QUESTION_DOCS
or len(consolidated_context_docs)
< 1.5
* AGENT_MAX_ANSWER_CONTEXT_DOCS # allow for larger context in refinement
):
consolidated_context_docs.append(original_doc)
counter += 1
# sort docs by their scores - though the scores refer to different questions
relevant_docs = dedup_inference_section_list(consolidated_context_docs)
# Create the list of documents to stream out. Start with the
# ones that wil be in the context (or, if len == 0, use docs
# that were retrieved for the original question)
answer_generation_documents = get_answer_generation_documents(
relevant_docs=relevant_docs,
context_documents=structured_subquestion_docs.context_documents,
original_question_docs=original_question_retrieved_documents,
max_docs=AGENT_MAX_STREAMED_DOCS_FOR_REFINED_ANSWER,
)
query_info = get_query_info(state.orig_question_sub_query_retrieval_results)
assert (
graph_config.tooling.search_tool
), "search_tool must be provided for agentic search"
# stream refined answer docs, or original question docs if no relevant docs are found
relevance_list = relevance_from_docs(
answer_generation_documents.streaming_documents
)
for tool_response in yield_search_responses(
query=question,
get_retrieved_sections=lambda: answer_generation_documents.context_documents,
get_final_context_sections=lambda: answer_generation_documents.context_documents,
search_query_info=query_info,
get_section_relevance=lambda: relevance_list,
search_tool=graph_config.tooling.search_tool,
):
write_custom_event(
"tool_response",
ExtendedToolResponse(
id=tool_response.id,
response=tool_response.response,
level=1,
level_question_num=0, # 0, 0 is the base question
),
writer,
)
if len(verified_reranked_documents) > 0:
refined_doc_effectiveness = len(relevant_docs) / len(
verified_reranked_documents
)
else:
refined_doc_effectiveness = 10.0
sub_question_answer_results = state.sub_question_results
answered_sub_question_answer_list: list[str] = []
sub_questions: list[str] = []
initial_answered_sub_questions: set[str] = set()
refined_answered_sub_questions: set[str] = set()
for i, result in enumerate(sub_question_answer_results, 1):
question_level, _ = parse_question_id(result.question_id)
sub_questions.append(result.question)
if (
result.verified_high_quality
and result.answer
and result.answer != UNKNOWN_ANSWER
):
sub_question_type = "initial" if question_level == 0 else "refined"
question_set = (
initial_answered_sub_questions
if question_level == 0
else refined_answered_sub_questions
)
question_set.add(result.question)
answered_sub_question_answer_list.append(
SUB_QUESTION_ANSWER_TEMPLATE_REFINED.format(
sub_question=result.question,
sub_answer=result.answer,
sub_question_num=i,
sub_question_type=sub_question_type,
)
)
# Calculate efficiency
total_answered_questions = (
initial_answered_sub_questions | refined_answered_sub_questions
)
revision_question_efficiency = (
len(total_answered_questions) / len(initial_answered_sub_questions)
if initial_answered_sub_questions
else 10.0 if refined_answered_sub_questions else 1.0
)
sub_question_answer_str = "\n\n------\n\n".join(
set(answered_sub_question_answer_list)
)
initial_answer = state.initial_answer or ""
# Choose appropriate prompt template
base_prompt = (
REFINED_ANSWER_PROMPT_W_SUB_QUESTIONS
if answered_sub_question_answer_list
else REFINED_ANSWER_PROMPT_WO_SUB_QUESTIONS
)
model = (
graph_config.tooling.fast_llm
if AGENT_ANSWER_GENERATION_BY_FAST_LLM
else graph_config.tooling.primary_llm
)
relevant_docs_str = format_docs(answer_generation_documents.context_documents)
relevant_docs_str = trim_prompt_piece(
config=model.config,
prompt_piece=relevant_docs_str,
reserved_str=base_prompt
+ question
+ sub_question_answer_str
+ initial_answer
+ persona_contextualized_prompt
+ prompt_enrichment_components.history,
)
msg = [
HumanMessage(
content=base_prompt.format(
question=question,
history=prompt_enrichment_components.history,
answered_sub_questions=remove_document_citations(
sub_question_answer_str
),
relevant_docs=relevant_docs_str,
initial_answer=(
remove_document_citations(initial_answer)
if initial_answer
else None
),
persona_specification=persona_contextualized_prompt,
date_prompt=prompt_enrichment_components.date_str,
)
)
]
streamed_tokens: list[str] = [""]
dispatch_timings: list[float] = []
agent_error: AgentErrorLog | None = None
try:
streamed_tokens, dispatch_timings = run_with_timeout(
AGENT_TIMEOUT_LLM_REFINED_ANSWER_GENERATION,
lambda: stream_llm_answer(
llm=model,
prompt=msg,
event_name="refined_agent_answer",
writer=writer,
agent_answer_level=1,
agent_answer_question_num=0,
agent_answer_type="agent_level_answer",
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION,
max_tokens=(
AGENT_MAX_TOKENS_ANSWER_GENERATION
if _should_restrict_tokens(model.config)
else None
),
),
)
except (LLMTimeoutError, TimeoutError):
agent_error = AgentErrorLog(
error_type=AgentLLMErrorType.TIMEOUT,
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
error_result=_llm_node_error_strings.timeout,
)
logger.error("LLM Timeout Error - generate refined answer")
except LLMRateLimitError:
agent_error = AgentErrorLog(
error_type=AgentLLMErrorType.RATE_LIMIT,
error_message=AGENT_LLM_RATELIMIT_MESSAGE,
error_result=_llm_node_error_strings.rate_limit,
)
logger.error("LLM Rate Limit Error - generate refined answer")
if agent_error:
write_custom_event(
"initial_agent_answer",
StreamingError(
error=AGENT_LLM_TIMEOUT_MESSAGE,
),
writer,
)
return RefinedAnswerUpdate(
refined_answer=None,
refined_answer_quality=False, # TODO: replace this with the actual check value
refined_agent_stats=None,
agent_refined_end_time=None,
agent_refined_metrics=AgentRefinedMetrics(
refined_doc_boost_factor=0.0,
refined_question_boost_factor=0.0,
duration_s=None,
),
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="generate refined answer",
node_start_time=node_start_time,
result=agent_error.error_result or "An LLM error occurred",
)
],
)
logger.debug(
f"Average dispatch time for refined answer: {sum(dispatch_timings) / len(dispatch_timings)}"
)
dispatch_main_answer_stop_info(1, writer)
response = merge_content(*streamed_tokens)
answer = cast(str, response)
# run a validation step for the refined answer only
msg = [
HumanMessage(
content=REFINED_ANSWER_VALIDATION_PROMPT.format(
question=question,
history=prompt_enrichment_components.history,
answered_sub_questions=sub_question_answer_str,
relevant_docs=relevant_docs_str,
proposed_answer=answer,
persona_specification=persona_contextualized_prompt,
)
)
]
validation_model = graph_config.tooling.fast_llm
try:
validation_response = run_with_timeout(
AGENT_TIMEOUT_LLM_REFINED_ANSWER_VALIDATION,
validation_model.invoke,
prompt=msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_VALIDATION,
max_tokens=AGENT_MAX_TOKENS_VALIDATION,
)
refined_answer_quality = binary_string_test_after_answer_separator(
text=cast(str, validation_response.content),
positive_value=AGENT_POSITIVE_VALUE_STR,
separator=AGENT_ANSWER_SEPARATOR,
)
except (LLMTimeoutError, TimeoutError):
refined_answer_quality = True
logger.error("LLM Timeout Error - validate refined answer")
except LLMRateLimitError:
refined_answer_quality = True
logger.error("LLM Rate Limit Error - validate refined answer")
refined_agent_stats = RefinedAgentStats(
revision_doc_efficiency=refined_doc_effectiveness,
revision_question_efficiency=revision_question_efficiency,
)
agent_refined_end_time = datetime.now()
if state.agent_refined_start_time:
agent_refined_duration = (
agent_refined_end_time - state.agent_refined_start_time
).total_seconds()
else:
agent_refined_duration = None
agent_refined_metrics = AgentRefinedMetrics(
refined_doc_boost_factor=refined_agent_stats.revision_doc_efficiency,
refined_question_boost_factor=refined_agent_stats.revision_question_efficiency,
duration_s=agent_refined_duration,
)
return RefinedAnswerUpdate(
refined_answer=answer,
refined_answer_quality=refined_answer_quality,
refined_agent_stats=refined_agent_stats,
agent_refined_end_time=agent_refined_end_time,
agent_refined_metrics=agent_refined_metrics,
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="generate refined answer",
node_start_time=node_start_time,
)
],
)

View File

@@ -1,42 +0,0 @@
from datetime import datetime
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
AnswerQuestionOutput,
)
from onyx.agents.agent_search.deep_search.main.states import (
SubQuestionResultsUpdate,
)
from onyx.agents.agent_search.shared_graph_utils.operators import (
dedup_inference_sections,
)
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
def ingest_refined_sub_answers(
state: AnswerQuestionOutput,
) -> SubQuestionResultsUpdate:
"""
LangGraph node to ingest and format the refined sub-answers and retrieved documents.
"""
node_start_time = datetime.now()
documents = []
answer_results = state.answer_results
for answer_result in answer_results:
documents.extend(answer_result.verified_reranked_documents)
return SubQuestionResultsUpdate(
# Deduping is done by the documents operator for the main graph
# so we might not need to dedup here
verified_reranked_documents=dedup_inference_sections(documents, []),
sub_question_results=answer_results,
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="ingest refined answers",
node_start_time=node_start_time,
)
],
)

View File

@@ -1,129 +0,0 @@
from datetime import datetime
from typing import cast
from langchain_core.runnables import RunnableConfig
from onyx.agents.agent_search.deep_search.main.models import (
AgentAdditionalMetrics,
)
from onyx.agents.agent_search.deep_search.main.models import AgentTimings
from onyx.agents.agent_search.deep_search.main.operations import logger
from onyx.agents.agent_search.deep_search.main.states import MainOutput
from onyx.agents.agent_search.deep_search.main.states import MainState
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.models import CombinedAgentMetrics
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.db.chat import log_agent_metrics
from onyx.db.chat import log_agent_sub_question_results
def persist_agent_results(state: MainState, config: RunnableConfig) -> MainOutput:
"""
LangGraph node to persist the agent results, including agent logging data.
"""
node_start_time = datetime.now()
agent_start_time = state.agent_start_time
agent_base_end_time = state.agent_base_end_time
agent_refined_start_time = state.agent_refined_start_time
agent_refined_end_time = state.agent_refined_end_time
agent_end_time = agent_refined_end_time or agent_base_end_time
agent_base_duration = None
if agent_base_end_time and agent_start_time:
agent_base_duration = (agent_base_end_time - agent_start_time).total_seconds()
agent_refined_duration = None
if agent_refined_start_time and agent_refined_end_time:
agent_refined_duration = (
agent_refined_end_time - agent_refined_start_time
).total_seconds()
agent_full_duration = None
if agent_end_time and agent_start_time:
agent_full_duration = (agent_end_time - agent_start_time).total_seconds()
agent_type = "refined" if agent_refined_duration else "base"
agent_base_metrics = state.agent_base_metrics
agent_refined_metrics = state.agent_refined_metrics
combined_agent_metrics = CombinedAgentMetrics(
timings=AgentTimings(
base_duration_s=agent_base_duration,
refined_duration_s=agent_refined_duration,
full_duration_s=agent_full_duration,
),
base_metrics=agent_base_metrics,
refined_metrics=agent_refined_metrics,
additional_metrics=AgentAdditionalMetrics(),
)
persona_id = None
graph_config = cast(GraphConfig, config["metadata"]["config"])
if graph_config.inputs.persona:
persona_id = graph_config.inputs.persona.id
user_id = None
assert (
graph_config.tooling.search_tool
), "search_tool must be provided for agentic search"
user = graph_config.tooling.search_tool.user
if user:
user_id = user.id
# log the agent metrics
if graph_config.persistence:
if agent_base_duration is not None:
log_agent_metrics(
db_session=graph_config.persistence.db_session,
user_id=user_id,
persona_id=persona_id,
agent_type=agent_type,
start_time=agent_start_time,
agent_metrics=combined_agent_metrics,
)
# Persist the sub-answer in the database
db_session = graph_config.persistence.db_session
chat_session_id = graph_config.persistence.chat_session_id
primary_message_id = graph_config.persistence.message_id
sub_question_answer_results = state.sub_question_results
log_agent_sub_question_results(
db_session=db_session,
chat_session_id=chat_session_id,
primary_message_id=primary_message_id,
sub_question_answer_results=sub_question_answer_results,
)
main_output = MainOutput(
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="persist agent results",
node_start_time=node_start_time,
)
],
)
for log_message in state.log_messages:
logger.debug(log_message)
if state.agent_base_metrics:
logger.debug(f"Initial loop: {state.agent_base_metrics.duration_s}")
if state.agent_refined_metrics:
logger.debug(f"Refined loop: {state.agent_refined_metrics.duration_s}")
if (
state.agent_base_metrics
and state.agent_refined_metrics
and state.agent_base_metrics.duration_s
and state.agent_refined_metrics.duration_s
):
logger.debug(
f"Total time: {float(state.agent_base_metrics.duration_s) + float(state.agent_refined_metrics.duration_s)}"
)
return main_output

View File

@@ -1,52 +0,0 @@
from datetime import datetime
from typing import cast
from langchain_core.runnables import RunnableConfig
from onyx.agents.agent_search.deep_search.main.states import (
ExploratorySearchUpdate,
)
from onyx.agents.agent_search.deep_search.main.states import MainState
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
build_history_prompt,
)
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import retrieve_search_docs
from onyx.configs.agent_configs import AGENT_EXPLORATORY_SEARCH_RESULTS
from onyx.context.search.models import InferenceSection
def start_agent_search(
state: MainState, config: RunnableConfig
) -> ExploratorySearchUpdate:
"""
LangGraph node to start the agentic search process.
"""
node_start_time = datetime.now()
graph_config = cast(GraphConfig, config["metadata"]["config"])
question = graph_config.inputs.prompt_builder.raw_user_query
history = build_history_prompt(graph_config, question)
# Initial search to inform decomposition. Just get top 3 fits
search_tool = graph_config.tooling.search_tool
assert search_tool, "search_tool must be provided for agentic search"
retrieved_docs: list[InferenceSection] = retrieve_search_docs(search_tool, question)
exploratory_search_results = retrieved_docs[:AGENT_EXPLORATORY_SEARCH_RESULTS]
return ExploratorySearchUpdate(
exploratory_search_results=exploratory_search_results,
previous_history_summary=history,
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="start agent search",
node_start_time=node_start_time,
)
],
)

View File

@@ -1,148 +0,0 @@
from collections.abc import Callable
from langgraph.types import StreamWriter
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkRetrievalStats
from onyx.agents.agent_search.shared_graph_utils.models import InitialAgentResultStats
from onyx.agents.agent_search.shared_graph_utils.models import QueryRetrievalResult
from onyx.agents.agent_search.shared_graph_utils.models import (
SubQuestionAnswerResults,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import StreamStopInfo
from onyx.chat.models import StreamStopReason
from onyx.chat.models import StreamType
from onyx.chat.models import SubQuestionPiece
from onyx.tools.models import SearchQueryInfo
from onyx.utils.logger import setup_logger
logger = setup_logger()
def dispatch_subquestion(
level: int, writer: StreamWriter
) -> Callable[[str, int], None]:
def _helper(sub_question_part: str, sep_num: int) -> None:
write_custom_event(
"decomp_qs",
SubQuestionPiece(
sub_question=sub_question_part,
level=level,
level_question_num=sep_num,
),
writer,
)
return _helper
def dispatch_subquestion_sep(level: int, writer: StreamWriter) -> Callable[[int], None]:
def _helper(sep_num: int) -> None:
write_custom_event(
"stream_finished",
StreamStopInfo(
stop_reason=StreamStopReason.FINISHED,
stream_type=StreamType.SUB_QUESTIONS,
level=level,
level_question_num=sep_num,
),
writer,
)
return _helper
def calculate_initial_agent_stats(
decomp_answer_results: list[SubQuestionAnswerResults],
original_question_stats: AgentChunkRetrievalStats,
) -> InitialAgentResultStats:
initial_agent_result_stats: InitialAgentResultStats = InitialAgentResultStats(
sub_questions={},
original_question={},
agent_effectiveness={},
)
orig_verified = original_question_stats.verified_count
orig_support_score = original_question_stats.verified_avg_scores
verified_document_chunk_ids = []
support_scores = 0.0
for decomp_answer_result in decomp_answer_results:
verified_document_chunk_ids += (
decomp_answer_result.sub_question_retrieval_stats.verified_doc_chunk_ids
)
if (
decomp_answer_result.sub_question_retrieval_stats.verified_avg_scores
is not None
):
support_scores += (
decomp_answer_result.sub_question_retrieval_stats.verified_avg_scores
)
verified_document_chunk_ids = list(set(verified_document_chunk_ids))
# Calculate sub-question stats
if (
verified_document_chunk_ids
and len(verified_document_chunk_ids) > 0
and support_scores is not None
):
sub_question_stats: dict[str, float | int | None] = {
"num_verified_documents": len(verified_document_chunk_ids),
"verified_avg_score": float(support_scores / len(decomp_answer_results)),
}
else:
sub_question_stats = {"num_verified_documents": 0, "verified_avg_score": None}
initial_agent_result_stats.sub_questions.update(sub_question_stats)
# Get original question stats
initial_agent_result_stats.original_question.update(
{
"num_verified_documents": original_question_stats.verified_count,
"verified_avg_score": original_question_stats.verified_avg_scores,
}
)
# Calculate chunk utilization ratio
sub_verified = initial_agent_result_stats.sub_questions["num_verified_documents"]
chunk_ratio: float | None = None
if sub_verified is not None and orig_verified is not None and orig_verified > 0:
chunk_ratio = (float(sub_verified) / orig_verified) if sub_verified > 0 else 0.0
elif sub_verified is not None and sub_verified > 0:
chunk_ratio = 10.0
initial_agent_result_stats.agent_effectiveness["utilized_chunk_ratio"] = chunk_ratio
if (
orig_support_score is None
or orig_support_score == 0.0
and initial_agent_result_stats.sub_questions["verified_avg_score"] is None
):
initial_agent_result_stats.agent_effectiveness["support_ratio"] = None
elif orig_support_score is None or orig_support_score == 0.0:
initial_agent_result_stats.agent_effectiveness["support_ratio"] = 10
elif initial_agent_result_stats.sub_questions["verified_avg_score"] is None:
initial_agent_result_stats.agent_effectiveness["support_ratio"] = 0
else:
initial_agent_result_stats.agent_effectiveness["support_ratio"] = (
initial_agent_result_stats.sub_questions["verified_avg_score"]
/ orig_support_score
)
return initial_agent_result_stats
def get_query_info(results: list[QueryRetrievalResult]) -> SearchQueryInfo:
# Use the query info from the base document retrieval
# this is used for some fields that are the same across the searches done
query_info = None
for result in results:
if result.query_info is not None:
query_info = result.query_info
break
assert query_info is not None, "must have query info"
return query_info

View File

@@ -1,175 +0,0 @@
from datetime import datetime
from operator import add
from typing import Annotated
from typing import TypedDict
from pydantic import BaseModel
from onyx.agents.agent_search.core_state import CoreState
from onyx.agents.agent_search.deep_search.main.models import AgentBaseMetrics
from onyx.agents.agent_search.deep_search.main.models import (
AgentRefinedMetrics,
)
from onyx.agents.agent_search.deep_search.main.models import (
RefinementSubQuestion,
)
from onyx.agents.agent_search.orchestration.states import ToolCallUpdate
from onyx.agents.agent_search.orchestration.states import ToolChoiceInput
from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkRetrievalStats
from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLog
from onyx.agents.agent_search.shared_graph_utils.models import (
EntityRelationshipTermExtraction,
)
from onyx.agents.agent_search.shared_graph_utils.models import InitialAgentResultStats
from onyx.agents.agent_search.shared_graph_utils.models import QueryRetrievalResult
from onyx.agents.agent_search.shared_graph_utils.models import RefinedAgentStats
from onyx.agents.agent_search.shared_graph_utils.models import (
SubQuestionAnswerResults,
)
from onyx.agents.agent_search.shared_graph_utils.operators import (
dedup_inference_sections,
)
from onyx.agents.agent_search.shared_graph_utils.operators import (
dedup_question_answer_results,
)
from onyx.context.search.models import InferenceSection
### States ###
class LoggerUpdate(BaseModel):
log_messages: Annotated[list[str], add] = []
class RefinedAgentStartStats(BaseModel):
agent_refined_start_time: datetime | None = None
class RefinedAgentEndStats(BaseModel):
agent_refined_end_time: datetime | None = None
agent_refined_metrics: AgentRefinedMetrics = AgentRefinedMetrics()
class InitialQuestionDecompositionUpdate(
RefinedAgentStartStats, RefinedAgentEndStats, LoggerUpdate
):
agent_start_time: datetime | None = None
previous_history: str | None = None
initial_sub_questions: list[str] = []
class ExploratorySearchUpdate(LoggerUpdate):
exploratory_search_results: list[InferenceSection] = []
previous_history_summary: str | None = None
class InitialRefinedAnswerComparisonUpdate(LoggerUpdate):
"""
Evaluation of whether the refined answer is better than the initial answer
"""
refined_answer_improvement_eval: bool = False
class InitialAnswerUpdate(LoggerUpdate):
"""
Initial answer information
"""
initial_answer: str | None = None
answer_error: AgentErrorLog | None = None
initial_agent_stats: InitialAgentResultStats | None = None
generated_sub_questions: list[str] = []
agent_base_end_time: datetime | None = None
agent_base_metrics: AgentBaseMetrics | None = None
class RefinedAnswerUpdate(RefinedAgentEndStats, LoggerUpdate):
"""
Refined answer information
"""
refined_answer: str | None = None
answer_error: AgentErrorLog | None = None
refined_agent_stats: RefinedAgentStats | None = None
refined_answer_quality: bool = False
class InitialAnswerQualityUpdate(LoggerUpdate):
"""
Initial answer quality evaluation
"""
initial_answer_quality_eval: bool = False
class RequireRefinemenEvalUpdate(LoggerUpdate):
require_refined_answer_eval: bool = True
class SubQuestionResultsUpdate(LoggerUpdate):
verified_reranked_documents: Annotated[
list[InferenceSection], dedup_inference_sections
] = []
context_documents: Annotated[list[InferenceSection], dedup_inference_sections] = []
cited_documents: Annotated[list[InferenceSection], dedup_inference_sections] = (
[]
) # cited docs from sub-answers are used for answer context
sub_question_results: Annotated[
list[SubQuestionAnswerResults], dedup_question_answer_results
] = []
class OrigQuestionRetrievalUpdate(LoggerUpdate):
orig_question_retrieved_documents: Annotated[
list[InferenceSection], dedup_inference_sections
]
orig_question_verified_reranked_documents: Annotated[
list[InferenceSection], dedup_inference_sections
]
orig_question_sub_query_retrieval_results: list[QueryRetrievalResult] = []
orig_question_retrieval_stats: AgentChunkRetrievalStats = AgentChunkRetrievalStats()
class EntityTermExtractionUpdate(LoggerUpdate):
entity_relation_term_extractions: EntityRelationshipTermExtraction = (
EntityRelationshipTermExtraction()
)
class RefinedQuestionDecompositionUpdate(RefinedAgentStartStats, LoggerUpdate):
refined_sub_questions: dict[int, RefinementSubQuestion] = {}
## Graph Input State
class MainInput(CoreState):
pass
## Graph State
class MainState(
# This includes the core state
MainInput,
ToolChoiceInput,
ToolCallUpdate,
ToolChoiceUpdate,
InitialQuestionDecompositionUpdate,
InitialAnswerUpdate,
SubQuestionResultsUpdate,
OrigQuestionRetrievalUpdate,
EntityTermExtractionUpdate,
InitialAnswerQualityUpdate,
RequireRefinemenEvalUpdate,
RefinedQuestionDecompositionUpdate,
RefinedAnswerUpdate,
RefinedAgentStartStats,
RefinedAgentEndStats,
InitialRefinedAnswerComparisonUpdate,
ExploratorySearchUpdate,
):
pass
## Graph Output State - presently not used
class MainOutput(TypedDict):
log_messages: list[str]

View File

@@ -1,33 +0,0 @@
from collections.abc import Hashable
from datetime import datetime
from langgraph.types import Send
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
SubQuestionAnsweringInput,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
ExpandedRetrievalInput,
)
from onyx.utils.logger import setup_logger
logger = setup_logger()
def send_to_expanded_refined_retrieval(
state: SubQuestionAnsweringInput,
) -> Send | Hashable:
"""
LangGraph edge to sends a refined sub-question extended retrieval.
"""
logger.debug("sending to expanded retrieval for follow up question via edge")
datetime.now()
return Send(
"refined_sub_question_expanded_retrieval",
ExpandedRetrievalInput(
question=state.question,
sub_question_id=state.question_id,
base_search=False,
log_messages=[f"{datetime.now()} -- Sending to expanded retrieval"],
),
)

View File

@@ -1,132 +0,0 @@
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.nodes.check_sub_answer import (
check_sub_answer,
)
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.nodes.format_sub_answer import (
format_sub_answer,
)
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.nodes.generate_sub_answer import (
generate_sub_answer,
)
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.nodes.ingest_retrieved_documents import (
ingest_retrieved_documents,
)
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
AnswerQuestionOutput,
)
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
AnswerQuestionState,
)
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
SubQuestionAnsweringInput,
)
from onyx.agents.agent_search.deep_search.refinement.consolidate_sub_answers.edges import (
send_to_expanded_refined_retrieval,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.graph_builder import (
expanded_retrieval_graph_builder,
)
from onyx.utils.logger import setup_logger
logger = setup_logger()
def answer_refined_query_graph_builder() -> StateGraph:
"""
LangGraph graph builder for the refined sub-answer generation process.
"""
graph = StateGraph(
state_schema=AnswerQuestionState,
input=SubQuestionAnsweringInput,
output=AnswerQuestionOutput,
)
### Add nodes ###
# Subgraph for the expanded retrieval process
expanded_retrieval = expanded_retrieval_graph_builder().compile()
graph.add_node(
node="refined_sub_question_expanded_retrieval",
action=expanded_retrieval,
)
# Ingest the retrieved documents
graph.add_node(
node="ingest_refined_retrieval",
action=ingest_retrieved_documents,
)
# Generate the refined sub-answer
graph.add_node(
node="generate_refined_sub_answer",
action=generate_sub_answer,
)
# Check if the refined sub-answer is correct
graph.add_node(
node="refined_sub_answer_check",
action=check_sub_answer,
)
# Format the refined sub-answer
graph.add_node(
node="format_refined_sub_answer",
action=format_sub_answer,
)
### Add edges ###
graph.add_conditional_edges(
source=START,
path=send_to_expanded_refined_retrieval,
path_map=["refined_sub_question_expanded_retrieval"],
)
graph.add_edge(
start_key="refined_sub_question_expanded_retrieval",
end_key="ingest_refined_retrieval",
)
graph.add_edge(
start_key="ingest_refined_retrieval",
end_key="generate_refined_sub_answer",
)
graph.add_edge(
start_key="generate_refined_sub_answer",
end_key="refined_sub_answer_check",
)
graph.add_edge(
start_key="refined_sub_answer_check",
end_key="format_refined_sub_answer",
)
graph.add_edge(
start_key="format_refined_sub_answer",
end_key=END,
)
return graph
if __name__ == "__main__":
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.llm.factory import get_default_llms
from onyx.context.search.models import SearchRequest
graph = answer_refined_query_graph_builder()
compiled_graph = graph.compile()
primary_llm, fast_llm = get_default_llms()
search_request = SearchRequest(
query="what can you do with onyx or danswer?",
)
with get_session_with_current_tenant() as db_session:
inputs = SubQuestionAnsweringInput(
question="what can you do with onyx?",
question_id="0_0",
log_messages=[],
)
for thing in compiled_graph.stream(
input=inputs,
stream_mode="custom",
):
logger.debug(thing)

View File

@@ -1,44 +0,0 @@
from collections.abc import Hashable
from typing import cast
from langchain_core.runnables.config import RunnableConfig
from langgraph.types import Send
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
ExpandedRetrievalState,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
RetrievalInput,
)
from onyx.agents.agent_search.models import GraphConfig
def parallel_retrieval_edge(
state: ExpandedRetrievalState, config: RunnableConfig
) -> list[Send | Hashable]:
"""
LangGraph edge to parallelize the retrieval process for each of the
generated sub-queries and the original question.
"""
graph_config = cast(GraphConfig, config["metadata"]["config"])
question = (
state.question
if state.question
else graph_config.inputs.prompt_builder.raw_user_query
)
query_expansions = state.expanded_queries + [question]
return [
Send(
"retrieve_documents",
RetrievalInput(
query_to_retrieve=query,
question=question,
base_search=False,
sub_question_id=state.sub_question_id,
log_messages=[],
),
)
for query in query_expansions
]

View File

@@ -1,161 +0,0 @@
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.edges import (
parallel_retrieval_edge,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.nodes.expand_queries import (
expand_queries,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.nodes.format_queries import (
format_queries,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.nodes.format_results import (
format_results,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.nodes.kickoff_verification import (
kickoff_verification,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.nodes.rerank_documents import (
rerank_documents,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.nodes.retrieve_documents import (
retrieve_documents,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.nodes.verify_documents import (
verify_documents,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
ExpandedRetrievalInput,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
ExpandedRetrievalOutput,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
ExpandedRetrievalState,
)
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
from onyx.utils.logger import setup_logger
logger = setup_logger()
def expanded_retrieval_graph_builder() -> StateGraph:
"""
LangGraph graph builder for the expanded retrieval process.
"""
graph = StateGraph(
state_schema=ExpandedRetrievalState,
input=ExpandedRetrievalInput,
output=ExpandedRetrievalOutput,
)
### Add nodes ###
# Convert the question into multiple sub-queries
graph.add_node(
node="expand_queries",
action=expand_queries,
)
# Format the sub-queries into a list of strings
graph.add_node(
node="format_queries",
action=format_queries,
)
# Retrieve the documents for each sub-query
graph.add_node(
node="retrieve_documents",
action=retrieve_documents,
)
# Start verification process that the documents are relevant to the question (not the query)
graph.add_node(
node="kickoff_verification",
action=kickoff_verification,
)
# Verify that a given document is relevant to the question (not the query)
graph.add_node(
node="verify_documents",
action=verify_documents,
)
# Rerank the documents that have been verified
graph.add_node(
node="rerank_documents",
action=rerank_documents,
)
# Format the results into a list of strings
graph.add_node(
node="format_results",
action=format_results,
)
### Add edges ###
graph.add_edge(
start_key=START,
end_key="expand_queries",
)
graph.add_edge(
start_key="expand_queries",
end_key="format_queries",
)
graph.add_conditional_edges(
source="format_queries",
path=parallel_retrieval_edge,
path_map=["retrieve_documents"],
)
graph.add_edge(
start_key="retrieve_documents",
end_key="kickoff_verification",
)
graph.add_edge(
start_key="verify_documents",
end_key="rerank_documents",
)
graph.add_edge(
start_key="rerank_documents",
end_key="format_results",
)
graph.add_edge(
start_key="format_results",
end_key=END,
)
return graph
if __name__ == "__main__":
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.llm.factory import get_default_llms
from onyx.context.search.models import SearchRequest
graph = expanded_retrieval_graph_builder()
compiled_graph = graph.compile()
primary_llm, fast_llm = get_default_llms()
search_request = SearchRequest(
query="what can you do with onyx or danswer?",
)
with get_session_with_current_tenant() as db_session:
graph_config, search_tool = get_test_config(
db_session, primary_llm, fast_llm, search_request
)
inputs = ExpandedRetrievalInput(
question="what can you do with onyx?",
base_search=False,
sub_question_id=None,
log_messages=[],
)
for thing in compiled_graph.stream(
input=inputs,
config={"configurable": {"config": graph_config}},
stream_mode="custom",
subgraphs=True,
):
logger.debug(thing)

View File

@@ -1,13 +0,0 @@
from pydantic import BaseModel
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkRetrievalStats
from onyx.agents.agent_search.shared_graph_utils.models import QueryRetrievalResult
from onyx.context.search.models import InferenceSection
class QuestionRetrievalResult(BaseModel):
expanded_query_results: list[QueryRetrievalResult] = []
retrieved_documents: list[InferenceSection] = []
verified_reranked_documents: list[InferenceSection] = []
context_documents: list[InferenceSection] = []
retrieval_stats: AgentChunkRetrievalStats = AgentChunkRetrievalStats()

View File

@@ -1,139 +0,0 @@
from datetime import datetime
from typing import cast
from langchain_core.messages import HumanMessage
from langchain_core.messages import merge_message_runs
from langchain_core.runnables.config import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.operations import (
dispatch_subquery,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
ExpandedRetrievalInput,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
QueryExpansionUpdate,
)
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_LLM_RATELIMIT_MESSAGE,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_LLM_TIMEOUT_MESSAGE,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AgentLLMErrorType,
)
from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLog
from onyx.agents.agent_search.shared_graph_utils.models import BaseMessage_Content
from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings
from onyx.agents.agent_search.shared_graph_utils.utils import dispatch_separated
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_SUBQUERY_GENERATION
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_CONNECT_LLM_QUERY_REWRITING_GENERATION,
)
from onyx.configs.agent_configs import AGENT_TIMEOUT_LLM_QUERY_REWRITING_GENERATION
from onyx.llm.chat_llm import LLMRateLimitError
from onyx.llm.chat_llm import LLMTimeoutError
from onyx.prompts.agent_search import (
QUERY_REWRITING_PROMPT,
)
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
from onyx.utils.timing import log_function_time
logger = setup_logger()
_llm_node_error_strings = LLMNodeErrorStrings(
timeout="Query rewriting failed due to LLM timeout - the original question will be used.",
rate_limit="Query rewriting failed due to LLM rate limit - the original question will be used.",
general_error="Query rewriting failed due to LLM error - the original question will be used.",
)
@log_function_time(print_only=True)
def expand_queries(
state: ExpandedRetrievalInput,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> QueryExpansionUpdate:
"""
LangGraph node to expand a question into multiple search queries.
"""
# Sometimes we want to expand the original question, sometimes we want to expand a sub-question.
# When we are running this node on the original question, no question is explictly passed in.
# Instead, we use the original question from the search request.
graph_config = cast(GraphConfig, config["metadata"]["config"])
node_start_time = datetime.now()
question = state.question
model = graph_config.tooling.fast_llm
sub_question_id = state.sub_question_id
if sub_question_id is None:
level, question_num = 0, 0
else:
level, question_num = parse_question_id(sub_question_id)
msg = [
HumanMessage(
content=QUERY_REWRITING_PROMPT.format(question=question),
)
]
agent_error: AgentErrorLog | None = None
llm_response_list: list[BaseMessage_Content] = []
llm_response = ""
rewritten_queries = []
try:
llm_response_list = run_with_timeout(
AGENT_TIMEOUT_LLM_QUERY_REWRITING_GENERATION,
dispatch_separated,
model.stream(
prompt=msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_QUERY_REWRITING_GENERATION,
max_tokens=AGENT_MAX_TOKENS_SUBQUERY_GENERATION,
),
dispatch_subquery(level, question_num, writer),
)
llm_response = merge_message_runs(llm_response_list, chunk_separator="")[
0
].content
rewritten_queries = llm_response.split("\n")
log_result = f"Number of expanded queries: {len(rewritten_queries)}"
except (LLMTimeoutError, TimeoutError):
agent_error = AgentErrorLog(
error_type=AgentLLMErrorType.TIMEOUT,
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
error_result=_llm_node_error_strings.timeout,
)
logger.error("LLM Timeout Error - expand queries")
log_result = agent_error.error_result
except LLMRateLimitError:
agent_error = AgentErrorLog(
error_type=AgentLLMErrorType.RATE_LIMIT,
error_message=AGENT_LLM_RATELIMIT_MESSAGE,
error_result=_llm_node_error_strings.rate_limit,
)
logger.error("LLM Rate Limit Error - expand queries")
log_result = agent_error.error_result
# use subquestion as query if query generation fails
return QueryExpansionUpdate(
expanded_queries=rewritten_queries,
log_messages=[
get_langgraph_node_log_string(
graph_component="shared - expanded retrieval",
node_name="expand queries",
node_start_time=node_start_time,
result=log_result,
)
],
)

View File

@@ -1,19 +0,0 @@
from langchain_core.runnables.config import RunnableConfig
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
ExpandedRetrievalState,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
QueryExpansionUpdate,
)
def format_queries(
state: ExpandedRetrievalState, config: RunnableConfig
) -> QueryExpansionUpdate:
"""
LangGraph node to format the expanded queries into a list of strings.
"""
return QueryExpansionUpdate(
expanded_queries=state.expanded_queries,
)

View File

@@ -1,91 +0,0 @@
from typing import cast
from langchain_core.runnables.config import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.deep_search.main.operations import get_query_info
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.models import (
QuestionRetrievalResult,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.operations import (
calculate_sub_question_retrieval_stats,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
ExpandedRetrievalState,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
ExpandedRetrievalUpdate,
)
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkRetrievalStats
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
from onyx.agents.agent_search.shared_graph_utils.utils import relevance_from_docs
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import ExtendedToolResponse
from onyx.tools.tool_implementations.search.search_tool import yield_search_responses
def format_results(
state: ExpandedRetrievalState,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> ExpandedRetrievalUpdate:
"""
LangGraph node that constructs the proper expanded retrieval format.
"""
level, question_num = parse_question_id(state.sub_question_id or "0_0")
query_info = get_query_info(state.query_retrieval_results)
graph_config = cast(GraphConfig, config["metadata"]["config"])
# Main question docs will be sent later after aggregation and deduping with sub-question docs
reranked_documents = state.reranked_documents
if not (level == 0 and question_num == 0):
if len(reranked_documents) == 0:
# The sub-question is used as the last query. If no verified documents are found, stream
# the top 3 for that one. We may want to revisit this.
reranked_documents = state.query_retrieval_results[-1].retrieved_documents[
:3
]
assert (
graph_config.tooling.search_tool
), "search_tool must be provided for agentic search"
relevance_list = relevance_from_docs(reranked_documents)
for tool_response in yield_search_responses(
query=state.question,
get_retrieved_sections=lambda: reranked_documents,
get_final_context_sections=lambda: reranked_documents,
search_query_info=query_info,
get_section_relevance=lambda: relevance_list,
search_tool=graph_config.tooling.search_tool,
):
write_custom_event(
"tool_response",
ExtendedToolResponse(
id=tool_response.id,
response=tool_response.response,
level=level,
level_question_num=question_num,
),
writer,
)
sub_question_retrieval_stats = calculate_sub_question_retrieval_stats(
verified_documents=state.verified_documents,
expanded_retrieval_results=state.query_retrieval_results,
)
if sub_question_retrieval_stats is None:
sub_question_retrieval_stats = AgentChunkRetrievalStats()
return ExpandedRetrievalUpdate(
expanded_retrieval_result=QuestionRetrievalResult(
expanded_query_results=state.query_retrieval_results,
retrieved_documents=state.retrieved_documents,
verified_reranked_documents=reranked_documents,
context_documents=state.reranked_documents,
retrieval_stats=sub_question_retrieval_stats,
),
)

View File

@@ -1,45 +0,0 @@
from typing import Literal
from langchain_core.runnables.config import RunnableConfig
from langgraph.types import Command
from langgraph.types import Send
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
DocVerificationInput,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
ExpandedRetrievalState,
)
from onyx.configs.agent_configs import AGENT_MAX_VERIFICATION_HITS
def kickoff_verification(
state: ExpandedRetrievalState,
config: RunnableConfig,
) -> Command[Literal["verify_documents"]]:
"""
LangGraph node (Command node!) that kicks off the verification process for the retrieved documents.
Note that this is a Command node and does the routing as well. (At present, no state updates
are done here, so this could be replaced with an edge. But we may choose to make state
updates later.)
"""
retrieved_documents = state.retrieved_documents[:AGENT_MAX_VERIFICATION_HITS]
verification_question = state.question
sub_question_id = state.sub_question_id
return Command(
update={},
goto=[
Send(
node="verify_documents",
arg=DocVerificationInput(
retrieved_document_to_verify=document,
question=verification_question,
base_search=False,
sub_question_id=sub_question_id,
log_messages=[],
),
)
for document in retrieved_documents
],
)

View File

@@ -1,110 +0,0 @@
from datetime import datetime
from typing import cast
from langchain_core.runnables.config import RunnableConfig
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.operations import (
logger,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
DocRerankingUpdate,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
ExpandedRetrievalState,
)
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.calculations import get_fit_scores
from onyx.agents.agent_search.shared_graph_utils.models import RetrievalFitStats
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.configs.agent_configs import AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS
from onyx.configs.agent_configs import AGENT_RERANKING_STATS
from onyx.context.search.models import InferenceSection
from onyx.context.search.models import RerankingDetails
from onyx.context.search.postprocessing.postprocessing import rerank_sections
from onyx.context.search.postprocessing.postprocessing import should_rerank
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.search_settings import get_current_search_settings
from onyx.utils.timing import log_function_time
@log_function_time(print_only=True)
def rerank_documents(
state: ExpandedRetrievalState, config: RunnableConfig
) -> DocRerankingUpdate:
"""
LangGraph node to rerank the retrieved and verified documents. A part of the
pre-existing pipeline is used here.
"""
node_start_time = datetime.now()
verified_documents = state.verified_documents
# Rerank post retrieval and verification. First, create a search query
# then create the list of reranked sections
# If no question defined/question is None in the state, use the original
# question from the search request as query
graph_config = cast(GraphConfig, config["metadata"]["config"])
question = (
state.question
if state.question
else graph_config.inputs.prompt_builder.raw_user_query
)
assert (
graph_config.tooling.search_tool
), "search_tool must be provided for agentic search"
# Note that these are passed in values from the API and are overrides which are typically None
rerank_settings = graph_config.inputs.rerank_settings
allow_agent_reranking = graph_config.behavior.allow_agent_reranking
if rerank_settings is None:
with get_session_with_current_tenant() as db_session:
search_settings = get_current_search_settings(db_session)
if not search_settings.disable_rerank_for_streaming:
rerank_settings = RerankingDetails.from_db_model(search_settings)
# Initial default: no reranking. Will be overwritten below if reranking is warranted
reranked_documents = verified_documents
if should_rerank(rerank_settings) and len(verified_documents) > 0:
if len(verified_documents) > 1:
if not allow_agent_reranking:
logger.info("Use of local rerank model without GPU, skipping reranking")
# No reranking, stay with verified_documents as default
else:
# Reranking is warranted, use the rerank_sections functon
reranked_documents = rerank_sections(
query_str=question,
# if runnable, then rerank_settings is not None
rerank_settings=cast(RerankingDetails, rerank_settings),
sections_to_rerank=verified_documents,
)
else:
logger.warning(
f"{len(verified_documents)} verified document(s) found, skipping reranking"
)
# No reranking, stay with verified_documents as default
else:
logger.warning("No reranking settings found, using unranked documents")
# No reranking, stay with verified_documents as default
if AGENT_RERANKING_STATS:
fit_scores = get_fit_scores(verified_documents, reranked_documents)
else:
fit_scores = RetrievalFitStats(fit_score_lift=0, rerank_effect=0, fit_scores={})
return DocRerankingUpdate(
reranked_documents=[
doc for doc in reranked_documents if isinstance(doc, InferenceSection)
][:AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS],
sub_question_retrieval_stats=fit_scores,
log_messages=[
get_langgraph_node_log_string(
graph_component="shared - expanded retrieval",
node_name="rerank documents",
node_start_time=node_start_time,
)
],
)

View File

@@ -1,119 +0,0 @@
from datetime import datetime
from typing import cast
from langchain_core.runnables.config import RunnableConfig
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.operations import (
logger,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
DocRetrievalUpdate,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
RetrievalInput,
)
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.calculations import get_fit_scores
from onyx.agents.agent_search.shared_graph_utils.models import QueryRetrievalResult
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.configs.agent_configs import AGENT_MAX_QUERY_RETRIEVAL_RESULTS
from onyx.configs.agent_configs import AGENT_RETRIEVAL_STATS
from onyx.context.search.models import InferenceSection
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.tools.models import SearchQueryInfo
from onyx.tools.models import SearchToolOverrideKwargs
from onyx.tools.tool_implementations.search.search_tool import (
SEARCH_RESPONSE_SUMMARY_ID,
)
from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary
from onyx.utils.timing import log_function_time
@log_function_time(print_only=True)
def retrieve_documents(
state: RetrievalInput, config: RunnableConfig
) -> DocRetrievalUpdate:
"""
LangGraph node to retrieve documents from the search tool.
"""
node_start_time = datetime.now()
query_to_retrieve = state.query_to_retrieve
graph_config = cast(GraphConfig, config["metadata"]["config"])
search_tool = graph_config.tooling.search_tool
retrieved_docs: list[InferenceSection] = []
if not query_to_retrieve.strip():
logger.warning("Empty query, skipping retrieval")
return DocRetrievalUpdate(
query_retrieval_results=[],
retrieved_documents=[],
log_messages=[
get_langgraph_node_log_string(
graph_component="shared - expanded retrieval",
node_name="retrieve documents",
node_start_time=node_start_time,
result="Empty query, skipping retrieval",
)
],
)
query_info = None
if search_tool is None:
raise ValueError("search_tool must be provided for agentic search")
callback_container: list[list[InferenceSection]] = []
# new db session to avoid concurrency issues
with get_session_with_current_tenant() as db_session:
for tool_response in search_tool.run(
query=query_to_retrieve,
override_kwargs=SearchToolOverrideKwargs(
force_no_rerank=True,
alternate_db_session=db_session,
retrieved_sections_callback=callback_container.append,
skip_query_analysis=not state.base_search,
),
):
# get retrieved docs to send to the rest of the graph
if tool_response.id == SEARCH_RESPONSE_SUMMARY_ID:
response = cast(SearchResponseSummary, tool_response.response)
retrieved_docs = response.top_sections
query_info = SearchQueryInfo(
predicted_search=response.predicted_search,
final_filters=response.final_filters,
recency_bias_multiplier=response.recency_bias_multiplier,
)
break
retrieved_docs = retrieved_docs[:AGENT_MAX_QUERY_RETRIEVAL_RESULTS]
if AGENT_RETRIEVAL_STATS:
pre_rerank_docs = callback_container[0] if callback_container else []
fit_scores = get_fit_scores(
pre_rerank_docs,
retrieved_docs,
)
else:
fit_scores = None
expanded_retrieval_result = QueryRetrievalResult(
query=query_to_retrieve,
retrieved_documents=retrieved_docs,
stats=fit_scores,
query_info=query_info,
)
return DocRetrievalUpdate(
query_retrieval_results=[expanded_retrieval_result],
retrieved_documents=retrieved_docs,
log_messages=[
get_langgraph_node_log_string(
graph_component="shared - expanded retrieval",
node_name="retrieve documents",
node_start_time=node_start_time,
)
],
)

View File

@@ -1,127 +0,0 @@
from datetime import datetime
from typing import cast
from langchain_core.messages import BaseMessage
from langchain_core.messages import HumanMessage
from langchain_core.runnables.config import RunnableConfig
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
DocVerificationInput,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
DocVerificationUpdate,
)
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
binary_string_test,
)
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
trim_prompt_piece,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_POSITIVE_VALUE_STR,
)
from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_VALIDATION
from onyx.configs.agent_configs import AGENT_TIMEOUT_CONNECT_LLM_DOCUMENT_VERIFICATION
from onyx.configs.agent_configs import AGENT_TIMEOUT_LLM_DOCUMENT_VERIFICATION
from onyx.llm.chat_llm import LLMRateLimitError
from onyx.llm.chat_llm import LLMTimeoutError
from onyx.prompts.agent_search import (
DOCUMENT_VERIFICATION_PROMPT,
)
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
from onyx.utils.timing import log_function_time
logger = setup_logger()
_llm_node_error_strings = LLMNodeErrorStrings(
timeout="The LLM timed out. The document could not be verified. The document will be treated as 'relevant'",
rate_limit="The LLM encountered a rate limit. The document could not be verified. The document will be treated as 'relevant'",
general_error="The LLM encountered an error. The document could not be verified. The document will be treated as 'relevant'",
)
@log_function_time(print_only=True)
def verify_documents(
state: DocVerificationInput, config: RunnableConfig
) -> DocVerificationUpdate:
"""
LangGraph node to check whether the document is relevant for the original user question
Args:
state (DocVerificationInput): The current state
config (RunnableConfig): Configuration containing AgentSearchConfig
Updates:
verified_documents: list[InferenceSection]
"""
node_start_time = datetime.now()
question = state.question
retrieved_document_to_verify = state.retrieved_document_to_verify
document_content = retrieved_document_to_verify.combined_content
graph_config = cast(GraphConfig, config["metadata"]["config"])
fast_llm = graph_config.tooling.fast_llm
document_content = trim_prompt_piece(
config=fast_llm.config,
prompt_piece=document_content,
reserved_str=DOCUMENT_VERIFICATION_PROMPT + question,
)
msg = [
HumanMessage(
content=DOCUMENT_VERIFICATION_PROMPT.format(
question=question, document_content=document_content
)
)
]
response: BaseMessage | None = None
verified_documents = [
retrieved_document_to_verify
] # default is to treat document as relevant
try:
response = run_with_timeout(
AGENT_TIMEOUT_LLM_DOCUMENT_VERIFICATION,
fast_llm.invoke,
prompt=msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_DOCUMENT_VERIFICATION,
max_tokens=AGENT_MAX_TOKENS_VALIDATION,
)
assert isinstance(response.content, str)
if not binary_string_test(
text=response.content, positive_value=AGENT_POSITIVE_VALUE_STR
):
verified_documents = []
except (LLMTimeoutError, TimeoutError):
# In this case, we decide to continue and don't raise an error, as
# little harm in letting some docs through that are less relevant.
logger.error("LLM Timeout Error - verify documents")
except LLMRateLimitError:
# In this case, we decide to continue and don't raise an error, as
# little harm in letting some docs through that are less relevant.
logger.error("LLM Rate Limit Error - verify documents")
return DocVerificationUpdate(
verified_documents=verified_documents,
log_messages=[
get_langgraph_node_log_string(
graph_component="shared - expanded retrieval",
node_name="verify documents",
node_start_time=node_start_time,
)
],
)

View File

@@ -1,93 +0,0 @@
from collections import defaultdict
from collections.abc import Callable
import numpy as np
from langgraph.types import StreamWriter
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkRetrievalStats
from onyx.agents.agent_search.shared_graph_utils.models import QueryRetrievalResult
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import SubQueryPiece
from onyx.context.search.models import InferenceSection
from onyx.utils.logger import setup_logger
logger = setup_logger()
def dispatch_subquery(
level: int, question_num: int, writer: StreamWriter
) -> Callable[[str, int], None]:
def helper(token: str, num: int) -> None:
write_custom_event(
"subqueries",
SubQueryPiece(
sub_query=token,
level=level,
level_question_num=question_num,
query_id=num,
),
writer,
)
return helper
def calculate_sub_question_retrieval_stats(
verified_documents: list[InferenceSection],
expanded_retrieval_results: list[QueryRetrievalResult],
) -> AgentChunkRetrievalStats:
chunk_scores: dict[str, dict[str, list[int | float]]] = defaultdict(
lambda: defaultdict(list)
)
for expanded_retrieval_result in expanded_retrieval_results:
for doc in expanded_retrieval_result.retrieved_documents:
doc_chunk_id = f"{doc.center_chunk.document_id}_{doc.center_chunk.chunk_id}"
if doc.center_chunk.score is not None:
chunk_scores[doc_chunk_id]["score"].append(doc.center_chunk.score)
verified_doc_chunk_ids = [
f"{verified_document.center_chunk.document_id}_{verified_document.center_chunk.chunk_id}"
for verified_document in verified_documents
]
dismissed_doc_chunk_ids = []
raw_chunk_stats_counts: dict[str, int] = defaultdict(int)
raw_chunk_stats_scores: dict[str, float] = defaultdict(float)
for doc_chunk_id, chunk_data in chunk_scores.items():
valid_chunk_scores = [
score for score in chunk_data["score"] if score is not None
]
key = "verified" if doc_chunk_id in verified_doc_chunk_ids else "rejected"
raw_chunk_stats_counts[f"{key}_count"] += 1
raw_chunk_stats_scores[f"{key}_scores"] += float(np.mean(valid_chunk_scores))
if key == "rejected":
dismissed_doc_chunk_ids.append(doc_chunk_id)
if raw_chunk_stats_counts["verified_count"] == 0:
verified_avg_scores = 0.0
else:
verified_avg_scores = raw_chunk_stats_scores["verified_scores"] / float(
raw_chunk_stats_counts["verified_count"]
)
rejected_scores = raw_chunk_stats_scores.get("rejected_scores")
if rejected_scores is not None:
rejected_avg_scores = rejected_scores / float(
raw_chunk_stats_counts["rejected_count"]
)
else:
rejected_avg_scores = None
chunk_stats = AgentChunkRetrievalStats(
verified_count=raw_chunk_stats_counts["verified_count"],
verified_avg_scores=verified_avg_scores,
rejected_count=raw_chunk_stats_counts["rejected_count"],
rejected_avg_scores=rejected_avg_scores,
verified_doc_chunk_ids=verified_doc_chunk_ids,
dismissed_doc_chunk_ids=dismissed_doc_chunk_ids,
)
return chunk_stats

View File

@@ -1,95 +0,0 @@
from operator import add
from typing import Annotated
from pydantic import BaseModel
from onyx.agents.agent_search.core_state import SubgraphCoreState
from onyx.agents.agent_search.deep_search.main.states import LoggerUpdate
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.models import (
QuestionRetrievalResult,
)
from onyx.agents.agent_search.shared_graph_utils.models import QueryRetrievalResult
from onyx.agents.agent_search.shared_graph_utils.models import RetrievalFitStats
from onyx.agents.agent_search.shared_graph_utils.operators import (
dedup_inference_sections,
)
from onyx.context.search.models import InferenceSection
### States ###
## Graph Input State
class ExpandedRetrievalInput(SubgraphCoreState):
# exception from 'no default value'for LangGraph input states
# Here, sub_question_id default None implies usage for the
# original question. This is sometimes needed for nested sub-graphs
sub_question_id: str | None = None
question: str
base_search: bool
## Update/Return States
class QueryExpansionUpdate(LoggerUpdate, BaseModel):
expanded_queries: list[str] = []
log_messages: list[str] = []
class DocVerificationUpdate(LoggerUpdate, BaseModel):
verified_documents: Annotated[list[InferenceSection], dedup_inference_sections] = []
class DocRetrievalUpdate(LoggerUpdate, BaseModel):
query_retrieval_results: Annotated[list[QueryRetrievalResult], add] = []
retrieved_documents: Annotated[list[InferenceSection], dedup_inference_sections] = (
[]
)
class DocRerankingUpdate(LoggerUpdate, BaseModel):
reranked_documents: Annotated[list[InferenceSection], dedup_inference_sections] = []
sub_question_retrieval_stats: RetrievalFitStats | None = None
class ExpandedRetrievalUpdate(LoggerUpdate, BaseModel):
expanded_retrieval_result: QuestionRetrievalResult
## Graph Output State
class ExpandedRetrievalOutput(LoggerUpdate, BaseModel):
expanded_retrieval_result: QuestionRetrievalResult = QuestionRetrievalResult()
base_expanded_retrieval_result: QuestionRetrievalResult = QuestionRetrievalResult()
retrieved_documents: Annotated[list[InferenceSection], dedup_inference_sections] = (
[]
)
## Graph State
class ExpandedRetrievalState(
# This includes the core state
ExpandedRetrievalInput,
QueryExpansionUpdate,
DocRetrievalUpdate,
DocVerificationUpdate,
DocRerankingUpdate,
ExpandedRetrievalOutput,
):
pass
## Conditional Input States
class DocVerificationInput(ExpandedRetrievalInput):
retrieved_document_to_verify: InferenceSection
class RetrievalInput(ExpandedRetrievalInput):
query_to_retrieve: str

View File

@@ -0,0 +1,59 @@
from collections.abc import Hashable
from langgraph.graph import END
from langgraph.types import Send
from onyx.agents.agent_search.dr.enums import DRPath
from onyx.agents.agent_search.dr.states import MainState
def decision_router(state: MainState) -> list[Send | Hashable] | DRPath | str:
if not state.tools_used:
raise IndexError("state.tools_used cannot be empty")
# next_tool is either a generic tool name or a DRPath string
next_tool_name = state.tools_used[-1]
available_tools = state.available_tools
if not available_tools:
raise ValueError("No tool is available. This should not happen.")
if next_tool_name in available_tools:
next_tool_path = available_tools[next_tool_name].path
elif next_tool_name == DRPath.END.value:
return END
elif next_tool_name == DRPath.LOGGER.value:
return DRPath.LOGGER
else:
return DRPath.ORCHESTRATOR
# handle invalid paths
if next_tool_path == DRPath.CLARIFIER:
raise ValueError("CLARIFIER is not a valid path during iteration")
# handle tool calls without a query
if (
next_tool_path
in (
DRPath.INTERNAL_SEARCH,
DRPath.INTERNET_SEARCH,
DRPath.KNOWLEDGE_GRAPH,
DRPath.IMAGE_GENERATION,
)
and len(state.query_list) == 0
):
return DRPath.CLOSER
return next_tool_path
def completeness_router(state: MainState) -> DRPath | str:
if not state.tools_used:
raise IndexError("tools_used cannot be empty")
# go to closer if path is CLOSER or no queries
next_path = state.tools_used[-1]
if next_path == DRPath.ORCHESTRATOR.value:
return DRPath.ORCHESTRATOR
return DRPath.LOGGER

View File

@@ -0,0 +1,30 @@
from onyx.agents.agent_search.dr.enums import DRPath
from onyx.agents.agent_search.dr.enums import ResearchType
MAX_CHAT_HISTORY_MESSAGES = (
3 # note: actual count is x2 to account for user and assistant messages
)
MAX_DR_PARALLEL_SEARCH = 4
# TODO: test more, generally not needed/adds unnecessary iterations
MAX_NUM_CLOSER_SUGGESTIONS = (
0 # how many times the closer can send back to the orchestrator
)
CLARIFICATION_REQUEST_PREFIX = "PLEASE CLARIFY:"
HIGH_LEVEL_PLAN_PREFIX = "The Plan:"
AVERAGE_TOOL_COSTS: dict[DRPath, float] = {
DRPath.INTERNAL_SEARCH: 1.0,
DRPath.KNOWLEDGE_GRAPH: 2.0,
DRPath.INTERNET_SEARCH: 1.5,
DRPath.IMAGE_GENERATION: 3.0,
DRPath.GENERIC_TOOL: 1.5, # TODO: see todo in OrchestratorTool
DRPath.CLOSER: 0.0,
}
DR_TIME_BUDGET_BY_TYPE = {
ResearchType.THOUGHTFUL: 3.0,
ResearchType.DEEP: 6.0,
}

View File

@@ -0,0 +1,112 @@
from datetime import datetime
from onyx.agents.agent_search.dr.enums import DRPath
from onyx.agents.agent_search.dr.enums import ResearchType
from onyx.agents.agent_search.dr.models import DRPromptPurpose
from onyx.agents.agent_search.dr.models import OrchestratorTool
from onyx.prompts.dr_prompts import GET_CLARIFICATION_PROMPT
from onyx.prompts.dr_prompts import KG_TYPES_DESCRIPTIONS
from onyx.prompts.dr_prompts import ORCHESTRATOR_DEEP_INITIAL_PLAN_PROMPT
from onyx.prompts.dr_prompts import ORCHESTRATOR_DEEP_ITERATIVE_DECISION_PROMPT
from onyx.prompts.dr_prompts import ORCHESTRATOR_FAST_ITERATIVE_DECISION_PROMPT
from onyx.prompts.dr_prompts import ORCHESTRATOR_FAST_ITERATIVE_REASONING_PROMPT
from onyx.prompts.dr_prompts import ORCHESTRATOR_NEXT_STEP_PURPOSE_PROMPT
from onyx.prompts.dr_prompts import TOOL_DIFFERENTIATION_HINTS
from onyx.prompts.dr_prompts import TOOL_QUESTION_HINTS
from onyx.prompts.prompt_template import PromptTemplate
def get_dr_prompt_orchestration_templates(
purpose: DRPromptPurpose,
research_type: ResearchType,
available_tools: dict[str, OrchestratorTool],
entity_types_string: str | None = None,
relationship_types_string: str | None = None,
reasoning_result: str | None = None,
tool_calls_string: str | None = None,
) -> PromptTemplate:
available_tools = available_tools or {}
tool_names = list(available_tools.keys())
tool_description_str = "\n\n".join(
f"- {tool_name}: {tool.description}"
for tool_name, tool in available_tools.items()
)
tool_cost_str = "\n".join(
f"{tool_name}: {tool.cost}" for tool_name, tool in available_tools.items()
)
tool_differentiations: list[str] = [
TOOL_DIFFERENTIATION_HINTS[(tool_1, tool_2)]
for tool_1 in available_tools
for tool_2 in available_tools
if (tool_1, tool_2) in TOOL_DIFFERENTIATION_HINTS
]
tool_differentiation_hint_string = (
"\n".join(tool_differentiations) or "(No differentiating hints available)"
)
# TODO: add tool deliniation pairs for custom tools as well
tool_question_hint_string = (
"\n".join(
"- " + TOOL_QUESTION_HINTS[tool]
for tool in available_tools
if tool in TOOL_QUESTION_HINTS
)
or "(No examples available)"
)
if DRPath.KNOWLEDGE_GRAPH.value in available_tools and (
entity_types_string or relationship_types_string
):
kg_types_descriptions = KG_TYPES_DESCRIPTIONS.build(
possible_entities=entity_types_string or "",
possible_relationships=relationship_types_string or "",
)
else:
kg_types_descriptions = "(The Knowledge Graph is not used.)"
if purpose == DRPromptPurpose.PLAN:
if research_type == ResearchType.THOUGHTFUL:
raise ValueError("plan generation is not supported for FAST time budget")
base_template = ORCHESTRATOR_DEEP_INITIAL_PLAN_PROMPT
elif purpose == DRPromptPurpose.NEXT_STEP_REASONING:
if research_type == ResearchType.THOUGHTFUL:
base_template = ORCHESTRATOR_FAST_ITERATIVE_REASONING_PROMPT
else:
raise ValueError(
"reasoning is not separately required for DEEP time budget"
)
elif purpose == DRPromptPurpose.NEXT_STEP_PURPOSE:
base_template = ORCHESTRATOR_NEXT_STEP_PURPOSE_PROMPT
elif purpose == DRPromptPurpose.NEXT_STEP:
if research_type == ResearchType.THOUGHTFUL:
base_template = ORCHESTRATOR_FAST_ITERATIVE_DECISION_PROMPT
else:
base_template = ORCHESTRATOR_DEEP_ITERATIVE_DECISION_PROMPT
elif purpose == DRPromptPurpose.CLARIFICATION:
if research_type == ResearchType.THOUGHTFUL:
raise ValueError("clarification is not supported for FAST time budget")
base_template = GET_CLARIFICATION_PROMPT
else:
# for mypy, clearly a mypy bug
raise ValueError(f"Invalid purpose: {purpose}")
return base_template.partial_build(
num_available_tools=str(len(tool_names)),
available_tools=", ".join(tool_names),
tool_choice_options=" or ".join(tool_names),
current_time=datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
kg_types_descriptions=kg_types_descriptions,
tool_descriptions=tool_description_str,
tool_differentiation_hints=tool_differentiation_hint_string,
tool_question_hints=tool_question_hint_string,
average_tool_costs=tool_cost_str,
reasoning_result=reasoning_result or "(No reasoning result provided.)",
tool_calls_string=tool_calls_string or "(No tool calls provided.)",
)

View File

@@ -0,0 +1,31 @@
from enum import Enum
class ResearchType(str, Enum):
"""Research type options for agent search operations"""
# BASIC = "BASIC"
LEGACY_AGENTIC = "LEGACY_AGENTIC" # only used for legacy agentic search migrations
THOUGHTFUL = "THOUGHTFUL"
DEEP = "DEEP"
class ResearchAnswerPurpose(str, Enum):
"""Research answer purpose options for agent search operations"""
ANSWER = "ANSWER"
CLARIFICATION_REQUEST = "CLARIFICATION_REQUEST"
class DRPath(str, Enum):
CLARIFIER = "Clarifier"
ORCHESTRATOR = "Orchestrator"
INTERNAL_SEARCH = "Search Tool"
GENERIC_TOOL = "Generic Tool"
KNOWLEDGE_GRAPH = "Knowledge Graph Search"
INTERNET_SEARCH = "Internet Search"
IMAGE_GENERATION = "Image Generation"
GENERIC_INTERNAL_TOOL = "Generic Internal Tool"
CLOSER = "Closer"
LOGGER = "Logger"
END = "End"

View File

@@ -0,0 +1,88 @@
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
from onyx.agents.agent_search.dr.conditional_edges import completeness_router
from onyx.agents.agent_search.dr.conditional_edges import decision_router
from onyx.agents.agent_search.dr.enums import DRPath
from onyx.agents.agent_search.dr.nodes.dr_a0_clarification import clarifier
from onyx.agents.agent_search.dr.nodes.dr_a1_orchestrator import orchestrator
from onyx.agents.agent_search.dr.nodes.dr_a2_closer import closer
from onyx.agents.agent_search.dr.nodes.dr_a3_logger import logging
from onyx.agents.agent_search.dr.states import MainInput
from onyx.agents.agent_search.dr.states import MainState
from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_basic_search_graph_builder import (
dr_basic_search_graph_builder,
)
from onyx.agents.agent_search.dr.sub_agents.custom_tool.dr_custom_tool_graph_builder import (
dr_custom_tool_graph_builder,
)
from onyx.agents.agent_search.dr.sub_agents.generic_internal_tool.dr_generic_internal_tool_graph_builder import (
dr_generic_internal_tool_graph_builder,
)
from onyx.agents.agent_search.dr.sub_agents.image_generation.dr_image_generation_graph_builder import (
dr_image_generation_graph_builder,
)
from onyx.agents.agent_search.dr.sub_agents.internet_search.dr_is_graph_builder import (
dr_is_graph_builder,
)
from onyx.agents.agent_search.dr.sub_agents.kg_search.dr_kg_search_graph_builder import (
dr_kg_search_graph_builder,
)
# from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_basic_search_2_act import search
def dr_graph_builder() -> StateGraph:
"""
LangGraph graph builder for the deep research agent.
"""
graph = StateGraph(state_schema=MainState, input=MainInput)
### Add nodes ###
graph.add_node(DRPath.CLARIFIER, clarifier)
graph.add_node(DRPath.ORCHESTRATOR, orchestrator)
basic_search_graph = dr_basic_search_graph_builder().compile()
graph.add_node(DRPath.INTERNAL_SEARCH, basic_search_graph)
kg_search_graph = dr_kg_search_graph_builder().compile()
graph.add_node(DRPath.KNOWLEDGE_GRAPH, kg_search_graph)
internet_search_graph = dr_is_graph_builder().compile()
graph.add_node(DRPath.INTERNET_SEARCH, internet_search_graph)
image_generation_graph = dr_image_generation_graph_builder().compile()
graph.add_node(DRPath.IMAGE_GENERATION, image_generation_graph)
custom_tool_graph = dr_custom_tool_graph_builder().compile()
graph.add_node(DRPath.GENERIC_TOOL, custom_tool_graph)
generic_internal_tool_graph = dr_generic_internal_tool_graph_builder().compile()
graph.add_node(DRPath.GENERIC_INTERNAL_TOOL, generic_internal_tool_graph)
graph.add_node(DRPath.CLOSER, closer)
graph.add_node(DRPath.LOGGER, logging)
### Add edges ###
graph.add_edge(start_key=START, end_key=DRPath.CLARIFIER)
graph.add_conditional_edges(DRPath.CLARIFIER, decision_router)
graph.add_conditional_edges(DRPath.ORCHESTRATOR, decision_router)
graph.add_edge(start_key=DRPath.INTERNAL_SEARCH, end_key=DRPath.ORCHESTRATOR)
graph.add_edge(start_key=DRPath.KNOWLEDGE_GRAPH, end_key=DRPath.ORCHESTRATOR)
graph.add_edge(start_key=DRPath.INTERNET_SEARCH, end_key=DRPath.ORCHESTRATOR)
graph.add_edge(start_key=DRPath.IMAGE_GENERATION, end_key=DRPath.ORCHESTRATOR)
graph.add_edge(start_key=DRPath.GENERIC_TOOL, end_key=DRPath.ORCHESTRATOR)
graph.add_edge(start_key=DRPath.GENERIC_INTERNAL_TOOL, end_key=DRPath.ORCHESTRATOR)
graph.add_conditional_edges(DRPath.CLOSER, completeness_router)
graph.add_edge(start_key=DRPath.LOGGER, end_key=END)
return graph

View File

@@ -0,0 +1,126 @@
from enum import Enum
from pydantic import BaseModel
from onyx.agents.agent_search.dr.enums import DRPath
from onyx.agents.agent_search.dr.sub_agents.image_generation.models import (
GeneratedImage,
)
from onyx.context.search.models import InferenceSection
from onyx.tools.tool import Tool
class OrchestratorStep(BaseModel):
tool: str
questions: list[str]
class OrchestratorDecisonsNoPlan(BaseModel):
reasoning: str
next_step: OrchestratorStep
class OrchestrationPlan(BaseModel):
reasoning: str
plan: str
class ClarificationGenerationResponse(BaseModel):
clarification_needed: bool
clarification_question: str
class DecisionResponse(BaseModel):
reasoning: str
decision: str
class QueryEvaluationResponse(BaseModel):
reasoning: str
query_permitted: bool
class OrchestrationClarificationInfo(BaseModel):
clarification_question: str
clarification_response: str | None = None
class WebSearchAnswer(BaseModel):
urls_to_open_indices: list[int]
class SearchAnswer(BaseModel):
reasoning: str
answer: str
claims: list[str] | None = None
class TestInfoCompleteResponse(BaseModel):
reasoning: str
complete: bool
gaps: list[str]
# TODO: revisit with custom tools implementation in v2
# each tool should be a class with the attributes below, plus the actual tool implementation
# this will also allow custom tools to have their own cost
class OrchestratorTool(BaseModel):
tool_id: int
name: str
llm_path: str # the path for the LLM to refer by
path: DRPath # the actual path in the graph
description: str
metadata: dict[str, str]
cost: float
tool_object: Tool | None = None # None for CLOSER
class Config:
arbitrary_types_allowed = True
class IterationInstructions(BaseModel):
iteration_nr: int
plan: str | None
reasoning: str
purpose: str
class IterationAnswer(BaseModel):
tool: str
tool_id: int
iteration_nr: int
parallelization_nr: int
question: str
reasoning: str | None
answer: str
cited_documents: dict[int, InferenceSection]
background_info: str | None = None
claims: list[str] | None = None
additional_data: dict[str, str] | None = None
response_type: str | None = None
data: dict | list | str | int | float | bool | None = None
file_ids: list[str] | None = None
# for image generation step-types
generated_images: list[GeneratedImage] | None = None
class AggregatedDRContext(BaseModel):
context: str
cited_documents: list[InferenceSection]
is_internet_marker_dict: dict[str, bool]
global_iteration_responses: list[IterationAnswer]
class DRPromptPurpose(str, Enum):
PLAN = "PLAN"
NEXT_STEP = "NEXT_STEP"
NEXT_STEP_REASONING = "NEXT_STEP_REASONING"
NEXT_STEP_PURPOSE = "NEXT_STEP_PURPOSE"
CLARIFICATION = "CLARIFICATION"
class BaseSearchProcessingResponse(BaseModel):
specified_source_types: list[str]
rewritten_query: str
time_filter: str

View File

@@ -0,0 +1,773 @@
import re
from datetime import datetime
from typing import Any
from typing import cast
from langchain_core.messages import HumanMessage
from langchain_core.messages import merge_content
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from sqlalchemy.orm import Session
from onyx.agents.agent_search.dr.constants import AVERAGE_TOOL_COSTS
from onyx.agents.agent_search.dr.constants import MAX_CHAT_HISTORY_MESSAGES
from onyx.agents.agent_search.dr.dr_prompt_builder import (
get_dr_prompt_orchestration_templates,
)
from onyx.agents.agent_search.dr.enums import DRPath
from onyx.agents.agent_search.dr.enums import ResearchAnswerPurpose
from onyx.agents.agent_search.dr.enums import ResearchType
from onyx.agents.agent_search.dr.models import ClarificationGenerationResponse
from onyx.agents.agent_search.dr.models import DecisionResponse
from onyx.agents.agent_search.dr.models import DRPromptPurpose
from onyx.agents.agent_search.dr.models import OrchestrationClarificationInfo
from onyx.agents.agent_search.dr.models import OrchestratorTool
from onyx.agents.agent_search.dr.process_llm_stream import process_llm_stream
from onyx.agents.agent_search.dr.states import MainState
from onyx.agents.agent_search.dr.states import OrchestrationSetup
from onyx.agents.agent_search.dr.utils import get_chat_history_string
from onyx.agents.agent_search.dr.utils import update_db_session_with_messages
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.llm import invoke_llm_json
from onyx.agents.agent_search.shared_graph_utils.llm import stream_llm_answer
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import run_with_timeout
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.agents.agent_search.utils import create_question_prompt
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import DocumentSourceDescription
from onyx.configs.constants import TMP_DRALPHA_PERSONA_NAME
from onyx.db.connector import fetch_unique_document_sources
from onyx.db.kg_config import get_kg_config_settings
from onyx.db.models import Tool
from onyx.db.tools import get_tools
from onyx.file_store.models import ChatFileType
from onyx.file_store.models import InMemoryChatFile
from onyx.kg.utils.extraction_utils import get_entity_types_str
from onyx.kg.utils.extraction_utils import get_relationship_types_str
from onyx.llm.utils import check_number_of_tokens
from onyx.llm.utils import get_max_input_tokens
from onyx.natural_language_processing.utils import get_tokenizer
from onyx.prompts.dr_prompts import ANSWER_PROMPT_WO_TOOL_CALLING
from onyx.prompts.dr_prompts import DECISION_PROMPT_W_TOOL_CALLING
from onyx.prompts.dr_prompts import DECISION_PROMPT_WO_TOOL_CALLING
from onyx.prompts.dr_prompts import DEFAULT_DR_SYSTEM_PROMPT
from onyx.prompts.dr_prompts import EVAL_SYSTEM_PROMPT_W_TOOL_CALLING
from onyx.prompts.dr_prompts import EVAL_SYSTEM_PROMPT_WO_TOOL_CALLING
from onyx.prompts.dr_prompts import REPEAT_PROMPT
from onyx.prompts.dr_prompts import TOOL_DESCRIPTION
from onyx.server.query_and_chat.streaming_models import MessageStart
from onyx.server.query_and_chat.streaming_models import OverallStop
from onyx.server.query_and_chat.streaming_models import SectionEnd
from onyx.server.query_and_chat.streaming_models import StreamingType
from onyx.tools.tool_implementations.images.image_generation_tool import (
ImageGenerationTool,
)
from onyx.tools.tool_implementations.internet_search.internet_search_tool import (
InternetSearchTool,
)
from onyx.tools.tool_implementations.knowledge_graph.knowledge_graph_tool import (
KnowledgeGraphTool,
)
from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.utils.b64 import get_image_type
from onyx.utils.b64 import get_image_type_from_bytes
from onyx.utils.logger import setup_logger
logger = setup_logger()
def _format_tool_name(tool_name: str) -> str:
"""Convert tool name to LLM-friendly format."""
name = tool_name.replace(" ", "_")
# take care of camel case like GetAPIKey -> GET_API_KEY for LLM readability
name = re.sub(r"(?<=[a-z0-9])(?=[A-Z])|(?<=[A-Z])(?=[A-Z][a-z])", "_", name)
return name.upper()
def _get_available_tools(
db_session: Session,
graph_config: GraphConfig,
kg_enabled: bool,
active_source_types: list[DocumentSource],
) -> dict[str, OrchestratorTool]:
available_tools: dict[str, OrchestratorTool] = {}
kg_enabled = graph_config.behavior.kg_config_settings.KG_ENABLED
persona = graph_config.inputs.persona
if persona:
include_kg = persona.name == TMP_DRALPHA_PERSONA_NAME and kg_enabled
else:
include_kg = False
tool_dict: dict[int, Tool] = {tool.id: tool for tool in get_tools(db_session)}
for tool in graph_config.tooling.tools:
tool_db_info = tool_dict.get(tool.id)
if tool_db_info:
incode_tool_id = tool_db_info.in_code_tool_id
else:
raise ValueError(f"Tool {tool.name} is not found in the database")
if isinstance(tool, InternetSearchTool):
llm_path = DRPath.INTERNET_SEARCH.value
path = DRPath.INTERNET_SEARCH
elif isinstance(tool, SearchTool):
llm_path = DRPath.INTERNAL_SEARCH.value
path = DRPath.INTERNAL_SEARCH
elif isinstance(tool, KnowledgeGraphTool) and include_kg:
if len(active_source_types) == 0:
logger.error(
"No active source types found, skipping Knowledge Graph tool"
)
continue
llm_path = DRPath.KNOWLEDGE_GRAPH.value
path = DRPath.KNOWLEDGE_GRAPH
elif isinstance(tool, ImageGenerationTool):
llm_path = DRPath.IMAGE_GENERATION.value
path = DRPath.IMAGE_GENERATION
elif incode_tool_id:
# if incode tool id is found, it is a generic internal tool
llm_path = DRPath.GENERIC_INTERNAL_TOOL.value
path = DRPath.GENERIC_INTERNAL_TOOL
else:
# otherwise it is a custom tool
llm_path = DRPath.GENERIC_TOOL.value
path = DRPath.GENERIC_TOOL
if path not in {DRPath.GENERIC_INTERNAL_TOOL, DRPath.GENERIC_TOOL}:
description = TOOL_DESCRIPTION.get(path, tool.description)
cost = AVERAGE_TOOL_COSTS[path]
else:
description = tool.description
cost = 1.0
tool_info = OrchestratorTool(
tool_id=tool.id,
name=tool.llm_name,
llm_path=llm_path,
path=path,
description=description,
metadata={},
cost=cost,
tool_object=tool,
)
# TODO: handle custom tools with same name as other tools (e.g., CLOSER)
available_tools[tool.llm_name] = tool_info
available_tool_paths = [tool.path for tool in available_tools.values()]
# make sure KG isn't enabled without internal search
if (
DRPath.KNOWLEDGE_GRAPH in available_tool_paths
and DRPath.INTERNAL_SEARCH not in available_tool_paths
):
raise ValueError(
"The Knowledge Graph is not supported without internal search tool"
)
# add CLOSER tool, which is always available
available_tools[DRPath.CLOSER.value] = OrchestratorTool(
tool_id=-1,
name=DRPath.CLOSER.value,
llm_path=DRPath.CLOSER.value,
path=DRPath.CLOSER,
description=TOOL_DESCRIPTION[DRPath.CLOSER],
metadata={},
cost=0.0,
tool_object=None,
)
return available_tools
def _construct_uploaded_text_context(files: list[InMemoryChatFile]) -> str:
"""Construct the uploaded context from the files."""
file_contents = []
for file in files:
if file.file_type in (
ChatFileType.DOC,
ChatFileType.PLAIN_TEXT,
ChatFileType.CSV,
):
file_contents.append(file.content.decode("utf-8"))
if len(file_contents) > 0:
return "Uploaded context:\n\n\n" + "\n\n".join(file_contents)
return ""
def _construct_uploaded_image_context(
files: list[InMemoryChatFile] | None = None,
img_urls: list[str] | None = None,
b64_imgs: list[str] | None = None,
) -> list[dict[str, Any]] | None:
"""Construct the uploaded image context from the files."""
# Only include image files for user messages
if files is None:
return None
img_files = [file for file in files if file.file_type == ChatFileType.IMAGE]
img_urls = img_urls or []
b64_imgs = b64_imgs or []
if not (img_files or img_urls or b64_imgs):
return None
return cast(
list[dict[str, Any]],
[
{
"type": "image_url",
"image_url": {
"url": (
f"data:{get_image_type_from_bytes(file.content)};"
f"base64,{file.to_base64()}"
),
},
}
for file in img_files
]
+ [
{
"type": "image_url",
"image_url": {
"url": f"data:{get_image_type(b64_img)};base64,{b64_img}",
},
}
for b64_img in b64_imgs
]
+ [
{
"type": "image_url",
"image_url": {
"url": url,
},
}
for url in img_urls
],
)
def _get_existing_clarification_request(
graph_config: GraphConfig,
) -> tuple[OrchestrationClarificationInfo, str, str] | None:
"""
Returns the clarification info, original question, and updated chat history if
a clarification request and response exists, otherwise returns None.
"""
# check for clarification request and response in message history
previous_raw_messages = graph_config.inputs.prompt_builder.raw_message_history
if len(previous_raw_messages) == 0 or (
previous_raw_messages[-1].research_answer_purpose
!= ResearchAnswerPurpose.CLARIFICATION_REQUEST
):
return None
# get the clarification request and response
previous_messages = graph_config.inputs.prompt_builder.message_history
last_message = previous_raw_messages[-1].message
clarification = OrchestrationClarificationInfo(
clarification_question=last_message.strip(),
clarification_response=graph_config.inputs.prompt_builder.raw_user_query,
)
original_question = graph_config.inputs.prompt_builder.raw_user_query
chat_history_string = "(No chat history yet available)"
# get the original user query and chat history string before the original query
# e.g., if history = [user query, assistant clarification request, user clarification response],
# previous_messages = [user query, assistant clarification request], we want the user query
for i, message in enumerate(reversed(previous_messages), 1):
if (
isinstance(message, HumanMessage)
and message.content
and isinstance(message.content, str)
):
original_question = message.content
chat_history_string = (
get_chat_history_string(
graph_config.inputs.prompt_builder.message_history[:-i],
MAX_CHAT_HISTORY_MESSAGES,
)
or "(No chat history yet available)"
)
break
return clarification, original_question, chat_history_string
_ARTIFICIAL_ALL_ENCOMPASSING_TOOL = {
"type": "function",
"function": {
"name": "run_any_knowledge_retrieval_and_any_action_tool",
"description": "Use this tool to get ANY external information \
that is relevant to the question, or for any action to be taken, including image generation. In fact, \
ANY tool mentioned can be accessed through this generic tool.",
"parameters": {
"type": "object",
"properties": {
"request": {
"type": "string",
"description": "The request to be made to the tool",
},
},
"required": ["request"],
},
},
}
def clarifier(
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> OrchestrationSetup:
"""
Perform a quick search on the question as is and see whether a set of clarification
questions is needed. For now this is based on the models
"""
node_start_time = datetime.now()
current_step_nr = 0
graph_config = cast(GraphConfig, config["metadata"]["config"])
llm_provider = graph_config.tooling.primary_llm.config.model_provider
llm_model_name = graph_config.tooling.primary_llm.config.model_name
llm_tokenizer = get_tokenizer(
model_name=llm_model_name,
provider_type=llm_provider,
)
max_input_tokens = get_max_input_tokens(
model_name=llm_model_name,
model_provider=llm_provider,
)
use_tool_calling_llm = graph_config.tooling.using_tool_calling_llm
db_session = graph_config.persistence.db_session
original_question = graph_config.inputs.prompt_builder.raw_user_query
research_type = graph_config.behavior.research_type
force_use_tool = graph_config.tooling.force_use_tool
message_id = graph_config.persistence.message_id
# Perform a commit to ensure the message_id is set and saved
db_session.commit()
# get the connected tools and format for the Deep Research flow
kg_enabled = graph_config.behavior.kg_config_settings.KG_ENABLED
db_session = graph_config.persistence.db_session
active_source_types = fetch_unique_document_sources(db_session)
available_tools = _get_available_tools(
db_session, graph_config, kg_enabled, active_source_types
)
available_tool_descriptions_str = "\n -" + "\n -".join(
[tool.description for tool in available_tools.values()]
)
kg_config = get_kg_config_settings()
if kg_config.KG_ENABLED and kg_config.KG_EXPOSED:
all_entity_types = get_entity_types_str(active=True)
all_relationship_types = get_relationship_types_str(active=True)
else:
all_entity_types = ""
all_relationship_types = ""
# if not active_source_types:
# raise ValueError("No active source types found")
active_source_types_descriptions = [
DocumentSourceDescription[source_type] for source_type in active_source_types
]
if len(active_source_types_descriptions) > 0:
active_source_type_descriptions_str = "\n -" + "\n -".join(
active_source_types_descriptions
)
else:
active_source_type_descriptions_str = ""
if graph_config.inputs.persona and len(graph_config.inputs.persona.prompts) > 0:
assistant_system_prompt = (
graph_config.inputs.persona.prompts[0].system_prompt
or DEFAULT_DR_SYSTEM_PROMPT
) + "\n\n"
if graph_config.inputs.persona.prompts[0].task_prompt:
assistant_task_prompt = (
"\n\nHere are more specifications from the user:\n\n"
+ graph_config.inputs.persona.prompts[0].task_prompt
)
else:
assistant_task_prompt = ""
else:
assistant_system_prompt = DEFAULT_DR_SYSTEM_PROMPT + "\n\n"
assistant_task_prompt = ""
chat_history_string = (
get_chat_history_string(
graph_config.inputs.prompt_builder.message_history,
MAX_CHAT_HISTORY_MESSAGES,
)
or "(No chat history yet available)"
)
uploaded_text_context = (
_construct_uploaded_text_context(graph_config.inputs.files)
if graph_config.inputs.files
else ""
)
uploaded_context_tokens = check_number_of_tokens(
uploaded_text_context, llm_tokenizer.encode
)
if uploaded_context_tokens > 0.5 * max_input_tokens:
raise ValueError(
f"Uploaded context is too long. {uploaded_context_tokens} tokens, "
f"but for this model we only allow {0.5 * max_input_tokens} tokens for uploaded context"
)
uploaded_image_context = _construct_uploaded_image_context(
graph_config.inputs.files
)
if not (force_use_tool and force_use_tool.force_use):
if not use_tool_calling_llm or len(available_tools) == 1:
if len(available_tools) > 1:
decision_prompt = DECISION_PROMPT_WO_TOOL_CALLING.build(
question=original_question,
chat_history_string=chat_history_string,
uploaded_context=uploaded_text_context or "",
active_source_type_descriptions_str=active_source_type_descriptions_str,
available_tool_descriptions_str=available_tool_descriptions_str,
)
llm_decision = invoke_llm_json(
llm=graph_config.tooling.primary_llm,
prompt=create_question_prompt(
EVAL_SYSTEM_PROMPT_WO_TOOL_CALLING,
decision_prompt,
),
schema=DecisionResponse,
)
else:
# if there is only one tool (Closer), we don't need to decide. It's an LLM answer
llm_decision = DecisionResponse(decision="LLM", reasoning="")
if llm_decision.decision == "LLM":
write_custom_event(
current_step_nr,
MessageStart(
content="",
final_documents=[],
),
writer,
)
answer_prompt = ANSWER_PROMPT_WO_TOOL_CALLING.build(
question=original_question,
chat_history_string=chat_history_string,
uploaded_context=uploaded_text_context or "",
active_source_type_descriptions_str=active_source_type_descriptions_str,
available_tool_descriptions_str=available_tool_descriptions_str,
)
answer_tokens, _, _ = run_with_timeout(
80,
lambda: stream_llm_answer(
llm=graph_config.tooling.primary_llm,
prompt=create_question_prompt(
assistant_system_prompt,
answer_prompt + assistant_task_prompt,
),
event_name="basic_response",
writer=writer,
answer_piece=StreamingType.MESSAGE_DELTA.value,
agent_answer_level=0,
agent_answer_question_num=0,
agent_answer_type="agent_level_answer",
timeout_override=60,
ind=current_step_nr,
context_docs=None,
replace_citations=True,
max_tokens=None,
),
)
write_custom_event(
current_step_nr,
SectionEnd(
type="section_end",
),
writer,
)
current_step_nr += 1
answer_str = cast(str, merge_content(*answer_tokens))
write_custom_event(
current_step_nr,
OverallStop(),
writer,
)
update_db_session_with_messages(
db_session=db_session,
chat_message_id=message_id,
chat_session_id=str(graph_config.persistence.chat_session_id),
is_agentic=graph_config.behavior.use_agentic_search,
message=answer_str,
update_parent_message=True,
research_answer_purpose=ResearchAnswerPurpose.ANSWER,
)
db_session.commit()
return OrchestrationSetup(
original_question=original_question,
chat_history_string="",
tools_used=[DRPath.END.value],
available_tools=available_tools,
query_list=[],
assistant_system_prompt=assistant_system_prompt,
assistant_task_prompt=assistant_task_prompt,
)
else:
decision_prompt = DECISION_PROMPT_W_TOOL_CALLING.build(
question=original_question,
chat_history_string=chat_history_string,
uploaded_context=uploaded_text_context or "",
active_source_type_descriptions_str=active_source_type_descriptions_str,
)
stream = graph_config.tooling.primary_llm.stream(
prompt=create_question_prompt(
assistant_system_prompt + EVAL_SYSTEM_PROMPT_W_TOOL_CALLING,
decision_prompt + assistant_task_prompt,
uploaded_image_context=uploaded_image_context,
),
tools=([_ARTIFICIAL_ALL_ENCOMPASSING_TOOL]),
tool_choice=(None),
structured_response_format=graph_config.inputs.structured_response_format,
)
full_response = process_llm_stream(
messages=stream,
should_stream_answer=True,
writer=writer,
ind=0,
generate_final_answer=True,
chat_message_id=str(graph_config.persistence.chat_session_id),
)
if len(full_response.ai_message_chunk.tool_calls) == 0:
if isinstance(full_response.full_answer, str):
full_answer = full_response.full_answer
else:
full_answer = None
update_db_session_with_messages(
db_session=db_session,
chat_message_id=message_id,
chat_session_id=str(graph_config.persistence.chat_session_id),
is_agentic=graph_config.behavior.use_agentic_search,
message=full_answer,
update_parent_message=True,
research_answer_purpose=ResearchAnswerPurpose.ANSWER,
)
db_session.commit()
return OrchestrationSetup(
original_question=original_question,
chat_history_string="",
tools_used=[DRPath.END.value],
query_list=[],
available_tools=available_tools,
assistant_system_prompt=assistant_system_prompt,
assistant_task_prompt=assistant_task_prompt,
)
# Continue, as external knowledge is required.
current_step_nr += 1
clarification = None
if research_type != ResearchType.THOUGHTFUL:
result = _get_existing_clarification_request(graph_config)
if result is not None:
clarification, original_question, chat_history_string = result
else:
# generate clarification questions if needed
chat_history_string = (
get_chat_history_string(
graph_config.inputs.prompt_builder.message_history,
MAX_CHAT_HISTORY_MESSAGES,
)
or "(No chat history yet available)"
)
base_clarification_prompt = get_dr_prompt_orchestration_templates(
DRPromptPurpose.CLARIFICATION,
research_type,
entity_types_string=all_entity_types,
relationship_types_string=all_relationship_types,
available_tools=available_tools,
)
clarification_prompt = base_clarification_prompt.build(
question=original_question,
chat_history_string=chat_history_string,
)
try:
clarification_response = invoke_llm_json(
llm=graph_config.tooling.primary_llm,
prompt=create_question_prompt(
assistant_system_prompt, clarification_prompt
),
schema=ClarificationGenerationResponse,
timeout_override=25,
# max_tokens=1500,
)
except Exception as e:
logger.error(f"Error in clarification generation: {e}")
raise e
if (
clarification_response.clarification_needed
and clarification_response.clarification_question
):
clarification = OrchestrationClarificationInfo(
clarification_question=clarification_response.clarification_question,
clarification_response=None,
)
write_custom_event(
current_step_nr,
MessageStart(
content="",
final_documents=None,
),
writer,
)
repeat_prompt = REPEAT_PROMPT.build(
original_information=clarification_response.clarification_question
)
_, _, _ = run_with_timeout(
80,
lambda: stream_llm_answer(
llm=graph_config.tooling.primary_llm,
prompt=repeat_prompt,
event_name="basic_response",
writer=writer,
agent_answer_level=0,
agent_answer_question_num=0,
agent_answer_type="agent_level_answer",
timeout_override=60,
answer_piece=StreamingType.MESSAGE_DELTA.value,
ind=current_step_nr,
# max_tokens=None,
),
)
write_custom_event(
current_step_nr,
SectionEnd(
type="section_end",
),
writer,
)
write_custom_event(
1,
OverallStop(),
writer,
)
update_db_session_with_messages(
db_session=db_session,
chat_message_id=message_id,
chat_session_id=str(graph_config.persistence.chat_session_id),
is_agentic=graph_config.behavior.use_agentic_search,
message=clarification_response.clarification_question,
update_parent_message=True,
research_type=research_type,
research_answer_purpose=ResearchAnswerPurpose.CLARIFICATION_REQUEST,
)
db_session.commit()
else:
chat_history_string = (
get_chat_history_string(
graph_config.inputs.prompt_builder.message_history,
MAX_CHAT_HISTORY_MESSAGES,
)
or "(No chat history yet available)"
)
if (
clarification
and clarification.clarification_question
and clarification.clarification_response is None
):
update_db_session_with_messages(
db_session=db_session,
chat_message_id=message_id,
chat_session_id=str(graph_config.persistence.chat_session_id),
is_agentic=graph_config.behavior.use_agentic_search,
message=clarification.clarification_question,
update_parent_message=True,
research_type=research_type,
research_answer_purpose=ResearchAnswerPurpose.CLARIFICATION_REQUEST,
)
db_session.commit()
next_tool = DRPath.END.value
else:
next_tool = DRPath.ORCHESTRATOR.value
return OrchestrationSetup(
original_question=original_question,
chat_history_string=chat_history_string,
tools_used=[next_tool],
query_list=[],
iteration_nr=0,
current_step_nr=current_step_nr,
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="clarifier",
node_start_time=node_start_time,
)
],
clarification=clarification,
available_tools=available_tools,
active_source_types=active_source_types,
active_source_types_descriptions="\n".join(active_source_types_descriptions),
assistant_system_prompt=assistant_system_prompt,
assistant_task_prompt=assistant_task_prompt,
uploaded_test_context=uploaded_text_context,
uploaded_image_context=uploaded_image_context,
)

View File

@@ -0,0 +1,569 @@
from datetime import datetime
from typing import cast
from langchain_core.messages import merge_content
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.dr.constants import DR_TIME_BUDGET_BY_TYPE
from onyx.agents.agent_search.dr.constants import HIGH_LEVEL_PLAN_PREFIX
from onyx.agents.agent_search.dr.dr_prompt_builder import (
get_dr_prompt_orchestration_templates,
)
from onyx.agents.agent_search.dr.enums import DRPath
from onyx.agents.agent_search.dr.enums import ResearchType
from onyx.agents.agent_search.dr.models import DRPromptPurpose
from onyx.agents.agent_search.dr.models import OrchestrationPlan
from onyx.agents.agent_search.dr.models import OrchestratorDecisonsNoPlan
from onyx.agents.agent_search.dr.states import IterationInstructions
from onyx.agents.agent_search.dr.states import MainState
from onyx.agents.agent_search.dr.states import OrchestrationUpdate
from onyx.agents.agent_search.dr.utils import aggregate_context
from onyx.agents.agent_search.dr.utils import create_tool_call_string
from onyx.agents.agent_search.dr.utils import get_prompt_question
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.llm import invoke_llm_json
from onyx.agents.agent_search.shared_graph_utils.llm import stream_llm_answer
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import run_with_timeout
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.agents.agent_search.utils import create_question_prompt
from onyx.kg.utils.extraction_utils import get_entity_types_str
from onyx.kg.utils.extraction_utils import get_relationship_types_str
from onyx.prompts.dr_prompts import DEFAULLT_DECISION_PROMPT
from onyx.prompts.dr_prompts import REPEAT_PROMPT
from onyx.prompts.dr_prompts import SUFFICIENT_INFORMATION_STRING
from onyx.server.query_and_chat.streaming_models import ReasoningStart
from onyx.server.query_and_chat.streaming_models import SectionEnd
from onyx.server.query_and_chat.streaming_models import StreamingType
from onyx.utils.logger import setup_logger
logger = setup_logger()
_DECISION_SYSTEM_PROMPT_PREFIX = "Here are general instructions by the user, which \
may or may not influence the decision what to do next:\n\n"
def _get_implied_next_tool_based_on_tool_call_history(
tools_used: list[str],
) -> str | None:
"""
Identify the next tool based on the tool call history. Initially, we only support
special handling of the image generation tool.
"""
if tools_used[-1] == DRPath.IMAGE_GENERATION.value:
return DRPath.LOGGER.value
else:
return None
def orchestrator(
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> OrchestrationUpdate:
"""
LangGraph node to decide the next step in the DR process.
"""
node_start_time = datetime.now()
graph_config = cast(GraphConfig, config["metadata"]["config"])
question = state.original_question
if not question:
raise ValueError("Question is required for orchestrator")
state.original_question
available_tools = state.available_tools
plan_of_record = state.plan_of_record
clarification = state.clarification
assistant_system_prompt = state.assistant_system_prompt
if assistant_system_prompt:
decision_system_prompt: str = (
DEFAULLT_DECISION_PROMPT
+ _DECISION_SYSTEM_PROMPT_PREFIX
+ assistant_system_prompt
)
else:
decision_system_prompt = DEFAULLT_DECISION_PROMPT
iteration_nr = state.iteration_nr + 1
current_step_nr = state.current_step_nr
research_type = graph_config.behavior.research_type
remaining_time_budget = state.remaining_time_budget
chat_history_string = state.chat_history_string or "(No chat history yet available)"
answer_history_string = (
aggregate_context(state.iteration_responses, include_documents=True).context
or "(No answer history yet available)"
)
next_tool_name = None
# Identify early exit condition based on tool call history
next_tool_based_on_tool_call_history = (
_get_implied_next_tool_based_on_tool_call_history(state.tools_used)
)
if next_tool_based_on_tool_call_history == DRPath.LOGGER.value:
return OrchestrationUpdate(
tools_used=[DRPath.LOGGER.value],
query_list=[],
iteration_nr=iteration_nr,
current_step_nr=current_step_nr,
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="orchestrator",
node_start_time=node_start_time,
)
],
plan_of_record=plan_of_record,
remaining_time_budget=remaining_time_budget,
iteration_instructions=[
IterationInstructions(
iteration_nr=iteration_nr,
plan=plan_of_record.plan if plan_of_record else None,
reasoning="",
purpose="",
)
],
)
# no early exit forced. Continue.
available_tools = state.available_tools or {}
uploaded_context = state.uploaded_test_context or ""
questions = [
f"{iteration_response.tool}: {iteration_response.question}"
for iteration_response in state.iteration_responses
if len(iteration_response.question) > 0
]
question_history_string = (
"\n".join(f" - {question}" for question in questions)
if questions
else "(No question history yet available)"
)
prompt_question = get_prompt_question(question, clarification)
gaps_str = (
("\n - " + "\n - ".join(state.gaps))
if state.gaps
else "(No explicit gaps were pointed out so far)"
)
all_entity_types = get_entity_types_str(active=True)
all_relationship_types = get_relationship_types_str(active=True)
# default to closer
query_list = ["Answer the question with the information you have."]
decision_prompt = None
reasoning_result = "(No reasoning result provided yet.)"
tool_calls_string = "(No tool calls provided yet.)"
if research_type == ResearchType.THOUGHTFUL:
if iteration_nr == 1:
remaining_time_budget = DR_TIME_BUDGET_BY_TYPE[ResearchType.THOUGHTFUL]
elif iteration_nr > 1:
# for each iteration past the first one, we need to see whether we
# have enough information to answer the question.
# if we do, we can stop the iteration and return the answer.
# if we do not, we need to continue the iteration.
base_reasoning_prompt = get_dr_prompt_orchestration_templates(
DRPromptPurpose.NEXT_STEP_REASONING,
ResearchType.THOUGHTFUL,
entity_types_string=all_entity_types,
relationship_types_string=all_relationship_types,
available_tools=available_tools,
)
reasoning_prompt = base_reasoning_prompt.build(
question=question,
chat_history_string=chat_history_string,
answer_history_string=answer_history_string,
iteration_nr=str(iteration_nr),
remaining_time_budget=str(remaining_time_budget),
uploaded_context=uploaded_context,
)
reasoning_tokens: list[str] = [""]
reasoning_tokens, _, _ = run_with_timeout(
80,
lambda: stream_llm_answer(
llm=graph_config.tooling.primary_llm,
prompt=create_question_prompt(
decision_system_prompt, reasoning_prompt
),
event_name="basic_response",
writer=writer,
agent_answer_level=0,
agent_answer_question_num=0,
agent_answer_type="agent_level_answer",
timeout_override=60,
answer_piece=StreamingType.REASONING_DELTA.value,
ind=current_step_nr,
# max_tokens=None,
),
)
write_custom_event(
current_step_nr,
SectionEnd(),
writer,
)
current_step_nr += 1
reasoning_result = cast(str, merge_content(*reasoning_tokens))
if SUFFICIENT_INFORMATION_STRING in reasoning_result:
return OrchestrationUpdate(
tools_used=[DRPath.CLOSER.value],
current_step_nr=current_step_nr,
query_list=[],
iteration_nr=iteration_nr,
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="orchestrator",
node_start_time=node_start_time,
)
],
plan_of_record=plan_of_record,
remaining_time_budget=remaining_time_budget,
iteration_instructions=[
IterationInstructions(
iteration_nr=iteration_nr,
plan=None,
reasoning=reasoning_result,
purpose="",
)
],
)
# for Thoughtful mode, we force a tool if requested an available
available_tools_for_decision = available_tools
force_use_tool = graph_config.tooling.force_use_tool
if iteration_nr == 1 and force_use_tool and force_use_tool.force_use:
forced_tool_name = force_use_tool.tool_name
available_tool_dict = {
available_tool.tool_object.name: available_tool
for _, available_tool in available_tools.items()
if available_tool.tool_object
}
if forced_tool_name in available_tool_dict.keys():
forced_tool = available_tool_dict[forced_tool_name]
available_tools_for_decision = {forced_tool.name: forced_tool}
base_decision_prompt = get_dr_prompt_orchestration_templates(
DRPromptPurpose.NEXT_STEP,
ResearchType.THOUGHTFUL,
entity_types_string=all_entity_types,
relationship_types_string=all_relationship_types,
available_tools=available_tools_for_decision,
)
decision_prompt = base_decision_prompt.build(
question=question,
chat_history_string=chat_history_string,
answer_history_string=answer_history_string,
iteration_nr=str(iteration_nr),
remaining_time_budget=str(remaining_time_budget),
reasoning_result=reasoning_result,
uploaded_context=uploaded_context,
)
if remaining_time_budget > 0:
try:
orchestrator_action = invoke_llm_json(
llm=graph_config.tooling.primary_llm,
prompt=create_question_prompt(
decision_system_prompt,
decision_prompt,
),
schema=OrchestratorDecisonsNoPlan,
timeout_override=35,
# max_tokens=2500,
)
next_step = orchestrator_action.next_step
next_tool_name = next_step.tool
query_list = [q for q in (next_step.questions or [])]
tool_calls_string = create_tool_call_string(next_tool_name, query_list)
except Exception as e:
logger.error(f"Error in approach extraction: {e}")
raise e
if next_tool_name in available_tools.keys():
remaining_time_budget -= available_tools[next_tool_name].cost
else:
logger.warning(f"Tool {next_tool_name} not found in available tools")
remaining_time_budget -= 1.0
else:
reasoning_result = "Time to wrap up."
next_tool_name = DRPath.CLOSER.value
else:
if iteration_nr == 1 and not plan_of_record:
# by default, we start a new iteration, but if there is a feedback request,
# we start a new iteration 0 again (set a bit later)
remaining_time_budget = DR_TIME_BUDGET_BY_TYPE[ResearchType.DEEP]
base_plan_prompt = get_dr_prompt_orchestration_templates(
DRPromptPurpose.PLAN,
ResearchType.DEEP,
entity_types_string=all_entity_types,
relationship_types_string=all_relationship_types,
available_tools=available_tools,
)
plan_generation_prompt = base_plan_prompt.build(
question=prompt_question,
chat_history_string=chat_history_string,
uploaded_context=uploaded_context,
)
try:
plan_of_record = invoke_llm_json(
llm=graph_config.tooling.primary_llm,
prompt=create_question_prompt(
decision_system_prompt,
plan_generation_prompt,
),
schema=OrchestrationPlan,
timeout_override=25,
# max_tokens=3000,
)
except Exception as e:
logger.error(f"Error in plan generation: {e}")
raise
write_custom_event(
current_step_nr,
ReasoningStart(),
writer,
)
start_time = datetime.now()
repeat_plan_prompt = REPEAT_PROMPT.build(
original_information=f"{HIGH_LEVEL_PLAN_PREFIX}\n\n {plan_of_record.plan}"
)
_, _, _ = run_with_timeout(
80,
lambda: stream_llm_answer(
llm=graph_config.tooling.primary_llm,
prompt=repeat_plan_prompt,
event_name="basic_response",
writer=writer,
agent_answer_level=0,
agent_answer_question_num=0,
agent_answer_type="agent_level_answer",
timeout_override=60,
answer_piece=StreamingType.REASONING_DELTA.value,
ind=current_step_nr,
),
)
end_time = datetime.now()
logger.debug(f"Time taken for plan streaming: {end_time - start_time}")
write_custom_event(
current_step_nr,
SectionEnd(),
writer,
)
current_step_nr += 1
if not plan_of_record:
raise ValueError(
"Plan information is required for iterative decision making"
)
base_decision_prompt = get_dr_prompt_orchestration_templates(
DRPromptPurpose.NEXT_STEP,
ResearchType.DEEP,
entity_types_string=all_entity_types,
relationship_types_string=all_relationship_types,
available_tools=available_tools,
)
decision_prompt = base_decision_prompt.build(
answer_history_string=answer_history_string,
question_history_string=question_history_string,
question=prompt_question,
iteration_nr=str(iteration_nr),
current_plan_of_record_string=plan_of_record.plan,
chat_history_string=chat_history_string,
remaining_time_budget=str(remaining_time_budget),
gaps=gaps_str,
uploaded_context=uploaded_context,
)
if remaining_time_budget > 0:
try:
orchestrator_action = invoke_llm_json(
llm=graph_config.tooling.primary_llm,
prompt=create_question_prompt(
decision_system_prompt,
decision_prompt,
),
schema=OrchestratorDecisonsNoPlan,
timeout_override=15,
# max_tokens=1500,
)
next_step = orchestrator_action.next_step
next_tool_name = next_step.tool
query_list = [q for q in (next_step.questions or [])]
reasoning_result = orchestrator_action.reasoning
tool_calls_string = create_tool_call_string(next_tool_name, query_list)
except Exception as e:
logger.error(f"Error in approach extraction: {e}")
raise e
if next_tool_name in available_tools.keys():
remaining_time_budget -= available_tools[next_tool_name].cost
else:
logger.warning(f"Tool {next_tool_name} not found in available tools")
remaining_time_budget -= 1.0
else:
reasoning_result = "Time to wrap up."
next_tool_name = DRPath.CLOSER.value
write_custom_event(
current_step_nr,
ReasoningStart(),
writer,
)
repeat_reasoning_prompt = REPEAT_PROMPT.build(
original_information=reasoning_result
)
_, _, _ = run_with_timeout(
80,
lambda: stream_llm_answer(
llm=graph_config.tooling.primary_llm,
prompt=repeat_reasoning_prompt,
event_name="basic_response",
writer=writer,
agent_answer_level=0,
agent_answer_question_num=0,
agent_answer_type="agent_level_answer",
timeout_override=60,
answer_piece=StreamingType.REASONING_DELTA.value,
ind=current_step_nr,
# max_tokens=None,
),
)
write_custom_event(
current_step_nr,
SectionEnd(),
writer,
)
current_step_nr += 1
base_next_step_purpose_prompt = get_dr_prompt_orchestration_templates(
DRPromptPurpose.NEXT_STEP_PURPOSE,
ResearchType.DEEP,
entity_types_string=all_entity_types,
relationship_types_string=all_relationship_types,
available_tools=available_tools,
)
orchestration_next_step_purpose_prompt = base_next_step_purpose_prompt.build(
question=prompt_question,
reasoning_result=reasoning_result,
tool_calls=tool_calls_string,
)
purpose_tokens: list[str] = [""]
try:
write_custom_event(
current_step_nr,
ReasoningStart(),
writer,
)
purpose_tokens, _, _ = run_with_timeout(
80,
lambda: stream_llm_answer(
llm=graph_config.tooling.primary_llm,
prompt=create_question_prompt(
decision_system_prompt,
orchestration_next_step_purpose_prompt,
),
event_name="basic_response",
writer=writer,
agent_answer_level=0,
agent_answer_question_num=0,
agent_answer_type="agent_level_answer",
timeout_override=60,
answer_piece=StreamingType.REASONING_DELTA.value,
ind=current_step_nr,
# max_tokens=None,
),
)
write_custom_event(
current_step_nr,
SectionEnd(),
writer,
)
current_step_nr += 1
except Exception as e:
logger.error(f"Error in orchestration next step purpose: {e}")
raise e
purpose = cast(str, merge_content(*purpose_tokens))
if not next_tool_name:
raise ValueError("The next step has not been defined. This should not happen.")
return OrchestrationUpdate(
tools_used=[next_tool_name],
query_list=query_list or [],
iteration_nr=iteration_nr,
current_step_nr=current_step_nr,
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="orchestrator",
node_start_time=node_start_time,
)
],
plan_of_record=plan_of_record,
remaining_time_budget=remaining_time_budget,
iteration_instructions=[
IterationInstructions(
iteration_nr=iteration_nr,
plan=plan_of_record.plan if plan_of_record else None,
reasoning=reasoning_result,
purpose=purpose,
)
],
)

View File

@@ -0,0 +1,384 @@
import re
from datetime import datetime
from typing import cast
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from sqlalchemy.orm import Session
from onyx.agents.agent_search.dr.constants import MAX_CHAT_HISTORY_MESSAGES
from onyx.agents.agent_search.dr.constants import MAX_NUM_CLOSER_SUGGESTIONS
from onyx.agents.agent_search.dr.enums import DRPath
from onyx.agents.agent_search.dr.enums import ResearchAnswerPurpose
from onyx.agents.agent_search.dr.enums import ResearchType
from onyx.agents.agent_search.dr.models import AggregatedDRContext
from onyx.agents.agent_search.dr.models import TestInfoCompleteResponse
from onyx.agents.agent_search.dr.states import FinalUpdate
from onyx.agents.agent_search.dr.states import MainState
from onyx.agents.agent_search.dr.states import OrchestrationUpdate
from onyx.agents.agent_search.dr.sub_agents.image_generation.models import (
GeneratedImageFullResult,
)
from onyx.agents.agent_search.dr.utils import aggregate_context
from onyx.agents.agent_search.dr.utils import convert_inference_sections_to_search_docs
from onyx.agents.agent_search.dr.utils import get_chat_history_string
from onyx.agents.agent_search.dr.utils import get_prompt_question
from onyx.agents.agent_search.dr.utils import parse_plan_to_dict
from onyx.agents.agent_search.dr.utils import update_db_session_with_messages
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.llm import invoke_llm_json
from onyx.agents.agent_search.shared_graph_utils.llm import stream_llm_answer
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.agents.agent_search.utils import create_question_prompt
from onyx.chat.chat_utils import llm_doc_from_inference_section
from onyx.context.search.models import InferenceSection
from onyx.db.chat import create_search_doc_from_inference_section
from onyx.db.models import ChatMessage__SearchDoc
from onyx.db.models import ResearchAgentIteration
from onyx.db.models import ResearchAgentIterationSubStep
from onyx.db.models import SearchDoc as DbSearchDoc
from onyx.prompts.dr_prompts import FINAL_ANSWER_PROMPT_W_SUB_ANSWERS
from onyx.prompts.dr_prompts import FINAL_ANSWER_PROMPT_WITHOUT_SUB_ANSWERS
from onyx.prompts.dr_prompts import TEST_INFO_COMPLETE_PROMPT
from onyx.server.query_and_chat.streaming_models import CitationDelta
from onyx.server.query_and_chat.streaming_models import CitationStart
from onyx.server.query_and_chat.streaming_models import MessageStart
from onyx.server.query_and_chat.streaming_models import SectionEnd
from onyx.server.query_and_chat.streaming_models import StreamingType
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
logger = setup_logger()
def extract_citation_numbers(text: str) -> list[int]:
"""
Extract all citation numbers from text in the format [[<number>]] or [[<number_1>, <number_2>, ...]].
Returns a list of all unique citation numbers found.
"""
# Pattern to match [[number]] or [[number1, number2, ...]]
pattern = r"\[\[(\d+(?:,\s*\d+)*)\]\]"
matches = re.findall(pattern, text)
cited_numbers = []
for match in matches:
# Split by comma and extract all numbers
numbers = [int(num.strip()) for num in match.split(",")]
cited_numbers.extend(numbers)
return list(set(cited_numbers)) # Return unique numbers
def replace_citation_with_link(match: re.Match[str], docs: list[DbSearchDoc]) -> str:
citation_content = match.group(1) # e.g., "3" or "3, 5, 7"
numbers = [int(num.strip()) for num in citation_content.split(",")]
# For multiple citations like [[3, 5, 7]], create separate linked citations
linked_citations = []
for num in numbers:
if num - 1 < len(docs): # Check bounds
link = docs[num - 1].link or ""
linked_citations.append(f"[[{num}]]({link})")
else:
linked_citations.append(f"[[{num}]]") # No link if out of bounds
return "".join(linked_citations)
def insert_chat_message_search_doc_pair(
message_id: int, search_doc_ids: list[int], db_session: Session
) -> None:
"""
Insert a pair of message_id and search_doc_id into the chat_message__search_doc table.
Args:
message_id: The ID of the chat message
search_doc_id: The ID of the search document
db_session: The database session
"""
for search_doc_id in search_doc_ids:
chat_message_search_doc = ChatMessage__SearchDoc(
chat_message_id=message_id, search_doc_id=search_doc_id
)
db_session.add(chat_message_search_doc)
def save_iteration(
state: MainState,
graph_config: GraphConfig,
aggregated_context: AggregatedDRContext,
final_answer: str,
all_cited_documents: list[InferenceSection],
is_internet_marker_dict: dict[str, bool],
) -> None:
db_session = graph_config.persistence.db_session
message_id = graph_config.persistence.message_id
research_type = graph_config.behavior.research_type
db_session = graph_config.persistence.db_session
# first, insert the search_docs
search_docs = [
create_search_doc_from_inference_section(
inference_section=inference_section,
is_internet=is_internet_marker_dict.get(
inference_section.center_chunk.document_id, False
), # TODO: revisit
db_session=db_session,
commit=False,
)
for inference_section in all_cited_documents
]
# then, map_search_docs to message
insert_chat_message_search_doc_pair(
message_id, [search_doc.id for search_doc in search_docs], db_session
)
# lastly, insert the citations
citation_dict: dict[int, int] = {}
cited_doc_nrs = extract_citation_numbers(final_answer)
for cited_doc_nr in cited_doc_nrs:
citation_dict[cited_doc_nr] = search_docs[cited_doc_nr - 1].id
# TODO: generate plan as dict in the first place
plan_of_record = state.plan_of_record.plan if state.plan_of_record else ""
plan_of_record_dict = parse_plan_to_dict(plan_of_record)
# Update the chat message and its parent message in database
update_db_session_with_messages(
db_session=db_session,
chat_message_id=message_id,
chat_session_id=str(graph_config.persistence.chat_session_id),
is_agentic=graph_config.behavior.use_agentic_search,
message=final_answer,
citations=citation_dict,
research_type=research_type,
research_plan=plan_of_record_dict,
final_documents=search_docs,
update_parent_message=True,
research_answer_purpose=ResearchAnswerPurpose.ANSWER,
)
for iteration_preparation in state.iteration_instructions:
research_agent_iteration_step = ResearchAgentIteration(
primary_question_id=message_id,
reasoning=iteration_preparation.reasoning,
purpose=iteration_preparation.purpose,
iteration_nr=iteration_preparation.iteration_nr,
)
db_session.add(research_agent_iteration_step)
for iteration_answer in aggregated_context.global_iteration_responses:
retrieved_search_docs = convert_inference_sections_to_search_docs(
list(iteration_answer.cited_documents.values())
)
# Convert SavedSearchDoc objects to JSON-serializable format
serialized_search_docs = [doc.model_dump() for doc in retrieved_search_docs]
research_agent_iteration_sub_step = ResearchAgentIterationSubStep(
primary_question_id=message_id,
iteration_nr=iteration_answer.iteration_nr,
iteration_sub_step_nr=iteration_answer.parallelization_nr,
sub_step_instructions=iteration_answer.question,
sub_step_tool_id=iteration_answer.tool_id,
sub_answer=iteration_answer.answer,
reasoning=iteration_answer.reasoning,
claims=iteration_answer.claims,
cited_doc_results=serialized_search_docs,
generated_images=(
GeneratedImageFullResult(images=iteration_answer.generated_images)
if iteration_answer.generated_images
else None
),
additional_data=iteration_answer.additional_data,
)
db_session.add(research_agent_iteration_sub_step)
db_session.commit()
def closer(
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> FinalUpdate | OrchestrationUpdate:
"""
LangGraph node to close the DR process and finalize the answer.
"""
node_start_time = datetime.now()
# TODO: generate final answer using all the previous steps
# (right now, answers from each step are concatenated onto each other)
# Also, add missing fields once usage in UI is clear.
current_step_nr = state.current_step_nr
graph_config = cast(GraphConfig, config["metadata"]["config"])
base_question = state.original_question
if not base_question:
raise ValueError("Question is required for closer")
research_type = graph_config.behavior.research_type
assistant_system_prompt = state.assistant_system_prompt
assistant_task_prompt = state.assistant_task_prompt
uploaded_context = state.uploaded_test_context or ""
clarification = state.clarification
prompt_question = get_prompt_question(base_question, clarification)
chat_history_string = (
get_chat_history_string(
graph_config.inputs.prompt_builder.message_history,
MAX_CHAT_HISTORY_MESSAGES,
)
or "(No chat history yet available)"
)
aggregated_context = aggregate_context(
state.iteration_responses, include_documents=True
)
iteration_responses_string = aggregated_context.context
all_cited_documents = aggregated_context.cited_documents
aggregated_context.is_internet_marker_dict
num_closer_suggestions = state.num_closer_suggestions
if (
num_closer_suggestions < MAX_NUM_CLOSER_SUGGESTIONS
and research_type == ResearchType.DEEP
):
test_info_complete_prompt = TEST_INFO_COMPLETE_PROMPT.build(
base_question=prompt_question,
questions_answers_claims=iteration_responses_string,
chat_history_string=chat_history_string,
high_level_plan=(
state.plan_of_record.plan
if state.plan_of_record
else "No plan available"
),
)
test_info_complete_json = invoke_llm_json(
llm=graph_config.tooling.primary_llm,
prompt=create_question_prompt(
assistant_system_prompt,
test_info_complete_prompt + (assistant_task_prompt or ""),
),
schema=TestInfoCompleteResponse,
timeout_override=40,
# max_tokens=1000,
)
if test_info_complete_json.complete:
pass
else:
return OrchestrationUpdate(
tools_used=[DRPath.ORCHESTRATOR.value],
query_list=[],
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="closer",
node_start_time=node_start_time,
)
],
gaps=test_info_complete_json.gaps,
num_closer_suggestions=num_closer_suggestions + 1,
)
retrieved_search_docs = convert_inference_sections_to_search_docs(
all_cited_documents
)
write_custom_event(
current_step_nr,
MessageStart(
content="",
final_documents=retrieved_search_docs,
),
writer,
)
if research_type == ResearchType.THOUGHTFUL:
final_answer_base_prompt = FINAL_ANSWER_PROMPT_WITHOUT_SUB_ANSWERS
else:
final_answer_base_prompt = FINAL_ANSWER_PROMPT_W_SUB_ANSWERS
final_answer_prompt = final_answer_base_prompt.build(
base_question=prompt_question,
iteration_responses_string=iteration_responses_string,
chat_history_string=chat_history_string,
uploaded_context=uploaded_context,
)
all_context_llmdocs = [
llm_doc_from_inference_section(inference_section)
for inference_section in all_cited_documents
]
try:
streamed_output, _, citation_infos = run_with_timeout(
240,
lambda: stream_llm_answer(
llm=graph_config.tooling.primary_llm,
prompt=create_question_prompt(
assistant_system_prompt,
final_answer_prompt + (assistant_task_prompt or ""),
),
event_name="basic_response",
writer=writer,
agent_answer_level=0,
agent_answer_question_num=0,
agent_answer_type="agent_level_answer",
timeout_override=60,
answer_piece=StreamingType.MESSAGE_DELTA.value,
ind=current_step_nr,
context_docs=all_context_llmdocs,
replace_citations=True,
# max_tokens=None,
),
)
final_answer = "".join(streamed_output)
except Exception as e:
raise ValueError(f"Error in consolidate_research: {e}")
write_custom_event(current_step_nr, SectionEnd(), writer)
current_step_nr += 1
write_custom_event(current_step_nr, CitationStart(), writer)
write_custom_event(current_step_nr, CitationDelta(citations=citation_infos), writer)
write_custom_event(current_step_nr, SectionEnd(), writer)
current_step_nr += 1
# Log the research agent steps
# save_iteration(
# state,
# graph_config,
# aggregated_context,
# final_answer,
# all_cited_documents,
# is_internet_marker_dict,
# )
return FinalUpdate(
final_answer=final_answer,
all_cited_documents=all_cited_documents,
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="closer",
node_start_time=node_start_time,
)
],
)

View File

@@ -0,0 +1,235 @@
import re
from datetime import datetime
from typing import cast
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from sqlalchemy.orm import Session
from onyx.agents.agent_search.dr.enums import ResearchAnswerPurpose
from onyx.agents.agent_search.dr.models import AggregatedDRContext
from onyx.agents.agent_search.dr.states import LoggerUpdate
from onyx.agents.agent_search.dr.states import MainState
from onyx.agents.agent_search.dr.sub_agents.image_generation.models import (
GeneratedImageFullResult,
)
from onyx.agents.agent_search.dr.utils import aggregate_context
from onyx.agents.agent_search.dr.utils import convert_inference_sections_to_search_docs
from onyx.agents.agent_search.dr.utils import parse_plan_to_dict
from onyx.agents.agent_search.dr.utils import update_db_session_with_messages
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.context.search.models import InferenceSection
from onyx.db.chat import create_search_doc_from_inference_section
from onyx.db.models import ChatMessage__SearchDoc
from onyx.db.models import ResearchAgentIteration
from onyx.db.models import ResearchAgentIterationSubStep
from onyx.db.models import SearchDoc as DbSearchDoc
from onyx.server.query_and_chat.streaming_models import OverallStop
from onyx.utils.logger import setup_logger
logger = setup_logger()
def _extract_citation_numbers(text: str) -> list[int]:
"""
Extract all citation numbers from text in the format [[<number>]] or [[<number_1>, <number_2>, ...]].
Returns a list of all unique citation numbers found.
"""
# Pattern to match [[number]] or [[number1, number2, ...]]
pattern = r"\[\[(\d+(?:,\s*\d+)*)\]\]"
matches = re.findall(pattern, text)
cited_numbers = []
for match in matches:
# Split by comma and extract all numbers
numbers = [int(num.strip()) for num in match.split(",")]
cited_numbers.extend(numbers)
return list(set(cited_numbers)) # Return unique numbers
def replace_citation_with_link(match: re.Match[str], docs: list[DbSearchDoc]) -> str:
citation_content = match.group(1) # e.g., "3" or "3, 5, 7"
numbers = [int(num.strip()) for num in citation_content.split(",")]
# For multiple citations like [[3, 5, 7]], create separate linked citations
linked_citations = []
for num in numbers:
if num - 1 < len(docs): # Check bounds
link = docs[num - 1].link or ""
linked_citations.append(f"[[{num}]]({link})")
else:
linked_citations.append(f"[[{num}]]") # No link if out of bounds
return "".join(linked_citations)
def _insert_chat_message_search_doc_pair(
message_id: int, search_doc_ids: list[int], db_session: Session
) -> None:
"""
Insert a pair of message_id and search_doc_id into the chat_message__search_doc table.
Args:
message_id: The ID of the chat message
search_doc_id: The ID of the search document
db_session: The database session
"""
for search_doc_id in search_doc_ids:
chat_message_search_doc = ChatMessage__SearchDoc(
chat_message_id=message_id, search_doc_id=search_doc_id
)
db_session.add(chat_message_search_doc)
def save_iteration(
state: MainState,
graph_config: GraphConfig,
aggregated_context: AggregatedDRContext,
final_answer: str,
all_cited_documents: list[InferenceSection],
is_internet_marker_dict: dict[str, bool],
) -> None:
db_session = graph_config.persistence.db_session
message_id = graph_config.persistence.message_id
research_type = graph_config.behavior.research_type
db_session = graph_config.persistence.db_session
# first, insert the search_docs
search_docs = [
create_search_doc_from_inference_section(
inference_section=inference_section,
is_internet=is_internet_marker_dict.get(
inference_section.center_chunk.document_id, False
), # TODO: revisit
db_session=db_session,
commit=False,
)
for inference_section in all_cited_documents
]
# then, map_search_docs to message
_insert_chat_message_search_doc_pair(
message_id, [search_doc.id for search_doc in search_docs], db_session
)
# lastly, insert the citations
citation_dict: dict[int, int] = {}
cited_doc_nrs = _extract_citation_numbers(final_answer)
if search_docs:
for cited_doc_nr in cited_doc_nrs:
citation_dict[cited_doc_nr] = search_docs[cited_doc_nr - 1].id
# TODO: generate plan as dict in the first place
plan_of_record = state.plan_of_record.plan if state.plan_of_record else ""
plan_of_record_dict = parse_plan_to_dict(plan_of_record)
# Update the chat message and its parent message in database
update_db_session_with_messages(
db_session=db_session,
chat_message_id=message_id,
chat_session_id=str(graph_config.persistence.chat_session_id),
is_agentic=graph_config.behavior.use_agentic_search,
message=final_answer,
citations=citation_dict,
research_type=research_type,
research_plan=plan_of_record_dict,
final_documents=search_docs,
update_parent_message=True,
research_answer_purpose=ResearchAnswerPurpose.ANSWER,
)
for iteration_preparation in state.iteration_instructions:
research_agent_iteration_step = ResearchAgentIteration(
primary_question_id=message_id,
reasoning=iteration_preparation.reasoning,
purpose=iteration_preparation.purpose,
iteration_nr=iteration_preparation.iteration_nr,
)
db_session.add(research_agent_iteration_step)
for iteration_answer in aggregated_context.global_iteration_responses:
retrieved_search_docs = convert_inference_sections_to_search_docs(
list(iteration_answer.cited_documents.values())
)
# Convert SavedSearchDoc objects to JSON-serializable format
serialized_search_docs = [doc.model_dump() for doc in retrieved_search_docs]
research_agent_iteration_sub_step = ResearchAgentIterationSubStep(
primary_question_id=message_id,
iteration_nr=iteration_answer.iteration_nr,
iteration_sub_step_nr=iteration_answer.parallelization_nr,
sub_step_instructions=iteration_answer.question,
sub_step_tool_id=iteration_answer.tool_id,
sub_answer=iteration_answer.answer,
reasoning=iteration_answer.reasoning,
claims=iteration_answer.claims,
cited_doc_results=serialized_search_docs,
generated_images=(
GeneratedImageFullResult(images=iteration_answer.generated_images)
if iteration_answer.generated_images
else None
),
additional_data=iteration_answer.additional_data,
)
db_session.add(research_agent_iteration_sub_step)
db_session.commit()
def logging(
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> LoggerUpdate:
"""
LangGraph node to close the DR process and finalize the answer.
"""
node_start_time = datetime.now()
# TODO: generate final answer using all the previous steps
# (right now, answers from each step are concatenated onto each other)
# Also, add missing fields once usage in UI is clear.
current_step_nr = state.current_step_nr
graph_config = cast(GraphConfig, config["metadata"]["config"])
base_question = state.original_question
if not base_question:
raise ValueError("Question is required for closer")
aggregated_context = aggregate_context(
state.iteration_responses, include_documents=True
)
all_cited_documents = aggregated_context.cited_documents
is_internet_marker_dict = aggregated_context.is_internet_marker_dict
final_answer = state.final_answer or ""
write_custom_event(current_step_nr, OverallStop(), writer)
# Log the research agent steps
save_iteration(
state,
graph_config,
aggregated_context,
final_answer,
all_cited_documents,
is_internet_marker_dict,
)
return LoggerUpdate(
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="logger",
node_start_time=node_start_time,
)
],
)

View File

@@ -0,0 +1,117 @@
from collections.abc import Iterator
from typing import cast
from langchain_core.messages import AIMessageChunk
from langchain_core.messages import BaseMessage
from langgraph.types import StreamWriter
from pydantic import BaseModel
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.chat_utils import saved_search_docs_from_llm_docs
from onyx.chat.models import AgentAnswerPiece
from onyx.chat.models import LlmDoc
from onyx.chat.models import OnyxAnswerPiece
from onyx.chat.stream_processing.answer_response_handler import AnswerResponseHandler
from onyx.chat.stream_processing.answer_response_handler import CitationResponseHandler
from onyx.chat.stream_processing.answer_response_handler import (
PassThroughAnswerResponseHandler,
)
from onyx.chat.stream_processing.utils import map_document_id_order
from onyx.context.search.models import InferenceSection
from onyx.server.query_and_chat.streaming_models import MessageDelta
from onyx.server.query_and_chat.streaming_models import MessageStart
from onyx.server.query_and_chat.streaming_models import SectionEnd
from onyx.utils.logger import setup_logger
logger = setup_logger()
class BasicSearchProcessedStreamResults(BaseModel):
ai_message_chunk: AIMessageChunk = AIMessageChunk(content="")
full_answer: str | None = None
cited_references: list[InferenceSection] = []
retrieved_documents: list[LlmDoc] = []
def process_llm_stream(
messages: Iterator[BaseMessage],
should_stream_answer: bool,
writer: StreamWriter,
ind: int,
final_search_results: list[LlmDoc] | None = None,
displayed_search_results: list[LlmDoc] | None = None,
generate_final_answer: bool = False,
chat_message_id: str | None = None,
) -> BasicSearchProcessedStreamResults:
tool_call_chunk = AIMessageChunk(content="")
if final_search_results and displayed_search_results:
answer_handler: AnswerResponseHandler = CitationResponseHandler(
context_docs=final_search_results,
final_doc_id_to_rank_map=map_document_id_order(final_search_results),
display_doc_id_to_rank_map=map_document_id_order(displayed_search_results),
)
else:
answer_handler = PassThroughAnswerResponseHandler()
full_answer = ""
start_final_answer_streaming_set = False
# This stream will be the llm answer if no tool is chosen. When a tool is chosen,
# the stream will contain AIMessageChunks with tool call information.
for message in messages:
answer_piece = message.content
if not isinstance(answer_piece, str):
# this is only used for logging, so fine to
# just add the string representation
answer_piece = str(answer_piece)
full_answer += answer_piece
if isinstance(message, AIMessageChunk) and (
message.tool_call_chunks or message.tool_calls
):
tool_call_chunk += message # type: ignore
elif should_stream_answer:
for response_part in answer_handler.handle_response_part(message, []):
# only stream out answer parts
if (
isinstance(response_part, (OnyxAnswerPiece, AgentAnswerPiece))
and generate_final_answer
and response_part.answer_piece
):
if chat_message_id is None:
raise ValueError(
"chat_message_id is required when generating final answer"
)
if not start_final_answer_streaming_set:
# Convert LlmDocs to SavedSearchDocs
saved_search_docs = saved_search_docs_from_llm_docs(
final_search_results
)
write_custom_event(
ind,
MessageStart(content="", final_documents=saved_search_docs),
writer,
)
start_final_answer_streaming_set = True
write_custom_event(
ind,
MessageDelta(content=response_part.answer_piece),
writer,
)
if generate_final_answer and start_final_answer_streaming_set:
# start_final_answer_streaming_set is only set if the answer is verbal and not a tool call
write_custom_event(
ind,
SectionEnd(),
writer,
)
logger.debug(f"Full answer: {full_answer}")
return BasicSearchProcessedStreamResults(
ai_message_chunk=cast(AIMessageChunk, tool_call_chunk), full_answer=full_answer
)

View File

@@ -0,0 +1,82 @@
from operator import add
from typing import Annotated
from typing import Any
from typing import TypedDict
from pydantic import BaseModel
from onyx.agents.agent_search.core_state import CoreState
from onyx.agents.agent_search.dr.models import IterationAnswer
from onyx.agents.agent_search.dr.models import IterationInstructions
from onyx.agents.agent_search.dr.models import OrchestrationClarificationInfo
from onyx.agents.agent_search.dr.models import OrchestrationPlan
from onyx.agents.agent_search.dr.models import OrchestratorTool
from onyx.context.search.models import InferenceSection
from onyx.db.connector import DocumentSource
### States ###
class LoggerUpdate(BaseModel):
log_messages: Annotated[list[str], add] = []
class OrchestrationUpdate(LoggerUpdate):
tools_used: Annotated[list[str], add] = []
query_list: list[str] = []
iteration_nr: int = 0
current_step_nr: int = 1
plan_of_record: OrchestrationPlan | None = None # None for Thoughtful
remaining_time_budget: float = 2.0 # set by default to about 2 searches
num_closer_suggestions: int = 0 # how many times the closer was suggested
gaps: list[str] = (
[]
) # gaps that may be identified by the closer before being able to answer the question.
iteration_instructions: Annotated[list[IterationInstructions], add] = []
class OrchestrationSetup(OrchestrationUpdate):
original_question: str | None = None
chat_history_string: str | None = None
clarification: OrchestrationClarificationInfo | None = None
available_tools: dict[str, OrchestratorTool] | None = None
num_closer_suggestions: int = 0 # how many times the closer was suggested
active_source_types: list[DocumentSource] | None = None
active_source_types_descriptions: str | None = None
assistant_system_prompt: str | None = None
assistant_task_prompt: str | None = None
uploaded_test_context: str | None = None
uploaded_image_context: list[dict[str, Any]] | None = None
class AnswerUpdate(LoggerUpdate):
iteration_responses: Annotated[list[IterationAnswer], add] = []
class FinalUpdate(LoggerUpdate):
final_answer: str | None = None
all_cited_documents: list[InferenceSection] = []
## Graph Input State
class MainInput(CoreState):
pass
## Graph State
class MainState(
# This includes the core state
MainInput,
OrchestrationSetup,
AnswerUpdate,
FinalUpdate,
):
pass
## Graph Output State
class MainOutput(TypedDict):
log_messages: list[str]
final_answer: str | None
all_cited_documents: list[InferenceSection]

View File

@@ -0,0 +1,47 @@
from datetime import datetime
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.dr.states import LoggerUpdate
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.server.query_and_chat.streaming_models import SearchToolStart
from onyx.utils.logger import setup_logger
logger = setup_logger()
def basic_search_branch(
state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> LoggerUpdate:
"""
LangGraph node to perform a standard search as part of the DR process.
"""
node_start_time = datetime.now()
iteration_nr = state.iteration_nr
current_step_nr = state.current_step_nr
logger.debug(f"Search start for Basic Search {iteration_nr} at {datetime.now()}")
write_custom_event(
current_step_nr,
SearchToolStart(
is_internet_search=False,
),
writer,
)
return LoggerUpdate(
log_messages=[
get_langgraph_node_log_string(
graph_component="basic_search",
node_name="branching",
node_start_time=node_start_time,
)
],
)

View File

@@ -0,0 +1,266 @@
import re
from datetime import datetime
from typing import cast
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.dr.enums import ResearchType
from onyx.agents.agent_search.dr.models import BaseSearchProcessingResponse
from onyx.agents.agent_search.dr.models import IterationAnswer
from onyx.agents.agent_search.dr.models import SearchAnswer
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
from onyx.agents.agent_search.dr.sub_agents.states import BranchUpdate
from onyx.agents.agent_search.dr.utils import convert_inference_sections_to_search_docs
from onyx.agents.agent_search.dr.utils import extract_document_citations
from onyx.agents.agent_search.kb_search.graph_utils import build_document_context
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.llm import invoke_llm_json
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.agents.agent_search.utils import create_question_prompt
from onyx.chat.models import LlmDoc
from onyx.context.search.models import InferenceSection
from onyx.db.connector import DocumentSource
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.prompts.dr_prompts import BASE_SEARCH_PROCESSING_PROMPT
from onyx.prompts.dr_prompts import INTERNAL_SEARCH_PROMPTS
from onyx.server.query_and_chat.streaming_models import SearchToolDelta
from onyx.tools.models import SearchToolOverrideKwargs
from onyx.tools.tool_implementations.search.search_tool import (
SEARCH_RESPONSE_SUMMARY_ID,
)
from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary
from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.utils.logger import setup_logger
logger = setup_logger()
def basic_search(
state: BranchInput,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> BranchUpdate:
"""
LangGraph node to perform a standard search as part of the DR process.
"""
node_start_time = datetime.now()
iteration_nr = state.iteration_nr
parallelization_nr = state.parallelization_nr
current_step_nr = state.current_step_nr
assistant_system_prompt = state.assistant_system_prompt
assistant_task_prompt = state.assistant_task_prompt
branch_query = state.branch_question
if not branch_query:
raise ValueError("branch_query is not set")
graph_config = cast(GraphConfig, config["metadata"]["config"])
base_question = graph_config.inputs.prompt_builder.raw_user_query
research_type = graph_config.behavior.research_type
if not state.available_tools:
raise ValueError("available_tools is not set")
elif len(state.tools_used) == 0:
raise ValueError("tools_used is empty")
search_tool_info = state.available_tools[state.tools_used[-1]]
search_tool = cast(SearchTool, search_tool_info.tool_object)
# sanity check
if search_tool != graph_config.tooling.search_tool:
raise ValueError("search_tool does not match the configured search tool")
# rewrite query and identify source types
active_source_types_str = ", ".join(
[source.value for source in state.active_source_types or []]
)
base_search_processing_prompt = BASE_SEARCH_PROCESSING_PROMPT.build(
active_source_types_str=active_source_types_str,
branch_query=branch_query,
current_time=datetime.now().strftime("%Y-%m-%d %H:%M"),
)
try:
search_processing = invoke_llm_json(
llm=graph_config.tooling.primary_llm,
prompt=create_question_prompt(
assistant_system_prompt, base_search_processing_prompt
),
schema=BaseSearchProcessingResponse,
timeout_override=15,
# max_tokens=100,
)
except Exception as e:
logger.error(f"Could not process query: {e}")
raise e
rewritten_query = search_processing.rewritten_query
# give back the query so we can render it in the UI
write_custom_event(
current_step_nr,
SearchToolDelta(
queries=[rewritten_query],
documents=[],
),
writer,
)
implied_start_date = search_processing.time_filter
# Validate time_filter format if it exists
implied_time_filter = None
if implied_start_date:
# Check if time_filter is in YYYY-MM-DD format
date_pattern = r"^\d{4}-\d{2}-\d{2}$"
if re.match(date_pattern, implied_start_date):
implied_time_filter = datetime.strptime(implied_start_date, "%Y-%m-%d")
specified_source_types: list[DocumentSource] | None = [
DocumentSource(source_type)
for source_type in search_processing.specified_source_types
]
if specified_source_types is not None and len(specified_source_types) == 0:
specified_source_types = None
logger.debug(
f"Search start for Standard Search {iteration_nr}.{parallelization_nr} at {datetime.now()}"
)
retrieved_docs: list[InferenceSection] = []
callback_container: list[list[InferenceSection]] = []
# new db session to avoid concurrency issues
with get_session_with_current_tenant() as search_db_session:
for tool_response in search_tool.run(
query=rewritten_query,
document_sources=specified_source_types,
time_filter=implied_time_filter,
override_kwargs=SearchToolOverrideKwargs(
force_no_rerank=True,
alternate_db_session=search_db_session,
retrieved_sections_callback=callback_container.append,
skip_query_analysis=True,
),
):
# get retrieved docs to send to the rest of the graph
if tool_response.id == SEARCH_RESPONSE_SUMMARY_ID:
response = cast(SearchResponseSummary, tool_response.response)
retrieved_docs = response.top_sections
break
# render the retrieved docs in the UI
write_custom_event(
current_step_nr,
SearchToolDelta(
queries=[],
documents=convert_inference_sections_to_search_docs(
retrieved_docs, is_internet=False
),
),
writer,
)
document_texts_list = []
for doc_num, retrieved_doc in enumerate(retrieved_docs[:15]):
if not isinstance(retrieved_doc, (InferenceSection, LlmDoc)):
raise ValueError(f"Unexpected document type: {type(retrieved_doc)}")
chunk_text = build_document_context(retrieved_doc, doc_num + 1)
document_texts_list.append(chunk_text)
document_texts = "\n\n".join(document_texts_list)
logger.debug(
f"Search end/LLM start for Standard Search {iteration_nr}.{parallelization_nr} at {datetime.now()}"
)
# Built prompt
if research_type == ResearchType.DEEP:
search_prompt = INTERNAL_SEARCH_PROMPTS[research_type].build(
search_query=branch_query,
base_question=base_question,
document_text=document_texts,
)
# Run LLM
# search_answer_json = None
search_answer_json = invoke_llm_json(
llm=graph_config.tooling.primary_llm,
prompt=create_question_prompt(
assistant_system_prompt, search_prompt + (assistant_task_prompt or "")
),
schema=SearchAnswer,
timeout_override=40,
# max_tokens=1500,
)
logger.debug(
f"LLM/all done for Standard Search {iteration_nr}.{parallelization_nr} at {datetime.now()}"
)
# get cited documents
answer_string = search_answer_json.answer
claims = search_answer_json.claims or []
reasoning = search_answer_json.reasoning
# answer_string = ""
# claims = []
(
citation_numbers,
answer_string,
claims,
) = extract_document_citations(answer_string, claims)
if citation_numbers and max(citation_numbers) > len(retrieved_docs):
raise ValueError("Citation numbers are out of range for retrieved docs.")
cited_documents = {
citation_number: retrieved_docs[citation_number - 1]
for citation_number in citation_numbers
}
else:
answer_string = ""
claims = []
cited_documents = {
doc_num + 1: retrieved_doc
for doc_num, retrieved_doc in enumerate(retrieved_docs[:15])
}
reasoning = ""
return BranchUpdate(
branch_iteration_responses=[
IterationAnswer(
tool=search_tool_info.llm_path,
tool_id=search_tool_info.tool_id,
iteration_nr=iteration_nr,
parallelization_nr=parallelization_nr,
question=branch_query,
answer=answer_string,
claims=claims,
cited_documents=cited_documents,
reasoning=reasoning,
additional_data=None,
)
],
log_messages=[
get_langgraph_node_log_string(
graph_component="basic_search",
node_name="searching",
node_start_time=node_start_time,
)
],
)

View File

@@ -0,0 +1,77 @@
from datetime import datetime
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentUpdate
from onyx.agents.agent_search.dr.utils import chunks_or_sections_to_search_docs
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.context.search.models import SavedSearchDoc
from onyx.server.query_and_chat.streaming_models import SectionEnd
from onyx.utils.logger import setup_logger
logger = setup_logger()
def is_reducer(
state: SubAgentMainState,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> SubAgentUpdate:
"""
LangGraph node to perform a standard search as part of the DR process.
"""
node_start_time = datetime.now()
branch_updates = state.branch_iteration_responses
current_iteration = state.iteration_nr
current_step_nr = state.current_step_nr
new_updates = [
update for update in branch_updates if update.iteration_nr == current_iteration
]
[update.question for update in new_updates]
doc_lists = [list(update.cited_documents.values()) for update in new_updates]
doc_list = []
for xs in doc_lists:
for x in xs:
doc_list.append(x)
# Convert InferenceSections to SavedSearchDocs
search_docs = chunks_or_sections_to_search_docs(doc_list)
retrieved_saved_search_docs = [
SavedSearchDoc.from_search_doc(search_doc, db_doc_id=0)
for search_doc in search_docs
]
for retrieved_saved_search_doc in retrieved_saved_search_docs:
retrieved_saved_search_doc.is_internet = False
write_custom_event(
current_step_nr,
SectionEnd(),
writer,
)
current_step_nr += 1
return SubAgentUpdate(
iteration_responses=new_updates,
current_step_nr=current_step_nr,
log_messages=[
get_langgraph_node_log_string(
graph_component="basic_search",
node_name="consolidation",
node_start_time=node_start_time,
)
],
)

View File

@@ -0,0 +1,50 @@
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_basic_search_1_branch import (
basic_search_branch,
)
from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_basic_search_2_act import (
basic_search,
)
from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_basic_search_3_reduce import (
is_reducer,
)
from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_image_generation_conditional_edges import (
branching_router,
)
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
from onyx.utils.logger import setup_logger
logger = setup_logger()
def dr_basic_search_graph_builder() -> StateGraph:
"""
LangGraph graph builder for Internet Search Sub-Agent
"""
graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
### Add nodes ###
graph.add_node("branch", basic_search_branch)
graph.add_node("act", basic_search)
graph.add_node("reducer", is_reducer)
### Add edges ###
graph.add_edge(start_key=START, end_key="branch")
graph.add_conditional_edges("branch", branching_router)
graph.add_edge(start_key="act", end_key="reducer")
graph.add_edge(start_key="reducer", end_key=END)
return graph

View File

@@ -0,0 +1,30 @@
from collections.abc import Hashable
from langgraph.types import Send
from onyx.agents.agent_search.dr.constants import MAX_DR_PARALLEL_SEARCH
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
return [
Send(
"act",
BranchInput(
iteration_nr=state.iteration_nr,
parallelization_nr=parallelization_nr,
branch_question=query,
current_step_nr=state.current_step_nr,
context="",
active_source_types=state.active_source_types,
tools_used=state.tools_used,
available_tools=state.available_tools,
assistant_system_prompt=state.assistant_system_prompt,
assistant_task_prompt=state.assistant_task_prompt,
),
)
for parallelization_nr, query in enumerate(
state.query_list[:MAX_DR_PARALLEL_SEARCH]
)
]

View File

@@ -0,0 +1,36 @@
from datetime import datetime
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.dr.states import LoggerUpdate
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.utils.logger import setup_logger
logger = setup_logger()
def custom_tool_branch(
state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> LoggerUpdate:
"""
LangGraph node to perform a generic tool call as part of the DR process.
"""
node_start_time = datetime.now()
iteration_nr = state.iteration_nr
logger.debug(f"Search start for Generic Tool {iteration_nr} at {datetime.now()}")
return LoggerUpdate(
log_messages=[
get_langgraph_node_log_string(
graph_component="custom_tool",
node_name="branching",
node_start_time=node_start_time,
)
],
)

View File

@@ -0,0 +1,167 @@
import json
from datetime import datetime
from typing import cast
from langchain_core.messages import AIMessage
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
from onyx.agents.agent_search.dr.sub_agents.states import BranchUpdate
from onyx.agents.agent_search.dr.sub_agents.states import IterationAnswer
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.prompts.dr_prompts import CUSTOM_TOOL_PREP_PROMPT
from onyx.prompts.dr_prompts import CUSTOM_TOOL_USE_PROMPT
from onyx.tools.tool_implementations.custom.custom_tool import CUSTOM_TOOL_RESPONSE_ID
from onyx.tools.tool_implementations.custom.custom_tool import CustomTool
from onyx.tools.tool_implementations.custom.custom_tool import CustomToolCallSummary
from onyx.tools.tool_implementations.mcp.mcp_tool import MCP_TOOL_RESPONSE_ID
from onyx.utils.logger import setup_logger
logger = setup_logger()
def custom_tool_act(
state: BranchInput,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> BranchUpdate:
"""
LangGraph node to perform a generic tool call as part of the DR process.
"""
node_start_time = datetime.now()
iteration_nr = state.iteration_nr
parallelization_nr = state.parallelization_nr
if not state.available_tools:
raise ValueError("available_tools is not set")
custom_tool_info = state.available_tools[state.tools_used[-1]]
custom_tool_name = custom_tool_info.name
custom_tool = cast(CustomTool, custom_tool_info.tool_object)
branch_query = state.branch_question
if not branch_query:
raise ValueError("branch_query is not set")
graph_config = cast(GraphConfig, config["metadata"]["config"])
base_question = graph_config.inputs.prompt_builder.raw_user_query
logger.debug(
f"Tool call start for {custom_tool_name} {iteration_nr}.{parallelization_nr} at {datetime.now()}"
)
# get tool call args
tool_args: dict | None = None
if graph_config.tooling.using_tool_calling_llm:
# get tool call args from tool-calling LLM
tool_use_prompt = CUSTOM_TOOL_PREP_PROMPT.build(
query=branch_query,
base_question=base_question,
tool_description=custom_tool_info.description,
)
tool_calling_msg = graph_config.tooling.primary_llm.invoke(
tool_use_prompt,
tools=[custom_tool.tool_definition()],
tool_choice="required",
timeout_override=40,
)
# make sure we got a tool call
if (
isinstance(tool_calling_msg, AIMessage)
and len(tool_calling_msg.tool_calls) == 1
):
tool_args = tool_calling_msg.tool_calls[0]["args"]
else:
logger.warning("Tool-calling LLM did not emit a tool call")
if tool_args is None:
# get tool call args from non-tool-calling LLM or for failed tool-calling LLM
tool_args = custom_tool.get_args_for_non_tool_calling_llm(
query=branch_query,
history=[],
llm=graph_config.tooling.primary_llm,
force_run=True,
)
if tool_args is None:
raise ValueError("Failed to obtain tool arguments from LLM")
# run the tool
response_summary: CustomToolCallSummary | None = None
for tool_response in custom_tool.run(**tool_args):
if tool_response.id in {CUSTOM_TOOL_RESPONSE_ID, MCP_TOOL_RESPONSE_ID}:
response_summary = cast(CustomToolCallSummary, tool_response.response)
break
if not response_summary:
raise ValueError("Custom tool did not return a valid response summary")
# summarise tool result
if not response_summary.response_type:
raise ValueError("Response type is not returned.")
if response_summary.response_type == "json":
tool_result_str = json.dumps(response_summary.tool_result, ensure_ascii=False)
elif response_summary.response_type in {"image", "csv"}:
tool_result_str = f"{response_summary.response_type} files: {response_summary.tool_result.file_ids}"
else:
tool_result_str = str(response_summary.tool_result)
tool_str = (
f"Tool used: {custom_tool_name}\n"
f"Description: {custom_tool_info.description}\n"
f"Result: {tool_result_str}"
)
tool_summary_prompt = CUSTOM_TOOL_USE_PROMPT.build(
query=branch_query, base_question=base_question, tool_response=tool_str
)
answer_string = str(
graph_config.tooling.primary_llm.invoke(
tool_summary_prompt, timeout_override=40
).content
).strip()
# get file_ids:
file_ids = None
if response_summary.response_type in {"image", "csv"} and hasattr(
response_summary.tool_result, "file_ids"
):
file_ids = response_summary.tool_result.file_ids
logger.debug(
f"Tool call end for {custom_tool_name} {iteration_nr}.{parallelization_nr} at {datetime.now()}"
)
return BranchUpdate(
branch_iteration_responses=[
IterationAnswer(
tool=custom_tool_name,
tool_id=custom_tool_info.tool_id,
iteration_nr=iteration_nr,
parallelization_nr=parallelization_nr,
question=branch_query,
answer=answer_string,
claims=[],
cited_documents={},
reasoning="",
additional_data=None,
response_type=response_summary.response_type,
data=response_summary.tool_result,
file_ids=file_ids,
)
],
log_messages=[
get_langgraph_node_log_string(
graph_component="custom_tool",
node_name="tool_calling",
node_start_time=node_start_time,
)
],
)

View File

@@ -0,0 +1,82 @@
from datetime import datetime
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentUpdate
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.server.query_and_chat.streaming_models import CustomToolDelta
from onyx.server.query_and_chat.streaming_models import CustomToolStart
from onyx.server.query_and_chat.streaming_models import SectionEnd
from onyx.utils.logger import setup_logger
logger = setup_logger()
def custom_tool_reducer(
state: SubAgentMainState,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> SubAgentUpdate:
"""
LangGraph node to perform a generic tool call as part of the DR process.
"""
node_start_time = datetime.now()
current_step_nr = state.current_step_nr
branch_updates = state.branch_iteration_responses
current_iteration = state.iteration_nr
new_updates = [
update for update in branch_updates if update.iteration_nr == current_iteration
]
for new_update in new_updates:
if not new_update.response_type:
raise ValueError("Response type is not returned.")
write_custom_event(
current_step_nr,
CustomToolStart(
tool_name=new_update.tool,
),
writer,
)
write_custom_event(
current_step_nr,
CustomToolDelta(
tool_name=new_update.tool,
response_type=new_update.response_type,
data=new_update.data,
file_ids=new_update.file_ids,
),
writer,
)
write_custom_event(
current_step_nr,
SectionEnd(),
writer,
)
current_step_nr += 1
return SubAgentUpdate(
iteration_responses=new_updates,
log_messages=[
get_langgraph_node_log_string(
graph_component="custom_tool",
node_name="consolidation",
node_start_time=node_start_time,
)
],
)

View File

@@ -0,0 +1,28 @@
from collections.abc import Hashable
from langgraph.types import Send
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
from onyx.agents.agent_search.dr.sub_agents.states import (
SubAgentInput,
)
def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
return [
Send(
"act",
BranchInput(
iteration_nr=state.iteration_nr,
parallelization_nr=parallelization_nr,
branch_question=query,
context="",
active_source_types=state.active_source_types,
tools_used=state.tools_used,
available_tools=state.available_tools,
),
)
for parallelization_nr, query in enumerate(
state.query_list[:1] # no parallel call for now
)
]

Some files were not shown because too many files have changed in this diff Show More