1
0
forked from github/onyx

Compare commits

...

128 Commits

Author SHA1 Message Date
Chris Weaver
c3702b76b6 docs: add agent files (#5412) 2025-09-14 20:07:18 -07:00
Chris Weaver
bb239d574c feat: single default assistant (#5351) 2025-09-14 20:05:33 -07:00
Chris Weaver
172e5f0e24 feat: Move reg IT to parallel + blacksmith and have MIT only run on merge q… (#5413) 2025-09-13 17:33:45 -07:00
Nils
26b026fb88 SharePoint Connector Fix - Nested Subfolder Indexing (#5404)
Co-authored-by: nsklei <nils.kleinrahm@pledoc.de>
2025-09-13 11:33:01 +00:00
joachim-danswer
870629e8a9 fix: Azure adjustment (#5410) 2025-09-13 00:03:37 +00:00
danielkravets
a547112321 feat: bitbucket connector (#5294) 2025-09-12 18:15:09 -07:00
joachim-danswer
da5a94815e fix: initial response quality, particularly for General assistant (#5399) 2025-09-12 00:14:49 -07:00
Jessica Singh
e024472b74 fix(federated-slack): pass in valid query (#5402) 2025-09-11 19:27:43 -07:00
Chris Weaver
e74855e633 feat: use private registry (#5401) 2025-09-11 18:20:56 -07:00
Justin Tahara
e4c26a933d fix(infra): Fix helm test timeout (#5386) 2025-09-11 18:19:07 -07:00
Chris Weaver
36c96f2d98 fix: playwright (#5396) 2025-09-11 14:06:03 -07:00
Justin Tahara
1ea94dcd8d fix(security): Remove Hard Fail from Trivy (#5394) 2025-09-11 10:35:26 -07:00
Wenxi
2b1c5a0755 fix: remove unneeded dependency from requirements (#5390) 2025-09-10 21:49:02 -07:00
Chris Weaver
82b5f806ab feat: Improve migration (#5391) 2025-09-10 19:29:11 -07:00
Chris Weaver
6340c517d1 fix: missing connectors section (#5387)
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
2025-09-10 19:28:56 -07:00
joachim-danswer
3baae2d4f0 fix: tf/dr flow improvements (#5380) 2025-09-10 16:39:19 -07:00
Chris Weaver
d7c223ddd4 feat: playwright test speed improvement (#5388) 2025-09-10 16:19:56 -07:00
Chris Weaver
df4917243b fix: parallelized IT (#5389) 2025-09-10 14:37:36 -07:00
Justin Tahara
a79ab713ce feat(infra): Adding rety to Trivy tests (#5383) 2025-09-10 14:13:58 -07:00
Chris Weaver
d1f7cee959 feat: parallelized integration tests (#5021)
Co-authored-by: Claude <noreply@anthropic.com>
2025-09-10 12:15:02 -07:00
Justin Tahara
a3f41e20da feat(infra): Add Node Selector option to all Templates (#5384) 2025-09-10 10:23:54 -07:00
Chris Weaver
458ed93da0 feat: remove prompt table (#5348) 2025-09-10 10:21:57 -07:00
Chris Weaver
273d073bd7 fix: non-image gen models (#5381) 2025-09-09 15:52:03 -07:00
Wenxi
9455c8e5ae fix: add back reverted changes to readme (#5377) 2025-09-09 10:23:33 -07:00
Justin Tahara
d45d4389a0 Revert "fix: update contribution guide" (#5376)
Co-authored-by: Wenxi <wenxi@onyx.app>
2025-09-09 09:37:16 -07:00
Chris Weaver
bd901c0da1 fix: playwright tests (#5372)
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
2025-09-09 00:29:52 -07:00
Wenxi
2192605c95 feat: Bedrock API Keys & filter available models (#5343)
Co-authored-by: cubic-dev-ai[bot] <191113872+cubic-dev-ai[bot]@users.noreply.github.com>
2025-09-08 18:50:04 -07:00
Wenxi
d248d2f4e9 refactor: update seeded docs (#5364)
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
2025-09-08 18:06:29 -07:00
Chris Weaver
331c53871a fix: image gen display (#5367) 2025-09-08 17:47:17 -07:00
SubashMohan
f62d0d9144 feat(admin/connectors): Disable Auto Sync for unsupported auth; add disabled dropdown + tooltip (#5358) 2025-09-08 21:39:47 +00:00
Chris Weaver
427945e757 fix: model server build (#5362) 2025-09-08 14:00:33 -07:00
Wenxi
e55cdc6250 fix: new docs links (#5363) 2025-09-08 13:49:19 -07:00
sktbcpraha
6a01db9ff2 fix: IMAP - mail processing fixes (#5360) 2025-09-08 12:11:09 -07:00
Richard Guan
82e9df5c22 fix: various bug bash improvements (#5330) 2025-09-07 23:17:01 -07:00
Chris Weaver
16c2ef2852 feat: Make usage report gen a background job (#5342) 2025-09-07 14:44:40 -07:00
Edwin Luo
224a70eea9 fix: update contribution guide (#5354) 2025-09-07 13:06:37 -07:00
Chris Weaver
c457982120 fix: connector tests (#5353) 2025-09-07 11:57:34 -07:00
Chris Weaver
0649748da2 fix: playwright tests (#5352) 2025-09-07 11:24:26 -07:00
Wenxi
ddceddaa28 chore: bump litellm to fix self-hosted inference (#5349) 2025-09-06 19:29:26 -07:00
Evan Lohn
c6733a5026 fix: handle new error type (#5345) 2025-09-06 18:26:54 -07:00
Wenxi
7db744a5de refactor: simplify sharepoint document extraction (#5341) 2025-09-06 20:17:33 +00:00
Chris Weaver
cd2a8b0def Fix mypy (#5347) 2025-09-05 23:28:35 -07:00
Richard Guan
f15bc26cd6 fix: deep research and thoughtful assistant message context and trace all llm calls in langsmith (#5344)
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
2025-09-05 23:13:45 -07:00
Chris Weaver
65f35f0293 fix: whitelabeling (#5346) 2025-09-05 21:06:44 -07:00
joachim-danswer
4e3e608249 fix: Tweaks to Deep Research and some KG adjustments (#5305) 2025-09-06 00:57:26 +00:00
Richard Guan
719a092a12 fix: web search bugs [DAN-2351] (#5281) 2025-09-05 19:50:59 +00:00
wichmann-git
6a8fde7eb1 fix(teams): sanitize None displayName to 'Unknown User' before parsing (#5322) 2025-09-05 10:21:26 -07:00
Justin Tahara
4fdd0812a0 fix(admin): Block access to Custom Analytics Page (#5319)
Co-authored-by: Wenxi <wenxi@onyx.app>
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Co-authored-by: cubic-dev-ai[bot] <191113872+cubic-dev-ai[bot]@users.noreply.github.com>
2025-09-05 10:02:31 -07:00
Wenxi
4913dc1e85 fix: skip large sharepoint files (#5338) 2025-09-05 06:47:30 +00:00
Chris Weaver
4a43a9642e fix: citatons endpoint (#5336) 2025-09-04 21:51:52 -07:00
Evan Lohn
cc48a0c38e fix: jira cloud api v3 (#5337) 2025-09-04 21:50:35 -07:00
Chris Weaver
01ccfd2df7 fix: Try to avoid timeouts on image gen (#5316) 2025-09-04 16:14:28 -07:00
Wenxi
36d75786ee fix: honor freshdesk 429 (#5334) 2025-09-04 12:06:19 -07:00
Chris Weaver
f9bc38ba65 fix: Add back MCP (#5333) 2025-09-04 11:13:38 -07:00
Chris Weaver
3da283221d feat: Re-enable sentry (#5329) 2025-09-03 19:07:37 -07:00
Wenxi
90568d3bbb refactor: remove option to exclude citations from assistants (#5320) 2025-09-03 17:32:18 -07:00
Wenxi
7955ca938c fix: freshdesk password and rate limits (#5325) 2025-09-03 17:32:00 -07:00
Chris Weaver
f5d357eb28 fix: old send-message (#5328) 2025-09-03 16:25:38 -07:00
Evan Lohn
d83f616214 fix: incorrect assumptions about fields (#5324) 2025-09-03 21:58:46 +00:00
Chris Weaver
275c1bec3d fix: adjust search tool display (#5317) 2025-09-02 16:44:23 -07:00
Wenxi
7d1ef912e8 fix: allow chats to be moved out of chat groups (#5315)
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Co-authored-by: cubic-dev-ai[bot] <191113872+cubic-dev-ai[bot]@users.noreply.github.com>
2025-09-02 21:35:55 +00:00
Wenxi
2fe1d4c373 fix: better tool tips (#5314) 2025-09-02 19:39:10 +00:00
SubashMohan
2396ad309e fix: enhance SharePoint connector error handling and content retrieval (#5302) 2025-09-02 08:57:30 -07:00
Wenxi
0b13ef963a fix: allow web and file to show in results (#5290)
* allow web and file to show in results

* don't lag on backspacing 2nd char
2025-09-01 21:53:09 -07:00
Justin Tahara
83073f3ded fix(infra): Add Playwright Directory (#5313) 2025-09-01 19:35:48 -07:00
Wenxi
439a27a775 scroll forms on invalid submit (#5310) 2025-09-01 15:33:52 -07:00
Justin Tahara
91773a4789 fix(jira): Upgrade the Jira Python Version (#5309) 2025-09-01 15:33:03 -07:00
Chris Weaver
185beca648 Small center bar improvements (#5306) 2025-09-01 13:32:56 -07:00
Justin Tahara
2dc564c8df feat(infra): Add IAM support for Redis (#5267)
* feat: JIRA support for custom JQL filter (#5164)

* jira jql support

* jira jql fixes

* Address comment

---------

Co-authored-by: sktbcpraha <131408565+sktbcpraha@users.noreply.github.com>
2025-09-01 10:52:28 -07:00
Chris Weaver
b259f53972 Remove console-log (#5304) 2025-09-01 10:18:39 -07:00
Chris Weaver
f8beb08e2f Fix web build (#5303) 2025-09-01 10:18:06 -07:00
Evan Lohn
83c88c7cf6 feat: mcp client1 (#5271)
* working mcp implementation v1

* attempt openapi fix

* fastmcp
2025-09-01 09:52:35 -07:00
Chris Weaver
2372dd40e0 fix: small formatting fixes (#5300)
* SMall formatting fixes

* Fix mypy

* reorder imports
2025-08-31 23:19:22 -07:00
Chris Weaver
5cb6bafe81 Cleanup on ChatPage/MessagesDisply (#5299) 2025-08-31 21:29:17 -07:00
Mohamed Mathari
a0309b31c7 feat: Add Outline Connector (#5284)
* Outline

* fixConnector

* fixTest

* The date filtering is implemented correctly as client-side filtering, which is the only way to achieve it with the Outline API since it doesn't support date parameters natively.

* Update web/src/lib/connectors/connectors.tsx

Co-authored-by: cubic-dev-ai[bot] <191113872+cubic-dev-ai[bot]@users.noreply.github.com>

* no connector config for outline

* Update backend/onyx/connectors/outline/client.py

Co-authored-by: cubic-dev-ai[bot] <191113872+cubic-dev-ai[bot]@users.noreply.github.com>

* Fix all PR review issues: document ID prefixes, error handling, test assertions, and null guards

* Update backend/onyx/connectors/outline/client.py

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>

* The test no longer depends on external network connectivity to httpbin.org

* I've enhanced the OutlineApiClient.post() method in backend/onyx/connectors/outline/client.py to properly handle network-level exceptions that could crash the connector during synchronization:

* Polling mechanism

* Removed flag-based approach

* commentOnClasses

* commentOnClasses

* commentOnClasses

* responseStatus

* startBound

* Changed the method signature to match the interface

* ConnectorMissingCredentials

* Time Out shared config

* Missing Credential message

---------

Co-authored-by: cubic-dev-ai[bot] <191113872+cubic-dev-ai[bot]@users.noreply.github.com>
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
2025-08-31 20:56:10 -07:00
Chris Weaver
0fd268dba7 fix: message render performance (#5297)
* Separate out message display into it's own component

* Memoize AIMessage

* Cleanup

* Remove log

* Address greptile/cubic comments
2025-08-31 19:49:53 -07:00
Wenxi
f345da7487 fix: make radios and checkboxes actually clickable (#5298)
* dont nest labels, use htmlfor, fix slackbot form bug

* fix playwright tests for improved labels
2025-08-31 19:16:25 -07:00
Chris Weaver
f2dacf03f1 fix: Chat Page performance improvements (#5295)
* CC performance improvements r2

* More misc chat performance improvements

* Remove unused import

* Remove silly useMemo

* Fix small shift

* Address greptile + cubic + subash comments

* Fix build

* Improve document sidebar

* Remove console.log

* Remove more logs

* Fix build
2025-08-31 14:29:03 -07:00
Wenxi
e0fef50cf0 fix: don't skip ccpairs if embedding swap in progress (#5189)
* don't skip ccpairs if embedding swap in progress

* refactor check_for_indexing to properly handle search setting swaps

* mypy

* mypy

* comment debugging log

* nits and more efficient active index attempt check
2025-08-29 17:17:36 -07:00
Chris Weaver
6ba3eeefa5 feat: center bar + tool force + tool disable (#5272)
* Exploration

* Adding user-specific assistant preferences

* Small fixes

* Improvements

* Reset forced tools upon switch

* Add starter messages

* Improve starter messages

* Add misisng file

* cleanup renaming imports

* Address greptile/cubic comments

* Fix build

* Add back assistant info

* Fix image icon

* rebase fix

* Color corrections

* Small tweak

* More color correction

* Remove animation for now

* fix test

* Fix coloring + allow only one forced tool
2025-08-29 17:17:09 -07:00
Richard Guan
aa158abaa9 . (#5286) 2025-08-29 17:07:50 -07:00
Wenxi
255c2af1d6 feat: reorganize connectors pages (#5186)
* Add popular connectors sections and cleanup connectors page

* Add other connectors env var

* other connectors env var to vscode env template

* update playwright tests

* sort by popuarlity

* recategorize and sort by popularity
2025-08-29 16:59:00 -07:00
Chris Weaver
9ece3b0310 fix: improve index attempts API (#5287)
* Improve index attempts API

* Fix import
2025-08-29 16:15:58 -07:00
joachim-danswer
9e3aca03a7 fix: various dr issues and comments (#5280)
* replacement of "message_delta" etc as Enums + removal

* prompt changes

* cubic fixes where appropriate

* schema fixes + citation symbols

* various fixes

* fix for kg context in new search

* cw comments

* updates
2025-08-29 15:08:23 -07:00
Wenxi
dbd5d4d8f1 fix: allow jira api v3 (#5285)
* allow jira api v3

* don't rely on api version for parsing issues and separate cloud and dc versions
2025-08-29 14:02:01 -07:00
Chris Weaver
cdb97c3ce4 fix: test_soft_delete_chat_session (#5283)
* Fix test_soft_delete_chat_session

* Fix flakiness
2025-08-29 09:01:55 -07:00
Chris Weaver
f30ced31a9 fix: IT (#5276)
* Fix IT

* test

* Fix test

* test

* fix

* Fix test
2025-08-28 20:42:14 -07:00
Wenxi
6cc6c43234 fix: explain why limit=None is appropriate for discord (#5278)
* explain why limit=None is appropriate for discord

* linting
2025-08-28 14:17:46 -07:00
Wenxi
224d934cf4 fix: ruff complaint about type comparison (#5279)
* ruff complaint about type comparison

* ruff complaint type comparison
2025-08-28 14:17:30 -07:00
Nigel Brown
8ecdc61ad3 fix: Explicitly add limit to the function calls (#5273)
* Explicitly add limit to the function calls
This means we miss fewer messages. The default limit is 100.

Signed-off-by: nigel brown <nigel@stacklok.com>

* Update backend/onyx/connectors/discord/connector.py

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>

---------

Signed-off-by: nigel brown <nigel@stacklok.com>
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
2025-08-28 13:35:02 -07:00
Chris Weaver
08161db7ea fix: playwright tests (#5259)
* Fix playwright tests

* Address comment

* Fix
2025-08-27 23:23:55 -07:00
Richard Guan
b139764631 feat: Fast internet search (#5238)
* squash: combine all DR commits into one

Co-authored-by: Joachim Rahmfeld <joachim@onyx.app>
Co-authored-by: Rei Meguro <rmeguro@umich.edu>

* Fixes

* show KG in Assistant only if available

* KG only usable for KG Beta (for now)

* base file upload

* improvements

* raise error if uploaded context is too long

* More improvements

* Fix citations

* jank implementation of internet search with deep research that can kind of work

* early implementation for google api support

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

---------

Co-authored-by: Weves <chrisweaver101@gmail.com>
Co-authored-by: Joachim Rahmfeld <joachim@onyx.app>
Co-authored-by: Rei Meguro <rmeguro@umich.edu>
Co-authored-by: joachim-danswer <joachim@danswer.ai>
2025-08-27 20:03:02 -07:00
joachim-danswer
2b23dbde8d fix: small DR/Thoughtful mode fixes (#5269)
* fix budget calculation

* Internal custom tool fix + Okta special casing

* nits

* CW comments
2025-08-26 22:33:54 -07:00
Wenxi
2dec009d63 feat: add api/versions to onyx (#5268)
* add api/versions to onyx

* add test and rename onyx

* cubic nit

Co-authored-by: cubic-dev-ai[bot] <191113872+cubic-dev-ai[bot]@users.noreply.github.com>

* move api version constants and add explanatory comment

---------

Co-authored-by: cubic-dev-ai[bot] <191113872+cubic-dev-ai[bot]@users.noreply.github.com>
2025-08-26 18:14:54 -07:00
Chris Weaver
91eadae353 Fix logger startup (#5263) 2025-08-26 17:33:25 -07:00
Wenxi
8bff616e27 fix: clarify jql instructions and embed links (#5264)
* clarify jql instructions and embed links

* typo

* lint

* fix unit test
2025-08-26 17:27:07 -07:00
sktbcpraha
2c049e170f feat: JIRA support for custom JQL filter (#5164)
* jira jql support

* jira jql fixes
2025-08-26 12:44:39 -07:00
Oht8wooWi8yait9n
23e6d7ef3c Update gemini model names. (#5262)
Co-authored-by: Aaron Sells <aaron.b.sells@nasa.gov>
2025-08-26 12:33:02 -07:00
Chris Weaver
ed81e75edd fix: add jira auto-sync option in UI (#5260)
* Add jira auto-sync option in UI

* Fix build
2025-08-26 11:21:04 -07:00
Wenxi
de22fc3a58 remove dead code (#5261) 2025-08-26 11:14:12 -07:00
Cameron
009b7f60f1 Update date format used for fetching from Bookstack (#5221) 2025-08-26 09:49:38 -07:00
Chris Weaver
9d997e20df feat: frontend refactor + DR (#5225)
* squash: combine all DR commits into one

Co-authored-by: Joachim Rahmfeld <joachim@onyx.app>
Co-authored-by: Rei Meguro <rmeguro@umich.edu>

* Fixes

* show KG in Assistant only if available

* KG only usable for KG Beta (for now)

* base file upload

* raise error if uploaded context is too long

* improvements

* More improvements

* Fix citations

* better decision making

* improved decision-making in Orchestrator

* generic_internal tools

* Small tweak

* tool use improvements

* add on

* More image gen stuff

* fixes

* Small color improvements

* Markdown utils

* fixed end conditions (incl early exit for image generation)

* remove agent search + image fixes

* Okta tool support for reload

* Some cleanup

* Stream back search tool results as they come

* tool forcing

* fixed no-Tool-Assistant

* Support anthropic tool calling

* Support anthropic models better

* More stuff

* prompt fixes and search step numbers

* Fix hook ordering issue

* internal search fix

* Improve citation look

* Small UI improvements

* Improvements

* Improve dot

* Small chat fixes

* Small UI tweaks

* Small improvements

* Remove un-used code

* Fix

* Remove test_answer.py for now

* Fix

* improvements

* Add foreign keys

* early forcing

* Fix tests

* Fix tests

---------

Co-authored-by: Joachim Rahmfeld <joachim@onyx.app>
Co-authored-by: Rei Meguro <rmeguro@umich.edu>
Co-authored-by: joachim-danswer <joachim@danswer.ai>
2025-08-26 00:26:14 -07:00
Denizhan Dakılır
e6423c4541 Handle disabled auth in connector indexing status endpoint (#5256) 2025-08-25 16:42:46 -07:00
Wenxi
cb969ad06a add require_email_verification to values.yaml (#5249) 2025-08-25 22:02:49 +00:00
Sam Waddell
c4076d16b6 fix: update all log paths to reflect change related to non-root user (#5244) 2025-08-25 14:11:18 -07:00
Evan Lohn
04a607a718 ensure multi-tenant contextvar is passed (#5240) 2025-08-25 13:35:50 -07:00
Evan Lohn
c1e1aa9dfd fix: downloads are never larger than 20mb (#5247)
* fix: downloads are never larger than 20mb

* JT comments

* import to fix integration tests
2025-08-25 18:10:14 +00:00
Chris Weaver
1ed7abae6e Small improvement (#5250) 2025-08-25 08:07:36 +05:30
SubashMohan
cf4855822b Perf/indexing status page (#5142)
* indexing status optimization first draft

* refactor: update pagination logic and enhance UI for indexing status table

* add index attempt pruning job and display federated connectors in index status page

* update celery worker command to include index_attempt_cleanup queue

* refactor: enhance indexing status table and remove deprecated components

* mypy fix

* address review comments

* fix pagination reset issue

* add TODO for optimizing connector materialization and performance in future deployments

* enhance connector indexing status retrieval by adding 'get_all_connectors' option and updating pagination logic

* refactor: transition to paginated connector indexing status retrieval and update related components

* fix: initialize latest_index_attempt_docs_indexed to 0 in CCPairIndexingStatusTable component

* feat: add mock connector file support for indexing status retrieval and update indexing_statuses type to Sequence

* mypy fix

* refactor: rename indexing status endpoint to simplify API and update related components
2025-08-24 17:43:47 -07:00
Justin Tahara
e242b1319c fix(infra): Fixed RDS IAM Issue (#5245) 2025-08-22 18:13:12 -07:00
Justin Tahara
eba4b6620e feat(infra): AWS IAM Terraform (#5228)
* feat(infra): AWS IAM Terraform

* Fixing dependency issue

* Fixing more weird logic

* Final cleanup

* one change

* oops
2025-08-22 16:39:16 -07:00
Justin Tahara
3534515e11 feat(infra): Utilize AWS RDS IAM Auth (#5226)
* feat(infra): Utilize AWS RDS IAM Auth

* Update spacing

* Bump helm version
2025-08-21 17:35:53 -07:00
Justin Tahara
5602ff8666 fix: use only celery-shared for security context (#5236) (#5239)
* fix: use only celery-shared for security context

* fix: bump helm chart version 0.2.8

Co-authored-by: Sam Waddell <shwaddell28@gmail.com>
2025-08-21 17:25:06 -07:00
Sam Waddell
2fc70781b4 fix: use only celery-shared for security context (#5236)
* fix: use only celery-shared for security context

* fix: bump helm chart version 0.2.8
2025-08-21 14:15:07 -07:00
Justin Tahara
f76b4dec4c feat(infra): Ignoring local Terraform files (#5227)
* feat(infra): Ignoring local Terraform files

* Addressing some comments
2025-08-21 09:43:18 -07:00
Jessica Singh
a5a516fa8a refactor(model): move api-based embeddings/reranking calls out of model server (#5216)
* move api-based embeddings/reranking calls to api server out of model server, added/modified unit tests

* ran pre-commit

* fix mypy errors

* mypy and precommit

* move utils to right place and add requirements

* precommit check

* removed extra constants, changed error msg

* Update backend/onyx/utils/search_nlp_models_utils.py

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>

* greptile

* addressed comments

* added code enforcement to throw error

---------

Co-authored-by: Jessica Singh <jessicasingh@Mac.attlocal.net>
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
2025-08-20 21:50:21 +00:00
Sam Waddell
811a198134 docs: add non-root user info (#5224) 2025-08-20 13:50:10 -07:00
Sam Waddell
5867ab1d7d feat: add non-root user to backend and model-server images (#5134)
* feat: add non-root user to backend and model-server image

* feat: update values to support security context for index, inference, and celery_shared

* feat: add security context support for index and inference

* feat: add celery_shared security context support to celery worker templates

* fix: cache management strategy

* fix: update deployment files for volume mount

* fix: address comments

* fix: bump helm chart version for new security context template changes

* fix: bump helm chart version for new security context template changes

* feat: move useradd earlier in build for reduced image size

---------

Co-authored-by: Phil Critchfield <phil.critchfield@liatrio.com>
2025-08-20 13:49:50 -07:00
Jose Bañez
dd6653eb1f fix(connector): #5178 Add error handling and logging for empty answer text in Loopio Connector (#5179)
* fix(connector): #5178 Add error handling and logging for empty answer text in LoopioConnector

* fix(connector): onyx-dot-app#5178:  Improve handling of empty answer text in LoopioConnector

---------

Co-authored-by: Jose Bañez <jose@4gclinical.com>
2025-08-20 09:14:08 -07:00
Richard Guan
db457ef432 fix(admin): [DAN-2202] Remove users from invited users after accept (#5214)
* .

* .

* .

* .

* .

* .

* .

---------

Co-authored-by: Richard Guan <richardguan@Richards-MacBook-Pro.local>
Co-authored-by: Richard Guan <richardguan@Mac.attlocal.net>
2025-08-20 03:55:02 +00:00
Richard Guan
de7fe939b2 . (#5212)
Co-authored-by: Richard Guan <richardguan@Richards-MBP.lan>
2025-08-20 02:36:44 +00:00
Chris Weaver
38114d9542 fix: PDF file upload (#5218)
* Fix / improve file upload

* Address cubic comment
2025-08-19 15:16:08 -07:00
Justin Tahara
32f20f2e2e feat(infra): Add WAF implementation (#5213) (#5217)
* feat(infra): Add WAF implementation

* Addressing greptile comments

* Additional removal of unnecessary code
2025-08-19 13:01:40 -07:00
Justin Tahara
3dd27099f7 feat(infra): Add WAF implementation (#5213)
* feat(infra): Add WAF implementation

* Addressing greptile comments

* Additional removal of unnecessary code
2025-08-18 17:45:50 -07:00
Cameron
91c4d43a80 Move @types packages to devDependencies (#5210) 2025-08-18 14:34:09 -07:00
SubashMohan
a63ba1bb03 fix: sharepoint group not found error and url with apostrophe (#5208)
* fix: handle ClientRequestException in SharePoint permission utils and connector

* feat: enhance SharePoint permission utilities with logging and URL handling

* greptile typo fix

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>

* enhance group sync handling for public groups

---------

Co-authored-by: Wenxi <wenxi@onyx.app>
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
2025-08-18 17:12:59 +00:00
Evan Lohn
7b6189e74c corrected routing (#5202) 2025-08-18 16:07:28 +00:00
Evan Lohn
ba423e5773 fix: model server concurrency (#5206)
* fix: model server race cond

* fix async

* different approach
2025-08-18 16:07:16 +00:00
676 changed files with 77901 additions and 37232 deletions

View File

@@ -35,6 +35,16 @@ inputs:
cache-to:
description: 'Cache destinations'
required: false
outputs:
description: 'Output destinations'
required: false
provenance:
description: 'Generate provenance attestation'
required: false
default: 'false'
build-args:
description: 'Build arguments'
required: false
retry-wait-time:
description: 'Time to wait before attempt 2 in seconds'
required: false
@@ -62,6 +72,9 @@ runs:
no-cache: ${{ inputs.no-cache }}
cache-from: ${{ inputs.cache-from }}
cache-to: ${{ inputs.cache-to }}
outputs: ${{ inputs.outputs }}
provenance: ${{ inputs.provenance }}
build-args: ${{ inputs.build-args }}
- name: Wait before attempt 2
if: steps.buildx1.outcome != 'success'
@@ -85,6 +98,9 @@ runs:
no-cache: ${{ inputs.no-cache }}
cache-from: ${{ inputs.cache-from }}
cache-to: ${{ inputs.cache-to }}
outputs: ${{ inputs.outputs }}
provenance: ${{ inputs.provenance }}
build-args: ${{ inputs.build-args }}
- name: Wait before attempt 3
if: steps.buildx1.outcome != 'success' && steps.buildx2.outcome != 'success'
@@ -108,6 +124,9 @@ runs:
no-cache: ${{ inputs.no-cache }}
cache-from: ${{ inputs.cache-from }}
cache-to: ${{ inputs.cache-to }}
outputs: ${{ inputs.outputs }}
provenance: ${{ inputs.provenance }}
build-args: ${{ inputs.build-args }}
- name: Report failure
if: steps.buildx1.outcome != 'success' && steps.buildx2.outcome != 'success' && steps.buildx3.outcome != 'success'

View File

@@ -142,15 +142,25 @@ jobs:
# can re-enable when they figure it out
# https://github.com/aquasecurity/trivy/discussions/7538
# https://github.com/aquasecurity/trivy-action/issues/389
# Security: Using pinned digest (0.65.0@sha256:a22415a38938a56c379387a8163fcb0ce38b10ace73e593475d3658d578b2436)
# Security: No Docker socket mount needed for remote registry scanning
- name: Run Trivy vulnerability scanner
uses: aquasecurity/trivy-action@master
env:
TRIVY_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-db:2"
TRIVY_JAVA_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-java-db:1"
TRIVY_USERNAME: ${{ secrets.DOCKER_USERNAME }}
TRIVY_PASSWORD: ${{ secrets.DOCKER_TOKEN }}
uses: nick-fields/retry@v3
with:
# To run locally: trivy image --severity HIGH,CRITICAL onyxdotapp/onyx-backend
image-ref: docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
severity: "CRITICAL,HIGH"
trivyignores: ./backend/.trivyignore
timeout_minutes: 30
max_attempts: 3
retry_wait_seconds: 10
command: |
docker run --rm -v $HOME/.cache/trivy:/root/.cache/trivy \
-v ${{ github.workspace }}/backend/.trivyignore:/tmp/.trivyignore:ro \
-e TRIVY_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-db:2" \
-e TRIVY_JAVA_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-java-db:1" \
-e TRIVY_USERNAME="${{ secrets.DOCKER_USERNAME }}" \
-e TRIVY_PASSWORD="${{ secrets.DOCKER_TOKEN }}" \
aquasec/trivy@sha256:a22415a38938a56c379387a8163fcb0ce38b10ace73e593475d3658d578b2436 \
image \
--skip-version-check \
--timeout 20m \
--severity CRITICAL,HIGH \
--ignorefile /tmp/.trivyignore \
docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}

View File

@@ -139,12 +139,20 @@ jobs:
# https://github.com/aquasecurity/trivy/discussions/7538
# https://github.com/aquasecurity/trivy-action/issues/389
- name: Run Trivy vulnerability scanner
uses: aquasecurity/trivy-action@master
env:
TRIVY_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-db:2"
TRIVY_JAVA_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-java-db:1"
TRIVY_USERNAME: ${{ secrets.DOCKER_USERNAME }}
TRIVY_PASSWORD: ${{ secrets.DOCKER_TOKEN }}
uses: nick-fields/retry@v3
with:
image-ref: docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
severity: "CRITICAL,HIGH"
timeout_minutes: 30
max_attempts: 3
retry_wait_seconds: 10
command: |
docker run --rm -v $HOME/.cache/trivy:/root/.cache/trivy \
-e TRIVY_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-db:2" \
-e TRIVY_JAVA_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-java-db:1" \
-e TRIVY_USERNAME="${{ secrets.DOCKER_USERNAME }}" \
-e TRIVY_PASSWORD="${{ secrets.DOCKER_TOKEN }}" \
aquasec/trivy@sha256:a22415a38938a56c379387a8163fcb0ce38b10ace73e593475d3658d578b2436 \
image \
--skip-version-check \
--timeout 20m \
--severity CRITICAL,HIGH \
docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}

View File

@@ -99,7 +99,7 @@ jobs:
needs: [check_model_server_changes]
if: needs.check_model_server_changes.outputs.changed == 'true'
runs-on:
[runs-on, runner=8cpu-linux-x64, "run-id=${{ github.run_id }}-arm64"]
[runs-on, runner=8cpu-linux-arm64, "run-id=${{ github.run_id }}-arm64"]
env:
PLATFORM_PAIR: linux-arm64
steps:
@@ -164,13 +164,20 @@ jobs:
fi
- name: Run Trivy vulnerability scanner
uses: aquasecurity/trivy-action@master
env:
TRIVY_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-db:2"
TRIVY_JAVA_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-java-db:1"
TRIVY_USERNAME: ${{ secrets.DOCKER_USERNAME }}
TRIVY_PASSWORD: ${{ secrets.DOCKER_TOKEN }}
uses: nick-fields/retry@v3
with:
image-ref: docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
severity: "CRITICAL,HIGH"
timeout: "10m"
timeout_minutes: 30
max_attempts: 3
retry_wait_seconds: 10
command: |
docker run --rm -v $HOME/.cache/trivy:/root/.cache/trivy \
-e TRIVY_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-db:2" \
-e TRIVY_JAVA_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-java-db:1" \
-e TRIVY_USERNAME="${{ secrets.DOCKER_USERNAME }}" \
-e TRIVY_PASSWORD="${{ secrets.DOCKER_TOKEN }}" \
aquasec/trivy@sha256:a22415a38938a56c379387a8163fcb0ce38b10ace73e593475d3658d578b2436 \
image \
--skip-version-check \
--timeout 20m \
--severity CRITICAL,HIGH \
docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}

View File

@@ -150,12 +150,20 @@ jobs:
# https://github.com/aquasecurity/trivy/discussions/7538
# https://github.com/aquasecurity/trivy-action/issues/389
- name: Run Trivy vulnerability scanner
uses: aquasecurity/trivy-action@master
env:
TRIVY_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-db:2"
TRIVY_JAVA_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-java-db:1"
TRIVY_USERNAME: ${{ secrets.DOCKER_USERNAME }}
TRIVY_PASSWORD: ${{ secrets.DOCKER_TOKEN }}
uses: nick-fields/retry@v3
with:
image-ref: docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
severity: "CRITICAL,HIGH"
timeout_minutes: 30
max_attempts: 3
retry_wait_seconds: 10
command: |
docker run --rm -v $HOME/.cache/trivy:/root/.cache/trivy \
-e TRIVY_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-db:2" \
-e TRIVY_JAVA_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-java-db:1" \
-e TRIVY_USERNAME="${{ secrets.DOCKER_USERNAME }}" \
-e TRIVY_PASSWORD="${{ secrets.DOCKER_TOKEN }}" \
aquasec/trivy@sha256:a22415a38938a56c379387a8163fcb0ce38b10ace73e593475d3658d578b2436 \
image \
--skip-version-check \
--timeout 20m \
--severity CRITICAL,HIGH \
docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}

View File

@@ -21,6 +21,9 @@ env:
CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }}
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
# LLMs
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
jobs:
discover-test-dirs:
runs-on: ubuntu-latest

View File

@@ -53,27 +53,154 @@ jobs:
if: steps.list-changed.outputs.changed == 'true'
uses: helm/kind-action@v1.12.0
- name: Run chart-testing (install)
- name: Pre-install cluster status check
if: steps.list-changed.outputs.changed == 'true'
run: ct install --all \
--helm-extra-set-args="\
--set=nginx.enabled=false \
--set=postgresql.enabled=false \
--set=redis.enabled=false \
--set=minio.enabled=false \
--set=vespa.enabled=false \
--set=slackbot.enabled=false \
--set=api.replicaCount=0 \
--set=inferenceCapability.replicaCount=0 \
--set=indexCapability.replicaCount=0 \
--set=celery_beat.replicaCount=0 \
--set=celery_worker_heavy.replicaCount=0 \
--set=celery_worker_docprocessing.replicaCount=0 \
--set=celery_worker_light.replicaCount=0 \
--set=celery_worker_monitoring.replicaCount=0 \
--set=celery_worker_primary.replicaCount=0 \
--set=celery_worker_user_files_indexing.replicaCount=0" \
--debug --config ct.yaml
run: |
echo "=== Pre-install Cluster Status ==="
kubectl get nodes -o wide
kubectl get pods --all-namespaces
kubectl get storageclass
- name: Add Helm repositories and update
if: steps.list-changed.outputs.changed == 'true'
run: |
echo "=== Adding Helm repositories ==="
helm repo add bitnami https://charts.bitnami.com/bitnami
helm repo add vespa https://onyx-dot-app.github.io/vespa-helm-charts
helm repo update
- name: Pre-pull critical images
if: steps.list-changed.outputs.changed == 'true'
run: |
echo "=== Pre-pulling critical images to avoid timeout ==="
# Get kind cluster name
KIND_CLUSTER=$(kubectl config current-context | sed 's/kind-//')
echo "Kind cluster: $KIND_CLUSTER"
# Pre-pull images that are likely to be used
echo "Pre-pulling PostgreSQL image..."
docker pull postgres:15-alpine || echo "Failed to pull postgres:15-alpine"
kind load docker-image postgres:15-alpine --name $KIND_CLUSTER || echo "Failed to load postgres image"
echo "Pre-pulling Redis image..."
docker pull redis:7-alpine || echo "Failed to pull redis:7-alpine"
kind load docker-image redis:7-alpine --name $KIND_CLUSTER || echo "Failed to load redis image"
echo "Pre-pulling Onyx images..."
docker pull docker.io/onyxdotapp/onyx-web-server:latest || echo "Failed to pull onyx web server"
docker pull docker.io/onyxdotapp/onyx-backend:latest || echo "Failed to pull onyx backend"
kind load docker-image docker.io/onyxdotapp/onyx-web-server:latest --name $KIND_CLUSTER || echo "Failed to load onyx web server"
kind load docker-image docker.io/onyxdotapp/onyx-backend:latest --name $KIND_CLUSTER || echo "Failed to load onyx backend"
echo "=== Images loaded into Kind cluster ==="
docker exec $KIND_CLUSTER-control-plane crictl images | grep -E "(postgres|redis|onyx)" || echo "Some images may still be loading..."
- name: Validate chart dependencies
if: steps.list-changed.outputs.changed == 'true'
run: |
echo "=== Validating chart dependencies ==="
cd deployment/helm/charts/onyx
helm dependency update
helm lint .
- name: Run chart-testing (install) with enhanced monitoring
timeout-minutes: 25
if: steps.list-changed.outputs.changed == 'true'
run: |
echo "=== Starting chart installation with monitoring ==="
# Function to monitor cluster state
monitor_cluster() {
while true; do
echo "=== Cluster Status Check at $(date) ==="
# Only show non-running pods to reduce noise
NON_RUNNING_PODS=$(kubectl get pods --all-namespaces --field-selector=status.phase!=Running,status.phase!=Succeeded --no-headers 2>/dev/null | wc -l)
if [ "$NON_RUNNING_PODS" -gt 0 ]; then
echo "Non-running pods:"
kubectl get pods --all-namespaces --field-selector=status.phase!=Running,status.phase!=Succeeded
else
echo "All pods running successfully"
fi
# Only show recent events if there are issues
RECENT_EVENTS=$(kubectl get events --sort-by=.lastTimestamp --all-namespaces --field-selector=type!=Normal 2>/dev/null | tail -5)
if [ -n "$RECENT_EVENTS" ]; then
echo "Recent warnings/errors:"
echo "$RECENT_EVENTS"
fi
sleep 60
done
}
# Start monitoring in background
monitor_cluster &
MONITOR_PID=$!
# Set up cleanup
cleanup() {
echo "=== Cleaning up monitoring process ==="
kill $MONITOR_PID 2>/dev/null || true
echo "=== Final cluster state ==="
kubectl get pods --all-namespaces
kubectl get events --all-namespaces --sort-by=.lastTimestamp | tail -20
}
# Trap cleanup on exit
trap cleanup EXIT
# Run the actual installation with detailed logging
echo "=== Starting ct install ==="
ct install --all \
--helm-extra-set-args="\
--set=nginx.enabled=false \
--set=minio.enabled=false \
--set=vespa.enabled=false \
--set=slackbot.enabled=false \
--set=postgresql.enabled=true \
--set=postgresql.primary.persistence.enabled=false \
--set=redis.enabled=true \
--set=webserver.replicaCount=1 \
--set=api.replicaCount=0 \
--set=inferenceCapability.replicaCount=0 \
--set=indexCapability.replicaCount=0 \
--set=celery_beat.replicaCount=0 \
--set=celery_worker_heavy.replicaCount=0 \
--set=celery_worker_docfetching.replicaCount=0 \
--set=celery_worker_docprocessing.replicaCount=0 \
--set=celery_worker_light.replicaCount=0 \
--set=celery_worker_monitoring.replicaCount=0 \
--set=celery_worker_primary.replicaCount=0 \
--set=celery_worker_user_files_indexing.replicaCount=0" \
--helm-extra-args="--timeout 900s --debug" \
--debug --config ct.yaml
echo "=== Installation completed successfully ==="
kubectl get pods --all-namespaces
- name: Post-install verification
if: steps.list-changed.outputs.changed == 'true'
run: |
echo "=== Post-install verification ==="
kubectl get pods --all-namespaces
kubectl get services --all-namespaces
# Only show issues if they exist
kubectl describe pods --all-namespaces | grep -A 5 -B 2 "Failed\|Error\|Warning" || echo "No pod issues found"
- name: Cleanup on failure
if: failure() && steps.list-changed.outputs.changed == 'true'
run: |
echo "=== Cleanup on failure ==="
echo "=== Final cluster state ==="
kubectl get pods --all-namespaces
kubectl get events --all-namespaces --sort-by=.lastTimestamp | tail -10
echo "=== Pod descriptions for debugging ==="
kubectl describe pods --all-namespaces | grep -A 10 -B 3 "Failed\|Error\|Warning\|Pending" || echo "No problematic pods found"
echo "=== Recent logs for debugging ==="
kubectl logs --all-namespaces --tail=50 | grep -i "error\|timeout\|failed\|pull" || echo "No error logs found"
echo "=== Helm releases ==="
helm list --all-namespaces
# the following would install only changed charts, but we only have one chart so
# don't worry about that for now
# run: ct install --target-branch ${{ github.event.repository.default_branch }}

View File

@@ -11,6 +11,12 @@ on:
- "release/**"
env:
# Private Registry Configuration
PRIVATE_REGISTRY: experimental-registry.blacksmith.sh:5000
PRIVATE_REGISTRY_USERNAME: ${{ secrets.PRIVATE_REGISTRY_USERNAME }}
PRIVATE_REGISTRY_PASSWORD: ${{ secrets.PRIVATE_REGISTRY_PASSWORD }}
# Test Environment Variables
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
CONFLUENCE_TEST_SPACE_URL: ${{ secrets.CONFLUENCE_TEST_SPACE_URL }}
@@ -23,18 +29,38 @@ env:
PERM_SYNC_SHAREPOINT_PRIVATE_KEY: ${{ secrets.PERM_SYNC_SHAREPOINT_PRIVATE_KEY }}
PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD: ${{ secrets.PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD }}
PERM_SYNC_SHAREPOINT_DIRECTORY_ID: ${{ secrets.PERM_SYNC_SHAREPOINT_DIRECTORY_ID }}
PLATFORM_PAIR: linux-amd64
jobs:
integration-tests:
# See https://runs-on.com/runners/linux/
runs-on:
[
runs-on,
runner=32cpu-linux-x64,
disk=large,
"run-id=${{ github.run_id }}",
]
discover-test-dirs:
runs-on: blacksmith-2vcpu-ubuntu-2404-arm
outputs:
test-dirs: ${{ steps.set-matrix.outputs.test-dirs }}
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Discover test directories
id: set-matrix
run: |
# Find all leaf-level directories in both test directories
tests_dirs=$(find backend/tests/integration/tests -mindepth 1 -maxdepth 1 -type d ! -name "__pycache__" -exec basename {} \; | sort)
connector_dirs=$(find backend/tests/integration/connector_job_tests -mindepth 1 -maxdepth 1 -type d ! -name "__pycache__" -exec basename {} \; | sort)
# Create JSON array with directory info
all_dirs=""
for dir in $tests_dirs; do
all_dirs="$all_dirs{\"path\":\"tests/$dir\",\"name\":\"tests-$dir\"},"
done
for dir in $connector_dirs; do
all_dirs="$all_dirs{\"path\":\"connector_job_tests/$dir\",\"name\":\"connector-$dir\"},"
done
# Remove trailing comma and wrap in array
all_dirs="[${all_dirs%,}]"
echo "test-dirs=$all_dirs" >> $GITHUB_OUTPUT
prepare-build:
runs-on: blacksmith-2vcpu-ubuntu-2404-arm
steps:
- name: Checkout code
uses: actions/checkout@v4
@@ -47,12 +73,12 @@ jobs:
cache-dependency-path: |
backend/requirements/default.txt
backend/requirements/dev.txt
backend/requirements/ee.txt
- run: |
- name: Install Python dependencies
run: |
python -m pip install --upgrade pip
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
pip install --retries 5 --timeout 30 -r backend/requirements/ee.txt
- name: Generate OpenAPI schema
working-directory: ./backend
@@ -70,132 +96,155 @@ jobs:
-i /local/openapi.json \
-g python \
-o /local/onyx_openapi_client \
--package-name onyx_openapi_client
--package-name onyx_openapi_client \
--skip-validate-spec \
--openapi-normalizer "SIMPLIFY_ONEOF_ANYOF=true,SET_OAS3_NULLABLE=true"
- name: Upload OpenAPI artifacts
uses: actions/upload-artifact@v4
with:
name: openapi-artifacts
path: backend/generated/
build-backend-image:
runs-on: blacksmith-16vcpu-ubuntu-2404-arm
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Login to Private Registry
uses: docker/login-action@v3
with:
registry: ${{ env.PRIVATE_REGISTRY }}
username: ${{ env.PRIVATE_REGISTRY_USERNAME }}
password: ${{ env.PRIVATE_REGISTRY_PASSWORD }}
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
uses: useblacksmith/setup-docker-builder@v1
- name: Build and push Backend Docker image
uses: useblacksmith/build-push-action@v2
with:
context: ./backend
file: ./backend/Dockerfile
platforms: linux/arm64
tags: ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-backend:test-${{ github.run_id }}
push: true
build-model-server-image:
runs-on: blacksmith-16vcpu-ubuntu-2404-arm
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Login to Private Registry
uses: docker/login-action@v3
with:
registry: ${{ env.PRIVATE_REGISTRY }}
username: ${{ env.PRIVATE_REGISTRY_USERNAME }}
password: ${{ env.PRIVATE_REGISTRY_PASSWORD }}
- name: Set up Docker Buildx
uses: useblacksmith/setup-docker-builder@v1
- name: Build and push Model Server Docker image
uses: useblacksmith/build-push-action@v2
with:
context: ./backend
file: ./backend/Dockerfile.model_server
platforms: linux/arm64
tags: ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-model-server:test-${{ github.run_id }}
push: true
outputs: type=registry
provenance: false
build-integration-image:
needs: prepare-build
runs-on: blacksmith-16vcpu-ubuntu-2404-arm
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Login to Private Registry
uses: docker/login-action@v3
with:
registry: ${{ env.PRIVATE_REGISTRY }}
username: ${{ env.PRIVATE_REGISTRY_USERNAME }}
password: ${{ env.PRIVATE_REGISTRY_PASSWORD }}
- name: Download OpenAPI artifacts
uses: actions/download-artifact@v4
with:
name: openapi-artifacts
path: backend/generated/
- name: Set up Docker Buildx
uses: useblacksmith/setup-docker-builder@v1
- name: Build and push integration test Docker image
uses: useblacksmith/build-push-action@v2
with:
context: ./backend
file: ./backend/tests/integration/Dockerfile
platforms: linux/arm64
tags: ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-integration:test-${{ github.run_id }}
push: true
integration-tests:
needs:
[
discover-test-dirs,
build-backend-image,
build-model-server-image,
build-integration-image,
]
runs-on: blacksmith-8vcpu-ubuntu-2404-arm
strategy:
fail-fast: false
matrix:
test-dir: ${{ fromJson(needs.discover-test-dirs.outputs.test-dirs) }}
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Login to Private Registry
uses: docker/login-action@v3
with:
registry: ${{ env.PRIVATE_REGISTRY }}
username: ${{ env.PRIVATE_REGISTRY_USERNAME }}
password: ${{ env.PRIVATE_REGISTRY_PASSWORD }}
# needed for pulling Vespa, Redis, Postgres, and Minio images
# otherwise, we hit the "Unauthenticated users" limit
# https://docs.docker.com/docker-hub/usage/
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
# tag every docker image with "test" so that we can spin up the correct set
# of images during testing
# We don't need to build the Web Docker image since it's not yet used
# in the integration tests. We have a separate action to verify that it builds
# successfully.
- name: Pull Web Docker image
- name: Pull Docker images
run: |
docker pull onyxdotapp/onyx-web-server:latest
docker tag onyxdotapp/onyx-web-server:latest onyxdotapp/onyx-web-server:test
# Pull all images from registry in parallel
echo "Pulling Docker images in parallel..."
# Pull images from private registry
(docker pull ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-backend:test-${{ github.run_id }}) &
(docker pull ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-model-server:test-${{ github.run_id }}) &
(docker pull ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-integration:test-${{ github.run_id }}) &
# we use the runs-on cache for docker builds
# in conjunction with runs-on runners, it has better speed and unlimited caching
# https://runs-on.com/caching/s3-cache-for-github-actions/
# https://runs-on.com/caching/docker/
# https://github.com/moby/buildkit#s3-cache-experimental
# Wait for all background jobs to complete
wait
echo "All Docker images pulled successfully"
# images are built and run locally for testing purposes. Not pushed.
- name: Build Backend Docker image
uses: ./.github/actions/custom-build-and-push
with:
context: ./backend
file: ./backend/Dockerfile
platforms: linux/amd64
tags: onyxdotapp/onyx-backend:test
push: false
load: true
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/backend-${{ env.PLATFORM_PAIR }}/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/backend-${{ env.PLATFORM_PAIR }}/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
- name: Build Model Server Docker image
uses: ./.github/actions/custom-build-and-push
with:
context: ./backend
file: ./backend/Dockerfile.model_server
platforms: linux/amd64
tags: onyxdotapp/onyx-model-server:test
push: false
load: true
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/model-server-${{ env.PLATFORM_PAIR }}/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/model-server-${{ env.PLATFORM_PAIR }}/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
- name: Build integration test Docker image
uses: ./.github/actions/custom-build-and-push
with:
context: ./backend
file: ./backend/tests/integration/Dockerfile
platforms: linux/amd64
tags: onyxdotapp/onyx-integration:test
push: false
load: true
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/integration-${{ env.PLATFORM_PAIR }}/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/integration-${{ env.PLATFORM_PAIR }}/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
# Start containers for multi-tenant tests
- name: Start Docker containers for multi-tenant tests
run: |
cd deployment/docker_compose
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true \
MULTI_TENANT=true \
AUTH_TYPE=cloud \
REQUIRE_EMAIL_VERIFICATION=false \
DISABLE_TELEMETRY=true \
IMAGE_TAG=test \
DEV_MODE=true \
docker compose -f docker-compose.multitenant-dev.yml -p onyx-stack up -d
id: start_docker_multi_tenant
# In practice, `cloud` Auth type would require OAUTH credentials to be set.
- name: Run Multi-Tenant Integration Tests
run: |
echo "Waiting for 3 minutes to ensure API server is ready..."
sleep 180
echo "Running integration tests..."
docker run --rm --network onyx-stack_default \
--name test-runner \
-e POSTGRES_HOST=relational_db \
-e POSTGRES_USER=postgres \
-e POSTGRES_PASSWORD=password \
-e DB_READONLY_USER=db_readonly_user \
-e DB_READONLY_PASSWORD=password \
-e POSTGRES_DB=postgres \
-e POSTGRES_USE_NULL_POOL=true \
-e VESPA_HOST=index \
-e REDIS_HOST=cache \
-e API_SERVER_HOST=api_server \
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
-e TEST_WEB_HOSTNAME=test-runner \
-e AUTH_TYPE=cloud \
-e MULTI_TENANT=true \
-e REQUIRE_EMAIL_VERIFICATION=false \
-e DISABLE_TELEMETRY=true \
-e IMAGE_TAG=test \
-e DEV_MODE=true \
onyxdotapp/onyx-integration:test \
/app/tests/integration/multitenant_tests
continue-on-error: true
id: run_multitenant_tests
- name: Check multi-tenant test results
run: |
if [ ${{ steps.run_multitenant_tests.outcome }} == 'failure' ]; then
echo "Multi-tenant integration tests failed. Exiting with error."
exit 1
else
echo "All multi-tenant integration tests passed successfully."
fi
- name: Stop multi-tenant Docker containers
run: |
cd deployment/docker_compose
docker compose -f docker-compose.multitenant-dev.yml -p onyx-stack down -v
# Re-tag to remove registry prefix for docker-compose
docker tag ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-backend:test-${{ github.run_id }} onyxdotapp/onyx-backend:test
docker tag ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-model-server:test-${{ github.run_id }} onyxdotapp/onyx-model-server:test
docker tag ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-integration:test-${{ github.run_id }} onyxdotapp/onyx-integration:test
# NOTE: Use pre-ping/null pool to reduce flakiness due to dropped connections
# NOTE: don't need web server for integration tests
- name: Start Docker containers
run: |
cd deployment/docker_compose
@@ -208,7 +257,16 @@ jobs:
IMAGE_TAG=test \
INTEGRATION_TESTS_MODE=true \
CHECK_TTL_MANAGEMENT_TASK_FREQUENCY_IN_HOURS=0.001 \
docker compose -f docker-compose.dev.yml -p onyx-stack up -d
docker compose -f docker-compose.dev.yml -p onyx-stack up \
relational_db \
index \
cache \
minio \
api_server \
inference_model_server \
indexing_model_server \
background \
-d
id: start_docker
- name: Wait for service to be ready
@@ -251,52 +309,44 @@ jobs:
docker compose -f docker-compose.mock-it-services.yml \
-p mock-it-services-stack up -d
# NOTE: Use pre-ping/null to reduce flakiness due to dropped connections
- name: Run Standard Integration Tests
run: |
echo "Running integration tests..."
docker run --rm --network onyx-stack_default \
--name test-runner \
-e POSTGRES_HOST=relational_db \
-e POSTGRES_USER=postgres \
-e POSTGRES_PASSWORD=password \
-e DB_READONLY_USER=db_readonly_user \
-e DB_READONLY_PASSWORD=password \
-e POSTGRES_DB=postgres \
-e POSTGRES_POOL_PRE_PING=true \
-e POSTGRES_USE_NULL_POOL=true \
-e VESPA_HOST=index \
-e REDIS_HOST=cache \
-e API_SERVER_HOST=api_server \
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
-e CONFLUENCE_TEST_SPACE_URL=${CONFLUENCE_TEST_SPACE_URL} \
-e CONFLUENCE_USER_NAME=${CONFLUENCE_USER_NAME} \
-e CONFLUENCE_ACCESS_TOKEN=${CONFLUENCE_ACCESS_TOKEN} \
-e JIRA_BASE_URL=${JIRA_BASE_URL} \
-e JIRA_USER_EMAIL=${JIRA_USER_EMAIL} \
-e JIRA_API_TOKEN=${JIRA_API_TOKEN} \
-e PERM_SYNC_SHAREPOINT_CLIENT_ID=${PERM_SYNC_SHAREPOINT_CLIENT_ID} \
-e PERM_SYNC_SHAREPOINT_PRIVATE_KEY="${PERM_SYNC_SHAREPOINT_PRIVATE_KEY}" \
-e PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD=${PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD} \
-e PERM_SYNC_SHAREPOINT_DIRECTORY_ID=${PERM_SYNC_SHAREPOINT_DIRECTORY_ID} \
-e TEST_WEB_HOSTNAME=test-runner \
-e MOCK_CONNECTOR_SERVER_HOST=mock_connector_server \
-e MOCK_CONNECTOR_SERVER_PORT=8001 \
onyxdotapp/onyx-integration:test \
/app/tests/integration/tests \
/app/tests/integration/connector_job_tests
continue-on-error: true
id: run_tests
- name: Check test results
run: |
if [ ${{ steps.run_tests.outcome }} == 'failure' ]; then
echo "Integration tests failed. Exiting with error."
exit 1
else
echo "All integration tests passed successfully."
fi
- name: Run Integration Tests for ${{ matrix.test-dir.name }}
uses: nick-fields/retry@v3
with:
timeout_minutes: 20
max_attempts: 3
retry_wait_seconds: 10
command: |
echo "Running integration tests for ${{ matrix.test-dir.path }}..."
docker run --rm --network onyx-stack_default \
--name test-runner \
-e POSTGRES_HOST=relational_db \
-e POSTGRES_USER=postgres \
-e POSTGRES_PASSWORD=password \
-e POSTGRES_DB=postgres \
-e DB_READONLY_USER=db_readonly_user \
-e DB_READONLY_PASSWORD=password \
-e POSTGRES_POOL_PRE_PING=true \
-e POSTGRES_USE_NULL_POOL=true \
-e VESPA_HOST=index \
-e REDIS_HOST=cache \
-e API_SERVER_HOST=api_server \
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
-e CONFLUENCE_TEST_SPACE_URL=${CONFLUENCE_TEST_SPACE_URL} \
-e CONFLUENCE_USER_NAME=${CONFLUENCE_USER_NAME} \
-e CONFLUENCE_ACCESS_TOKEN=${CONFLUENCE_ACCESS_TOKEN} \
-e JIRA_BASE_URL=${JIRA_BASE_URL} \
-e JIRA_USER_EMAIL=${JIRA_USER_EMAIL} \
-e JIRA_API_TOKEN=${JIRA_API_TOKEN} \
-e PERM_SYNC_SHAREPOINT_CLIENT_ID=${PERM_SYNC_SHAREPOINT_CLIENT_ID} \
-e PERM_SYNC_SHAREPOINT_PRIVATE_KEY="${PERM_SYNC_SHAREPOINT_PRIVATE_KEY}" \
-e PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD=${PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD} \
-e PERM_SYNC_SHAREPOINT_DIRECTORY_ID=${PERM_SYNC_SHAREPOINT_DIRECTORY_ID} \
-e TEST_WEB_HOSTNAME=test-runner \
-e MOCK_CONNECTOR_SERVER_HOST=mock_connector_server \
-e MOCK_CONNECTOR_SERVER_PORT=8001 \
onyxdotapp/onyx-integration:test \
/app/tests/integration/${{ matrix.test-dir.path }}
# ------------------------------------------------------------
# Always gather logs BEFORE "down":
@@ -316,7 +366,7 @@ jobs:
if: always()
uses: actions/upload-artifact@v4
with:
name: docker-all-logs
name: docker-all-logs-${{ matrix.test-dir.name }}
path: ${{ github.workspace }}/docker-compose.log
# ------------------------------------------------------------
@@ -325,3 +375,157 @@ jobs:
run: |
cd deployment/docker_compose
docker compose -f docker-compose.dev.yml -p onyx-stack down -v
multitenant-tests:
needs:
[
build-backend-image,
build-model-server-image,
build-integration-image,
]
runs-on: blacksmith-8vcpu-ubuntu-2404-arm
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Login to Private Registry
uses: docker/login-action@v3
with:
registry: ${{ env.PRIVATE_REGISTRY }}
username: ${{ env.PRIVATE_REGISTRY_USERNAME }}
password: ${{ env.PRIVATE_REGISTRY_PASSWORD }}
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Pull Docker images
run: |
(docker pull ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-backend:test-${{ github.run_id }}) &
(docker pull ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-model-server:test-${{ github.run_id }}) &
(docker pull ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-integration:test-${{ github.run_id }}) &
wait
docker tag ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-backend:test-${{ github.run_id }} onyxdotapp/onyx-backend:test
docker tag ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-model-server:test-${{ github.run_id }} onyxdotapp/onyx-model-server:test
docker tag ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-integration:test-${{ github.run_id }} onyxdotapp/onyx-integration:test
- name: Start Docker containers for multi-tenant tests
run: |
cd deployment/docker_compose
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true \
MULTI_TENANT=true \
AUTH_TYPE=cloud \
REQUIRE_EMAIL_VERIFICATION=false \
DISABLE_TELEMETRY=true \
IMAGE_TAG=test \
DEV_MODE=true \
docker compose -f docker-compose.multitenant-dev.yml -p onyx-stack up \
relational_db \
index \
cache \
minio \
api_server \
inference_model_server \
indexing_model_server \
background \
-d
id: start_docker_multi_tenant
- name: Wait for service to be ready (multi-tenant)
run: |
echo "Starting wait-for-service script for multi-tenant..."
docker logs -f onyx-stack-api_server-1 &
start_time=$(date +%s)
timeout=300
while true; do
current_time=$(date +%s)
elapsed_time=$((current_time - start_time))
if [ $elapsed_time -ge $timeout ]; then
echo "Timeout reached. Service did not become ready in 5 minutes."
exit 1
fi
response=$(curl -s -o /dev/null -w "%{http_code}" http://localhost:8080/health || echo "curl_error")
if [ "$response" = "200" ]; then
echo "Service is ready!"
break
elif [ "$response" = "curl_error" ]; then
echo "Curl encountered an error; retrying..."
else
echo "Service not ready yet (HTTP $response). Retrying in 5 seconds..."
fi
sleep 5
done
echo "Finished waiting for service."
- name: Run Multi-Tenant Integration Tests
run: |
echo "Running multi-tenant integration tests..."
docker run --rm --network onyx-stack_default \
--name test-runner \
-e POSTGRES_HOST=relational_db \
-e POSTGRES_USER=postgres \
-e POSTGRES_PASSWORD=password \
-e DB_READONLY_USER=db_readonly_user \
-e DB_READONLY_PASSWORD=password \
-e POSTGRES_DB=postgres \
-e POSTGRES_USE_NULL_POOL=true \
-e VESPA_HOST=index \
-e REDIS_HOST=cache \
-e API_SERVER_HOST=api_server \
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
-e TEST_WEB_HOSTNAME=test-runner \
-e AUTH_TYPE=cloud \
-e MULTI_TENANT=true \
-e SKIP_RESET=true \
-e REQUIRE_EMAIL_VERIFICATION=false \
-e DISABLE_TELEMETRY=true \
-e IMAGE_TAG=test \
-e DEV_MODE=true \
onyxdotapp/onyx-integration:test \
/app/tests/integration/multitenant_tests
- name: Dump API server logs (multi-tenant)
if: always()
run: |
cd deployment/docker_compose
docker compose -f docker-compose.multitenant-dev.yml -p onyx-stack logs --no-color api_server > $GITHUB_WORKSPACE/api_server_multitenant.log || true
- name: Dump all-container logs (multi-tenant)
if: always()
run: |
cd deployment/docker_compose
docker compose -f docker-compose.multitenant-dev.yml -p onyx-stack logs --no-color > $GITHUB_WORKSPACE/docker-compose-multitenant.log || true
- name: Upload logs (multi-tenant)
if: always()
uses: actions/upload-artifact@v4
with:
name: docker-all-logs-multitenant
path: ${{ github.workspace }}/docker-compose-multitenant.log
- name: Stop multi-tenant Docker containers
if: always()
run: |
cd deployment/docker_compose
docker compose -f docker-compose.multitenant-dev.yml -p onyx-stack down -v
required:
runs-on: blacksmith-2vcpu-ubuntu-2404-arm
needs: [integration-tests, multitenant-tests]
if: ${{ always() }}
steps:
- uses: actions/github-script@v7
with:
script: |
const needs = ${{ toJSON(needs) }};
const failed = Object.values(needs).some(n => n.result !== 'success');
if (failed) {
core.setFailed('One or more upstream jobs failed or were cancelled.');
} else {
core.notice('All required jobs succeeded.');
}

View File

@@ -5,12 +5,15 @@ concurrency:
on:
merge_group:
pull_request:
branches:
- main
- "release/**"
types: [checks_requested]
env:
# Private Registry Configuration
PRIVATE_REGISTRY: experimental-registry.blacksmith.sh:5000
PRIVATE_REGISTRY_USERNAME: ${{ secrets.PRIVATE_REGISTRY_USERNAME }}
PRIVATE_REGISTRY_PASSWORD: ${{ secrets.PRIVATE_REGISTRY_PASSWORD }}
# Test Environment Variables
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
CONFLUENCE_TEST_SPACE_URL: ${{ secrets.CONFLUENCE_TEST_SPACE_URL }}
@@ -23,21 +26,42 @@ env:
PERM_SYNC_SHAREPOINT_PRIVATE_KEY: ${{ secrets.PERM_SYNC_SHAREPOINT_PRIVATE_KEY }}
PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD: ${{ secrets.PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD }}
PERM_SYNC_SHAREPOINT_DIRECTORY_ID: ${{ secrets.PERM_SYNC_SHAREPOINT_DIRECTORY_ID }}
PLATFORM_PAIR: linux-amd64
jobs:
integration-tests-mit:
# See https://runs-on.com/runners/linux/
runs-on:
[
runs-on,
runner=32cpu-linux-x64,
disk=large,
"run-id=${{ github.run_id }}",
]
discover-test-dirs:
runs-on: blacksmith-2vcpu-ubuntu-2404-arm
outputs:
test-dirs: ${{ steps.set-matrix.outputs.test-dirs }}
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Discover test directories
id: set-matrix
run: |
# Find all leaf-level directories in both test directories
tests_dirs=$(find backend/tests/integration/tests -mindepth 1 -maxdepth 1 -type d ! -name "__pycache__" -exec basename {} \; | sort)
connector_dirs=$(find backend/tests/integration/connector_job_tests -mindepth 1 -maxdepth 1 -type d ! -name "__pycache__" -exec basename {} \; | sort)
# Create JSON array with directory info
all_dirs=""
for dir in $tests_dirs; do
all_dirs="$all_dirs{\"path\":\"tests/$dir\",\"name\":\"tests-$dir\"},"
done
for dir in $connector_dirs; do
all_dirs="$all_dirs{\"path\":\"connector_job_tests/$dir\",\"name\":\"connector-$dir\"},"
done
# Remove trailing comma and wrap in array
all_dirs="[${all_dirs%,}]"
echo "test-dirs=$all_dirs" >> $GITHUB_OUTPUT
prepare-build:
runs-on: blacksmith-2vcpu-ubuntu-2404-arm
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Setup Python
uses: actions/setup-python@v5
with:
@@ -46,7 +70,9 @@ jobs:
cache-dependency-path: |
backend/requirements/default.txt
backend/requirements/dev.txt
- run: |
- name: Install Python dependencies
run: |
python -m pip install --upgrade pip
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
@@ -67,72 +93,156 @@ jobs:
-i /local/openapi.json \
-g python \
-o /local/onyx_openapi_client \
--package-name onyx_openapi_client
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
--package-name onyx_openapi_client \
--skip-validate-spec \
--openapi-normalizer "SIMPLIFY_ONEOF_ANYOF=true,SET_OAS3_NULLABLE=true"
- name: Upload OpenAPI artifacts
uses: actions/upload-artifact@v4
with:
name: openapi-artifacts
path: backend/generated/
build-backend-image:
runs-on: blacksmith-16vcpu-ubuntu-2404-arm
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Login to Private Registry
uses: docker/login-action@v3
with:
registry: ${{ env.PRIVATE_REGISTRY }}
username: ${{ env.PRIVATE_REGISTRY_USERNAME }}
password: ${{ env.PRIVATE_REGISTRY_PASSWORD }}
- name: Set up Docker Buildx
uses: useblacksmith/setup-docker-builder@v1
- name: Build and push Backend Docker image
uses: useblacksmith/build-push-action@v2
with:
context: ./backend
file: ./backend/Dockerfile
platforms: linux/arm64
tags: ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-backend:test-${{ github.run_id }}
push: true
build-model-server-image:
runs-on: blacksmith-16vcpu-ubuntu-2404-arm
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Login to Private Registry
uses: docker/login-action@v3
with:
registry: ${{ env.PRIVATE_REGISTRY }}
username: ${{ env.PRIVATE_REGISTRY_USERNAME }}
password: ${{ env.PRIVATE_REGISTRY_PASSWORD }}
- name: Set up Docker Buildx
uses: useblacksmith/setup-docker-builder@v1
- name: Build and push Model Server Docker image
uses: useblacksmith/build-push-action@v2
with:
context: ./backend
file: ./backend/Dockerfile.model_server
platforms: linux/arm64
tags: ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-model-server:test-${{ github.run_id }}
push: true
outputs: type=registry
provenance: false
build-integration-image:
needs: prepare-build
runs-on: blacksmith-16vcpu-ubuntu-2404-arm
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Login to Private Registry
uses: docker/login-action@v3
with:
registry: ${{ env.PRIVATE_REGISTRY }}
username: ${{ env.PRIVATE_REGISTRY_USERNAME }}
password: ${{ env.PRIVATE_REGISTRY_PASSWORD }}
- name: Download OpenAPI artifacts
uses: actions/download-artifact@v4
with:
name: openapi-artifacts
path: backend/generated/
- name: Set up Docker Buildx
uses: useblacksmith/setup-docker-builder@v1
- name: Build and push integration test Docker image
uses: useblacksmith/build-push-action@v2
with:
context: ./backend
file: ./backend/tests/integration/Dockerfile
platforms: linux/arm64
tags: ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-integration:test-${{ github.run_id }}
push: true
integration-tests-mit:
needs:
[
discover-test-dirs,
build-backend-image,
build-model-server-image,
build-integration-image,
]
# See https://docs.blacksmith.sh/blacksmith-runners/overview
runs-on: blacksmith-8vcpu-ubuntu-2404-arm
strategy:
fail-fast: false
matrix:
test-dir: ${{ fromJson(needs.discover-test-dirs.outputs.test-dirs) }}
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Login to Private Registry
uses: docker/login-action@v3
with:
registry: ${{ env.PRIVATE_REGISTRY }}
username: ${{ env.PRIVATE_REGISTRY_USERNAME }}
password: ${{ env.PRIVATE_REGISTRY_PASSWORD }}
# needed for pulling Vespa, Redis, Postgres, and Minio images
# otherwise, we hit the "Unauthenticated users" limit
# https://docs.docker.com/docker-hub/usage/
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
# tag every docker image with "test" so that we can spin up the correct set
# of images during testing
# We don't need to build the Web Docker image since it's not yet used
# in the integration tests. We have a separate action to verify that it builds
# successfully.
- name: Pull Web Docker image
- name: Pull Docker images
run: |
docker pull onyxdotapp/onyx-web-server:latest
docker tag onyxdotapp/onyx-web-server:latest onyxdotapp/onyx-web-server:test
# Pull all images from registry in parallel
echo "Pulling Docker images in parallel..."
# Pull images from private registry
(docker pull ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-backend:test-${{ github.run_id }}) &
(docker pull ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-model-server:test-${{ github.run_id }}) &
(docker pull ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-integration:test-${{ github.run_id }}) &
# we use the runs-on cache for docker builds
# in conjunction with runs-on runners, it has better speed and unlimited caching
# https://runs-on.com/caching/s3-cache-for-github-actions/
# https://runs-on.com/caching/docker/
# https://github.com/moby/buildkit#s3-cache-experimental
# Wait for all background jobs to complete
wait
echo "All Docker images pulled successfully"
# images are built and run locally for testing purposes. Not pushed.
- name: Build Backend Docker image
uses: ./.github/actions/custom-build-and-push
with:
context: ./backend
file: ./backend/Dockerfile
platforms: linux/amd64
tags: onyxdotapp/onyx-backend:test
push: false
load: true
cache-from: type=s3,prefix=cache/${{ github.repository }}/mit-integration-tests/backend-${{ env.PLATFORM_PAIR }}/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
cache-to: type=s3,prefix=cache/${{ github.repository }}/mit-integration-tests/backend-${{ env.PLATFORM_PAIR }}/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
- name: Build Model Server Docker image
uses: ./.github/actions/custom-build-and-push
with:
context: ./backend
file: ./backend/Dockerfile.model_server
platforms: linux/amd64
tags: onyxdotapp/onyx-model-server:test
push: false
load: true
cache-from: type=s3,prefix=cache/${{ github.repository }}/mit-integration-tests/model-server-${{ env.PLATFORM_PAIR }}/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
cache-to: type=s3,prefix=cache/${{ github.repository }}/mit-integration-tests/model-server-${{ env.PLATFORM_PAIR }}/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
- name: Build integration test Docker image
uses: ./.github/actions/custom-build-and-push
with:
context: ./backend
file: ./backend/tests/integration/Dockerfile
platforms: linux/amd64
tags: onyxdotapp/onyx-integration:test
push: false
load: true
cache-from: type=s3,prefix=cache/${{ github.repository }}/mit-integration-tests/integration-${{ env.PLATFORM_PAIR }}/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
cache-to: type=s3,prefix=cache/${{ github.repository }}/mit-integration-tests/integration-${{ env.PLATFORM_PAIR }}/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
# Re-tag to remove registry prefix for docker-compose
docker tag ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-backend:test-${{ github.run_id }} onyxdotapp/onyx-backend:test
docker tag ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-model-server:test-${{ github.run_id }} onyxdotapp/onyx-model-server:test
docker tag ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-integration:test-${{ github.run_id }} onyxdotapp/onyx-integration:test
# NOTE: Use pre-ping/null pool to reduce flakiness due to dropped connections
# NOTE: don't need web server for integration tests
- name: Start Docker containers
run: |
cd deployment/docker_compose
@@ -143,7 +253,16 @@ jobs:
DISABLE_TELEMETRY=true \
IMAGE_TAG=test \
INTEGRATION_TESTS_MODE=true \
docker compose -f docker-compose.dev.yml -p onyx-stack up -d
docker compose -f docker-compose.dev.yml -p onyx-stack up \
relational_db \
index \
cache \
minio \
api_server \
inference_model_server \
indexing_model_server \
background \
-d
id: start_docker
- name: Wait for service to be ready
@@ -187,51 +306,44 @@ jobs:
-p mock-it-services-stack up -d
# NOTE: Use pre-ping/null to reduce flakiness due to dropped connections
- name: Run Standard Integration Tests
run: |
echo "Running integration tests..."
docker run --rm --network onyx-stack_default \
--name test-runner \
-e POSTGRES_HOST=relational_db \
-e POSTGRES_USER=postgres \
-e POSTGRES_PASSWORD=password \
-e POSTGRES_DB=postgres \
-e DB_READONLY_USER=db_readonly_user \
-e DB_READONLY_PASSWORD=password \
-e POSTGRES_POOL_PRE_PING=true \
-e POSTGRES_USE_NULL_POOL=true \
-e VESPA_HOST=index \
-e REDIS_HOST=cache \
-e API_SERVER_HOST=api_server \
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
-e CONFLUENCE_TEST_SPACE_URL=${CONFLUENCE_TEST_SPACE_URL} \
-e CONFLUENCE_USER_NAME=${CONFLUENCE_USER_NAME} \
-e CONFLUENCE_ACCESS_TOKEN=${CONFLUENCE_ACCESS_TOKEN} \
-e JIRA_BASE_URL=${JIRA_BASE_URL} \
-e JIRA_USER_EMAIL=${JIRA_USER_EMAIL} \
-e JIRA_API_TOKEN=${JIRA_API_TOKEN} \
-e PERM_SYNC_SHAREPOINT_CLIENT_ID=${PERM_SYNC_SHAREPOINT_CLIENT_ID} \
-e PERM_SYNC_SHAREPOINT_PRIVATE_KEY="${PERM_SYNC_SHAREPOINT_PRIVATE_KEY}" \
-e PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD=${PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD} \
-e PERM_SYNC_SHAREPOINT_DIRECTORY_ID=${PERM_SYNC_SHAREPOINT_DIRECTORY_ID} \
-e TEST_WEB_HOSTNAME=test-runner \
-e MOCK_CONNECTOR_SERVER_HOST=mock_connector_server \
-e MOCK_CONNECTOR_SERVER_PORT=8001 \
onyxdotapp/onyx-integration:test \
/app/tests/integration/tests \
/app/tests/integration/connector_job_tests
continue-on-error: true
id: run_tests
- name: Check test results
run: |
if [ ${{ steps.run_tests.outcome }} == 'failure' ]; then
echo "Integration tests failed. Exiting with error."
exit 1
else
echo "All integration tests passed successfully."
fi
- name: Run Integration Tests for ${{ matrix.test-dir.name }}
uses: nick-fields/retry@v3
with:
timeout_minutes: 20
max_attempts: 3
retry_wait_seconds: 10
command: |
echo "Running integration tests for ${{ matrix.test-dir.path }}..."
docker run --rm --network onyx-stack_default \
--name test-runner \
-e POSTGRES_HOST=relational_db \
-e POSTGRES_USER=postgres \
-e POSTGRES_PASSWORD=password \
-e POSTGRES_DB=postgres \
-e DB_READONLY_USER=db_readonly_user \
-e DB_READONLY_PASSWORD=password \
-e POSTGRES_POOL_PRE_PING=true \
-e POSTGRES_USE_NULL_POOL=true \
-e VESPA_HOST=index \
-e REDIS_HOST=cache \
-e API_SERVER_HOST=api_server \
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
-e CONFLUENCE_TEST_SPACE_URL=${CONFLUENCE_TEST_SPACE_URL} \
-e CONFLUENCE_USER_NAME=${CONFLUENCE_USER_NAME} \
-e CONFLUENCE_ACCESS_TOKEN=${CONFLUENCE_ACCESS_TOKEN} \
-e JIRA_BASE_URL=${JIRA_BASE_URL} \
-e JIRA_USER_EMAIL=${JIRA_USER_EMAIL} \
-e JIRA_API_TOKEN=${JIRA_API_TOKEN} \
-e PERM_SYNC_SHAREPOINT_CLIENT_ID=${PERM_SYNC_SHAREPOINT_CLIENT_ID} \
-e PERM_SYNC_SHAREPOINT_PRIVATE_KEY="${PERM_SYNC_SHAREPOINT_PRIVATE_KEY}" \
-e PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD=${PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD} \
-e PERM_SYNC_SHAREPOINT_DIRECTORY_ID=${PERM_SYNC_SHAREPOINT_DIRECTORY_ID} \
-e TEST_WEB_HOSTNAME=test-runner \
-e MOCK_CONNECTOR_SERVER_HOST=mock_connector_server \
-e MOCK_CONNECTOR_SERVER_PORT=8001 \
onyxdotapp/onyx-integration:test \
/app/tests/integration/${{ matrix.test-dir.path }}
# ------------------------------------------------------------
# Always gather logs BEFORE "down":
@@ -251,7 +363,7 @@ jobs:
if: always()
uses: actions/upload-artifact@v4
with:
name: docker-all-logs
name: docker-all-logs-${{ matrix.test-dir.name }}
path: ${{ github.workspace }}/docker-compose.log
# ------------------------------------------------------------
@@ -260,3 +372,20 @@ jobs:
run: |
cd deployment/docker_compose
docker compose -f docker-compose.dev.yml -p onyx-stack down -v
required:
runs-on: blacksmith-2vcpu-ubuntu-2404-arm
needs: [integration-tests-mit]
if: ${{ always() }}
steps:
- uses: actions/github-script@v7
with:
script: |
const needs = ${{ toJSON(needs) }};
const failed = Object.values(needs).some(n => n.result !== 'success');
if (failed) {
core.setFailed('One or more upstream jobs failed or were cancelled.');
} else {
core.notice('All required jobs succeeded.');
}

View File

@@ -6,44 +6,165 @@ concurrency:
on: push
env:
# AWS ECR Configuration
AWS_REGION: ${{ secrets.AWS_REGION || 'us-west-2' }}
ECR_REGISTRY: ${{ secrets.ECR_REGISTRY }}
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID_ECR }}
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY_ECR }}
BUILDX_NO_DEFAULT_ATTESTATIONS: 1
# Test Environment Variables
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
GEN_AI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
EXA_API_KEY: ${{ secrets.EXA_API_KEY }}
# for federated slack tests
SLACK_CLIENT_ID: ${{ secrets.SLACK_CLIENT_ID }}
SLACK_CLIENT_SECRET: ${{ secrets.SLACK_CLIENT_SECRET }}
MOCK_LLM_RESPONSE: true
PYTEST_PLAYWRIGHT_SKIP_INITIAL_RESET: true
jobs:
playwright-tests:
name: Playwright Tests
build-web-image:
runs-on: blacksmith-8vcpu-ubuntu-2404-arm
steps:
- name: Checkout code
uses: actions/checkout@v4
# See https://runs-on.com/runners/linux/
runs-on:
[
runs-on,
runner=32cpu-linux-x64,
disk=large,
"run-id=${{ github.run_id }}",
]
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@v4
with:
aws-access-key-id: ${{ env.AWS_ACCESS_KEY_ID }}
aws-secret-access-key: ${{ env.AWS_SECRET_ACCESS_KEY }}
aws-region: ${{ env.AWS_REGION }}
- name: Login to Amazon ECR
id: login-ecr
uses: aws-actions/amazon-ecr-login@v2
- name: Set up Docker Buildx
uses: useblacksmith/setup-docker-builder@v1
- name: Build and push Web Docker image
uses: useblacksmith/build-push-action@v2
with:
context: ./web
file: ./web/Dockerfile
platforms: linux/arm64
tags: ${{ env.ECR_REGISTRY }}/integration-test-onyx-web-server:playwright-test-${{ github.run_id }}
provenance: false
sbom: false
push: true
build-backend-image:
runs-on: blacksmith-8vcpu-ubuntu-2404-arm
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@v4
with:
aws-access-key-id: ${{ env.AWS_ACCESS_KEY_ID }}
aws-secret-access-key: ${{ env.AWS_SECRET_ACCESS_KEY }}
aws-region: ${{ env.AWS_REGION }}
- name: Login to Amazon ECR
id: login-ecr
uses: aws-actions/amazon-ecr-login@v2
- name: Set up Docker Buildx
uses: useblacksmith/setup-docker-builder@v1
- name: Build and push Backend Docker image
uses: useblacksmith/build-push-action@v2
with:
context: ./backend
file: ./backend/Dockerfile
platforms: linux/arm64
tags: ${{ env.ECR_REGISTRY }}/integration-test-onyx-backend:playwright-test-${{ github.run_id }}
provenance: false
sbom: false
push: true
build-model-server-image:
runs-on: blacksmith-8vcpu-ubuntu-2404-arm
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@v4
with:
aws-access-key-id: ${{ env.AWS_ACCESS_KEY_ID }}
aws-secret-access-key: ${{ env.AWS_SECRET_ACCESS_KEY }}
aws-region: ${{ env.AWS_REGION }}
- name: Login to Amazon ECR
id: login-ecr
uses: aws-actions/amazon-ecr-login@v2
- name: Set up Docker Buildx
uses: useblacksmith/setup-docker-builder@v1
- name: Build and push Model Server Docker image
uses: useblacksmith/build-push-action@v2
with:
context: ./backend
file: ./backend/Dockerfile.model_server
platforms: linux/arm64
tags: ${{ env.ECR_REGISTRY }}/integration-test-onyx-model-server:playwright-test-${{ github.run_id }}
provenance: false
sbom: false
push: true
playwright-tests:
needs: [build-web-image, build-backend-image, build-model-server-image]
name: Playwright Tests
runs-on: blacksmith-8vcpu-ubuntu-2404-arm
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Set up Python
uses: actions/setup-python@v5
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@v4
with:
python-version: "3.11"
cache: "pip"
cache-dependency-path: |
backend/requirements/default.txt
backend/requirements/dev.txt
backend/requirements/model_server.txt
- run: |
python -m pip install --upgrade pip
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
pip install --retries 5 --timeout 30 -r backend/requirements/model_server.txt
aws-access-key-id: ${{ env.AWS_ACCESS_KEY_ID }}
aws-secret-access-key: ${{ env.AWS_SECRET_ACCESS_KEY }}
aws-region: ${{ env.AWS_REGION }}
- name: Login to Amazon ECR
id: login-ecr
uses: aws-actions/amazon-ecr-login@v2
# needed for pulling Vespa, Redis, Postgres, and Minio images
# otherwise, we hit the "Unauthenticated users" limit
# https://docs.docker.com/docker-hub/usage/
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Pull Docker images
run: |
# Pull all images from ECR in parallel
echo "Pulling Docker images in parallel..."
(docker pull ${{ env.ECR_REGISTRY }}/integration-test-onyx-web-server:playwright-test-${{ github.run_id }}) &
(docker pull ${{ env.ECR_REGISTRY }}/integration-test-onyx-backend:playwright-test-${{ github.run_id }}) &
(docker pull ${{ env.ECR_REGISTRY }}/integration-test-onyx-model-server:playwright-test-${{ github.run_id }}) &
# Wait for all background jobs to complete
wait
echo "All Docker images pulled successfully"
# Re-tag with expected names for docker-compose
docker tag ${{ env.ECR_REGISTRY }}/integration-test-onyx-web-server:playwright-test-${{ github.run_id }} onyxdotapp/onyx-web-server:test
docker tag ${{ env.ECR_REGISTRY }}/integration-test-onyx-backend:playwright-test-${{ github.run_id }} onyxdotapp/onyx-backend:test
docker tag ${{ env.ECR_REGISTRY }}/integration-test-onyx-model-server:playwright-test-${{ github.run_id }} onyxdotapp/onyx-model-server:test
- name: Setup node
uses: actions/setup-node@v4
@@ -58,68 +179,13 @@ jobs:
working-directory: ./web
run: npx playwright install --with-deps
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
# tag every docker image with "test" so that we can spin up the correct set
# of images during testing
# we use the runs-on cache for docker builds
# in conjunction with runs-on runners, it has better speed and unlimited caching
# https://runs-on.com/caching/s3-cache-for-github-actions/
# https://runs-on.com/caching/docker/
# https://github.com/moby/buildkit#s3-cache-experimental
# images are built and run locally for testing purposes. Not pushed.
- name: Build Web Docker image
uses: ./.github/actions/custom-build-and-push
with:
context: ./web
file: ./web/Dockerfile
platforms: linux/amd64
tags: onyxdotapp/onyx-web-server:test
push: false
load: true
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/web-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/web-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
- name: Build Backend Docker image
uses: ./.github/actions/custom-build-and-push
with:
context: ./backend
file: ./backend/Dockerfile
platforms: linux/amd64
tags: onyxdotapp/onyx-backend:test
push: false
load: true
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/backend/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/backend/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
- name: Build Model Server Docker image
uses: ./.github/actions/custom-build-and-push
with:
context: ./backend
file: ./backend/Dockerfile.model_server
platforms: linux/amd64
tags: onyxdotapp/onyx-model-server:test
push: false
load: true
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/model-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/model-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
- name: Start Docker containers
run: |
cd deployment/docker_compose
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true \
AUTH_TYPE=basic \
GEN_AI_API_KEY=${{ secrets.OPENAI_API_KEY }} \
GEN_AI_API_KEY=${{ env.OPENAI_API_KEY }} \
EXA_API_KEY=${{ env.EXA_API_KEY }} \
REQUIRE_EMAIL_VERIFICATION=false \
DISABLE_TELEMETRY=true \
IMAGE_TAG=test \
@@ -160,12 +226,6 @@ jobs:
done
echo "Finished waiting for service."
- name: Run pytest playwright test init
working-directory: ./backend
env:
PYTEST_IGNORE_SKIP: true
run: pytest -s tests/integration/tests/playwright/test_playwright.py
- name: Run Playwright tests
working-directory: ./web
run: npx playwright test

View File

@@ -48,6 +48,8 @@ jobs:
-g python \
-o /local/onyx_openapi_client \
--package-name onyx_openapi_client \
--skip-validate-spec \
--openapi-normalizer "SIMPLIFY_ONEOF_ANYOF=true,SET_OAS3_NULLABLE=true"
- name: Run MyPy
run: |

12
.gitignore vendored
View File

@@ -17,12 +17,24 @@ backend/tests/regression/answer_quality/test_data.json
backend/tests/regression/search_quality/eval-*
backend/tests/regression/search_quality/search_eval_config.yaml
backend/tests/regression/search_quality/*.json
*.log
# secret files
.env
jira_test_env
settings.json
# others
/deployment/data/nginx/app.conf
*.sw?
/backend/tests/regression/answer_quality/search_test_config.yaml
# Local .terraform directories
**/.terraform/*
# Local .tfstate files
*.tfstate
*.tfstate.*
# Local .terraform.lock.hcl file
.terraform.lock.hcl

View File

@@ -23,6 +23,9 @@ DISABLE_LLM_DOC_RELEVANCE=False
# Useful if you want to toggle auth on/off (google_oauth/OIDC specifically)
OAUTH_CLIENT_ID=<REPLACE THIS>
OAUTH_CLIENT_SECRET=<REPLACE THIS>
OPENID_CONFIG_URL=<REPLACE THIS>
SAML_CONF_DIR=/<ABSOLUTE PATH TO ONYX>/onyx/backend/ee/onyx/configs/saml_config
# Generally not useful for dev, we don't generally want to set up an SMTP server for dev
REQUIRE_EMAIL_VERIFICATION=False
@@ -46,7 +49,6 @@ PYTHONUNBUFFERED=1
# Internet Search
BING_API_KEY=<REPLACE THIS>
EXA_API_KEY=<REPLACE THIS>
@@ -65,3 +67,12 @@ S3_ENDPOINT_URL=http://localhost:9004
S3_FILE_STORE_BUCKET_NAME=onyx-file-store-bucket
S3_AWS_ACCESS_KEY_ID=minioadmin
S3_AWS_SECRET_ACCESS_KEY=minioadmin
# Show extra/uncommon connectors
SHOW_EXTRA_CONNECTORS=True
# Local langsmith tracing
LANGSMITH_TRACING="true"
LANGSMITH_ENDPOINT="https://api.smith.langchain.com"
LANGSMITH_API_KEY=<REPLACE_THIS>
LANGSMITH_PROJECT=<REPLACE_THIS>

View File

@@ -31,14 +31,16 @@
],
"presentation": {
"group": "1"
}
},
"stopAll": true
},
{
"name": "Web / Model / API",
"configurations": ["Web Server", "Model Server", "API Server"],
"presentation": {
"group": "1"
}
},
"stopAll": true
},
{
"name": "Celery (all)",
@@ -53,7 +55,8 @@
],
"presentation": {
"group": "1"
}
},
"stopAll": true
}
],
"configurations": [
@@ -189,7 +192,7 @@
"--loglevel=INFO",
"--hostname=light@%n",
"-Q",
"vespa_metadata_sync,connector_deletion,doc_permissions_upsert"
"vespa_metadata_sync,connector_deletion,doc_permissions_upsert,index_attempt_cleanup"
],
"presentation": {
"group": "2"

295
AGENTS.md Normal file
View File

@@ -0,0 +1,295 @@
# AGENTS.md
This file provides guidance to Codex when working with code in this repository.
## KEY NOTES
- If you run into any missing python dependency errors, try running your command with `workon onyx &&` in front
to assume the python venv.
- To make tests work, check the `.env` file at the root of the project to find an OpenAI key.
- If using `playwright` to explore the frontend, you can usually log in with username `a@test.com` and password
`a`. The app can be accessed at `http://localhost:3000`.
- You should assume that all Onyx services are running. To verify, you can check the `backend/log` directory to
make sure we see logs coming out from the relevant service.
- To connect to the Postgres database, use: `docker exec -it onyx-stack-relational_db-1 psql -U postgres -c "<SQL>"`
- When making calls to the backend, always go through the frontend. E.g. make a call to `http://localhost:3000/api/persona` not `http://localhost:8080/api/persona`
- Put ALL db operations under the `backend/onyx/db` / `backend/ee/onyx/db` directories. Don't run queries
outside of those directories.
## Project Overview
**Onyx** (formerly Danswer) is an open-source Gen-AI and Enterprise Search platform that connects to company documents, apps, and people. It features a modular architecture with both Community Edition (MIT licensed) and Enterprise Edition offerings.
### Background Workers (Celery)
Onyx uses Celery for asynchronous task processing with multiple specialized workers:
#### Worker Types
1. **Primary Worker** (`celery_app.py`)
- Coordinates core background tasks and system-wide operations
- Handles connector management, document sync, pruning, and periodic checks
- Runs with 4 threads concurrency
- Tasks: connector deletion, vespa sync, pruning, LLM model updates, user file sync
2. **Docfetching Worker** (`docfetching`)
- Fetches documents from external data sources (connectors)
- Spawns docprocessing tasks for each document batch
- Implements watchdog monitoring for stuck connectors
- Configurable concurrency (default from env)
3. **Docprocessing Worker** (`docprocessing`)
- Processes fetched documents through the indexing pipeline:
- Upserts documents to PostgreSQL
- Chunks documents and adds contextual information
- Embeds chunks via model server
- Writes chunks to Vespa vector database
- Updates document metadata
- Configurable concurrency (default from env)
4. **Light Worker** (`light`)
- Handles lightweight, fast operations
- Tasks: vespa operations, document permissions sync, external group sync
- Higher concurrency for quick tasks
5. **Heavy Worker** (`heavy`)
- Handles resource-intensive operations
- Primary task: document pruning operations
- Runs with 4 threads concurrency
6. **KG Processing Worker** (`kg_processing`)
- Handles Knowledge Graph processing and clustering
- Builds relationships between documents
- Runs clustering algorithms
- Configurable concurrency
7. **Monitoring Worker** (`monitoring`)
- System health monitoring and metrics collection
- Monitors Celery queues, process memory, and system status
- Single thread (monitoring doesn't need parallelism)
- Cloud-specific monitoring tasks
8. **Beat Worker** (`beat`)
- Celery's scheduler for periodic tasks
- Uses DynamicTenantScheduler for multi-tenant support
- Schedules tasks like:
- Indexing checks (every 15 seconds)
- Connector deletion checks (every 20 seconds)
- Vespa sync checks (every 20 seconds)
- Pruning checks (every 20 seconds)
- KG processing (every 60 seconds)
- Monitoring tasks (every 5 minutes)
- Cleanup tasks (hourly)
#### Key Features
- **Thread-based Workers**: All workers use thread pools (not processes) for stability
- **Tenant Awareness**: Multi-tenant support with per-tenant task isolation. There is a
middleware layer that automatically finds the appropriate tenant ID when sending tasks
via Celery Beat.
- **Task Prioritization**: High, Medium, Low priority queues
- **Monitoring**: Built-in heartbeat and liveness checking
- **Failure Handling**: Automatic retry and failure recovery mechanisms
- **Redis Coordination**: Inter-process communication via Redis
- **PostgreSQL State**: Task state and metadata stored in PostgreSQL
#### Important Notes
**Defining Tasks**:
- Always use `@shared_task` rather than `@celery_app`
- Put tasks under `background/celery/tasks/` or `ee/background/celery/tasks`
**Defining APIs**:
When creating new FastAPI APIs, do NOT use the `response_model` field. Instead, just type the
function.
**Testing Updates**:
If you make any updates to a celery worker and you want to test these changes, you will need
to ask me to restart the celery worker. There is no auto-restart on code-change mechanism.
### Code Quality
```bash
# Install and run pre-commit hooks
pre-commit install
pre-commit run --all-files
```
NOTE: Always make sure everything is strictly typed (both in Python and Typescript).
## Architecture Overview
### Technology Stack
- **Backend**: Python 3.11, FastAPI, SQLAlchemy, Alembic, Celery
- **Frontend**: Next.js 15+, React 18, TypeScript, Tailwind CSS
- **Database**: PostgreSQL with Redis caching
- **Search**: Vespa vector database
- **Auth**: OAuth2, SAML, multi-provider support
- **AI/ML**: LangChain, LiteLLM, multiple embedding models
### Directory Structure
```
backend/
├── onyx/
│ ├── auth/ # Authentication & authorization
│ ├── chat/ # Chat functionality & LLM interactions
│ ├── connectors/ # Data source connectors
│ ├── db/ # Database models & operations
│ ├── document_index/ # Vespa integration
│ ├── federated_connectors/ # External search connectors
│ ├── llm/ # LLM provider integrations
│ └── server/ # API endpoints & routers
├── ee/ # Enterprise Edition features
├── alembic/ # Database migrations
└── tests/ # Test suites
web/
├── src/app/ # Next.js app router pages
├── src/components/ # Reusable React components
└── src/lib/ # Utilities & business logic
```
## Database & Migrations
### Running Migrations
```bash
# Standard migrations
alembic upgrade head
# Multi-tenant (Enterprise)
alembic -n schema_private upgrade head
```
### Creating Migrations
```bash
# Auto-generate migration
alembic revision --autogenerate -m "description"
# Multi-tenant migration
alembic -n schema_private revision --autogenerate -m "description"
```
## Testing Strategy
There are 4 main types of tests within Onyx:
### Unit Tests
These should not assume any Onyx/external services are available to be called.
Interactions with the outside world should be mocked using `unittest.mock`. Generally, only
write these for complex, isolated modules e.g. `citation_processing.py`.
To run them:
```bash
python -m dotenv -f .vscode/.env run -- pytest -xv backend/tests/unit
```
### External Dependency Unit Tests
These tests assume that all external dependencies of Onyx are available and callable (e.g. Postgres, Redis,
MinIO/S3, Vespa are running + OpenAI can be called + any request to the internet is fine + etc.).
However, the actual Onyx containers are not running and with these tests we call the function to test directly.
We can also mock components/calls at will.
The goal with these tests are to minimize mocking while giving some flexibility to mock things that are flakey,
need strictly controlled behavior, or need to have their internal behavior validated (e.g. verify a function is called
with certain args, something that would be impossible with proper integration tests).
A great example of this type of test is `backend/tests/external_dependency_unit/connectors/confluence/test_confluence_group_sync.py`.
To run them:
```bash
python -m dotenv -f .vscode/.env run -- pytest backend/tests/external_dependency_unit
```
### Integration Tests
Standard integration tests. Every test in `backend/tests/integration` runs against a real Onyx deployment. We cannot
mock anything in these tests. Prefer writing integration tests (or External Dependency Unit Tests if mocking/internal
verification is necessary) over any other type of test.
Tests are parallelized at a directory level.
When writing integration tests, make sure to check the root `conftest.py` for useful fixtures + the `backend/tests/integration/common_utils` directory for utilities. Prefer (if one exists), calling the appropriate Manager
class in the utils over directly calling the APIs with a library like `requests`. Prefer using fixtures rather than
calling the utilities directly (e.g. do NOT create admin users with
`admin_user = UserManager.create(name="admin_user")`, instead use the `admin_user` fixture).
A great example of this type of test is `backend/tests/integration/dev_apis/test_simple_chat_api.py`.
To run them:
```bash
python -m dotenv -f .vscode/.env run -- pytest backend/tests/integration
```
### Playwright (E2E) Tests
These tests are an even more complete version of the Integration Tests mentioned above. Has all services of Onyx
running, *including* the Web Server.
Use these tests for anything that requires significant frontend <-> backend coordination.
Tests are located at `web/tests/e2e`. Tests are written in TypeScript.
To run them:
```bash
npx playwright test <TEST_NAME>
```
## Logs
When (1) writing integration tests or (2) doing live tests (e.g. curl / playwright) you can get access
to logs via the `backend/log/<service_name>_debug.log` file. All Onyx services (api_server, web_server, celery_X)
will be tailing their logs to this file.
## Security Considerations
- Never commit API keys or secrets to repository
- Use encrypted credential storage for connector credentials
- Follow RBAC patterns for new features
- Implement proper input validation with Pydantic models
- Use parameterized queries to prevent SQL injection
## AI/LLM Integration
- Multiple LLM providers supported via LiteLLM
- Configurable models per feature (chat, search, embeddings)
- Streaming support for real-time responses
- Token management and rate limiting
- Custom prompts and agent actions
## UI/UX Patterns
- Tailwind CSS with design system in `web/src/components/ui/`
- Radix UI and Headless UI for accessible components
- SWR for data fetching and caching
- Form validation with react-hook-form
- Error handling with popup notifications
## Creating a Plan
When creating a plan in the `plans` directory, make sure to include at least these elements:
**Issues to Address**
What the change is meant to do.
**Important Notes**
Things you come across in your research that are important to the implementation.
**Implementation strategy**
How you are going to make the changes happen. High level approach.
**Tests**
What unit (use rarely), external dependency unit, integration, and playwright tests you plan to write to
verify the correct behavior. Don't overtest. Usually, a given change only needs one type of test.
Do NOT include these: *Timeline*, *Rollback plan*
This is a minimal list - feel free to include more. Do NOT write code as part of your plan.
Keep it high level. You can reference certain files or functions though.
Before writing your plan, make sure to do research. Explore the relevant sections in the codebase.

295
CLAUDE.md Normal file
View File

@@ -0,0 +1,295 @@
# CLAUDE.md
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
## KEY NOTES
- If you run into any missing python dependency errors, try running your command with `workon onyx &&` in front
to assume the python venv.
- To make tests work, check the `.env` file at the root of the project to find an OpenAI key.
- If using `playwright` to explore the frontend, you can usually log in with username `a@test.com` and password
`a`. The app can be accessed at `http://localhost:3000`.
- You should assume that all Onyx services are running. To verify, you can check the `backend/log` directory to
make sure we see logs coming out from the relevant service.
- To connect to the Postgres database, use: `docker exec -it onyx-stack-relational_db-1 psql -U postgres -c "<SQL>"`
- When making calls to the backend, always go through the frontend. E.g. make a call to `http://localhost:3000/api/persona` not `http://localhost:8080/api/persona`
- Put ALL db operations under the `backend/onyx/db` / `backend/ee/onyx/db` directories. Don't run queries
outside of those directories.
## Project Overview
**Onyx** (formerly Danswer) is an open-source Gen-AI and Enterprise Search platform that connects to company documents, apps, and people. It features a modular architecture with both Community Edition (MIT licensed) and Enterprise Edition offerings.
### Background Workers (Celery)
Onyx uses Celery for asynchronous task processing with multiple specialized workers:
#### Worker Types
1. **Primary Worker** (`celery_app.py`)
- Coordinates core background tasks and system-wide operations
- Handles connector management, document sync, pruning, and periodic checks
- Runs with 4 threads concurrency
- Tasks: connector deletion, vespa sync, pruning, LLM model updates, user file sync
2. **Docfetching Worker** (`docfetching`)
- Fetches documents from external data sources (connectors)
- Spawns docprocessing tasks for each document batch
- Implements watchdog monitoring for stuck connectors
- Configurable concurrency (default from env)
3. **Docprocessing Worker** (`docprocessing`)
- Processes fetched documents through the indexing pipeline:
- Upserts documents to PostgreSQL
- Chunks documents and adds contextual information
- Embeds chunks via model server
- Writes chunks to Vespa vector database
- Updates document metadata
- Configurable concurrency (default from env)
4. **Light Worker** (`light`)
- Handles lightweight, fast operations
- Tasks: vespa operations, document permissions sync, external group sync
- Higher concurrency for quick tasks
5. **Heavy Worker** (`heavy`)
- Handles resource-intensive operations
- Primary task: document pruning operations
- Runs with 4 threads concurrency
6. **KG Processing Worker** (`kg_processing`)
- Handles Knowledge Graph processing and clustering
- Builds relationships between documents
- Runs clustering algorithms
- Configurable concurrency
7. **Monitoring Worker** (`monitoring`)
- System health monitoring and metrics collection
- Monitors Celery queues, process memory, and system status
- Single thread (monitoring doesn't need parallelism)
- Cloud-specific monitoring tasks
8. **Beat Worker** (`beat`)
- Celery's scheduler for periodic tasks
- Uses DynamicTenantScheduler for multi-tenant support
- Schedules tasks like:
- Indexing checks (every 15 seconds)
- Connector deletion checks (every 20 seconds)
- Vespa sync checks (every 20 seconds)
- Pruning checks (every 20 seconds)
- KG processing (every 60 seconds)
- Monitoring tasks (every 5 minutes)
- Cleanup tasks (hourly)
#### Key Features
- **Thread-based Workers**: All workers use thread pools (not processes) for stability
- **Tenant Awareness**: Multi-tenant support with per-tenant task isolation. There is a
middleware layer that automatically finds the appropriate tenant ID when sending tasks
via Celery Beat.
- **Task Prioritization**: High, Medium, Low priority queues
- **Monitoring**: Built-in heartbeat and liveness checking
- **Failure Handling**: Automatic retry and failure recovery mechanisms
- **Redis Coordination**: Inter-process communication via Redis
- **PostgreSQL State**: Task state and metadata stored in PostgreSQL
#### Important Notes
**Defining Tasks**:
- Always use `@shared_task` rather than `@celery_app`
- Put tasks under `background/celery/tasks/` or `ee/background/celery/tasks`
**Defining APIs**:
When creating new FastAPI APIs, do NOT use the `response_model` field. Instead, just type the
function.
**Testing Updates**:
If you make any updates to a celery worker and you want to test these changes, you will need
to ask me to restart the celery worker. There is no auto-restart on code-change mechanism.
### Code Quality
```bash
# Install and run pre-commit hooks
pre-commit install
pre-commit run --all-files
```
NOTE: Always make sure everything is strictly typed (both in Python and Typescript).
## Architecture Overview
### Technology Stack
- **Backend**: Python 3.11, FastAPI, SQLAlchemy, Alembic, Celery
- **Frontend**: Next.js 15+, React 18, TypeScript, Tailwind CSS
- **Database**: PostgreSQL with Redis caching
- **Search**: Vespa vector database
- **Auth**: OAuth2, SAML, multi-provider support
- **AI/ML**: LangChain, LiteLLM, multiple embedding models
### Directory Structure
```
backend/
├── onyx/
│ ├── auth/ # Authentication & authorization
│ ├── chat/ # Chat functionality & LLM interactions
│ ├── connectors/ # Data source connectors
│ ├── db/ # Database models & operations
│ ├── document_index/ # Vespa integration
│ ├── federated_connectors/ # External search connectors
│ ├── llm/ # LLM provider integrations
│ └── server/ # API endpoints & routers
├── ee/ # Enterprise Edition features
├── alembic/ # Database migrations
└── tests/ # Test suites
web/
├── src/app/ # Next.js app router pages
├── src/components/ # Reusable React components
└── src/lib/ # Utilities & business logic
```
## Database & Migrations
### Running Migrations
```bash
# Standard migrations
alembic upgrade head
# Multi-tenant (Enterprise)
alembic -n schema_private upgrade head
```
### Creating Migrations
```bash
# Auto-generate migration
alembic revision --autogenerate -m "description"
# Multi-tenant migration
alembic -n schema_private revision --autogenerate -m "description"
```
## Testing Strategy
There are 4 main types of tests within Onyx:
### Unit Tests
These should not assume any Onyx/external services are available to be called.
Interactions with the outside world should be mocked using `unittest.mock`. Generally, only
write these for complex, isolated modules e.g. `citation_processing.py`.
To run them:
```bash
python -m dotenv -f .vscode/.env run -- pytest -xv backend/tests/unit
```
### External Dependency Unit Tests
These tests assume that all external dependencies of Onyx are available and callable (e.g. Postgres, Redis,
MinIO/S3, Vespa are running + OpenAI can be called + any request to the internet is fine + etc.).
However, the actual Onyx containers are not running and with these tests we call the function to test directly.
We can also mock components/calls at will.
The goal with these tests are to minimize mocking while giving some flexibility to mock things that are flakey,
need strictly controlled behavior, or need to have their internal behavior validated (e.g. verify a function is called
with certain args, something that would be impossible with proper integration tests).
A great example of this type of test is `backend/tests/external_dependency_unit/connectors/confluence/test_confluence_group_sync.py`.
To run them:
```bash
python -m dotenv -f .vscode/.env run -- pytest backend/tests/external_dependency_unit
```
### Integration Tests
Standard integration tests. Every test in `backend/tests/integration` runs against a real Onyx deployment. We cannot
mock anything in these tests. Prefer writing integration tests (or External Dependency Unit Tests if mocking/internal
verification is necessary) over any other type of test.
Tests are parallelized at a directory level.
When writing integration tests, make sure to check the root `conftest.py` for useful fixtures + the `backend/tests/integration/common_utils` directory for utilities. Prefer (if one exists), calling the appropriate Manager
class in the utils over directly calling the APIs with a library like `requests`. Prefer using fixtures rather than
calling the utilities directly (e.g. do NOT create admin users with
`admin_user = UserManager.create(name="admin_user")`, instead use the `admin_user` fixture).
A great example of this type of test is `backend/tests/integration/dev_apis/test_simple_chat_api.py`.
To run them:
```bash
python -m dotenv -f .vscode/.env run -- pytest backend/tests/integration
```
### Playwright (E2E) Tests
These tests are an even more complete version of the Integration Tests mentioned above. Has all services of Onyx
running, *including* the Web Server.
Use these tests for anything that requires significant frontend <-> backend coordination.
Tests are located at `web/tests/e2e`. Tests are written in TypeScript.
To run them:
```bash
npx playwright test <TEST_NAME>
```
## Logs
When (1) writing integration tests or (2) doing live tests (e.g. curl / playwright) you can get access
to logs via the `backend/log/<service_name>_debug.log` file. All Onyx services (api_server, web_server, celery_X)
will be tailing their logs to this file.
## Security Considerations
- Never commit API keys or secrets to repository
- Use encrypted credential storage for connector credentials
- Follow RBAC patterns for new features
- Implement proper input validation with Pydantic models
- Use parameterized queries to prevent SQL injection
## AI/LLM Integration
- Multiple LLM providers supported via LiteLLM
- Configurable models per feature (chat, search, embeddings)
- Streaming support for real-time responses
- Token management and rate limiting
- Custom prompts and agent actions
## UI/UX Patterns
- Tailwind CSS with design system in `web/src/components/ui/`
- Radix UI and Headless UI for accessible components
- SWR for data fetching and caching
- Form validation with react-hook-form
- Error handling with popup notifications
## Creating a Plan
When creating a plan in the `plans` directory, make sure to include at least these elements:
**Issues to Address**
What the change is meant to do.
**Important Notes**
Things you come across in your research that are important to the implementation.
**Implementation strategy**
How you are going to make the changes happen. High level approach.
**Tests**
What unit (use rarely), external dependency unit, integration, and playwright tests you plan to write to
verify the correct behavior. Don't overtest. Usually, a given change only needs one type of test.
Do NOT include these: *Timeline*, *Rollback plan*
This is a minimal list - feel free to include more. Do NOT write code as part of your plan.
Keep it high level. You can reference certain files or functions though.
Before writing your plan, make sure to do research. Explore the relevant sections in the codebase.

View File

@@ -103,10 +103,10 @@ If using PowerShell, the command slightly differs:
Install the required python dependencies:
```bash
pip install -r onyx/backend/requirements/default.txt
pip install -r onyx/backend/requirements/dev.txt
pip install -r onyx/backend/requirements/ee.txt
pip install -r onyx/backend/requirements/model_server.txt
pip install -r backend/requirements/default.txt
pip install -r backend/requirements/dev.txt
pip install -r backend/requirements/ee.txt
pip install -r backend/requirements/model_server.txt
```
Install Playwright for Python (headless browser required by the Web Connector)

View File

@@ -5,7 +5,7 @@ This guide explains how to set up and use VSCode's debugging capabilities with t
## Initial Setup
1. **Environment Setup**:
- Copy `.vscode/.env.template` to `.vscode/.env`
- Copy `.vscode/env_template.txt` to `.vscode/.env`
- Fill in the necessary environment variables in `.vscode/.env`
2. **launch.json**:
- Copy `.vscode/launch.template.jsonc` to `.vscode/launch.json`
@@ -17,10 +17,9 @@ Before starting, make sure the Docker Daemon is running.
1. Open the Debug view in VSCode (Cmd+Shift+D on macOS)
2. From the dropdown at the top, select "Clear and Restart External Volumes and Containers" and press the green play button
3. From the dropdown at the top, select "Run All Onyx Services" and press the green play button
4. CD into web, run "npm i" followed by npm run dev.
5. Now, you can navigate to onyx in your browser (default is http://localhost:3000) and start using the app
6. You can set breakpoints by clicking to the left of line numbers to help debug while the app is running
7. Use the debug toolbar to step through code, inspect variables, etc.
4. Now, you can navigate to onyx in your browser (default is http://localhost:3000) and start using the app
5. You can set breakpoints by clicking to the left of line numbers to help debug while the app is running
6. Use the debug toolbar to step through code, inspect variables, etc.
## Features

View File

@@ -57,7 +57,7 @@ https://private-user-images.githubusercontent.com/32520769/414509312-48392e83-95
**To try it out for free and get started in seconds, check out [Onyx Cloud](https://cloud.onyx.app/signup)**.
Onyx can also be run locally (even on a laptop) or deployed on a virtual machine with a single
`docker compose` command. Checkout our [docs](https://docs.onyx.app/quickstart) to learn more.
`docker compose` command. Checkout our [docs](https://docs.onyx.app/deployment/getting_started/quickstart) to learn more.
We also have built-in support for high-availability/scalable deployment on Kubernetes.
References [here](https://github.com/onyx-dot-app/onyx/tree/main/deployment).
@@ -97,7 +97,7 @@ Keep knowledge and access up to sync across 40+ connectors:
- Websites
- And more ...
See the full list [here](https://docs.onyx.app/connectors).
See the full list [here](https://docs.onyx.app/admin/connectors/overview).
## 📚 Licensing

View File

@@ -12,7 +12,8 @@ ARG ONYX_VERSION=0.0.0-dev
# DO_NOT_TRACK is used to disable telemetry for Unstructured
ENV ONYX_VERSION=${ONYX_VERSION} \
DANSWER_RUNNING_IN_DOCKER="true" \
DO_NOT_TRACK="true"
DO_NOT_TRACK="true" \
PLAYWRIGHT_BROWSERS_PATH="/app/.cache/ms-playwright"
RUN echo "ONYX_VERSION: ${ONYX_VERSION}"
@@ -116,6 +117,14 @@ COPY ./assets /app/assets
ENV PYTHONPATH=/app
# Create non-root user for security best practices
RUN groupadd -g 1001 onyx && \
useradd -u 1001 -g onyx -m -s /bin/bash onyx && \
chown -R onyx:onyx /app && \
mkdir -p /var/log/onyx && \
chmod 755 /var/log/onyx && \
chown onyx:onyx /var/log/onyx
# Default command which does nothing
# This container is used by api server and background which specify their own CMD
CMD ["tail", "-f", "/dev/null"]

View File

@@ -9,11 +9,36 @@ visit https://github.com/onyx-dot-app/onyx."
# Default ONYX_VERSION, typically overriden during builds by GitHub Actions.
ARG ONYX_VERSION=0.0.0-dev
ENV ONYX_VERSION=${ONYX_VERSION} \
DANSWER_RUNNING_IN_DOCKER="true"
DANSWER_RUNNING_IN_DOCKER="true" \
HF_HOME=/app/.cache/huggingface
RUN echo "ONYX_VERSION: ${ONYX_VERSION}"
# Create non-root user for security best practices
RUN mkdir -p /app && \
groupadd -g 1001 onyx && \
useradd -u 1001 -g onyx -m -s /bin/bash onyx && \
chown -R onyx:onyx /app && \
mkdir -p /var/log/onyx && \
chmod 755 /var/log/onyx && \
chown onyx:onyx /var/log/onyx
# --- add toolchain needed for Rust/Python builds (fastuuid) ---
ENV RUSTUP_HOME=/usr/local/rustup \
CARGO_HOME=/usr/local/cargo \
PATH=/usr/local/cargo/bin:$PATH
RUN set -eux; \
apt-get update && apt-get install -y --no-install-recommends \
build-essential \
pkg-config \
curl \
ca-certificates \
&& rm -rf /var/lib/apt/lists/* \
# Install latest stable Rust (supports Cargo.lock v4)
&& curl -sSf https://sh.rustup.rs | sh -s -- -y --profile minimal --default-toolchain stable \
&& rustc --version && cargo --version
COPY ./requirements/model_server.txt /tmp/requirements.txt
RUN pip install --no-cache-dir --upgrade \
--retries 5 \
@@ -38,9 +63,11 @@ snapshot_download('mixedbread-ai/mxbai-rerank-xsmall-v1'); \
from sentence_transformers import SentenceTransformer; \
SentenceTransformer(model_name_or_path='nomic-ai/nomic-embed-text-v1', trust_remote_code=True);"
# In case the user has volumes mounted to /root/.cache/huggingface that they've downloaded while
# running Onyx, don't overwrite it with the built in cache folder
RUN mv /root/.cache/huggingface /root/.cache/temp_huggingface
# In case the user has volumes mounted to /app/.cache/huggingface that they've downloaded while
# running Onyx, move the current contents of the cache folder to a temporary location to ensure
# it's preserved in order to combine with the user's cache contents
RUN mv /app/.cache/huggingface /app/.cache/temp_huggingface && \
chown -R onyx:onyx /app
WORKDIR /app

View File

@@ -0,0 +1,380 @@
"""merge_default_assistants_into_unified
Revision ID: 505c488f6662
Revises: d09fc20a3c66
Create Date: 2025-09-09 19:00:56.816626
"""
import json
from typing import Any
from typing import NamedTuple
from uuid import UUID
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "505c488f6662"
down_revision = "d09fc20a3c66"
branch_labels = None
depends_on = None
# Constants for the unified assistant
UNIFIED_ASSISTANT_NAME = "Assistant"
UNIFIED_ASSISTANT_DESCRIPTION = (
"Your AI assistant with search, web browsing, and image generation capabilities."
)
UNIFIED_ASSISTANT_NUM_CHUNKS = 25
UNIFIED_ASSISTANT_DISPLAY_PRIORITY = 0
UNIFIED_ASSISTANT_LLM_FILTER_EXTRACTION = True
UNIFIED_ASSISTANT_LLM_RELEVANCE_FILTER = False
UNIFIED_ASSISTANT_RECENCY_BIAS = "AUTO" # NOTE: needs to be capitalized
UNIFIED_ASSISTANT_CHUNKS_ABOVE = 0
UNIFIED_ASSISTANT_CHUNKS_BELOW = 0
UNIFIED_ASSISTANT_DATETIME_AWARE = True
# NOTE: tool specific prompts are handled on the fly and automatically injected
# into the prompt before passing to the LLM.
DEFAULT_SYSTEM_PROMPT = """
You are a highly capable, thoughtful, and precise assistant. Your goal is to deeply understand the \
user's intent, ask clarifying questions when needed, think step-by-step through complex problems, \
provide clear and accurate answers, and proactively anticipate helpful follow-up information. Always \
prioritize being truthful, nuanced, insightful, and efficient.
The current date is [[CURRENT_DATETIME]]
You use different text styles, bolding, emojis (sparingly), block quotes, and other formatting to make \
your responses more readable and engaging.
You use proper Markdown and LaTeX to format your responses for math, scientific, and chemical formulas, \
symbols, etc.: '$$\\n[expression]\\n$$' for standalone cases and '\\( [expression] \\)' when inline.
For code you prefer to use Markdown and specify the language.
You can use Markdown horizontal rules (---) to separate sections of your responses.
You can use Markdown tables to format your responses for data, lists, and other structured information.
""".strip()
INSERT_DICT: dict[str, Any] = {
"name": UNIFIED_ASSISTANT_NAME,
"description": UNIFIED_ASSISTANT_DESCRIPTION,
"system_prompt": DEFAULT_SYSTEM_PROMPT,
"num_chunks": UNIFIED_ASSISTANT_NUM_CHUNKS,
"display_priority": UNIFIED_ASSISTANT_DISPLAY_PRIORITY,
"llm_filter_extraction": UNIFIED_ASSISTANT_LLM_FILTER_EXTRACTION,
"llm_relevance_filter": UNIFIED_ASSISTANT_LLM_RELEVANCE_FILTER,
"recency_bias": UNIFIED_ASSISTANT_RECENCY_BIAS,
"chunks_above": UNIFIED_ASSISTANT_CHUNKS_ABOVE,
"chunks_below": UNIFIED_ASSISTANT_CHUNKS_BELOW,
"datetime_aware": UNIFIED_ASSISTANT_DATETIME_AWARE,
}
GENERAL_ASSISTANT_ID = -1
ART_ASSISTANT_ID = -3
class UserRow(NamedTuple):
"""Typed representation of user row from database query."""
id: UUID
chosen_assistants: list[int] | None
visible_assistants: list[int] | None
hidden_assistants: list[int] | None
pinned_assistants: list[int] | None
def upgrade() -> None:
conn = op.get_bind()
# Start transaction
conn.execute(sa.text("BEGIN"))
try:
# Step 1: Create or update the unified assistant (ID 0)
search_assistant = conn.execute(
sa.text("SELECT * FROM persona WHERE id = 0")
).fetchone()
if search_assistant:
# Update existing Search assistant to be the unified assistant
conn.execute(
sa.text(
"""
UPDATE persona
SET name = :name,
description = :description,
system_prompt = :system_prompt,
num_chunks = :num_chunks,
is_default_persona = true,
is_visible = true,
deleted = false,
display_priority = :display_priority,
llm_filter_extraction = :llm_filter_extraction,
llm_relevance_filter = :llm_relevance_filter,
recency_bias = :recency_bias,
chunks_above = :chunks_above,
chunks_below = :chunks_below,
datetime_aware = :datetime_aware,
starter_messages = null
WHERE id = 0
"""
),
INSERT_DICT,
)
else:
# Create new unified assistant with ID 0
conn.execute(
sa.text(
"""
INSERT INTO persona (
id, name, description, system_prompt, num_chunks,
is_default_persona, is_visible, deleted, display_priority,
llm_filter_extraction, llm_relevance_filter, recency_bias,
chunks_above, chunks_below, datetime_aware, starter_messages,
builtin_persona
) VALUES (
0, :name, :description, :system_prompt, :num_chunks,
true, true, false, :display_priority, :llm_filter_extraction,
:llm_relevance_filter, :recency_bias, :chunks_above, :chunks_below,
:datetime_aware, null, true
)
"""
),
INSERT_DICT,
)
# Step 2: Mark ALL builtin assistants as deleted (except the unified assistant ID 0)
conn.execute(
sa.text(
"""
UPDATE persona
SET deleted = true, is_visible = false, is_default_persona = false
WHERE builtin_persona = true AND id != 0
"""
)
)
# Step 3: Add all built-in tools to the unified assistant
# First, get the tool IDs for SearchTool, ImageGenerationTool, and WebSearchTool
search_tool = conn.execute(
sa.text("SELECT id FROM tool WHERE in_code_tool_id = 'SearchTool'")
).fetchone()
if not search_tool:
raise ValueError(
"SearchTool not found in database. Ensure tools migration has run first."
)
image_gen_tool = conn.execute(
sa.text("SELECT id FROM tool WHERE in_code_tool_id = 'ImageGenerationTool'")
).fetchone()
if not image_gen_tool:
raise ValueError(
"ImageGenerationTool not found in database. Ensure tools migration has run first."
)
# WebSearchTool is optional - may not be configured
web_search_tool = conn.execute(
sa.text("SELECT id FROM tool WHERE in_code_tool_id = 'WebSearchTool'")
).fetchone()
# Clear existing tool associations for persona 0
conn.execute(sa.text("DELETE FROM persona__tool WHERE persona_id = 0"))
# Add tools to the unified assistant
conn.execute(
sa.text(
"""
INSERT INTO persona__tool (persona_id, tool_id)
VALUES (0, :tool_id)
ON CONFLICT DO NOTHING
"""
),
{"tool_id": search_tool[0]},
)
conn.execute(
sa.text(
"""
INSERT INTO persona__tool (persona_id, tool_id)
VALUES (0, :tool_id)
ON CONFLICT DO NOTHING
"""
),
{"tool_id": image_gen_tool[0]},
)
if web_search_tool:
conn.execute(
sa.text(
"""
INSERT INTO persona__tool (persona_id, tool_id)
VALUES (0, :tool_id)
ON CONFLICT DO NOTHING
"""
),
{"tool_id": web_search_tool[0]},
)
# Step 4: Migrate existing chat sessions from all builtin assistants to unified assistant
conn.execute(
sa.text(
"""
UPDATE chat_session
SET persona_id = 0
WHERE persona_id IN (
SELECT id FROM persona WHERE builtin_persona = true AND id != 0
)
"""
)
)
# Step 5: Migrate user preferences - remove references to all builtin assistants
# First, get all builtin assistant IDs (except 0)
builtin_assistants_result = conn.execute(
sa.text(
"""
SELECT id FROM persona
WHERE builtin_persona = true AND id != 0
"""
)
).fetchall()
builtin_assistant_ids = [row[0] for row in builtin_assistants_result]
# Get all users with preferences
users_result = conn.execute(
sa.text(
"""
SELECT id, chosen_assistants, visible_assistants,
hidden_assistants, pinned_assistants
FROM "user"
"""
)
).fetchall()
for user_row in users_result:
user = UserRow(*user_row)
user_id: UUID = user.id
updates: dict[str, Any] = {}
# Remove all builtin assistants from chosen_assistants
if user.chosen_assistants:
new_chosen: list[int] = [
assistant_id
for assistant_id in user.chosen_assistants
if assistant_id not in builtin_assistant_ids
]
if new_chosen != user.chosen_assistants:
updates["chosen_assistants"] = json.dumps(new_chosen)
# Remove all builtin assistants from visible_assistants
if user.visible_assistants:
new_visible: list[int] = [
assistant_id
for assistant_id in user.visible_assistants
if assistant_id not in builtin_assistant_ids
]
if new_visible != user.visible_assistants:
updates["visible_assistants"] = json.dumps(new_visible)
# Add all builtin assistants to hidden_assistants
if user.hidden_assistants:
new_hidden: list[int] = list(user.hidden_assistants)
for old_id in builtin_assistant_ids:
if old_id not in new_hidden:
new_hidden.append(old_id)
if new_hidden != user.hidden_assistants:
updates["hidden_assistants"] = json.dumps(new_hidden)
else:
updates["hidden_assistants"] = json.dumps(builtin_assistant_ids)
# Remove all builtin assistants from pinned_assistants
if user.pinned_assistants:
new_pinned: list[int] = [
assistant_id
for assistant_id in user.pinned_assistants
if assistant_id not in builtin_assistant_ids
]
if new_pinned != user.pinned_assistants:
updates["pinned_assistants"] = json.dumps(new_pinned)
# Apply updates if any
if updates:
set_clause = ", ".join([f"{k} = :{k}" for k in updates.keys()])
updates["user_id"] = str(user_id) # Convert UUID to string for SQL
conn.execute(
sa.text(f'UPDATE "user" SET {set_clause} WHERE id = :user_id'),
updates,
)
# Commit transaction
conn.execute(sa.text("COMMIT"))
except Exception as e:
# Rollback on error
conn.execute(sa.text("ROLLBACK"))
raise e
def downgrade() -> None:
conn = op.get_bind()
# Start transaction
conn.execute(sa.text("BEGIN"))
try:
# Only restore General (ID -1) and Art (ID -3) assistants
# Step 1: Keep Search assistant (ID 0) as default but restore original state
conn.execute(
sa.text(
"""
UPDATE persona
SET is_default_persona = true,
is_visible = true,
deleted = false
WHERE id = 0
"""
)
)
# Step 2: Restore General assistant (ID -1)
conn.execute(
sa.text(
"""
UPDATE persona
SET deleted = false,
is_visible = true,
is_default_persona = true
WHERE id = :general_assistant_id
"""
),
{"general_assistant_id": GENERAL_ASSISTANT_ID},
)
# Step 3: Restore Art assistant (ID -3)
conn.execute(
sa.text(
"""
UPDATE persona
SET deleted = false,
is_visible = true,
is_default_persona = true
WHERE id = :art_assistant_id
"""
),
{"art_assistant_id": ART_ASSISTANT_ID},
)
# Note: We don't restore the original tool associations, names, or descriptions
# as those would require more complex logic to determine original state.
# We also cannot restore original chat session persona_ids as we don't
# have the original mappings.
# Other builtin assistants remain deleted as per the requirement.
# Commit transaction
conn.execute(sa.text("COMMIT"))
except Exception as e:
# Rollback on error
conn.execute(sa.text("ROLLBACK"))
raise e

View File

@@ -0,0 +1,115 @@
"""add research agent database tables and chat message research fields
Revision ID: 5ae8240accb3
Revises: b558f51620b4
Create Date: 2025-08-06 14:29:24.691388
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "5ae8240accb3"
down_revision = "b558f51620b4"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Add research_type and research_plan columns to chat_message table
op.add_column(
"chat_message",
sa.Column("research_type", sa.String(), nullable=True),
)
op.add_column(
"chat_message",
sa.Column("research_plan", postgresql.JSONB(), nullable=True),
)
# Create research_agent_iteration table
op.create_table(
"research_agent_iteration",
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
sa.Column(
"primary_question_id",
sa.Integer(),
sa.ForeignKey("chat_message.id", ondelete="CASCADE"),
nullable=False,
),
sa.Column("iteration_nr", sa.Integer(), nullable=False),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.func.now(),
nullable=False,
),
sa.Column("purpose", sa.String(), nullable=True),
sa.Column("reasoning", sa.String(), nullable=True),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint(
"primary_question_id",
"iteration_nr",
name="_research_agent_iteration_unique_constraint",
),
)
# Create research_agent_iteration_sub_step table
op.create_table(
"research_agent_iteration_sub_step",
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
sa.Column(
"primary_question_id",
sa.Integer(),
sa.ForeignKey("chat_message.id", ondelete="CASCADE"),
nullable=False,
),
sa.Column(
"parent_question_id",
sa.Integer(),
sa.ForeignKey("research_agent_iteration_sub_step.id", ondelete="CASCADE"),
nullable=True,
),
sa.Column("iteration_nr", sa.Integer(), nullable=False),
sa.Column("iteration_sub_step_nr", sa.Integer(), nullable=False),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.func.now(),
nullable=False,
),
sa.Column("sub_step_instructions", sa.String(), nullable=True),
sa.Column(
"sub_step_tool_id",
sa.Integer(),
sa.ForeignKey("tool.id"),
nullable=True,
),
sa.Column("reasoning", sa.String(), nullable=True),
sa.Column("sub_answer", sa.String(), nullable=True),
sa.Column("cited_doc_results", postgresql.JSONB(), nullable=True),
sa.Column("claims", postgresql.JSONB(), nullable=True),
sa.Column("generated_images", postgresql.JSONB(), nullable=True),
sa.Column("additional_data", postgresql.JSONB(), nullable=True),
sa.PrimaryKeyConstraint("id"),
sa.ForeignKeyConstraint(
["primary_question_id", "iteration_nr"],
[
"research_agent_iteration.primary_question_id",
"research_agent_iteration.iteration_nr",
],
ondelete="CASCADE",
),
)
def downgrade() -> None:
# Drop tables in reverse order
op.drop_table("research_agent_iteration_sub_step")
op.drop_table("research_agent_iteration")
# Remove columns from chat_message table
op.drop_column("chat_message", "research_plan")
op.drop_column("chat_message", "research_type")

View File

@@ -0,0 +1,249 @@
"""add_mcp_server_and_connection_config_models
Revision ID: 7ed603b64d5a
Revises: b329d00a9ea6
Create Date: 2025-07-28 17:35:59.900680
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
from onyx.db.enums import MCPAuthenticationType
# revision identifiers, used by Alembic.
revision = "7ed603b64d5a"
down_revision = "b329d00a9ea6"
branch_labels = None
depends_on = None
def upgrade() -> None:
"""Create tables and columns for MCP Server support"""
# 1. MCP Server main table (no FK constraints yet to avoid circular refs)
op.create_table(
"mcp_server",
sa.Column("id", sa.Integer(), primary_key=True),
sa.Column("owner", sa.String(), nullable=False),
sa.Column("name", sa.String(), nullable=False),
sa.Column("description", sa.String(), nullable=True),
sa.Column("server_url", sa.String(), nullable=False),
sa.Column(
"auth_type",
sa.Enum(
MCPAuthenticationType,
name="mcp_authentication_type",
native_enum=False,
),
nullable=False,
),
sa.Column("admin_connection_config_id", sa.Integer(), nullable=True),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"), # type: ignore
nullable=False,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"), # type: ignore
nullable=False,
),
)
# 2. MCP Connection Config table (can reference mcp_server now that it exists)
op.create_table(
"mcp_connection_config",
sa.Column("id", sa.Integer(), primary_key=True),
sa.Column("mcp_server_id", sa.Integer(), nullable=True),
sa.Column("user_email", sa.String(), nullable=False, default=""),
sa.Column("config", sa.LargeBinary(), nullable=False),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"), # type: ignore
nullable=False,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"), # type: ignore
nullable=False,
),
sa.ForeignKeyConstraint(
["mcp_server_id"], ["mcp_server.id"], ondelete="CASCADE"
),
)
# Helpful indexes
op.create_index(
"ix_mcp_connection_config_server_user",
"mcp_connection_config",
["mcp_server_id", "user_email"],
)
op.create_index(
"ix_mcp_connection_config_user_email",
"mcp_connection_config",
["user_email"],
)
# 3. Add the back-references from mcp_server to connection configs
op.create_foreign_key(
"mcp_server_admin_config_fk",
"mcp_server",
"mcp_connection_config",
["admin_connection_config_id"],
["id"],
ondelete="SET NULL",
)
# 4. Association / access-control tables
op.create_table(
"mcp_server__user",
sa.Column("mcp_server_id", sa.Integer(), primary_key=True),
sa.Column("user_id", sa.UUID(), primary_key=True),
sa.ForeignKeyConstraint(
["mcp_server_id"], ["mcp_server.id"], ondelete="CASCADE"
),
sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"),
)
op.create_table(
"mcp_server__user_group",
sa.Column("mcp_server_id", sa.Integer(), primary_key=True),
sa.Column("user_group_id", sa.Integer(), primary_key=True),
sa.ForeignKeyConstraint(
["mcp_server_id"], ["mcp_server.id"], ondelete="CASCADE"
),
sa.ForeignKeyConstraint(["user_group_id"], ["user_group.id"]),
)
# 5. Update existing `tool` table allow tools to belong to an MCP server
op.add_column(
"tool",
sa.Column("mcp_server_id", sa.Integer(), nullable=True),
)
# Add column for MCP tool input schema
op.add_column(
"tool",
sa.Column("mcp_input_schema", postgresql.JSONB(), nullable=True),
)
op.create_foreign_key(
"tool_mcp_server_fk",
"tool",
"mcp_server",
["mcp_server_id"],
["id"],
ondelete="CASCADE",
)
# 6. Update persona__tool foreign keys to cascade delete
# This ensures that when a tool is deleted (including via MCP server deletion),
# the corresponding persona__tool rows are also deleted
op.drop_constraint(
"persona__tool_tool_id_fkey", "persona__tool", type_="foreignkey"
)
op.drop_constraint(
"persona__tool_persona_id_fkey", "persona__tool", type_="foreignkey"
)
op.create_foreign_key(
"persona__tool_persona_id_fkey",
"persona__tool",
"persona",
["persona_id"],
["id"],
ondelete="CASCADE",
)
op.create_foreign_key(
"persona__tool_tool_id_fkey",
"persona__tool",
"tool",
["tool_id"],
["id"],
ondelete="CASCADE",
)
# 7. Update research_agent_iteration_sub_step foreign key to SET NULL on delete
# This ensures that when a tool is deleted, the sub_step_tool_id is set to NULL
# instead of causing a foreign key constraint violation
op.drop_constraint(
"research_agent_iteration_sub_step_sub_step_tool_id_fkey",
"research_agent_iteration_sub_step",
type_="foreignkey",
)
op.create_foreign_key(
"research_agent_iteration_sub_step_sub_step_tool_id_fkey",
"research_agent_iteration_sub_step",
"tool",
["sub_step_tool_id"],
["id"],
ondelete="SET NULL",
)
def downgrade() -> None:
"""Drop all MCP-related tables / columns"""
# # # 1. Drop FK & columns from tool
# op.drop_constraint("tool_mcp_server_fk", "tool", type_="foreignkey")
op.execute("DELETE FROM tool WHERE mcp_server_id IS NOT NULL")
op.drop_constraint(
"research_agent_iteration_sub_step_sub_step_tool_id_fkey",
"research_agent_iteration_sub_step",
type_="foreignkey",
)
op.create_foreign_key(
"research_agent_iteration_sub_step_sub_step_tool_id_fkey",
"research_agent_iteration_sub_step",
"tool",
["sub_step_tool_id"],
["id"],
)
# Restore original persona__tool foreign keys (without CASCADE)
op.drop_constraint(
"persona__tool_persona_id_fkey", "persona__tool", type_="foreignkey"
)
op.drop_constraint(
"persona__tool_tool_id_fkey", "persona__tool", type_="foreignkey"
)
op.create_foreign_key(
"persona__tool_persona_id_fkey",
"persona__tool",
"persona",
["persona_id"],
["id"],
)
op.create_foreign_key(
"persona__tool_tool_id_fkey",
"persona__tool",
"tool",
["tool_id"],
["id"],
)
op.drop_column("tool", "mcp_input_schema")
op.drop_column("tool", "mcp_server_id")
# 2. Drop association tables
op.drop_table("mcp_server__user_group")
op.drop_table("mcp_server__user")
# 3. Drop FK from mcp_server to connection configs
op.drop_constraint("mcp_server_admin_config_fk", "mcp_server", type_="foreignkey")
# 4. Drop connection config indexes & table
op.drop_index(
"ix_mcp_connection_config_user_email", table_name="mcp_connection_config"
)
op.drop_index(
"ix_mcp_connection_config_server_user", table_name="mcp_connection_config"
)
op.drop_table("mcp_connection_config")
# 5. Finally drop mcp_server table
op.drop_table("mcp_server")

View File

@@ -0,0 +1,38 @@
"""drop include citations
Revision ID: 8818cf73fa1a
Revises: 7ed603b64d5a
Create Date: 2025-09-02 19:43:50.060680
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "8818cf73fa1a"
down_revision = "7ed603b64d5a"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.drop_column("prompt", "include_citations")
def downgrade() -> None:
op.add_column(
"prompt",
sa.Column(
"include_citations",
sa.BOOLEAN(),
autoincrement=False,
nullable=True,
),
)
# Set include_citations based on prompt name: FALSE for ImageGeneration, TRUE for others
op.execute(
sa.text(
"UPDATE prompt SET include_citations = CASE WHEN name = 'ImageGeneration' THEN FALSE ELSE TRUE END"
)
)

View File

@@ -0,0 +1,225 @@
"""merge prompt into persona
Revision ID: abbfec3a5ac5
Revises: 8818cf73fa1a
Create Date: 2024-12-19 12:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "abbfec3a5ac5"
down_revision = "8818cf73fa1a"
branch_labels = None
depends_on = None
MAX_PROMPT_LENGTH = 5_000_000
def upgrade() -> None:
"""NOTE: Prompts without any Personas will just be lost."""
# Step 1: Add new columns to persona table (only if they don't exist)
# Check if columns exist before adding them
connection = op.get_bind()
inspector = sa.inspect(connection)
existing_columns = [col["name"] for col in inspector.get_columns("persona")]
if "system_prompt" not in existing_columns:
op.add_column(
"persona",
sa.Column(
"system_prompt", sa.String(length=MAX_PROMPT_LENGTH), nullable=True
),
)
if "task_prompt" not in existing_columns:
op.add_column(
"persona",
sa.Column(
"task_prompt", sa.String(length=MAX_PROMPT_LENGTH), nullable=True
),
)
if "datetime_aware" not in existing_columns:
op.add_column(
"persona",
sa.Column(
"datetime_aware", sa.Boolean(), nullable=False, server_default="true"
),
)
# Step 2: Migrate data from prompt table to persona table (only if tables exist)
existing_tables = inspector.get_table_names()
if "prompt" in existing_tables and "persona__prompt" in existing_tables:
# For personas that have associated prompts, copy the prompt data
op.execute(
"""
UPDATE persona
SET
system_prompt = p.system_prompt,
task_prompt = p.task_prompt,
datetime_aware = p.datetime_aware
FROM (
-- Get the first prompt for each persona (in case there are multiple)
SELECT DISTINCT ON (pp.persona_id)
pp.persona_id,
pr.system_prompt,
pr.task_prompt,
pr.datetime_aware
FROM persona__prompt pp
JOIN prompt pr ON pp.prompt_id = pr.id
) p
WHERE persona.id = p.persona_id
"""
)
# Step 3: Update chat_message references
# Since chat messages referenced prompt_id, we need to update them to use persona_id
# This is complex as we need to map from prompt_id to persona_id
# Check if chat_message has prompt_id column
chat_message_columns = [
col["name"] for col in inspector.get_columns("chat_message")
]
if "prompt_id" in chat_message_columns:
op.execute(
"""
ALTER TABLE chat_message
DROP CONSTRAINT IF EXISTS chat_message__prompt_fk
"""
)
op.drop_column("chat_message", "prompt_id")
# Step 4: Handle personas without prompts - set default values if needed (always run this)
op.execute(
"""
UPDATE persona
SET
system_prompt = COALESCE(system_prompt, ''),
task_prompt = COALESCE(task_prompt, '')
WHERE system_prompt IS NULL OR task_prompt IS NULL
"""
)
# Step 5: Drop the persona__prompt association table (if it exists)
if "persona__prompt" in existing_tables:
op.drop_table("persona__prompt")
# Step 6: Drop the prompt table (if it exists)
if "prompt" in existing_tables:
op.drop_table("prompt")
# Step 7: Make system_prompt and task_prompt non-nullable after migration (only if they exist)
op.alter_column(
"persona",
"system_prompt",
existing_type=sa.String(length=MAX_PROMPT_LENGTH),
nullable=False,
server_default=None,
)
op.alter_column(
"persona",
"task_prompt",
existing_type=sa.String(length=MAX_PROMPT_LENGTH),
nullable=False,
server_default=None,
)
def downgrade() -> None:
# Step 1: Recreate the prompt table
op.create_table(
"prompt",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=True),
sa.Column("name", sa.String(), nullable=False),
sa.Column("description", sa.String(), nullable=False),
sa.Column("system_prompt", sa.String(length=MAX_PROMPT_LENGTH), nullable=False),
sa.Column("task_prompt", sa.String(length=MAX_PROMPT_LENGTH), nullable=False),
sa.Column(
"datetime_aware", sa.Boolean(), nullable=False, server_default="true"
),
sa.Column(
"default_prompt", sa.Boolean(), nullable=False, server_default="false"
),
sa.Column("deleted", sa.Boolean(), nullable=False, server_default="false"),
sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"),
)
# Step 2: Recreate the persona__prompt association table
op.create_table(
"persona__prompt",
sa.Column("persona_id", sa.Integer(), nullable=False),
sa.Column("prompt_id", sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(
["persona_id"],
["persona.id"],
),
sa.ForeignKeyConstraint(
["prompt_id"],
["prompt.id"],
),
sa.PrimaryKeyConstraint("persona_id", "prompt_id"),
)
# Step 3: Migrate data back from persona to prompt table
op.execute(
"""
INSERT INTO prompt (
name,
description,
system_prompt,
task_prompt,
datetime_aware,
default_prompt,
deleted,
user_id
)
SELECT
CONCAT('Prompt for ', name),
description,
system_prompt,
task_prompt,
datetime_aware,
is_default_persona,
deleted,
user_id
FROM persona
WHERE system_prompt IS NOT NULL AND system_prompt != ''
RETURNING id, name
"""
)
# Step 4: Re-establish persona__prompt relationships
op.execute(
"""
INSERT INTO persona__prompt (persona_id, prompt_id)
SELECT
p.id as persona_id,
pr.id as prompt_id
FROM persona p
JOIN prompt pr ON pr.name = CONCAT('Prompt for ', p.name)
WHERE p.system_prompt IS NOT NULL AND p.system_prompt != ''
"""
)
# Step 5: Add prompt_id column back to chat_message
op.add_column("chat_message", sa.Column("prompt_id", sa.Integer(), nullable=True))
# Step 6: Re-establish foreign key constraint
op.create_foreign_key(
"chat_message__prompt_fk", "chat_message", "prompt", ["prompt_id"], ["id"]
)
# Step 7: Remove columns from persona table
op.drop_column("persona", "datetime_aware")
op.drop_column("persona", "task_prompt")
op.drop_column("persona", "system_prompt")

View File

@@ -0,0 +1,38 @@
"""Adding assistant-specific user preferences
Revision ID: b329d00a9ea6
Revises: f9b8c7d6e5a4
Create Date: 2025-08-26 23:14:44.592985
"""
from alembic import op
import fastapi_users_db_sqlalchemy
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "b329d00a9ea6"
down_revision = "f9b8c7d6e5a4"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.create_table(
"assistant__user_specific_config",
sa.Column("assistant_id", sa.Integer(), nullable=False),
sa.Column(
"user_id",
fastapi_users_db_sqlalchemy.generics.GUID(),
nullable=False,
),
sa.Column("disabled_tool_ids", postgresql.ARRAY(sa.Integer()), nullable=False),
sa.ForeignKeyConstraint(["assistant_id"], ["persona.id"], ondelete="CASCADE"),
sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("assistant_id", "user_id"),
)
def downgrade() -> None:
op.drop_table("assistant__user_specific_config")

View File

@@ -0,0 +1,43 @@
"""adjust prompt length
Revision ID: b7ec9b5b505f
Revises: abbfec3a5ac5
Create Date: 2025-09-10 18:51:15.629197
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "b7ec9b5b505f"
down_revision = "abbfec3a5ac5"
branch_labels = None
depends_on = None
MAX_PROMPT_LENGTH = 5_000_000
def upgrade() -> None:
# NOTE: need to run this since the previous migration PREVIOUSLY set the length to 8000
op.alter_column(
"persona",
"system_prompt",
existing_type=sa.String(length=8000),
type_=sa.String(length=MAX_PROMPT_LENGTH),
existing_nullable=False,
)
op.alter_column(
"persona",
"task_prompt",
existing_type=sa.String(length=8000),
type_=sa.String(length=MAX_PROMPT_LENGTH),
existing_nullable=False,
)
def downgrade() -> None:
# Downgrade not necessary
pass

View File

@@ -0,0 +1,147 @@
"""migrate_agent_sub_questions_to_research_iterations
Revision ID: bd7c3bf8beba
Revises: f8a9b2c3d4e5
Create Date: 2025-08-18 11:33:27.098287
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "bd7c3bf8beba"
down_revision = "f8a9b2c3d4e5"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Get connection to execute raw SQL
connection = op.get_bind()
# First, insert data into research_agent_iteration table
# This creates one iteration record per primary_question_id using the earliest time_created
connection.execute(
sa.text(
"""
INSERT INTO research_agent_iteration (primary_question_id, created_at, iteration_nr, purpose, reasoning)
SELECT
primary_question_id,
MIN(time_created) as created_at,
1 as iteration_nr,
'Generating and researching subquestions' as purpose,
'(No previous reasoning)' as reasoning
FROM agent__sub_question
JOIN chat_message on agent__sub_question.primary_question_id = chat_message.id
WHERE primary_question_id IS NOT NULL
AND chat_message.is_agentic = true
GROUP BY primary_question_id
ON CONFLICT DO NOTHING;
"""
)
)
# Then, insert data into research_agent_iteration_sub_step table
# This migrates each sub-question as a sub-step
connection.execute(
sa.text(
"""
INSERT INTO research_agent_iteration_sub_step (
primary_question_id,
iteration_nr,
iteration_sub_step_nr,
created_at,
sub_step_instructions,
sub_step_tool_id,
sub_answer,
cited_doc_results
)
SELECT
primary_question_id,
1 as iteration_nr,
level_question_num as iteration_sub_step_nr,
time_created as created_at,
sub_question as sub_step_instructions,
1 as sub_step_tool_id,
sub_answer,
sub_question_doc_results as cited_doc_results
FROM agent__sub_question
JOIN chat_message on agent__sub_question.primary_question_id = chat_message.id
WHERE chat_message.is_agentic = true
AND primary_question_id IS NOT NULL
ON CONFLICT DO NOTHING;
"""
)
)
# Update chat_message records: set legacy agentic type and answer purpose for existing agentic messages
connection.execute(
sa.text(
"""
UPDATE chat_message
SET research_answer_purpose = 'ANSWER'
WHERE is_agentic = true
AND research_type IS NULL and
message_type = 'ASSISTANT';
"""
)
)
connection.execute(
sa.text(
"""
UPDATE chat_message
SET research_type = 'LEGACY_AGENTIC'
WHERE is_agentic = true
AND research_type IS NULL;
"""
)
)
def downgrade() -> None:
# Get connection to execute raw SQL
connection = op.get_bind()
# Note: This downgrade removes all research agent iteration data
# There's no way to perfectly restore the original agent__sub_question data
# if it was deleted after this migration
# Delete all research_agent_iteration_sub_step records that were migrated
connection.execute(
sa.text(
"""
DELETE FROM research_agent_iteration_sub_step
USING chat_message
WHERE research_agent_iteration_sub_step.primary_question_id = chat_message.id
AND chat_message.research_type = 'LEGACY_AGENTIC';
"""
)
)
# Delete all research_agent_iteration records that were migrated
connection.execute(
sa.text(
"""
DELETE FROM research_agent_iteration
USING chat_message
WHERE research_agent_iteration.primary_question_id = chat_message.id
AND chat_message.research_type = 'LEGACY_AGENTIC';
"""
)
)
# Revert chat_message updates: clear research fields for legacy agentic messages
connection.execute(
sa.text(
"""
UPDATE chat_message
SET research_type = NULL,
research_answer_purpose = NULL
WHERE is_agentic = true
AND research_type = 'LEGACY_AGENTIC'
AND message_type = 'ASSISTANT';
"""
)
)

View File

@@ -0,0 +1,125 @@
"""seed_builtin_tools
Revision ID: d09fc20a3c66
Revises: b7ec9b5b505f
Create Date: 2025-09-09 19:32:16.824373
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "d09fc20a3c66"
down_revision = "b7ec9b5b505f"
branch_labels = None
depends_on = None
# Tool definitions - core tools that should always be seeded
# Names/in_code_tool_id are the same as the class names in the tool_implementations package
BUILT_IN_TOOLS = [
{
"name": "SearchTool",
"display_name": "Internal Search",
"description": "The Search Action allows the Assistant to search through connected knowledge to help build an answer.",
"in_code_tool_id": "SearchTool",
},
{
"name": "ImageGenerationTool",
"display_name": "Image Generation",
"description": (
"The Image Generation Action allows the assistant to use DALL-E 3 or GPT-IMAGE-1 to generate images. "
"The action will be used when the user asks the assistant to generate an image."
),
"in_code_tool_id": "ImageGenerationTool",
},
{
"name": "WebSearchTool",
"display_name": "Web Search",
"description": (
"The Web Search Action allows the assistant "
"to perform internet searches for up-to-date information."
),
"in_code_tool_id": "WebSearchTool",
},
{
"name": "KnowledgeGraphTool",
"display_name": "Knowledge Graph Search",
"description": (
"The Knowledge Graph Search Action allows the assistant to search the "
"Knowledge Graph for information. This tool can (for now) only be active in the KG Beta Assistant, "
"and it requires the Knowledge Graph to be enabled."
),
"in_code_tool_id": "KnowledgeGraphTool",
},
{
"name": "OktaProfileTool",
"display_name": "Okta Profile",
"description": (
"The Okta Profile Action allows the assistant to fetch the current user's information from Okta. "
"This may include the user's name, email, phone number, address, and other details such as their "
"manager and direct reports."
),
"in_code_tool_id": "OktaProfileTool",
},
]
def upgrade() -> None:
conn = op.get_bind()
# Start transaction
conn.execute(sa.text("BEGIN"))
try:
# Get existing tools to check what already exists
existing_tools = conn.execute(
sa.text(
"SELECT in_code_tool_id FROM tool WHERE in_code_tool_id IS NOT NULL"
)
).fetchall()
existing_tool_ids = {row[0] for row in existing_tools}
# Insert or update built-in tools
for tool in BUILT_IN_TOOLS:
if tool["in_code_tool_id"] in existing_tool_ids:
# Update existing tool
conn.execute(
sa.text(
"""
UPDATE tool
SET name = :name,
display_name = :display_name,
description = :description
WHERE in_code_tool_id = :in_code_tool_id
"""
),
tool,
)
else:
# Insert new tool
conn.execute(
sa.text(
"""
INSERT INTO tool (name, display_name, description, in_code_tool_id)
VALUES (:name, :display_name, :description, :in_code_tool_id)
"""
),
tool,
)
# Commit transaction
conn.execute(sa.text("COMMIT"))
except Exception as e:
# Rollback on error
conn.execute(sa.text("ROLLBACK"))
raise e
def downgrade() -> None:
# We don't remove the tools on downgrade since it's totally fine to just
# have them around. If we upgrade again, it will be a no-op.
pass

View File

@@ -0,0 +1,30 @@
"""add research_answer_purpose to chat_message
Revision ID: f8a9b2c3d4e5
Revises: 5ae8240accb3
Create Date: 2025-01-27 12:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "f8a9b2c3d4e5"
down_revision = "5ae8240accb3"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Add research_answer_purpose column to chat_message table
op.add_column(
"chat_message",
sa.Column("research_answer_purpose", sa.String(), nullable=True),
)
def downgrade() -> None:
# Remove research_answer_purpose column from chat_message table
op.drop_column("chat_message", "research_answer_purpose")

View File

@@ -0,0 +1,69 @@
"""remove foreign key constraints from research_agent_iteration_sub_step
Revision ID: f9b8c7d6e5a4
Revises: bd7c3bf8beba
Create Date: 2025-01-27 12:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "f9b8c7d6e5a4"
down_revision = "bd7c3bf8beba"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Drop the existing foreign key constraint for parent_question_id
op.drop_constraint(
"research_agent_iteration_sub_step_parent_question_id_fkey",
"research_agent_iteration_sub_step",
type_="foreignkey",
)
# Drop the parent_question_id column entirely
op.drop_column("research_agent_iteration_sub_step", "parent_question_id")
# Drop the foreign key constraint for primary_question_id to chat_message.id
# (keep the column as it's needed for the composite foreign key)
op.drop_constraint(
"research_agent_iteration_sub_step_primary_question_id_fkey",
"research_agent_iteration_sub_step",
type_="foreignkey",
)
def downgrade() -> None:
# Restore the foreign key constraint for primary_question_id to chat_message.id
op.create_foreign_key(
"research_agent_iteration_sub_step_primary_question_id_fkey",
"research_agent_iteration_sub_step",
"chat_message",
["primary_question_id"],
["id"],
ondelete="CASCADE",
)
# Add back the parent_question_id column
op.add_column(
"research_agent_iteration_sub_step",
sa.Column(
"parent_question_id",
sa.Integer(),
nullable=True,
),
)
# Restore the foreign key constraint pointing to research_agent_iteration_sub_step.id
op.create_foreign_key(
"research_agent_iteration_sub_step_parent_question_id_fkey",
"research_agent_iteration_sub_step",
"research_agent_iteration_sub_step",
["parent_question_id"],
["id"],
ondelete="CASCADE",
)

View File

@@ -1,133 +1,4 @@
from datetime import datetime
from datetime import timezone
from uuid import UUID
from celery import shared_task
from celery import Task
from ee.onyx.background.celery_utils import should_perform_chat_ttl_check
from ee.onyx.background.task_name_builders import name_chat_ttl_task
from ee.onyx.server.reporting.usage_export_generation import create_new_usage_report
from onyx.background.celery.apps.primary import celery_app
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.configs.constants import OnyxCeleryTask
from onyx.db.chat import delete_chat_session
from onyx.db.chat import get_chat_sessions_older_than
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.enums import TaskStatus
from onyx.db.tasks import mark_task_as_finished_with_id
from onyx.db.tasks import register_task
from onyx.server.settings.store import load_settings
from onyx.utils.logger import setup_logger
logger = setup_logger()
# mark as EE for all tasks in this file
@shared_task(
name=OnyxCeleryTask.PERFORM_TTL_MANAGEMENT_TASK,
ignore_result=True,
soft_time_limit=JOB_TIMEOUT,
bind=True,
trail=False,
)
def perform_ttl_management_task(
self: Task, retention_limit_days: int, *, tenant_id: str
) -> None:
task_id = self.request.id
if not task_id:
raise RuntimeError("No task id defined for this task; cannot identify it")
start_time = datetime.now(tz=timezone.utc)
user_id: UUID | None = None
session_id: UUID | None = None
try:
with get_session_with_current_tenant() as db_session:
# we generally want to move off this, but keeping for now
register_task(
db_session=db_session,
task_name=name_chat_ttl_task(retention_limit_days, tenant_id),
task_id=task_id,
status=TaskStatus.STARTED,
start_time=start_time,
)
old_chat_sessions = get_chat_sessions_older_than(
retention_limit_days, db_session
)
for user_id, session_id in old_chat_sessions:
# one session per delete so that we don't blow up if a deletion fails.
with get_session_with_current_tenant() as db_session:
delete_chat_session(
user_id,
session_id,
db_session,
include_deleted=True,
hard_delete=True,
)
with get_session_with_current_tenant() as db_session:
mark_task_as_finished_with_id(
db_session=db_session,
task_id=task_id,
success=True,
)
except Exception:
logger.exception(
"delete_chat_session exceptioned. "
f"user_id={user_id} session_id={session_id}"
)
with get_session_with_current_tenant() as db_session:
mark_task_as_finished_with_id(
db_session=db_session,
task_id=task_id,
success=False,
)
raise
#####
# Periodic Tasks
#####
@celery_app.task(
name=OnyxCeleryTask.CHECK_TTL_MANAGEMENT_TASK,
ignore_result=True,
soft_time_limit=JOB_TIMEOUT,
)
def check_ttl_management_task(*, tenant_id: str) -> None:
"""Runs periodically to check if any ttl tasks should be run and adds them
to the queue"""
settings = load_settings()
retention_limit_days = settings.maximum_chat_retention_days
with get_session_with_current_tenant() as db_session:
if should_perform_chat_ttl_check(retention_limit_days, db_session):
perform_ttl_management_task.apply_async(
kwargs=dict(
retention_limit_days=retention_limit_days, tenant_id=tenant_id
),
)
@celery_app.task(
name=OnyxCeleryTask.AUTOGENERATE_USAGE_REPORT_TASK,
ignore_result=True,
soft_time_limit=JOB_TIMEOUT,
)
def autogenerate_usage_report_task(*, tenant_id: str) -> None:
"""This generates usage report under the /admin/generate-usage/report endpoint"""
with get_session_with_current_tenant() as db_session:
create_new_usage_report(
db_session=db_session,
user_id=None,
period=None,
)
celery_app.autodiscover_tasks(
@@ -135,5 +6,7 @@ celery_app.autodiscover_tasks(
"ee.onyx.background.celery.tasks.doc_permission_syncing",
"ee.onyx.background.celery.tasks.external_group_syncing",
"ee.onyx.background.celery.tasks.cloud",
"ee.onyx.background.celery.tasks.ttl_management",
"ee.onyx.background.celery.tasks.usage_reporting",
]
)

View File

@@ -23,7 +23,7 @@ ee_beat_system_tasks: list[dict] = []
ee_beat_task_templates: list[dict] = [
{
"name": "autogenerate-usage-report",
"task": OnyxCeleryTask.AUTOGENERATE_USAGE_REPORT_TASK,
"task": OnyxCeleryTask.GENERATE_USAGE_REPORT_TASK,
"schedule": timedelta(days=30),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
@@ -57,7 +57,7 @@ if not MULTI_TENANT:
ee_tasks_to_schedule = [
{
"name": "autogenerate-usage-report",
"task": OnyxCeleryTask.AUTOGENERATE_USAGE_REPORT_TASK,
"task": OnyxCeleryTask.GENERATE_USAGE_REPORT_TASK,
"schedule": timedelta(days=30), # TODO: change this to config flag
"options": {
"priority": OnyxCeleryPriority.MEDIUM,

View File

@@ -0,0 +1,106 @@
from datetime import datetime
from datetime import timezone
from uuid import UUID
from celery import shared_task
from celery import Task
from ee.onyx.background.celery_utils import should_perform_chat_ttl_check
from ee.onyx.background.task_name_builders import name_chat_ttl_task
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.configs.constants import OnyxCeleryTask
from onyx.db.chat import delete_chat_session
from onyx.db.chat import get_chat_sessions_older_than
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.enums import TaskStatus
from onyx.db.tasks import mark_task_as_finished_with_id
from onyx.db.tasks import register_task
from onyx.server.settings.store import load_settings
from onyx.utils.logger import setup_logger
logger = setup_logger()
@shared_task(
name=OnyxCeleryTask.PERFORM_TTL_MANAGEMENT_TASK,
ignore_result=True,
soft_time_limit=JOB_TIMEOUT,
bind=True,
trail=False,
)
def perform_ttl_management_task(
self: Task, retention_limit_days: int, *, tenant_id: str
) -> None:
task_id = self.request.id
if not task_id:
raise RuntimeError("No task id defined for this task; cannot identify it")
start_time = datetime.now(tz=timezone.utc)
user_id: UUID | None = None
session_id: UUID | None = None
try:
with get_session_with_current_tenant() as db_session:
# we generally want to move off this, but keeping for now
register_task(
db_session=db_session,
task_name=name_chat_ttl_task(retention_limit_days, tenant_id),
task_id=task_id,
status=TaskStatus.STARTED,
start_time=start_time,
)
old_chat_sessions = get_chat_sessions_older_than(
retention_limit_days, db_session
)
for user_id, session_id in old_chat_sessions:
# one session per delete so that we don't blow up if a deletion fails.
with get_session_with_current_tenant() as db_session:
delete_chat_session(
user_id,
session_id,
db_session,
include_deleted=True,
hard_delete=True,
)
with get_session_with_current_tenant() as db_session:
mark_task_as_finished_with_id(
db_session=db_session,
task_id=task_id,
success=True,
)
except Exception:
logger.exception(
"delete_chat_session exceptioned. "
f"user_id={user_id} session_id={session_id}"
)
with get_session_with_current_tenant() as db_session:
mark_task_as_finished_with_id(
db_session=db_session,
task_id=task_id,
success=False,
)
raise
@shared_task(
name=OnyxCeleryTask.CHECK_TTL_MANAGEMENT_TASK,
ignore_result=True,
soft_time_limit=JOB_TIMEOUT,
)
def check_ttl_management_task(*, tenant_id: str) -> None:
"""Runs periodically to check if any ttl tasks should be run and adds them
to the queue"""
settings = load_settings()
retention_limit_days = settings.maximum_chat_retention_days
with get_session_with_current_tenant() as db_session:
if should_perform_chat_ttl_check(retention_limit_days, db_session):
perform_ttl_management_task.apply_async(
kwargs=dict(
retention_limit_days=retention_limit_days, tenant_id=tenant_id
),
)

View File

@@ -0,0 +1,46 @@
from datetime import datetime
from uuid import UUID
from celery import shared_task
from celery import Task
from ee.onyx.server.reporting.usage_export_generation import create_new_usage_report
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.configs.constants import OnyxCeleryTask
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.utils.logger import setup_logger
logger = setup_logger()
@shared_task(
name=OnyxCeleryTask.GENERATE_USAGE_REPORT_TASK,
ignore_result=True,
soft_time_limit=JOB_TIMEOUT,
bind=True,
trail=False,
)
def generate_usage_report_task(
self: Task,
*,
tenant_id: str,
user_id: str | None = None,
period_from: str | None = None,
period_to: str | None = None,
) -> None:
"""User-initiated usage report generation task"""
# Parse period if provided
period = None
if period_from and period_to:
period = (
datetime.fromisoformat(period_from),
datetime.fromisoformat(period_to),
)
# Generate the report
with get_session_with_current_tenant() as db_session:
create_new_usage_report(
db_session=db_session,
user_id=UUID(user_id) if user_id else None,
period=period,
)

View File

@@ -1,38 +0,0 @@
from ee.onyx.server.query_and_chat.models import OneShotQAResponse
from onyx.chat.models import AllCitations
from onyx.chat.models import LLMRelevanceFilterResponse
from onyx.chat.models import OnyxAnswerPiece
from onyx.chat.models import QADocsResponse
from onyx.chat.models import StreamingError
from onyx.chat.process_message import ChatPacketStream
from onyx.server.query_and_chat.models import ChatMessageDetail
from onyx.utils.timing import log_function_time
@log_function_time()
def gather_stream_for_answer_api(
packets: ChatPacketStream,
) -> OneShotQAResponse:
response = OneShotQAResponse()
answer = ""
for packet in packets:
if isinstance(packet, OnyxAnswerPiece) and packet.answer_piece:
answer += packet.answer_piece
elif isinstance(packet, QADocsResponse):
response.docs = packet
# Extraneous, provided for backwards compatibility
response.rephrase = packet.rephrased_query
elif isinstance(packet, StreamingError):
response.error_msg = packet.error
elif isinstance(packet, ChatMessageDetail):
response.chat_message_id = packet.message_id
elif isinstance(packet, LLMRelevanceFilterResponse):
response.llm_selected_doc_indices = packet.llm_selected_doc_indices
elif isinstance(packet, AllCitations):
response.citations = packet.citations
if answer:
response.answer = answer
return response

View File

@@ -2,18 +2,14 @@ from collections.abc import Callable
from collections.abc import Generator
from typing import Optional
from typing import Protocol
from typing import TYPE_CHECKING
from ee.onyx.db.external_perm import ExternalUserGroup # noqa
from onyx.access.models import DocExternalAccess # noqa
from onyx.context.search.models import InferenceChunk
from onyx.db.models import ConnectorCredentialPair # noqa
from onyx.db.utils import DocumentRow
from onyx.db.utils import SortOrder
# Avoid circular imports
if TYPE_CHECKING:
from ee.onyx.db.external_perm import ExternalUserGroup # noqa
from onyx.access.models import DocExternalAccess # noqa
from onyx.db.models import ConnectorCredentialPair # noqa
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface # noqa
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface # noqa
class FetchAllDocumentsFunction(Protocol):
@@ -52,20 +48,20 @@ class FetchAllDocumentsIdsFunction(Protocol):
# Defining the input/output types for the sync functions
DocSyncFuncType = Callable[
[
"ConnectorCredentialPair",
ConnectorCredentialPair,
FetchAllDocumentsFunction,
FetchAllDocumentsIdsFunction,
Optional["IndexingHeartbeatInterface"],
Optional[IndexingHeartbeatInterface],
],
Generator["DocExternalAccess", None, None],
Generator[DocExternalAccess, None, None],
]
GroupSyncFuncType = Callable[
[
str, # tenant_id
"ConnectorCredentialPair", # cc_pair
ConnectorCredentialPair, # cc_pair
],
Generator["ExternalUserGroup", None, None],
Generator[ExternalUserGroup, None, None],
]
# list of chunks to be censored and the user email. returns censored chunks

View File

@@ -1,9 +1,12 @@
import re
from collections import deque
from typing import Any
from urllib.parse import unquote
from urllib.parse import urlparse
from office365.graph_client import GraphClient # type: ignore[import-untyped]
from office365.onedrive.driveitems.driveItem import DriveItem # type: ignore[import-untyped]
from office365.runtime.client_request import ClientRequestException # type: ignore
from office365.sharepoint.client_context import ClientContext # type: ignore[import-untyped]
from office365.sharepoint.permissions.securable_object import RoleAssignmentCollection # type: ignore[import-untyped]
from pydantic import BaseModel
@@ -231,6 +234,7 @@ def _get_sharepoint_groups(
nonlocal groups, user_emails
for user in users:
logger.debug(f"User: {user.to_json()}")
if user.principal_type == USER_PRINCIPAL_TYPE and hasattr(
user, "user_principal_name"
):
@@ -285,7 +289,7 @@ def _get_azuread_groups(
for member in members:
member_data = member.to_json()
logger.debug(f"Member: {member_data}")
# Check for user-specific attributes
user_principal_name = member_data.get("userPrincipalName")
mail = member_data.get("mail")
@@ -366,6 +370,7 @@ def _get_groups_and_members_recursively(
client_context: ClientContext,
graph_client: GraphClient,
groups: set[SharepointGroup],
is_group_sync: bool = False,
) -> GroupsResult:
"""
Get all groups and their members recursively.
@@ -373,6 +378,7 @@ def _get_groups_and_members_recursively(
group_queue: deque[SharepointGroup] = deque(groups)
visited_groups: set[str] = set()
visited_group_name_to_emails: dict[str, set[str]] = {}
found_public_group = False
while group_queue:
group = group_queue.popleft()
if group.login_name in visited_groups:
@@ -390,19 +396,35 @@ def _get_groups_and_members_recursively(
if group_info:
group_queue.extend(group_info)
if group.principal_type == AZURE_AD_GROUP_PRINCIPAL_TYPE:
# if the site is public, we have default groups assigned to it, so we return early
if _is_public_login_name(group.login_name):
return GroupsResult(groups_to_emails={}, found_public_group=True)
group_info, user_emails = _get_azuread_groups(
graph_client, group.login_name
)
visited_group_name_to_emails[group.name].update(user_emails)
if group_info:
group_queue.extend(group_info)
try:
# if the site is public, we have default groups assigned to it, so we return early
if _is_public_login_name(group.login_name):
found_public_group = True
if not is_group_sync:
return GroupsResult(
groups_to_emails={}, found_public_group=True
)
else:
# we don't want to sync public groups, so we skip them
continue
group_info, user_emails = _get_azuread_groups(
graph_client, group.login_name
)
visited_group_name_to_emails[group.name].update(user_emails)
if group_info:
group_queue.extend(group_info)
except ClientRequestException as e:
# If the group is not found, we skip it. There is a chance that group is still referenced
# in sharepoint but it is removed from Azure AD. There is no actual documentation on this, but based on
# our testing we have seen this happen.
if e.response is not None and e.response.status_code == 404:
logger.warning(f"Group {group.login_name} not found")
continue
raise e
return GroupsResult(
groups_to_emails=visited_group_name_to_emails, found_public_group=False
groups_to_emails=visited_group_name_to_emails,
found_public_group=found_public_group,
)
@@ -427,6 +449,7 @@ def get_external_access_from_sharepoint(
) -> None:
nonlocal user_emails, groups
for assignment in role_assignments:
logger.debug(f"Assignment: {assignment.to_json()}")
if assignment.role_definition_bindings:
is_limited_access = True
for role_definition_binding in assignment.role_definition_bindings:
@@ -503,12 +526,19 @@ def get_external_access_from_sharepoint(
)
elif site_page:
site_url = site_page.get("webUrl")
site_pages = client_context.web.lists.get_by_title("Site Pages")
client_context.load(site_pages)
client_context.execute_query()
site_pages.items.get_by_url(site_url).role_assignments.expand(
["Member", "RoleDefinitionBindings"]
).get_all(page_loaded=add_user_and_group_to_sets).execute_query()
# Prefer server-relative URL to avoid OData filters that break on apostrophes
server_relative_url = unquote(urlparse(site_url).path)
file_obj = client_context.web.get_file_by_server_relative_url(
server_relative_url
)
item = file_obj.listItemAllFields
sleep_and_retry(
item.role_assignments.expand(["Member", "RoleDefinitionBindings"]).get_all(
page_loaded=add_user_and_group_to_sets,
),
"get_external_access_from_sharepoint",
)
else:
raise RuntimeError("No drive item or site page provided")
@@ -595,13 +625,9 @@ def get_sharepoint_external_groups(
"get_sharepoint_external_groups",
)
groups_and_members: GroupsResult = _get_groups_and_members_recursively(
client_context, graph_client, groups
client_context, graph_client, groups, is_group_sync=True
)
# We don't have any direct way to check if the site is public, so we check if any public group is present
if groups_and_members.found_public_group:
return []
# get all Azure AD groups because if any group is assigned to the drive item, we don't want to miss them
# We can't assign sharepoint groups to drive items or drives, so we don't need to get all sharepoint groups
azure_ad_groups = sleep_and_retry(

View File

@@ -14,7 +14,6 @@ from onyx.db.chat import create_new_chat_message
from onyx.db.chat import get_chat_messages_by_sessions
from onyx.db.chat import get_chat_sessions_by_slack_thread_id
from onyx.db.chat import get_or_create_root_message
from onyx.db.models import Prompt
from onyx.db.models import SlackChannelConfig
from onyx.db.models import StandardAnswer as StandardAnswerModel
from onyx.onyxbot.slack.blocks import get_restate_blocks
@@ -81,7 +80,6 @@ def _handle_standard_answers(
message_info: SlackMessageInfo,
receiver_ids: list[str] | None,
slack_channel_config: SlackChannelConfig,
prompt: Prompt | None,
logger: OnyxLoggingAdapter,
client: WebClient,
db_session: Session,
@@ -161,7 +159,6 @@ def _handle_standard_answers(
new_user_message = create_new_chat_message(
chat_session_id=chat_session.id,
parent_message=root_message,
prompt_id=prompt.id if prompt else None,
message=query_msg.message,
token_count=0,
message_type=MessageType.USER,
@@ -182,7 +179,6 @@ def _handle_standard_answers(
chat_message = create_new_chat_message(
chat_session_id=chat_session.id,
parent_message=new_user_message,
prompt_id=prompt.id if prompt else None,
message=answer_message,
token_count=0,
message_type=MessageType.ASSISTANT,

View File

@@ -1,43 +1,22 @@
import re
from typing import cast
from uuid import UUID
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from sqlalchemy.orm import Session
from ee.onyx.server.query_and_chat.models import AgentAnswer
from ee.onyx.server.query_and_chat.models import AgentSubQuery
from ee.onyx.server.query_and_chat.models import AgentSubQuestion
from ee.onyx.server.query_and_chat.models import BasicCreateChatMessageRequest
from ee.onyx.server.query_and_chat.models import (
BasicCreateChatMessageWithHistoryRequest,
)
from ee.onyx.server.query_and_chat.models import ChatBasicResponse
from onyx.auth.users import current_user
from onyx.chat.chat_utils import combine_message_thread
from onyx.chat.chat_utils import create_chat_chain
from onyx.chat.models import AgentAnswerPiece
from onyx.chat.models import AllCitations
from onyx.chat.models import ExtendedToolResponse
from onyx.chat.models import FinalUsedContextDocsResponse
from onyx.chat.models import LlmDoc
from onyx.chat.models import LLMRelevanceFilterResponse
from onyx.chat.models import OnyxAnswerPiece
from onyx.chat.models import QADocsResponse
from onyx.chat.models import RefinedAnswerImprovement
from onyx.chat.models import StreamingError
from onyx.chat.models import SubQueryPiece
from onyx.chat.models import SubQuestionIdentifier
from onyx.chat.models import SubQuestionPiece
from onyx.chat.process_message import ChatPacketStream
from onyx.chat.models import ChatBasicResponse
from onyx.chat.process_message import gather_stream
from onyx.chat.process_message import stream_chat_message_objects
from onyx.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
from onyx.configs.constants import MessageType
from onyx.context.search.models import OptionalSearchSetting
from onyx.context.search.models import RetrievalDetails
from onyx.context.search.models import SavedSearchDoc
from onyx.db.chat import create_chat_session
from onyx.db.chat import create_new_chat_message
from onyx.db.chat import get_or_create_root_message
@@ -46,7 +25,6 @@ from onyx.db.models import User
from onyx.llm.factory import get_llms_for_persona
from onyx.natural_language_processing.utils import get_tokenizer
from onyx.secondary_llm_flows.query_expansion import thread_based_query_rephrase
from onyx.server.query_and_chat.models import ChatMessageDetail
from onyx.server.query_and_chat.models import CreateChatMessageRequest
from onyx.utils.logger import setup_logger
@@ -55,180 +33,6 @@ logger = setup_logger()
router = APIRouter(prefix="/chat")
def _get_final_context_doc_indices(
final_context_docs: list[LlmDoc] | None,
top_docs: list[SavedSearchDoc] | None,
) -> list[int] | None:
"""
this function returns a list of indices of the simple search docs
that were actually fed to the LLM.
"""
if final_context_docs is None or top_docs is None:
return None
final_context_doc_ids = {doc.document_id for doc in final_context_docs}
return [
i for i, doc in enumerate(top_docs) if doc.document_id in final_context_doc_ids
]
def _convert_packet_stream_to_response(
packets: ChatPacketStream,
chat_session_id: UUID,
) -> ChatBasicResponse:
response = ChatBasicResponse()
final_context_docs: list[LlmDoc] = []
answer = ""
# accumulate stream data with these dicts
agent_sub_questions: dict[tuple[int, int], AgentSubQuestion] = {}
agent_answers: dict[tuple[int, int], AgentAnswer] = {}
agent_sub_queries: dict[tuple[int, int, int], AgentSubQuery] = {}
for packet in packets:
if isinstance(packet, OnyxAnswerPiece) and packet.answer_piece:
answer += packet.answer_piece
elif isinstance(packet, QADocsResponse):
response.top_documents = packet.top_documents
# This is a no-op if agent_sub_questions hasn't already been filled
if packet.level is not None and packet.level_question_num is not None:
id = (packet.level, packet.level_question_num)
if id in agent_sub_questions:
agent_sub_questions[id].document_ids = [
saved_search_doc.document_id
for saved_search_doc in packet.top_documents
]
elif isinstance(packet, StreamingError):
response.error_msg = packet.error
elif isinstance(packet, ChatMessageDetail):
response.message_id = packet.message_id
elif isinstance(packet, LLMRelevanceFilterResponse):
response.llm_selected_doc_indices = packet.llm_selected_doc_indices
# TODO: deprecate `llm_chunks_indices`
response.llm_chunks_indices = packet.llm_selected_doc_indices
elif isinstance(packet, FinalUsedContextDocsResponse):
final_context_docs = packet.final_context_docs
elif isinstance(packet, AllCitations):
response.cited_documents = {
citation.citation_num: citation.document_id
for citation in packet.citations
}
# agentic packets
elif isinstance(packet, SubQuestionPiece):
if packet.level is not None and packet.level_question_num is not None:
id = (packet.level, packet.level_question_num)
if agent_sub_questions.get(id) is None:
agent_sub_questions[id] = AgentSubQuestion(
level=packet.level,
level_question_num=packet.level_question_num,
sub_question=packet.sub_question,
document_ids=[],
)
else:
agent_sub_questions[id].sub_question += packet.sub_question
elif isinstance(packet, AgentAnswerPiece):
if packet.level is not None and packet.level_question_num is not None:
id = (packet.level, packet.level_question_num)
if agent_answers.get(id) is None:
agent_answers[id] = AgentAnswer(
level=packet.level,
level_question_num=packet.level_question_num,
answer=packet.answer_piece,
answer_type=packet.answer_type,
)
else:
agent_answers[id].answer += packet.answer_piece
elif isinstance(packet, SubQueryPiece):
if packet.level is not None and packet.level_question_num is not None:
sub_query_id = (
packet.level,
packet.level_question_num,
packet.query_id,
)
if agent_sub_queries.get(sub_query_id) is None:
agent_sub_queries[sub_query_id] = AgentSubQuery(
level=packet.level,
level_question_num=packet.level_question_num,
sub_query=packet.sub_query,
query_id=packet.query_id,
)
else:
agent_sub_queries[sub_query_id].sub_query += packet.sub_query
elif isinstance(packet, ExtendedToolResponse):
# we shouldn't get this ... it gets intercepted and translated to QADocsResponse
logger.warning(
"_convert_packet_stream_to_response: Unexpected chat packet type ExtendedToolResponse!"
)
elif isinstance(packet, RefinedAnswerImprovement):
response.agent_refined_answer_improvement = (
packet.refined_answer_improvement
)
else:
logger.warning(
f"_convert_packet_stream_to_response - Unrecognized chat packet: type={type(packet)}"
)
response.final_context_doc_indices = _get_final_context_doc_indices(
final_context_docs, response.top_documents
)
# organize / sort agent metadata for output
if len(agent_sub_questions) > 0:
response.agent_sub_questions = cast(
dict[int, list[AgentSubQuestion]],
SubQuestionIdentifier.make_dict_by_level(agent_sub_questions),
)
if len(agent_answers) > 0:
# return the agent_level_answer from the first level or the last one depending
# on agent_refined_answer_improvement
response.agent_answers = cast(
dict[int, list[AgentAnswer]],
SubQuestionIdentifier.make_dict_by_level(agent_answers),
)
if response.agent_answers:
selected_answer_level = (
0
if not response.agent_refined_answer_improvement
else len(response.agent_answers) - 1
)
level_answers = response.agent_answers[selected_answer_level]
for level_answer in level_answers:
if level_answer.answer_type != "agent_level_answer":
continue
answer = level_answer.answer
break
if len(agent_sub_queries) > 0:
# subqueries are often emitted with trailing whitespace ... clean it up here
# perhaps fix at the source?
for v in agent_sub_queries.values():
v.sub_query = v.sub_query.strip()
response.agent_sub_queries = (
AgentSubQuery.make_dict_by_level_and_question_index(agent_sub_queries)
)
response.answer = answer
if answer:
response.answer_citationless = remove_answer_citations(answer)
response.chat_session_id = chat_session_id
return response
def remove_answer_citations(answer: str) -> str:
pattern = r"\s*\[\[\d+\]\]\(http[s]?://[^\s]+\)"
return re.sub(pattern, "", answer)
@router.post("/send-message-simple-api")
def handle_simplified_chat_message(
chat_message_req: BasicCreateChatMessageRequest,
@@ -289,7 +93,6 @@ def handle_simplified_chat_message(
parent_message_id=parent_message.id,
message=chat_message_req.message,
file_descriptors=[],
prompt_id=None,
search_doc_ids=chat_message_req.search_doc_ids,
retrieval_options=retrieval_options,
# Simple API does not support reranking, hide complexity from user
@@ -310,7 +113,7 @@ def handle_simplified_chat_message(
enforce_chat_session_id_for_search_docs=False,
)
return _convert_packet_stream_to_response(packets, chat_session_id)
return gather_stream(packets)
@router.post("/send-message-simple-with-history")
@@ -377,7 +180,6 @@ def handle_send_message_simple_with_history(
chat_message = create_new_chat_message(
chat_session_id=chat_session.id,
parent_message=chat_message,
prompt_id=req.prompt_id,
message=msg.message,
token_count=len(llm_tokenizer.encode(msg.message)),
message_type=msg.role,
@@ -410,7 +212,6 @@ def handle_send_message_simple_with_history(
parent_message_id=chat_message.id,
message=query,
file_descriptors=[],
prompt_id=req.prompt_id,
search_doc_ids=req.search_doc_ids,
retrieval_options=retrieval_options,
# Simple API does not support reranking, hide complexity from user
@@ -430,4 +231,4 @@ def handle_send_message_simple_with_history(
enforce_chat_session_id_for_search_docs=False,
)
return _convert_packet_stream_to_response(packets, chat_session.id)
return gather_stream(packets)

View File

@@ -6,10 +6,8 @@ from pydantic import BaseModel
from pydantic import Field
from pydantic import model_validator
from onyx.chat.models import CitationInfo
from onyx.chat.models import PersonaOverrideConfig
from onyx.chat.models import QADocsResponse
from onyx.chat.models import SubQuestionIdentifier
from onyx.chat.models import ThreadMessage
from onyx.configs.constants import DocumentSource
from onyx.context.search.enums import LLMEvaluationType
@@ -17,8 +15,9 @@ from onyx.context.search.enums import SearchType
from onyx.context.search.models import ChunkContext
from onyx.context.search.models import RerankingDetails
from onyx.context.search.models import RetrievalDetails
from onyx.context.search.models import SavedSearchDoc
from onyx.server.manage.models import StandardAnswer
from onyx.server.query_and_chat.streaming_models import CitationInfo
from onyx.server.query_and_chat.streaming_models import SubQuestionIdentifier
class StandardAnswerRequest(BaseModel):
@@ -74,7 +73,6 @@ class BasicCreateChatMessageRequest(ChunkContext):
class BasicCreateChatMessageWithHistoryRequest(ChunkContext):
# Last element is the new query. All previous elements are historical context
messages: list[ThreadMessage]
prompt_id: int | None
persona_id: int
retrieval_options: RetrievalDetails | None = None
query_override: str | None = None
@@ -156,33 +154,6 @@ class AgentSubQuery(SubQuestionIdentifier):
return sorted_dict
class ChatBasicResponse(BaseModel):
# This is built piece by piece, any of these can be None as the flow could break
answer: str | None = None
answer_citationless: str | None = None
top_documents: list[SavedSearchDoc] | None = None
error_msg: str | None = None
message_id: int | None = None
llm_selected_doc_indices: list[int] | None = None
final_context_doc_indices: list[int] | None = None
# this is a map of the citation number to the document id
cited_documents: dict[int, str] | None = None
# FOR BACKWARDS COMPATIBILITY
llm_chunks_indices: list[int] | None = None
# agentic fields
agent_sub_questions: dict[int, list[AgentSubQuestion]] | None = None
agent_answers: dict[int, list[AgentAnswer]] | None = None
agent_sub_queries: dict[int, dict[int, list[AgentSubQuery]]] | None = None
agent_refined_answer_improvement: bool | None = None
# Chat session ID for tracking conversation continuity
chat_session_id: UUID | None = None
class OneShotQARequest(ChunkContext):
# Supports simplier APIs that don't deal with chat histories or message edits
# Easier APIs to work with for developers
@@ -190,10 +161,8 @@ class OneShotQARequest(ChunkContext):
persona_id: int | None = None
messages: list[ThreadMessage]
prompt_id: int | None = None
retrieval_options: RetrievalDetails = Field(default_factory=RetrievalDetails)
rerank_settings: RerankingDetails | None = None
return_contexts: bool = False
# allows the caller to specify the exact search query they want to use
# can be used if the message sent to the LLM / query should not be the same
@@ -210,11 +179,9 @@ class OneShotQARequest(ChunkContext):
def check_persona_fields(self) -> "OneShotQARequest":
if self.persona_override_config is None and self.persona_id is None:
raise ValueError("Exactly one of persona_config or persona_id must be set")
elif self.persona_override_config is not None and (
self.persona_id is not None or self.prompt_id is not None
):
elif self.persona_override_config is not None and (self.persona_id is not None):
raise ValueError(
"If persona_override_config is set, persona_id and prompt_id cannot be set"
"If persona_override_config is set, persona_id cannot be set"
)
return self
@@ -225,6 +192,5 @@ class OneShotQAResponse(BaseModel):
rephrase: str | None = None
citations: list[CitationInfo] | None = None
docs: QADocsResponse | None = None
llm_selected_doc_indices: list[int] | None = None
error_msg: str | None = None
chat_message_id: int | None = None

View File

@@ -8,7 +8,6 @@ from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from sqlalchemy.orm import Session
from ee.onyx.chat.process_message import gather_stream_for_answer_api
from ee.onyx.onyxbot.slack.handlers.handle_standard_answers import (
oneoff_standard_answers,
)
@@ -20,8 +19,10 @@ from ee.onyx.server.query_and_chat.models import StandardAnswerResponse
from onyx.auth.users import current_user
from onyx.chat.chat_utils import combine_message_thread
from onyx.chat.chat_utils import prepare_chat_message_request
from onyx.chat.models import AnswerStream
from onyx.chat.models import PersonaOverrideConfig
from onyx.chat.process_message import ChatPacketStream
from onyx.chat.models import QADocsResponse
from onyx.chat.process_message import gather_stream
from onyx.chat.process_message import stream_chat_message_objects
from onyx.configs.onyxbot_configs import MAX_THREAD_CONTEXT_PERCENTAGE
from onyx.context.search.models import SavedSearchDocWithContent
@@ -30,7 +31,6 @@ from onyx.context.search.pipeline import SearchPipeline
from onyx.context.search.utils import dedupe_documents
from onyx.context.search.utils import drop_llm_indices
from onyx.context.search.utils import relevant_sections_to_indices
from onyx.db.chat import get_prompt_by_id
from onyx.db.engine.sql_engine import get_session
from onyx.db.models import Persona
from onyx.db.models import User
@@ -39,6 +39,7 @@ from onyx.llm.factory import get_default_llms
from onyx.llm.factory import get_llms_for_persona
from onyx.llm.factory import get_main_llm_from_tuple
from onyx.natural_language_processing.utils import get_tokenizer
from onyx.server.query_and_chat.streaming_models import CitationInfo
from onyx.server.utils import get_json_line
from onyx.utils.logger import setup_logger
@@ -140,7 +141,7 @@ def get_answer_stream(
query_request: OneShotQARequest,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> ChatPacketStream:
) -> AnswerStream:
query = query_request.messages[0].message
logger.notice(f"Received query for Answer API: {query}")
@@ -150,14 +151,6 @@ def get_answer_stream(
):
raise KeyError("Must provide persona ID or Persona Config")
prompt = None
if query_request.prompt_id is not None:
prompt = get_prompt_by_id(
prompt_id=query_request.prompt_id,
user=user,
db_session=db_session,
)
persona_info: Persona | PersonaOverrideConfig | None = None
if query_request.persona_override_config is not None:
persona_info = query_request.persona_override_config
@@ -192,7 +185,6 @@ def get_answer_stream(
user=user,
persona_id=query_request.persona_id,
persona_override_config=query_request.persona_override_config,
prompt=prompt,
message_ts_to_respond_to=None,
retrieval_details=query_request.retrieval_options,
rerank_settings=query_request.rerank_settings,
@@ -205,7 +197,6 @@ def get_answer_stream(
new_msg_req=request,
user=user,
db_session=db_session,
include_contexts=query_request.return_contexts,
)
return packets
@@ -219,12 +210,28 @@ def get_answer_with_citation(
) -> OneShotQAResponse:
try:
packets = get_answer_stream(request, user, db_session)
answer = gather_stream_for_answer_api(packets)
answer = gather_stream(packets)
if answer.error_msg:
raise RuntimeError(answer.error_msg)
return answer
return OneShotQAResponse(
answer=answer.answer,
chat_message_id=answer.message_id,
error_msg=answer.error_msg,
citations=[
CitationInfo(citation_num=i, document_id=doc_id)
for i, doc_id in answer.cited_documents.items()
],
docs=QADocsResponse(
top_documents=answer.top_documents,
predicted_flow=None,
predicted_search=None,
applied_source_filters=None,
applied_time_cutoff=None,
recency_bias_multiplier=0.0,
),
)
except Exception as e:
logger.error(f"Error in get_answer_with_citation: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail="An internal server error occurred")

View File

@@ -12,11 +12,13 @@ from sqlalchemy.orm import Session
from ee.onyx.db.usage_export import get_all_usage_reports
from ee.onyx.db.usage_export import get_usage_report_data
from ee.onyx.db.usage_export import UsageReportMetadata
from ee.onyx.server.reporting.usage_export_generation import create_new_usage_report
from onyx.auth.users import current_admin_user
from onyx.background.celery.versioned_apps.client import app as client_app
from onyx.configs.constants import OnyxCeleryTask
from onyx.db.engine.sql_engine import get_session
from onyx.db.models import User
from onyx.file_store.constants import STANDARD_CHUNK_SIZE
from shared_configs.contextvars import get_current_tenant_id
router = APIRouter()
@@ -26,24 +28,31 @@ class GenerateUsageReportParams(BaseModel):
period_to: str | None = None
@router.post("/admin/generate-usage-report")
@router.post("/admin/usage-report", status_code=204)
def generate_report(
params: GenerateUsageReportParams,
user: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> UsageReportMetadata:
period = None
) -> None:
# Validate period parameters
if params.period_from and params.period_to:
try:
period = (
datetime.fromisoformat(params.period_from),
datetime.fromisoformat(params.period_to),
)
datetime.fromisoformat(params.period_from)
datetime.fromisoformat(params.period_to)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
new_report = create_new_usage_report(db_session, user.id if user else None, period)
return new_report
tenant_id = get_current_tenant_id()
client_app.send_task(
OnyxCeleryTask.GENERATE_USAGE_REPORT_TASK,
kwargs={
"tenant_id": tenant_id,
"user_id": str(user.id) if user else None,
"period_from": params.period_from,
"period_to": params.period_to,
},
)
return None
@router.get("/admin/usage-report/{report_name}")
@@ -54,7 +63,7 @@ def read_usage_report(
) -> Response:
try:
file = get_usage_report_data(report_name)
except ValueError as e:
except (ValueError, RuntimeError) as e:
raise HTTPException(status_code=404, detail=str(e))
def iterfile() -> Generator[bytes, None, None]:

View File

@@ -131,32 +131,35 @@ def _seed_llms(
def _seed_personas(db_session: Session, personas: list[PersonaUpsertRequest]) -> None:
if personas:
logger.notice("Seeding Personas")
for persona in personas:
if not persona.prompt_ids:
raise ValueError(
f"Invalid Persona with name {persona.name}; no prompts exist"
try:
for persona in personas:
upsert_persona(
user=None, # Seeding is done as admin
name=persona.name,
description=persona.description,
num_chunks=(
persona.num_chunks if persona.num_chunks is not None else 0.0
),
llm_relevance_filter=persona.llm_relevance_filter,
llm_filter_extraction=persona.llm_filter_extraction,
recency_bias=RecencyBiasSetting.AUTO,
document_set_ids=persona.document_set_ids,
llm_model_provider_override=persona.llm_model_provider_override,
llm_model_version_override=persona.llm_model_version_override,
starter_messages=persona.starter_messages,
is_public=persona.is_public,
db_session=db_session,
tool_ids=persona.tool_ids,
display_priority=persona.display_priority,
system_prompt=persona.system_prompt,
task_prompt=persona.task_prompt,
datetime_aware=persona.datetime_aware,
commit=False,
)
upsert_persona(
user=None, # Seeding is done as admin
name=persona.name,
description=persona.description,
num_chunks=(
persona.num_chunks if persona.num_chunks is not None else 0.0
),
llm_relevance_filter=persona.llm_relevance_filter,
llm_filter_extraction=persona.llm_filter_extraction,
recency_bias=RecencyBiasSetting.AUTO,
prompt_ids=persona.prompt_ids,
document_set_ids=persona.document_set_ids,
llm_model_provider_override=persona.llm_model_provider_override,
llm_model_version_override=persona.llm_model_version_override,
starter_messages=persona.starter_messages,
is_public=persona.is_public,
db_session=db_session,
tool_ids=persona.tool_ids,
display_priority=persona.display_priority,
)
db_session.commit()
except Exception:
logger.exception("Failed to seed personas.")
raise
def _seed_settings(settings: Settings) -> None:

View File

@@ -1,34 +1,5 @@
from shared_configs.enums import EmbeddingProvider
from shared_configs.enums import EmbedTextType
MODEL_WARM_UP_STRING = "hi " * 512
INFORMATION_CONTENT_MODEL_WARM_UP_STRING = "hi " * 16
DEFAULT_OPENAI_MODEL = "text-embedding-3-small"
DEFAULT_COHERE_MODEL = "embed-english-light-v3.0"
DEFAULT_VOYAGE_MODEL = "voyage-large-2-instruct"
DEFAULT_VERTEX_MODEL = "text-embedding-005"
class EmbeddingModelTextType:
PROVIDER_TEXT_TYPE_MAP = {
EmbeddingProvider.COHERE: {
EmbedTextType.QUERY: "search_query",
EmbedTextType.PASSAGE: "search_document",
},
EmbeddingProvider.VOYAGE: {
EmbedTextType.QUERY: "query",
EmbedTextType.PASSAGE: "document",
},
EmbeddingProvider.GOOGLE: {
EmbedTextType.QUERY: "RETRIEVAL_QUERY",
EmbedTextType.PASSAGE: "RETRIEVAL_DOCUMENT",
},
}
@staticmethod
def get_type(provider: EmbeddingProvider, text_type: EmbedTextType) -> str:
return EmbeddingModelTextType.PROVIDER_TEXT_TYPE_MAP[provider][text_type]
class GPUStatus:

View File

@@ -1,55 +1,30 @@
import asyncio
import json
import time
from types import TracebackType
from typing import cast
from typing import Any
from typing import Optional
import aioboto3 # type: ignore
import httpx
import openai
import vertexai # type: ignore
import voyageai # type: ignore
from cohere import AsyncClient as CohereAsyncClient
from fastapi import APIRouter
from fastapi import HTTPException
from fastapi import Request
from google.oauth2 import service_account # type: ignore
from litellm import aembedding
from litellm.exceptions import RateLimitError
from retry import retry
from sentence_transformers import CrossEncoder # type: ignore
from sentence_transformers import SentenceTransformer # type: ignore
from vertexai.language_models import TextEmbeddingInput # type: ignore
from vertexai.language_models import TextEmbeddingModel # type: ignore
from model_server.constants import DEFAULT_COHERE_MODEL
from model_server.constants import DEFAULT_OPENAI_MODEL
from model_server.constants import DEFAULT_VERTEX_MODEL
from model_server.constants import DEFAULT_VOYAGE_MODEL
from model_server.constants import EmbeddingModelTextType
from model_server.constants import EmbeddingProvider
from model_server.utils import pass_aws_key
from model_server.utils import simple_log_function_time
from onyx.utils.logger import setup_logger
from shared_configs.configs import API_BASED_EMBEDDING_TIMEOUT
from shared_configs.configs import INDEXING_ONLY
from shared_configs.configs import OPENAI_EMBEDDING_TIMEOUT
from shared_configs.configs import VERTEXAI_EMBEDDING_LOCAL_BATCH_SIZE
from shared_configs.enums import EmbedTextType
from shared_configs.enums import RerankerProvider
from shared_configs.model_server_models import Embedding
from shared_configs.model_server_models import EmbedRequest
from shared_configs.model_server_models import EmbedResponse
from shared_configs.model_server_models import RerankRequest
from shared_configs.model_server_models import RerankResponse
from shared_configs.utils import batch_list
logger = setup_logger()
router = APIRouter(prefix="/encoder")
_GLOBAL_MODELS_DICT: dict[str, "SentenceTransformer"] = {}
_RERANK_MODEL: Optional["CrossEncoder"] = None
@@ -57,315 +32,6 @@ _RERANK_MODEL: Optional["CrossEncoder"] = None
_RETRY_DELAY = 10 if INDEXING_ONLY else 0.1
_RETRY_TRIES = 10 if INDEXING_ONLY else 2
# OpenAI only allows 2048 embeddings to be computed at once
_OPENAI_MAX_INPUT_LEN = 2048
# Cohere allows up to 96 embeddings in a single embedding calling
_COHERE_MAX_INPUT_LEN = 96
# Authentication error string constants
_AUTH_ERROR_401 = "401"
_AUTH_ERROR_UNAUTHORIZED = "unauthorized"
_AUTH_ERROR_INVALID_API_KEY = "invalid api key"
_AUTH_ERROR_PERMISSION = "permission"
def is_authentication_error(error: Exception) -> bool:
"""Check if an exception is related to authentication issues.
Args:
error: The exception to check
Returns:
bool: True if the error appears to be authentication-related
"""
error_str = str(error).lower()
return (
_AUTH_ERROR_401 in error_str
or _AUTH_ERROR_UNAUTHORIZED in error_str
or _AUTH_ERROR_INVALID_API_KEY in error_str
or _AUTH_ERROR_PERMISSION in error_str
)
def format_embedding_error(
error: Exception,
service_name: str,
model: str | None,
provider: EmbeddingProvider,
sanitized_api_key: str | None = None,
status_code: int | None = None,
) -> str:
"""
Format a standardized error string for embedding errors.
"""
detail = f"Status {status_code}" if status_code else f"{type(error)}"
return (
f"{'HTTP error' if status_code else 'Exception'} embedding text with {service_name} - {detail}: "
f"Model: {model} "
f"Provider: {provider} "
f"API Key: {sanitized_api_key} "
f"Exception: {error}"
)
# Custom exception for authentication errors
class AuthenticationError(Exception):
"""Raised when authentication fails with a provider."""
def __init__(self, provider: str, message: str = "API key is invalid or expired"):
self.provider = provider
self.message = message
super().__init__(f"{provider} authentication failed: {message}")
class CloudEmbedding:
def __init__(
self,
api_key: str,
provider: EmbeddingProvider,
api_url: str | None = None,
api_version: str | None = None,
timeout: int = API_BASED_EMBEDDING_TIMEOUT,
) -> None:
self.provider = provider
self.api_key = api_key
self.api_url = api_url
self.api_version = api_version
self.timeout = timeout
self.http_client = httpx.AsyncClient(timeout=timeout)
self._closed = False
self.sanitized_api_key = api_key[:4] + "********" + api_key[-4:]
async def _embed_openai(
self, texts: list[str], model: str | None, reduced_dimension: int | None
) -> list[Embedding]:
if not model:
model = DEFAULT_OPENAI_MODEL
# Use the OpenAI specific timeout for this one
client = openai.AsyncOpenAI(
api_key=self.api_key, timeout=OPENAI_EMBEDDING_TIMEOUT
)
final_embeddings: list[Embedding] = []
for text_batch in batch_list(texts, _OPENAI_MAX_INPUT_LEN):
response = await client.embeddings.create(
input=text_batch,
model=model,
dimensions=reduced_dimension or openai.NOT_GIVEN,
)
final_embeddings.extend(
[embedding.embedding for embedding in response.data]
)
return final_embeddings
async def _embed_cohere(
self, texts: list[str], model: str | None, embedding_type: str
) -> list[Embedding]:
if not model:
model = DEFAULT_COHERE_MODEL
client = CohereAsyncClient(api_key=self.api_key)
final_embeddings: list[Embedding] = []
for text_batch in batch_list(texts, _COHERE_MAX_INPUT_LEN):
# Does not use the same tokenizer as the Onyx API server but it's approximately the same
# empirically it's only off by a very few tokens so it's not a big deal
response = await client.embed(
texts=text_batch,
model=model,
input_type=embedding_type,
truncate="END",
)
final_embeddings.extend(cast(list[Embedding], response.embeddings))
return final_embeddings
async def _embed_voyage(
self, texts: list[str], model: str | None, embedding_type: str
) -> list[Embedding]:
if not model:
model = DEFAULT_VOYAGE_MODEL
client = voyageai.AsyncClient(
api_key=self.api_key, timeout=API_BASED_EMBEDDING_TIMEOUT
)
response = await client.embed(
texts=texts,
model=model,
input_type=embedding_type,
truncation=True,
)
return response.embeddings
async def _embed_azure(
self, texts: list[str], model: str | None
) -> list[Embedding]:
response = await aembedding(
model=model,
input=texts,
timeout=API_BASED_EMBEDDING_TIMEOUT,
api_key=self.api_key,
api_base=self.api_url,
api_version=self.api_version,
)
embeddings = [embedding["embedding"] for embedding in response.data]
return embeddings
async def _embed_vertex(
self, texts: list[str], model: str | None, embedding_type: str
) -> list[Embedding]:
if not model:
model = DEFAULT_VERTEX_MODEL
credentials = service_account.Credentials.from_service_account_info(
json.loads(self.api_key)
)
project_id = json.loads(self.api_key)["project_id"]
vertexai.init(project=project_id, credentials=credentials)
client = TextEmbeddingModel.from_pretrained(model)
inputs = [TextEmbeddingInput(text, embedding_type) for text in texts]
# Split into batches of 25 texts
max_texts_per_batch = VERTEXAI_EMBEDDING_LOCAL_BATCH_SIZE
batches = [
inputs[i : i + max_texts_per_batch]
for i in range(0, len(inputs), max_texts_per_batch)
]
# Dispatch all embedding calls asynchronously at once
tasks = [
client.get_embeddings_async(batch, auto_truncate=True) for batch in batches
]
# Wait for all tasks to complete in parallel
results = await asyncio.gather(*tasks)
return [embedding.values for batch in results for embedding in batch]
async def _embed_litellm_proxy(
self, texts: list[str], model_name: str | None
) -> list[Embedding]:
if not model_name:
raise ValueError("Model name is required for LiteLLM proxy embedding.")
if not self.api_url:
raise ValueError("API URL is required for LiteLLM proxy embedding.")
headers = (
{} if not self.api_key else {"Authorization": f"Bearer {self.api_key}"}
)
response = await self.http_client.post(
self.api_url,
json={
"model": model_name,
"input": texts,
},
headers=headers,
)
response.raise_for_status()
result = response.json()
return [embedding["embedding"] for embedding in result["data"]]
@retry(tries=_RETRY_TRIES, delay=_RETRY_DELAY)
async def embed(
self,
*,
texts: list[str],
text_type: EmbedTextType,
model_name: str | None = None,
deployment_name: str | None = None,
reduced_dimension: int | None = None,
) -> list[Embedding]:
try:
if self.provider == EmbeddingProvider.OPENAI:
return await self._embed_openai(texts, model_name, reduced_dimension)
elif self.provider == EmbeddingProvider.AZURE:
return await self._embed_azure(texts, f"azure/{deployment_name}")
elif self.provider == EmbeddingProvider.LITELLM:
return await self._embed_litellm_proxy(texts, model_name)
embedding_type = EmbeddingModelTextType.get_type(self.provider, text_type)
if self.provider == EmbeddingProvider.COHERE:
return await self._embed_cohere(texts, model_name, embedding_type)
elif self.provider == EmbeddingProvider.VOYAGE:
return await self._embed_voyage(texts, model_name, embedding_type)
elif self.provider == EmbeddingProvider.GOOGLE:
return await self._embed_vertex(texts, model_name, embedding_type)
else:
raise ValueError(f"Unsupported provider: {self.provider}")
except openai.AuthenticationError:
raise AuthenticationError(provider="OpenAI")
except httpx.HTTPStatusError as e:
if e.response.status_code == 401:
raise AuthenticationError(provider=str(self.provider))
error_string = format_embedding_error(
e,
str(self.provider),
model_name or deployment_name,
self.provider,
sanitized_api_key=self.sanitized_api_key,
status_code=e.response.status_code,
)
logger.error(error_string)
logger.debug(f"Exception texts: {texts}")
raise RuntimeError(error_string)
except Exception as e:
if is_authentication_error(e):
raise AuthenticationError(provider=str(self.provider))
error_string = format_embedding_error(
e,
str(self.provider),
model_name or deployment_name,
self.provider,
sanitized_api_key=self.sanitized_api_key,
)
logger.error(error_string)
logger.debug(f"Exception texts: {texts}")
raise RuntimeError(error_string)
@staticmethod
def create(
api_key: str,
provider: EmbeddingProvider,
api_url: str | None = None,
api_version: str | None = None,
) -> "CloudEmbedding":
logger.debug(f"Creating Embedding instance for provider: {provider}")
return CloudEmbedding(api_key, provider, api_url, api_version)
async def aclose(self) -> None:
"""Explicitly close the client."""
if not self._closed:
await self.http_client.aclose()
self._closed = True
async def __aenter__(self) -> "CloudEmbedding":
return self
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
await self.aclose()
def __del__(self) -> None:
"""Finalizer to warn about unclosed clients."""
if not self._closed:
logger.warning(
"CloudEmbedding was not properly closed. Use 'async with' or call aclose()"
)
def get_embedding_model(
model_name: str,
@@ -404,20 +70,34 @@ def get_local_reranking_model(
return _RERANK_MODEL
ENCODING_RETRIES = 3
ENCODING_RETRY_DELAY = 0.1
def _concurrent_embedding(
texts: list[str], model: "SentenceTransformer", normalize_embeddings: bool
) -> Any:
"""Synchronous wrapper for concurrent_embedding to use with run_in_executor."""
for _ in range(ENCODING_RETRIES):
try:
return model.encode(texts, normalize_embeddings=normalize_embeddings)
except RuntimeError as e:
# There is a concurrency bug in the SentenceTransformer library that causes
# the model to fail to encode texts. It's pretty rare and we want to allow
# concurrent embedding, hence we retry (the specific error is
# "RuntimeError: Already borrowed" and occurs in the transformers library)
logger.error(f"Error encoding texts, retrying: {e}")
time.sleep(ENCODING_RETRY_DELAY)
return model.encode(texts, normalize_embeddings=normalize_embeddings)
@simple_log_function_time()
async def embed_text(
texts: list[str],
text_type: EmbedTextType,
model_name: str | None,
deployment_name: str | None,
max_context_length: int,
normalize_embeddings: bool,
api_key: str | None,
provider_type: EmbeddingProvider | None,
prefix: str | None,
api_url: str | None,
api_version: str | None,
reduced_dimension: int | None,
gpu_type: str = "UNKNOWN",
) -> list[Embedding]:
if not all(texts):
@@ -434,52 +114,10 @@ async def embed_text(
for text in texts:
total_chars += len(text)
if provider_type is not None:
logger.info(
f"Embedding {len(texts)} texts with {total_chars} total characters with provider: {provider_type}"
)
# Only local models should call this function now
# API providers should go directly to API server
if api_key is None:
logger.error("API key not provided for cloud model")
raise RuntimeError("API key not provided for cloud model")
if prefix:
logger.warning("Prefix provided for cloud model, which is not supported")
raise ValueError(
"Prefix string is not valid for cloud models. "
"Cloud models take an explicit text type instead."
)
async with CloudEmbedding(
api_key=api_key,
provider=provider_type,
api_url=api_url,
api_version=api_version,
) as cloud_model:
embeddings = await cloud_model.embed(
texts=texts,
model_name=model_name,
deployment_name=deployment_name,
text_type=text_type,
reduced_dimension=reduced_dimension,
)
if any(embedding is None for embedding in embeddings):
error_message = "Embeddings contain None values\n"
error_message += "Corresponding texts:\n"
error_message += "\n".join(texts)
logger.error(error_message)
raise ValueError(error_message)
elapsed = time.monotonic() - start
logger.info(
f"event=embedding_provider "
f"texts={len(texts)} "
f"chars={total_chars} "
f"provider={provider_type} "
f"elapsed={elapsed:.2f}"
)
elif model_name is not None:
if model_name is not None:
logger.info(
f"Embedding {len(texts)} texts with {total_chars} total characters with local model: {model_name}"
)
@@ -492,8 +130,8 @@ async def embed_text(
# Run CPU-bound embedding in a thread pool
embeddings_vectors = await asyncio.get_event_loop().run_in_executor(
None,
lambda: local_model.encode(
prefixed_texts, normalize_embeddings=normalize_embeddings
lambda: _concurrent_embedding(
prefixed_texts, local_model, normalize_embeddings
),
)
embeddings = [
@@ -515,10 +153,8 @@ async def embed_text(
f"elapsed={elapsed:.2f}"
)
else:
logger.error("Neither model name nor provider specified for embedding")
raise ValueError(
"Either model name or provider must be provided to run embeddings."
)
logger.error("Model name not specified for embedding")
raise ValueError("Model name must be provided to run embeddings.")
return embeddings
@@ -533,77 +169,6 @@ async def local_rerank(query: str, docs: list[str], model_name: str) -> list[flo
)
async def cohere_rerank_api(
query: str, docs: list[str], model_name: str, api_key: str
) -> list[float]:
cohere_client = CohereAsyncClient(api_key=api_key)
response = await cohere_client.rerank(query=query, documents=docs, model=model_name)
results = response.results
sorted_results = sorted(results, key=lambda item: item.index)
return [result.relevance_score for result in sorted_results]
async def cohere_rerank_aws(
query: str,
docs: list[str],
model_name: str,
region_name: str,
aws_access_key_id: str,
aws_secret_access_key: str,
) -> list[float]:
session = aioboto3.Session(
aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key
)
async with session.client(
"bedrock-runtime", region_name=region_name
) as bedrock_client:
body = json.dumps(
{
"query": query,
"documents": docs,
"api_version": 2,
}
)
# Invoke the Bedrock model asynchronously
response = await bedrock_client.invoke_model(
modelId=model_name,
accept="application/json",
contentType="application/json",
body=body,
)
# Read the response asynchronously
response_body = json.loads(await response["body"].read())
# Extract and sort the results
results = response_body.get("results", [])
sorted_results = sorted(results, key=lambda item: item["index"])
return [result["relevance_score"] for result in sorted_results]
async def litellm_rerank(
query: str, docs: list[str], api_url: str, model_name: str, api_key: str | None
) -> list[float]:
headers = {} if not api_key else {"Authorization": f"Bearer {api_key}"}
async with httpx.AsyncClient() as client:
response = await client.post(
api_url,
json={
"model": model_name,
"query": query,
"documents": docs,
},
headers=headers,
)
response.raise_for_status()
result = response.json()
return [
item["relevance_score"]
for item in sorted(result["results"], key=lambda x: x["index"])
]
@router.post("/bi-encoder-embed")
async def route_bi_encoder_embed(
request: Request,
@@ -615,6 +180,13 @@ async def route_bi_encoder_embed(
async def process_embed_request(
embed_request: EmbedRequest, gpu_type: str = "UNKNOWN"
) -> EmbedResponse:
# Only local models should use this endpoint - API providers should make direct API calls
if embed_request.provider_type is not None:
raise ValueError(
f"Model server embedding endpoint should only be used for local models. "
f"API provider '{embed_request.provider_type}' should make direct API calls instead."
)
if not embed_request.texts:
raise HTTPException(status_code=400, detail="No texts to be embedded")
@@ -632,26 +204,12 @@ async def process_embed_request(
embeddings = await embed_text(
texts=embed_request.texts,
model_name=embed_request.model_name,
deployment_name=embed_request.deployment_name,
max_context_length=embed_request.max_context_length,
normalize_embeddings=embed_request.normalize_embeddings,
api_key=embed_request.api_key,
provider_type=embed_request.provider_type,
text_type=embed_request.text_type,
api_url=embed_request.api_url,
api_version=embed_request.api_version,
reduced_dimension=embed_request.reduced_dimension,
prefix=prefix,
gpu_type=gpu_type,
)
return EmbedResponse(embeddings=embeddings)
except AuthenticationError as e:
# Handle authentication errors consistently
logger.error(f"Authentication error: {e.provider}")
raise HTTPException(
status_code=401,
detail=f"Authentication failed: {e.message}",
)
except RateLimitError as e:
raise HTTPException(
status_code=429,
@@ -669,6 +227,13 @@ async def process_embed_request(
@router.post("/cross-encoder-scores")
async def process_rerank_request(rerank_request: RerankRequest) -> RerankResponse:
"""Cross encoders can be purely black box from the app perspective"""
# Only local models should use this endpoint - API providers should make direct API calls
if rerank_request.provider_type is not None:
raise ValueError(
f"Model server reranking endpoint should only be used for local models. "
f"API provider '{rerank_request.provider_type}' should make direct API calls instead."
)
if INDEXING_ONLY:
raise RuntimeError("Indexing model server should not call intent endpoint")
@@ -680,55 +245,13 @@ async def process_rerank_request(rerank_request: RerankRequest) -> RerankRespons
raise ValueError("Empty documents cannot be reranked.")
try:
if rerank_request.provider_type is None:
sim_scores = await local_rerank(
query=rerank_request.query,
docs=rerank_request.documents,
model_name=rerank_request.model_name,
)
return RerankResponse(scores=sim_scores)
elif rerank_request.provider_type == RerankerProvider.LITELLM:
if rerank_request.api_url is None:
raise ValueError("API URL is required for LiteLLM reranking.")
sim_scores = await litellm_rerank(
query=rerank_request.query,
docs=rerank_request.documents,
api_url=rerank_request.api_url,
model_name=rerank_request.model_name,
api_key=rerank_request.api_key,
)
return RerankResponse(scores=sim_scores)
elif rerank_request.provider_type == RerankerProvider.COHERE:
if rerank_request.api_key is None:
raise RuntimeError("Cohere Rerank Requires an API Key")
sim_scores = await cohere_rerank_api(
query=rerank_request.query,
docs=rerank_request.documents,
model_name=rerank_request.model_name,
api_key=rerank_request.api_key,
)
return RerankResponse(scores=sim_scores)
elif rerank_request.provider_type == RerankerProvider.BEDROCK:
if rerank_request.api_key is None:
raise RuntimeError("Bedrock Rerank Requires an API Key")
aws_access_key_id, aws_secret_access_key, aws_region = pass_aws_key(
rerank_request.api_key
)
sim_scores = await cohere_rerank_aws(
query=rerank_request.query,
docs=rerank_request.documents,
model_name=rerank_request.model_name,
region_name=aws_region,
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
)
return RerankResponse(scores=sim_scores)
else:
raise ValueError(f"Unsupported provider: {rerank_request.provider_type}")
# At this point, provider_type is None, so handle local reranking
sim_scores = await local_rerank(
query=rerank_request.query,
docs=rerank_request.documents,
model_name=rerank_request.model_name,
)
return RerankResponse(scores=sim_scores)
except Exception as e:
logger.exception(f"Error during reranking process:\n{str(e)}")

View File

@@ -34,8 +34,8 @@ from shared_configs.configs import SENTRY_DSN
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"
HF_CACHE_PATH = Path(os.path.expanduser("~")) / ".cache/huggingface"
TEMP_HF_CACHE_PATH = Path(os.path.expanduser("~")) / ".cache/temp_huggingface"
HF_CACHE_PATH = Path(".cache/huggingface")
TEMP_HF_CACHE_PATH = Path(".cache/temp_huggingface")
transformer_logging.set_verbosity_error()

View File

@@ -70,32 +70,3 @@ def get_gpu_type() -> str:
return GPUStatus.MAC_MPS
return GPUStatus.NONE
def pass_aws_key(api_key: str) -> tuple[str, str, str]:
"""Parse AWS API key string into components.
Args:
api_key: String in format 'aws_ACCESSKEY_SECRETKEY_REGION'
Returns:
Tuple of (access_key, secret_key, region)
Raises:
ValueError: If key format is invalid
"""
if not api_key.startswith("aws"):
raise ValueError("API key must start with 'aws' prefix")
parts = api_key.split("_")
if len(parts) != 4:
raise ValueError(
f"API key must be in format 'aws_ACCESSKEY_SECRETKEY_REGION', got {len(parts) - 1} parts"
"this is an onyx specific format for formatting the aws secrets for bedrock"
)
try:
_, aws_access_key_id, aws_secret_access_key, aws_region = parts
return aws_access_key_id, aws_secret_access_key, aws_region
except Exception as e:
raise ValueError(f"Failed to parse AWS key components: {str(e)}")

View File

@@ -1,97 +0,0 @@
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
from onyx.agents.agent_search.basic.states import BasicInput
from onyx.agents.agent_search.basic.states import BasicOutput
from onyx.agents.agent_search.basic.states import BasicState
from onyx.agents.agent_search.orchestration.nodes.call_tool import call_tool
from onyx.agents.agent_search.orchestration.nodes.choose_tool import choose_tool
from onyx.agents.agent_search.orchestration.nodes.prepare_tool_input import (
prepare_tool_input,
)
from onyx.agents.agent_search.orchestration.nodes.use_tool_response import (
basic_use_tool_response,
)
from onyx.utils.logger import setup_logger
logger = setup_logger()
def basic_graph_builder() -> StateGraph:
graph = StateGraph(
state_schema=BasicState,
input=BasicInput,
output=BasicOutput,
)
### Add nodes ###
graph.add_node(
node="prepare_tool_input",
action=prepare_tool_input,
)
graph.add_node(
node="choose_tool",
action=choose_tool,
)
graph.add_node(
node="call_tool",
action=call_tool,
)
graph.add_node(
node="basic_use_tool_response",
action=basic_use_tool_response,
)
### Add edges ###
graph.add_edge(start_key=START, end_key="prepare_tool_input")
graph.add_edge(start_key="prepare_tool_input", end_key="choose_tool")
graph.add_conditional_edges("choose_tool", should_continue, ["call_tool", END])
graph.add_edge(
start_key="call_tool",
end_key="basic_use_tool_response",
)
graph.add_edge(
start_key="basic_use_tool_response",
end_key=END,
)
return graph
def should_continue(state: BasicState) -> str:
return (
# If there are no tool calls, basic graph already streamed the answer
END
if state.tool_choice is None
else "call_tool"
)
if __name__ == "__main__":
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.context.search.models import SearchRequest
from onyx.llm.factory import get_default_llms
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
graph = basic_graph_builder()
compiled_graph = graph.compile()
input = BasicInput(unused=True)
primary_llm, fast_llm = get_default_llms()
with get_session_with_current_tenant() as db_session:
config, _ = get_test_config(
db_session=db_session,
primary_llm=primary_llm,
fast_llm=fast_llm,
search_request=SearchRequest(query="How does onyx use FastAPI?"),
)
compiled_graph.invoke(input, config={"metadata": {"config": config}})

View File

@@ -1,35 +0,0 @@
from typing import TypedDict
from langchain_core.messages import AIMessageChunk
from pydantic import BaseModel
from onyx.agents.agent_search.orchestration.states import ToolCallUpdate
from onyx.agents.agent_search.orchestration.states import ToolChoiceInput
from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
# States contain values that change over the course of graph execution,
# Config is for values that are set at the start and never change.
# If you are using a value from the config and realize it needs to change,
# you should add it to the state and use/update the version in the state.
## Graph Input State
class BasicInput(BaseModel):
# Langgraph needs a nonempty input, but we pass in all static
# data through a RunnableConfig.
unused: bool = True
## Graph Output State
class BasicOutput(TypedDict):
tool_call_chunk: AIMessageChunk
## Graph State
class BasicState(
BasicInput,
ToolChoiceInput,
ToolCallUpdate,
ToolChoiceUpdate,
):
pass

View File

@@ -1,64 +0,0 @@
from collections.abc import Iterator
from typing import cast
from langchain_core.messages import AIMessageChunk
from langchain_core.messages import BaseMessage
from langgraph.types import StreamWriter
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import LlmDoc
from onyx.chat.stream_processing.answer_response_handler import AnswerResponseHandler
from onyx.chat.stream_processing.answer_response_handler import CitationResponseHandler
from onyx.chat.stream_processing.answer_response_handler import (
PassThroughAnswerResponseHandler,
)
from onyx.chat.stream_processing.utils import map_document_id_order
from onyx.utils.logger import setup_logger
logger = setup_logger()
def process_llm_stream(
messages: Iterator[BaseMessage],
should_stream_answer: bool,
writer: StreamWriter,
final_search_results: list[LlmDoc] | None = None,
displayed_search_results: list[LlmDoc] | None = None,
) -> AIMessageChunk:
tool_call_chunk = AIMessageChunk(content="")
if final_search_results and displayed_search_results:
answer_handler: AnswerResponseHandler = CitationResponseHandler(
context_docs=final_search_results,
final_doc_id_to_rank_map=map_document_id_order(final_search_results),
display_doc_id_to_rank_map=map_document_id_order(displayed_search_results),
)
else:
answer_handler = PassThroughAnswerResponseHandler()
full_answer = ""
# This stream will be the llm answer if no tool is chosen. When a tool is chosen,
# the stream will contain AIMessageChunks with tool call information.
for message in messages:
answer_piece = message.content
if not isinstance(answer_piece, str):
# this is only used for logging, so fine to
# just add the string representation
answer_piece = str(answer_piece)
full_answer += answer_piece
if isinstance(message, AIMessageChunk) and (
message.tool_call_chunks or message.tool_calls
):
tool_call_chunk += message # type: ignore
elif should_stream_answer:
for response_part in answer_handler.handle_response_part(message, []):
write_custom_event(
"basic_response",
response_part,
writer,
)
logger.debug(f"Full answer: {full_answer}")
return cast(AIMessageChunk, tool_call_chunk)

View File

@@ -10,6 +10,7 @@ class CoreState(BaseModel):
"""
log_messages: Annotated[list[str], add] = []
current_step_nr: int = 1
class SubgraphCoreState(BaseModel):

View File

@@ -14,8 +14,6 @@ from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
trim_prompt_piece,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import AgentAnswerPiece
from onyx.configs.constants import DocumentSource
from onyx.prompts.agents.dc_prompts import DC_OBJECT_NO_BASE_DATA_EXTRACTION_PROMPT
from onyx.prompts.agents.dc_prompts import DC_OBJECT_SEPARATOR
@@ -41,7 +39,7 @@ def search_objects(
raise ValueError("Search tool and persona must be provided for DivCon search")
try:
instructions = graph_config.inputs.persona.prompts[0].system_prompt
instructions = graph_config.inputs.persona.system_prompt or ""
agent_1_instructions = extract_section(
instructions, "Agent Step 1:", "Agent Step 2:"
@@ -139,17 +137,6 @@ def search_objects(
except Exception as e:
raise ValueError(f"Error in search_objects: {e}")
write_custom_event(
"initial_agent_answer",
AgentAnswerPiece(
answer_piece=" Researching the individual objects for each source type... ",
level=0,
level_question_num=0,
answer_type="agent_level_answer",
),
writer,
)
return SearchSourcesObjectsUpdate(
analysis_objects=object_list,
analysis_sources=document_sources,

View File

@@ -43,7 +43,7 @@ def research_object_source(
raise ValueError("Search tool and persona must be provided for DivCon search")
try:
instructions = graph_config.inputs.persona.prompts[0].system_prompt
instructions = graph_config.inputs.persona.system_prompt or ""
agent_2_instructions = extract_section(
instructions, "Agent Step 2:", "Agent Step 3:"

View File

@@ -9,8 +9,6 @@ from onyx.agents.agent_search.dc_search_analysis.states import MainState
from onyx.agents.agent_search.dc_search_analysis.states import (
ObjectResearchInformationUpdate,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import AgentAnswerPiece
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -23,17 +21,6 @@ def structure_research_by_object(
LangGraph node to start the agentic search process.
"""
write_custom_event(
"initial_agent_answer",
AgentAnswerPiece(
answer_piece=" consolidating the information across source types for each object...",
level=0,
level_question_num=0,
answer_type="agent_level_answer",
),
writer,
)
object_source_research_results = state.object_source_research_results
object_research_information_results: List[Dict[str, str]] = []

View File

@@ -33,7 +33,7 @@ def consolidate_object_research(
if search_tool is None or graph_config.inputs.persona is None:
raise ValueError("Search tool and persona must be provided for DivCon search")
instructions = graph_config.inputs.persona.prompts[0].system_prompt
instructions = graph_config.inputs.persona.system_prompt or ""
agent_4_instructions = extract_section(
instructions, "Agent Step 4:", "Agent Step 5:"

View File

@@ -12,8 +12,6 @@ from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
trim_prompt_piece,
)
from onyx.agents.agent_search.shared_graph_utils.llm import stream_llm_answer
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import AgentAnswerPiece
from onyx.prompts.agents.dc_prompts import DC_FORMATTING_NO_BASE_DATA_PROMPT
from onyx.prompts.agents.dc_prompts import DC_FORMATTING_WITH_BASE_DATA_PROMPT
from onyx.utils.logger import setup_logger
@@ -33,22 +31,11 @@ def consolidate_research(
search_tool = graph_config.tooling.search_tool
write_custom_event(
"initial_agent_answer",
AgentAnswerPiece(
answer_piece=" generating the answer\n\n\n",
level=0,
level_question_num=0,
answer_type="agent_level_answer",
),
writer,
)
if search_tool is None or graph_config.inputs.persona is None:
raise ValueError("Search tool and persona must be provided for DivCon search")
# Populate prompt
instructions = graph_config.inputs.persona.prompts[0].system_prompt
instructions = graph_config.inputs.persona.system_prompt or ""
try:
agent_5_instructions = extract_section(

View File

@@ -1,31 +0,0 @@
from collections.abc import Hashable
from datetime import datetime
from langgraph.types import Send
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
SubQuestionAnsweringInput,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
ExpandedRetrievalInput,
)
from onyx.utils.logger import setup_logger
logger = setup_logger()
def send_to_expanded_retrieval(state: SubQuestionAnsweringInput) -> Send | Hashable:
"""
LangGraph edge to send a sub-question to the expanded retrieval.
"""
edge_start_time = datetime.now()
return Send(
"initial_sub_question_expanded_retrieval",
ExpandedRetrievalInput(
question=state.question,
base_search=False,
sub_question_id=state.question_id,
log_messages=[f"{edge_start_time} -- Sending to expanded retrieval"],
),
)

View File

@@ -1,137 +0,0 @@
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.edges import (
send_to_expanded_retrieval,
)
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.nodes.check_sub_answer import (
check_sub_answer,
)
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.nodes.format_sub_answer import (
format_sub_answer,
)
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.nodes.generate_sub_answer import (
generate_sub_answer,
)
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.nodes.ingest_retrieved_documents import (
ingest_retrieved_documents,
)
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
AnswerQuestionOutput,
)
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
AnswerQuestionState,
)
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
SubQuestionAnsweringInput,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.graph_builder import (
expanded_retrieval_graph_builder,
)
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
from onyx.utils.logger import setup_logger
logger = setup_logger()
def answer_query_graph_builder() -> StateGraph:
"""
LangGraph sub-graph builder for the initial individual sub-answer generation.
"""
graph = StateGraph(
state_schema=AnswerQuestionState,
input=SubQuestionAnsweringInput,
output=AnswerQuestionOutput,
)
### Add nodes ###
# The sub-graph that executes the expanded retrieval process for a sub-question
expanded_retrieval = expanded_retrieval_graph_builder().compile()
graph.add_node(
node="initial_sub_question_expanded_retrieval",
action=expanded_retrieval,
)
# The node that ingests the retrieved documents and puts them into the proper
# state keys.
graph.add_node(
node="ingest_retrieval",
action=ingest_retrieved_documents,
)
# The node that generates the sub-answer
graph.add_node(
node="generate_sub_answer",
action=generate_sub_answer,
)
# The node that checks the sub-answer
graph.add_node(
node="answer_check",
action=check_sub_answer,
)
# The node that formats the sub-answer for the following initial answer generation
graph.add_node(
node="format_answer",
action=format_sub_answer,
)
### Add edges ###
graph.add_conditional_edges(
source=START,
path=send_to_expanded_retrieval,
path_map=["initial_sub_question_expanded_retrieval"],
)
graph.add_edge(
start_key="initial_sub_question_expanded_retrieval",
end_key="ingest_retrieval",
)
graph.add_edge(
start_key="ingest_retrieval",
end_key="generate_sub_answer",
)
graph.add_edge(
start_key="generate_sub_answer",
end_key="answer_check",
)
graph.add_edge(
start_key="answer_check",
end_key="format_answer",
)
graph.add_edge(
start_key="format_answer",
end_key=END,
)
return graph
if __name__ == "__main__":
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.llm.factory import get_default_llms
from onyx.context.search.models import SearchRequest
graph = answer_query_graph_builder()
compiled_graph = graph.compile()
primary_llm, fast_llm = get_default_llms()
search_request = SearchRequest(
query="what can you do with onyx or danswer?",
)
with get_session_with_current_tenant() as db_session:
graph_config, search_tool = get_test_config(
db_session, primary_llm, fast_llm, search_request
)
inputs = SubQuestionAnsweringInput(
question="what can you do with onyx?",
question_id="0_0",
log_messages=[],
)
for thing in compiled_graph.stream(
input=inputs,
config={"configurable": {"config": graph_config}},
):
logger.debug(thing)

View File

@@ -1,136 +0,0 @@
from datetime import datetime
from typing import cast
from langchain_core.messages import BaseMessage
from langchain_core.messages import HumanMessage
from langchain_core.runnables.config import RunnableConfig
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
AnswerQuestionState,
)
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
SubQuestionAnswerCheckUpdate,
)
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
binary_string_test,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_LLM_RATELIMIT_MESSAGE,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_LLM_TIMEOUT_MESSAGE,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_POSITIVE_VALUE_STR,
)
from onyx.agents.agent_search.shared_graph_utils.constants import AgentLLMErrorType
from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLog
from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_VALIDATION
from onyx.configs.agent_configs import AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_CHECK
from onyx.configs.agent_configs import AGENT_TIMEOUT_LLM_SUBANSWER_CHECK
from onyx.llm.chat_llm import LLMRateLimitError
from onyx.llm.chat_llm import LLMTimeoutError
from onyx.prompts.agent_search import SUB_ANSWER_CHECK_PROMPT
from onyx.prompts.agent_search import UNKNOWN_ANSWER
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
from onyx.utils.timing import log_function_time
logger = setup_logger()
_llm_node_error_strings = LLMNodeErrorStrings(
timeout="LLM Timeout Error. The sub-answer will be treated as 'relevant'",
rate_limit="LLM Rate Limit Error. The sub-answer will be treated as 'relevant'",
general_error="General LLM Error. The sub-answer will be treated as 'relevant'",
)
@log_function_time(print_only=True)
def check_sub_answer(
state: AnswerQuestionState, config: RunnableConfig
) -> SubQuestionAnswerCheckUpdate:
"""
LangGraph node to check the quality of the sub-answer. The answer
is represented as a boolean value.
"""
node_start_time = datetime.now()
level, question_num = parse_question_id(state.question_id)
if state.answer == UNKNOWN_ANSWER:
return SubQuestionAnswerCheckUpdate(
answer_quality=False,
log_messages=[
get_langgraph_node_log_string(
graph_component="initial - generate individual sub answer",
node_name="check sub answer",
node_start_time=node_start_time,
result="unknown answer",
)
],
)
msg = [
HumanMessage(
content=SUB_ANSWER_CHECK_PROMPT.format(
question=state.question,
base_answer=state.answer,
)
)
]
graph_config = cast(GraphConfig, config["metadata"]["config"])
fast_llm = graph_config.tooling.fast_llm
agent_error: AgentErrorLog | None = None
response: BaseMessage | None = None
try:
response = run_with_timeout(
AGENT_TIMEOUT_LLM_SUBANSWER_CHECK,
fast_llm.invoke,
prompt=msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_CHECK,
max_tokens=AGENT_MAX_TOKENS_VALIDATION,
)
quality_str: str = cast(str, response.content)
answer_quality = binary_string_test(
text=quality_str, positive_value=AGENT_POSITIVE_VALUE_STR
)
log_result = f"Answer quality: {quality_str}"
except (LLMTimeoutError, TimeoutError):
agent_error = AgentErrorLog(
error_type=AgentLLMErrorType.TIMEOUT,
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
error_result=_llm_node_error_strings.timeout,
)
answer_quality = True
log_result = agent_error.error_result
logger.error("LLM Timeout Error - check sub answer")
except LLMRateLimitError:
agent_error = AgentErrorLog(
error_type=AgentLLMErrorType.RATE_LIMIT,
error_message=AGENT_LLM_RATELIMIT_MESSAGE,
error_result=_llm_node_error_strings.rate_limit,
)
answer_quality = True
log_result = agent_error.error_result
logger.error("LLM Rate Limit Error - check sub answer")
return SubQuestionAnswerCheckUpdate(
answer_quality=answer_quality,
log_messages=[
get_langgraph_node_log_string(
graph_component="initial - generate individual sub answer",
node_name="check sub answer",
node_start_time=node_start_time,
result=log_result,
)
],
)

View File

@@ -1,30 +0,0 @@
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
AnswerQuestionOutput,
)
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
AnswerQuestionState,
)
from onyx.agents.agent_search.shared_graph_utils.models import (
SubQuestionAnswerResults,
)
def format_sub_answer(state: AnswerQuestionState) -> AnswerQuestionOutput:
"""
LangGraph node to generate the sub-answer format.
"""
return AnswerQuestionOutput(
answer_results=[
SubQuestionAnswerResults(
question=state.question,
question_id=state.question_id,
verified_high_quality=state.answer_quality,
answer=state.answer,
sub_query_retrieval_results=state.expanded_retrieval_results,
verified_reranked_documents=state.verified_reranked_documents,
context_documents=state.context_documents,
cited_documents=state.cited_documents,
sub_question_retrieval_stats=state.sub_question_retrieval_stats,
)
],
)

View File

@@ -1,185 +0,0 @@
from datetime import datetime
from typing import cast
from langchain_core.messages import merge_message_runs
from langchain_core.runnables.config import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
AnswerQuestionState,
)
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
SubQuestionAnswerGenerationUpdate,
)
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
build_sub_question_answer_prompt,
)
from onyx.agents.agent_search.shared_graph_utils.calculations import (
dedup_sort_inference_section_list,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_LLM_RATELIMIT_MESSAGE,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_LLM_TIMEOUT_MESSAGE,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AgentLLMErrorType,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
LLM_ANSWER_ERROR_MESSAGE,
)
from onyx.agents.agent_search.shared_graph_utils.llm import stream_llm_answer
from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLog
from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings
from onyx.agents.agent_search.shared_graph_utils.utils import get_answer_citation_ids
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_persona_agent_prompt_expressions,
)
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import AgentAnswerPiece
from onyx.chat.models import StreamStopInfo
from onyx.chat.models import StreamStopReason
from onyx.chat.models import StreamType
from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_SUBANSWER_GENERATION
from onyx.configs.agent_configs import AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION
from onyx.configs.agent_configs import AGENT_TIMEOUT_LLM_SUBANSWER_GENERATION
from onyx.llm.chat_llm import LLMRateLimitError
from onyx.llm.chat_llm import LLMTimeoutError
from onyx.prompts.agent_search import NO_RECOVERED_DOCS
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
from onyx.utils.timing import log_function_time
logger = setup_logger()
_llm_node_error_strings = LLMNodeErrorStrings(
timeout="LLM Timeout Error. A sub-answer could not be constructed and the sub-question will be ignored.",
rate_limit="LLM Rate Limit Error. A sub-answer could not be constructed and the sub-question will be ignored.",
general_error="General LLM Error. A sub-answer could not be constructed and the sub-question will be ignored.",
)
@log_function_time(print_only=True)
def generate_sub_answer(
state: AnswerQuestionState,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> SubQuestionAnswerGenerationUpdate:
"""
LangGraph node to generate a sub-answer.
"""
node_start_time = datetime.now()
graph_config = cast(GraphConfig, config["metadata"]["config"])
question = state.question
state.verified_reranked_documents
level, question_num = parse_question_id(state.question_id)
context_docs = state.context_documents[:AGENT_MAX_ANSWER_CONTEXT_DOCS]
context_docs = dedup_sort_inference_section_list(context_docs)
persona_contextualized_prompt = get_persona_agent_prompt_expressions(
graph_config.inputs.persona
).contextualized_prompt
if len(context_docs) == 0:
answer_str = NO_RECOVERED_DOCS
cited_documents: list = []
log_results = "No documents retrieved"
write_custom_event(
"sub_answers",
AgentAnswerPiece(
answer_piece=answer_str,
level=level,
level_question_num=question_num,
answer_type="agent_sub_answer",
),
writer,
)
else:
fast_llm = graph_config.tooling.fast_llm
msg = build_sub_question_answer_prompt(
question=question,
original_question=graph_config.inputs.prompt_builder.raw_user_query,
docs=context_docs,
persona_specification=persona_contextualized_prompt,
config=fast_llm.config,
)
agent_error: AgentErrorLog | None = None
response: list[str] = []
try:
response, _ = run_with_timeout(
AGENT_TIMEOUT_LLM_SUBANSWER_GENERATION,
lambda: stream_llm_answer(
llm=fast_llm,
prompt=msg,
event_name="sub_answers",
writer=writer,
agent_answer_level=level,
agent_answer_question_num=question_num,
agent_answer_type="agent_sub_answer",
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION,
max_tokens=AGENT_MAX_TOKENS_SUBANSWER_GENERATION,
),
)
except (LLMTimeoutError, TimeoutError):
agent_error = AgentErrorLog(
error_type=AgentLLMErrorType.TIMEOUT,
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
error_result=_llm_node_error_strings.timeout,
)
logger.error("LLM Timeout Error - generate sub answer")
except LLMRateLimitError:
agent_error = AgentErrorLog(
error_type=AgentLLMErrorType.RATE_LIMIT,
error_message=AGENT_LLM_RATELIMIT_MESSAGE,
error_result=_llm_node_error_strings.rate_limit,
)
logger.error("LLM Rate Limit Error - generate sub answer")
if agent_error:
answer_str = LLM_ANSWER_ERROR_MESSAGE
cited_documents = []
log_results = (
agent_error.error_result
or "Sub-answer generation failed due to LLM error"
)
else:
answer_str = merge_message_runs(response, chunk_separator="")[0].content
answer_citation_ids = get_answer_citation_ids(answer_str)
cited_documents = [
context_docs[id] for id in answer_citation_ids if id < len(context_docs)
]
log_results = None
stop_event = StreamStopInfo(
stop_reason=StreamStopReason.FINISHED,
stream_type=StreamType.SUB_ANSWER,
level=level,
level_question_num=question_num,
)
write_custom_event("stream_finished", stop_event, writer)
return SubQuestionAnswerGenerationUpdate(
answer=answer_str,
cited_documents=cited_documents,
log_messages=[
get_langgraph_node_log_string(
graph_component="initial - generate individual sub answer",
node_name="generate sub answer",
node_start_time=node_start_time,
result=log_results or "",
)
],
)

View File

@@ -1,25 +0,0 @@
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
SubQuestionRetrievalIngestionUpdate,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
ExpandedRetrievalOutput,
)
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkRetrievalStats
def ingest_retrieved_documents(
state: ExpandedRetrievalOutput,
) -> SubQuestionRetrievalIngestionUpdate:
"""
LangGraph node to ingest the retrieved documents to format it for the sub-answer.
"""
sub_question_retrieval_stats = state.expanded_retrieval_result.retrieval_stats
if sub_question_retrieval_stats is None:
sub_question_retrieval_stats = [AgentChunkRetrievalStats()]
return SubQuestionRetrievalIngestionUpdate(
expanded_retrieval_results=state.expanded_retrieval_result.expanded_query_results,
verified_reranked_documents=state.expanded_retrieval_result.verified_reranked_documents,
context_documents=state.expanded_retrieval_result.context_documents,
sub_question_retrieval_stats=sub_question_retrieval_stats,
)

View File

@@ -1,73 +0,0 @@
from operator import add
from typing import Annotated
from pydantic import BaseModel
from onyx.agents.agent_search.core_state import SubgraphCoreState
from onyx.agents.agent_search.deep_search.main.states import LoggerUpdate
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkRetrievalStats
from onyx.agents.agent_search.shared_graph_utils.models import QueryRetrievalResult
from onyx.agents.agent_search.shared_graph_utils.models import (
SubQuestionAnswerResults,
)
from onyx.agents.agent_search.shared_graph_utils.operators import (
dedup_inference_sections,
)
from onyx.context.search.models import InferenceSection
## Update States
class SubQuestionAnswerCheckUpdate(LoggerUpdate, BaseModel):
answer_quality: bool = False
log_messages: list[str] = []
class SubQuestionAnswerGenerationUpdate(LoggerUpdate, BaseModel):
answer: str = ""
log_messages: list[str] = []
cited_documents: Annotated[list[InferenceSection], dedup_inference_sections] = []
# answer_stat: AnswerStats
class SubQuestionRetrievalIngestionUpdate(LoggerUpdate, BaseModel):
expanded_retrieval_results: list[QueryRetrievalResult] = []
verified_reranked_documents: Annotated[
list[InferenceSection], dedup_inference_sections
] = []
context_documents: Annotated[list[InferenceSection], dedup_inference_sections] = []
sub_question_retrieval_stats: AgentChunkRetrievalStats = AgentChunkRetrievalStats()
## Graph Input State
class SubQuestionAnsweringInput(SubgraphCoreState):
question: str
question_id: str
# level 0 is original question and first decomposition, level 1 is follow up, etc
# question_num is a unique number per original question per level.
## Graph State
class AnswerQuestionState(
SubQuestionAnsweringInput,
SubQuestionAnswerGenerationUpdate,
SubQuestionAnswerCheckUpdate,
SubQuestionRetrievalIngestionUpdate,
):
pass
## Graph Output State
class AnswerQuestionOutput(LoggerUpdate, BaseModel):
"""
This is a list of results even though each call of this subgraph only returns one result.
This is because if we parallelize the answer query subgraph, there will be multiple
results in a list so the add operator is used to add them together.
"""
answer_results: Annotated[list[SubQuestionAnswerResults], add] = []

View File

@@ -1,50 +0,0 @@
from collections.abc import Hashable
from datetime import datetime
from langgraph.types import Send
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
AnswerQuestionOutput,
)
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
SubQuestionAnsweringInput,
)
from onyx.agents.agent_search.deep_search.initial.generate_initial_answer.states import (
SubQuestionRetrievalState,
)
from onyx.agents.agent_search.shared_graph_utils.utils import make_question_id
def parallelize_initial_sub_question_answering(
state: SubQuestionRetrievalState,
) -> list[Send | Hashable]:
"""
LangGraph edge to parallelize the initial sub-question answering. If there are no sub-questions,
we send empty answers to the initial answer generation, and that answer would be generated
solely based on the documents retrieved for the original question.
"""
edge_start_time = datetime.now()
if len(state.initial_sub_questions) > 0:
return [
Send(
"answer_query_subgraph",
SubQuestionAnsweringInput(
question=question,
question_id=make_question_id(0, question_num + 1),
log_messages=[
f"{edge_start_time} -- Main Edge - Parallelize Initial Sub-question Answering"
],
),
)
for question_num, question in enumerate(state.initial_sub_questions)
]
else:
return [
Send(
"ingest_answers",
AnswerQuestionOutput(
answer_results=[],
),
)
]

View File

@@ -1,96 +0,0 @@
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
from onyx.agents.agent_search.deep_search.initial.generate_initial_answer.nodes.generate_initial_answer import (
generate_initial_answer,
)
from onyx.agents.agent_search.deep_search.initial.generate_initial_answer.nodes.validate_initial_answer import (
validate_initial_answer,
)
from onyx.agents.agent_search.deep_search.initial.generate_initial_answer.states import (
SubQuestionRetrievalInput,
)
from onyx.agents.agent_search.deep_search.initial.generate_initial_answer.states import (
SubQuestionRetrievalState,
)
from onyx.agents.agent_search.deep_search.initial.generate_sub_answers.graph_builder import (
generate_sub_answers_graph_builder,
)
from onyx.agents.agent_search.deep_search.initial.retrieve_orig_question_docs.graph_builder import (
retrieve_orig_question_docs_graph_builder,
)
from onyx.utils.logger import setup_logger
logger = setup_logger()
def generate_initial_answer_graph_builder(test_mode: bool = False) -> StateGraph:
"""
LangGraph graph builder for the initial answer generation.
"""
graph = StateGraph(
state_schema=SubQuestionRetrievalState,
input=SubQuestionRetrievalInput,
)
# The sub-graph that generates the initial sub-answers
generate_sub_answers = generate_sub_answers_graph_builder().compile()
graph.add_node(
node="generate_sub_answers_subgraph",
action=generate_sub_answers,
)
# The sub-graph that retrieves the original question documents. This is run
# in parallel with the sub-answer generation process
retrieve_orig_question_docs = retrieve_orig_question_docs_graph_builder().compile()
graph.add_node(
node="retrieve_orig_question_docs_subgraph_wrapper",
action=retrieve_orig_question_docs,
)
# Node that generates the initial answer using the results of the previous
# two sub-graphs
graph.add_node(
node="generate_initial_answer",
action=generate_initial_answer,
)
# Node that validates the initial answer
graph.add_node(
node="validate_initial_answer",
action=validate_initial_answer,
)
### Add edges ###
graph.add_edge(
start_key=START,
end_key="retrieve_orig_question_docs_subgraph_wrapper",
)
graph.add_edge(
start_key=START,
end_key="generate_sub_answers_subgraph",
)
# Wait for both, the original question docs and the sub-answers to be generated before proceeding
graph.add_edge(
start_key=[
"retrieve_orig_question_docs_subgraph_wrapper",
"generate_sub_answers_subgraph",
],
end_key="generate_initial_answer",
)
graph.add_edge(
start_key="generate_initial_answer",
end_key="validate_initial_answer",
)
graph.add_edge(
start_key="validate_initial_answer",
end_key=END,
)
return graph

View File

@@ -1,405 +0,0 @@
from datetime import datetime
from typing import cast
from langchain_core.messages import HumanMessage
from langchain_core.messages import merge_content
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.deep_search.initial.generate_initial_answer.states import (
SubQuestionRetrievalState,
)
from onyx.agents.agent_search.deep_search.main.models import AgentBaseMetrics
from onyx.agents.agent_search.deep_search.main.operations import (
calculate_initial_agent_stats,
)
from onyx.agents.agent_search.deep_search.main.operations import get_query_info
from onyx.agents.agent_search.deep_search.main.operations import logger
from onyx.agents.agent_search.deep_search.main.states import (
InitialAnswerUpdate,
)
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
get_prompt_enrichment_components,
)
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
trim_prompt_piece,
)
from onyx.agents.agent_search.shared_graph_utils.calculations import (
get_answer_generation_documents,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_LLM_RATELIMIT_MESSAGE,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_LLM_TIMEOUT_MESSAGE,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AgentLLMErrorType,
)
from onyx.agents.agent_search.shared_graph_utils.llm import stream_llm_answer
from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLog
from onyx.agents.agent_search.shared_graph_utils.models import InitialAgentResultStats
from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings
from onyx.agents.agent_search.shared_graph_utils.operators import (
dedup_inference_section_list,
)
from onyx.agents.agent_search.shared_graph_utils.utils import _should_restrict_tokens
from onyx.agents.agent_search.shared_graph_utils.utils import (
dispatch_main_answer_stop_info,
)
from onyx.agents.agent_search.shared_graph_utils.utils import format_docs
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_deduplicated_structured_subquestion_documents,
)
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import relevance_from_docs
from onyx.agents.agent_search.shared_graph_utils.utils import remove_document_citations
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import AgentAnswerPiece
from onyx.chat.models import ExtendedToolResponse
from onyx.chat.models import StreamingError
from onyx.configs.agent_configs import AGENT_ANSWER_GENERATION_BY_FAST_LLM
from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS
from onyx.configs.agent_configs import AGENT_MAX_STREAMED_DOCS_FOR_INITIAL_ANSWER
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_ANSWER_GENERATION
from onyx.configs.agent_configs import AGENT_MIN_ORIG_QUESTION_DOCS
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION,
)
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_LLM_INITIAL_ANSWER_GENERATION,
)
from onyx.llm.chat_llm import LLMRateLimitError
from onyx.llm.chat_llm import LLMTimeoutError
from onyx.prompts.agent_search import INITIAL_ANSWER_PROMPT_W_SUB_QUESTIONS
from onyx.prompts.agent_search import (
INITIAL_ANSWER_PROMPT_WO_SUB_QUESTIONS,
)
from onyx.prompts.agent_search import (
SUB_QUESTION_ANSWER_TEMPLATE,
)
from onyx.prompts.agent_search import UNKNOWN_ANSWER
from onyx.tools.tool_implementations.search.search_tool import yield_search_responses
from onyx.utils.threadpool_concurrency import run_with_timeout
from onyx.utils.timing import log_function_time
_llm_node_error_strings = LLMNodeErrorStrings(
timeout="LLM Timeout Error. The initial answer could not be generated.",
rate_limit="LLM Rate Limit Error. The initial answer could not be generated.",
general_error="General LLM Error. The initial answer could not be generated.",
)
@log_function_time(print_only=True)
def generate_initial_answer(
state: SubQuestionRetrievalState,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> InitialAnswerUpdate:
"""
LangGraph node to generate the initial answer, using the initial sub-questions/sub-answers and the
documents retrieved for the original question.
"""
node_start_time = datetime.now()
graph_config = cast(GraphConfig, config["metadata"]["config"])
question = graph_config.inputs.prompt_builder.raw_user_query
prompt_enrichment_components = get_prompt_enrichment_components(graph_config)
# get all documents cited in sub-questions
structured_subquestion_docs = get_deduplicated_structured_subquestion_documents(
state.sub_question_results
)
orig_question_retrieval_documents = state.orig_question_retrieved_documents
consolidated_context_docs = structured_subquestion_docs.cited_documents
counter = 0
for original_doc in orig_question_retrieval_documents:
if original_doc in structured_subquestion_docs.cited_documents:
continue
if (
counter <= AGENT_MIN_ORIG_QUESTION_DOCS
or len(consolidated_context_docs) < AGENT_MAX_ANSWER_CONTEXT_DOCS
):
consolidated_context_docs.append(original_doc)
counter += 1
# sort docs by their scores - though the scores refer to different questions
relevant_docs = dedup_inference_section_list(consolidated_context_docs)
sub_questions: list[str] = []
# Create the list of documents to stream out. Start with the
# ones that wil be in the context (or, if len == 0, use docs
# that were retrieved for the original question)
answer_generation_documents = get_answer_generation_documents(
relevant_docs=relevant_docs,
context_documents=structured_subquestion_docs.context_documents,
original_question_docs=orig_question_retrieval_documents,
max_docs=AGENT_MAX_STREAMED_DOCS_FOR_INITIAL_ANSWER,
)
# Use the query info from the base document retrieval
query_info = get_query_info(state.orig_question_sub_query_retrieval_results)
assert (
graph_config.tooling.search_tool
), "search_tool must be provided for agentic search"
relevance_list = relevance_from_docs(
answer_generation_documents.streaming_documents
)
for tool_response in yield_search_responses(
query=question,
get_retrieved_sections=lambda: answer_generation_documents.context_documents,
get_final_context_sections=lambda: answer_generation_documents.context_documents,
search_query_info=query_info,
get_section_relevance=lambda: relevance_list,
search_tool=graph_config.tooling.search_tool,
):
write_custom_event(
"tool_response",
ExtendedToolResponse(
id=tool_response.id,
response=tool_response.response,
level=0,
level_question_num=0, # 0, 0 is the base question
),
writer,
)
if len(answer_generation_documents.context_documents) == 0:
write_custom_event(
"initial_agent_answer",
AgentAnswerPiece(
answer_piece=UNKNOWN_ANSWER,
level=0,
level_question_num=0,
answer_type="agent_level_answer",
),
writer,
)
dispatch_main_answer_stop_info(0, writer)
answer = UNKNOWN_ANSWER
initial_agent_stats = InitialAgentResultStats(
sub_questions={},
original_question={},
agent_effectiveness={},
)
else:
sub_question_answer_results = state.sub_question_results
# Collect the sub-questions and sub-answers and construct an appropriate
# prompt string.
# Consider replacing by a function.
answered_sub_questions: list[str] = []
all_sub_questions: list[str] = [] # Separate list for tracking all questions
for idx, sub_question_answer_result in enumerate(
sub_question_answer_results, start=1
):
all_sub_questions.append(sub_question_answer_result.question)
is_valid_answer = (
sub_question_answer_result.verified_high_quality
and sub_question_answer_result.answer
and sub_question_answer_result.answer != UNKNOWN_ANSWER
)
if is_valid_answer:
answered_sub_questions.append(
SUB_QUESTION_ANSWER_TEMPLATE.format(
sub_question=sub_question_answer_result.question,
sub_answer=sub_question_answer_result.answer,
sub_question_num=idx,
)
)
sub_question_answer_str = (
"\n\n------\n\n".join(answered_sub_questions)
if answered_sub_questions
else ""
)
# Use the appropriate prompt based on whether there are sub-questions.
base_prompt = (
INITIAL_ANSWER_PROMPT_W_SUB_QUESTIONS
if answered_sub_questions
else INITIAL_ANSWER_PROMPT_WO_SUB_QUESTIONS
)
sub_questions = all_sub_questions # Replace the original assignment
model = (
graph_config.tooling.fast_llm
if AGENT_ANSWER_GENERATION_BY_FAST_LLM
else graph_config.tooling.primary_llm
)
doc_context = format_docs(answer_generation_documents.context_documents)
doc_context = trim_prompt_piece(
config=model.config,
prompt_piece=doc_context,
reserved_str=(
base_prompt
+ sub_question_answer_str
+ prompt_enrichment_components.persona_prompts.contextualized_prompt
+ prompt_enrichment_components.history
+ prompt_enrichment_components.date_str
),
)
msg = [
HumanMessage(
content=base_prompt.format(
question=question,
answered_sub_questions=remove_document_citations(
sub_question_answer_str
),
relevant_docs=doc_context,
persona_specification=prompt_enrichment_components.persona_prompts.contextualized_prompt,
history=prompt_enrichment_components.history,
date_prompt=prompt_enrichment_components.date_str,
)
)
]
streamed_tokens: list[str] = [""]
dispatch_timings: list[float] = []
agent_error: AgentErrorLog | None = None
try:
streamed_tokens, dispatch_timings = run_with_timeout(
AGENT_TIMEOUT_LLM_INITIAL_ANSWER_GENERATION,
lambda: stream_llm_answer(
llm=model,
prompt=msg,
event_name="initial_agent_answer",
writer=writer,
agent_answer_level=0,
agent_answer_question_num=0,
agent_answer_type="agent_level_answer",
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION,
max_tokens=(
AGENT_MAX_TOKENS_ANSWER_GENERATION
if _should_restrict_tokens(model.config)
else None
),
),
)
except (LLMTimeoutError, TimeoutError):
agent_error = AgentErrorLog(
error_type=AgentLLMErrorType.TIMEOUT,
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
error_result=_llm_node_error_strings.timeout,
)
logger.error("LLM Timeout Error - generate initial answer")
except LLMRateLimitError:
agent_error = AgentErrorLog(
error_type=AgentLLMErrorType.RATE_LIMIT,
error_message=AGENT_LLM_RATELIMIT_MESSAGE,
error_result=_llm_node_error_strings.rate_limit,
)
logger.error("LLM Rate Limit Error - generate initial answer")
if agent_error:
write_custom_event(
"initial_agent_answer",
StreamingError(
error=AGENT_LLM_TIMEOUT_MESSAGE,
),
writer,
)
return InitialAnswerUpdate(
initial_answer=None,
answer_error=AgentErrorLog(
error_message=agent_error.error_message or "An LLM error occurred",
error_type=agent_error.error_type,
error_result=agent_error.error_result,
),
initial_agent_stats=None,
generated_sub_questions=sub_questions,
agent_base_end_time=None,
agent_base_metrics=None,
log_messages=[
get_langgraph_node_log_string(
graph_component="initial - generate initial answer",
node_name="generate initial answer",
node_start_time=node_start_time,
result=agent_error.error_result or "An LLM error occurred",
)
],
)
logger.debug(
f"Average dispatch time for initial answer: {sum(dispatch_timings) / len(dispatch_timings)}"
)
dispatch_main_answer_stop_info(0, writer)
response = merge_content(*streamed_tokens)
answer = cast(str, response)
initial_agent_stats = calculate_initial_agent_stats(
state.sub_question_results, state.orig_question_retrieval_stats
)
logger.debug(
f"\n\nYYYYY--Sub-Questions:\n\n{sub_question_answer_str}\n\nStats:\n\n"
)
if initial_agent_stats:
logger.debug(initial_agent_stats.original_question)
logger.debug(initial_agent_stats.sub_questions)
logger.debug(initial_agent_stats.agent_effectiveness)
agent_base_end_time = datetime.now()
if agent_base_end_time and state.agent_start_time:
duration_s = (agent_base_end_time - state.agent_start_time).total_seconds()
else:
duration_s = None
agent_base_metrics = AgentBaseMetrics(
num_verified_documents_total=len(relevant_docs),
num_verified_documents_core=state.orig_question_retrieval_stats.verified_count,
verified_avg_score_core=state.orig_question_retrieval_stats.verified_avg_scores,
num_verified_documents_base=initial_agent_stats.sub_questions.get(
"num_verified_documents"
),
verified_avg_score_base=initial_agent_stats.sub_questions.get(
"verified_avg_score"
),
base_doc_boost_factor=initial_agent_stats.agent_effectiveness.get(
"utilized_chunk_ratio"
),
support_boost_factor=initial_agent_stats.agent_effectiveness.get(
"support_ratio"
),
duration_s=duration_s,
)
return InitialAnswerUpdate(
initial_answer=answer,
initial_agent_stats=initial_agent_stats,
generated_sub_questions=sub_questions,
agent_base_end_time=agent_base_end_time,
agent_base_metrics=agent_base_metrics,
log_messages=[
get_langgraph_node_log_string(
graph_component="initial - generate initial answer",
node_name="generate initial answer",
node_start_time=node_start_time,
result="",
)
],
)

View File

@@ -1,42 +0,0 @@
from datetime import datetime
from onyx.agents.agent_search.deep_search.initial.generate_initial_answer.states import (
SubQuestionRetrievalState,
)
from onyx.agents.agent_search.deep_search.main.operations import logger
from onyx.agents.agent_search.deep_search.main.states import (
InitialAnswerQualityUpdate,
)
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.utils.timing import log_function_time
@log_function_time(print_only=True)
def validate_initial_answer(
state: SubQuestionRetrievalState,
) -> InitialAnswerQualityUpdate:
"""
Check whether the initial answer sufficiently addresses the original user question.
"""
node_start_time = datetime.now()
logger.debug(
f"--------{node_start_time}--------Checking for base answer validity - for not set True/False manually"
)
verdict = True # not actually required as already streamed out. Refinement will do similar
return InitialAnswerQualityUpdate(
initial_answer_quality_eval=verdict,
log_messages=[
get_langgraph_node_log_string(
graph_component="initial - generate initial answer",
node_name="validate initial answer",
node_start_time=node_start_time,
result="",
)
],
)

View File

@@ -1,51 +0,0 @@
from operator import add
from typing import Annotated
from typing import TypedDict
from onyx.agents.agent_search.core_state import CoreState
from onyx.agents.agent_search.deep_search.main.states import (
ExploratorySearchUpdate,
)
from onyx.agents.agent_search.deep_search.main.states import (
InitialAnswerQualityUpdate,
)
from onyx.agents.agent_search.deep_search.main.states import (
InitialAnswerUpdate,
)
from onyx.agents.agent_search.deep_search.main.states import (
InitialQuestionDecompositionUpdate,
)
from onyx.agents.agent_search.deep_search.main.states import (
OrigQuestionRetrievalUpdate,
)
from onyx.agents.agent_search.deep_search.main.states import (
SubQuestionResultsUpdate,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.models import (
QuestionRetrievalResult,
)
from onyx.context.search.models import InferenceSection
### States ###
class SubQuestionRetrievalInput(CoreState):
exploratory_search_results: list[InferenceSection]
## Graph State
class SubQuestionRetrievalState(
# This includes the core state
SubQuestionRetrievalInput,
InitialQuestionDecompositionUpdate,
InitialAnswerUpdate,
SubQuestionResultsUpdate,
OrigQuestionRetrievalUpdate,
InitialAnswerQualityUpdate,
ExploratorySearchUpdate,
):
base_raw_search_result: Annotated[list[QuestionRetrievalResult], add]
## Graph Output State
class SubQuestionRetrievalOutput(TypedDict):
log_messages: list[str]

View File

@@ -1,48 +0,0 @@
from collections.abc import Hashable
from datetime import datetime
from langgraph.types import Send
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
AnswerQuestionOutput,
)
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
SubQuestionAnsweringInput,
)
from onyx.agents.agent_search.deep_search.initial.generate_initial_answer.states import (
SubQuestionRetrievalState,
)
from onyx.agents.agent_search.shared_graph_utils.utils import make_question_id
def parallelize_initial_sub_question_answering(
state: SubQuestionRetrievalState,
) -> list[Send | Hashable]:
"""
LangGraph edge to parallelize the initial sub-question answering.
"""
edge_start_time = datetime.now()
if len(state.initial_sub_questions) > 0:
return [
Send(
"answer_sub_question_subgraphs",
SubQuestionAnsweringInput(
question=question,
question_id=make_question_id(0, question_num + 1),
log_messages=[
f"{edge_start_time} -- Main Edge - Parallelize Initial Sub-question Answering"
],
),
)
for question_num, question in enumerate(state.initial_sub_questions)
]
else:
return [
Send(
"ingest_answers",
AnswerQuestionOutput(
answer_results=[],
),
)
]

View File

@@ -1,81 +0,0 @@
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.graph_builder import (
answer_query_graph_builder,
)
from onyx.agents.agent_search.deep_search.initial.generate_sub_answers.edges import (
parallelize_initial_sub_question_answering,
)
from onyx.agents.agent_search.deep_search.initial.generate_sub_answers.nodes.decompose_orig_question import (
decompose_orig_question,
)
from onyx.agents.agent_search.deep_search.initial.generate_sub_answers.nodes.format_initial_sub_answers import (
format_initial_sub_answers,
)
from onyx.agents.agent_search.deep_search.initial.generate_sub_answers.states import (
SubQuestionAnsweringInput,
)
from onyx.agents.agent_search.deep_search.initial.generate_sub_answers.states import (
SubQuestionAnsweringState,
)
from onyx.utils.logger import setup_logger
logger = setup_logger()
test_mode = False
def generate_sub_answers_graph_builder() -> StateGraph:
"""
LangGraph graph builder for the initial sub-answer generation process.
It generates the initial sub-questions and produces the answers.
"""
graph = StateGraph(
state_schema=SubQuestionAnsweringState,
input=SubQuestionAnsweringInput,
)
# Decompose the original question into sub-questions
graph.add_node(
node="decompose_orig_question",
action=decompose_orig_question,
)
# The sub-graph that executes the initial sub-question answering for
# each of the sub-questions.
answer_sub_question_subgraphs = answer_query_graph_builder().compile()
graph.add_node(
node="answer_sub_question_subgraphs",
action=answer_sub_question_subgraphs,
)
# Node that collects and formats the initial sub-question answers
graph.add_node(
node="format_initial_sub_question_answers",
action=format_initial_sub_answers,
)
graph.add_edge(
start_key=START,
end_key="decompose_orig_question",
)
graph.add_conditional_edges(
source="decompose_orig_question",
path=parallelize_initial_sub_question_answering,
path_map=["answer_sub_question_subgraphs"],
)
graph.add_edge(
start_key=["answer_sub_question_subgraphs"],
end_key="format_initial_sub_question_answers",
)
graph.add_edge(
start_key="format_initial_sub_question_answers",
end_key=END,
)
return graph

View File

@@ -1,190 +0,0 @@
from datetime import datetime
from typing import cast
from langchain_core.messages import HumanMessage
from langchain_core.messages import merge_content
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.deep_search.initial.generate_initial_answer.states import (
SubQuestionRetrievalState,
)
from onyx.agents.agent_search.deep_search.main.models import (
AgentRefinedMetrics,
)
from onyx.agents.agent_search.deep_search.main.operations import dispatch_subquestion
from onyx.agents.agent_search.deep_search.main.operations import (
dispatch_subquestion_sep,
)
from onyx.agents.agent_search.deep_search.main.states import (
InitialQuestionDecompositionUpdate,
)
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
build_history_prompt,
)
from onyx.agents.agent_search.shared_graph_utils.models import BaseMessage_Content
from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings
from onyx.agents.agent_search.shared_graph_utils.utils import dispatch_separated
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import StreamStopInfo
from onyx.chat.models import StreamStopReason
from onyx.chat.models import StreamType
from onyx.chat.models import SubQuestionPiece
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_SUBQUESTION_GENERATION
from onyx.configs.agent_configs import AGENT_NUM_DOCS_FOR_DECOMPOSITION
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_CONNECT_LLM_SUBQUESTION_GENERATION,
)
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_LLM_SUBQUESTION_GENERATION,
)
from onyx.llm.chat_llm import LLMRateLimitError
from onyx.llm.chat_llm import LLMTimeoutError
from onyx.prompts.agent_search import (
INITIAL_DECOMPOSITION_PROMPT_QUESTIONS_AFTER_SEARCH_ASSUMING_REFINEMENT,
)
from onyx.prompts.agent_search import (
INITIAL_QUESTION_DECOMPOSITION_PROMPT_ASSUMING_REFINEMENT,
)
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
from onyx.utils.timing import log_function_time
logger = setup_logger()
_llm_node_error_strings = LLMNodeErrorStrings(
timeout="LLM Timeout Error. Sub-questions could not be generated.",
rate_limit="LLM Rate Limit Error. Sub-questions could not be generated.",
general_error="General LLM Error. Sub-questions could not be generated.",
)
@log_function_time(print_only=True)
def decompose_orig_question(
state: SubQuestionRetrievalState,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> InitialQuestionDecompositionUpdate:
"""
LangGraph node to decompose the original question into sub-questions.
"""
node_start_time = datetime.now()
graph_config = cast(GraphConfig, config["metadata"]["config"])
question = graph_config.inputs.prompt_builder.raw_user_query
perform_initial_search_decomposition = (
graph_config.behavior.perform_initial_search_decomposition
)
# Get the rewritten queries in a defined format
model = graph_config.tooling.fast_llm
history = build_history_prompt(graph_config, question)
# Use the initial search results to inform the decomposition
agent_start_time = datetime.now()
# Initial search to inform decomposition. Just get top 3 fits
if perform_initial_search_decomposition:
# Due to unfortunate state representation in LangGraph, we need here to double check that the retrieval has
# happened prior to this point, allowing silent failure here since it is not critical for decomposition in
# all queries.
if not state.exploratory_search_results:
logger.error("Initial search for decomposition failed")
sample_doc_str = "\n\n".join(
[
doc.combined_content
for doc in state.exploratory_search_results[
:AGENT_NUM_DOCS_FOR_DECOMPOSITION
]
]
)
decomposition_prompt = INITIAL_DECOMPOSITION_PROMPT_QUESTIONS_AFTER_SEARCH_ASSUMING_REFINEMENT.format(
question=question, sample_doc_str=sample_doc_str, history=history
)
else:
decomposition_prompt = (
INITIAL_QUESTION_DECOMPOSITION_PROMPT_ASSUMING_REFINEMENT.format(
question=question, history=history
)
)
# Start decomposition
msg = [HumanMessage(content=decomposition_prompt)]
# Send the initial question as a subquestion with number 0
write_custom_event(
"decomp_qs",
SubQuestionPiece(
sub_question=question,
level=0,
level_question_num=0,
),
writer,
)
# dispatches custom events for subquestion tokens, adding in subquestion ids.
streamed_tokens: list[BaseMessage_Content] = []
try:
streamed_tokens = run_with_timeout(
AGENT_TIMEOUT_LLM_SUBQUESTION_GENERATION,
dispatch_separated,
model.stream(
msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_SUBQUESTION_GENERATION,
max_tokens=AGENT_MAX_TOKENS_SUBQUESTION_GENERATION,
),
dispatch_subquestion(0, writer),
sep_callback=dispatch_subquestion_sep(0, writer),
)
decomposition_response = merge_content(*streamed_tokens)
list_of_subqs = cast(str, decomposition_response).split("\n")
initial_sub_questions = [sq.strip() for sq in list_of_subqs if sq.strip() != ""]
log_result = f"decomposed original question into {len(initial_sub_questions)} subquestions"
stop_event = StreamStopInfo(
stop_reason=StreamStopReason.FINISHED,
stream_type=StreamType.SUB_QUESTIONS,
level=0,
)
write_custom_event("stream_finished", stop_event, writer)
except (LLMTimeoutError, TimeoutError) as e:
logger.error("LLM Timeout Error - decompose orig question")
raise e # fail loudly on this critical step
except LLMRateLimitError as e:
logger.error("LLM Rate Limit Error - decompose orig question")
raise e
return InitialQuestionDecompositionUpdate(
initial_sub_questions=initial_sub_questions,
agent_start_time=agent_start_time,
agent_refined_start_time=None,
agent_refined_end_time=None,
agent_refined_metrics=AgentRefinedMetrics(
refined_doc_boost_factor=None,
refined_question_boost_factor=None,
duration_s=None,
),
log_messages=[
get_langgraph_node_log_string(
graph_component="initial - generate sub answers",
node_name="decompose original question",
node_start_time=node_start_time,
result=log_result,
)
],
)

View File

@@ -1,50 +0,0 @@
from datetime import datetime
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
AnswerQuestionOutput,
)
from onyx.agents.agent_search.deep_search.main.states import (
SubQuestionResultsUpdate,
)
from onyx.agents.agent_search.shared_graph_utils.operators import (
dedup_inference_sections,
)
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
def format_initial_sub_answers(
state: AnswerQuestionOutput,
) -> SubQuestionResultsUpdate:
"""
LangGraph node to format the answers to the initial sub-questions, including
deduping verified documents and context documents.
"""
node_start_time = datetime.now()
documents = []
context_documents = []
cited_documents = []
answer_results = state.answer_results
for answer_result in answer_results:
documents.extend(answer_result.verified_reranked_documents)
context_documents.extend(answer_result.context_documents)
cited_documents.extend(answer_result.cited_documents)
return SubQuestionResultsUpdate(
# Deduping is done by the documents operator for the main graph
# so we might not need to dedup here
verified_reranked_documents=dedup_inference_sections(documents, []),
context_documents=dedup_inference_sections(context_documents, []),
cited_documents=dedup_inference_sections(cited_documents, []),
sub_question_results=answer_results,
log_messages=[
get_langgraph_node_log_string(
graph_component="initial - generate sub answers",
node_name="format initial sub answers",
node_start_time=node_start_time,
result="",
)
],
)

View File

@@ -1,34 +0,0 @@
from typing import TypedDict
from onyx.agents.agent_search.core_state import CoreState
from onyx.agents.agent_search.deep_search.main.states import (
InitialAnswerUpdate,
)
from onyx.agents.agent_search.deep_search.main.states import (
InitialQuestionDecompositionUpdate,
)
from onyx.agents.agent_search.deep_search.main.states import (
SubQuestionResultsUpdate,
)
from onyx.context.search.models import InferenceSection
### States ###
class SubQuestionAnsweringInput(CoreState):
exploratory_search_results: list[InferenceSection]
## Graph State
class SubQuestionAnsweringState(
# This includes the core state
SubQuestionAnsweringInput,
InitialQuestionDecompositionUpdate,
InitialAnswerUpdate,
SubQuestionResultsUpdate,
):
pass
## Graph Output State
class SubQuestionAnsweringOutput(TypedDict):
log_messages: list[str]

View File

@@ -1,81 +0,0 @@
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
from onyx.agents.agent_search.deep_search.initial.retrieve_orig_question_docs.nodes.format_orig_question_search_input import (
format_orig_question_search_input,
)
from onyx.agents.agent_search.deep_search.initial.retrieve_orig_question_docs.nodes.format_orig_question_search_output import (
format_orig_question_search_output,
)
from onyx.agents.agent_search.deep_search.initial.retrieve_orig_question_docs.states import (
BaseRawSearchInput,
)
from onyx.agents.agent_search.deep_search.initial.retrieve_orig_question_docs.states import (
BaseRawSearchOutput,
)
from onyx.agents.agent_search.deep_search.initial.retrieve_orig_question_docs.states import (
BaseRawSearchState,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.graph_builder import (
expanded_retrieval_graph_builder,
)
def retrieve_orig_question_docs_graph_builder() -> StateGraph:
"""
LangGraph graph builder for the retrieval of documents
that are relevant to the original question. This is
largely a wrapper around the expanded retrieval process to
ensure parallelism with the sub-question answer process.
"""
graph = StateGraph(
state_schema=BaseRawSearchState,
input=BaseRawSearchInput,
output=BaseRawSearchOutput,
)
### Add nodes ###
# Format the original question search output
graph.add_node(
node="format_orig_question_search_output",
action=format_orig_question_search_output,
)
# The sub-graph that executes the expanded retrieval process
expanded_retrieval = expanded_retrieval_graph_builder().compile()
graph.add_node(
node="retrieve_orig_question_docs_subgraph",
action=expanded_retrieval,
)
# Format the original question search input
graph.add_node(
node="format_orig_question_search_input",
action=format_orig_question_search_input,
)
### Add edges ###
graph.add_edge(start_key=START, end_key="format_orig_question_search_input")
graph.add_edge(
start_key="format_orig_question_search_input",
end_key="retrieve_orig_question_docs_subgraph",
)
graph.add_edge(
start_key="retrieve_orig_question_docs_subgraph",
end_key="format_orig_question_search_output",
)
graph.add_edge(
start_key="format_orig_question_search_output",
end_key=END,
)
return graph
if __name__ == "__main__":
pass

View File

@@ -1,28 +0,0 @@
from typing import cast
from langchain_core.runnables.config import RunnableConfig
from onyx.agents.agent_search.core_state import CoreState
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
ExpandedRetrievalInput,
)
from onyx.agents.agent_search.models import GraphConfig
from onyx.utils.logger import setup_logger
logger = setup_logger()
def format_orig_question_search_input(
state: CoreState, config: RunnableConfig
) -> ExpandedRetrievalInput:
"""
LangGraph node to format the search input for the original question.
"""
logger.debug("generate_raw_search_data")
graph_config = cast(GraphConfig, config["metadata"]["config"])
return ExpandedRetrievalInput(
question=graph_config.inputs.prompt_builder.raw_user_query,
base_search=True,
sub_question_id=None, # This graph is always and only used for the original question
log_messages=[],
)

View File

@@ -1,30 +0,0 @@
from onyx.agents.agent_search.deep_search.main.states import OrigQuestionRetrievalUpdate
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
ExpandedRetrievalOutput,
)
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkRetrievalStats
from onyx.utils.logger import setup_logger
logger = setup_logger()
def format_orig_question_search_output(
state: ExpandedRetrievalOutput,
) -> OrigQuestionRetrievalUpdate:
"""
LangGraph node to format the search result for the original question into the
proper format.
"""
sub_question_retrieval_stats = state.expanded_retrieval_result.retrieval_stats
if sub_question_retrieval_stats is None:
sub_question_retrieval_stats = AgentChunkRetrievalStats()
else:
sub_question_retrieval_stats = sub_question_retrieval_stats
return OrigQuestionRetrievalUpdate(
orig_question_verified_reranked_documents=state.expanded_retrieval_result.verified_reranked_documents,
orig_question_sub_query_retrieval_results=state.expanded_retrieval_result.expanded_query_results,
orig_question_retrieved_documents=state.retrieved_documents,
orig_question_retrieval_stats=sub_question_retrieval_stats,
log_messages=[],
)

View File

@@ -1,29 +0,0 @@
from onyx.agents.agent_search.deep_search.main.states import (
OrigQuestionRetrievalUpdate,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
ExpandedRetrievalInput,
)
## Graph Input State
class BaseRawSearchInput(ExpandedRetrievalInput):
pass
## Graph Output State
class BaseRawSearchOutput(OrigQuestionRetrievalUpdate):
"""
This is a list of results even though each call of this subgraph only returns one result.
This is because if we parallelize the answer query subgraph, there will be multiple
results in a list so the add operator is used to add them together.
"""
# base_expanded_retrieval_result: QuestionRetrievalResult = QuestionRetrievalResult()
## Graph State
class BaseRawSearchState(
BaseRawSearchInput, BaseRawSearchOutput, OrigQuestionRetrievalUpdate
):
pass

View File

@@ -1,113 +0,0 @@
from collections.abc import Hashable
from datetime import datetime
from typing import cast
from typing import Literal
from langchain_core.runnables import RunnableConfig
from langgraph.types import Send
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
AnswerQuestionOutput,
)
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
SubQuestionAnsweringInput,
)
from onyx.agents.agent_search.deep_search.main.states import MainState
from onyx.agents.agent_search.deep_search.main.states import (
RequireRefinemenEvalUpdate,
)
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.utils import make_question_id
from onyx.utils.logger import setup_logger
logger = setup_logger()
def route_initial_tool_choice(
state: MainState, config: RunnableConfig
) -> Literal["call_tool", "start_agent_search", "logging_node"]:
"""
LangGraph edge to route to agent search.
"""
agent_config = cast(GraphConfig, config["metadata"]["config"])
if state.tool_choice is None:
return "logging_node"
if (
agent_config.behavior.use_agentic_search
and agent_config.tooling.search_tool is not None
and state.tool_choice.tool.name == agent_config.tooling.search_tool.name
):
return "start_agent_search"
else:
return "call_tool"
def parallelize_initial_sub_question_answering(
state: MainState,
) -> list[Send | Hashable]:
edge_start_time = datetime.now()
if len(state.initial_sub_questions) > 0:
return [
Send(
"answer_query_subgraph",
SubQuestionAnsweringInput(
question=question,
question_id=make_question_id(0, question_num + 1),
log_messages=[
f"{edge_start_time} -- Main Edge - Parallelize Initial Sub-question Answering"
],
),
)
for question_num, question in enumerate(state.initial_sub_questions)
]
else:
return [
Send(
"ingest_answers",
AnswerQuestionOutput(
answer_results=[],
),
)
]
# Define the function that determines whether to continue or not
def continue_to_refined_answer_or_end(
state: RequireRefinemenEvalUpdate,
) -> Literal["create_refined_sub_questions", "logging_node"]:
if state.require_refined_answer_eval:
return "create_refined_sub_questions"
else:
return "logging_node"
def parallelize_refined_sub_question_answering(
state: MainState,
) -> list[Send | Hashable]:
edge_start_time = datetime.now()
if len(state.refined_sub_questions) > 0:
return [
Send(
"answer_refined_question_subgraphs",
SubQuestionAnsweringInput(
question=question_data.sub_question,
question_id=make_question_id(1, question_num),
log_messages=[
f"{edge_start_time} -- Main Edge - Parallelize Refined Sub-question Answering"
],
),
)
for question_num, question_data in state.refined_sub_questions.items()
]
else:
return [
Send(
"ingest_refined_sub_answers",
AnswerQuestionOutput(
answer_results=[],
),
)
]

View File

@@ -1,263 +0,0 @@
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
from onyx.agents.agent_search.deep_search.initial.generate_initial_answer.graph_builder import (
generate_initial_answer_graph_builder,
)
from onyx.agents.agent_search.deep_search.main.edges import (
continue_to_refined_answer_or_end,
)
from onyx.agents.agent_search.deep_search.main.edges import (
parallelize_refined_sub_question_answering,
)
from onyx.agents.agent_search.deep_search.main.edges import (
route_initial_tool_choice,
)
from onyx.agents.agent_search.deep_search.main.nodes.compare_answers import (
compare_answers,
)
from onyx.agents.agent_search.deep_search.main.nodes.create_refined_sub_questions import (
create_refined_sub_questions,
)
from onyx.agents.agent_search.deep_search.main.nodes.decide_refinement_need import (
decide_refinement_need,
)
from onyx.agents.agent_search.deep_search.main.nodes.extract_entities_terms import (
extract_entities_terms,
)
from onyx.agents.agent_search.deep_search.main.nodes.generate_validate_refined_answer import (
generate_validate_refined_answer,
)
from onyx.agents.agent_search.deep_search.main.nodes.ingest_refined_sub_answers import (
ingest_refined_sub_answers,
)
from onyx.agents.agent_search.deep_search.main.nodes.persist_agent_results import (
persist_agent_results,
)
from onyx.agents.agent_search.deep_search.main.nodes.start_agent_search import (
start_agent_search,
)
from onyx.agents.agent_search.deep_search.main.states import MainInput
from onyx.agents.agent_search.deep_search.main.states import MainState
from onyx.agents.agent_search.deep_search.refinement.consolidate_sub_answers.graph_builder import (
answer_refined_query_graph_builder,
)
from onyx.agents.agent_search.orchestration.nodes.call_tool import call_tool
from onyx.agents.agent_search.orchestration.nodes.choose_tool import choose_tool
from onyx.agents.agent_search.orchestration.nodes.prepare_tool_input import (
prepare_tool_input,
)
from onyx.agents.agent_search.orchestration.nodes.use_tool_response import (
basic_use_tool_response,
)
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
from onyx.utils.logger import setup_logger
logger = setup_logger()
test_mode = False
def agent_search_graph_builder() -> StateGraph:
"""
LangGraph graph builder for the main agent search process.
"""
graph = StateGraph(
state_schema=MainState,
input=MainInput,
)
# Prepare the tool input
graph.add_node(
node="prepare_tool_input",
action=prepare_tool_input,
)
# Choose the initial tool
graph.add_node(
node="choose_tool",
action=choose_tool,
)
# Call the tool, if required
graph.add_node(
node="call_tool",
action=call_tool,
)
# Use the tool response
graph.add_node(
node="basic_use_tool_response",
action=basic_use_tool_response,
)
# Start the agent search process
graph.add_node(
node="start_agent_search",
action=start_agent_search,
)
# The sub-graph for the initial answer generation
generate_initial_answer_subgraph = generate_initial_answer_graph_builder().compile()
graph.add_node(
node="generate_initial_answer_subgraph",
action=generate_initial_answer_subgraph,
)
# Create the refined sub-questions
graph.add_node(
node="create_refined_sub_questions",
action=create_refined_sub_questions,
)
# Subgraph for the refined sub-answer generation
answer_refined_question = answer_refined_query_graph_builder().compile()
graph.add_node(
node="answer_refined_question_subgraphs",
action=answer_refined_question,
)
# Ingest the refined sub-answers
graph.add_node(
node="ingest_refined_sub_answers",
action=ingest_refined_sub_answers,
)
# Node to generate the refined answer
graph.add_node(
node="generate_validate_refined_answer",
action=generate_validate_refined_answer,
)
# Early node to extract the entities and terms from the initial answer,
# This information is used to inform the creation the refined sub-questions
graph.add_node(
node="extract_entity_term",
action=extract_entities_terms,
)
# Decide if the answer needs to be refined (currently always true)
graph.add_node(
node="decide_refinement_need",
action=decide_refinement_need,
)
# Compare the initial and refined answers, and determine whether
# the refined answer is sufficiently better
graph.add_node(
node="compare_answers",
action=compare_answers,
)
# Log the results. This will log the stats as well as the answers, sub-questions, and sub-answers
graph.add_node(
node="logging_node",
action=persist_agent_results,
)
### Add edges ###
graph.add_edge(start_key=START, end_key="prepare_tool_input")
graph.add_edge(
start_key="prepare_tool_input",
end_key="choose_tool",
)
graph.add_conditional_edges(
"choose_tool",
route_initial_tool_choice,
["call_tool", "start_agent_search", "logging_node"],
)
graph.add_edge(
start_key="call_tool",
end_key="basic_use_tool_response",
)
graph.add_edge(
start_key="basic_use_tool_response",
end_key="logging_node",
)
graph.add_edge(
start_key="start_agent_search",
end_key="generate_initial_answer_subgraph",
)
graph.add_edge(
start_key="start_agent_search",
end_key="extract_entity_term",
)
# Wait for the initial answer generation and the entity/term extraction to be complete
# before deciding if a refinement is needed.
graph.add_edge(
start_key=["generate_initial_answer_subgraph", "extract_entity_term"],
end_key="decide_refinement_need",
)
graph.add_conditional_edges(
source="decide_refinement_need",
path=continue_to_refined_answer_or_end,
path_map=["create_refined_sub_questions", "logging_node"],
)
graph.add_conditional_edges(
source="create_refined_sub_questions",
path=parallelize_refined_sub_question_answering,
path_map=["answer_refined_question_subgraphs"],
)
graph.add_edge(
start_key="answer_refined_question_subgraphs",
end_key="ingest_refined_sub_answers",
)
graph.add_edge(
start_key="ingest_refined_sub_answers",
end_key="generate_validate_refined_answer",
)
graph.add_edge(
start_key="generate_validate_refined_answer",
end_key="compare_answers",
)
graph.add_edge(
start_key="compare_answers",
end_key="logging_node",
)
graph.add_edge(
start_key="logging_node",
end_key=END,
)
return graph
if __name__ == "__main__":
pass
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.llm.factory import get_default_llms
from onyx.context.search.models import SearchRequest
graph = agent_search_graph_builder()
compiled_graph = graph.compile()
primary_llm, fast_llm = get_default_llms()
with get_session_with_current_tenant() as db_session:
search_request = SearchRequest(query="Who created Excel?")
graph_config = get_test_config(
db_session, primary_llm, fast_llm, search_request
)
inputs = MainInput(log_messages=[])
for thing in compiled_graph.stream(
input=inputs,
config={"configurable": {"config": graph_config}},
stream_mode="custom",
subgraphs=True,
):
logger.debug(thing)

View File

@@ -1,168 +0,0 @@
from datetime import datetime
from typing import cast
from langchain_core.messages import BaseMessage
from langchain_core.messages import HumanMessage
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.deep_search.main.states import (
InitialRefinedAnswerComparisonUpdate,
)
from onyx.agents.agent_search.deep_search.main.states import MainState
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
binary_string_test,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_LLM_RATELIMIT_MESSAGE,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_LLM_TIMEOUT_MESSAGE,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_POSITIVE_VALUE_STR,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AgentLLMErrorType,
)
from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLog
from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import RefinedAnswerImprovement
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_VALIDATION
from onyx.configs.agent_configs import AGENT_TIMEOUT_CONNECT_LLM_COMPARE_ANSWERS
from onyx.configs.agent_configs import AGENT_TIMEOUT_LLM_COMPARE_ANSWERS
from onyx.llm.chat_llm import LLMRateLimitError
from onyx.llm.chat_llm import LLMTimeoutError
from onyx.prompts.agent_search import (
INITIAL_REFINED_ANSWER_COMPARISON_PROMPT,
)
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
from onyx.utils.timing import log_function_time
logger = setup_logger()
_llm_node_error_strings = LLMNodeErrorStrings(
timeout="The LLM timed out, and the answers could not be compared.",
rate_limit="The LLM encountered a rate limit, and the answers could not be compared.",
general_error="The LLM encountered an error, and the answers could not be compared.",
)
_ANSWER_QUALITY_NOT_SUFFICIENT_MESSAGE = (
"Answer quality is not sufficient, so stay with the initial answer."
)
@log_function_time(print_only=True)
def compare_answers(
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> InitialRefinedAnswerComparisonUpdate:
"""
LangGraph node to compare the initial answer and the refined answer and determine if the
refined answer is sufficiently better than the initial answer.
"""
node_start_time = datetime.now()
graph_config = cast(GraphConfig, config["metadata"]["config"])
question = graph_config.inputs.prompt_builder.raw_user_query
initial_answer = state.initial_answer
refined_answer = state.refined_answer
# if answer quality is not sufficient, then stay with the initial answer
if not state.refined_answer_quality:
write_custom_event(
"refined_answer_improvement",
RefinedAnswerImprovement(
refined_answer_improvement=False,
),
writer,
)
return InitialRefinedAnswerComparisonUpdate(
refined_answer_improvement_eval=False,
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="compare answers",
node_start_time=node_start_time,
result=_ANSWER_QUALITY_NOT_SUFFICIENT_MESSAGE,
)
],
)
compare_answers_prompt = INITIAL_REFINED_ANSWER_COMPARISON_PROMPT.format(
question=question, initial_answer=initial_answer, refined_answer=refined_answer
)
msg = [HumanMessage(content=compare_answers_prompt)]
agent_error: AgentErrorLog | None = None
# Get the rewritten queries in a defined format
model = graph_config.tooling.fast_llm
resp: BaseMessage | None = None
refined_answer_improvement: bool | None = None
# no need to stream this
try:
resp = run_with_timeout(
AGENT_TIMEOUT_LLM_COMPARE_ANSWERS,
model.invoke,
prompt=msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_COMPARE_ANSWERS,
max_tokens=AGENT_MAX_TOKENS_VALIDATION,
)
except (LLMTimeoutError, TimeoutError):
agent_error = AgentErrorLog(
error_type=AgentLLMErrorType.TIMEOUT,
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
error_result=_llm_node_error_strings.timeout,
)
logger.error("LLM Timeout Error - compare answers")
# continue as True in this support step
except LLMRateLimitError:
agent_error = AgentErrorLog(
error_type=AgentLLMErrorType.RATE_LIMIT,
error_message=AGENT_LLM_RATELIMIT_MESSAGE,
error_result=_llm_node_error_strings.rate_limit,
)
logger.error("LLM Rate Limit Error - compare answers")
# continue as True in this support step
if agent_error or resp is None:
refined_answer_improvement = True
if agent_error:
log_result = agent_error.error_result
else:
log_result = "An answer could not be generated."
else:
refined_answer_improvement = binary_string_test(
text=cast(str, resp.content),
positive_value=AGENT_POSITIVE_VALUE_STR,
)
log_result = f"Answer comparison: {refined_answer_improvement}"
write_custom_event(
"refined_answer_improvement",
RefinedAnswerImprovement(
refined_answer_improvement=refined_answer_improvement,
),
writer,
)
return InitialRefinedAnswerComparisonUpdate(
refined_answer_improvement_eval=refined_answer_improvement,
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="compare answers",
node_start_time=node_start_time,
result=log_result,
)
],
)

View File

@@ -1,213 +0,0 @@
from datetime import datetime
from typing import cast
from langchain_core.messages import HumanMessage
from langchain_core.messages import merge_content
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.deep_search.main.models import (
RefinementSubQuestion,
)
from onyx.agents.agent_search.deep_search.main.operations import dispatch_subquestion
from onyx.agents.agent_search.deep_search.main.operations import (
dispatch_subquestion_sep,
)
from onyx.agents.agent_search.deep_search.main.states import MainState
from onyx.agents.agent_search.deep_search.main.states import (
RefinedQuestionDecompositionUpdate,
)
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
build_history_prompt,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_LLM_RATELIMIT_MESSAGE,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_LLM_TIMEOUT_MESSAGE,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AgentLLMErrorType,
)
from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLog
from onyx.agents.agent_search.shared_graph_utils.models import BaseMessage_Content
from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings
from onyx.agents.agent_search.shared_graph_utils.utils import dispatch_separated
from onyx.agents.agent_search.shared_graph_utils.utils import (
format_entity_term_extraction,
)
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import make_question_id
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import StreamingError
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_SUBQUESTION_GENERATION
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_CONNECT_LLM_REFINED_SUBQUESTION_GENERATION,
)
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_LLM_REFINED_SUBQUESTION_GENERATION,
)
from onyx.llm.chat_llm import LLMRateLimitError
from onyx.llm.chat_llm import LLMTimeoutError
from onyx.prompts.agent_search import (
REFINEMENT_QUESTION_DECOMPOSITION_PROMPT_W_INITIAL_SUBQUESTION_ANSWERS,
)
from onyx.tools.models import ToolCallKickoff
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
from onyx.utils.timing import log_function_time
logger = setup_logger()
_ANSWERED_SUBQUESTIONS_DIVIDER = "\n\n---\n\n"
_llm_node_error_strings = LLMNodeErrorStrings(
timeout="The LLM timed out. The sub-questions could not be generated.",
rate_limit="The LLM encountered a rate limit. The sub-questions could not be generated.",
general_error="The LLM encountered an error. The sub-questions could not be generated.",
)
@log_function_time(print_only=True)
def create_refined_sub_questions(
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> RefinedQuestionDecompositionUpdate:
"""
LangGraph node to create refined sub-questions based on the initial answer, the history,
the entity term extraction results found earlier, and the sub-questions that were answered and failed.
"""
graph_config = cast(GraphConfig, config["metadata"]["config"])
write_custom_event(
"start_refined_answer_creation",
ToolCallKickoff(
tool_name="agent_search_1",
tool_args={
"query": graph_config.inputs.prompt_builder.raw_user_query,
"answer": state.initial_answer,
},
),
writer,
)
node_start_time = datetime.now()
agent_refined_start_time = datetime.now()
question = graph_config.inputs.prompt_builder.raw_user_query
base_answer = state.initial_answer
history = build_history_prompt(graph_config, question)
# get the entity term extraction dict and properly format it
entity_retlation_term_extractions = state.entity_relation_term_extractions
entity_term_extraction_str = format_entity_term_extraction(
entity_retlation_term_extractions
)
initial_question_answers = state.sub_question_results
addressed_subquestions_with_answers = [
f"Subquestion: {x.question}\nSubanswer:\n{x.answer}"
for x in initial_question_answers
if x.verified_high_quality and x.answer
]
failed_question_list = [
x.question for x in initial_question_answers if not x.verified_high_quality
]
msg = [
HumanMessage(
content=REFINEMENT_QUESTION_DECOMPOSITION_PROMPT_W_INITIAL_SUBQUESTION_ANSWERS.format(
question=question,
history=history,
entity_term_extraction_str=entity_term_extraction_str,
base_answer=base_answer,
answered_subquestions_with_answers=_ANSWERED_SUBQUESTIONS_DIVIDER.join(
addressed_subquestions_with_answers
),
failed_sub_questions="\n - ".join(failed_question_list),
),
)
]
# Grader
model = graph_config.tooling.fast_llm
agent_error: AgentErrorLog | None = None
streamed_tokens: list[BaseMessage_Content] = []
try:
streamed_tokens = run_with_timeout(
AGENT_TIMEOUT_LLM_REFINED_SUBQUESTION_GENERATION,
dispatch_separated,
model.stream(
msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_REFINED_SUBQUESTION_GENERATION,
max_tokens=AGENT_MAX_TOKENS_SUBQUESTION_GENERATION,
),
dispatch_subquestion(1, writer),
sep_callback=dispatch_subquestion_sep(1, writer),
)
except (LLMTimeoutError, TimeoutError):
agent_error = AgentErrorLog(
error_type=AgentLLMErrorType.TIMEOUT,
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
error_result=_llm_node_error_strings.timeout,
)
logger.error("LLM Timeout Error - create refined sub questions")
except LLMRateLimitError:
agent_error = AgentErrorLog(
error_type=AgentLLMErrorType.RATE_LIMIT,
error_message=AGENT_LLM_RATELIMIT_MESSAGE,
error_result=_llm_node_error_strings.rate_limit,
)
logger.error("LLM Rate Limit Error - create refined sub questions")
if agent_error:
refined_sub_question_dict: dict[int, RefinementSubQuestion] = {}
log_result = agent_error.error_result
write_custom_event(
"refined_sub_question_creation_error",
StreamingError(
error="Your LLM was not able to create refined sub questions in time and timed out. Please try again.",
),
writer,
)
else:
response = merge_content(*streamed_tokens)
if isinstance(response, str):
parsed_response = [q for q in response.split("\n") if q.strip() != ""]
else:
raise ValueError("LLM response is not a string")
refined_sub_question_dict = {}
for sub_question_num, sub_question in enumerate(parsed_response):
refined_sub_question = RefinementSubQuestion(
sub_question=sub_question,
sub_question_id=make_question_id(1, sub_question_num + 1),
verified=False,
answered=False,
answer="",
)
refined_sub_question_dict[sub_question_num + 1] = refined_sub_question
log_result = f"Created {len(refined_sub_question_dict)} refined sub questions"
return RefinedQuestionDecompositionUpdate(
refined_sub_questions=refined_sub_question_dict,
agent_refined_start_time=agent_refined_start_time,
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="create refined sub questions",
node_start_time=node_start_time,
result=log_result,
)
],
)

View File

@@ -1,56 +0,0 @@
from datetime import datetime
from typing import cast
from langchain_core.runnables import RunnableConfig
from onyx.agents.agent_search.deep_search.main.states import MainState
from onyx.agents.agent_search.deep_search.main.states import (
RequireRefinemenEvalUpdate,
)
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.utils.timing import log_function_time
@log_function_time(print_only=True)
def decide_refinement_need(
state: MainState, config: RunnableConfig
) -> RequireRefinemenEvalUpdate:
"""
LangGraph node to decide if refinement is needed based on the initial answer and the question.
At present, we always refine.
"""
node_start_time = datetime.now()
graph_config = cast(GraphConfig, config["metadata"]["config"])
decision = graph_config.behavior.allow_refinement
if state.answer_error:
return RequireRefinemenEvalUpdate(
require_refined_answer_eval=False,
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="decide refinement need",
node_start_time=node_start_time,
result="Timeout Error",
)
],
)
log_messages = [
get_langgraph_node_log_string(
graph_component="main",
node_name="decide refinement need",
node_start_time=node_start_time,
result=f"Refinement decision: {decision}",
)
]
return RequireRefinemenEvalUpdate(
require_refined_answer_eval=graph_config.behavior.allow_refinement and decision,
log_messages=log_messages,
)

View File

@@ -1,142 +0,0 @@
from datetime import datetime
from typing import cast
from langchain_core.messages import HumanMessage
from langchain_core.runnables import RunnableConfig
from onyx.agents.agent_search.deep_search.main.operations import logger
from onyx.agents.agent_search.deep_search.main.states import (
EntityTermExtractionUpdate,
)
from onyx.agents.agent_search.deep_search.main.states import MainState
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
trim_prompt_piece,
)
from onyx.agents.agent_search.shared_graph_utils.models import EntityExtractionResult
from onyx.agents.agent_search.shared_graph_utils.models import (
EntityRelationshipTermExtraction,
)
from onyx.agents.agent_search.shared_graph_utils.utils import format_docs
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_ENTITY_TERM_EXTRACTION
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_CONNECT_LLM_ENTITY_TERM_EXTRACTION,
)
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_LLM_ENTITY_TERM_EXTRACTION,
)
from onyx.configs.constants import NUM_EXPLORATORY_DOCS
from onyx.llm.chat_llm import LLMRateLimitError
from onyx.llm.chat_llm import LLMTimeoutError
from onyx.prompts.agent_search import ENTITY_TERM_EXTRACTION_PROMPT
from onyx.prompts.agent_search import ENTITY_TERM_EXTRACTION_PROMPT_JSON_EXAMPLE
from onyx.utils.threadpool_concurrency import run_with_timeout
from onyx.utils.timing import log_function_time
@log_function_time(print_only=True)
def extract_entities_terms(
state: MainState, config: RunnableConfig
) -> EntityTermExtractionUpdate:
"""
LangGraph node to extract entities, relationships, and terms from the initial search results.
This data is used to inform particularly the sub-questions that are created for the refined answer.
"""
node_start_time = datetime.now()
graph_config = cast(GraphConfig, config["metadata"]["config"])
if not graph_config.behavior.allow_refinement:
return EntityTermExtractionUpdate(
entity_relation_term_extractions=EntityRelationshipTermExtraction(
entities=[],
relationships=[],
terms=[],
),
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="extract entities terms",
node_start_time=node_start_time,
result="Refinement is not allowed",
)
],
)
# first four lines duplicates from generate_initial_answer
question = graph_config.inputs.prompt_builder.raw_user_query
initial_search_docs = state.exploratory_search_results[:NUM_EXPLORATORY_DOCS]
# start with the entity/term/extraction
doc_context = format_docs(initial_search_docs)
# Calculation here is only approximate
doc_context = trim_prompt_piece(
config=graph_config.tooling.fast_llm.config,
prompt_piece=doc_context,
reserved_str=ENTITY_TERM_EXTRACTION_PROMPT
+ question
+ ENTITY_TERM_EXTRACTION_PROMPT_JSON_EXAMPLE,
)
msg = [
HumanMessage(
content=ENTITY_TERM_EXTRACTION_PROMPT.format(
question=question, context=doc_context
)
+ ENTITY_TERM_EXTRACTION_PROMPT_JSON_EXAMPLE,
)
]
fast_llm = graph_config.tooling.fast_llm
# Grader
try:
llm_response = run_with_timeout(
AGENT_TIMEOUT_LLM_ENTITY_TERM_EXTRACTION,
fast_llm.invoke,
prompt=msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_ENTITY_TERM_EXTRACTION,
max_tokens=AGENT_MAX_TOKENS_ENTITY_TERM_EXTRACTION,
)
cleaned_response = (
str(llm_response.content).replace("```json\n", "").replace("\n```", "")
)
first_bracket = cleaned_response.find("{")
last_bracket = cleaned_response.rfind("}")
cleaned_response = cleaned_response[first_bracket : last_bracket + 1]
try:
entity_extraction_result = EntityExtractionResult.model_validate_json(
cleaned_response
)
except ValueError:
logger.error(
"Failed to parse LLM response as JSON in Entity-Term Extraction"
)
entity_extraction_result = EntityExtractionResult(
retrieved_entities_relationships=EntityRelationshipTermExtraction(),
)
except (LLMTimeoutError, TimeoutError):
logger.error("LLM Timeout Error - extract entities terms")
entity_extraction_result = EntityExtractionResult(
retrieved_entities_relationships=EntityRelationshipTermExtraction(),
)
except LLMRateLimitError:
logger.error("LLM Rate Limit Error - extract entities terms")
entity_extraction_result = EntityExtractionResult(
retrieved_entities_relationships=EntityRelationshipTermExtraction(),
)
return EntityTermExtractionUpdate(
entity_relation_term_extractions=entity_extraction_result.retrieved_entities_relationships,
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="extract entities terms",
node_start_time=node_start_time,
)
],
)

View File

@@ -1,445 +0,0 @@
from datetime import datetime
from typing import cast
from langchain_core.messages import HumanMessage
from langchain_core.messages import merge_content
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.deep_search.main.models import (
AgentRefinedMetrics,
)
from onyx.agents.agent_search.deep_search.main.operations import get_query_info
from onyx.agents.agent_search.deep_search.main.states import MainState
from onyx.agents.agent_search.deep_search.main.states import (
RefinedAnswerUpdate,
)
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
binary_string_test_after_answer_separator,
)
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
get_prompt_enrichment_components,
)
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
trim_prompt_piece,
)
from onyx.agents.agent_search.shared_graph_utils.calculations import (
get_answer_generation_documents,
)
from onyx.agents.agent_search.shared_graph_utils.constants import AGENT_ANSWER_SEPARATOR
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_LLM_RATELIMIT_MESSAGE,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_LLM_TIMEOUT_MESSAGE,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_POSITIVE_VALUE_STR,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AgentLLMErrorType,
)
from onyx.agents.agent_search.shared_graph_utils.llm import stream_llm_answer
from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLog
from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings
from onyx.agents.agent_search.shared_graph_utils.models import RefinedAgentStats
from onyx.agents.agent_search.shared_graph_utils.operators import (
dedup_inference_section_list,
)
from onyx.agents.agent_search.shared_graph_utils.utils import _should_restrict_tokens
from onyx.agents.agent_search.shared_graph_utils.utils import (
dispatch_main_answer_stop_info,
)
from onyx.agents.agent_search.shared_graph_utils.utils import format_docs
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_deduplicated_structured_subquestion_documents,
)
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
from onyx.agents.agent_search.shared_graph_utils.utils import relevance_from_docs
from onyx.agents.agent_search.shared_graph_utils.utils import (
remove_document_citations,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import ExtendedToolResponse
from onyx.chat.models import StreamingError
from onyx.configs.agent_configs import AGENT_ANSWER_GENERATION_BY_FAST_LLM
from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS
from onyx.configs.agent_configs import AGENT_MAX_STREAMED_DOCS_FOR_REFINED_ANSWER
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_ANSWER_GENERATION
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_VALIDATION
from onyx.configs.agent_configs import AGENT_MIN_ORIG_QUESTION_DOCS
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION,
)
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_VALIDATION,
)
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_LLM_REFINED_ANSWER_GENERATION,
)
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_LLM_REFINED_ANSWER_VALIDATION,
)
from onyx.llm.chat_llm import LLMRateLimitError
from onyx.llm.chat_llm import LLMTimeoutError
from onyx.prompts.agent_search import (
REFINED_ANSWER_PROMPT_W_SUB_QUESTIONS,
)
from onyx.prompts.agent_search import (
REFINED_ANSWER_PROMPT_WO_SUB_QUESTIONS,
)
from onyx.prompts.agent_search import (
REFINED_ANSWER_VALIDATION_PROMPT,
)
from onyx.prompts.agent_search import (
SUB_QUESTION_ANSWER_TEMPLATE_REFINED,
)
from onyx.prompts.agent_search import UNKNOWN_ANSWER
from onyx.tools.tool_implementations.search.search_tool import yield_search_responses
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
from onyx.utils.timing import log_function_time
logger = setup_logger()
_llm_node_error_strings = LLMNodeErrorStrings(
timeout="The LLM timed out. The refined answer could not be generated.",
rate_limit="The LLM encountered a rate limit. The refined answer could not be generated.",
general_error="The LLM encountered an error. The refined answer could not be generated.",
)
@log_function_time(print_only=True)
def generate_validate_refined_answer(
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> RefinedAnswerUpdate:
"""
LangGraph node to generate the refined answer and validate it.
"""
node_start_time = datetime.now()
graph_config = cast(GraphConfig, config["metadata"]["config"])
question = graph_config.inputs.prompt_builder.raw_user_query
prompt_enrichment_components = get_prompt_enrichment_components(graph_config)
persona_contextualized_prompt = (
prompt_enrichment_components.persona_prompts.contextualized_prompt
)
verified_reranked_documents = state.verified_reranked_documents
# get all documents cited in sub-questions
structured_subquestion_docs = get_deduplicated_structured_subquestion_documents(
state.sub_question_results
)
original_question_verified_documents = (
state.orig_question_verified_reranked_documents
)
original_question_retrieved_documents = state.orig_question_retrieved_documents
consolidated_context_docs = structured_subquestion_docs.cited_documents
counter = 0
for original_doc in original_question_verified_documents:
if original_doc not in structured_subquestion_docs.cited_documents:
if (
counter <= AGENT_MIN_ORIG_QUESTION_DOCS
or len(consolidated_context_docs)
< 1.5
* AGENT_MAX_ANSWER_CONTEXT_DOCS # allow for larger context in refinement
):
consolidated_context_docs.append(original_doc)
counter += 1
# sort docs by their scores - though the scores refer to different questions
relevant_docs = dedup_inference_section_list(consolidated_context_docs)
# Create the list of documents to stream out. Start with the
# ones that wil be in the context (or, if len == 0, use docs
# that were retrieved for the original question)
answer_generation_documents = get_answer_generation_documents(
relevant_docs=relevant_docs,
context_documents=structured_subquestion_docs.context_documents,
original_question_docs=original_question_retrieved_documents,
max_docs=AGENT_MAX_STREAMED_DOCS_FOR_REFINED_ANSWER,
)
query_info = get_query_info(state.orig_question_sub_query_retrieval_results)
assert (
graph_config.tooling.search_tool
), "search_tool must be provided for agentic search"
# stream refined answer docs, or original question docs if no relevant docs are found
relevance_list = relevance_from_docs(
answer_generation_documents.streaming_documents
)
for tool_response in yield_search_responses(
query=question,
get_retrieved_sections=lambda: answer_generation_documents.context_documents,
get_final_context_sections=lambda: answer_generation_documents.context_documents,
search_query_info=query_info,
get_section_relevance=lambda: relevance_list,
search_tool=graph_config.tooling.search_tool,
):
write_custom_event(
"tool_response",
ExtendedToolResponse(
id=tool_response.id,
response=tool_response.response,
level=1,
level_question_num=0, # 0, 0 is the base question
),
writer,
)
if len(verified_reranked_documents) > 0:
refined_doc_effectiveness = len(relevant_docs) / len(
verified_reranked_documents
)
else:
refined_doc_effectiveness = 10.0
sub_question_answer_results = state.sub_question_results
answered_sub_question_answer_list: list[str] = []
sub_questions: list[str] = []
initial_answered_sub_questions: set[str] = set()
refined_answered_sub_questions: set[str] = set()
for i, result in enumerate(sub_question_answer_results, 1):
question_level, _ = parse_question_id(result.question_id)
sub_questions.append(result.question)
if (
result.verified_high_quality
and result.answer
and result.answer != UNKNOWN_ANSWER
):
sub_question_type = "initial" if question_level == 0 else "refined"
question_set = (
initial_answered_sub_questions
if question_level == 0
else refined_answered_sub_questions
)
question_set.add(result.question)
answered_sub_question_answer_list.append(
SUB_QUESTION_ANSWER_TEMPLATE_REFINED.format(
sub_question=result.question,
sub_answer=result.answer,
sub_question_num=i,
sub_question_type=sub_question_type,
)
)
# Calculate efficiency
total_answered_questions = (
initial_answered_sub_questions | refined_answered_sub_questions
)
revision_question_efficiency = (
len(total_answered_questions) / len(initial_answered_sub_questions)
if initial_answered_sub_questions
else 10.0 if refined_answered_sub_questions else 1.0
)
sub_question_answer_str = "\n\n------\n\n".join(
set(answered_sub_question_answer_list)
)
initial_answer = state.initial_answer or ""
# Choose appropriate prompt template
base_prompt = (
REFINED_ANSWER_PROMPT_W_SUB_QUESTIONS
if answered_sub_question_answer_list
else REFINED_ANSWER_PROMPT_WO_SUB_QUESTIONS
)
model = (
graph_config.tooling.fast_llm
if AGENT_ANSWER_GENERATION_BY_FAST_LLM
else graph_config.tooling.primary_llm
)
relevant_docs_str = format_docs(answer_generation_documents.context_documents)
relevant_docs_str = trim_prompt_piece(
config=model.config,
prompt_piece=relevant_docs_str,
reserved_str=base_prompt
+ question
+ sub_question_answer_str
+ initial_answer
+ persona_contextualized_prompt
+ prompt_enrichment_components.history,
)
msg = [
HumanMessage(
content=base_prompt.format(
question=question,
history=prompt_enrichment_components.history,
answered_sub_questions=remove_document_citations(
sub_question_answer_str
),
relevant_docs=relevant_docs_str,
initial_answer=(
remove_document_citations(initial_answer)
if initial_answer
else None
),
persona_specification=persona_contextualized_prompt,
date_prompt=prompt_enrichment_components.date_str,
)
)
]
streamed_tokens: list[str] = [""]
dispatch_timings: list[float] = []
agent_error: AgentErrorLog | None = None
try:
streamed_tokens, dispatch_timings = run_with_timeout(
AGENT_TIMEOUT_LLM_REFINED_ANSWER_GENERATION,
lambda: stream_llm_answer(
llm=model,
prompt=msg,
event_name="refined_agent_answer",
writer=writer,
agent_answer_level=1,
agent_answer_question_num=0,
agent_answer_type="agent_level_answer",
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION,
max_tokens=(
AGENT_MAX_TOKENS_ANSWER_GENERATION
if _should_restrict_tokens(model.config)
else None
),
),
)
except (LLMTimeoutError, TimeoutError):
agent_error = AgentErrorLog(
error_type=AgentLLMErrorType.TIMEOUT,
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
error_result=_llm_node_error_strings.timeout,
)
logger.error("LLM Timeout Error - generate refined answer")
except LLMRateLimitError:
agent_error = AgentErrorLog(
error_type=AgentLLMErrorType.RATE_LIMIT,
error_message=AGENT_LLM_RATELIMIT_MESSAGE,
error_result=_llm_node_error_strings.rate_limit,
)
logger.error("LLM Rate Limit Error - generate refined answer")
if agent_error:
write_custom_event(
"initial_agent_answer",
StreamingError(
error=AGENT_LLM_TIMEOUT_MESSAGE,
),
writer,
)
return RefinedAnswerUpdate(
refined_answer=None,
refined_answer_quality=False, # TODO: replace this with the actual check value
refined_agent_stats=None,
agent_refined_end_time=None,
agent_refined_metrics=AgentRefinedMetrics(
refined_doc_boost_factor=0.0,
refined_question_boost_factor=0.0,
duration_s=None,
),
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="generate refined answer",
node_start_time=node_start_time,
result=agent_error.error_result or "An LLM error occurred",
)
],
)
logger.debug(
f"Average dispatch time for refined answer: {sum(dispatch_timings) / len(dispatch_timings)}"
)
dispatch_main_answer_stop_info(1, writer)
response = merge_content(*streamed_tokens)
answer = cast(str, response)
# run a validation step for the refined answer only
msg = [
HumanMessage(
content=REFINED_ANSWER_VALIDATION_PROMPT.format(
question=question,
history=prompt_enrichment_components.history,
answered_sub_questions=sub_question_answer_str,
relevant_docs=relevant_docs_str,
proposed_answer=answer,
persona_specification=persona_contextualized_prompt,
)
)
]
validation_model = graph_config.tooling.fast_llm
try:
validation_response = run_with_timeout(
AGENT_TIMEOUT_LLM_REFINED_ANSWER_VALIDATION,
validation_model.invoke,
prompt=msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_VALIDATION,
max_tokens=AGENT_MAX_TOKENS_VALIDATION,
)
refined_answer_quality = binary_string_test_after_answer_separator(
text=cast(str, validation_response.content),
positive_value=AGENT_POSITIVE_VALUE_STR,
separator=AGENT_ANSWER_SEPARATOR,
)
except (LLMTimeoutError, TimeoutError):
refined_answer_quality = True
logger.error("LLM Timeout Error - validate refined answer")
except LLMRateLimitError:
refined_answer_quality = True
logger.error("LLM Rate Limit Error - validate refined answer")
refined_agent_stats = RefinedAgentStats(
revision_doc_efficiency=refined_doc_effectiveness,
revision_question_efficiency=revision_question_efficiency,
)
agent_refined_end_time = datetime.now()
if state.agent_refined_start_time:
agent_refined_duration = (
agent_refined_end_time - state.agent_refined_start_time
).total_seconds()
else:
agent_refined_duration = None
agent_refined_metrics = AgentRefinedMetrics(
refined_doc_boost_factor=refined_agent_stats.revision_doc_efficiency,
refined_question_boost_factor=refined_agent_stats.revision_question_efficiency,
duration_s=agent_refined_duration,
)
return RefinedAnswerUpdate(
refined_answer=answer,
refined_answer_quality=refined_answer_quality,
refined_agent_stats=refined_agent_stats,
agent_refined_end_time=agent_refined_end_time,
agent_refined_metrics=agent_refined_metrics,
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="generate refined answer",
node_start_time=node_start_time,
)
],
)

View File

@@ -1,42 +0,0 @@
from datetime import datetime
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
AnswerQuestionOutput,
)
from onyx.agents.agent_search.deep_search.main.states import (
SubQuestionResultsUpdate,
)
from onyx.agents.agent_search.shared_graph_utils.operators import (
dedup_inference_sections,
)
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
def ingest_refined_sub_answers(
state: AnswerQuestionOutput,
) -> SubQuestionResultsUpdate:
"""
LangGraph node to ingest and format the refined sub-answers and retrieved documents.
"""
node_start_time = datetime.now()
documents = []
answer_results = state.answer_results
for answer_result in answer_results:
documents.extend(answer_result.verified_reranked_documents)
return SubQuestionResultsUpdate(
# Deduping is done by the documents operator for the main graph
# so we might not need to dedup here
verified_reranked_documents=dedup_inference_sections(documents, []),
sub_question_results=answer_results,
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="ingest refined answers",
node_start_time=node_start_time,
)
],
)

View File

@@ -1,129 +0,0 @@
from datetime import datetime
from typing import cast
from langchain_core.runnables import RunnableConfig
from onyx.agents.agent_search.deep_search.main.models import (
AgentAdditionalMetrics,
)
from onyx.agents.agent_search.deep_search.main.models import AgentTimings
from onyx.agents.agent_search.deep_search.main.operations import logger
from onyx.agents.agent_search.deep_search.main.states import MainOutput
from onyx.agents.agent_search.deep_search.main.states import MainState
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.models import CombinedAgentMetrics
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.db.chat import log_agent_metrics
from onyx.db.chat import log_agent_sub_question_results
def persist_agent_results(state: MainState, config: RunnableConfig) -> MainOutput:
"""
LangGraph node to persist the agent results, including agent logging data.
"""
node_start_time = datetime.now()
agent_start_time = state.agent_start_time
agent_base_end_time = state.agent_base_end_time
agent_refined_start_time = state.agent_refined_start_time
agent_refined_end_time = state.agent_refined_end_time
agent_end_time = agent_refined_end_time or agent_base_end_time
agent_base_duration = None
if agent_base_end_time and agent_start_time:
agent_base_duration = (agent_base_end_time - agent_start_time).total_seconds()
agent_refined_duration = None
if agent_refined_start_time and agent_refined_end_time:
agent_refined_duration = (
agent_refined_end_time - agent_refined_start_time
).total_seconds()
agent_full_duration = None
if agent_end_time and agent_start_time:
agent_full_duration = (agent_end_time - agent_start_time).total_seconds()
agent_type = "refined" if agent_refined_duration else "base"
agent_base_metrics = state.agent_base_metrics
agent_refined_metrics = state.agent_refined_metrics
combined_agent_metrics = CombinedAgentMetrics(
timings=AgentTimings(
base_duration_s=agent_base_duration,
refined_duration_s=agent_refined_duration,
full_duration_s=agent_full_duration,
),
base_metrics=agent_base_metrics,
refined_metrics=agent_refined_metrics,
additional_metrics=AgentAdditionalMetrics(),
)
persona_id = None
graph_config = cast(GraphConfig, config["metadata"]["config"])
if graph_config.inputs.persona:
persona_id = graph_config.inputs.persona.id
user_id = None
assert (
graph_config.tooling.search_tool
), "search_tool must be provided for agentic search"
user = graph_config.tooling.search_tool.user
if user:
user_id = user.id
# log the agent metrics
if graph_config.persistence:
if agent_base_duration is not None:
log_agent_metrics(
db_session=graph_config.persistence.db_session,
user_id=user_id,
persona_id=persona_id,
agent_type=agent_type,
start_time=agent_start_time,
agent_metrics=combined_agent_metrics,
)
# Persist the sub-answer in the database
db_session = graph_config.persistence.db_session
chat_session_id = graph_config.persistence.chat_session_id
primary_message_id = graph_config.persistence.message_id
sub_question_answer_results = state.sub_question_results
log_agent_sub_question_results(
db_session=db_session,
chat_session_id=chat_session_id,
primary_message_id=primary_message_id,
sub_question_answer_results=sub_question_answer_results,
)
main_output = MainOutput(
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="persist agent results",
node_start_time=node_start_time,
)
],
)
for log_message in state.log_messages:
logger.debug(log_message)
if state.agent_base_metrics:
logger.debug(f"Initial loop: {state.agent_base_metrics.duration_s}")
if state.agent_refined_metrics:
logger.debug(f"Refined loop: {state.agent_refined_metrics.duration_s}")
if (
state.agent_base_metrics
and state.agent_refined_metrics
and state.agent_base_metrics.duration_s
and state.agent_refined_metrics.duration_s
):
logger.debug(
f"Total time: {float(state.agent_base_metrics.duration_s) + float(state.agent_refined_metrics.duration_s)}"
)
return main_output

View File

@@ -1,52 +0,0 @@
from datetime import datetime
from typing import cast
from langchain_core.runnables import RunnableConfig
from onyx.agents.agent_search.deep_search.main.states import (
ExploratorySearchUpdate,
)
from onyx.agents.agent_search.deep_search.main.states import MainState
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
build_history_prompt,
)
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import retrieve_search_docs
from onyx.configs.agent_configs import AGENT_EXPLORATORY_SEARCH_RESULTS
from onyx.context.search.models import InferenceSection
def start_agent_search(
state: MainState, config: RunnableConfig
) -> ExploratorySearchUpdate:
"""
LangGraph node to start the agentic search process.
"""
node_start_time = datetime.now()
graph_config = cast(GraphConfig, config["metadata"]["config"])
question = graph_config.inputs.prompt_builder.raw_user_query
history = build_history_prompt(graph_config, question)
# Initial search to inform decomposition. Just get top 3 fits
search_tool = graph_config.tooling.search_tool
assert search_tool, "search_tool must be provided for agentic search"
retrieved_docs: list[InferenceSection] = retrieve_search_docs(search_tool, question)
exploratory_search_results = retrieved_docs[:AGENT_EXPLORATORY_SEARCH_RESULTS]
return ExploratorySearchUpdate(
exploratory_search_results=exploratory_search_results,
previous_history_summary=history,
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="start agent search",
node_start_time=node_start_time,
)
],
)

View File

@@ -1,148 +0,0 @@
from collections.abc import Callable
from langgraph.types import StreamWriter
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkRetrievalStats
from onyx.agents.agent_search.shared_graph_utils.models import InitialAgentResultStats
from onyx.agents.agent_search.shared_graph_utils.models import QueryRetrievalResult
from onyx.agents.agent_search.shared_graph_utils.models import (
SubQuestionAnswerResults,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import StreamStopInfo
from onyx.chat.models import StreamStopReason
from onyx.chat.models import StreamType
from onyx.chat.models import SubQuestionPiece
from onyx.tools.models import SearchQueryInfo
from onyx.utils.logger import setup_logger
logger = setup_logger()
def dispatch_subquestion(
level: int, writer: StreamWriter
) -> Callable[[str, int], None]:
def _helper(sub_question_part: str, sep_num: int) -> None:
write_custom_event(
"decomp_qs",
SubQuestionPiece(
sub_question=sub_question_part,
level=level,
level_question_num=sep_num,
),
writer,
)
return _helper
def dispatch_subquestion_sep(level: int, writer: StreamWriter) -> Callable[[int], None]:
def _helper(sep_num: int) -> None:
write_custom_event(
"stream_finished",
StreamStopInfo(
stop_reason=StreamStopReason.FINISHED,
stream_type=StreamType.SUB_QUESTIONS,
level=level,
level_question_num=sep_num,
),
writer,
)
return _helper
def calculate_initial_agent_stats(
decomp_answer_results: list[SubQuestionAnswerResults],
original_question_stats: AgentChunkRetrievalStats,
) -> InitialAgentResultStats:
initial_agent_result_stats: InitialAgentResultStats = InitialAgentResultStats(
sub_questions={},
original_question={},
agent_effectiveness={},
)
orig_verified = original_question_stats.verified_count
orig_support_score = original_question_stats.verified_avg_scores
verified_document_chunk_ids = []
support_scores = 0.0
for decomp_answer_result in decomp_answer_results:
verified_document_chunk_ids += (
decomp_answer_result.sub_question_retrieval_stats.verified_doc_chunk_ids
)
if (
decomp_answer_result.sub_question_retrieval_stats.verified_avg_scores
is not None
):
support_scores += (
decomp_answer_result.sub_question_retrieval_stats.verified_avg_scores
)
verified_document_chunk_ids = list(set(verified_document_chunk_ids))
# Calculate sub-question stats
if (
verified_document_chunk_ids
and len(verified_document_chunk_ids) > 0
and support_scores is not None
):
sub_question_stats: dict[str, float | int | None] = {
"num_verified_documents": len(verified_document_chunk_ids),
"verified_avg_score": float(support_scores / len(decomp_answer_results)),
}
else:
sub_question_stats = {"num_verified_documents": 0, "verified_avg_score": None}
initial_agent_result_stats.sub_questions.update(sub_question_stats)
# Get original question stats
initial_agent_result_stats.original_question.update(
{
"num_verified_documents": original_question_stats.verified_count,
"verified_avg_score": original_question_stats.verified_avg_scores,
}
)
# Calculate chunk utilization ratio
sub_verified = initial_agent_result_stats.sub_questions["num_verified_documents"]
chunk_ratio: float | None = None
if sub_verified is not None and orig_verified is not None and orig_verified > 0:
chunk_ratio = (float(sub_verified) / orig_verified) if sub_verified > 0 else 0.0
elif sub_verified is not None and sub_verified > 0:
chunk_ratio = 10.0
initial_agent_result_stats.agent_effectiveness["utilized_chunk_ratio"] = chunk_ratio
if (
orig_support_score is None
or orig_support_score == 0.0
and initial_agent_result_stats.sub_questions["verified_avg_score"] is None
):
initial_agent_result_stats.agent_effectiveness["support_ratio"] = None
elif orig_support_score is None or orig_support_score == 0.0:
initial_agent_result_stats.agent_effectiveness["support_ratio"] = 10
elif initial_agent_result_stats.sub_questions["verified_avg_score"] is None:
initial_agent_result_stats.agent_effectiveness["support_ratio"] = 0
else:
initial_agent_result_stats.agent_effectiveness["support_ratio"] = (
initial_agent_result_stats.sub_questions["verified_avg_score"]
/ orig_support_score
)
return initial_agent_result_stats
def get_query_info(results: list[QueryRetrievalResult]) -> SearchQueryInfo:
# Use the query info from the base document retrieval
# this is used for some fields that are the same across the searches done
query_info = None
for result in results:
if result.query_info is not None:
query_info = result.query_info
break
assert query_info is not None, "must have query info"
return query_info

View File

@@ -1,175 +0,0 @@
from datetime import datetime
from operator import add
from typing import Annotated
from typing import TypedDict
from pydantic import BaseModel
from onyx.agents.agent_search.core_state import CoreState
from onyx.agents.agent_search.deep_search.main.models import AgentBaseMetrics
from onyx.agents.agent_search.deep_search.main.models import (
AgentRefinedMetrics,
)
from onyx.agents.agent_search.deep_search.main.models import (
RefinementSubQuestion,
)
from onyx.agents.agent_search.orchestration.states import ToolCallUpdate
from onyx.agents.agent_search.orchestration.states import ToolChoiceInput
from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkRetrievalStats
from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLog
from onyx.agents.agent_search.shared_graph_utils.models import (
EntityRelationshipTermExtraction,
)
from onyx.agents.agent_search.shared_graph_utils.models import InitialAgentResultStats
from onyx.agents.agent_search.shared_graph_utils.models import QueryRetrievalResult
from onyx.agents.agent_search.shared_graph_utils.models import RefinedAgentStats
from onyx.agents.agent_search.shared_graph_utils.models import (
SubQuestionAnswerResults,
)
from onyx.agents.agent_search.shared_graph_utils.operators import (
dedup_inference_sections,
)
from onyx.agents.agent_search.shared_graph_utils.operators import (
dedup_question_answer_results,
)
from onyx.context.search.models import InferenceSection
### States ###
class LoggerUpdate(BaseModel):
log_messages: Annotated[list[str], add] = []
class RefinedAgentStartStats(BaseModel):
agent_refined_start_time: datetime | None = None
class RefinedAgentEndStats(BaseModel):
agent_refined_end_time: datetime | None = None
agent_refined_metrics: AgentRefinedMetrics = AgentRefinedMetrics()
class InitialQuestionDecompositionUpdate(
RefinedAgentStartStats, RefinedAgentEndStats, LoggerUpdate
):
agent_start_time: datetime | None = None
previous_history: str | None = None
initial_sub_questions: list[str] = []
class ExploratorySearchUpdate(LoggerUpdate):
exploratory_search_results: list[InferenceSection] = []
previous_history_summary: str | None = None
class InitialRefinedAnswerComparisonUpdate(LoggerUpdate):
"""
Evaluation of whether the refined answer is better than the initial answer
"""
refined_answer_improvement_eval: bool = False
class InitialAnswerUpdate(LoggerUpdate):
"""
Initial answer information
"""
initial_answer: str | None = None
answer_error: AgentErrorLog | None = None
initial_agent_stats: InitialAgentResultStats | None = None
generated_sub_questions: list[str] = []
agent_base_end_time: datetime | None = None
agent_base_metrics: AgentBaseMetrics | None = None
class RefinedAnswerUpdate(RefinedAgentEndStats, LoggerUpdate):
"""
Refined answer information
"""
refined_answer: str | None = None
answer_error: AgentErrorLog | None = None
refined_agent_stats: RefinedAgentStats | None = None
refined_answer_quality: bool = False
class InitialAnswerQualityUpdate(LoggerUpdate):
"""
Initial answer quality evaluation
"""
initial_answer_quality_eval: bool = False
class RequireRefinemenEvalUpdate(LoggerUpdate):
require_refined_answer_eval: bool = True
class SubQuestionResultsUpdate(LoggerUpdate):
verified_reranked_documents: Annotated[
list[InferenceSection], dedup_inference_sections
] = []
context_documents: Annotated[list[InferenceSection], dedup_inference_sections] = []
cited_documents: Annotated[list[InferenceSection], dedup_inference_sections] = (
[]
) # cited docs from sub-answers are used for answer context
sub_question_results: Annotated[
list[SubQuestionAnswerResults], dedup_question_answer_results
] = []
class OrigQuestionRetrievalUpdate(LoggerUpdate):
orig_question_retrieved_documents: Annotated[
list[InferenceSection], dedup_inference_sections
]
orig_question_verified_reranked_documents: Annotated[
list[InferenceSection], dedup_inference_sections
]
orig_question_sub_query_retrieval_results: list[QueryRetrievalResult] = []
orig_question_retrieval_stats: AgentChunkRetrievalStats = AgentChunkRetrievalStats()
class EntityTermExtractionUpdate(LoggerUpdate):
entity_relation_term_extractions: EntityRelationshipTermExtraction = (
EntityRelationshipTermExtraction()
)
class RefinedQuestionDecompositionUpdate(RefinedAgentStartStats, LoggerUpdate):
refined_sub_questions: dict[int, RefinementSubQuestion] = {}
## Graph Input State
class MainInput(CoreState):
pass
## Graph State
class MainState(
# This includes the core state
MainInput,
ToolChoiceInput,
ToolCallUpdate,
ToolChoiceUpdate,
InitialQuestionDecompositionUpdate,
InitialAnswerUpdate,
SubQuestionResultsUpdate,
OrigQuestionRetrievalUpdate,
EntityTermExtractionUpdate,
InitialAnswerQualityUpdate,
RequireRefinemenEvalUpdate,
RefinedQuestionDecompositionUpdate,
RefinedAnswerUpdate,
RefinedAgentStartStats,
RefinedAgentEndStats,
InitialRefinedAnswerComparisonUpdate,
ExploratorySearchUpdate,
):
pass
## Graph Output State - presently not used
class MainOutput(TypedDict):
log_messages: list[str]

View File

@@ -1,33 +0,0 @@
from collections.abc import Hashable
from datetime import datetime
from langgraph.types import Send
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
SubQuestionAnsweringInput,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
ExpandedRetrievalInput,
)
from onyx.utils.logger import setup_logger
logger = setup_logger()
def send_to_expanded_refined_retrieval(
state: SubQuestionAnsweringInput,
) -> Send | Hashable:
"""
LangGraph edge to sends a refined sub-question extended retrieval.
"""
logger.debug("sending to expanded retrieval for follow up question via edge")
datetime.now()
return Send(
"refined_sub_question_expanded_retrieval",
ExpandedRetrievalInput(
question=state.question,
sub_question_id=state.question_id,
base_search=False,
log_messages=[f"{datetime.now()} -- Sending to expanded retrieval"],
),
)

View File

@@ -1,132 +0,0 @@
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.nodes.check_sub_answer import (
check_sub_answer,
)
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.nodes.format_sub_answer import (
format_sub_answer,
)
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.nodes.generate_sub_answer import (
generate_sub_answer,
)
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.nodes.ingest_retrieved_documents import (
ingest_retrieved_documents,
)
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
AnswerQuestionOutput,
)
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
AnswerQuestionState,
)
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
SubQuestionAnsweringInput,
)
from onyx.agents.agent_search.deep_search.refinement.consolidate_sub_answers.edges import (
send_to_expanded_refined_retrieval,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.graph_builder import (
expanded_retrieval_graph_builder,
)
from onyx.utils.logger import setup_logger
logger = setup_logger()
def answer_refined_query_graph_builder() -> StateGraph:
"""
LangGraph graph builder for the refined sub-answer generation process.
"""
graph = StateGraph(
state_schema=AnswerQuestionState,
input=SubQuestionAnsweringInput,
output=AnswerQuestionOutput,
)
### Add nodes ###
# Subgraph for the expanded retrieval process
expanded_retrieval = expanded_retrieval_graph_builder().compile()
graph.add_node(
node="refined_sub_question_expanded_retrieval",
action=expanded_retrieval,
)
# Ingest the retrieved documents
graph.add_node(
node="ingest_refined_retrieval",
action=ingest_retrieved_documents,
)
# Generate the refined sub-answer
graph.add_node(
node="generate_refined_sub_answer",
action=generate_sub_answer,
)
# Check if the refined sub-answer is correct
graph.add_node(
node="refined_sub_answer_check",
action=check_sub_answer,
)
# Format the refined sub-answer
graph.add_node(
node="format_refined_sub_answer",
action=format_sub_answer,
)
### Add edges ###
graph.add_conditional_edges(
source=START,
path=send_to_expanded_refined_retrieval,
path_map=["refined_sub_question_expanded_retrieval"],
)
graph.add_edge(
start_key="refined_sub_question_expanded_retrieval",
end_key="ingest_refined_retrieval",
)
graph.add_edge(
start_key="ingest_refined_retrieval",
end_key="generate_refined_sub_answer",
)
graph.add_edge(
start_key="generate_refined_sub_answer",
end_key="refined_sub_answer_check",
)
graph.add_edge(
start_key="refined_sub_answer_check",
end_key="format_refined_sub_answer",
)
graph.add_edge(
start_key="format_refined_sub_answer",
end_key=END,
)
return graph
if __name__ == "__main__":
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.llm.factory import get_default_llms
from onyx.context.search.models import SearchRequest
graph = answer_refined_query_graph_builder()
compiled_graph = graph.compile()
primary_llm, fast_llm = get_default_llms()
search_request = SearchRequest(
query="what can you do with onyx or danswer?",
)
with get_session_with_current_tenant() as db_session:
inputs = SubQuestionAnsweringInput(
question="what can you do with onyx?",
question_id="0_0",
log_messages=[],
)
for thing in compiled_graph.stream(
input=inputs,
stream_mode="custom",
):
logger.debug(thing)

View File

@@ -1,44 +0,0 @@
from collections.abc import Hashable
from typing import cast
from langchain_core.runnables.config import RunnableConfig
from langgraph.types import Send
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
ExpandedRetrievalState,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
RetrievalInput,
)
from onyx.agents.agent_search.models import GraphConfig
def parallel_retrieval_edge(
state: ExpandedRetrievalState, config: RunnableConfig
) -> list[Send | Hashable]:
"""
LangGraph edge to parallelize the retrieval process for each of the
generated sub-queries and the original question.
"""
graph_config = cast(GraphConfig, config["metadata"]["config"])
question = (
state.question
if state.question
else graph_config.inputs.prompt_builder.raw_user_query
)
query_expansions = state.expanded_queries + [question]
return [
Send(
"retrieve_documents",
RetrievalInput(
query_to_retrieve=query,
question=question,
base_search=False,
sub_question_id=state.sub_question_id,
log_messages=[],
),
)
for query in query_expansions
]

View File

@@ -1,161 +0,0 @@
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.edges import (
parallel_retrieval_edge,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.nodes.expand_queries import (
expand_queries,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.nodes.format_queries import (
format_queries,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.nodes.format_results import (
format_results,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.nodes.kickoff_verification import (
kickoff_verification,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.nodes.rerank_documents import (
rerank_documents,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.nodes.retrieve_documents import (
retrieve_documents,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.nodes.verify_documents import (
verify_documents,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
ExpandedRetrievalInput,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
ExpandedRetrievalOutput,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
ExpandedRetrievalState,
)
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
from onyx.utils.logger import setup_logger
logger = setup_logger()
def expanded_retrieval_graph_builder() -> StateGraph:
"""
LangGraph graph builder for the expanded retrieval process.
"""
graph = StateGraph(
state_schema=ExpandedRetrievalState,
input=ExpandedRetrievalInput,
output=ExpandedRetrievalOutput,
)
### Add nodes ###
# Convert the question into multiple sub-queries
graph.add_node(
node="expand_queries",
action=expand_queries,
)
# Format the sub-queries into a list of strings
graph.add_node(
node="format_queries",
action=format_queries,
)
# Retrieve the documents for each sub-query
graph.add_node(
node="retrieve_documents",
action=retrieve_documents,
)
# Start verification process that the documents are relevant to the question (not the query)
graph.add_node(
node="kickoff_verification",
action=kickoff_verification,
)
# Verify that a given document is relevant to the question (not the query)
graph.add_node(
node="verify_documents",
action=verify_documents,
)
# Rerank the documents that have been verified
graph.add_node(
node="rerank_documents",
action=rerank_documents,
)
# Format the results into a list of strings
graph.add_node(
node="format_results",
action=format_results,
)
### Add edges ###
graph.add_edge(
start_key=START,
end_key="expand_queries",
)
graph.add_edge(
start_key="expand_queries",
end_key="format_queries",
)
graph.add_conditional_edges(
source="format_queries",
path=parallel_retrieval_edge,
path_map=["retrieve_documents"],
)
graph.add_edge(
start_key="retrieve_documents",
end_key="kickoff_verification",
)
graph.add_edge(
start_key="verify_documents",
end_key="rerank_documents",
)
graph.add_edge(
start_key="rerank_documents",
end_key="format_results",
)
graph.add_edge(
start_key="format_results",
end_key=END,
)
return graph
if __name__ == "__main__":
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.llm.factory import get_default_llms
from onyx.context.search.models import SearchRequest
graph = expanded_retrieval_graph_builder()
compiled_graph = graph.compile()
primary_llm, fast_llm = get_default_llms()
search_request = SearchRequest(
query="what can you do with onyx or danswer?",
)
with get_session_with_current_tenant() as db_session:
graph_config, search_tool = get_test_config(
db_session, primary_llm, fast_llm, search_request
)
inputs = ExpandedRetrievalInput(
question="what can you do with onyx?",
base_search=False,
sub_question_id=None,
log_messages=[],
)
for thing in compiled_graph.stream(
input=inputs,
config={"configurable": {"config": graph_config}},
stream_mode="custom",
subgraphs=True,
):
logger.debug(thing)

View File

@@ -1,13 +0,0 @@
from pydantic import BaseModel
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkRetrievalStats
from onyx.agents.agent_search.shared_graph_utils.models import QueryRetrievalResult
from onyx.context.search.models import InferenceSection
class QuestionRetrievalResult(BaseModel):
expanded_query_results: list[QueryRetrievalResult] = []
retrieved_documents: list[InferenceSection] = []
verified_reranked_documents: list[InferenceSection] = []
context_documents: list[InferenceSection] = []
retrieval_stats: AgentChunkRetrievalStats = AgentChunkRetrievalStats()

View File

@@ -1,139 +0,0 @@
from datetime import datetime
from typing import cast
from langchain_core.messages import HumanMessage
from langchain_core.messages import merge_message_runs
from langchain_core.runnables.config import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.operations import (
dispatch_subquery,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
ExpandedRetrievalInput,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
QueryExpansionUpdate,
)
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_LLM_RATELIMIT_MESSAGE,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_LLM_TIMEOUT_MESSAGE,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AgentLLMErrorType,
)
from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLog
from onyx.agents.agent_search.shared_graph_utils.models import BaseMessage_Content
from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings
from onyx.agents.agent_search.shared_graph_utils.utils import dispatch_separated
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_SUBQUERY_GENERATION
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_CONNECT_LLM_QUERY_REWRITING_GENERATION,
)
from onyx.configs.agent_configs import AGENT_TIMEOUT_LLM_QUERY_REWRITING_GENERATION
from onyx.llm.chat_llm import LLMRateLimitError
from onyx.llm.chat_llm import LLMTimeoutError
from onyx.prompts.agent_search import (
QUERY_REWRITING_PROMPT,
)
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
from onyx.utils.timing import log_function_time
logger = setup_logger()
_llm_node_error_strings = LLMNodeErrorStrings(
timeout="Query rewriting failed due to LLM timeout - the original question will be used.",
rate_limit="Query rewriting failed due to LLM rate limit - the original question will be used.",
general_error="Query rewriting failed due to LLM error - the original question will be used.",
)
@log_function_time(print_only=True)
def expand_queries(
state: ExpandedRetrievalInput,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> QueryExpansionUpdate:
"""
LangGraph node to expand a question into multiple search queries.
"""
# Sometimes we want to expand the original question, sometimes we want to expand a sub-question.
# When we are running this node on the original question, no question is explictly passed in.
# Instead, we use the original question from the search request.
graph_config = cast(GraphConfig, config["metadata"]["config"])
node_start_time = datetime.now()
question = state.question
model = graph_config.tooling.fast_llm
sub_question_id = state.sub_question_id
if sub_question_id is None:
level, question_num = 0, 0
else:
level, question_num = parse_question_id(sub_question_id)
msg = [
HumanMessage(
content=QUERY_REWRITING_PROMPT.format(question=question),
)
]
agent_error: AgentErrorLog | None = None
llm_response_list: list[BaseMessage_Content] = []
llm_response = ""
rewritten_queries = []
try:
llm_response_list = run_with_timeout(
AGENT_TIMEOUT_LLM_QUERY_REWRITING_GENERATION,
dispatch_separated,
model.stream(
prompt=msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_QUERY_REWRITING_GENERATION,
max_tokens=AGENT_MAX_TOKENS_SUBQUERY_GENERATION,
),
dispatch_subquery(level, question_num, writer),
)
llm_response = merge_message_runs(llm_response_list, chunk_separator="")[
0
].content
rewritten_queries = llm_response.split("\n")
log_result = f"Number of expanded queries: {len(rewritten_queries)}"
except (LLMTimeoutError, TimeoutError):
agent_error = AgentErrorLog(
error_type=AgentLLMErrorType.TIMEOUT,
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
error_result=_llm_node_error_strings.timeout,
)
logger.error("LLM Timeout Error - expand queries")
log_result = agent_error.error_result
except LLMRateLimitError:
agent_error = AgentErrorLog(
error_type=AgentLLMErrorType.RATE_LIMIT,
error_message=AGENT_LLM_RATELIMIT_MESSAGE,
error_result=_llm_node_error_strings.rate_limit,
)
logger.error("LLM Rate Limit Error - expand queries")
log_result = agent_error.error_result
# use subquestion as query if query generation fails
return QueryExpansionUpdate(
expanded_queries=rewritten_queries,
log_messages=[
get_langgraph_node_log_string(
graph_component="shared - expanded retrieval",
node_name="expand queries",
node_start_time=node_start_time,
result=log_result,
)
],
)

View File

@@ -1,19 +0,0 @@
from langchain_core.runnables.config import RunnableConfig
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
ExpandedRetrievalState,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
QueryExpansionUpdate,
)
def format_queries(
state: ExpandedRetrievalState, config: RunnableConfig
) -> QueryExpansionUpdate:
"""
LangGraph node to format the expanded queries into a list of strings.
"""
return QueryExpansionUpdate(
expanded_queries=state.expanded_queries,
)

View File

@@ -1,91 +0,0 @@
from typing import cast
from langchain_core.runnables.config import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.deep_search.main.operations import get_query_info
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.models import (
QuestionRetrievalResult,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.operations import (
calculate_sub_question_retrieval_stats,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
ExpandedRetrievalState,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
ExpandedRetrievalUpdate,
)
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkRetrievalStats
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
from onyx.agents.agent_search.shared_graph_utils.utils import relevance_from_docs
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import ExtendedToolResponse
from onyx.tools.tool_implementations.search.search_tool import yield_search_responses
def format_results(
state: ExpandedRetrievalState,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> ExpandedRetrievalUpdate:
"""
LangGraph node that constructs the proper expanded retrieval format.
"""
level, question_num = parse_question_id(state.sub_question_id or "0_0")
query_info = get_query_info(state.query_retrieval_results)
graph_config = cast(GraphConfig, config["metadata"]["config"])
# Main question docs will be sent later after aggregation and deduping with sub-question docs
reranked_documents = state.reranked_documents
if not (level == 0 and question_num == 0):
if len(reranked_documents) == 0:
# The sub-question is used as the last query. If no verified documents are found, stream
# the top 3 for that one. We may want to revisit this.
reranked_documents = state.query_retrieval_results[-1].retrieved_documents[
:3
]
assert (
graph_config.tooling.search_tool
), "search_tool must be provided for agentic search"
relevance_list = relevance_from_docs(reranked_documents)
for tool_response in yield_search_responses(
query=state.question,
get_retrieved_sections=lambda: reranked_documents,
get_final_context_sections=lambda: reranked_documents,
search_query_info=query_info,
get_section_relevance=lambda: relevance_list,
search_tool=graph_config.tooling.search_tool,
):
write_custom_event(
"tool_response",
ExtendedToolResponse(
id=tool_response.id,
response=tool_response.response,
level=level,
level_question_num=question_num,
),
writer,
)
sub_question_retrieval_stats = calculate_sub_question_retrieval_stats(
verified_documents=state.verified_documents,
expanded_retrieval_results=state.query_retrieval_results,
)
if sub_question_retrieval_stats is None:
sub_question_retrieval_stats = AgentChunkRetrievalStats()
return ExpandedRetrievalUpdate(
expanded_retrieval_result=QuestionRetrievalResult(
expanded_query_results=state.query_retrieval_results,
retrieved_documents=state.retrieved_documents,
verified_reranked_documents=reranked_documents,
context_documents=state.reranked_documents,
retrieval_stats=sub_question_retrieval_stats,
),
)

View File

@@ -1,45 +0,0 @@
from typing import Literal
from langchain_core.runnables.config import RunnableConfig
from langgraph.types import Command
from langgraph.types import Send
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
DocVerificationInput,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
ExpandedRetrievalState,
)
from onyx.configs.agent_configs import AGENT_MAX_VERIFICATION_HITS
def kickoff_verification(
state: ExpandedRetrievalState,
config: RunnableConfig,
) -> Command[Literal["verify_documents"]]:
"""
LangGraph node (Command node!) that kicks off the verification process for the retrieved documents.
Note that this is a Command node and does the routing as well. (At present, no state updates
are done here, so this could be replaced with an edge. But we may choose to make state
updates later.)
"""
retrieved_documents = state.retrieved_documents[:AGENT_MAX_VERIFICATION_HITS]
verification_question = state.question
sub_question_id = state.sub_question_id
return Command(
update={},
goto=[
Send(
node="verify_documents",
arg=DocVerificationInput(
retrieved_document_to_verify=document,
question=verification_question,
base_search=False,
sub_question_id=sub_question_id,
log_messages=[],
),
)
for document in retrieved_documents
],
)

Some files were not shown because too many files have changed in this diff Show More