Compare commits

..

138 Commits

Author SHA1 Message Date
Justin Tahara
3f8ef8b465 fix(celery): Guardrail for User File Processing (#8633) 2026-03-01 10:29:55 -08:00
Justin Tahara
ed46504a1a fix(gong): Respecting Retry Timeout Header (#8866) 2026-02-27 14:22:34 -08:00
Nikolas Garza
7a24b34516 fix(slack): sanitize HTML tags and broken citation links in bot responses (#8767) 2026-02-26 17:27:31 -08:00
dependabot[bot]
7a7ffa9051 chore(deps): Bump mistune from 0.8.4 to 3.1.4 in /backend (#6407)
Co-authored-by: Jamison Lahman <jamison@lahman.dev>
2026-02-26 17:27:31 -08:00
Jamison Lahman
3053ab518c chore(devtools): upgrade ods: v0.6.1->v0.6.2 (#8773) 2026-02-26 16:26:55 -08:00
justin-tahara
be38d3500f Fixing mypy 2026-02-09 15:48:53 -08:00
Justin Tahara
753a3bc093 fix(posthog): Chat metrics for Cloud (#8278) 2026-02-09 15:48:53 -08:00
Raunak Bhagat
2ba8fafe78 fix: Add explicit sizings to icons (#8018) 2026-02-06 18:15:47 -08:00
Raunak Bhagat
b77b580ebd Cherry-pick card fix 2026-02-06 18:15:40 -08:00
Nikolas Garza
3eee98b932 fix: make it more clear how to add channels to fed slack config form (#8227) 2026-02-06 16:35:46 -08:00
Nikolas Garza
a97eb02fef fix(db): null out document set and persona ownership on user deletion (#8219) 2026-02-06 16:35:46 -08:00
Justin Tahara
c5061495a2 fix(ui): Inconsistent LLM Provider Logo (#8220) 2026-02-06 13:56:57 -08:00
Justin Tahara
c20b0789ae fix(ui): Additional LLM Config update (#8174) 2026-02-06 13:56:49 -08:00
Justin Tahara
d99848717b fix(ui): Ollama Model Selection (#8091) 2026-02-06 13:53:52 -08:00
Evan Lohn
aaca55c415 fix(salesforce): cleanup logic (#8175) 2026-02-06 13:52:46 -08:00
Justin Tahara
9d7ffd1e4a fix(ui): Updating Dropdown Modal component (#8033) 2026-02-06 11:39:48 -08:00
Justin Tahara
a249161827 chore(chat): Cleaning Error Codes + Tests (#8186) 2026-02-06 11:39:36 -08:00
Justin Tahara
e126346a91 fix(agents): Removing Label Dependency (#8189) 2026-02-06 11:03:16 -08:00
Justin Tahara
a96682fa73 fix(ui): Agent Saving with other people files (#8095) 2026-02-02 10:30:46 -08:00
Justin Tahara
3920371d56 feat(desktop): Ensure that UI reflects Light/Dark Toggle (#7684) 2026-02-02 10:30:36 -08:00
Wenxi Onyx
e5a257345c 2nd dummy commit (noop README change) to fix beta tag on docker 2026-01-31 11:17:12 -08:00
Wenxi Onyx
a49df511e2 dummy commit (noop README change) to fix beta tag on docker 2026-01-31 11:09:41 -08:00
Justin Tahara
d5d2a8a1a6 fix(desktop): Remove Global Shortcuts (#7914) 2026-01-30 13:46:26 -08:00
Justin Tahara
b2f46b264c fix(asana): Workspace Team ID mismatch (#7674) 2026-01-30 13:19:07 -08:00
Jamison Lahman
c6ad363fbd chore(mypy): fix mypy cache issues switching between HEAD and release (#7732) 2026-01-27 15:52:53 -08:00
Jamison Lahman
e313119f9a fix(citations): enable citation sidebar w/ web_search-only assistants (#7888) 2026-01-27 14:50:00 -08:00
Wenxi
3a2a542a03 fix: connector details back button should nav back (#7869) 2026-01-27 14:35:15 -08:00
Yuhong Sun
413aeba4a1 fix: Project Creation (#7851) 2026-01-27 14:34:59 -08:00
Wenxi
46028aa2bb fix: user count check (#7811) 2026-01-27 14:34:29 -08:00
Justin Tahara
454943c4a6 fix(llm): Hide private models from Agent Creation (#7873) 2026-01-27 14:33:40 -08:00
Justin Tahara
87946266de fix(redis): Adding more TTLs (#7886) 2026-01-27 14:32:14 -08:00
Jamison Lahman
144030c5ca chore(vscode): add non-clean seeded db restore (#7795) 2026-01-26 08:55:19 -08:00
SubashMohan
a557d76041 feat(ui): add new icons and enhance FadeDiv, Modal, Tabs, ExpandableTextDisplay (#7563)
Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-26 10:26:09 +00:00
SubashMohan
605e808158 fix(layout): adjust footer margin and prevent page refresh on chatsession drop (#7759) 2026-01-26 04:45:40 +00:00
roshan
8fec88c90d chore(deployment): remove no auth option from setup script (#7784) 2026-01-26 04:42:45 +00:00
Yuhong Sun
e54969a693 fix: LiteLLM Azure models don't stream (#7761) 2026-01-25 07:46:51 +00:00
Raunak Bhagat
1da2b2f28f fix: Some new fixes that were discovered by AI reviewers during 2.9-hotfixing (#7757)
Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-25 04:44:30 +00:00
Nikolas Garza
eb7b91e08e fix(tests): use crawler-friendly search query in Exa integration test (#7746) 2026-01-24 20:58:02 +00:00
Yuhong Sun
3339000968 fix: Spacing issue on Feedback (#7747) 2026-01-24 12:59:00 -08:00
Nikolas Garza
d9db849e94 fix(chat): prevent streaming text from appearing in bursts after citations (#7745) 2026-01-24 11:48:34 -08:00
Yuhong Sun
046408359c fix: Azure OpenAI Tool Calls (#7727) 2026-01-24 01:47:03 +00:00
acaprau
4b8cca190f feat(opensearch): Implement complete retrieval filtering (#7691) 2026-01-23 23:27:42 +00:00
Justin Tahara
52a312a63b feat: onyx discord bot - supervisord and kube deployment (#7706) 2026-01-23 20:55:06 +00:00
Danelegend
0594fd17de chore(tests): add more packet tests (#7677) 2026-01-23 19:49:41 +00:00
Jamison Lahman
fded81dc28 chore(extensions): pull in chrome extension (#7703) 2026-01-23 10:17:05 -08:00
Danelegend
31db112de9 feat(url): Open url around snippet (#7488) 2026-01-23 17:02:38 +00:00
Jamison Lahman
a3e2da2c51 chore(vscode): add useful database operations (#7702) 2026-01-23 08:49:59 -08:00
Evan Lohn
f4d33bcc0d feat: basic user MCP action attaching (#7681) 2026-01-23 05:50:49 +00:00
Jamison Lahman
464d957494 chore(devtools): upgrade ods v0.4.0; vscode to restore seeded db (#7696) 2026-01-23 05:21:46 +00:00
Jamison Lahman
be12de9a44 chore(devtools): ods db restore --fetch-seeded (#7689) 2026-01-22 20:41:28 -08:00
Yuhong Sun
3e4a1f8a09 feat: Maintain correct docs on replay (#7683) 2026-01-22 19:24:10 -08:00
Raunak Bhagat
af9b7826ab fix: Remove cursor pointer from view-only field (#7688)
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
2026-01-23 02:47:08 +00:00
Danelegend
cb16eb13fc chore(tests): Mock LLM (#7590) 2026-01-23 01:48:54 +00:00
Jamison Lahman
20a73bdd2e chore(desktop): make artifact filename version-agnostic (#7679) 2026-01-22 15:15:52 -08:00
Justin Tahara
85cc2b99b7 fix(fastapi): Resolve CVE-2025-68481 (#7661) 2026-01-22 20:07:25 +00:00
Jamison Lahman
1208a3ee2b chore(fe): disable blur when there is not a custom background (#7673)
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
2026-01-22 11:26:16 -08:00
Justin Tahara
900fcef9dd feat(desktop): Domain Configuration (#7655) 2026-01-22 18:15:44 +00:00
Justin Tahara
d4ed25753b fix(ui): Coda Logo (#7656) 2026-01-22 10:10:02 -08:00
Justin Tahara
0ee58333b4 fix(ui): User Groups Connectors Fix (#7658) 2026-01-22 17:59:12 +00:00
Justin Tahara
11b7e0d571 fix(ui): First Connector Result (#7657) 2026-01-22 17:52:02 +00:00
acaprau
a35831f328 fix(opensearch): Release Onyx Helm Charts was failing (#7672) 2026-01-22 17:41:47 +00:00
Justin Tahara
048a6d5259 fix(ui): Fix Token Rate Limits Page (#7659) 2026-01-22 17:20:21 +00:00
Ciaran Sweet
e4bdb15910 docs: enhance send-chat-message docs to also show ChatFullResponse (#7430) 2026-01-22 16:48:26 +00:00
Jamison Lahman
3517d59286 chore(fe): add custom backgrounds to the settings page (#7668) 2026-01-21 21:32:56 -08:00
Jamison Lahman
4bc08e5d88 chore(fe): remove Text pseudo-element padding (#7665) 2026-01-21 19:50:42 -08:00
Yuhong Sun
4bd080cf62 chore: Redirect user to create account (#7654) 2026-01-22 02:44:58 +00:00
Raunak Bhagat
b0a8625ffc feat: Add confirmation modal for connector disconnect (#7637)
Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-22 02:08:19 +00:00
Yuhong Sun
f94baf6143 fix: DR Language Tuning (#7660) 2026-01-21 17:36:50 -08:00
Wenxi
9e1867638a feat: onyx discord bot - frontend (#7497) 2026-01-22 00:00:12 +00:00
Yuhong Sun
5b6d7c9f0d chore: Onboarding Image Generation (#7653) 2026-01-21 15:49:15 -08:00
Danelegend
e5dcf31f10 fix(image): Emit error to user (#7644) 2026-01-21 22:50:12 +00:00
Nikolas Garza
8ca06ef3e7 fix: deflake chat user journey test (#7646) 2026-01-21 22:33:30 +00:00
Justin Tahara
6897dbd610 feat(desktop): Properly Sign Mac App (#7608) 2026-01-21 22:17:45 +00:00
Evan Lohn
7f3cb77466 chore: remove prompt caching from chat history (#7636) 2026-01-21 21:35:11 +00:00
acaprau
267042a5aa fix(opensearch): Use the same method for getting title that the title embedding logic uses; small cleanup for content embedding (#7638) 2026-01-21 21:34:38 +00:00
Yuhong Sun
d02b3ae6ac chore: Remove default prompt shortcuts (#7639) 2026-01-21 21:28:53 +00:00
Yuhong Sun
683c3f7a7e fix: color mode and memories (#7642) 2026-01-21 13:29:33 -08:00
Nikolas Garza
008b4d2288 fix(slack): Extract person names and filter garbage in query expansion (#7632) 2026-01-21 21:09:50 +00:00
Jamison Lahman
8be261405a chore(deployments): fix region (#7640) 2026-01-21 13:14:42 -08:00
acaprau
61f2c48ebc feat(opensearch): Add helm charts (#7606) 2026-01-21 19:34:18 +00:00
acaprau
dbde2e6d6d chore(opensearch): Create OpenSearch docker compose, enabling test_opensearch_client.py to run in CI (#7611) 2026-01-21 18:41:23 +00:00
Raunak Bhagat
2860136214 feat: Refreshed user settings page (#7455)
Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-21 16:41:56 +00:00
Raunak Bhagat
49ec5994d3 refactor: Improve refresh-components with cleanup and truncation (#7622)
Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-21 00:29:25 -08:00
Raunak Bhagat
8d5fb67f0f feat: improve prompt shortcuts with uniqueness constraints and enhancements (#7619)
Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-21 07:31:35 +00:00
Raunak Bhagat
15d02f6e3c fix: Prevent description duplication in Modal header (#7609)
Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-21 04:32:22 +00:00
Jamison Lahman
e58974c419 chore(fe): move chatpage footer inside background element (#7618) 2026-01-21 04:21:49 +00:00
Yuhong Sun
6b66c07952 chore: Delete multilingual docker compose file (#7616) 2026-01-20 19:50:01 -08:00
Jamison Lahman
cae058a3ac chore(extensions): simplify and de-dupe NRFPage (#7607) 2026-01-21 03:42:19 +00:00
Nikolas Garza
aa3b21a191 fix: scroll to bottom when loading existing conversations (#7614) 2026-01-20 19:19:18 -08:00
Raunak Bhagat
7a07a78696 fix: Set width to fit for rightChildren section in LineItem (#7604)
Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-21 01:55:03 +00:00
Nikolas Garza
a8db236e37 feat(billing): fetch Stripe publishable key from S3 (#7595) 2026-01-21 01:32:57 +00:00
Raunak Bhagat
8a2e4ed36f fix: Fix flashing in progress-circle icon (#7605)
Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-21 01:03:52 +00:00
Evan Lohn
216f2c95a7 chore: add dialog description to modal (#7603) 2026-01-21 00:41:35 +00:00
Evan Lohn
67081efe08 fix: modal header in index attempt errors (#7601) 2026-01-21 00:37:23 +00:00
Yuhong Sun
9d40b8336f feat: Allow no system prompt (#7600) 2026-01-20 16:16:39 -08:00
Evan Lohn
23f0033302 chore: bg services launch.json (#7597) 2026-01-21 00:05:20 +00:00
Raunak Bhagat
9011b76eb0 refactor: Add new layout component (#7588)
Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-20 23:36:18 +00:00
Yuhong Sun
45e436bafc fix: prompt tunings (#7594) 2026-01-20 15:13:05 -08:00
Justin Tahara
010bc36d61 Revert "chore(deps): Bump fastapi-users from 14.0.1 to 15.0.2 in /backend/requirements" (#7593) 2026-01-20 14:44:21 -08:00
dependabot[bot]
468e488bdb chore(deps): bump docker/setup-buildx-action from 3.11.1 to 3.12.0 (#7527)
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-01-20 22:36:39 +00:00
dependabot[bot]
9104c0ffce chore(deps): Bump fastapi-users from 14.0.1 to 15.0.2 in /backend/requirements (#6897)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: justin-tahara <justintahara@gmail.com>
2026-01-20 22:31:02 +00:00
Jamison Lahman
d36a6bd0b4 feat(fe): custom chat backgrounds (#7486)
Co-authored-by: cubic-dev-ai[bot] <191113872+cubic-dev-ai[bot]@users.noreply.github.com>
2026-01-20 14:29:06 -08:00
Jamison Lahman
a3603c498c chore(deployments): fetch secrets from AWS (#7584) 2026-01-20 22:10:19 +00:00
Jamison Lahman
8f274e34c9 chore(blame): unignore checked in .vscode/ files (#7592) 2026-01-20 14:07:27 -08:00
Justin Tahara
5c256760ff fix(vertex ai): Extra Args for Opus 4.5 (#7586) 2026-01-20 14:07:14 -08:00
Nikolas Garza
258e1372b3 fix(billing): remove grandfathered pricing option when subscription lapses (#7583) 2026-01-20 21:55:37 +00:00
Yuhong Sun
83a543a265 chore: NLTK and stopwords (#7587) 2026-01-20 13:36:04 -08:00
Evan Lohn
f9719d199d fix: drive connector creation ui (#7578) 2026-01-20 21:10:06 +00:00
Raunak Bhagat
1c7bb6e56a fix: Input variant refactor (#7579)
Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-20 13:04:16 -08:00
acaprau
982ad7d329 feat(opensearch): Add dual document indices (#7539) 2026-01-20 20:53:24 +00:00
Jamison Lahman
f94292808b chore(vscode): launch.template.jsonc -> launch.json (#7440) 2026-01-20 20:32:46 +00:00
Justin Tahara
293553a2e2 fix(tests): Anthropic Prompt Caching Test (#7585) 2026-01-20 20:32:24 +00:00
Justin Tahara
ba906ae6fa chore(llm): Removing Claude Haiku 3.5 (#7577) 2026-01-20 19:06:14 +00:00
Raunak Bhagat
c84c7a354e refactor: refactor to use string-enum props instead of boolean props (#7575)
Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-20 18:59:54 +00:00
Jamison Lahman
2187b0dd82 chore(pre-commit): disallow large files (#7576) 2026-01-20 11:02:00 -08:00
acaprau
d88a417bf9 feat(opensearch): Formally disable secondary indices in the backend (#7541) 2026-01-20 18:21:47 +00:00
Jamison Lahman
f2d32b0b3b fix(fe): inline code text wraps (#7574) 2026-01-20 17:11:42 +00:00
Nikolas Garza
f89432009f fix(fe): show scroll-down button when user scrolls up during streaming (#7562) 2026-01-20 07:07:55 +00:00
Jamison Lahman
8ab2bab34e chore(fe): fix sticky header parent height (#7561) 2026-01-20 06:18:32 +00:00
Jamison Lahman
59e0d62512 chore(fe): align assistant icon with chat bar (#7537) 2026-01-19 19:47:18 -08:00
Jamison Lahman
a1471b16a5 fix(fe): chat header is sticky and transparent (#7487) 2026-01-19 19:20:03 -08:00
Yuhong Sun
9d3811cb58 fix: prompt tuning (#7550) 2026-01-19 19:04:18 -08:00
Yuhong Sun
3cd9505383 feat: Memory initial (#7547) 2026-01-19 18:57:13 -08:00
Nikolas Garza
d11829b393 refactor: proxy customer portal session through control plane (#7544) 2026-01-20 01:24:30 +00:00
Nikolas Garza
f6e068e914 feat(billing): add annual pricing support to subscription checkout (#7506) 2026-01-20 00:17:18 +00:00
roshan
0c84edd980 feat: onyx embeddable widget (#7427)
Co-authored-by: cubic-dev-ai[bot] <191113872+cubic-dev-ai[bot]@users.noreply.github.com>
2026-01-20 00:01:10 +00:00
Wenxi
2b274a7683 feat: onyx discord bot - discord client (#7496) 2026-01-20 00:00:20 +00:00
Wenxi
ddd91f2d71 feat: onyx discord bot - api client and cache manager (#7495) 2026-01-19 23:15:17 +00:00
Yuhong Sun
a7c7da0dfc fix: tool call handling for weak models (#7538) 2026-01-19 13:37:00 -08:00
Evan Lohn
b00a3e8b5d fix(test): confluence group sync (#7536) 2026-01-19 21:20:48 +00:00
Raunak Bhagat
d77d1a48f1 fix: Line item fixes (#7513) 2026-01-19 20:25:35 +00:00
Raunak Bhagat
7b4fc6729c fix: Popover size fix (#7521) 2026-01-19 18:44:29 +00:00
Nikolas Garza
1f113c86ef feat(ee): license enforcement middleware (#7483) 2026-01-19 18:03:39 +00:00
Raunak Bhagat
8e38ba3e21 refactor: Fix some onboarding inaccuracies (#7511) 2026-01-19 04:33:27 +00:00
Raunak Bhagat
bb9708a64f refactor: Small styling / prop-naming refactors (#7503) 2026-01-19 02:49:27 +00:00
Raunak Bhagat
8cae97e145 fix: Fix connector-setup modal (#7502) 2026-01-19 00:29:36 +00:00
Wenxi
7e4abca224 feat: onyx discord bot - backend, crud, and apis (#7494) 2026-01-18 23:13:58 +00:00
Yuhong Sun
233a91ea65 chore: drop dead table (#7500) 2026-01-17 20:05:22 -08:00
515 changed files with 42886 additions and 8031 deletions

View File

@@ -8,7 +8,9 @@ on:
# Set restrictive default permissions for all jobs. Jobs that need more permissions
# should explicitly declare them.
permissions: {}
permissions:
# Required for OIDC authentication with AWS
id-token: write # zizmor: ignore[excessive-permissions]
env:
EDGE_TAG: ${{ startsWith(github.ref_name, 'nightly-latest') }}
@@ -150,16 +152,30 @@ jobs:
if: always() && needs.check-version-tag.result == 'failure' && github.event_name != 'workflow_dispatch'
runs-on: ubuntu-slim
timeout-minutes: 10
environment: release
steps:
- name: Checkout
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
with:
persist-credentials: false
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
with:
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
aws-region: us-east-2
- name: Get AWS Secrets
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
with:
secret-ids: |
MONITOR_DEPLOYMENTS_WEBHOOK, deploy/monitor-deployments-webhook
parse-json-secrets: true
- name: Send Slack notification
uses: ./.github/actions/slack-notify
with:
webhook-url: ${{ secrets.MONITOR_DEPLOYMENTS_WEBHOOK }}
webhook-url: ${{ env.MONITOR_DEPLOYMENTS_WEBHOOK }}
failed-jobs: "• check-version-tag"
title: "🚨 Version Tag Check Failed"
ref-name: ${{ github.ref_name }}
@@ -168,6 +184,7 @@ jobs:
needs: determine-builds
if: needs.determine-builds.outputs.build-desktop == 'true'
permissions:
id-token: write
contents: write
actions: read
strategy:
@@ -185,12 +202,33 @@ jobs:
runs-on: ${{ matrix.platform }}
timeout-minutes: 90
environment: release
steps:
- uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6.0.1
with:
# NOTE: persist-credentials is needed for tauri-action to create GitHub releases.
persist-credentials: true # zizmor: ignore[artipacked]
- name: Configure AWS credentials
if: startsWith(matrix.platform, 'macos-')
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
with:
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
aws-region: us-east-2
- name: Get AWS Secrets
if: startsWith(matrix.platform, 'macos-')
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
with:
secret-ids: |
APPLE_ID, deploy/apple-id
APPLE_PASSWORD, deploy/apple-password
APPLE_CERTIFICATE, deploy/apple-certificate
APPLE_CERTIFICATE_PASSWORD, deploy/apple-certificate-password
KEYCHAIN_PASSWORD, deploy/keychain-password
APPLE_TEAM_ID, deploy/apple-team-id
parse-json-secrets: true
- name: install dependencies (ubuntu only)
if: startsWith(matrix.platform, 'ubuntu-')
run: |
@@ -285,15 +323,40 @@ jobs:
Write-Host "Versions set to: $VERSION"
- name: Import Apple Developer Certificate
if: startsWith(matrix.platform, 'macos-')
run: |
echo $APPLE_CERTIFICATE | base64 --decode > certificate.p12
security create-keychain -p "$KEYCHAIN_PASSWORD" build.keychain
security default-keychain -s build.keychain
security unlock-keychain -p "$KEYCHAIN_PASSWORD" build.keychain
security set-keychain-settings -t 3600 -u build.keychain
security import certificate.p12 -k build.keychain -P "$APPLE_CERTIFICATE_PASSWORD" -T /usr/bin/codesign
security set-key-partition-list -S apple-tool:,apple:,codesign: -s -k "$KEYCHAIN_PASSWORD" build.keychain
security find-identity -v -p codesigning build.keychain
- name: Verify Certificate
if: startsWith(matrix.platform, 'macos-')
run: |
CERT_INFO=$(security find-identity -v -p codesigning build.keychain | grep -E "(Developer ID Application|Apple Distribution|Apple Development)" | head -n 1)
CERT_ID=$(echo "$CERT_INFO" | awk -F'"' '{print $2}')
echo "CERT_ID=$CERT_ID" >> $GITHUB_ENV
echo "Certificate imported."
- uses: tauri-apps/tauri-action@73fb865345c54760d875b94642314f8c0c894afa # ratchet:tauri-apps/tauri-action@action-v0.6.1
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
APPLE_ID: ${{ env.APPLE_ID }}
APPLE_PASSWORD: ${{ env.APPLE_PASSWORD }}
APPLE_SIGNING_IDENTITY: ${{ env.CERT_ID }}
APPLE_TEAM_ID: ${{ env.APPLE_TEAM_ID }}
with:
tagName: ${{ needs.determine-builds.outputs.is-test-run != 'true' && 'v__VERSION__' || format('v0.0.0-dev+{0}', needs.determine-builds.outputs.short-sha) }}
releaseName: ${{ needs.determine-builds.outputs.is-test-run != 'true' && 'v__VERSION__' || format('v0.0.0-dev+{0}', needs.determine-builds.outputs.short-sha) }}
releaseBody: "See the assets to download this version and install."
releaseDraft: true
prerelease: false
assetNamePattern: "[name]_[arch][ext]"
args: ${{ matrix.args }}
build-web-amd64:
@@ -305,6 +368,7 @@ jobs:
- run-id=${{ github.run_id }}-web-amd64
- extras=ecr-cache
timeout-minutes: 90
environment: release
outputs:
digest: ${{ steps.build.outputs.digest }}
env:
@@ -317,6 +381,20 @@ jobs:
with:
persist-credentials: false
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
with:
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
aws-region: us-east-2
- name: Get AWS Secrets
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
with:
secret-ids: |
DOCKER_USERNAME, deploy/docker-username
DOCKER_TOKEN, deploy/docker-token
parse-json-secrets: true
- name: Docker meta
id: meta
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
@@ -326,13 +404,13 @@ jobs:
latest=false
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
username: ${{ env.DOCKER_USERNAME }}
password: ${{ env.DOCKER_TOKEN }}
- name: Build and push AMD64
id: build
@@ -363,6 +441,7 @@ jobs:
- run-id=${{ github.run_id }}-web-arm64
- extras=ecr-cache
timeout-minutes: 90
environment: release
outputs:
digest: ${{ steps.build.outputs.digest }}
env:
@@ -375,6 +454,20 @@ jobs:
with:
persist-credentials: false
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
with:
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
aws-region: us-east-2
- name: Get AWS Secrets
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
with:
secret-ids: |
DOCKER_USERNAME, deploy/docker-username
DOCKER_TOKEN, deploy/docker-token
parse-json-secrets: true
- name: Docker meta
id: meta
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
@@ -384,13 +477,13 @@ jobs:
latest=false
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
username: ${{ env.DOCKER_USERNAME }}
password: ${{ env.DOCKER_TOKEN }}
- name: Build and push ARM64
id: build
@@ -423,19 +516,34 @@ jobs:
- run-id=${{ github.run_id }}-merge-web
- extras=ecr-cache
timeout-minutes: 90
environment: release
env:
REGISTRY_IMAGE: onyxdotapp/onyx-web-server
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
with:
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
aws-region: us-east-2
- name: Get AWS Secrets
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
with:
secret-ids: |
DOCKER_USERNAME, deploy/docker-username
DOCKER_TOKEN, deploy/docker-token
parse-json-secrets: true
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
username: ${{ env.DOCKER_USERNAME }}
password: ${{ env.DOCKER_TOKEN }}
- name: Docker meta
id: meta
@@ -471,6 +579,7 @@ jobs:
- run-id=${{ github.run_id }}-web-cloud-amd64
- extras=ecr-cache
timeout-minutes: 90
environment: release
outputs:
digest: ${{ steps.build.outputs.digest }}
env:
@@ -483,6 +592,20 @@ jobs:
with:
persist-credentials: false
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
with:
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
aws-region: us-east-2
- name: Get AWS Secrets
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
with:
secret-ids: |
DOCKER_USERNAME, deploy/docker-username
DOCKER_TOKEN, deploy/docker-token
parse-json-secrets: true
- name: Docker meta
id: meta
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
@@ -492,13 +615,13 @@ jobs:
latest=false
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
username: ${{ env.DOCKER_USERNAME }}
password: ${{ env.DOCKER_TOKEN }}
- name: Build and push AMD64
id: build
@@ -537,6 +660,7 @@ jobs:
- run-id=${{ github.run_id }}-web-cloud-arm64
- extras=ecr-cache
timeout-minutes: 90
environment: release
outputs:
digest: ${{ steps.build.outputs.digest }}
env:
@@ -549,6 +673,20 @@ jobs:
with:
persist-credentials: false
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
with:
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
aws-region: us-east-2
- name: Get AWS Secrets
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
with:
secret-ids: |
DOCKER_USERNAME, deploy/docker-username
DOCKER_TOKEN, deploy/docker-token
parse-json-secrets: true
- name: Docker meta
id: meta
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
@@ -558,13 +696,13 @@ jobs:
latest=false
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
username: ${{ env.DOCKER_USERNAME }}
password: ${{ env.DOCKER_TOKEN }}
- name: Build and push ARM64
id: build
@@ -605,19 +743,34 @@ jobs:
- run-id=${{ github.run_id }}-merge-web-cloud
- extras=ecr-cache
timeout-minutes: 90
environment: release
env:
REGISTRY_IMAGE: onyxdotapp/onyx-web-server-cloud
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
with:
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
aws-region: us-east-2
- name: Get AWS Secrets
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
with:
secret-ids: |
DOCKER_USERNAME, deploy/docker-username
DOCKER_TOKEN, deploy/docker-token
parse-json-secrets: true
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
username: ${{ env.DOCKER_USERNAME }}
password: ${{ env.DOCKER_TOKEN }}
- name: Docker meta
id: meta
@@ -650,6 +803,7 @@ jobs:
- run-id=${{ github.run_id }}-backend-amd64
- extras=ecr-cache
timeout-minutes: 90
environment: release
outputs:
digest: ${{ steps.build.outputs.digest }}
env:
@@ -662,6 +816,20 @@ jobs:
with:
persist-credentials: false
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
with:
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
aws-region: us-east-2
- name: Get AWS Secrets
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
with:
secret-ids: |
DOCKER_USERNAME, deploy/docker-username
DOCKER_TOKEN, deploy/docker-token
parse-json-secrets: true
- name: Docker meta
id: meta
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
@@ -671,13 +839,13 @@ jobs:
latest=false
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
username: ${{ env.DOCKER_USERNAME }}
password: ${{ env.DOCKER_TOKEN }}
- name: Build and push AMD64
id: build
@@ -707,6 +875,7 @@ jobs:
- run-id=${{ github.run_id }}-backend-arm64
- extras=ecr-cache
timeout-minutes: 90
environment: release
outputs:
digest: ${{ steps.build.outputs.digest }}
env:
@@ -719,6 +888,20 @@ jobs:
with:
persist-credentials: false
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
with:
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
aws-region: us-east-2
- name: Get AWS Secrets
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
with:
secret-ids: |
DOCKER_USERNAME, deploy/docker-username
DOCKER_TOKEN, deploy/docker-token
parse-json-secrets: true
- name: Docker meta
id: meta
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
@@ -728,13 +911,13 @@ jobs:
latest=false
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
username: ${{ env.DOCKER_USERNAME }}
password: ${{ env.DOCKER_TOKEN }}
- name: Build and push ARM64
id: build
@@ -766,19 +949,34 @@ jobs:
- run-id=${{ github.run_id }}-merge-backend
- extras=ecr-cache
timeout-minutes: 90
environment: release
env:
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-backend-cloud' || 'onyxdotapp/onyx-backend' }}
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
with:
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
aws-region: us-east-2
- name: Get AWS Secrets
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
with:
secret-ids: |
DOCKER_USERNAME, deploy/docker-username
DOCKER_TOKEN, deploy/docker-token
parse-json-secrets: true
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
username: ${{ env.DOCKER_USERNAME }}
password: ${{ env.DOCKER_TOKEN }}
- name: Docker meta
id: meta
@@ -815,6 +1013,7 @@ jobs:
- volume=40gb
- extras=ecr-cache
timeout-minutes: 90
environment: release
outputs:
digest: ${{ steps.build.outputs.digest }}
env:
@@ -827,6 +1026,20 @@ jobs:
with:
persist-credentials: false
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
with:
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
aws-region: us-east-2
- name: Get AWS Secrets
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
with:
secret-ids: |
DOCKER_USERNAME, deploy/docker-username
DOCKER_TOKEN, deploy/docker-token
parse-json-secrets: true
- name: Docker meta
id: meta
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
@@ -836,15 +1049,15 @@ jobs:
latest=false
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
with:
buildkitd-flags: ${{ vars.DOCKER_DEBUG == 'true' && '--debug' || '' }}
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
username: ${{ env.DOCKER_USERNAME }}
password: ${{ env.DOCKER_TOKEN }}
- name: Build and push AMD64
id: build
@@ -879,6 +1092,7 @@ jobs:
- volume=40gb
- extras=ecr-cache
timeout-minutes: 90
environment: release
outputs:
digest: ${{ steps.build.outputs.digest }}
env:
@@ -891,6 +1105,20 @@ jobs:
with:
persist-credentials: false
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
with:
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
aws-region: us-east-2
- name: Get AWS Secrets
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
with:
secret-ids: |
DOCKER_USERNAME, deploy/docker-username
DOCKER_TOKEN, deploy/docker-token
parse-json-secrets: true
- name: Docker meta
id: meta
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
@@ -900,15 +1128,15 @@ jobs:
latest=false
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
with:
buildkitd-flags: ${{ vars.DOCKER_DEBUG == 'true' && '--debug' || '' }}
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
username: ${{ env.DOCKER_USERNAME }}
password: ${{ env.DOCKER_TOKEN }}
- name: Build and push ARM64
id: build
@@ -944,19 +1172,34 @@ jobs:
- run-id=${{ github.run_id }}-merge-model-server
- extras=ecr-cache
timeout-minutes: 90
environment: release
env:
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-model-server-cloud' || 'onyxdotapp/onyx-model-server' }}
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
with:
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
aws-region: us-east-2
- name: Get AWS Secrets
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
with:
secret-ids: |
DOCKER_USERNAME, deploy/docker-username
DOCKER_TOKEN, deploy/docker-token
parse-json-secrets: true
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
username: ${{ env.DOCKER_USERNAME }}
password: ${{ env.DOCKER_TOKEN }}
- name: Docker meta
id: meta
@@ -994,11 +1237,26 @@ jobs:
- run-id=${{ github.run_id }}-trivy-scan-web
- extras=ecr-cache
timeout-minutes: 90
environment: release
env:
REGISTRY_IMAGE: onyxdotapp/onyx-web-server
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
with:
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
aws-region: us-east-2
- name: Get AWS Secrets
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
with:
secret-ids: |
DOCKER_USERNAME, deploy/docker-username
DOCKER_TOKEN, deploy/docker-token
parse-json-secrets: true
- name: Run Trivy vulnerability scanner
uses: nick-fields/retry@ce71cc2ab81d554ebbe88c79ab5975992d79ba08 # ratchet:nick-fields/retry@v3
with:
@@ -1014,8 +1272,8 @@ jobs:
docker run --rm -v $HOME/.cache/trivy:/root/.cache/trivy \
-e TRIVY_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-db:2" \
-e TRIVY_JAVA_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-java-db:1" \
-e TRIVY_USERNAME="${{ secrets.DOCKER_USERNAME }}" \
-e TRIVY_PASSWORD="${{ secrets.DOCKER_TOKEN }}" \
-e TRIVY_USERNAME="${{ env.DOCKER_USERNAME }}" \
-e TRIVY_PASSWORD="${{ env.DOCKER_TOKEN }}" \
aquasec/trivy@sha256:a22415a38938a56c379387a8163fcb0ce38b10ace73e593475d3658d578b2436 \
image \
--skip-version-check \
@@ -1034,11 +1292,26 @@ jobs:
- run-id=${{ github.run_id }}-trivy-scan-web-cloud
- extras=ecr-cache
timeout-minutes: 90
environment: release
env:
REGISTRY_IMAGE: onyxdotapp/onyx-web-server-cloud
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
with:
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
aws-region: us-east-2
- name: Get AWS Secrets
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
with:
secret-ids: |
DOCKER_USERNAME, deploy/docker-username
DOCKER_TOKEN, deploy/docker-token
parse-json-secrets: true
- name: Run Trivy vulnerability scanner
uses: nick-fields/retry@ce71cc2ab81d554ebbe88c79ab5975992d79ba08 # ratchet:nick-fields/retry@v3
with:
@@ -1054,8 +1327,8 @@ jobs:
docker run --rm -v $HOME/.cache/trivy:/root/.cache/trivy \
-e TRIVY_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-db:2" \
-e TRIVY_JAVA_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-java-db:1" \
-e TRIVY_USERNAME="${{ secrets.DOCKER_USERNAME }}" \
-e TRIVY_PASSWORD="${{ secrets.DOCKER_TOKEN }}" \
-e TRIVY_USERNAME="${{ env.DOCKER_USERNAME }}" \
-e TRIVY_PASSWORD="${{ env.DOCKER_TOKEN }}" \
aquasec/trivy@sha256:a22415a38938a56c379387a8163fcb0ce38b10ace73e593475d3658d578b2436 \
image \
--skip-version-check \
@@ -1074,6 +1347,7 @@ jobs:
- run-id=${{ github.run_id }}-trivy-scan-backend
- extras=ecr-cache
timeout-minutes: 90
environment: release
env:
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-backend-cloud' || 'onyxdotapp/onyx-backend' }}
steps:
@@ -1084,6 +1358,20 @@ jobs:
with:
persist-credentials: false
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
with:
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
aws-region: us-east-2
- name: Get AWS Secrets
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
with:
secret-ids: |
DOCKER_USERNAME, deploy/docker-username
DOCKER_TOKEN, deploy/docker-token
parse-json-secrets: true
- name: Run Trivy vulnerability scanner
uses: nick-fields/retry@ce71cc2ab81d554ebbe88c79ab5975992d79ba08 # ratchet:nick-fields/retry@v3
with:
@@ -1100,8 +1388,8 @@ jobs:
-v ${{ github.workspace }}/backend/.trivyignore:/tmp/.trivyignore:ro \
-e TRIVY_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-db:2" \
-e TRIVY_JAVA_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-java-db:1" \
-e TRIVY_USERNAME="${{ secrets.DOCKER_USERNAME }}" \
-e TRIVY_PASSWORD="${{ secrets.DOCKER_TOKEN }}" \
-e TRIVY_USERNAME="${{ env.DOCKER_USERNAME }}" \
-e TRIVY_PASSWORD="${{ env.DOCKER_TOKEN }}" \
aquasec/trivy@sha256:a22415a38938a56c379387a8163fcb0ce38b10ace73e593475d3658d578b2436 \
image \
--skip-version-check \
@@ -1121,11 +1409,26 @@ jobs:
- run-id=${{ github.run_id }}-trivy-scan-model-server
- extras=ecr-cache
timeout-minutes: 90
environment: release
env:
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-model-server-cloud' || 'onyxdotapp/onyx-model-server' }}
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
with:
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
aws-region: us-east-2
- name: Get AWS Secrets
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
with:
secret-ids: |
DOCKER_USERNAME, deploy/docker-username
DOCKER_TOKEN, deploy/docker-token
parse-json-secrets: true
- name: Run Trivy vulnerability scanner
uses: nick-fields/retry@ce71cc2ab81d554ebbe88c79ab5975992d79ba08 # ratchet:nick-fields/retry@v3
with:
@@ -1141,8 +1444,8 @@ jobs:
docker run --rm -v $HOME/.cache/trivy:/root/.cache/trivy \
-e TRIVY_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-db:2" \
-e TRIVY_JAVA_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-java-db:1" \
-e TRIVY_USERNAME="${{ secrets.DOCKER_USERNAME }}" \
-e TRIVY_PASSWORD="${{ secrets.DOCKER_TOKEN }}" \
-e TRIVY_USERNAME="${{ env.DOCKER_USERNAME }}" \
-e TRIVY_PASSWORD="${{ env.DOCKER_TOKEN }}" \
aquasec/trivy@sha256:a22415a38938a56c379387a8163fcb0ce38b10ace73e593475d3658d578b2436 \
image \
--skip-version-check \
@@ -1170,12 +1473,26 @@ jobs:
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
runs-on: ubuntu-slim
timeout-minutes: 90
environment: release
steps:
- name: Checkout
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
with:
persist-credentials: false
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
with:
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
aws-region: us-east-2
- name: Get AWS Secrets
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
with:
secret-ids: |
MONITOR_DEPLOYMENTS_WEBHOOK, deploy/monitor-deployments-webhook
parse-json-secrets: true
- name: Determine failed jobs
id: failed-jobs
shell: bash
@@ -1241,7 +1558,7 @@ jobs:
- name: Send Slack notification
uses: ./.github/actions/slack-notify
with:
webhook-url: ${{ secrets.MONITOR_DEPLOYMENTS_WEBHOOK }}
webhook-url: ${{ env.MONITOR_DEPLOYMENTS_WEBHOOK }}
failed-jobs: ${{ steps.failed-jobs.outputs.jobs }}
title: "🚨 Deployment Workflow Failed"
ref-name: ${{ github.ref_name }}

View File

@@ -21,7 +21,7 @@ jobs:
timeout-minutes: 45
steps:
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3

View File

@@ -21,7 +21,7 @@ jobs:
timeout-minutes: 45
steps:
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3

View File

@@ -29,6 +29,7 @@ jobs:
run: |
helm repo add ingress-nginx https://kubernetes.github.io/ingress-nginx
helm repo add onyx-vespa https://onyx-dot-app.github.io/vespa-helm-charts
helm repo add opensearch https://opensearch-project.github.io/helm-charts
helm repo add cloudnative-pg https://cloudnative-pg.github.io/charts
helm repo add ot-container-kit https://ot-container-kit.github.io/helm-charts
helm repo add minio https://charts.min.io/

View File

@@ -94,7 +94,7 @@ jobs:
steps:
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3

View File

@@ -45,6 +45,9 @@ env:
# TODO: debug why this is failing and enable
CODE_INTERPRETER_BASE_URL: http://localhost:8000
# OpenSearch
OPENSEARCH_ADMIN_PASSWORD: "StrongPassword123!"
jobs:
discover-test-dirs:
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
@@ -125,11 +128,13 @@ jobs:
docker compose \
-f docker-compose.yml \
-f docker-compose.dev.yml \
-f docker-compose.opensearch.yml \
up -d \
minio \
relational_db \
cache \
index \
opensearch \
code-interpreter
- name: Run migrations
@@ -158,7 +163,7 @@ jobs:
cd deployment/docker_compose
# Get list of running containers
containers=$(docker compose -f docker-compose.yml -f docker-compose.dev.yml ps -q)
containers=$(docker compose -f docker-compose.yml -f docker-compose.dev.yml -f docker-compose.opensearch.yml ps -q)
# Collect logs from each container
for container in $containers; do

View File

@@ -88,6 +88,7 @@ jobs:
echo "=== Adding Helm repositories ==="
helm repo add ingress-nginx https://kubernetes.github.io/ingress-nginx
helm repo add vespa https://onyx-dot-app.github.io/vespa-helm-charts
helm repo add opensearch https://opensearch-project.github.io/helm-charts
helm repo add cloudnative-pg https://cloudnative-pg.github.io/charts
helm repo add ot-container-kit https://ot-container-kit.github.io/helm-charts
helm repo add minio https://charts.min.io/
@@ -180,6 +181,11 @@ jobs:
trap cleanup EXIT
# Run the actual installation with detailed logging
# Note that opensearch.enabled is true whereas others in this install
# are false. There is some work that needs to be done to get this
# entire step working in CI, enabling opensearch here is a small step
# in that direction. If this is causing issues, disabling it in this
# step should be ok in the short term.
echo "=== Starting ct install ==="
set +e
ct install --all \
@@ -187,6 +193,8 @@ jobs:
--set=nginx.enabled=false \
--set=minio.enabled=false \
--set=vespa.enabled=false \
--set=opensearch.enabled=true \
--set=auth.opensearch.enabled=true \
--set=slackbot.enabled=false \
--set=postgresql.enabled=true \
--set=postgresql.nameOverride=cloudnative-pg \

View File

@@ -103,7 +103,7 @@ jobs:
echo "cache-suffix=${CACHE_SUFFIX}" >> $GITHUB_OUTPUT
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
# needed for pulling Vespa, Redis, Postgres, and Minio images
# otherwise, we hit the "Unauthenticated users" limit
@@ -163,7 +163,7 @@ jobs:
echo "cache-suffix=${CACHE_SUFFIX}" >> $GITHUB_OUTPUT
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
# needed for pulling Vespa, Redis, Postgres, and Minio images
# otherwise, we hit the "Unauthenticated users" limit
@@ -208,7 +208,7 @@ jobs:
persist-credentials: false
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
# needed for pulling openapitools/openapi-generator-cli
# otherwise, we hit the "Unauthenticated users" limit

View File

@@ -95,7 +95,7 @@ jobs:
echo "cache-suffix=${CACHE_SUFFIX}" >> $GITHUB_OUTPUT
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
# needed for pulling Vespa, Redis, Postgres, and Minio images
# otherwise, we hit the "Unauthenticated users" limit
@@ -155,7 +155,7 @@ jobs:
echo "cache-suffix=${CACHE_SUFFIX}" >> $GITHUB_OUTPUT
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
# needed for pulling Vespa, Redis, Postgres, and Minio images
# otherwise, we hit the "Unauthenticated users" limit
@@ -214,7 +214,7 @@ jobs:
echo "cache-suffix=${CACHE_SUFFIX}" >> $GITHUB_OUTPUT
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
# needed for pulling openapitools/openapi-generator-cli
# otherwise, we hit the "Unauthenticated users" limit

View File

@@ -85,7 +85,7 @@ jobs:
echo "cache-suffix=${CACHE_SUFFIX}" >> $GITHUB_OUTPUT
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
# needed for pulling external images otherwise, we hit the "Unauthenticated users" limit
# https://docs.docker.com/docker-hub/usage/
@@ -146,7 +146,7 @@ jobs:
echo "cache-suffix=${CACHE_SUFFIX}" >> $GITHUB_OUTPUT
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
# needed for pulling external images otherwise, we hit the "Unauthenticated users" limit
# https://docs.docker.com/docker-hub/usage/
@@ -207,7 +207,7 @@ jobs:
echo "cache-suffix=${CACHE_SUFFIX}" >> $GITHUB_OUTPUT
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
# needed for pulling external images otherwise, we hit the "Unauthenticated users" limit
# https://docs.docker.com/docker-hub/usage/

View File

@@ -50,8 +50,9 @@ jobs:
uses: runs-on/cache@50350ad4242587b6c8c2baa2e740b1bc11285ff4 # ratchet:runs-on/cache@v4
with:
path: backend/.mypy_cache
key: mypy-${{ runner.os }}-${{ hashFiles('**/*.py', '**/*.pyi', 'backend/pyproject.toml') }}
key: mypy-${{ runner.os }}-${{ github.base_ref || github.event.merge_group.base_ref || 'main' }}-${{ hashFiles('**/*.py', '**/*.pyi', 'backend/pyproject.toml') }}
restore-keys: |
mypy-${{ runner.os }}-${{ github.base_ref || github.event.merge_group.base_ref || 'main' }}-
mypy-${{ runner.os }}-
- name: Run MyPy

View File

@@ -70,7 +70,7 @@ jobs:
password: ${{ secrets.DOCKER_TOKEN }}
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f
- name: Build and load
uses: docker/bake-action@5be5f02ff8819ecd3092ea6b2e6261c31774f2b4 # ratchet:docker/bake-action@v6

3
.gitignore vendored
View File

@@ -1,5 +1,8 @@
# editors
.vscode
!/.vscode/env_template.txt
!/.vscode/launch.json
!/.vscode/tasks.template.jsonc
.zed
.cursor

View File

@@ -74,6 +74,13 @@ repos:
# pass_filenames: true
# files: ^backend/.*\.py$
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: 3e8a8703264a2f4a69428a0aa4dcb512790b2c8c # frozen: v6.0.0
hooks:
- id: check-added-large-files
name: Check for added large files
args: ["--maxkb=1500"]
- repo: https://github.com/rhysd/actionlint
rev: a443f344ff32813837fa49f7aa6cbc478d770e62 # frozen: v1.7.9
hooks:

View File

@@ -1,5 +1,3 @@
/* Copy this file into '.vscode/launch.json' or merge its contents into your existing configurations. */
{
// Use IntelliSense to learn about possible attributes.
// Hover to view descriptions of existing attributes.
@@ -24,7 +22,7 @@
"Slack Bot",
"Celery primary",
"Celery light",
"Celery background",
"Celery heavy",
"Celery docfetching",
"Celery docprocessing",
"Celery beat"
@@ -151,6 +149,24 @@
},
"consoleTitle": "Slack Bot Console"
},
{
"name": "Discord Bot",
"consoleName": "Discord Bot",
"type": "debugpy",
"request": "launch",
"program": "onyx/onyxbot/discord/client.py",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"presentation": {
"group": "2"
},
"consoleTitle": "Discord Bot Console"
},
{
"name": "MCP Server",
"consoleName": "MCP Server",
@@ -579,6 +595,120 @@
"group": "3"
}
},
{
// Dummy entry used to label the group
"name": "--- Database ---",
"type": "node",
"request": "launch",
"presentation": {
"group": "4",
"order": 0
}
},
{
"name": "Restore seeded database dump",
"type": "node",
"request": "launch",
"runtimeExecutable": "uv",
"runtimeArgs": [
"run",
"--with",
"onyx-devtools",
"ods",
"db",
"restore",
"--fetch-seeded",
"--yes"
],
"cwd": "${workspaceFolder}",
"console": "integratedTerminal",
"presentation": {
"group": "4"
}
},
{
"name": "Clean restore seeded database dump (destructive)",
"type": "node",
"request": "launch",
"runtimeExecutable": "uv",
"runtimeArgs": [
"run",
"--with",
"onyx-devtools",
"ods",
"db",
"restore",
"--fetch-seeded",
"--clean",
"--yes"
],
"cwd": "${workspaceFolder}",
"console": "integratedTerminal",
"presentation": {
"group": "4"
}
},
{
"name": "Create database snapshot",
"type": "node",
"request": "launch",
"runtimeExecutable": "uv",
"runtimeArgs": [
"run",
"--with",
"onyx-devtools",
"ods",
"db",
"dump",
"backup.dump"
],
"cwd": "${workspaceFolder}",
"console": "integratedTerminal",
"presentation": {
"group": "4"
}
},
{
"name": "Clean restore database snapshot (destructive)",
"type": "node",
"request": "launch",
"runtimeExecutable": "uv",
"runtimeArgs": [
"run",
"--with",
"onyx-devtools",
"ods",
"db",
"restore",
"--clean",
"--yes",
"backup.dump"
],
"cwd": "${workspaceFolder}",
"console": "integratedTerminal",
"presentation": {
"group": "4"
}
},
{
"name": "Upgrade database to head revision",
"type": "node",
"request": "launch",
"runtimeExecutable": "uv",
"runtimeArgs": [
"run",
"--with",
"onyx-devtools",
"ods",
"db",
"upgrade"
],
"cwd": "${workspaceFolder}",
"console": "integratedTerminal",
"presentation": {
"group": "4"
}
},
{
// script to generate the openapi schema
"name": "Onyx OpenAPI Schema Generator",

View File

@@ -37,10 +37,6 @@ CVE-2023-50868
CVE-2023-52425
CVE-2024-28757
# sqlite, only used by NLTK library to grab word lemmatizer and stopwords
# No impact in our settings
CVE-2023-7104
# libharfbuzz0b, O(n^2) growth, worst case is denial of service
# Accept the risk
CVE-2023-25193

View File

@@ -89,12 +89,6 @@ RUN uv pip install --system --no-cache-dir --upgrade \
RUN python -c "from tokenizers import Tokenizer; \
Tokenizer.from_pretrained('nomic-ai/nomic-embed-text-v1')"
# Pre-downloading NLTK for setups with limited egress
RUN python -c "import nltk; \
nltk.download('stopwords', quiet=True); \
nltk.download('punkt_tab', quiet=True);"
# nltk.download('wordnet', quiet=True); introduce this back if lemmatization is needed
# Pre-downloading tiktoken for setups with limited egress
RUN python -c "import tiktoken; \
tiktoken.get_encoding('cl100k_base')"

View File

@@ -0,0 +1,42 @@
"""add_unique_constraint_to_inputprompt_prompt_user_id
Revision ID: 2c2430828bdf
Revises: fb80bdd256de
Create Date: 2026-01-20 16:01:54.314805
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "2c2430828bdf"
down_revision = "fb80bdd256de"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Create unique constraint on (prompt, user_id) for user-owned prompts
# This ensures each user can only have one shortcut with a given name
op.create_unique_constraint(
"uq_inputprompt_prompt_user_id",
"inputprompt",
["prompt", "user_id"],
)
# Create partial unique index for public prompts (where user_id IS NULL)
# PostgreSQL unique constraints don't enforce uniqueness for NULL values,
# so we need a partial index to ensure public prompt names are also unique
op.execute(
"""
CREATE UNIQUE INDEX uq_inputprompt_prompt_public
ON inputprompt (prompt)
WHERE user_id IS NULL
"""
)
def downgrade() -> None:
op.execute("DROP INDEX IF EXISTS uq_inputprompt_prompt_public")
op.drop_constraint("uq_inputprompt_prompt_user_id", "inputprompt", type_="unique")

View File

@@ -0,0 +1,29 @@
"""remove default prompt shortcuts
Revision ID: 41fa44bef321
Revises: 2c2430828bdf
Create Date: 2025-01-21
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "41fa44bef321"
down_revision = "2c2430828bdf"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Delete any user associations for the default prompts first (foreign key constraint)
op.execute(
"DELETE FROM inputprompt__user WHERE input_prompt_id IN (SELECT id FROM inputprompt WHERE id < 0)"
)
# Delete the pre-seeded default prompt shortcuts (they have negative IDs)
op.execute("DELETE FROM inputprompt WHERE id < 0")
def downgrade() -> None:
# We don't restore the default prompts on downgrade
pass

View File

@@ -0,0 +1,116 @@
"""Add Discord bot tables
Revision ID: 8b5ce697290e
Revises: a1b2c3d4e5f7
Create Date: 2025-01-14
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "8b5ce697290e"
down_revision = "a1b2c3d4e5f7"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
# DiscordBotConfig (singleton table - one per tenant)
op.create_table(
"discord_bot_config",
sa.Column(
"id",
sa.String(),
primary_key=True,
server_default=sa.text("'SINGLETON'"),
),
sa.Column("bot_token", sa.LargeBinary(), nullable=False), # EncryptedString
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.func.now(),
nullable=False,
),
sa.CheckConstraint("id = 'SINGLETON'", name="ck_discord_bot_config_singleton"),
)
# DiscordGuildConfig
op.create_table(
"discord_guild_config",
sa.Column("id", sa.Integer(), primary_key=True),
sa.Column("guild_id", sa.BigInteger(), nullable=True, unique=True),
sa.Column("guild_name", sa.String(), nullable=True),
sa.Column("registration_key", sa.String(), nullable=False, unique=True),
sa.Column("registered_at", sa.DateTime(timezone=True), nullable=True),
sa.Column(
"default_persona_id",
sa.Integer(),
sa.ForeignKey("persona.id", ondelete="SET NULL"),
nullable=True,
),
sa.Column(
"enabled", sa.Boolean(), server_default=sa.text("true"), nullable=False
),
)
# DiscordChannelConfig
op.create_table(
"discord_channel_config",
sa.Column("id", sa.Integer(), primary_key=True),
sa.Column(
"guild_config_id",
sa.Integer(),
sa.ForeignKey("discord_guild_config.id", ondelete="CASCADE"),
nullable=False,
),
sa.Column("channel_id", sa.BigInteger(), nullable=False),
sa.Column("channel_name", sa.String(), nullable=False),
sa.Column(
"channel_type",
sa.String(20),
server_default=sa.text("'text'"),
nullable=False,
),
sa.Column(
"is_private",
sa.Boolean(),
server_default=sa.text("false"),
nullable=False,
),
sa.Column(
"thread_only_mode",
sa.Boolean(),
server_default=sa.text("false"),
nullable=False,
),
sa.Column(
"require_bot_invocation",
sa.Boolean(),
server_default=sa.text("true"),
nullable=False,
),
sa.Column(
"persona_override_id",
sa.Integer(),
sa.ForeignKey("persona.id", ondelete="SET NULL"),
nullable=True,
),
sa.Column(
"enabled", sa.Boolean(), server_default=sa.text("false"), nullable=False
),
)
# Unique constraint: one config per channel per guild
op.create_unique_constraint(
"uq_discord_channel_guild_channel",
"discord_channel_config",
["guild_config_id", "channel_id"],
)
def downgrade() -> None:
op.drop_table("discord_channel_config")
op.drop_table("discord_guild_config")
op.drop_table("discord_bot_config")

View File

@@ -0,0 +1,47 @@
"""drop agent_search_metrics table
Revision ID: a1b2c3d4e5f7
Revises: 73e9983e5091
Create Date: 2026-01-17
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "a1b2c3d4e5f7"
down_revision = "73e9983e5091"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.drop_table("agent__search_metrics")
def downgrade() -> None:
op.create_table(
"agent__search_metrics",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("user_id", sa.UUID(), nullable=True),
sa.Column("persona_id", sa.Integer(), nullable=True),
sa.Column("agent_type", sa.String(), nullable=False),
sa.Column("start_time", sa.DateTime(timezone=True), nullable=False),
sa.Column("base_duration_s", sa.Float(), nullable=False),
sa.Column("full_duration_s", sa.Float(), nullable=False),
sa.Column("base_metrics", postgresql.JSONB(), nullable=True),
sa.Column("refined_metrics", postgresql.JSONB(), nullable=True),
sa.Column("all_metrics", postgresql.JSONB(), nullable=True),
sa.ForeignKeyConstraint(
["user_id"],
["user.id"],
ondelete="CASCADE",
),
sa.ForeignKeyConstraint(
["persona_id"],
["persona.id"],
),
sa.PrimaryKeyConstraint("id"),
)

View File

@@ -1,42 +0,0 @@
"""Add SET NULL cascade to chat_session.persona_id foreign key
Revision ID: ac9c7b76419b
Revises: 73e9983e5091
Create Date: 2026-01-17 18:10:00.000000
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "ac9c7b76419b"
down_revision = "73e9983e5091"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Drop the existing foreign key constraint (no cascade behavior)
op.drop_constraint("fk_chat_session_persona_id", "chat_session", type_="foreignkey")
# Recreate with SET NULL on delete, so deleting a persona sets
# chat_session.persona_id to NULL instead of blocking the delete
op.create_foreign_key(
"fk_chat_session_persona_id",
"chat_session",
"persona",
["persona_id"],
["id"],
ondelete="SET NULL",
)
def downgrade() -> None:
# Revert to original constraint without cascade behavior
op.drop_constraint("fk_chat_session_persona_id", "chat_session", type_="foreignkey")
op.create_foreign_key(
"fk_chat_session_persona_id",
"chat_session",
"persona",
["persona_id"],
["id"],
)

View File

@@ -0,0 +1,31 @@
"""add chat_background to user
Revision ID: fb80bdd256de
Revises: 8b5ce697290e
Create Date: 2026-01-16 16:15:59.222617
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "fb80bdd256de"
down_revision = "8b5ce697290e"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"user",
sa.Column(
"chat_background",
sa.String(),
nullable=True,
),
)
def downgrade() -> None:
op.drop_column("user", "chat_background")

View File

@@ -109,7 +109,6 @@ CHECK_TTL_MANAGEMENT_TASK_FREQUENCY_IN_HOURS = float(
STRIPE_SECRET_KEY = os.environ.get("STRIPE_SECRET_KEY")
STRIPE_PRICE_ID = os.environ.get("STRIPE_PRICE")
# JWT Public Key URL
JWT_PUBLIC_KEY_URL: str | None = os.getenv("JWT_PUBLIC_KEY_URL", None)
@@ -129,3 +128,8 @@ MARKETING_POSTHOG_API_KEY = os.environ.get("MARKETING_POSTHOG_API_KEY")
HUBSPOT_TRACKING_URL = os.environ.get("HUBSPOT_TRACKING_URL")
GATED_TENANTS_KEY = "gated_tenants"
# License enforcement - when True, blocks API access for gated/expired licenses
LICENSE_ENFORCEMENT_ENABLED = (
os.environ.get("LICENSE_ENFORCEMENT_ENABLED", "").lower() == "true"
)

View File

@@ -16,6 +16,9 @@ from ee.onyx.server.enterprise_settings.api import (
from ee.onyx.server.evals.api import router as evals_router
from ee.onyx.server.license.api import router as license_router
from ee.onyx.server.manage.standard_answer import router as standard_answer_router
from ee.onyx.server.middleware.license_enforcement import (
add_license_enforcement_middleware,
)
from ee.onyx.server.middleware.tenant_tracking import (
add_api_server_tenant_id_middleware,
)
@@ -83,6 +86,10 @@ def get_application() -> FastAPI:
if MULTI_TENANT:
add_api_server_tenant_id_middleware(application, logger)
# Add license enforcement middleware (runs after tenant tracking)
# This blocks access when license is expired/gated
add_license_enforcement_middleware(application, logger)
if AUTH_TYPE == AuthType.CLOUD:
# For Google OAuth, refresh tokens are requested by:
# 1. Adding the right scopes

View File

@@ -17,7 +17,8 @@ from onyx.context.search.models import InferenceChunk
from onyx.context.search.pipeline import merge_individual_chunks
from onyx.context.search.pipeline import search_pipeline
from onyx.db.models import User
from onyx.document_index.factory import get_current_primary_default_document_index
from onyx.db.search_settings import get_current_search_settings
from onyx.document_index.factory import get_default_document_index
from onyx.document_index.interfaces import DocumentIndex
from onyx.llm.factory import get_default_llm
from onyx.secondary_llm_flows.document_filter import select_sections_for_expansion
@@ -42,11 +43,13 @@ def _run_single_search(
document_index: DocumentIndex,
user: User | None,
db_session: Session,
num_hits: int | None = None,
) -> list[InferenceChunk]:
"""Execute a single search query and return chunks."""
chunk_search_request = ChunkSearchRequest(
query=query,
user_selected_filters=filters,
limit=num_hits,
)
return search_pipeline(
@@ -72,7 +75,9 @@ def stream_search_query(
Used by both streaming and non-streaming endpoints.
"""
# Get document index
document_index = get_current_primary_default_document_index(db_session)
search_settings = get_current_search_settings(db_session)
# This flow is for search so we do not get all indices.
document_index = get_default_document_index(search_settings, None)
# Determine queries to execute
original_query = request.search_query
@@ -114,6 +119,7 @@ def stream_search_query(
document_index=document_index,
user=user,
db_session=db_session,
num_hits=request.num_hits,
)
else:
# Multiple queries - run in parallel and merge with RRF
@@ -121,7 +127,14 @@ def stream_search_query(
search_functions = [
(
_run_single_search,
(query, request.filters, document_index, user, db_session),
(
query,
request.filters,
document_index,
user,
db_session,
request.num_hits,
),
)
for query in all_executed_queries
]
@@ -168,6 +181,9 @@ def stream_search_query(
# Merge chunks into sections
sections = merge_individual_chunks(chunks)
# Truncate to the requested number of hits
sections = sections[: request.num_hits]
# Apply LLM document selection if requested
# num_docs_fed_to_llm_selection specifies how many sections to feed to the LLM for selection
# The LLM will always try to select TARGET_NUM_SECTIONS_FOR_LLM_SELECTION sections from those fed to it

View File

@@ -10,6 +10,8 @@ EE_PUBLIC_ENDPOINT_SPECS = PUBLIC_ENDPOINT_SPECS + [
("/enterprise-settings/logo", {"GET"}),
("/enterprise-settings/logotype", {"GET"}),
("/enterprise-settings/custom-analytics-script", {"GET"}),
# Stripe publishable key is safe to expose publicly
("/tenants/stripe-publishable-key", {"GET"}),
]

View File

@@ -0,0 +1,102 @@
"""Middleware to enforce license status application-wide."""
import logging
from collections.abc import Awaitable
from collections.abc import Callable
from fastapi import FastAPI
from fastapi import Request
from fastapi import Response
from fastapi.responses import JSONResponse
from redis.exceptions import RedisError
from ee.onyx.configs.app_configs import LICENSE_ENFORCEMENT_ENABLED
from ee.onyx.db.license import get_cached_license_metadata
from ee.onyx.server.tenants.product_gating import is_tenant_gated
from onyx.server.settings.models import ApplicationStatus
from shared_configs.configs import MULTI_TENANT
from shared_configs.contextvars import get_current_tenant_id
# Paths that are ALWAYS accessible, even when license is expired/gated.
# These enable users to:
# /auth - Log in/out (users can't fix billing if locked out of auth)
# /license - Fetch, upload, or check license status
# /health - Health checks for load balancers/orchestrators
# /me - Basic user info needed for UI rendering
# /settings, /enterprise-settings - View app status and branding
# /tenants/billing-* - Manage subscription to resolve gating
ALLOWED_PATH_PREFIXES = {
"/auth",
"/license",
"/health",
"/me",
"/settings",
"/enterprise-settings",
"/tenants/billing-information",
"/tenants/create-customer-portal-session",
"/tenants/create-subscription-session",
}
def _is_path_allowed(path: str) -> bool:
"""Check if path is in allowlist (prefix match)."""
return any(path.startswith(prefix) for prefix in ALLOWED_PATH_PREFIXES)
def add_license_enforcement_middleware(
app: FastAPI, logger: logging.LoggerAdapter
) -> None:
logger.info("License enforcement middleware registered")
@app.middleware("http")
async def enforce_license(
request: Request, call_next: Callable[[Request], Awaitable[Response]]
) -> Response:
"""Block requests when license is expired/gated."""
if not LICENSE_ENFORCEMENT_ENABLED:
return await call_next(request)
path = request.url.path
if path.startswith("/api"):
path = path[4:]
if _is_path_allowed(path):
return await call_next(request)
is_gated = False
tenant_id = get_current_tenant_id()
if MULTI_TENANT:
try:
is_gated = is_tenant_gated(tenant_id)
except RedisError as e:
logger.warning(f"Failed to check tenant gating status: {e}")
# Fail open - don't block users due to Redis connectivity issues
is_gated = False
else:
try:
metadata = get_cached_license_metadata(tenant_id)
if metadata:
if metadata.status == ApplicationStatus.GATED_ACCESS:
is_gated = True
else:
# No license metadata = gated for self-hosted EE
is_gated = True
except RedisError as e:
logger.warning(f"Failed to check license metadata: {e}")
# Fail open - don't block users due to Redis connectivity issues
is_gated = False
if is_gated:
logger.info(f"Blocking request for gated tenant: {tenant_id}, path={path}")
return JSONResponse(
status_code=402,
content={
"detail": {
"error": "license_expired",
"message": "Your subscription has expired. Please update your billing.",
}
},
)
return await call_next(request)

View File

@@ -32,6 +32,7 @@ class SendSearchQueryRequest(BaseModel):
filters: BaseFilters | None = None
num_docs_fed_to_llm_selection: int | None = None
run_query_expansion: bool = False
num_hits: int = 50
include_content: bool = False
stream: bool = False

View File

@@ -0,0 +1,54 @@
"""EE Settings API - provides license-aware settings override."""
from redis.exceptions import RedisError
from ee.onyx.configs.app_configs import LICENSE_ENFORCEMENT_ENABLED
from ee.onyx.db.license import get_cached_license_metadata
from onyx.server.settings.models import ApplicationStatus
from onyx.server.settings.models import Settings
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
# Statuses that indicate a billing/license problem - propagate these to settings
_GATED_STATUSES = frozenset(
{
ApplicationStatus.GATED_ACCESS,
ApplicationStatus.GRACE_PERIOD,
ApplicationStatus.PAYMENT_REMINDER,
}
)
def apply_license_status_to_settings(settings: Settings) -> Settings:
"""EE version: checks license status for self-hosted deployments.
For self-hosted, looks up license metadata and overrides application_status
if the license is missing or indicates a problem (expired, grace period, etc.).
For multi-tenant (cloud), the settings already have the correct status
from the control plane, so no override is needed.
If LICENSE_ENFORCEMENT_ENABLED is false, settings are returned unchanged,
allowing the product to function normally without license checks.
"""
if not LICENSE_ENFORCEMENT_ENABLED:
return settings
if MULTI_TENANT:
return settings
tenant_id = get_current_tenant_id()
try:
metadata = get_cached_license_metadata(tenant_id)
if metadata and metadata.status in _GATED_STATUSES:
settings.application_status = metadata.status
elif not metadata:
# No license = gated access for self-hosted EE
settings.application_status = ApplicationStatus.GATED_ACCESS
except RedisError as e:
logger.warning(f"Failed to check license metadata for settings: {e}")
return settings

View File

@@ -1,9 +1,9 @@
from typing import cast
from typing import Literal
import requests
import stripe
from ee.onyx.configs.app_configs import STRIPE_PRICE_ID
from ee.onyx.configs.app_configs import STRIPE_SECRET_KEY
from ee.onyx.server.tenants.access import generate_data_plane_token
from ee.onyx.server.tenants.models import BillingInformation
@@ -16,15 +16,21 @@ stripe.api_key = STRIPE_SECRET_KEY
logger = setup_logger()
def fetch_stripe_checkout_session(tenant_id: str) -> str:
def fetch_stripe_checkout_session(
tenant_id: str,
billing_period: Literal["monthly", "annual"] = "monthly",
) -> str:
token = generate_data_plane_token()
headers = {
"Authorization": f"Bearer {token}",
"Content-Type": "application/json",
}
url = f"{CONTROL_PLANE_API_BASE_URL}/create-checkout-session"
params = {"tenant_id": tenant_id}
response = requests.post(url, headers=headers, params=params)
payload = {
"tenant_id": tenant_id,
"billing_period": billing_period,
}
response = requests.post(url, headers=headers, json=payload)
response.raise_for_status()
return response.json()["sessionId"]
@@ -70,24 +76,46 @@ def fetch_billing_information(
return BillingInformation(**response_data)
def fetch_customer_portal_session(tenant_id: str, return_url: str | None = None) -> str:
"""
Fetch a Stripe customer portal session URL from the control plane.
NOTE: This is currently only used for multi-tenant (cloud) deployments.
Self-hosted proxy endpoints will be added in a future phase.
"""
token = generate_data_plane_token()
headers = {
"Authorization": f"Bearer {token}",
"Content-Type": "application/json",
}
url = f"{CONTROL_PLANE_API_BASE_URL}/create-customer-portal-session"
payload = {"tenant_id": tenant_id}
if return_url:
payload["return_url"] = return_url
response = requests.post(url, headers=headers, json=payload)
response.raise_for_status()
return response.json()["url"]
def register_tenant_users(tenant_id: str, number_of_users: int) -> stripe.Subscription:
"""
Send a request to the control service to register the number of users for a tenant.
Update the number of seats for a tenant's subscription.
Preserves the existing price (monthly, annual, or grandfathered).
"""
if not STRIPE_PRICE_ID:
raise Exception("STRIPE_PRICE_ID is not set")
response = fetch_tenant_stripe_information(tenant_id)
stripe_subscription_id = cast(str, response.get("stripe_subscription_id"))
subscription = stripe.Subscription.retrieve(stripe_subscription_id)
subscription_item = subscription["items"]["data"][0]
# Use existing price to preserve the customer's current plan
current_price_id = subscription_item.price.id
updated_subscription = stripe.Subscription.modify(
stripe_subscription_id,
items=[
{
"id": subscription["items"]["data"][0].id,
"price": STRIPE_PRICE_ID,
"id": subscription_item.id,
"price": current_price_id,
"quantity": number_of_users,
}
],

View File

@@ -1,33 +1,41 @@
import stripe
import asyncio
import httpx
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from ee.onyx.auth.users import current_admin_user
from ee.onyx.configs.app_configs import STRIPE_SECRET_KEY
from ee.onyx.server.tenants.access import control_plane_dep
from ee.onyx.server.tenants.billing import fetch_billing_information
from ee.onyx.server.tenants.billing import fetch_customer_portal_session
from ee.onyx.server.tenants.billing import fetch_stripe_checkout_session
from ee.onyx.server.tenants.billing import fetch_tenant_stripe_information
from ee.onyx.server.tenants.models import BillingInformation
from ee.onyx.server.tenants.models import CreateSubscriptionSessionRequest
from ee.onyx.server.tenants.models import ProductGatingFullSyncRequest
from ee.onyx.server.tenants.models import ProductGatingRequest
from ee.onyx.server.tenants.models import ProductGatingResponse
from ee.onyx.server.tenants.models import StripePublishableKeyResponse
from ee.onyx.server.tenants.models import SubscriptionSessionResponse
from ee.onyx.server.tenants.models import SubscriptionStatusResponse
from ee.onyx.server.tenants.product_gating import overwrite_full_gated_set
from ee.onyx.server.tenants.product_gating import store_product_gating
from onyx.auth.users import User
from onyx.configs.app_configs import STRIPE_PUBLISHABLE_KEY_OVERRIDE
from onyx.configs.app_configs import STRIPE_PUBLISHABLE_KEY_URL
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.contextvars import get_current_tenant_id
stripe.api_key = STRIPE_SECRET_KEY
logger = setup_logger()
router = APIRouter(prefix="/tenants")
# Cache for Stripe publishable key to avoid hitting S3 on every request
_stripe_publishable_key_cache: str | None = None
_stripe_key_lock = asyncio.Lock()
@router.post("/product-gating")
def gate_product(
@@ -82,21 +90,17 @@ async def billing_information(
async def create_customer_portal_session(
_: User = Depends(current_admin_user),
) -> dict:
"""
Create a Stripe customer portal session via the control plane.
NOTE: This is currently only used for multi-tenant (cloud) deployments.
Self-hosted proxy endpoints will be added in a future phase.
"""
tenant_id = get_current_tenant_id()
return_url = f"{WEB_DOMAIN}/admin/billing"
try:
stripe_info = fetch_tenant_stripe_information(tenant_id)
stripe_customer_id = stripe_info.get("stripe_customer_id")
if not stripe_customer_id:
raise HTTPException(status_code=400, detail="Stripe customer ID not found")
logger.info(stripe_customer_id)
portal_session = stripe.billing_portal.Session.create(
customer=stripe_customer_id,
return_url=f"{WEB_DOMAIN}/admin/billing",
)
logger.info(portal_session)
return {"url": portal_session.url}
portal_url = fetch_customer_portal_session(tenant_id, return_url)
return {"url": portal_url}
except Exception as e:
logger.exception("Failed to create customer portal session")
raise HTTPException(status_code=500, detail=str(e))
@@ -104,15 +108,82 @@ async def create_customer_portal_session(
@router.post("/create-subscription-session")
async def create_subscription_session(
request: CreateSubscriptionSessionRequest | None = None,
_: User = Depends(current_admin_user),
) -> SubscriptionSessionResponse:
try:
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
if not tenant_id:
raise HTTPException(status_code=400, detail="Tenant ID not found")
session_id = fetch_stripe_checkout_session(tenant_id)
billing_period = request.billing_period if request else "monthly"
session_id = fetch_stripe_checkout_session(tenant_id, billing_period)
return SubscriptionSessionResponse(sessionId=session_id)
except Exception as e:
logger.exception("Failed to create resubscription session")
logger.exception("Failed to create subscription session")
raise HTTPException(status_code=500, detail=str(e))
@router.get("/stripe-publishable-key")
async def get_stripe_publishable_key() -> StripePublishableKeyResponse:
"""
Fetch the Stripe publishable key.
Priority: env var override (for testing) > S3 bucket (production).
This endpoint is public (no auth required) since publishable keys are safe to expose.
The key is cached in memory to avoid hitting S3 on every request.
"""
global _stripe_publishable_key_cache
# Fast path: return cached value without lock
if _stripe_publishable_key_cache:
return StripePublishableKeyResponse(
publishable_key=_stripe_publishable_key_cache
)
# Use lock to prevent concurrent S3 requests
async with _stripe_key_lock:
# Double-check after acquiring lock (another request may have populated cache)
if _stripe_publishable_key_cache:
return StripePublishableKeyResponse(
publishable_key=_stripe_publishable_key_cache
)
# Check for env var override first (for local testing with pk_test_* keys)
if STRIPE_PUBLISHABLE_KEY_OVERRIDE:
key = STRIPE_PUBLISHABLE_KEY_OVERRIDE.strip()
if not key.startswith("pk_"):
raise HTTPException(
status_code=500,
detail="Invalid Stripe publishable key format",
)
_stripe_publishable_key_cache = key
return StripePublishableKeyResponse(publishable_key=key)
# Fall back to S3 bucket
if not STRIPE_PUBLISHABLE_KEY_URL:
raise HTTPException(
status_code=500,
detail="Stripe publishable key is not configured",
)
try:
async with httpx.AsyncClient() as client:
response = await client.get(STRIPE_PUBLISHABLE_KEY_URL)
response.raise_for_status()
key = response.text.strip()
# Validate key format
if not key.startswith("pk_"):
raise HTTPException(
status_code=500,
detail="Invalid Stripe publishable key format",
)
_stripe_publishable_key_cache = key
return StripePublishableKeyResponse(publishable_key=key)
except httpx.HTTPError:
raise HTTPException(
status_code=500,
detail="Failed to fetch Stripe publishable key",
)

View File

@@ -1,4 +1,5 @@
from datetime import datetime
from typing import Literal
from pydantic import BaseModel
@@ -73,6 +74,12 @@ class SubscriptionSessionResponse(BaseModel):
sessionId: str
class CreateSubscriptionSessionRequest(BaseModel):
"""Request to create a subscription checkout session."""
billing_period: Literal["monthly", "annual"] = "monthly"
class TenantByDomainResponse(BaseModel):
tenant_id: str
number_of_users: int
@@ -98,3 +105,7 @@ class PendingUserSnapshot(BaseModel):
class ApproveUserRequest(BaseModel):
email: str
class StripePublishableKeyResponse(BaseModel):
publishable_key: str

View File

@@ -65,3 +65,9 @@ def get_gated_tenants() -> set[str]:
redis_client = get_redis_replica_client(tenant_id=ONYX_CLOUD_TENANT_ID)
gated_tenants_bytes = cast(set[bytes], redis_client.smembers(GATED_TENANTS_KEY))
return {tenant_id.decode("utf-8") for tenant_id in gated_tenants_bytes}
def is_tenant_gated(tenant_id: str) -> bool:
"""Fast O(1) check if tenant is in gated set (multi-tenant only)."""
redis_client = get_redis_replica_client(tenant_id=ONYX_CLOUD_TENANT_ID)
return bool(redis_client.sismember(GATED_TENANTS_KEY, tenant_id))

View File

@@ -97,10 +97,14 @@ def get_access_for_documents(
def _get_acl_for_user(user: User | None, db_session: Session) -> set[str]:
"""Returns a list of ACL entries that the user has access to. This is meant to be
used downstream to filter out documents that the user does not have access to. The
user should have access to a document if at least one entry in the document's ACL
matches one entry in the returned set.
"""Returns a list of ACL entries that the user has access to.
This is meant to be used downstream to filter out documents that the user
does not have access to. The user should have access to a document if at
least one entry in the document's ACL matches one entry in the returned set.
NOTE: These strings must be formatted in the same way as the output of
DocumentAccess::to_acl.
"""
if user:
return {prefix_user_email(user.email), PUBLIC_DOC_PAT}

View File

@@ -125,9 +125,11 @@ class DocumentAccess(ExternalAccess):
)
def to_acl(self) -> set[str]:
# the acl's emitted by this function are prefixed by type
# to get the native objects, access the member variables directly
"""Converts the access state to a set of formatted ACL strings.
NOTE: When querying for documents, the supplied ACL filter strings must
be formatted in the same way as this function.
"""
acl_set: set[str] = set()
for user_email in self.user_emails:
if user_email:

View File

@@ -11,6 +11,7 @@ from typing import Any
from typing import cast
from typing import Dict
from typing import List
from typing import Literal
from typing import Optional
from typing import Protocol
from typing import Tuple
@@ -1456,6 +1457,9 @@ def get_default_admin_user_emails_() -> list[str]:
STATE_TOKEN_AUDIENCE = "fastapi-users:oauth-state"
STATE_TOKEN_LIFETIME_SECONDS = 3600
CSRF_TOKEN_KEY = "csrftoken"
CSRF_TOKEN_COOKIE_NAME = "fastapiusersoauthcsrf"
class OAuth2AuthorizeResponse(BaseModel):
@@ -1463,13 +1467,19 @@ class OAuth2AuthorizeResponse(BaseModel):
def generate_state_token(
data: Dict[str, str], secret: SecretType, lifetime_seconds: int = 3600
data: Dict[str, str],
secret: SecretType,
lifetime_seconds: int = STATE_TOKEN_LIFETIME_SECONDS,
) -> str:
data["aud"] = STATE_TOKEN_AUDIENCE
return generate_jwt(data, secret, lifetime_seconds)
def generate_csrf_token() -> str:
return secrets.token_urlsafe(32)
# refer to https://github.com/fastapi-users/fastapi-users/blob/42ddc241b965475390e2bce887b084152ae1a2cd/fastapi_users/fastapi_users.py#L91
def create_onyx_oauth_router(
oauth_client: BaseOAuth2,
@@ -1498,6 +1508,13 @@ def get_oauth_router(
redirect_url: Optional[str] = None,
associate_by_email: bool = False,
is_verified_by_default: bool = False,
*,
csrf_token_cookie_name: str = CSRF_TOKEN_COOKIE_NAME,
csrf_token_cookie_path: str = "/",
csrf_token_cookie_domain: Optional[str] = None,
csrf_token_cookie_secure: Optional[bool] = None,
csrf_token_cookie_httponly: bool = True,
csrf_token_cookie_samesite: Optional[Literal["lax", "strict", "none"]] = "lax",
) -> APIRouter:
"""Generate a router with the OAuth routes."""
router = APIRouter()
@@ -1514,6 +1531,9 @@ def get_oauth_router(
route_name=callback_route_name,
)
if csrf_token_cookie_secure is None:
csrf_token_cookie_secure = WEB_DOMAIN.startswith("https")
@router.get(
"/authorize",
name=f"oauth:{oauth_client.name}.{backend.name}.authorize",
@@ -1521,8 +1541,10 @@ def get_oauth_router(
)
async def authorize(
request: Request,
response: Response,
redirect: bool = Query(False),
scopes: List[str] = Query(None),
) -> OAuth2AuthorizeResponse:
) -> Response | OAuth2AuthorizeResponse:
referral_source = request.cookies.get("referral_source", None)
if redirect_url is not None:
@@ -1532,9 +1554,11 @@ def get_oauth_router(
next_url = request.query_params.get("next", "/")
csrf_token = generate_csrf_token()
state_data: Dict[str, str] = {
"next_url": next_url,
"referral_source": referral_source or "default_referral",
CSRF_TOKEN_KEY: csrf_token,
}
state = generate_state_token(state_data, state_secret)
@@ -1551,6 +1575,31 @@ def get_oauth_router(
authorization_url, {"access_type": "offline", "prompt": "consent"}
)
if redirect:
redirect_response = RedirectResponse(authorization_url, status_code=302)
redirect_response.set_cookie(
key=csrf_token_cookie_name,
value=csrf_token,
max_age=STATE_TOKEN_LIFETIME_SECONDS,
path=csrf_token_cookie_path,
domain=csrf_token_cookie_domain,
secure=csrf_token_cookie_secure,
httponly=csrf_token_cookie_httponly,
samesite=csrf_token_cookie_samesite,
)
return redirect_response
response.set_cookie(
key=csrf_token_cookie_name,
value=csrf_token,
max_age=STATE_TOKEN_LIFETIME_SECONDS,
path=csrf_token_cookie_path,
domain=csrf_token_cookie_domain,
secure=csrf_token_cookie_secure,
httponly=csrf_token_cookie_httponly,
samesite=csrf_token_cookie_samesite,
)
return OAuth2AuthorizeResponse(authorization_url=authorization_url)
@log_function_time(print_only=True)
@@ -1600,7 +1649,33 @@ def get_oauth_router(
try:
state_data = decode_jwt(state, state_secret, [STATE_TOKEN_AUDIENCE])
except jwt.DecodeError:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=getattr(
ErrorCode, "ACCESS_TOKEN_DECODE_ERROR", "ACCESS_TOKEN_DECODE_ERROR"
),
)
except jwt.ExpiredSignatureError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=getattr(
ErrorCode,
"ACCESS_TOKEN_ALREADY_EXPIRED",
"ACCESS_TOKEN_ALREADY_EXPIRED",
),
)
cookie_csrf_token = request.cookies.get(csrf_token_cookie_name)
state_csrf_token = state_data.get(CSRF_TOKEN_KEY)
if (
not cookie_csrf_token
or not state_csrf_token
or not secrets.compare_digest(cookie_csrf_token, state_csrf_token)
):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=getattr(ErrorCode, "OAUTH_INVALID_STATE", "OAUTH_INVALID_STATE"),
)
next_url = state_data.get("next_url", "/")
referral_source = state_data.get("referral_source", None)

View File

@@ -26,10 +26,13 @@ from onyx.background.celery.celery_utils import celery_is_worker_primary
from onyx.background.celery.celery_utils import make_probe_path
from onyx.background.celery.tasks.vespa.document_sync import DOCUMENT_SYNC_PREFIX
from onyx.background.celery.tasks.vespa.document_sync import DOCUMENT_SYNC_TASKSET_KEY
from onyx.configs.app_configs import ENABLE_OPENSEARCH_FOR_ONYX
from onyx.configs.app_configs import ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
from onyx.configs.constants import ONYX_CLOUD_CELERY_TASK_PREFIX
from onyx.configs.constants import OnyxRedisLocks
from onyx.db.engine.sql_engine import get_sqlalchemy_engine
from onyx.document_index.opensearch.client import (
wait_for_opensearch_with_timeout,
)
from onyx.document_index.vespa.shared_utils.utils import wait_for_vespa_with_timeout
from onyx.httpx.httpx_pool import HttpxPool
from onyx.redis.redis_connector import RedisConnector
@@ -516,15 +519,17 @@ def wait_for_vespa_or_shutdown(sender: Any, **kwargs: Any) -> None:
"""Waits for Vespa to become ready subject to a timeout.
Raises WorkerShutdown if the timeout is reached."""
if ENABLE_OPENSEARCH_FOR_ONYX:
# TODO(andrei): Do some similar liveness checking for OpenSearch.
return
if not wait_for_vespa_with_timeout():
msg = "Vespa: Readiness probe did not succeed within the timeout. Exiting..."
msg = "[Vespa] Readiness probe did not succeed within the timeout. Exiting..."
logger.error(msg)
raise WorkerShutdown(msg)
if ENABLE_OPENSEARCH_INDEXING_FOR_ONYX:
if not wait_for_opensearch_with_timeout():
msg = "[OpenSearch] Readiness probe did not succeed within the timeout. Exiting..."
logger.error(msg)
raise WorkerShutdown(msg)
# File for validating worker liveness
class LivenessProbe(bootsteps.StartStopStep):

View File

@@ -87,7 +87,7 @@ from onyx.db.models import SearchSettings
from onyx.db.search_settings import get_current_search_settings
from onyx.db.search_settings import get_secondary_search_settings
from onyx.db.swap_index import check_and_perform_index_swap
from onyx.document_index.factory import get_default_document_index
from onyx.document_index.factory import get_all_document_indices
from onyx.file_store.document_batch_storage import DocumentBatchStorage
from onyx.file_store.document_batch_storage import get_document_batch_storage
from onyx.httpx.httpx_pool import HttpxPool
@@ -1436,7 +1436,7 @@ def _docprocessing_task(
callback=callback,
)
document_index = get_default_document_index(
document_indices = get_all_document_indices(
index_attempt.search_settings,
None,
httpx_client=HttpxPool.get("vespa"),
@@ -1473,7 +1473,7 @@ def _docprocessing_task(
# real work happens here!
index_pipeline_result = run_indexing_pipeline(
embedder=embedding_model,
document_index=document_index,
document_indices=document_indices,
ignore_time_skip=True, # Documents are already filtered during extraction
db_session=db_session,
tenant_id=tenant_id,

View File

@@ -25,7 +25,7 @@ from onyx.db.document_set import fetch_document_sets_for_document
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.relationships import delete_document_references_from_kg
from onyx.db.search_settings import get_active_search_settings
from onyx.document_index.factory import get_default_document_index
from onyx.document_index.factory import get_all_document_indices
from onyx.document_index.interfaces import VespaDocumentFields
from onyx.httpx.httpx_pool import HttpxPool
from onyx.redis.redis_pool import get_redis_client
@@ -97,13 +97,17 @@ def document_by_cc_pair_cleanup_task(
action = "skip"
active_search_settings = get_active_search_settings(db_session)
doc_index = get_default_document_index(
# This flow is for updates and deletion so we get all indices.
document_indices = get_all_document_indices(
active_search_settings.primary,
active_search_settings.secondary,
httpx_client=HttpxPool.get("vespa"),
)
retry_index = RetryDocumentIndex(doc_index)
retry_document_indices: list[RetryDocumentIndex] = [
RetryDocumentIndex(document_index)
for document_index in document_indices
]
count = get_document_connector_count(db_session, document_id)
if count == 1:
@@ -113,11 +117,12 @@ def document_by_cc_pair_cleanup_task(
chunk_count = fetch_chunk_count_for_document(document_id, db_session)
_ = retry_index.delete_single(
document_id,
tenant_id=tenant_id,
chunk_count=chunk_count,
)
for retry_document_index in retry_document_indices:
_ = retry_document_index.delete_single(
document_id,
tenant_id=tenant_id,
chunk_count=chunk_count,
)
delete_document_references_from_kg(
db_session=db_session,
@@ -155,14 +160,18 @@ def document_by_cc_pair_cleanup_task(
hidden=doc.hidden,
)
# update Vespa. OK if doc doesn't exist. Raises exception otherwise.
retry_index.update_single(
document_id,
tenant_id=tenant_id,
chunk_count=doc.chunk_count,
fields=fields,
user_fields=None,
)
for retry_document_index in retry_document_indices:
# TODO(andrei): Previously there was a comment here saying
# it was ok if a doc did not exist in the document index. I
# don't agree with that claim, so keep an eye on this task
# to see if this raises.
retry_document_index.update_single(
document_id,
tenant_id=tenant_id,
chunk_count=doc.chunk_count,
fields=fields,
user_fields=None,
)
# there are still other cc_pair references to the doc, so just resync to Vespa
delete_document_by_connector_credential_pair__no_commit(

View File

@@ -12,6 +12,7 @@ from retry import retry
from sqlalchemy import select
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_redis import celery_get_queue_length
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
from onyx.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
from onyx.configs.app_configs import MANAGED_VESPA
@@ -19,12 +20,14 @@ from onyx.configs.app_configs import VESPA_CLOUD_CERT_PATH
from onyx.configs.app_configs import VESPA_CLOUD_KEY_PATH
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_TASK_EXPIRES
from onyx.configs.constants import CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisLocks
from onyx.configs.constants import USER_FILE_PROCESSING_MAX_QUEUE_DEPTH
from onyx.connectors.file.connector import LocalFileConnector
from onyx.connectors.models import Document
from onyx.db.engine.sql_engine import get_session_with_current_tenant
@@ -32,7 +35,7 @@ from onyx.db.enums import UserFileStatus
from onyx.db.models import UserFile
from onyx.db.search_settings import get_active_search_settings
from onyx.db.search_settings import get_active_search_settings_list
from onyx.document_index.factory import get_default_document_index
from onyx.document_index.factory import get_all_document_indices
from onyx.document_index.interfaces import VespaDocumentUserFields
from onyx.document_index.vespa_constants import DOCUMENT_ID_ENDPOINT
from onyx.file_store.file_store import get_default_file_store
@@ -53,6 +56,17 @@ def _user_file_lock_key(user_file_id: str | UUID) -> str:
return f"{OnyxRedisLocks.USER_FILE_PROCESSING_LOCK_PREFIX}:{user_file_id}"
def _user_file_queued_key(user_file_id: str | UUID) -> str:
"""Key that exists while a process_single_user_file task is sitting in the queue.
The beat generator sets this with a TTL equal to CELERY_USER_FILE_PROCESSING_TASK_EXPIRES
before enqueuing and the worker deletes it as its first action. This prevents
the beat from adding duplicate tasks for files that already have a live task
in flight.
"""
return f"{OnyxRedisLocks.USER_FILE_QUEUED_PREFIX}:{user_file_id}"
def _user_file_project_sync_lock_key(user_file_id: str | UUID) -> str:
return f"{OnyxRedisLocks.USER_FILE_PROJECT_SYNC_LOCK_PREFIX}:{user_file_id}"
@@ -116,7 +130,24 @@ def _get_document_chunk_count(
def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
"""Scan for user files with PROCESSING status and enqueue per-file tasks.
Uses direct Redis locks to avoid overlapping runs.
Three mechanisms prevent queue runaway:
1. **Queue depth backpressure** if the broker queue already has more than
USER_FILE_PROCESSING_MAX_QUEUE_DEPTH items we skip this beat cycle
entirely. Workers are clearly behind; adding more tasks would only make
the backlog worse.
2. **Per-file queued guard** before enqueuing a task we set a short-lived
Redis key (TTL = CELERY_USER_FILE_PROCESSING_TASK_EXPIRES). If that key
already exists the file already has a live task in the queue, so we skip
it. The worker deletes the key the moment it picks up the task so the
next beat cycle can re-enqueue if the file is still PROCESSING.
3. **Task expiry** every enqueued task carries an `expires` value equal to
CELERY_USER_FILE_PROCESSING_TASK_EXPIRES. If a task is still sitting in
the queue after that deadline, Celery discards it without touching the DB.
This is a belt-and-suspenders defence: even if the guard key is lost (e.g.
Redis restart), stale tasks evict themselves rather than piling up forever.
"""
task_logger.info("check_user_file_processing - Starting")
@@ -131,7 +162,21 @@ def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
return None
enqueued = 0
skipped_guard = 0
try:
# --- Protection 1: queue depth backpressure ---
r_celery = self.app.broker_connection().channel().client # type: ignore
queue_len = celery_get_queue_length(
OnyxCeleryQueues.USER_FILE_PROCESSING, r_celery
)
if queue_len > USER_FILE_PROCESSING_MAX_QUEUE_DEPTH:
task_logger.warning(
f"check_user_file_processing - Queue depth {queue_len} exceeds "
f"{USER_FILE_PROCESSING_MAX_QUEUE_DEPTH}, skipping enqueue for "
f"tenant={tenant_id}"
)
return None
with get_session_with_current_tenant() as db_session:
user_file_ids = (
db_session.execute(
@@ -144,12 +189,35 @@ def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
)
for user_file_id in user_file_ids:
self.app.send_task(
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
kwargs={"user_file_id": str(user_file_id), "tenant_id": tenant_id},
queue=OnyxCeleryQueues.USER_FILE_PROCESSING,
priority=OnyxCeleryPriority.HIGH,
# --- Protection 2: per-file queued guard ---
queued_key = _user_file_queued_key(user_file_id)
guard_set = redis_client.set(
queued_key,
1,
ex=CELERY_USER_FILE_PROCESSING_TASK_EXPIRES,
nx=True,
)
if not guard_set:
skipped_guard += 1
continue
# --- Protection 3: task expiry ---
# If task submission fails, clear the guard immediately so the
# next beat cycle can retry enqueuing this file.
try:
self.app.send_task(
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
kwargs={
"user_file_id": str(user_file_id),
"tenant_id": tenant_id,
},
queue=OnyxCeleryQueues.USER_FILE_PROCESSING,
priority=OnyxCeleryPriority.HIGH,
expires=CELERY_USER_FILE_PROCESSING_TASK_EXPIRES,
)
except Exception:
redis_client.delete(queued_key)
raise
enqueued += 1
finally:
@@ -157,7 +225,8 @@ def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
lock.release()
task_logger.info(
f"check_user_file_processing - Enqueued {enqueued} tasks for tenant={tenant_id}"
f"check_user_file_processing - Enqueued {enqueued} skipped_guard={skipped_guard} "
f"tasks for tenant={tenant_id}"
)
return None
@@ -172,6 +241,12 @@ def process_single_user_file(self: Task, *, user_file_id: str, tenant_id: str) -
start = time.monotonic()
redis_client = get_redis_client(tenant_id=tenant_id)
# Clear the "queued" guard set by the beat generator so that the next beat
# cycle can re-enqueue this file if it is still in PROCESSING state after
# this task completes or fails.
redis_client.delete(_user_file_queued_key(user_file_id))
file_lock: RedisLock = redis_client.lock(
_user_file_lock_key(user_file_id),
timeout=CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT,
@@ -244,7 +319,8 @@ def process_single_user_file(self: Task, *, user_file_id: str, tenant_id: str) -
search_settings=current_search_settings,
)
document_index = get_default_document_index(
# This flow is for indexing so we get all indices.
document_indices = get_all_document_indices(
current_search_settings,
None,
httpx_client=HttpxPool.get("vespa"),
@@ -258,7 +334,7 @@ def process_single_user_file(self: Task, *, user_file_id: str, tenant_id: str) -
# real work happens here!
index_pipeline_result = run_indexing_pipeline(
embedder=embedding_model,
document_index=document_index,
document_indices=document_indices,
ignore_time_skip=True,
db_session=db_session,
tenant_id=tenant_id,
@@ -412,12 +488,16 @@ def process_single_user_file_delete(
httpx_init_vespa_pool(20)
active_search_settings = get_active_search_settings(db_session)
document_index = get_default_document_index(
# This flow is for deletion so we get all indices.
document_indices = get_all_document_indices(
search_settings=active_search_settings.primary,
secondary_search_settings=active_search_settings.secondary,
httpx_client=HttpxPool.get("vespa"),
)
retry_index = RetryDocumentIndex(document_index)
retry_document_indices: list[RetryDocumentIndex] = [
RetryDocumentIndex(document_index)
for document_index in document_indices
]
index_name = active_search_settings.primary.index_name
selection = f"{index_name}.document_id=='{user_file_id}'"
@@ -438,11 +518,12 @@ def process_single_user_file_delete(
else:
chunk_count = user_file.chunk_count
retry_index.delete_single(
doc_id=user_file_id,
tenant_id=tenant_id,
chunk_count=chunk_count,
)
for retry_document_index in retry_document_indices:
retry_document_index.delete_single(
doc_id=user_file_id,
tenant_id=tenant_id,
chunk_count=chunk_count,
)
# 2) Delete the user-uploaded file content from filestore (blob + metadata)
file_store = get_default_file_store()
@@ -564,12 +645,16 @@ def process_single_user_file_project_sync(
httpx_init_vespa_pool(20)
active_search_settings = get_active_search_settings(db_session)
doc_index = get_default_document_index(
# This flow is for updates so we get all indices.
document_indices = get_all_document_indices(
search_settings=active_search_settings.primary,
secondary_search_settings=active_search_settings.secondary,
httpx_client=HttpxPool.get("vespa"),
)
retry_index = RetryDocumentIndex(doc_index)
retry_document_indices: list[RetryDocumentIndex] = [
RetryDocumentIndex(document_index)
for document_index in document_indices
]
user_file = db_session.get(UserFile, _as_uuid(user_file_id))
if not user_file:
@@ -579,13 +664,14 @@ def process_single_user_file_project_sync(
return None
project_ids = [project.id for project in user_file.projects]
retry_index.update_single(
doc_id=str(user_file.id),
tenant_id=tenant_id,
chunk_count=user_file.chunk_count,
fields=None,
user_fields=VespaDocumentUserFields(user_projects=project_ids),
)
for retry_document_index in retry_document_indices:
retry_document_index.update_single(
doc_id=str(user_file.id),
tenant_id=tenant_id,
chunk_count=user_file.chunk_count,
fields=None,
user_fields=VespaDocumentUserFields(user_projects=project_ids),
)
task_logger.info(
f"process_single_user_file_project_sync - User file id={user_file_id}"

View File

@@ -21,6 +21,8 @@ from onyx.utils.logger import setup_logger
DOCUMENT_SYNC_PREFIX = "documentsync"
DOCUMENT_SYNC_FENCE_KEY = f"{DOCUMENT_SYNC_PREFIX}_fence"
DOCUMENT_SYNC_TASKSET_KEY = f"{DOCUMENT_SYNC_PREFIX}_taskset"
FENCE_TTL = 7 * 24 * 60 * 60 # 7 days - defensive TTL to prevent memory leaks
TASKSET_TTL = FENCE_TTL
logger = setup_logger()
@@ -50,7 +52,7 @@ def set_document_sync_fence(r: Redis, payload: int | None) -> None:
r.delete(DOCUMENT_SYNC_FENCE_KEY)
return
r.set(DOCUMENT_SYNC_FENCE_KEY, payload)
r.set(DOCUMENT_SYNC_FENCE_KEY, payload, ex=FENCE_TTL)
r.sadd(OnyxRedisConstants.ACTIVE_FENCES, DOCUMENT_SYNC_FENCE_KEY)
@@ -110,6 +112,7 @@ def generate_document_sync_tasks(
# Add to the tracking taskset in Redis BEFORE creating the celery task
r.sadd(DOCUMENT_SYNC_TASKSET_KEY, custom_task_id)
r.expire(DOCUMENT_SYNC_TASKSET_KEY, TASKSET_TTL)
# Create the Celery task
celery_app.send_task(

View File

@@ -49,7 +49,7 @@ from onyx.db.search_settings import get_active_search_settings
from onyx.db.sync_record import cleanup_sync_records
from onyx.db.sync_record import insert_sync_record
from onyx.db.sync_record import update_sync_record_status
from onyx.document_index.factory import get_default_document_index
from onyx.document_index.factory import get_all_document_indices
from onyx.document_index.interfaces import VespaDocumentFields
from onyx.httpx.httpx_pool import HttpxPool
from onyx.redis.redis_document_set import RedisDocumentSet
@@ -70,6 +70,8 @@ logger = setup_logger()
# celery auto associates tasks created inside another task,
# which bloats the result metadata considerably. trail=False prevents this.
# TODO(andrei): Rename all these kinds of functions from *vespa* to a more
# generic *document_index*.
@shared_task(
name=OnyxCeleryTask.CHECK_FOR_VESPA_SYNC_TASK,
ignore_result=True,
@@ -465,13 +467,17 @@ def vespa_metadata_sync_task(self: Task, document_id: str, *, tenant_id: str) ->
try:
with get_session_with_current_tenant() as db_session:
active_search_settings = get_active_search_settings(db_session)
doc_index = get_default_document_index(
# This flow is for updates so we get all indices.
document_indices = get_all_document_indices(
search_settings=active_search_settings.primary,
secondary_search_settings=active_search_settings.secondary,
httpx_client=HttpxPool.get("vespa"),
)
retry_index = RetryDocumentIndex(doc_index)
retry_document_indices: list[RetryDocumentIndex] = [
RetryDocumentIndex(document_index)
for document_index in document_indices
]
doc = get_document(document_id, db_session)
if not doc:
@@ -500,14 +506,18 @@ def vespa_metadata_sync_task(self: Task, document_id: str, *, tenant_id: str) ->
# aggregated_boost_factor=doc.aggregated_boost_factor,
)
# update Vespa. OK if doc doesn't exist. Raises exception otherwise.
retry_index.update_single(
document_id,
tenant_id=tenant_id,
chunk_count=doc.chunk_count,
fields=fields,
user_fields=None,
)
for retry_document_index in retry_document_indices:
# TODO(andrei): Previously there was a comment here saying
# it was ok if a doc did not exist in the document index. I
# don't agree with that claim, so keep an eye on this task
# to see if this raises.
retry_document_index.update_single(
document_id,
tenant_id=tenant_id,
chunk_count=doc.chunk_count,
fields=fields,
user_fields=None,
)
# update db last. Worst case = we crash right before this and
# the sync might repeat again later

View File

@@ -7,6 +7,7 @@ from typing import Any
from onyx.chat.citation_processor import CitationMapping
from onyx.chat.emitter import Emitter
from onyx.context.search.models import SearchDoc
from onyx.server.query_and_chat.placement import Placement
from onyx.server.query_and_chat.streaming_models import OverallStop
from onyx.server.query_and_chat.streaming_models import Packet
@@ -15,6 +16,11 @@ from onyx.tools.models import ToolCallInfo
from onyx.utils.threadpool_concurrency import run_in_background
from onyx.utils.threadpool_concurrency import wait_on_background
# Type alias for search doc deduplication key
# Simple key: just document_id (str)
# Full key: (document_id, chunk_ind, match_highlights)
SearchDocKey = str | tuple[str, int, tuple[str, ...]]
class ChatStateContainer:
"""Container for accumulating state during LLM loop execution.
@@ -40,6 +46,10 @@ class ChatStateContainer:
# True if this turn is a clarification question (deep research flow)
self.is_clarification: bool = False
# Note: LLM cost tracking is now handled in multi_llm.py
# Search doc collection - maps dedup key to SearchDoc for all docs from tool calls
self._all_search_docs: dict[SearchDocKey, SearchDoc] = {}
# Track which citation numbers were actually emitted during streaming
self._emitted_citations: set[int] = set()
def add_tool_call(self, tool_call: ToolCallInfo) -> None:
"""Add a tool call to the accumulated state."""
@@ -91,6 +101,54 @@ class ChatStateContainer:
with self._lock:
return self.is_clarification
@staticmethod
def create_search_doc_key(
search_doc: SearchDoc, use_simple_key: bool = True
) -> SearchDocKey:
"""Create a unique key for a SearchDoc for deduplication.
Args:
search_doc: The SearchDoc to create a key for
use_simple_key: If True (default), use only document_id for deduplication.
If False, include chunk_ind and match_highlights so that the same
document/chunk with different highlights are stored separately.
"""
if use_simple_key:
return search_doc.document_id
match_highlights_tuple = tuple(sorted(search_doc.match_highlights or []))
return (search_doc.document_id, search_doc.chunk_ind, match_highlights_tuple)
def add_search_docs(
self, search_docs: list[SearchDoc], use_simple_key: bool = True
) -> None:
"""Add search docs to the accumulated collection with deduplication.
Args:
search_docs: List of SearchDoc objects to add
use_simple_key: If True (default), deduplicate by document_id only.
If False, deduplicate by document_id + chunk_ind + match_highlights.
"""
with self._lock:
for doc in search_docs:
key = self.create_search_doc_key(doc, use_simple_key)
if key not in self._all_search_docs:
self._all_search_docs[key] = doc
def get_all_search_docs(self) -> dict[SearchDocKey, SearchDoc]:
"""Thread-safe getter for all accumulated search docs (returns a copy)."""
with self._lock:
return self._all_search_docs.copy()
def add_emitted_citation(self, citation_num: int) -> None:
"""Add a citation number that was actually emitted during streaming."""
with self._lock:
self._emitted_citations.add(citation_num)
def get_emitted_citations(self) -> set[int]:
"""Thread-safe getter for emitted citations (returns a copy)."""
with self._lock:
return self._emitted_citations.copy()
def run_chat_loop_with_state_containers(
func: Callable[..., None],

View File

@@ -53,6 +53,50 @@ def update_citation_processor_from_tool_response(
citation_processor.update_citation_mapping(citation_to_doc)
def extract_citation_order_from_text(text: str) -> list[int]:
"""Extract citation numbers from text in order of first appearance.
Parses citation patterns like [1], [1, 2], [[1]], 【1】 etc. and returns
the citation numbers in the order they first appear in the text.
Args:
text: The text containing citations
Returns:
List of citation numbers in order of first appearance (no duplicates)
"""
# Same pattern used in collapse_citations and DynamicCitationProcessor
# Group 2 captures the number in double bracket format: [[1]], 【【1】】
# Group 4 captures the numbers in single bracket format: [1], [1, 2]
citation_pattern = re.compile(
r"([\[【[]{2}(\d+)[\]】]]{2})|([\[【[]([\d]+(?: *, *\d+)*)[\]】]])"
)
seen: set[int] = set()
order: list[int] = []
for match in citation_pattern.finditer(text):
# Group 2 is for double bracket single number, group 4 is for single bracket
if match.group(2):
nums_str = match.group(2)
elif match.group(4):
nums_str = match.group(4)
else:
continue
for num_str in nums_str.split(","):
num_str = num_str.strip()
if num_str:
try:
num = int(num_str)
if num not in seen:
seen.add(num)
order.append(num)
except ValueError:
continue
return order
def collapse_citations(
answer_text: str,
existing_citation_mapping: CitationMapping,

View File

@@ -9,6 +9,7 @@ from onyx.chat.citation_processor import CitationMode
from onyx.chat.citation_processor import DynamicCitationProcessor
from onyx.chat.citation_utils import update_citation_processor_from_tool_response
from onyx.chat.emitter import Emitter
from onyx.chat.llm_step import extract_tool_calls_from_response_text
from onyx.chat.llm_step import run_llm_step
from onyx.chat.models import ChatMessageSimple
from onyx.chat.models import ExtractedProjectFiles
@@ -38,11 +39,13 @@ from onyx.tools.built_in_tools import CITEABLE_TOOLS_NAMES
from onyx.tools.built_in_tools import STOPPING_TOOLS_NAMES
from onyx.tools.interface import Tool
from onyx.tools.models import ToolCallInfo
from onyx.tools.models import ToolCallKickoff
from onyx.tools.models import ToolResponse
from onyx.tools.tool_implementations.images.models import (
FinalImageGenerationResponse,
)
from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.tools.tool_implementations.web_search.utils import extract_url_snippet_map
from onyx.tools.tool_implementations.web_search.web_search_tool import WebSearchTool
from onyx.tools.tool_runner import run_tool_calls
from onyx.tracing.framework.create import trace
@@ -51,6 +54,78 @@ from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
def _try_fallback_tool_extraction(
llm_step_result: LlmStepResult,
tool_choice: ToolChoiceOptions,
fallback_extraction_attempted: bool,
tool_defs: list[dict],
turn_index: int,
) -> tuple[LlmStepResult, bool]:
"""Attempt to extract tool calls from response text as a fallback.
This is a last resort fallback for low quality LLMs or those that don't have
tool calling from the serving layer. Also triggers if there's reasoning but
no answer and no tool calls.
Args:
llm_step_result: The result from the LLM step
tool_choice: The tool choice option used for this step
fallback_extraction_attempted: Whether fallback extraction was already attempted
tool_defs: List of tool definitions
turn_index: The current turn index for placement
Returns:
Tuple of (possibly updated LlmStepResult, whether fallback was attempted this call)
"""
if fallback_extraction_attempted:
return llm_step_result, False
no_tool_calls = (
not llm_step_result.tool_calls or len(llm_step_result.tool_calls) == 0
)
reasoning_but_no_answer_or_tools = (
llm_step_result.reasoning and not llm_step_result.answer and no_tool_calls
)
should_try_fallback = (
tool_choice == ToolChoiceOptions.REQUIRED and no_tool_calls
) or reasoning_but_no_answer_or_tools
if not should_try_fallback:
return llm_step_result, False
# Try to extract from answer first, then fall back to reasoning
extracted_tool_calls: list[ToolCallKickoff] = []
if llm_step_result.answer:
extracted_tool_calls = extract_tool_calls_from_response_text(
response_text=llm_step_result.answer,
tool_definitions=tool_defs,
placement=Placement(turn_index=turn_index),
)
if not extracted_tool_calls and llm_step_result.reasoning:
extracted_tool_calls = extract_tool_calls_from_response_text(
response_text=llm_step_result.reasoning,
tool_definitions=tool_defs,
placement=Placement(turn_index=turn_index),
)
if extracted_tool_calls:
logger.info(
f"Extracted {len(extracted_tool_calls)} tool call(s) from response text "
f"as fallback (tool_choice was REQUIRED but no tool calls returned)"
)
return (
LlmStepResult(
reasoning=llm_step_result.reasoning,
answer=llm_step_result.answer,
tool_calls=extracted_tool_calls,
),
True,
)
return llm_step_result, True
# Hardcoded oppinionated value, might breaks down to something like:
# Cycle 1: Calls web_search for something
# Cycle 2: Calls open_url for some results
@@ -352,6 +427,7 @@ def run_llm_loop(
ran_image_gen: bool = False
just_ran_web_search: bool = False
has_called_search_tool: bool = False
fallback_extraction_attempted: bool = False
citation_mapping: dict[int, str] = {} # Maps citation_num -> document_id/URL
default_base_system_prompt: str = get_default_base_system_prompt(db_session)
@@ -378,12 +454,16 @@ def run_llm_loop(
# The section below calculates the available tokens for history a bit more accurately
# now that project files are loaded in.
if persona and persona.replace_base_system_prompt and persona.system_prompt:
if persona and persona.replace_base_system_prompt:
# Handles the case where user has checked off the "Replace base system prompt" checkbox
system_prompt = ChatMessageSimple(
message=persona.system_prompt,
token_count=token_counter(persona.system_prompt),
message_type=MessageType.SYSTEM,
system_prompt = (
ChatMessageSimple(
message=persona.system_prompt,
token_count=token_counter(persona.system_prompt),
message_type=MessageType.SYSTEM,
)
if persona.system_prompt
else None
)
custom_agent_prompt_msg = None
else:
@@ -470,10 +550,11 @@ def run_llm_loop(
# This calls the LLM, yields packets (reasoning, answers, etc.) and returns the result
# It also pre-processes the tool calls in preparation for running them
tool_defs = [tool.tool_definition() for tool in final_tools]
llm_step_result, has_reasoned = run_llm_step(
emitter=emitter,
history=truncated_message_history,
tool_definitions=[tool.tool_definition() for tool in final_tools],
tool_definitions=tool_defs,
tool_choice=tool_choice,
llm=llm,
placement=Placement(turn_index=llm_cycle_count + reasoning_cycles),
@@ -488,6 +569,19 @@ def run_llm_loop(
if has_reasoned:
reasoning_cycles += 1
# Fallback extraction for LLMs that don't support tool calling natively or are lower quality
# and might incorrectly output tool calls in other channels
llm_step_result, attempted = _try_fallback_tool_extraction(
llm_step_result=llm_step_result,
tool_choice=tool_choice,
fallback_extraction_attempted=fallback_extraction_attempted,
tool_defs=tool_defs,
turn_index=llm_cycle_count + reasoning_cycles,
)
if attempted:
# To prevent the case of excessive looping with bad models, we only allow one fallback attempt
fallback_extraction_attempted = True
# Save citation mapping after each LLM step for incremental state updates
state_container.set_citation_mapping(citation_processor.citation_to_doc)
@@ -523,6 +617,7 @@ def run_llm_loop(
next_citation_num=citation_processor.get_next_citation_number(),
max_concurrent_tools=None,
skip_search_query_expansion=has_called_search_tool,
url_snippet_map=extract_url_snippet_map(gathered_documents or []),
)
tool_responses = parallel_tool_call_results.tool_responses
citation_mapping = parallel_tool_call_results.updated_citation_mapping
@@ -561,8 +656,15 @@ def run_llm_loop(
# Extract search_docs if this is a search tool response
search_docs = None
displayed_docs = None
if isinstance(tool_response.rich_response, SearchDocsResponse):
search_docs = tool_response.rich_response.search_docs
displayed_docs = tool_response.rich_response.displayed_docs
# Add ALL search docs to state container for DB persistence
if search_docs:
state_container.add_search_docs(search_docs)
if gathered_documents:
gathered_documents.extend(search_docs)
else:
@@ -580,6 +682,12 @@ def run_llm_loop(
):
generated_images = tool_response.rich_response.generated_images
saved_response = (
tool_response.rich_response
if isinstance(tool_response.rich_response, str)
else tool_response.llm_facing_response
)
tool_call_info = ToolCallInfo(
parent_tool_call_id=None, # Top-level tool calls are attached to the chat message
turn_index=llm_cycle_count + reasoning_cycles,
@@ -589,8 +697,8 @@ def run_llm_loop(
tool_id=tool.id,
reasoning_tokens=llm_step_result.reasoning, # All tool calls from this loop share the same reasoning
tool_call_arguments=tool_call.tool_args,
tool_call_response=tool_response.llm_facing_response,
search_docs=search_docs,
tool_call_response=saved_response,
search_docs=displayed_docs or search_docs,
generated_images=generated_images,
)
# Add to state container for partial save support
@@ -645,7 +753,12 @@ def run_llm_loop(
should_cite_documents = True
if not llm_step_result or not llm_step_result.answer:
raise RuntimeError("LLM did not return an answer.")
raise RuntimeError(
"The LLM did not return an answer. "
"Typically this is an issue with LLMs that do not support tool calling natively, "
"or the model serving API is not configured correctly. "
"This may also happen with models that are lower quality outputting invalid tool calls."
)
emitter.emit(
Packet(

View File

@@ -14,6 +14,7 @@ from onyx.chat.emitter import Emitter
from onyx.chat.models import ChatMessageSimple
from onyx.chat.models import LlmStepResult
from onyx.configs.app_configs import LOG_ONYX_MODEL_INTERACTIONS
from onyx.configs.app_configs import PROMPT_CACHE_CHAT_HISTORY
from onyx.configs.constants import MessageType
from onyx.context.search.models import SearchDoc
from onyx.file_store.models import ChatFileType
@@ -49,6 +50,7 @@ from onyx.tools.models import ToolCallKickoff
from onyx.tracing.framework.create import generation_span
from onyx.utils.b64 import get_image_type_from_bytes
from onyx.utils.logger import setup_logger
from onyx.utils.text_processing import find_all_json_objects
logger = setup_logger()
@@ -278,6 +280,144 @@ def _extract_tool_call_kickoffs(
return tool_calls
def extract_tool_calls_from_response_text(
response_text: str | None,
tool_definitions: list[dict],
placement: Placement,
) -> list[ToolCallKickoff]:
"""Extract tool calls from LLM response text by matching JSON against tool definitions.
This is a fallback mechanism for when the LLM was expected to return tool calls
but didn't use the proper tool call format. It searches for JSON objects in the
response text that match the structure of available tools.
Args:
response_text: The LLM's text response to search for tool calls
tool_definitions: List of tool definitions to match against
placement: Placement information for the tool calls
Returns:
List of ToolCallKickoff objects for any matched tool calls
"""
if not response_text or not tool_definitions:
return []
# Build a map of tool names to their definitions
tool_name_to_def: dict[str, dict] = {}
for tool_def in tool_definitions:
if tool_def.get("type") == "function" and "function" in tool_def:
func_def = tool_def["function"]
tool_name = func_def.get("name")
if tool_name:
tool_name_to_def[tool_name] = func_def
if not tool_name_to_def:
return []
# Find all JSON objects in the response text
json_objects = find_all_json_objects(response_text)
tool_calls: list[ToolCallKickoff] = []
tab_index = 0
for json_obj in json_objects:
matched_tool_call = _try_match_json_to_tool(json_obj, tool_name_to_def)
if matched_tool_call:
tool_name, tool_args = matched_tool_call
tool_calls.append(
ToolCallKickoff(
tool_call_id=f"extracted_{uuid.uuid4().hex[:8]}",
tool_name=tool_name,
tool_args=tool_args,
placement=Placement(
turn_index=placement.turn_index,
tab_index=tab_index,
sub_turn_index=placement.sub_turn_index,
),
)
)
tab_index += 1
logger.info(
f"Extracted {len(tool_calls)} tool call(s) from response text as fallback"
)
return tool_calls
def _try_match_json_to_tool(
json_obj: dict[str, Any],
tool_name_to_def: dict[str, dict],
) -> tuple[str, dict[str, Any]] | None:
"""Try to match a JSON object to a tool definition.
Supports several formats:
1. Direct tool call format: {"name": "tool_name", "arguments": {...}}
2. Function call format: {"function": {"name": "tool_name", "arguments": {...}}}
3. Tool name as key: {"tool_name": {...arguments...}}
4. Arguments matching a tool's parameter schema
Args:
json_obj: The JSON object to match
tool_name_to_def: Map of tool names to their function definitions
Returns:
Tuple of (tool_name, tool_args) if matched, None otherwise
"""
# Format 1: Direct tool call format {"name": "...", "arguments": {...}}
if "name" in json_obj and json_obj["name"] in tool_name_to_def:
tool_name = json_obj["name"]
arguments = json_obj.get("arguments", json_obj.get("parameters", {}))
if isinstance(arguments, str):
try:
arguments = json.loads(arguments)
except json.JSONDecodeError:
arguments = {}
if isinstance(arguments, dict):
return (tool_name, arguments)
# Format 2: Function call format {"function": {"name": "...", "arguments": {...}}}
if "function" in json_obj and isinstance(json_obj["function"], dict):
func_obj = json_obj["function"]
if "name" in func_obj and func_obj["name"] in tool_name_to_def:
tool_name = func_obj["name"]
arguments = func_obj.get("arguments", func_obj.get("parameters", {}))
if isinstance(arguments, str):
try:
arguments = json.loads(arguments)
except json.JSONDecodeError:
arguments = {}
if isinstance(arguments, dict):
return (tool_name, arguments)
# Format 3: Tool name as key {"tool_name": {...arguments...}}
for tool_name in tool_name_to_def:
if tool_name in json_obj:
arguments = json_obj[tool_name]
if isinstance(arguments, dict):
return (tool_name, arguments)
# Format 4: Check if the JSON object matches a tool's parameter schema
for tool_name, func_def in tool_name_to_def.items():
params = func_def.get("parameters", {})
properties = params.get("properties", {})
required = params.get("required", [])
if not properties:
continue
# Check if all required parameters are present (empty required = all optional)
if all(req in json_obj for req in required):
# Check if any of the tool's properties are in the JSON object
matching_props = [prop for prop in properties if prop in json_obj]
if matching_props:
# Filter to only include known properties
filtered_args = {k: v for k, v in json_obj.items() if k in properties}
return (tool_name, filtered_args)
return None
def translate_history_to_llm_format(
history: list[ChatMessageSimple],
llm_config: LLMConfig,
@@ -293,7 +433,7 @@ def translate_history_to_llm_format(
for idx, msg in enumerate(history):
# if the message is being added to the history
if msg.message_type in [
if PROMPT_CACHE_CHAT_HISTORY and msg.message_type in [
MessageType.SYSTEM,
MessageType.USER,
MessageType.ASSISTANT,
@@ -720,6 +860,11 @@ def run_llm_step_pkt_generator(
),
obj=result,
)
# Track emitted citation for saving
if state_container:
state_container.add_emitted_citation(
result.citation_number
)
else:
# When citation_processor is None, use delta.content directly without modification
accumulated_answer += delta.content
@@ -846,6 +991,9 @@ def run_llm_step_pkt_generator(
),
obj=result,
)
# Track emitted citation for saving
if state_container:
state_container.add_emitted_citation(result.citation_number)
# Note: Content (AgentResponseDelta) doesn't need an explicit end packet - OverallStop handles it
# Tool calls are handled by tool execution code and emit their own packets (e.g., SectionEnd)

View File

@@ -42,7 +42,6 @@ from onyx.configs.constants import DocumentSource
from onyx.configs.constants import MessageType
from onyx.configs.constants import MilestoneRecordType
from onyx.context.search.models import BaseFilters
from onyx.context.search.models import CitationDocInfo
from onyx.context.search.models import SearchDoc
from onyx.db.chat import create_new_chat_message
from onyx.db.chat import get_chat_session_by_id
@@ -86,10 +85,6 @@ from onyx.utils.logger import setup_logger
from onyx.utils.long_term_log import LongTermLogger
from onyx.utils.telemetry import mt_cloud_telemetry
from onyx.utils.timing import log_function_time
from onyx.utils.variable_functionality import (
fetch_versioned_implementation_with_fallback,
)
from onyx.utils.variable_functionality import noop_fallback
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
@@ -362,21 +357,20 @@ def handle_stream_message_objects(
event=MilestoneRecordType.MULTIPLE_ASSISTANTS,
)
# Track user message in PostHog for analytics
fetch_versioned_implementation_with_fallback(
module="onyx.utils.telemetry",
attribute="event_telemetry",
fallback=noop_fallback,
)(
distinct_id=user.email if user else tenant_id,
event="user_message_sent",
mt_cloud_telemetry(
tenant_id=tenant_id,
distinct_id=(
user.email
if user and not getattr(user, "is_anonymous", False)
else tenant_id
),
event=MilestoneRecordType.USER_MESSAGE_SENT,
properties={
"origin": new_msg_req.origin.value,
"has_files": len(new_msg_req.file_descriptors) > 0,
"has_project": chat_session.project_id is not None,
"has_persona": persona is not None and persona.id != DEFAULT_PERSONA_ID,
"deep_research": new_msg_req.deep_research,
"tenant_id": tenant_id,
},
)
@@ -744,27 +738,16 @@ def llm_loop_completion_handle(
else:
final_answer = "The generation was stopped by the user."
# Build citation_docs_info from accumulated citations in state container
citation_docs_info: list[CitationDocInfo] = []
seen_citation_nums: set[int] = set()
for citation_num, search_doc in state_container.citation_to_doc.items():
if citation_num not in seen_citation_nums:
seen_citation_nums.add(citation_num)
citation_docs_info.append(
CitationDocInfo(
search_doc=search_doc,
citation_number=citation_num,
)
)
save_chat_turn(
message_text=final_answer,
reasoning_tokens=state_container.reasoning_tokens,
citation_docs_info=citation_docs_info,
citation_to_doc=state_container.citation_to_doc,
tool_calls=state_container.tool_calls,
all_search_docs=state_container.get_all_search_docs(),
db_session=db_session,
assistant_message=assistant_message,
is_clarification=state_container.is_clarification,
emitted_citations=state_container.get_emitted_citations(),
)

View File

@@ -18,6 +18,7 @@ from onyx.prompts.prompt_utils import handle_onyx_date_awareness
from onyx.prompts.prompt_utils import replace_citation_guidance_tag
from onyx.prompts.tool_prompts import GENERATE_IMAGE_GUIDANCE
from onyx.prompts.tool_prompts import INTERNAL_SEARCH_GUIDANCE
from onyx.prompts.tool_prompts import MEMORY_GUIDANCE
from onyx.prompts.tool_prompts import OPEN_URLS_GUIDANCE
from onyx.prompts.tool_prompts import PYTHON_TOOL_GUIDANCE
from onyx.prompts.tool_prompts import TOOL_DESCRIPTION_SEARCH_GUIDANCE
@@ -28,6 +29,7 @@ from onyx.tools.interface import Tool
from onyx.tools.tool_implementations.images.image_generation_tool import (
ImageGenerationTool,
)
from onyx.tools.tool_implementations.memory.memory_tool import MemoryTool
from onyx.tools.tool_implementations.open_url.open_url_tool import OpenURLTool
from onyx.tools.tool_implementations.python.python_tool import PythonTool
from onyx.tools.tool_implementations.search.search_tool import SearchTool
@@ -178,8 +180,9 @@ def build_system_prompt(
site_colon_disabled=WEB_SEARCH_SITE_DISABLED_GUIDANCE
)
+ OPEN_URLS_GUIDANCE
+ GENERATE_IMAGE_GUIDANCE
+ PYTHON_TOOL_GUIDANCE
+ GENERATE_IMAGE_GUIDANCE
+ MEMORY_GUIDANCE
)
return system_prompt
@@ -193,6 +196,7 @@ def build_system_prompt(
has_generate_image = any(
isinstance(tool, ImageGenerationTool) for tool in tools
)
has_memory = any(isinstance(tool, MemoryTool) for tool in tools)
if has_web_search or has_internal_search or include_all_guidance:
system_prompt += TOOL_DESCRIPTION_SEARCH_GUIDANCE
@@ -222,4 +226,7 @@ def build_system_prompt(
if has_generate_image or include_all_guidance:
system_prompt += GENERATE_IMAGE_GUIDANCE
if has_memory or include_all_guidance:
system_prompt += MEMORY_GUIDANCE
return system_prompt

View File

@@ -2,8 +2,9 @@ import json
from sqlalchemy.orm import Session
from onyx.chat.chat_state import ChatStateContainer
from onyx.chat.chat_state import SearchDocKey
from onyx.configs.constants import DocumentSource
from onyx.context.search.models import CitationDocInfo
from onyx.context.search.models import SearchDoc
from onyx.db.chat import add_search_docs_to_chat_message
from onyx.db.chat import add_search_docs_to_tool_call
@@ -19,22 +20,6 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
def _create_search_doc_key(search_doc: SearchDoc) -> tuple[str, int, tuple[str, ...]]:
"""
Create a unique key for a SearchDoc that accounts for different versions of the same
document/chunk with different match_highlights.
Args:
search_doc: The SearchDoc pydantic model to create a key for
Returns:
A tuple of (document_id, chunk_ind, sorted match_highlights) that uniquely identifies
this specific version of the document
"""
match_highlights_tuple = tuple(sorted(search_doc.match_highlights or []))
return (search_doc.document_id, search_doc.chunk_ind, match_highlights_tuple)
def _create_and_link_tool_calls(
tool_calls: list[ToolCallInfo],
assistant_message: ChatMessage,
@@ -154,38 +139,36 @@ def save_chat_turn(
message_text: str,
reasoning_tokens: str | None,
tool_calls: list[ToolCallInfo],
citation_docs_info: list[CitationDocInfo],
citation_to_doc: dict[int, SearchDoc],
all_search_docs: dict[SearchDocKey, SearchDoc],
db_session: Session,
assistant_message: ChatMessage,
is_clarification: bool = False,
emitted_citations: set[int] | None = None,
) -> None:
"""
Save a chat turn by populating the assistant_message and creating related entities.
This function:
1. Updates the ChatMessage with text, reasoning tokens, and token count
2. Creates SearchDoc entries from ToolCall search_docs (for tool calls that returned documents)
3. Collects all unique SearchDocs from all tool calls and links them to ChatMessage
4. Builds citation mapping from citation_docs_info
5. Links all unique SearchDocs from tool calls to the ChatMessage
2. Creates DB SearchDoc entries from pre-deduplicated all_search_docs
3. Builds tool_call -> search_doc mapping for displayed docs
4. Builds citation mapping from citation_to_doc
5. Links all unique SearchDocs to the ChatMessage
6. Creates ToolCall entries and links SearchDocs to them
7. Builds the citations mapping for the ChatMessage
Deduplication Logic:
- SearchDocs are deduplicated using (document_id, chunk_ind, match_highlights) as the key
- This ensures that the same document/chunk with different match_highlights (from different
queries) are stored as separate SearchDoc entries
- Each ToolCall and ChatMessage will map to the correct version of the SearchDoc that
matches its specific query highlights
Args:
message_text: The message content to save
reasoning_tokens: Optional reasoning tokens for the message
tool_calls: List of tool call information to create ToolCall entries (may include search_docs)
citation_docs_info: List of citation document information for building citations mapping
citation_to_doc: Mapping from citation number to SearchDoc for building citations
all_search_docs: Pre-deduplicated search docs from ChatStateContainer
db_session: Database session for persistence
assistant_message: The ChatMessage object to populate (should already exist in DB)
is_clarification: Whether this assistant message is a clarification question (deep research flow)
emitted_citations: Set of citation numbers that were actually emitted during streaming.
If provided, only citations in this set will be saved; others are filtered out.
"""
# 1. Update ChatMessage with message content, reasoning tokens, and token count
assistant_message.message = message_text
@@ -200,53 +183,53 @@ def save_chat_turn(
else:
assistant_message.token_count = 0
# 2. Create SearchDoc entries from tool_calls
# Build mapping from SearchDoc to DB SearchDoc ID
# Use (document_id, chunk_ind, match_highlights) as key to avoid duplicates
# while ensuring different versions with different highlights are stored separately
search_doc_key_to_id: dict[tuple[str, int, tuple[str, ...]], int] = {}
tool_call_to_search_doc_ids: dict[str, list[int]] = {}
# 2. Create DB SearchDoc entries from pre-deduplicated all_search_docs
search_doc_key_to_id: dict[SearchDocKey, int] = {}
for key, search_doc_py in all_search_docs.items():
db_search_doc = create_db_search_doc(
server_search_doc=search_doc_py,
db_session=db_session,
commit=False,
)
search_doc_key_to_id[key] = db_search_doc.id
# Process tool calls and their search docs
# 3. Build tool_call -> search_doc mapping (for displayed docs in each tool call)
tool_call_to_search_doc_ids: dict[str, list[int]] = {}
for tool_call_info in tool_calls:
if tool_call_info.search_docs:
search_doc_ids_for_tool: list[int] = []
for search_doc_py in tool_call_info.search_docs:
# Create a unique key for this SearchDoc version
search_doc_key = _create_search_doc_key(search_doc_py)
# Check if we've already created this exact SearchDoc version
if search_doc_key in search_doc_key_to_id:
search_doc_ids_for_tool.append(search_doc_key_to_id[search_doc_key])
key = ChatStateContainer.create_search_doc_key(search_doc_py)
if key in search_doc_key_to_id:
search_doc_ids_for_tool.append(search_doc_key_to_id[key])
else:
# Create new DB SearchDoc entry
# Displayed doc not in all_search_docs - create it
# This can happen if displayed_docs contains docs not in search_docs
db_search_doc = create_db_search_doc(
server_search_doc=search_doc_py,
db_session=db_session,
commit=False,
)
search_doc_key_to_id[search_doc_key] = db_search_doc.id
search_doc_key_to_id[key] = db_search_doc.id
search_doc_ids_for_tool.append(db_search_doc.id)
tool_call_to_search_doc_ids[tool_call_info.tool_call_id] = list(
set(search_doc_ids_for_tool)
)
# 3. Collect all unique SearchDoc IDs from all tool calls to link to ChatMessage
# Use a set to deduplicate by ID (since we've already deduplicated by key above)
all_search_doc_ids_set: set[int] = set()
for search_doc_ids in tool_call_to_search_doc_ids.values():
all_search_doc_ids_set.update(search_doc_ids)
# Collect all search doc IDs for ChatMessage linking
all_search_doc_ids_set: set[int] = set(search_doc_key_to_id.values())
# 4. Build citation mapping from citation_docs_info
# 4. Build a citation mapping from the citation number to the saved DB SearchDoc ID
# Only include citations that were actually emitted during streaming
citation_number_to_search_doc_id: dict[int, int] = {}
for citation_doc_info in citation_docs_info:
# Extract SearchDoc pydantic model
search_doc_py = citation_doc_info.search_doc
for citation_num, search_doc_py in citation_to_doc.items():
# Skip citations that weren't actually emitted (if emitted_citations is provided)
if emitted_citations is not None and citation_num not in emitted_citations:
continue
# Create the unique key for this SearchDoc version
search_doc_key = _create_search_doc_key(search_doc_py)
search_doc_key = ChatStateContainer.create_search_doc_key(search_doc_py)
# Get the search doc ID (should already exist from processing tool_calls)
if search_doc_key in search_doc_key_to_id:
@@ -283,10 +266,7 @@ def save_chat_turn(
all_search_doc_ids_set.add(db_search_doc_id)
# Build mapping from citation number to search doc ID
if citation_doc_info.citation_number is not None:
citation_number_to_search_doc_id[citation_doc_info.citation_number] = (
db_search_doc_id
)
citation_number_to_search_doc_id[citation_num] = db_search_doc_id
# 5. Link all unique SearchDocs (from both tool calls and citations) to ChatMessage
final_search_doc_ids: list[int] = list(all_search_doc_ids_set)
@@ -306,23 +286,10 @@ def save_chat_turn(
tool_call_to_search_doc_ids=tool_call_to_search_doc_ids,
)
# 7. Build citations mapping from citation_docs_info
# Any citation_doc_info with a citation_number appeared in the text and should be mapped
citations: dict[int, int] = {}
for citation_doc_info in citation_docs_info:
if citation_doc_info.citation_number is not None:
search_doc_id = citation_number_to_search_doc_id.get(
citation_doc_info.citation_number
)
if search_doc_id is not None:
citations[citation_doc_info.citation_number] = search_doc_id
else:
logger.warning(
f"Citation number {citation_doc_info.citation_number} found in citation_docs_info "
f"but no matching search doc ID in mapping"
)
assistant_message.citations = citations if citations else None
# 7. Build citations mapping - use the mapping we already built in step 4
assistant_message.citations = (
citation_number_to_search_doc_id if citation_number_to_search_doc_id else None
)
# Finally save the messages, tool calls, and docs
db_session.commit()

View File

@@ -208,8 +208,19 @@ OPENSEARCH_REST_API_PORT = int(os.environ.get("OPENSEARCH_REST_API_PORT") or 920
OPENSEARCH_ADMIN_USERNAME = os.environ.get("OPENSEARCH_ADMIN_USERNAME", "admin")
OPENSEARCH_ADMIN_PASSWORD = os.environ.get("OPENSEARCH_ADMIN_PASSWORD", "")
ENABLE_OPENSEARCH_FOR_ONYX = (
os.environ.get("ENABLE_OPENSEARCH_FOR_ONYX", "").lower() == "true"
# This is the "base" config for now, the idea is that at least for our dev
# environments we always want to be dual indexing into both OpenSearch and Vespa
# to stress test the new codepaths. Only enable this if there is some instance
# of OpenSearch running for the relevant Onyx instance.
ENABLE_OPENSEARCH_INDEXING_FOR_ONYX = (
os.environ.get("ENABLE_OPENSEARCH_INDEXING_FOR_ONYX", "").lower() == "true"
)
# Given that the "base" config above is true, this enables whether we want to
# retrieve from OpenSearch or Vespa. We want to be able to quickly toggle this
# in the event we see issues with OpenSearch retrieval in our dev environments.
ENABLE_OPENSEARCH_RETRIEVAL_FOR_ONYX = (
ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
and os.environ.get("ENABLE_OPENSEARCH_RETRIEVAL_FOR_ONYX", "").lower() == "true"
)
VESPA_HOST = os.environ.get("VESPA_HOST") or "localhost"
@@ -738,6 +749,10 @@ JOB_TIMEOUT = 60 * 60 * 6 # 6 hours default
LOG_ONYX_MODEL_INTERACTIONS = (
os.environ.get("LOG_ONYX_MODEL_INTERACTIONS", "").lower() == "true"
)
PROMPT_CACHE_CHAT_HISTORY = (
os.environ.get("PROMPT_CACHE_CHAT_HISTORY", "").lower() == "true"
)
# If set to `true` will enable additional logs about Vespa query performance
# (time spent on finding the right docs + time spent fetching summaries from disk)
LOG_VESPA_TIMING_INFORMATION = (
@@ -1011,3 +1026,19 @@ INSTANCE_TYPE = (
if os.environ.get("IS_MANAGED_INSTANCE", "").lower() == "true"
else "cloud" if AUTH_TYPE == AuthType.CLOUD else "self_hosted"
)
## Discord Bot Configuration
DISCORD_BOT_TOKEN = os.environ.get("DISCORD_BOT_TOKEN")
DISCORD_BOT_INVOKE_CHAR = os.environ.get("DISCORD_BOT_INVOKE_CHAR", "!")
## Stripe Configuration
# URL to fetch the Stripe publishable key from a public S3 bucket.
# Publishable keys are safe to expose publicly - they can only initialize
# Stripe.js and tokenize payment info, not make charges or access data.
STRIPE_PUBLISHABLE_KEY_URL = (
"https://onyx-stripe-public.s3.amazonaws.com/publishable-key.txt"
)
# Override for local testing with Stripe test keys (pk_test_*)
STRIPE_PUBLISHABLE_KEY_OVERRIDE = os.environ.get("STRIPE_PUBLISHABLE_KEY")

View File

@@ -1,6 +1,5 @@
import os
INPUT_PROMPT_YAML = "./onyx/seeding/input_prompts.yaml"
PROMPTS_YAML = "./onyx/seeding/prompts.yaml"
PERSONAS_YAML = "./onyx/seeding/personas.yaml"
NUM_RETURNED_HITS = 50

View File

@@ -93,6 +93,7 @@ SSL_CERT_FILE = "bundle.pem"
DANSWER_API_KEY_PREFIX = "API_KEY__"
DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN = "onyxapikey.ai"
UNNAMED_KEY_PLACEHOLDER = "Unnamed"
DISCORD_SERVICE_API_KEY_NAME = "discord-bot-service"
# Key-Value store keys
KV_REINDEX_KEY = "needs_reindexing"
@@ -152,6 +153,17 @@ CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT = 300 # 5 min
CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT = 30 * 60 # 30 minutes (in seconds)
# How long a queued user-file task is valid before workers discard it.
# Should be longer than the beat interval (20 s) but short enough to prevent
# indefinite queue growth. Workers drop tasks older than this without touching
# the DB, so a shorter value = faster drain of stale duplicates.
CELERY_USER_FILE_PROCESSING_TASK_EXPIRES = 60 # 1 minute (in seconds)
# Maximum number of tasks allowed in the user-file-processing queue before the
# beat generator stops adding more. Prevents unbounded queue growth when workers
# fall behind.
USER_FILE_PROCESSING_MAX_QUEUE_DEPTH = 500
CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT = 5 * 60 # 5 minutes (in seconds)
DANSWER_REDIS_FUNCTION_LOCK_PREFIX = "da_function_lock:"
@@ -340,6 +352,7 @@ class MilestoneRecordType(str, Enum):
CREATED_CONNECTOR = "created_connector"
CONNECTOR_SUCCEEDED = "connector_succeeded"
RAN_QUERY = "ran_query"
USER_MESSAGE_SENT = "user_message_sent"
MULTIPLE_ASSISTANTS = "multiple_assistants"
CREATED_ASSISTANT = "created_assistant"
CREATED_ONYX_BOT = "created_onyx_bot"
@@ -422,6 +435,9 @@ class OnyxRedisLocks:
# User file processing
USER_FILE_PROCESSING_BEAT_LOCK = "da_lock:check_user_file_processing_beat"
USER_FILE_PROCESSING_LOCK_PREFIX = "da_lock:user_file_processing"
# Short-lived key set when a task is enqueued; cleared when the worker picks it up.
# Prevents the beat from re-enqueuing the same file while a task is already queued.
USER_FILE_QUEUED_PREFIX = "da_lock:user_file_queued"
USER_FILE_PROJECT_SYNC_BEAT_LOCK = "da_lock:check_user_file_project_sync_beat"
USER_FILE_PROJECT_SYNC_LOCK_PREFIX = "da_lock:user_file_project_sync"
USER_FILE_DELETE_BEAT_LOCK = "da_lock:check_user_file_delete_beat"

View File

@@ -25,11 +25,17 @@ class AsanaConnector(LoadConnector, PollConnector):
batch_size: int = INDEX_BATCH_SIZE,
continue_on_failure: bool = CONTINUE_ON_CONNECTOR_FAILURE,
) -> None:
self.workspace_id = asana_workspace_id
self.project_ids_to_index: list[str] | None = (
asana_project_ids.split(",") if asana_project_ids is not None else None
)
self.asana_team_id = asana_team_id
self.workspace_id = asana_workspace_id.strip()
if asana_project_ids:
project_ids = [
project_id.strip()
for project_id in asana_project_ids.split(",")
if project_id.strip()
]
self.project_ids_to_index = project_ids or None
else:
self.project_ids_to_index = None
self.asana_team_id = (asana_team_id.strip() or None) if asana_team_id else None
self.batch_size = batch_size
self.continue_on_failure = continue_on_failure
logger.info(

View File

@@ -31,6 +31,8 @@ class GongConnector(LoadConnector, PollConnector):
BASE_URL = "https://api.gong.io"
MAX_CALL_DETAILS_ATTEMPTS = 6
CALL_DETAILS_DELAY = 30 # in seconds
# Gong API limit is 3 calls/sec — stay safely under it
MIN_REQUEST_INTERVAL = 0.5 # seconds between requests
def __init__(
self,
@@ -44,9 +46,13 @@ class GongConnector(LoadConnector, PollConnector):
self.continue_on_fail = continue_on_fail
self.auth_token_basic: str | None = None
self.hide_user_info = hide_user_info
self._last_request_time: float = 0.0
# urllib3 Retry already respects the Retry-After header by default
# (respect_retry_after_header=True), so on 429 it will sleep for the
# duration Gong specifies before retrying.
retry_strategy = Retry(
total=5,
total=10,
backoff_factor=2,
status_forcelist=[429, 500, 502, 503, 504],
)
@@ -60,8 +66,24 @@ class GongConnector(LoadConnector, PollConnector):
url = f"{GongConnector.BASE_URL}{endpoint}"
return url
def _throttled_request(
self, method: str, url: str, **kwargs: Any
) -> requests.Response:
"""Rate-limited request wrapper. Enforces MIN_REQUEST_INTERVAL between
calls to stay under Gong's 3 calls/sec limit and avoid triggering 429s."""
now = time.monotonic()
elapsed = now - self._last_request_time
if elapsed < self.MIN_REQUEST_INTERVAL:
time.sleep(self.MIN_REQUEST_INTERVAL - elapsed)
response = self._session.request(method, url, **kwargs)
self._last_request_time = time.monotonic()
return response
def _get_workspace_id_map(self) -> dict[str, str]:
response = self._session.get(GongConnector.make_url("/v2/workspaces"))
response = self._throttled_request(
"GET", GongConnector.make_url("/v2/workspaces")
)
response.raise_for_status()
workspaces_details = response.json().get("workspaces")
@@ -105,8 +127,8 @@ class GongConnector(LoadConnector, PollConnector):
del body["filter"]["workspaceId"]
while True:
response = self._session.post(
GongConnector.make_url("/v2/calls/transcript"), json=body
response = self._throttled_request(
"POST", GongConnector.make_url("/v2/calls/transcript"), json=body
)
# If no calls in the range, just break out
if response.status_code == 404:
@@ -141,8 +163,8 @@ class GongConnector(LoadConnector, PollConnector):
"contentSelector": {"exposedFields": {"parties": True}},
}
response = self._session.post(
GongConnector.make_url("/v2/calls/extensive"), json=body
response = self._throttled_request(
"POST", GongConnector.make_url("/v2/calls/extensive"), json=body
)
response.raise_for_status()
@@ -193,7 +215,8 @@ class GongConnector(LoadConnector, PollConnector):
# There's a likely race condition in the API where a transcript will have a
# call id but the call to v2/calls/extensive will not return all of the id's
# retry with exponential backoff has been observed to mitigate this
# in ~2 minutes
# in ~2 minutes. After max attempts, proceed with whatever we have —
# the per-call loop below will skip missing IDs gracefully.
current_attempt = 0
while True:
current_attempt += 1
@@ -212,11 +235,14 @@ class GongConnector(LoadConnector, PollConnector):
f"missing_call_ids={missing_call_ids}"
)
if current_attempt >= self.MAX_CALL_DETAILS_ATTEMPTS:
raise RuntimeError(
f"Attempt count exceeded for _get_call_details_by_ids: "
f"missing_call_ids={missing_call_ids} "
f"max_attempts={self.MAX_CALL_DETAILS_ATTEMPTS}"
logger.error(
f"Giving up on missing call id's after "
f"{self.MAX_CALL_DETAILS_ATTEMPTS} attempts: "
f"missing_call_ids={missing_call_ids}"
f"proceeding with {len(call_details_map)} of "
f"{len(transcript_call_ids)} calls"
)
break
wait_seconds = self.CALL_DETAILS_DELAY * pow(2, current_attempt - 1)
logger.warning(

View File

@@ -244,6 +244,9 @@ def convert_metadata_dict_to_list_of_strings(
Each string is a key-value pair separated by the INDEX_SEPARATOR. If a key
points to a list of values, each value generates a unique pair.
NOTE: Whatever formatting strategy is used here to generate a key-value
string must be replicated when constructing query filters.
Args:
metadata: The metadata dict to convert where values can be either a
string or a list of strings.

View File

@@ -6,6 +6,7 @@ import sys
import tempfile
import time
from collections import defaultdict
from collections.abc import Callable
from pathlib import Path
from typing import Any
from typing import cast
@@ -30,20 +31,29 @@ from onyx.connectors.salesforce.onyx_salesforce import OnyxSalesforce
from onyx.connectors.salesforce.salesforce_calls import fetch_all_csvs_in_parallel
from onyx.connectors.salesforce.sqlite_functions import OnyxSalesforceSQLite
from onyx.connectors.salesforce.utils import ACCOUNT_OBJECT_TYPE
from onyx.connectors.salesforce.utils import BASE_DATA_PATH
from onyx.connectors.salesforce.utils import get_sqlite_db_path
from onyx.connectors.salesforce.utils import ID_FIELD
from onyx.connectors.salesforce.utils import MODIFIED_FIELD
from onyx.connectors.salesforce.utils import NAME_FIELD
from onyx.connectors.salesforce.utils import USER_OBJECT_TYPE
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
def _convert_to_metadata_value(value: Any) -> str | list[str]:
"""Convert a Salesforce field value to a valid metadata value.
Document metadata expects str | list[str], but Salesforce returns
various types (bool, float, int, etc.). This function ensures all
values are properly converted to strings.
"""
if isinstance(value, list):
return [str(item) for item in value]
return str(value)
_DEFAULT_PARENT_OBJECT_TYPES = [ACCOUNT_OBJECT_TYPE]
_DEFAULT_ATTRIBUTES_TO_KEEP: dict[str, dict[str, str]] = {
@@ -433,6 +443,88 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnectorWithPermSyn
# # gc.collect()
# return all_types
def _yield_doc_batches(
self,
sf_db: OnyxSalesforceSQLite,
type_to_processed: dict[str, int],
changed_ids_to_type: dict[str, str],
parent_types: set[str],
increment_parents_changed: Callable[[], None],
) -> GenerateDocumentsOutput:
""" """
docs_to_yield: list[Document] = []
docs_to_yield_bytes = 0
last_log_time = 0.0
for (
parent_type,
parent_id,
examined_ids,
) in sf_db.get_changed_parent_ids_by_type(
changed_ids=list(changed_ids_to_type.keys()),
parent_types=parent_types,
):
now = time.monotonic()
processed = examined_ids - 1
if now - last_log_time > SalesforceConnector.LOG_INTERVAL:
logger.info(
f"Processing stats: {type_to_processed} "
f"file_size={sf_db.file_size} "
f"processed={processed} "
f"remaining={len(changed_ids_to_type) - processed}"
)
last_log_time = now
type_to_processed[parent_type] = type_to_processed.get(parent_type, 0) + 1
parent_object = sf_db.get_record(parent_id, parent_type)
if not parent_object:
logger.warning(
f"Failed to get parent object {parent_id} for {parent_type}"
)
continue
# use the db to create a document we can yield
doc = convert_sf_object_to_doc(
sf_db,
sf_object=parent_object,
sf_instance=self.sf_client.sf_instance,
)
doc.metadata["object_type"] = parent_type
# Add default attributes to the metadata
for (
sf_attribute,
canonical_attribute,
) in _DEFAULT_ATTRIBUTES_TO_KEEP.get(parent_type, {}).items():
if sf_attribute in parent_object.data:
doc.metadata[canonical_attribute] = _convert_to_metadata_value(
parent_object.data[sf_attribute]
)
doc_sizeof = sys.getsizeof(doc)
docs_to_yield_bytes += doc_sizeof
docs_to_yield.append(doc)
increment_parents_changed()
# memory usage is sensitive to the input length, so we're yielding immediately
# if the batch exceeds a certain byte length
if (
len(docs_to_yield) >= self.batch_size
or docs_to_yield_bytes > SalesforceConnector.MAX_BATCH_BYTES
):
yield docs_to_yield
docs_to_yield = []
docs_to_yield_bytes = 0
# observed a memory leak / size issue with the account table if we don't gc.collect here.
gc.collect()
yield docs_to_yield
def _full_sync(
self,
temp_dir: str,
@@ -443,8 +535,6 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnectorWithPermSyn
if not self._sf_client:
raise RuntimeError("self._sf_client is None!")
docs_to_yield: list[Document] = []
changed_ids_to_type: dict[str, str] = {}
parents_changed = 0
examined_ids = 0
@@ -492,9 +582,6 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnectorWithPermSyn
f"records={num_records}"
)
# yield an empty list to keep the connector alive
yield docs_to_yield
new_ids = sf_db.update_from_csv(
object_type=object_type,
csv_download_path=csv_path,
@@ -527,79 +614,17 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnectorWithPermSyn
)
# Step 3 - extract and index docs
docs_to_yield_bytes = 0
last_log_time = 0.0
for (
parent_type,
parent_id,
examined_ids,
) in sf_db.get_changed_parent_ids_by_type(
changed_ids=list(changed_ids_to_type.keys()),
parent_types=ctx.parent_types,
):
now = time.monotonic()
processed = examined_ids - 1
if now - last_log_time > SalesforceConnector.LOG_INTERVAL:
logger.info(
f"Processing stats: {type_to_processed} "
f"file_size={sf_db.file_size} "
f"processed={processed} "
f"remaining={len(changed_ids_to_type) - processed}"
)
last_log_time = now
type_to_processed[parent_type] = (
type_to_processed.get(parent_type, 0) + 1
)
parent_object = sf_db.get_record(parent_id, parent_type)
if not parent_object:
logger.warning(
f"Failed to get parent object {parent_id} for {parent_type}"
)
continue
# use the db to create a document we can yield
doc = convert_sf_object_to_doc(
sf_db,
sf_object=parent_object,
sf_instance=self.sf_client.sf_instance,
)
doc.metadata["object_type"] = parent_type
# Add default attributes to the metadata
for (
sf_attribute,
canonical_attribute,
) in _DEFAULT_ATTRIBUTES_TO_KEEP.get(parent_type, {}).items():
if sf_attribute in parent_object.data:
doc.metadata[canonical_attribute] = parent_object.data[
sf_attribute
]
doc_sizeof = sys.getsizeof(doc)
docs_to_yield_bytes += doc_sizeof
docs_to_yield.append(doc)
def increment_parents_changed() -> None:
nonlocal parents_changed
parents_changed += 1
# memory usage is sensitive to the input length, so we're yielding immediately
# if the batch exceeds a certain byte length
if (
len(docs_to_yield) >= self.batch_size
or docs_to_yield_bytes > SalesforceConnector.MAX_BATCH_BYTES
):
yield docs_to_yield
docs_to_yield = []
docs_to_yield_bytes = 0
# observed a memory leak / size issue with the account table if we don't gc.collect here.
gc.collect()
yield docs_to_yield
yield from self._yield_doc_batches(
sf_db,
type_to_processed,
changed_ids_to_type,
ctx.parent_types,
increment_parents_changed,
)
except Exception:
logger.exception("Unexpected exception")
raise
@@ -801,7 +826,9 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnectorWithPermSyn
canonical_attribute,
) in _DEFAULT_ATTRIBUTES_TO_KEEP.get(actual_parent_type, {}).items():
if sf_attribute in record:
doc.metadata[canonical_attribute] = record[sf_attribute]
doc.metadata[canonical_attribute] = _convert_to_metadata_value(
record[sf_attribute]
)
doc_sizeof = sys.getsizeof(doc)
docs_to_yield_bytes += doc_sizeof
@@ -1088,36 +1115,21 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnectorWithPermSyn
return return_context
def load_from_state(self) -> GenerateDocumentsOutput:
if MULTI_TENANT:
# if multi tenant, we cannot expect the sqlite db to be cached/present
with tempfile.TemporaryDirectory() as temp_dir:
return self._full_sync(temp_dir)
# nuke the db since we're starting from scratch
sqlite_db_path = get_sqlite_db_path(BASE_DATA_PATH)
if os.path.exists(sqlite_db_path):
logger.info(f"load_from_state: Removing db at {sqlite_db_path}.")
os.remove(sqlite_db_path)
return self._full_sync(BASE_DATA_PATH)
# Always use a temp directory for SQLite - the database is rebuilt
# from scratch each time via CSV downloads, so there's no caching benefit
# from persisting it. Using temp dirs also avoids collisions between
# multiple CC pairs and eliminates stale WAL/SHM file issues.
# TODO(evan): make this thing checkpointed and persist/load db from filestore
with tempfile.TemporaryDirectory() as temp_dir:
yield from self._full_sync(temp_dir)
def poll_source(
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
) -> GenerateDocumentsOutput:
"""Poll source will synchronize updated parent objects one by one."""
if start == 0:
# nuke the db if we're starting from scratch
sqlite_db_path = get_sqlite_db_path(BASE_DATA_PATH)
if os.path.exists(sqlite_db_path):
logger.info(
f"poll_source: Starting at time 0, removing db at {sqlite_db_path}."
)
os.remove(sqlite_db_path)
return self._delta_sync(BASE_DATA_PATH, start, end)
# Always use a temp directory - see comment in load_from_state()
with tempfile.TemporaryDirectory() as temp_dir:
return self._delta_sync(temp_dir, start, end)
yield from self._delta_sync(temp_dir, start, end)
def retrieve_all_slim_docs_perm_sync(
self,

View File

@@ -12,6 +12,7 @@ from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.salesforce.utils import ACCOUNT_OBJECT_TYPE
from onyx.connectors.salesforce.utils import ID_FIELD
from onyx.connectors.salesforce.utils import NAME_FIELD
from onyx.connectors.salesforce.utils import remove_sqlite_db_files
from onyx.connectors.salesforce.utils import SalesforceObject
from onyx.connectors.salesforce.utils import USER_OBJECT_TYPE
from onyx.connectors.salesforce.utils import validate_salesforce_id
@@ -22,6 +23,9 @@ from shared_configs.utils import batch_list
logger = setup_logger()
SQLITE_DISK_IO_ERROR = "disk I/O error"
class OnyxSalesforceSQLite:
"""Notes on context management using 'with self.conn':
@@ -99,8 +103,37 @@ class OnyxSalesforceSQLite:
def apply_schema(self) -> None:
"""Initialize the SQLite database with required tables if they don't exist.
Non-destructive operation.
Non-destructive operation. If a disk I/O error is encountered (often due
to stale WAL/SHM files from a previous crash), this method will attempt
to recover by removing the corrupted files and recreating the database.
"""
try:
self._apply_schema_impl()
except sqlite3.OperationalError as e:
if SQLITE_DISK_IO_ERROR not in str(e):
raise
logger.warning(f"SQLite disk I/O error detected, attempting recovery: {e}")
self._recover_from_corruption()
self._apply_schema_impl()
def _recover_from_corruption(self) -> None:
"""Recover from SQLite corruption by removing all database files and reconnecting."""
logger.info(f"Removing corrupted SQLite files: {self.filename}")
# Close existing connection
self.close()
# Remove all SQLite files (main db, WAL, SHM)
remove_sqlite_db_files(self.filename)
# Reconnect - this will create a fresh database
self.connect()
logger.info("SQLite recovery complete, fresh database created")
def _apply_schema_impl(self) -> None:
"""Internal implementation of apply_schema."""
if self._conn is None:
raise RuntimeError("Database connection is closed")

View File

@@ -41,6 +41,28 @@ def get_sqlite_db_path(directory: str) -> str:
return os.path.join(directory, "salesforce_db.sqlite")
def remove_sqlite_db_files(db_path: str) -> None:
"""Remove SQLite database and all associated files (WAL, SHM).
SQLite in WAL mode creates additional files:
- .sqlite-wal: Write-ahead log
- .sqlite-shm: Shared memory file
If these files become stale (e.g., after a crash), they can cause
'disk I/O error' when trying to open the database. This function
ensures all related files are removed.
"""
files_to_remove = [
db_path,
f"{db_path}-wal",
f"{db_path}-shm",
]
for file_path in files_to_remove:
if os.path.exists(file_path):
os.remove(file_path)
# NOTE: only used with shelves, deprecated at this point
def get_object_type_path(object_type: str) -> str:
"""Get the directory path for a specific object type."""
type_dir = os.path.join(BASE_DATA_PATH, object_type)

View File

@@ -15,6 +15,7 @@ from onyx.federated_connectors.slack.models import SlackEntities
from onyx.llm.interfaces import LLM
from onyx.llm.models import UserMessage
from onyx.llm.utils import llm_response_to_string
from onyx.natural_language_processing.english_stopwords import ENGLISH_STOPWORDS_SET
from onyx.onyxbot.slack.models import ChannelType
from onyx.prompts.federated_search import SLACK_DATE_EXTRACTION_PROMPT
from onyx.prompts.federated_search import SLACK_QUERY_EXPANSION_PROMPT
@@ -113,7 +114,7 @@ def is_recency_query(query: str) -> bool:
if not has_recency_keyword:
return False
# Get combined stop words (NLTK + Slack-specific)
# Get combined stop words (English + Slack-specific)
all_stop_words = _get_combined_stop_words()
# Extract content words (excluding stop words)
@@ -488,7 +489,7 @@ def build_channel_override_query(channel_references: set[str], time_filter: str)
return f"__CHANNEL_OVERRIDE__ {channel_filter}{time_filter}"
# Slack-specific stop words (in addition to standard NLTK stop words)
# Slack-specific stop words (in addition to standard English stop words)
# These include Slack-specific terms and temporal/recency keywords
SLACK_SPECIFIC_STOP_WORDS = frozenset(
RECENCY_KEYWORDS
@@ -508,27 +509,16 @@ SLACK_SPECIFIC_STOP_WORDS = frozenset(
)
def _get_combined_stop_words() -> set[str]:
"""Get combined NLTK + Slack-specific stop words.
def _get_combined_stop_words() -> frozenset[str]:
"""Get combined English + Slack-specific stop words.
Returns a set of stop words for filtering content words.
Falls back to just Slack-specific stop words if NLTK is unavailable.
Returns a frozenset of stop words for filtering content words.
Note: Currently only supports English stop words. Non-English queries
may have suboptimal content word extraction. Future enhancement could
detect query language and load appropriate stop words.
"""
try:
from nltk.corpus import stopwords # type: ignore
# TODO: Support multiple languages - currently hardcoded to English
# Could detect language or allow configuration
nltk_stop_words = set(stopwords.words("english"))
except Exception:
# Fallback if NLTK not available
nltk_stop_words = set()
return nltk_stop_words | SLACK_SPECIFIC_STOP_WORDS
return ENGLISH_STOPWORDS_SET | SLACK_SPECIFIC_STOP_WORDS
def extract_content_words_from_recency_query(
@@ -536,7 +526,7 @@ def extract_content_words_from_recency_query(
) -> list[str]:
"""Extract meaningful content words from a recency query.
Filters out NLTK stop words, Slack-specific terms, channel references, and proper nouns.
Filters out English stop words, Slack-specific terms, channel references, and proper nouns.
Args:
query_text: The user's query text
@@ -545,7 +535,7 @@ def extract_content_words_from_recency_query(
Returns:
List of content words (up to MAX_CONTENT_WORDS)
"""
# Get combined stop words (NLTK + Slack-specific)
# Get combined stop words (English + Slack-specific)
all_stop_words = _get_combined_stop_words()
words = query_text.split()
@@ -567,6 +557,23 @@ def extract_content_words_from_recency_query(
return content_words_filtered[:MAX_CONTENT_WORDS]
def _is_valid_keyword_query(line: str) -> bool:
"""Check if a line looks like a valid keyword query vs explanatory text.
Returns False for lines that appear to be LLM explanations rather than keywords.
"""
# Reject lines that start with parentheses (explanatory notes)
if line.startswith("("):
return False
# Reject lines that are too long (likely sentences, not keywords)
# Keywords should be short - reject if > 50 chars or > 6 words
if len(line) > 50 or len(line.split()) > 6:
return False
return True
def expand_query_with_llm(query_text: str, llm: LLM) -> list[str]:
"""Use LLM to expand query into multiple search variations.
@@ -589,10 +596,18 @@ def expand_query_with_llm(query_text: str, llm: LLM) -> list[str]:
response_clean = _parse_llm_code_block_response(response)
# Split into lines and filter out empty lines
rephrased_queries = [
raw_queries = [
line.strip() for line in response_clean.split("\n") if line.strip()
]
# Filter out lines that look like explanatory text rather than keywords
rephrased_queries = [q for q in raw_queries if _is_valid_keyword_query(q)]
# Log if we filtered out garbage
if len(raw_queries) != len(rephrased_queries):
filtered_out = set(raw_queries) - set(rephrased_queries)
logger.warning(f"Filtered out non-keyword LLM responses: {filtered_out}")
# If no queries generated, use empty query
if not rephrased_queries:
logger.debug("No content keywords extracted from query expansion")

View File

@@ -116,6 +116,8 @@ class UserFileFilters(BaseModel):
class IndexFilters(BaseFilters, UserFileFilters):
# NOTE: These strings must be formatted in the same way as the output of
# DocumentAccess::to_acl.
access_control_list: list[str] | None
tenant_id: str | None = None
@@ -144,10 +146,6 @@ class BasicChunkRequest(BaseModel):
# In case some queries favor recency more than other queries.
recency_bias_multiplier: float = 1.0
# Sometimes we may want to extract specific keywords from a more semantic query for
# a better keyword search.
query_keywords: list[str] | None = None # Not used currently
limit: int | None = None
offset: int | None = None # This one is not set currently
@@ -166,6 +164,8 @@ class ChunkIndexRequest(BasicChunkRequest):
# Calculated final filters
filters: IndexFilters
query_keywords: list[str] | None = None
class ContextExpansionType(str, Enum):
NOT_RELEVANT = "not_relevant"
@@ -372,6 +372,10 @@ class SearchDocsResponse(BaseModel):
# document id is the most staightforward way.
citation_mapping: dict[int, str]
# For cases where the frontend only needs to display a subset of the search docs
# The whole list is typically still needed for later steps but this set should be saved separately
displayed_docs: list[SearchDoc] | None = None
class SavedSearchDoc(SearchDoc):
db_doc_id: int
@@ -430,11 +434,6 @@ class SavedSearchDoc(SearchDoc):
return self_score < other_score
class CitationDocInfo(BaseModel):
search_doc: SearchDoc
citation_number: int | None
class SavedSearchDocWithContent(SavedSearchDoc):
"""Used for endpoints that need to return the actual contents of the retrieved
section in addition to the match_highlights."""

View File

@@ -19,6 +19,7 @@ from onyx.db.models import Persona
from onyx.db.models import User
from onyx.document_index.interfaces import DocumentIndex
from onyx.llm.interfaces import LLM
from onyx.natural_language_processing.english_stopwords import strip_stopwords
from onyx.secondary_llm_flows.source_filter import extract_source_filter
from onyx.secondary_llm_flows.time_filter import extract_time_filter
from onyx.utils.logger import setup_logger
@@ -278,12 +279,16 @@ def search_pipeline(
bypass_acl=chunk_search_request.bypass_acl,
)
query_keywords = strip_stopwords(chunk_search_request.query)
query_request = ChunkIndexRequest(
query=chunk_search_request.query,
hybrid_alpha=chunk_search_request.hybrid_alpha,
recency_bias_multiplier=chunk_search_request.recency_bias_multiplier,
query_keywords=chunk_search_request.query_keywords,
query_keywords=query_keywords,
filters=filters,
limit=chunk_search_request.limit,
offset=chunk_search_request.offset,
)
retrieved_chunks = search_chunks(

View File

@@ -23,45 +23,6 @@ from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
logger = setup_logger()
def _dedupe_chunks(
chunks: list[InferenceChunk],
) -> list[InferenceChunk]:
used_chunks: dict[tuple[str, int], InferenceChunk] = {}
for chunk in chunks:
key = (chunk.document_id, chunk.chunk_id)
if key not in used_chunks:
used_chunks[key] = chunk
else:
stored_chunk_score = used_chunks[key].score or 0
this_chunk_score = chunk.score or 0
if stored_chunk_score < this_chunk_score:
used_chunks[key] = chunk
return list(used_chunks.values())
def download_nltk_data() -> None:
import nltk # type: ignore[import-untyped]
resources = {
"stopwords": "corpora/stopwords",
# "wordnet": "corpora/wordnet", # Not in use
"punkt_tab": "tokenizers/punkt_tab",
}
for resource_name, resource_path in resources.items():
try:
nltk.data.find(resource_path)
logger.info(f"{resource_name} is already downloaded.")
except LookupError:
try:
logger.info(f"Downloading {resource_name}...")
nltk.download(resource_name, quiet=True)
logger.info(f"{resource_name} downloaded successfully.")
except Exception as e:
logger.error(f"Failed to download {resource_name}. Error: {e}")
def combine_retrieval_results(
chunk_sets: list[list[InferenceChunk]],
) -> list[InferenceChunk]:

View File

@@ -0,0 +1,451 @@
"""CRUD operations for Discord bot models."""
from datetime import datetime
from datetime import timezone
from sqlalchemy import delete
from sqlalchemy import select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import Session
from onyx.auth.api_key import build_displayable_api_key
from onyx.auth.api_key import generate_api_key
from onyx.auth.api_key import hash_api_key
from onyx.auth.schemas import UserRole
from onyx.configs.constants import DISCORD_SERVICE_API_KEY_NAME
from onyx.db.api_key import insert_api_key
from onyx.db.models import ApiKey
from onyx.db.models import DiscordBotConfig
from onyx.db.models import DiscordChannelConfig
from onyx.db.models import DiscordGuildConfig
from onyx.db.models import User
from onyx.db.utils import DiscordChannelView
from onyx.server.api_key.models import APIKeyArgs
from onyx.utils.logger import setup_logger
logger = setup_logger()
# === DiscordBotConfig ===
def get_discord_bot_config(db_session: Session) -> DiscordBotConfig | None:
"""Get the Discord bot config for this tenant (at most one)."""
return db_session.scalar(select(DiscordBotConfig).limit(1))
def create_discord_bot_config(
db_session: Session,
bot_token: str,
) -> DiscordBotConfig:
"""Create the Discord bot config. Raises ValueError if already exists.
The check constraint on id='SINGLETON' ensures only one config per tenant.
"""
existing = get_discord_bot_config(db_session)
if existing:
raise ValueError("Discord bot config already exists")
config = DiscordBotConfig(bot_token=bot_token)
db_session.add(config)
try:
db_session.flush()
except IntegrityError:
# Race condition: another request created the config concurrently
db_session.rollback()
raise ValueError("Discord bot config already exists")
return config
def delete_discord_bot_config(db_session: Session) -> bool:
"""Delete the Discord bot config. Returns True if deleted."""
result = db_session.execute(delete(DiscordBotConfig))
db_session.flush()
return result.rowcount > 0 # type: ignore[attr-defined]
# === Discord Service API Key ===
def get_discord_service_api_key(db_session: Session) -> ApiKey | None:
"""Get the Discord service API key if it exists."""
return db_session.scalar(
select(ApiKey).where(ApiKey.name == DISCORD_SERVICE_API_KEY_NAME)
)
def get_or_create_discord_service_api_key(
db_session: Session,
tenant_id: str,
) -> str:
"""Get existing Discord service API key or create one.
The API key is used by the Discord bot to authenticate with the
Onyx API pods when sending chat requests.
Args:
db_session: Database session for the tenant.
tenant_id: The tenant ID (used for logging/context).
Returns:
The raw API key string (not hashed).
Raises:
RuntimeError: If API key creation fails.
"""
# Check for existing key
existing = get_discord_service_api_key(db_session)
if existing:
# Database only stores the hash, so we must regenerate to get the raw key.
# This is safe since the Discord bot is the only consumer of this key.
logger.debug(
f"Found existing Discord service API key for tenant {tenant_id} that isn't in cache, "
"regenerating to update cache"
)
new_api_key = generate_api_key(tenant_id)
existing.hashed_api_key = hash_api_key(new_api_key)
existing.api_key_display = build_displayable_api_key(new_api_key)
db_session.flush()
return new_api_key
# Create new API key
logger.info(f"Creating Discord service API key for tenant {tenant_id}")
api_key_args = APIKeyArgs(
name=DISCORD_SERVICE_API_KEY_NAME,
role=UserRole.LIMITED, # Limited role is sufficient for chat requests
)
api_key_descriptor = insert_api_key(
db_session=db_session,
api_key_args=api_key_args,
user_id=None, # Service account, no owner
)
if not api_key_descriptor.api_key:
raise RuntimeError(
f"Failed to create Discord service API key for tenant {tenant_id}"
)
return api_key_descriptor.api_key
def delete_discord_service_api_key(db_session: Session) -> bool:
"""Delete the Discord service API key for a tenant.
Called when:
- Bot config is deleted (self-hosted)
- All guild configs are deleted (Cloud)
Args:
db_session: Database session for the tenant.
Returns:
True if the key was deleted, False if it didn't exist.
"""
existing_key = get_discord_service_api_key(db_session)
if not existing_key:
return False
# Also delete the associated user
api_key_user = db_session.scalar(
select(User).where(User.id == existing_key.user_id) # type: ignore[arg-type]
)
db_session.delete(existing_key)
if api_key_user:
db_session.delete(api_key_user)
db_session.flush()
logger.info("Deleted Discord service API key")
return True
# === DiscordGuildConfig ===
def get_guild_configs(
db_session: Session,
include_channels: bool = False,
) -> list[DiscordGuildConfig]:
"""Get all guild configs for this tenant."""
stmt = select(DiscordGuildConfig)
if include_channels:
stmt = stmt.options(joinedload(DiscordGuildConfig.channels))
return list(db_session.scalars(stmt).unique().all())
def get_guild_config_by_internal_id(
db_session: Session,
internal_id: int,
) -> DiscordGuildConfig | None:
"""Get a specific guild config by its ID."""
return db_session.scalar(
select(DiscordGuildConfig).where(DiscordGuildConfig.id == internal_id)
)
def get_guild_config_by_discord_id(
db_session: Session,
guild_id: int,
) -> DiscordGuildConfig | None:
"""Get a guild config by Discord guild ID."""
return db_session.scalar(
select(DiscordGuildConfig).where(DiscordGuildConfig.guild_id == guild_id)
)
def get_guild_config_by_registration_key(
db_session: Session,
registration_key: str,
) -> DiscordGuildConfig | None:
"""Get a guild config by its registration key."""
return db_session.scalar(
select(DiscordGuildConfig).where(
DiscordGuildConfig.registration_key == registration_key
)
)
def create_guild_config(
db_session: Session,
registration_key: str,
) -> DiscordGuildConfig:
"""Create a new guild config with a registration key (guild_id=NULL)."""
config = DiscordGuildConfig(registration_key=registration_key)
db_session.add(config)
db_session.flush()
return config
def register_guild(
db_session: Session,
config: DiscordGuildConfig,
guild_id: int,
guild_name: str,
) -> DiscordGuildConfig:
"""Complete registration by setting guild_id and guild_name."""
config.guild_id = guild_id
config.guild_name = guild_name
config.registered_at = datetime.now(timezone.utc)
db_session.flush()
return config
def update_guild_config(
db_session: Session,
config: DiscordGuildConfig,
enabled: bool,
default_persona_id: int | None = None,
) -> DiscordGuildConfig:
"""Update guild config fields."""
config.enabled = enabled
config.default_persona_id = default_persona_id
db_session.flush()
return config
def delete_guild_config(
db_session: Session,
internal_id: int,
) -> bool:
"""Delete guild config (cascades to channel configs). Returns True if deleted."""
result = db_session.execute(
delete(DiscordGuildConfig).where(DiscordGuildConfig.id == internal_id)
)
db_session.flush()
return result.rowcount > 0 # type: ignore[attr-defined]
# === DiscordChannelConfig ===
def get_channel_configs(
db_session: Session,
guild_config_id: int,
) -> list[DiscordChannelConfig]:
"""Get all channel configs for a guild."""
return list(
db_session.scalars(
select(DiscordChannelConfig).where(
DiscordChannelConfig.guild_config_id == guild_config_id
)
).all()
)
def get_channel_config_by_discord_ids(
db_session: Session,
guild_id: int,
channel_id: int,
) -> DiscordChannelConfig | None:
"""Get a specific channel config by guild_id and channel_id."""
return db_session.scalar(
select(DiscordChannelConfig)
.join(DiscordGuildConfig)
.where(
DiscordGuildConfig.guild_id == guild_id,
DiscordChannelConfig.channel_id == channel_id,
)
)
def get_channel_config_by_internal_ids(
db_session: Session,
guild_config_id: int,
channel_config_id: int,
) -> DiscordChannelConfig | None:
"""Get a specific channel config by guild_config_id and channel_config_id"""
return db_session.scalar(
select(DiscordChannelConfig).where(
DiscordChannelConfig.guild_config_id == guild_config_id,
DiscordChannelConfig.id == channel_config_id,
)
)
def update_discord_channel_config(
db_session: Session,
config: DiscordChannelConfig,
channel_name: str,
thread_only_mode: bool,
require_bot_invocation: bool,
enabled: bool,
persona_override_id: int | None = None,
) -> DiscordChannelConfig:
"""Update channel config fields."""
config.channel_name = channel_name
config.require_bot_invocation = require_bot_invocation
config.persona_override_id = persona_override_id
config.enabled = enabled
config.thread_only_mode = thread_only_mode
db_session.flush()
return config
def delete_discord_channel_config(
db_session: Session,
guild_config_id: int,
channel_config_id: int,
) -> bool:
"""Delete a channel config. Returns True if deleted."""
result = db_session.execute(
delete(DiscordChannelConfig).where(
DiscordChannelConfig.guild_config_id == guild_config_id,
DiscordChannelConfig.id == channel_config_id,
)
)
db_session.flush()
return result.rowcount > 0 # type: ignore[attr-defined]
def create_channel_config(
db_session: Session,
guild_config_id: int,
channel_view: DiscordChannelView,
) -> DiscordChannelConfig:
"""Create a new channel config with default settings (disabled by default, admin enables via UI)."""
config = DiscordChannelConfig(
guild_config_id=guild_config_id,
channel_id=channel_view.channel_id,
channel_name=channel_view.channel_name,
channel_type=channel_view.channel_type,
is_private=channel_view.is_private,
)
db_session.add(config)
db_session.flush()
return config
def bulk_create_channel_configs(
db_session: Session,
guild_config_id: int,
channels: list[DiscordChannelView],
) -> list[DiscordChannelConfig]:
"""Create multiple channel configs at once. Skips existing channels."""
# Get existing channel IDs for this guild
existing_channel_ids = set(
db_session.scalars(
select(DiscordChannelConfig.channel_id).where(
DiscordChannelConfig.guild_config_id == guild_config_id
)
).all()
)
# Create configs for new channels only
new_configs = []
for channel_view in channels:
if channel_view.channel_id not in existing_channel_ids:
config = DiscordChannelConfig(
guild_config_id=guild_config_id,
channel_id=channel_view.channel_id,
channel_name=channel_view.channel_name,
channel_type=channel_view.channel_type,
is_private=channel_view.is_private,
)
db_session.add(config)
new_configs.append(config)
db_session.flush()
return new_configs
def sync_channel_configs(
db_session: Session,
guild_config_id: int,
current_channels: list[DiscordChannelView],
) -> tuple[int, int, int]:
"""Sync channel configs with current Discord channels.
- Creates configs for new channels (disabled by default)
- Removes configs for deleted channels
- Updates names and types for existing channels if changed
Returns: (added_count, removed_count, updated_count)
"""
current_channel_map = {
channel_view.channel_id: channel_view for channel_view in current_channels
}
current_channel_ids = set(current_channel_map.keys())
# Get existing configs
existing_configs = get_channel_configs(db_session, guild_config_id)
existing_channel_ids = {c.channel_id for c in existing_configs}
# Find channels to add, remove, and potentially update
to_add = current_channel_ids - existing_channel_ids
to_remove = existing_channel_ids - current_channel_ids
# Add new channels
added_count = 0
for channel_id in to_add:
channel_view = current_channel_map[channel_id]
create_channel_config(db_session, guild_config_id, channel_view)
added_count += 1
# Remove deleted channels
removed_count = 0
for config in existing_configs:
if config.channel_id in to_remove:
db_session.delete(config)
removed_count += 1
# Update names, types, and privacy for existing channels if changed
updated_count = 0
for config in existing_configs:
if config.channel_id in current_channel_ids:
channel_view = current_channel_map[config.channel_id]
changed = False
if config.channel_name != channel_view.channel_name:
config.channel_name = channel_view.channel_name
changed = True
if config.channel_type != channel_view.channel_type:
config.channel_type = channel_view.channel_type
changed = True
if config.is_private != channel_view.is_private:
config.is_private = channel_view.is_private
changed = True
if changed:
updated_count += 1
db_session.flush()
return added_count, removed_count, updated_count

View File

@@ -3,6 +3,8 @@ from uuid import UUID
from fastapi import HTTPException
from sqlalchemy import or_
from sqlalchemy import select
from sqlalchemy.dialects.postgresql import insert as pg_insert
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import aliased
from sqlalchemy.orm import Session
@@ -18,45 +20,6 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
def insert_input_prompt_if_not_exists(
user: User | None,
input_prompt_id: int | None,
prompt: str,
content: str,
active: bool,
is_public: bool,
db_session: Session,
commit: bool = True,
) -> InputPrompt:
if input_prompt_id is not None:
input_prompt = (
db_session.query(InputPrompt).filter_by(id=input_prompt_id).first()
)
else:
query = db_session.query(InputPrompt).filter(InputPrompt.prompt == prompt)
if user:
query = query.filter(InputPrompt.user_id == user.id)
else:
query = query.filter(InputPrompt.user_id.is_(None))
input_prompt = query.first()
if input_prompt is None:
input_prompt = InputPrompt(
id=input_prompt_id,
prompt=prompt,
content=content,
active=active,
is_public=is_public or user is None,
user_id=user.id if user else None,
)
db_session.add(input_prompt)
if commit:
db_session.commit()
return input_prompt
def insert_input_prompt(
prompt: str,
content: str,
@@ -64,16 +27,41 @@ def insert_input_prompt(
user: User | None,
db_session: Session,
) -> InputPrompt:
input_prompt = InputPrompt(
user_id = user.id if user else None
# Use atomic INSERT ... ON CONFLICT DO NOTHING with RETURNING
# to avoid race conditions with the uniqueness check
stmt = pg_insert(InputPrompt).values(
prompt=prompt,
content=content,
active=True,
is_public=is_public,
user_id=user.id if user is not None else None,
user_id=user_id,
)
db_session.add(input_prompt)
db_session.commit()
# Use the appropriate constraint based on whether this is a user-owned or public prompt
if user_id is not None:
stmt = stmt.on_conflict_do_nothing(constraint="uq_inputprompt_prompt_user_id")
else:
# Partial unique indexes cannot be targeted by constraint name;
# must use index_elements + index_where
stmt = stmt.on_conflict_do_nothing(
index_elements=[InputPrompt.prompt],
index_where=InputPrompt.user_id.is_(None),
)
stmt = stmt.returning(InputPrompt)
result = db_session.execute(stmt)
input_prompt = result.scalar_one_or_none()
if input_prompt is None:
raise HTTPException(
status_code=409,
detail=f"A prompt shortcut with the name '{prompt}' already exists",
)
db_session.commit()
return input_prompt
@@ -98,23 +86,40 @@ def update_input_prompt(
input_prompt.content = content
input_prompt.active = active
db_session.commit()
try:
db_session.commit()
except IntegrityError:
db_session.rollback()
raise HTTPException(
status_code=409,
detail=f"A prompt shortcut with the name '{prompt}' already exists",
)
return input_prompt
def validate_user_prompt_authorization(
user: User | None, input_prompt: InputPrompt
) -> bool:
"""
Check if the user is authorized to modify the given input prompt.
Returns True only if the user owns the prompt.
Returns False for public prompts (only admins can modify those),
unless auth is disabled (then anyone can manage public prompts).
"""
prompt = InputPromptSnapshot.from_model(input_prompt=input_prompt)
if prompt.user_id is not None:
if user is None:
return False
# Public prompts cannot be modified via the user API (unless auth is disabled)
if prompt.is_public or prompt.user_id is None:
return AUTH_TYPE == AuthType.DISABLED
user_details = UserInfo.from_model(user)
if str(user_details.id) != str(prompt.user_id):
return False
return True
# User must be logged in
if user is None:
return False
# User must own the prompt
user_details = UserInfo.from_model(user)
return str(user_details.id) == str(prompt.user_id)
def remove_public_input_prompt(input_prompt_id: int, db_session: Session) -> None:

View File

@@ -9,6 +9,9 @@ def get_memories(user: User | None, db_session: Session) -> list[str]:
if user is None:
return []
if not user.use_memories:
return []
user_info = [
f"User's name: {user.personal_name}" if user.personal_name else "",
f"User's role: {user.personal_role}" if user.personal_role else "",

View File

@@ -26,6 +26,7 @@ from sqlalchemy import ForeignKey
from sqlalchemy import func
from sqlalchemy import Index
from sqlalchemy import Integer
from sqlalchemy import BigInteger
from sqlalchemy import Sequence
from sqlalchemy import String
@@ -187,6 +188,7 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
nullable=True,
default=None,
)
chat_background: Mapped[str | None] = mapped_column(String, nullable=True)
# personalization fields are exposed via the chat user settings "Personalization" tab
personal_name: Mapped[str | None] = mapped_column(String, nullable=True)
personal_role: Mapped[str | None] = mapped_column(String, nullable=True)
@@ -2045,7 +2047,7 @@ class ChatSession(Base):
ForeignKey("user.id", ondelete="CASCADE"), nullable=True
)
persona_id: Mapped[int | None] = mapped_column(
ForeignKey("persona.id", ondelete="SET NULL"), nullable=True
ForeignKey("persona.id"), nullable=True
)
description: Mapped[str | None] = mapped_column(Text, nullable=True)
# This chat created by OnyxBot
@@ -2931,8 +2933,6 @@ class PersonaLabel(Base):
"Persona",
secondary=Persona__PersonaLabel.__table__,
back_populates="labels",
cascade="all, delete-orphan",
single_parent=True,
)
@@ -3038,6 +3038,124 @@ class SlackBot(Base):
)
class DiscordBotConfig(Base):
"""Global Discord bot configuration (one per tenant).
Stores the bot token when not provided via DISCORD_BOT_TOKEN env var.
Uses a fixed ID with check constraint to enforce only one row per tenant.
"""
__tablename__ = "discord_bot_config"
id: Mapped[str] = mapped_column(
String, primary_key=True, server_default=text("'SINGLETON'")
)
bot_token: Mapped[str] = mapped_column(EncryptedString(), nullable=False)
created_at: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), nullable=False
)
class DiscordGuildConfig(Base):
"""Configuration for a Discord guild (server) connected to this tenant.
registration_key is a one-time key used to link a Discord server to this tenant.
Format: discord_<tenant_id>.<random_token>
guild_id is NULL until the Discord admin runs !register with the key.
"""
__tablename__ = "discord_guild_config"
id: Mapped[int] = mapped_column(primary_key=True)
# Discord snowflake - NULL until registered via command in Discord
guild_id: Mapped[int | None] = mapped_column(BigInteger, nullable=True, unique=True)
guild_name: Mapped[str | None] = mapped_column(String(256), nullable=True)
# One-time registration key: discord_<tenant_id>.<random_token>
registration_key: Mapped[str] = mapped_column(String, unique=True, nullable=False)
registered_at: Mapped[datetime.datetime | None] = mapped_column(
DateTime(timezone=True), nullable=True
)
# Configuration
default_persona_id: Mapped[int | None] = mapped_column(
ForeignKey("persona.id", ondelete="SET NULL"), nullable=True
)
enabled: Mapped[bool] = mapped_column(
Boolean, server_default=text("true"), nullable=False
)
# Relationships
default_persona: Mapped["Persona | None"] = relationship(
"Persona", foreign_keys=[default_persona_id]
)
channels: Mapped[list["DiscordChannelConfig"]] = relationship(
back_populates="guild_config", cascade="all, delete-orphan"
)
class DiscordChannelConfig(Base):
"""Per-channel configuration for Discord bot behavior.
Used to whitelist specific channels and configure per-channel behavior.
"""
__tablename__ = "discord_channel_config"
id: Mapped[int] = mapped_column(primary_key=True)
guild_config_id: Mapped[int] = mapped_column(
ForeignKey("discord_guild_config.id", ondelete="CASCADE"), nullable=False
)
# Discord snowflake
channel_id: Mapped[int] = mapped_column(BigInteger, nullable=False)
channel_name: Mapped[str] = mapped_column(String(), nullable=False)
# Channel type from Discord (text, forum)
channel_type: Mapped[str] = mapped_column(
String(20), server_default=text("'text'"), nullable=False
)
# True if @everyone cannot view the channel
is_private: Mapped[bool] = mapped_column(
Boolean, server_default=text("false"), nullable=False
)
# If true, bot only responds to messages in threads
# Otherwise, will reply in channel
thread_only_mode: Mapped[bool] = mapped_column(
Boolean, server_default=text("false"), nullable=False
)
# If true (default), bot only responds when @mentioned
# If false, bot responds to ALL messages in this channel
require_bot_invocation: Mapped[bool] = mapped_column(
Boolean, server_default=text("true"), nullable=False
)
# Override the guild's default persona for this channel
persona_override_id: Mapped[int | None] = mapped_column(
ForeignKey("persona.id", ondelete="SET NULL"), nullable=True
)
enabled: Mapped[bool] = mapped_column(
Boolean, server_default=text("false"), nullable=False
)
# Relationships
guild_config: Mapped["DiscordGuildConfig"] = relationship(back_populates="channels")
persona_override: Mapped["Persona | None"] = relationship()
# Constraints
__table_args__ = (
UniqueConstraint(
"guild_config_id", "channel_id", name="uq_discord_channel_guild_channel"
),
)
class Milestone(Base):
# This table is used to track significant events for a deployment towards finding value
# The table is currently not used for features but it may be used in the future to inform
@@ -3115,25 +3233,6 @@ class FileRecord(Base):
)
class AgentSearchMetrics(Base):
__tablename__ = "agent__search_metrics"
id: Mapped[int] = mapped_column(primary_key=True)
user_id: Mapped[UUID | None] = mapped_column(
ForeignKey("user.id", ondelete="CASCADE"), nullable=True
)
persona_id: Mapped[int | None] = mapped_column(
ForeignKey("persona.id"), nullable=True
)
agent_type: Mapped[str] = mapped_column(String)
start_time: Mapped[datetime.datetime] = mapped_column(DateTime(timezone=True))
base_duration_s: Mapped[float] = mapped_column(Float)
full_duration_s: Mapped[float] = mapped_column(Float)
base_metrics: Mapped[JSON_ro] = mapped_column(postgresql.JSONB(), nullable=True)
refined_metrics: Mapped[JSON_ro] = mapped_column(postgresql.JSONB(), nullable=True)
all_metrics: Mapped[JSON_ro] = mapped_column(postgresql.JSONB(), nullable=True)
"""
************************************************************************
Enterprise Edition Models
@@ -3526,6 +3625,18 @@ class InputPrompt(Base):
ForeignKey("user.id", ondelete="CASCADE"), nullable=True
)
__table_args__ = (
# Unique constraint on (prompt, user_id) for user-owned prompts
UniqueConstraint("prompt", "user_id", name="uq_inputprompt_prompt_user_id"),
# Partial unique index for public prompts (user_id IS NULL)
Index(
"uq_inputprompt_prompt_public",
"prompt",
unique=True,
postgresql_where=text("user_id IS NULL"),
),
)
class InputPrompt__User(Base):
__tablename__ = "inputprompt__user"
@@ -3534,7 +3645,7 @@ class InputPrompt__User(Base):
ForeignKey("inputprompt.id"), primary_key=True
)
user_id: Mapped[UUID | None] = mapped_column(
ForeignKey("inputprompt.id"), primary_key=True
ForeignKey("user.id"), primary_key=True
)
disabled: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)

View File

@@ -917,7 +917,9 @@ def upsert_persona(
existing_persona.icon_name = icon_name
existing_persona.is_visible = is_visible
existing_persona.search_start_date = search_start_date
existing_persona.labels = labels or []
if label_ids is not None:
existing_persona.labels.clear()
existing_persona.labels = labels or []
existing_persona.is_default_persona = (
is_default_persona
if is_default_persona is not None

View File

@@ -20,7 +20,7 @@ from onyx.db.models import SearchSettings
from onyx.db.search_settings import get_current_search_settings
from onyx.db.search_settings import get_secondary_search_settings
from onyx.db.search_settings import update_search_settings_status
from onyx.document_index.factory import get_default_document_index
from onyx.document_index.factory import get_all_document_indices
from onyx.key_value_store.factory import get_kv_store
from onyx.utils.logger import setup_logger
@@ -80,39 +80,43 @@ def _perform_index_swap(
db_session=db_session,
)
# remove the old index from the vector db
document_index = get_default_document_index(new_search_settings, None)
# This flow is for checking and possibly creating an index so we get all
# indices.
document_indices = get_all_document_indices(new_search_settings, None, None)
WAIT_SECONDS = 5
success = False
for x in range(VESPA_NUM_ATTEMPTS_ON_STARTUP):
try:
logger.notice(
f"Vespa index swap (attempt {x+1}/{VESPA_NUM_ATTEMPTS_ON_STARTUP})..."
)
document_index.ensure_indices_exist(
primary_embedding_dim=new_search_settings.final_embedding_dim,
primary_embedding_precision=new_search_settings.embedding_precision,
# just finished swap, no more secondary index
secondary_index_embedding_dim=None,
secondary_index_embedding_precision=None,
)
for document_index in document_indices:
success = False
for x in range(VESPA_NUM_ATTEMPTS_ON_STARTUP):
try:
logger.notice(
f"Document index {document_index.__class__.__name__} swap (attempt {x+1}/{VESPA_NUM_ATTEMPTS_ON_STARTUP})..."
)
document_index.ensure_indices_exist(
primary_embedding_dim=new_search_settings.final_embedding_dim,
primary_embedding_precision=new_search_settings.embedding_precision,
# just finished swap, no more secondary index
secondary_index_embedding_dim=None,
secondary_index_embedding_precision=None,
)
logger.notice("Vespa index swap complete.")
success = True
break
except Exception:
logger.exception(
f"Vespa index swap did not succeed. The Vespa service may not be ready yet. Retrying in {WAIT_SECONDS} seconds."
)
time.sleep(WAIT_SECONDS)
logger.notice("Document index swap complete.")
success = True
break
except Exception:
logger.exception(
f"Document index swap for {document_index.__class__.__name__} did not succeed. "
f"The document index services may not be ready yet. Retrying in {WAIT_SECONDS} seconds."
)
time.sleep(WAIT_SECONDS)
if not success:
logger.error(
f"Vespa index swap did not succeed. Attempt limit reached. ({VESPA_NUM_ATTEMPTS_ON_STARTUP})"
)
return None
if not success:
logger.error(
f"Document index swap for {document_index.__class__.__name__} did not succeed. "
f"Attempt limit reached. ({VESPA_NUM_ATTEMPTS_ON_STARTUP})"
)
return None
return current_search_settings

View File

@@ -139,6 +139,20 @@ def update_user_theme_preference(
db_session.commit()
def update_user_chat_background(
user_id: UUID,
chat_background: str | None,
db_session: Session,
) -> None:
"""Update user's chat background setting."""
db_session.execute(
update(User)
.where(User.id == user_id) # type: ignore
.values(chat_background=chat_background)
)
db_session.commit()
def update_user_personalization(
user_id: UUID,
*,

View File

@@ -15,7 +15,9 @@ from sqlalchemy.sql.elements import KeyedColumnElement
from onyx.auth.invited_users import remove_user_from_invited_users
from onyx.auth.schemas import UserRole
from onyx.db.api_key import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
from onyx.db.models import DocumentSet
from onyx.db.models import DocumentSet__User
from onyx.db.models import Persona
from onyx.db.models import Persona__User
from onyx.db.models import SamlAccount
from onyx.db.models import User
@@ -327,6 +329,15 @@ def delete_user_from_db(
db_session.query(SamlAccount).filter(
SamlAccount.user_id == user_to_delete.id
).delete()
# Null out ownership on document sets and personas so they're
# preserved for other users instead of being cascade-deleted
db_session.query(DocumentSet).filter(
DocumentSet.user_id == user_to_delete.id
).update({DocumentSet.user_id: None})
db_session.query(Persona).filter(Persona.user_id == user_to_delete.id).update(
{Persona.user_id: None}
)
db_session.query(DocumentSet__User).filter(
DocumentSet__User.user_id == user_to_delete.id
).delete()

View File

@@ -40,3 +40,10 @@ class DocumentRow(BaseModel):
class SortOrder(str, Enum):
ASC = "asc"
DESC = "desc"
class DiscordChannelView(BaseModel):
channel_id: int
channel_name: str
channel_type: str = "text" # text, forum
is_private: bool = False # True if @everyone cannot view the channel

View File

@@ -2,13 +2,18 @@ from onyx.configs.app_configs import BLURB_SIZE
from onyx.configs.constants import RETURN_SEPARATOR
from onyx.context.search.models import InferenceChunk
from onyx.context.search.models import InferenceChunkUncleaned
from onyx.indexing.models import DocAwareChunk
from onyx.indexing.models import DocMetadataAwareIndexChunk
def generate_enriched_content_for_chunk(chunk: DocMetadataAwareIndexChunk) -> str:
def generate_enriched_content_for_chunk_text(chunk: DocMetadataAwareIndexChunk) -> str:
return f"{chunk.title_prefix}{chunk.doc_summary}{chunk.content}{chunk.chunk_context}{chunk.metadata_suffix_keyword}"
def generate_enriched_content_for_chunk_embedding(chunk: DocAwareChunk) -> str:
return f"{chunk.title_prefix}{chunk.doc_summary}{chunk.content}{chunk.chunk_context}{chunk.metadata_suffix_semantic}"
def cleanup_content_for_chunks(
chunks: list[InferenceChunkUncleaned],
) -> list[InferenceChunk]:

View File

@@ -1,9 +1,8 @@
import httpx
from sqlalchemy.orm import Session
from onyx.configs.app_configs import ENABLE_OPENSEARCH_FOR_ONYX
from onyx.configs.app_configs import ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
from onyx.configs.app_configs import ENABLE_OPENSEARCH_RETRIEVAL_FOR_ONYX
from onyx.db.models import SearchSettings
from onyx.db.search_settings import get_current_search_settings
from onyx.document_index.interfaces import DocumentIndex
from onyx.document_index.opensearch.opensearch_document_index import (
OpenSearchOldDocumentIndex,
@@ -17,17 +16,24 @@ def get_default_document_index(
secondary_search_settings: SearchSettings | None,
httpx_client: httpx.Client | None = None,
) -> DocumentIndex:
"""Primary index is the index that is used for querying/updating etc.
Secondary index is for when both the currently used index and the upcoming
index both need to be updated, updates are applied to both indices"""
"""Gets the default document index from env vars.
To be used for retrieval only. Indexing should be done through both indices
until Vespa is deprecated.
Pre-existing docstring for this function, although secondary indices are not
currently supported:
Primary index is the index that is used for querying/updating etc. Secondary
index is for when both the currently used index and the upcoming index both
need to be updated, updates are applied to both indices.
"""
secondary_index_name: str | None = None
secondary_large_chunks_enabled: bool | None = None
if secondary_search_settings:
secondary_index_name = secondary_search_settings.index_name
secondary_large_chunks_enabled = secondary_search_settings.large_chunks_enabled
if ENABLE_OPENSEARCH_FOR_ONYX:
if ENABLE_OPENSEARCH_RETRIEVAL_FOR_ONYX:
return OpenSearchOldDocumentIndex(
index_name=search_settings.index_name,
secondary_index_name=secondary_index_name,
@@ -47,12 +53,48 @@ def get_default_document_index(
)
def get_current_primary_default_document_index(db_session: Session) -> DocumentIndex:
def get_all_document_indices(
search_settings: SearchSettings,
secondary_search_settings: SearchSettings | None,
httpx_client: httpx.Client | None = None,
) -> list[DocumentIndex]:
"""Gets all document indices.
NOTE: Will only return an OpenSearch index interface if
ENABLE_OPENSEARCH_INDEXING_FOR_ONYX is True. This is so we don't break flows
where we know it won't be enabled.
Used for indexing only. Until Vespa is deprecated we will index into both
document indices. Retrieval is done through only one index however.
Large chunks and secondary indices are not currently supported so we
hardcode appropriate values.
"""
TODO: Use redis to cache this or something
"""
search_settings = get_current_search_settings(db_session)
return get_default_document_index(
search_settings,
None,
vespa_document_index = VespaIndex(
index_name=search_settings.index_name,
secondary_index_name=(
secondary_search_settings.index_name if secondary_search_settings else None
),
large_chunks_enabled=search_settings.large_chunks_enabled,
secondary_large_chunks_enabled=(
secondary_search_settings.large_chunks_enabled
if secondary_search_settings
else None
),
multitenant=MULTI_TENANT,
httpx_client=httpx_client,
)
opensearch_document_index: OpenSearchOldDocumentIndex | None = None
if ENABLE_OPENSEARCH_INDEXING_FOR_ONYX:
opensearch_document_index = OpenSearchOldDocumentIndex(
index_name=search_settings.index_name,
secondary_index_name=None,
large_chunks_enabled=False,
secondary_large_chunks_enabled=None,
multitenant=MULTI_TENANT,
httpx_client=httpx_client,
)
result: list[DocumentIndex] = [vespa_document_index]
if opensearch_document_index:
result.append(opensearch_document_index)
return result

View File

@@ -28,8 +28,8 @@ of "minimum value clipping".
## On time decay and boosting
Embedding models do not have a uniform distribution from 0 to 1. The values typically cluster strongly around 0.6 to 0.8 but also
varies between models and even the query. It is not a safe assumption to pre-normalize the scores so we also cannot apply any
additive or multiplicative boost to it. Ie. if results of a doc cluster around 0.6 to 0.8 and I give a 50% penalty to the score,
it doesn't bring a result from the top of the range to 50 percentile, it brings its under the 0.6 and is now the worst match.
additive or multiplicative boost to it. i.e. if results of a doc cluster around 0.6 to 0.8 and I give a 50% penalty to the score,
it doesn't bring a result from the top of the range to 50th percentile, it brings it under the 0.6 and is now the worst match.
Same logic applies to additive boosting.
So these boosts can only be applied after normalization. Unfortunately with Opensearch, the normalization processor runs last
@@ -40,7 +40,7 @@ and vector would make the docs which only came because of time filter very low s
scored documents from the union of all the `Search` phase documents to show up higher and potentially not get dropped before
being fetched and returned to the user. But there are other issues of including these:
- There is no way to sort by this field, only a filter, so there's no way to guarantee the best docs even irrespective of the
contents. If there are lots of updates, this may miss
contents. If there are lots of updates, this may miss.
- There is not a good way to normalize this field, the best is to clip it on the bottom.
- This would require using min-max norm but z-score norm is better for the other functions due to things like it being less
sensitive to outliers, better handles distribution drifts (min-max assumes stable meaningful ranges), better for comparing

View File

@@ -1,4 +1,5 @@
import logging
import time
from typing import Any
from typing import Generic
from typing import TypeVar
@@ -569,6 +570,9 @@ class OpenSearchClient:
def close(self) -> None:
"""Closes the client.
TODO(andrei): Can we have some way to auto close when the client no
longer has any references?
Raises:
Exception: There was an error closing the client.
"""
@@ -596,3 +600,55 @@ class OpenSearchClient:
)
hits_second_layer: list[Any] = hits_first_layer.get("hits", [])
return hits_second_layer
def wait_for_opensearch_with_timeout(
wait_interval_s: int = 5,
wait_limit_s: int = 60,
client: OpenSearchClient | None = None,
) -> bool:
"""Waits for OpenSearch to become ready subject to a timeout.
Will create a new dummy client if no client is provided. Will close this
client at the end of the function. Will not close the client if it was
supplied.
Args:
wait_interval_s: The interval in seconds to wait between checks.
Defaults to 5.
wait_limit_s: The total timeout in seconds to wait for OpenSearch to
become ready. Defaults to 60.
client: The OpenSearch client to use for pinging. If None, a new dummy
client will be created. Defaults to None.
Returns:
True if OpenSearch is ready, False otherwise.
"""
made_client = False
try:
if client is None:
# NOTE: index_name does not matter because we are only using this object
# to ping.
# TODO(andrei): Make this better.
client = OpenSearchClient(index_name="")
made_client = True
time_start = time.monotonic()
while True:
if client.ping():
logger.info("[OpenSearch] Readiness probe succeeded. Continuing...")
return True
time_elapsed = time.monotonic() - time_start
if time_elapsed > wait_limit_s:
logger.info(
f"[OpenSearch] Readiness probe did not succeed within the timeout "
f"({wait_limit_s} seconds)."
)
return False
logger.info(
f"[OpenSearch] Readiness probe ongoing. elapsed={time_elapsed:.1f} timeout={wait_limit_s:.1f}"
)
time.sleep(wait_interval_s)
finally:
if made_client:
assert client is not None
client.close()

View File

@@ -3,7 +3,9 @@ from typing import Any
import httpx
from onyx.access.models import DocumentAccess
from onyx.configs.chat_configs import TITLE_CONTENT_RATIO
from onyx.configs.constants import PUBLIC_DOC_PAT
from onyx.connectors.cross_connector_utils.miscellaneous_utils import (
get_experts_stores_representations,
)
@@ -17,7 +19,7 @@ from onyx.db.enums import EmbeddingPrecision
from onyx.db.models import DocumentSource
from onyx.document_index.chunk_content_enrichment import cleanup_content_for_chunks
from onyx.document_index.chunk_content_enrichment import (
generate_enriched_content_for_chunk,
generate_enriched_content_for_chunk_text,
)
from onyx.document_index.interfaces import DocumentIndex as OldDocumentIndex
from onyx.document_index.interfaces import (
@@ -68,6 +70,18 @@ from shared_configs.model_server_models import Embedding
logger = setup_logger(__name__)
def generate_opensearch_filtered_access_control_list(
access: DocumentAccess,
) -> list[str]:
"""Generates an access control list with PUBLIC_DOC_PAT removed.
In the OpenSearch schema this is represented by PUBLIC_FIELD_NAME.
"""
access_control_list = access.to_acl()
access_control_list.discard(PUBLIC_DOC_PAT)
return list(access_control_list)
def _convert_retrieved_opensearch_chunk_to_inference_chunk_uncleaned(
chunk: DocumentChunk,
score: float | None,
@@ -140,19 +154,21 @@ def _convert_onyx_chunk_to_opensearch_document(
return DocumentChunk(
document_id=chunk.source_document.id,
chunk_index=chunk.chunk_id,
title=chunk.source_document.title,
# Use get_title_for_document_index to match the logic used when creating
# the title_embedding in the embedder. This method falls back to
# semantic_identifier when title is None (but not empty string).
title=chunk.source_document.get_title_for_document_index(),
title_vector=chunk.title_embedding,
content=generate_enriched_content_for_chunk(chunk),
content=generate_enriched_content_for_chunk_text(chunk),
content_vector=chunk.embeddings.full_embedding,
source_type=chunk.source_document.source.value,
metadata_list=chunk.source_document.get_metadata_str_attributes(),
metadata_suffix=chunk.metadata_suffix_keyword,
last_updated=chunk.source_document.doc_updated_at,
public=chunk.access.is_public,
# TODO(andrei): When going over ACL look very carefully at
# access_control_list. Notice DocumentAccess::to_acl prepends every
# string with a type.
access_control_list=list(chunk.access.to_acl()),
access_control_list=generate_opensearch_filtered_access_control_list(
chunk.access
),
global_boost=chunk.boost,
semantic_identifier=chunk.source_document.semantic_identifier,
image_file_id=chunk.image_file_id,
@@ -421,6 +437,24 @@ class OpenSearchDocumentIndex(DocumentIndex):
def verify_and_create_index_if_necessary(
self, embedding_dim: int, embedding_precision: EmbeddingPrecision
) -> None:
"""Verifies and creates the index if necessary.
Also puts the desired search pipeline state, creating the pipelines if
they do not exist and updating them otherwise.
Args:
embedding_dim: Vector dimensionality for the vector similarity part
of the search.
embedding_precision: Precision of the values of the vectors for the
similarity part of the search.
Raises:
RuntimeError: There was an error verifying or creating the index or
search pipelines.
"""
logger.debug(
f"[OpenSearchDocumentIndex] Verifying and creating index {self._index_name} if necessary."
)
expected_mappings = DocumentSchema.get_document_schema(
embedding_dim, self._tenant_state.multitenant
)
@@ -450,6 +484,9 @@ class OpenSearchDocumentIndex(DocumentIndex):
chunks: list[DocMetadataAwareIndexChunk],
indexing_metadata: IndexingMetadata,
) -> list[DocumentInsertionRecord]:
logger.debug(
f"[OpenSearchDocumentIndex] Indexing {len(chunks)} chunks for index {self._index_name}."
)
# Set of doc IDs.
unique_docs_to_be_indexed: set[str] = set()
document_indexing_results: list[DocumentInsertionRecord] = []
@@ -494,6 +531,8 @@ class OpenSearchDocumentIndex(DocumentIndex):
def delete(self, document_id: str, chunk_count: int | None = None) -> int:
"""Deletes all chunks for a given document.
Does nothing if the specified document ID does not exist.
TODO(andrei): Make this method require supplying source type.
TODO(andrei): Consider implementing this method to delete on document
chunk IDs vs querying for matching document chunks.
@@ -510,6 +549,9 @@ class OpenSearchDocumentIndex(DocumentIndex):
Returns:
The number of chunks successfully deleted.
"""
logger.debug(
f"[OpenSearchDocumentIndex] Deleting document {document_id} from index {self._index_name}."
)
query_body = DocumentQuery.delete_from_document_id_query(
document_id=document_id,
tenant_state=self._tenant_state,
@@ -523,6 +565,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
) -> None:
"""Updates some set of chunks.
NOTE: Will raise if the specified document chunks do not exist.
NOTE: Requires document chunk count be known; will raise if it is not.
NOTE: Each update request must have some field to update; if not it is
assumed there is a bug in the caller and this will raise.
@@ -539,14 +582,19 @@ class OpenSearchDocumentIndex(DocumentIndex):
RuntimeError: Failed to update some or all of the chunks for the
specified documents.
"""
logger.debug(
f"[OpenSearchDocumentIndex] Updating {len(update_requests)} chunks for index {self._index_name}."
)
for update_request in update_requests:
properties_to_update: dict[str, Any] = dict()
# TODO(andrei): Nit but consider if we can use DocumentChunk
# here so we don't have to think about passing in the
# appropriate types into this dict.
if update_request.access is not None:
properties_to_update[ACCESS_CONTROL_LIST_FIELD_NAME] = list(
update_request.access.to_acl()
properties_to_update[ACCESS_CONTROL_LIST_FIELD_NAME] = (
generate_opensearch_filtered_access_control_list(
update_request.access
)
)
if update_request.document_sets is not None:
properties_to_update[DOCUMENT_SETS_FIELD_NAME] = list(
@@ -592,24 +640,27 @@ class OpenSearchDocumentIndex(DocumentIndex):
def id_based_retrieval(
self,
chunk_requests: list[DocumentSectionRequest],
# TODO(andrei): When going over ACL look very carefully at
# access_control_list. Notice DocumentAccess::to_acl prepends every
# string with a type.
filters: IndexFilters,
# TODO(andrei): Remove this from the new interface at some point; we
# should not be exposing this.
batch_retrieval: bool = False,
# TODO(andrei): Add a param for whether to retrieve hidden docs.
) -> list[InferenceChunk]:
"""
TODO(andrei): Consider implementing this method to retrieve on document
chunk IDs vs querying for matching document chunks.
"""
logger.debug(
f"[OpenSearchDocumentIndex] Retrieving {len(chunk_requests)} chunks for index {self._index_name}."
)
results: list[InferenceChunk] = []
for chunk_request in chunk_requests:
search_hits: list[SearchHit[DocumentChunk]] = []
query_body = DocumentQuery.get_from_document_id_query(
document_id=chunk_request.document_id,
tenant_state=self._tenant_state,
index_filters=filters,
include_hidden=False,
max_chunk_size=chunk_request.max_chunk_size,
min_chunk_index=chunk_request.min_chunk_ind,
max_chunk_index=chunk_request.max_chunk_ind,
@@ -636,19 +687,21 @@ class OpenSearchDocumentIndex(DocumentIndex):
query_embedding: Embedding,
final_keywords: list[str] | None,
query_type: QueryType,
# TODO(andrei): When going over ACL look very carefully at
# access_control_list. Notice DocumentAccess::to_acl prepends every
# string with a type.
filters: IndexFilters,
num_to_retrieve: int,
offset: int = 0,
) -> list[InferenceChunk]:
logger.debug(
f"[OpenSearchDocumentIndex] Hybrid retrieving {num_to_retrieve} chunks for index {self._index_name}."
)
query_body = DocumentQuery.get_hybrid_search_query(
query_text=query,
query_vector=query_embedding,
num_candidates=1000, # TODO(andrei): Magic number.
num_hits=num_to_retrieve,
tenant_state=self._tenant_state,
index_filters=filters,
include_hidden=False,
)
search_hits: list[SearchHit[DocumentChunk]] = self._os_client.search(
body=query_body,

View File

@@ -172,24 +172,23 @@ class DocumentChunk(BaseModel):
return serialized_exclude_none
@field_serializer("last_updated", mode="wrap")
def serialize_datetime_fields_to_epoch_millis(
def serialize_datetime_fields_to_epoch_seconds(
self, value: datetime | None, handler: SerializerFunctionWrapHandler
) -> int | None:
"""
Serializes datetime fields to milliseconds since the Unix epoch.
Serializes datetime fields to seconds since the Unix epoch.
If there is no datetime, returns None.
"""
if value is None:
return None
value = set_or_convert_timezone_to_utc(value)
# timestamp returns a float in seconds so convert to millis.
return int(value.timestamp() * 1000)
return int(value.timestamp())
@field_validator("last_updated", mode="before")
@classmethod
def parse_epoch_millis_to_datetime(cls, value: Any) -> datetime | None:
"""Parses milliseconds since the Unix epoch to a datetime object.
def parse_epoch_seconds_to_datetime(cls, value: Any) -> datetime | None:
"""Parses seconds since the Unix epoch to a datetime object.
If the input is None, returns None.
@@ -204,7 +203,7 @@ class DocumentChunk(BaseModel):
raise ValueError(
f"Bug: Expected an int for the last_updated property from OpenSearch, got {type(value)} instead."
)
return datetime.fromtimestamp(value / 1000, tz=timezone.utc)
return datetime.fromtimestamp(value, tz=timezone.utc)
@field_serializer("tenant_id", mode="wrap")
def serialize_tenant_state(
@@ -354,11 +353,9 @@ class DocumentSchema:
},
SOURCE_TYPE_FIELD_NAME: {"type": "keyword"},
METADATA_LIST_FIELD_NAME: {"type": "keyword"},
# TODO(andrei): Check if Vespa stores seconds, we may wanna do
# seconds here not millis.
LAST_UPDATED_FIELD_NAME: {
"type": "date",
"format": "epoch_millis",
"format": "epoch_second",
# For some reason date defaults to False, even though it
# would make sense to sort by date.
"doc_values": True,
@@ -366,14 +363,21 @@ class DocumentSchema:
# Access control fields.
# Whether the doc is public. Could have fallen under access
# control list but is such a broad and critical filter that it
# is its own field.
# is its own field. If true, ACCESS_CONTROL_LIST_FIELD_NAME
# should have no effect on queries.
PUBLIC_FIELD_NAME: {"type": "boolean"},
# Access control list for the doc, excluding public access,
# which is covered above.
# If a user's access set contains at least one entry from this
# set, the user should be able to retrieve this document. This
# only applies if public is set to false; public non-hidden
# documents are always visible to anyone in a given tenancy
# regardless of this field.
ACCESS_CONTROL_LIST_FIELD_NAME: {"type": "keyword"},
# Whether the doc is hidden from search results. Should clobber
# all other search filters; up to search implementations to
# guarantee this.
# Whether the doc is hidden from search results.
# Should clobber all other access search filters, namely
# PUBLIC_FIELD_NAME and ACCESS_CONTROL_LIST_FIELD_NAME; up to
# search implementations to guarantee this.
HIDDEN_FIELD_NAME: {"type": "boolean"},
GLOBAL_BOOST_FIELD_NAME: {"type": "integer"},
# This field is only used for displaying a useful name for the
@@ -447,7 +451,6 @@ class DocumentSchema:
DOCUMENT_ID_FIELD_NAME: {"type": "keyword"},
CHUNK_INDEX_FIELD_NAME: {"type": "integer"},
# The maximum number of tokens this chunk's content can hold.
# TODO(andrei): Can we generalize this to embedding type?
MAX_CHUNK_SIZE_FIELD_NAME: {"type": "integer"},
},
}

View File

@@ -1,21 +1,36 @@
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from typing import Any
from uuid import UUID
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import INDEX_SEPARATOR
from onyx.context.search.models import IndexFilters
from onyx.context.search.models import Tag
from onyx.document_index.interfaces_new import TenantState
from onyx.document_index.opensearch.constants import SEARCH_CONTENT_KEYWORD_WEIGHT
from onyx.document_index.opensearch.constants import SEARCH_CONTENT_PHRASE_WEIGHT
from onyx.document_index.opensearch.constants import SEARCH_CONTENT_VECTOR_WEIGHT
from onyx.document_index.opensearch.constants import SEARCH_TITLE_KEYWORD_WEIGHT
from onyx.document_index.opensearch.constants import SEARCH_TITLE_VECTOR_WEIGHT
from onyx.document_index.opensearch.schema import ACCESS_CONTROL_LIST_FIELD_NAME
from onyx.document_index.opensearch.schema import CHUNK_INDEX_FIELD_NAME
from onyx.document_index.opensearch.schema import CONTENT_FIELD_NAME
from onyx.document_index.opensearch.schema import CONTENT_VECTOR_FIELD_NAME
from onyx.document_index.opensearch.schema import DOCUMENT_ID_FIELD_NAME
from onyx.document_index.opensearch.schema import DOCUMENT_SETS_FIELD_NAME
from onyx.document_index.opensearch.schema import HIDDEN_FIELD_NAME
from onyx.document_index.opensearch.schema import LAST_UPDATED_FIELD_NAME
from onyx.document_index.opensearch.schema import MAX_CHUNK_SIZE_FIELD_NAME
from onyx.document_index.opensearch.schema import METADATA_LIST_FIELD_NAME
from onyx.document_index.opensearch.schema import PUBLIC_FIELD_NAME
from onyx.document_index.opensearch.schema import set_or_convert_timezone_to_utc
from onyx.document_index.opensearch.schema import SOURCE_TYPE_FIELD_NAME
from onyx.document_index.opensearch.schema import TENANT_ID_FIELD_NAME
from onyx.document_index.opensearch.schema import TITLE_FIELD_NAME
from onyx.document_index.opensearch.schema import TITLE_VECTOR_FIELD_NAME
from onyx.document_index.opensearch.schema import USER_PROJECTS_FIELD_NAME
# Normalization pipelines combine document scores from multiple query clauses.
# The number and ordering of weights should match the query clauses. The values
@@ -91,6 +106,11 @@ assert (
# given search. This value is configurable in the index settings.
DEFAULT_OPENSEARCH_MAX_RESULT_WINDOW = 10_000
# For documents which do not have a value for LAST_UPDATED_FIELD_NAME, we assume
# that the document was last updated this many days ago for the purpose of time
# cutoff filtering during retrieval.
ASSUMED_DOCUMENT_AGE_DAYS = 90
class DocumentQuery:
"""
@@ -103,6 +123,8 @@ class DocumentQuery:
def get_from_document_id_query(
document_id: str,
tenant_state: TenantState,
index_filters: IndexFilters,
include_hidden: bool,
max_chunk_size: int,
min_chunk_index: int | None,
max_chunk_index: int | None,
@@ -120,6 +142,8 @@ class DocumentQuery:
document_id: Onyx document ID. Notably not an OpenSearch document
ID, which points to what Onyx would refer to as a chunk.
tenant_state: Tenant state containing the tenant ID.
index_filters: Filters for the document retrieval query.
include_hidden: Whether to include hidden documents.
max_chunk_size: Document chunks are categorized by the maximum
number of tokens they can hold. This parameter specifies the
maximum size category of document chunks to retrieve.
@@ -136,28 +160,21 @@ class DocumentQuery:
Returns:
A dictionary representing the final ID search query.
"""
filter_clauses: list[dict[str, Any]] = [
{"term": {DOCUMENT_ID_FIELD_NAME: {"value": document_id}}}
]
if tenant_state.multitenant:
# TODO(andrei): Fix tenant stuff.
filter_clauses.append(
{"term": {TENANT_ID_FIELD_NAME: {"value": tenant_state.tenant_id}}}
)
if min_chunk_index is not None or max_chunk_index is not None:
range_clause: dict[str, Any] = {"range": {CHUNK_INDEX_FIELD_NAME: {}}}
if min_chunk_index is not None:
range_clause["range"][CHUNK_INDEX_FIELD_NAME]["gte"] = min_chunk_index
if max_chunk_index is not None:
range_clause["range"][CHUNK_INDEX_FIELD_NAME]["lte"] = max_chunk_index
filter_clauses.append(range_clause)
filter_clauses.append(
{"term": {MAX_CHUNK_SIZE_FIELD_NAME: {"value": max_chunk_size}}}
filter_clauses = DocumentQuery._get_search_filters(
tenant_state=tenant_state,
include_hidden=include_hidden,
access_control_list=index_filters.access_control_list,
source_types=index_filters.source_type or [],
tags=index_filters.tags or [],
document_sets=index_filters.document_set or [],
user_file_ids=index_filters.user_file_ids or [],
project_id=index_filters.project_id,
time_cutoff=index_filters.time_cutoff,
min_chunk_index=min_chunk_index,
max_chunk_index=max_chunk_index,
max_chunk_size=max_chunk_size,
document_id=document_id,
)
final_get_ids_query: dict[str, Any] = {
"query": {"bool": {"filter": filter_clauses}},
# We include this to make sure OpenSearch does not revert to
@@ -195,15 +212,22 @@ class DocumentQuery:
Returns:
A dictionary representing the final delete query.
"""
filter_clauses: list[dict[str, Any]] = [
{"term": {DOCUMENT_ID_FIELD_NAME: {"value": document_id}}}
]
if tenant_state.multitenant:
filter_clauses.append(
{"term": {TENANT_ID_FIELD_NAME: {"value": tenant_state.tenant_id}}}
)
filter_clauses = DocumentQuery._get_search_filters(
tenant_state=tenant_state,
# Delete hidden docs too.
include_hidden=True,
access_control_list=None,
source_types=[],
tags=[],
document_sets=[],
user_file_ids=[],
project_id=None,
time_cutoff=None,
min_chunk_index=None,
max_chunk_index=None,
max_chunk_size=None,
document_id=document_id,
)
final_delete_query: dict[str, Any] = {
"query": {"bool": {"filter": filter_clauses}},
}
@@ -217,19 +241,25 @@ class DocumentQuery:
num_candidates: int,
num_hits: int,
tenant_state: TenantState,
index_filters: IndexFilters,
include_hidden: bool,
) -> dict[str, Any]:
"""Returns a final hybrid search query.
This query can be directly supplied to the OpenSearch client.
NOTE: This query can be directly supplied to the OpenSearch client, but
it MUST be supplied in addition to a search pipeline. The results from
hybrid search are not meaningful without that step.
Args:
query_text: The text to query for.
query_vector: The vector embedding of the text to query for.
num_candidates: The number of candidates to consider for vector
num_candidates: The number of neighbors to consider for vector
similarity search. Generally more candidates improves search
quality at the cost of performance.
num_hits: The final number of hits to return.
tenant_state: Tenant state containing the tenant ID.
index_filters: Filters for the hybrid search query.
include_hidden: Whether to include hidden documents.
Returns:
A dictionary representing the final hybrid search query.
@@ -243,31 +273,47 @@ class DocumentQuery:
hybrid_search_subqueries = DocumentQuery._get_hybrid_search_subqueries(
query_text, query_vector, num_candidates
)
hybrid_search_filters = DocumentQuery._get_hybrid_search_filters(tenant_state)
hybrid_search_filters = DocumentQuery._get_search_filters(
tenant_state=tenant_state,
include_hidden=include_hidden,
# TODO(andrei): We've done no filtering for PUBLIC_DOC_PAT up to
# now. This should not cause any issues but it can introduce
# redundant filters in queries that may affect performance.
access_control_list=index_filters.access_control_list,
source_types=index_filters.source_type or [],
tags=index_filters.tags or [],
document_sets=index_filters.document_set or [],
user_file_ids=index_filters.user_file_ids or [],
project_id=index_filters.project_id,
time_cutoff=index_filters.time_cutoff,
min_chunk_index=None,
max_chunk_index=None,
)
match_highlights_configuration = (
DocumentQuery._get_match_highlights_configuration()
)
# See https://docs.opensearch.org/latest/query-dsl/compound/hybrid/
hybrid_search_query: dict[str, Any] = {
"bool": {
"must": [
{
"hybrid": {
"queries": hybrid_search_subqueries,
}
}
],
# TODO(andrei): When revisiting our hybrid query logic see if
# this needs to be nested one level down.
"filter": hybrid_search_filters,
"hybrid": {
"queries": hybrid_search_subqueries,
# Applied to all the sub-queries. Source:
# https://docs.opensearch.org/latest/query-dsl/compound/hybrid/
# Does AND for each filter in the list.
"filter": {"bool": {"filter": hybrid_search_filters}},
}
}
# NOTE: By default, hybrid search retrieves "size"-many results from
# each OpenSearch shard before aggregation. Source:
# https://docs.opensearch.org/latest/vector-search/ai-search/hybrid-search/pagination/
final_hybrid_search_body: dict[str, Any] = {
"query": hybrid_search_query,
"size": num_hits,
"highlight": match_highlights_configuration,
}
return final_hybrid_search_body
@staticmethod
@@ -294,7 +340,8 @@ class DocumentQuery:
pipeline.
NOTE: For OpenSearch, 5 is the maximum number of query clauses allowed
in a single hybrid query.
in a single hybrid query. Source:
https://docs.opensearch.org/latest/query-dsl/compound/hybrid/
Args:
query_text: The text of the query to search for.
@@ -305,6 +352,7 @@ class DocumentQuery:
hybrid_search_queries: list[dict[str, Any]] = [
{
"knn": {
# Match on semantic similarity of the title.
TITLE_VECTOR_FIELD_NAME: {
"vector": query_vector,
"k": num_candidates,
@@ -313,6 +361,7 @@ class DocumentQuery:
},
{
"knn": {
# Match on semantic similarity of the content.
CONTENT_VECTOR_FIELD_NAME: {
"vector": query_vector,
"k": num_candidates,
@@ -322,36 +371,273 @@ class DocumentQuery:
{
"multi_match": {
"query": query_text,
# TODO(andrei): Ask Yuhong do we want this?
# Either fuzzy match on the analyzed title (boosted 2x), or
# exact match on exact title keywords (no OpenSearch
# analysis done on the title). See
# https://docs.opensearch.org/latest/mappings/supported-field-types/keyword/
"fields": [f"{TITLE_FIELD_NAME}^2", f"{TITLE_FIELD_NAME}.keyword"],
# Returns the score of the best match of the fields above.
# See
# https://docs.opensearch.org/latest/query-dsl/full-text/multi-match/
"type": "best_fields",
}
},
# Fuzzy match on the OpenSearch-analyzed content. See
# https://docs.opensearch.org/latest/query-dsl/full-text/match/
{"match": {CONTENT_FIELD_NAME: {"query": query_text}}},
# Exact match on the OpenSearch-analyzed content. See
# https://docs.opensearch.org/latest/query-dsl/full-text/match-phrase/
{"match_phrase": {CONTENT_FIELD_NAME: {"query": query_text, "boost": 1.5}}},
]
return hybrid_search_queries
@staticmethod
def _get_hybrid_search_filters(tenant_state: TenantState) -> list[dict[str, Any]]:
"""Returns filters for hybrid search.
def _get_search_filters(
tenant_state: TenantState,
include_hidden: bool,
access_control_list: list[str] | None,
source_types: list[DocumentSource],
tags: list[Tag],
document_sets: list[str],
user_file_ids: list[UUID],
project_id: int | None,
time_cutoff: datetime | None,
min_chunk_index: int | None,
max_chunk_index: int | None,
max_chunk_size: int | None = None,
document_id: str | None = None,
) -> list[dict[str, Any]]:
"""Returns filters to be passed into the "filter" key of a search query.
For now only fetches public and not hidden documents.
The "filter" key applies a logical AND operator to its elements, so
every subfilter must evaluate to true in order for the document to be
retrieved. This function returns a list of such subfilters.
See https://docs.opensearch.org/latest/query-dsl/compound/bool/
The return of this function is not sufficient to be directly supplied to
the OpenSearch client. See get_hybrid_search_query.
Args:
tenant_state: Tenant state containing the tenant ID.
include_hidden: Whether to include hidden documents.
access_control_list: Access control list for the documents to
retrieve. If None, there is no restriction on the documents that
can be retrieved. If not None, only public documents can be
retrieved, or non-public documents where at least one acl
provided here is present in the document's acl list.
source_types: If supplied, only documents of one of these source
types will be retrieved.
tags: If supplied, only documents with an entry in their metadata
list corresponding to a tag will be retrieved.
document_sets: If supplied, only documents with at least one
document set ID from this list will be retrieved.
user_file_ids: If supplied, only document IDs in this list will be
retrieved.
project_id: If not None, only documents with this project ID in user
projects will be retrieved.
time_cutoff: Time cutoff for the documents to retrieve. If not None,
Documents which were last updated before this date will not be
returned. For documents which do not have a value for their last
updated time, we assume some default age of
ASSUMED_DOCUMENT_AGE_DAYS for when the document was last
updated.
min_chunk_index: The minimum chunk index to retrieve, inclusive. If
None, no minimum chunk index will be applied.
max_chunk_index: The maximum chunk index to retrieve, inclusive. If
None, no maximum chunk index will be applied.
max_chunk_size: The type of chunk to retrieve, specified by the
maximum number of tokens it can hold. If None, no filter will be
applied for this. Defaults to None.
NOTE: See DocumentChunk.max_chunk_size.
document_id: The document ID to retrieve. If None, no filter will be
applied for this. Defaults to None.
WARNING: This filters on the same property as user_file_ids.
Although it would never make sense to supply both, note that if
user_file_ids is supplied and does not contain document_id, no
matches will be retrieved.
TODO(andrei): Add ACL filters and stuff.
Returns:
A list of filters to be passed into the "filter" key of a search
query.
"""
hybrid_search_filters: list[dict[str, Any]] = [
{"term": {PUBLIC_FIELD_NAME: {"value": True}}},
{"term": {HIDDEN_FIELD_NAME: {"value": False}}},
]
def _get_acl_visibility_filter(
access_control_list: list[str],
) -> dict[str, Any]:
# Logical OR operator on its elements.
acl_visibility_filter: dict[str, Any] = {"bool": {"should": []}}
acl_visibility_filter["bool"]["should"].append(
{"term": {PUBLIC_FIELD_NAME: {"value": True}}}
)
for acl in access_control_list:
acl_subclause: dict[str, Any] = {
"term": {ACCESS_CONTROL_LIST_FIELD_NAME: {"value": acl}}
}
acl_visibility_filter["bool"]["should"].append(acl_subclause)
return acl_visibility_filter
def _get_source_type_filter(
source_types: list[DocumentSource],
) -> dict[str, Any]:
# Logical OR operator on its elements.
source_type_filter: dict[str, Any] = {"bool": {"should": []}}
for source_type in source_types:
source_type_filter["bool"]["should"].append(
{"term": {SOURCE_TYPE_FIELD_NAME: {"value": source_type.value}}}
)
return source_type_filter
def _get_tag_filter(tags: list[Tag]) -> dict[str, Any]:
# Logical OR operator on its elements.
tag_filter: dict[str, Any] = {"bool": {"should": []}}
for tag in tags:
# Kind of an abstraction leak, see
# convert_metadata_dict_to_list_of_strings for why metadata list
# entries are expected to look this way.
tag_str = f"{tag.tag_key}{INDEX_SEPARATOR}{tag.tag_value}"
tag_filter["bool"]["should"].append(
{"term": {METADATA_LIST_FIELD_NAME: {"value": tag_str}}}
)
return tag_filter
def _get_document_set_filter(document_sets: list[str]) -> dict[str, Any]:
# Logical OR operator on its elements.
document_set_filter: dict[str, Any] = {"bool": {"should": []}}
for document_set in document_sets:
document_set_filter["bool"]["should"].append(
{"term": {DOCUMENT_SETS_FIELD_NAME: {"value": document_set}}}
)
return document_set_filter
def _get_user_file_id_filter(user_file_ids: list[UUID]) -> dict[str, Any]:
# Logical OR operator on its elements.
user_file_id_filter: dict[str, Any] = {"bool": {"should": []}}
for user_file_id in user_file_ids:
user_file_id_filter["bool"]["should"].append(
{"term": {DOCUMENT_ID_FIELD_NAME: {"value": str(user_file_id)}}}
)
return user_file_id_filter
def _get_user_project_filter(project_id: int) -> dict[str, Any]:
# Logical OR operator on its elements.
user_project_filter: dict[str, Any] = {"bool": {"should": []}}
user_project_filter["bool"]["should"].append(
{"term": {USER_PROJECTS_FIELD_NAME: {"value": project_id}}}
)
return user_project_filter
def _get_time_cutoff_filter(time_cutoff: datetime) -> dict[str, Any]:
# Convert to UTC if not already so the cutoff is comparable to the
# document data.
time_cutoff = set_or_convert_timezone_to_utc(time_cutoff)
# Logical OR operator on its elements.
time_cutoff_filter: dict[str, Any] = {"bool": {"should": []}}
time_cutoff_filter["bool"]["should"].append(
{
"range": {
LAST_UPDATED_FIELD_NAME: {"gte": int(time_cutoff.timestamp())}
}
}
)
if time_cutoff < datetime.now(timezone.utc) - timedelta(
days=ASSUMED_DOCUMENT_AGE_DAYS
):
# Since the time cutoff is older than ASSUMED_DOCUMENT_AGE_DAYS
# ago, we include documents which have no
# LAST_UPDATED_FIELD_NAME value.
time_cutoff_filter["bool"]["should"].append(
{
"bool": {
"must_not": {"exists": {"field": LAST_UPDATED_FIELD_NAME}}
}
}
)
return time_cutoff_filter
def _get_chunk_index_filter(
min_chunk_index: int | None, max_chunk_index: int | None
) -> dict[str, Any]:
range_clause: dict[str, Any] = {"range": {CHUNK_INDEX_FIELD_NAME: {}}}
if min_chunk_index is not None:
range_clause["range"][CHUNK_INDEX_FIELD_NAME]["gte"] = min_chunk_index
if max_chunk_index is not None:
range_clause["range"][CHUNK_INDEX_FIELD_NAME]["lte"] = max_chunk_index
return range_clause
filter_clauses: list[dict[str, Any]] = []
if not include_hidden:
filter_clauses.append({"term": {HIDDEN_FIELD_NAME: {"value": False}}})
if access_control_list is not None:
# If an access control list is provided, the caller can only
# retrieve public documents, and non-public documents where at least
# one acl provided here is present in the document's acl list. If
# there is explicitly no list provided, we make no restrictions on
# the documents that can be retrieved.
filter_clauses.append(_get_acl_visibility_filter(access_control_list))
if source_types:
# If at least one source type is provided, the caller will only
# retrieve documents whose source type is present in this input
# list.
filter_clauses.append(_get_source_type_filter(source_types))
if tags:
# If at least one tag is provided, the caller will only retrieve
# documents where at least one tag provided here is present in the
# document's metadata list.
filter_clauses.append(_get_tag_filter(tags))
if document_sets:
# If at least one document set is provided, the caller will only
# retrieve documents where at least one document set provided here
# is present in the document's document sets list.
filter_clauses.append(_get_document_set_filter(document_sets))
if user_file_ids:
# If at least one user file ID is provided, the caller will only
# retrieve documents where the document ID is in this input list of
# file IDs. Note that these IDs correspond to Onyx documents whereas
# the entries retrieved from the document index correspond to Onyx
# document chunks.
filter_clauses.append(_get_user_file_id_filter(user_file_ids))
if project_id is not None:
# If a project ID is provided, the caller will only retrieve
# documents where the project ID provided here is present in the
# document's user projects list.
filter_clauses.append(_get_user_project_filter(project_id))
if time_cutoff is not None:
# If a time cutoff is provided, the caller will only retrieve
# documents where the document was last updated at or after the time
# cutoff. For documents which do not have a value for
# LAST_UPDATED_FIELD_NAME, we assume some default age for the
# purposes of time cutoff.
filter_clauses.append(_get_time_cutoff_filter(time_cutoff))
if min_chunk_index is not None or max_chunk_index is not None:
filter_clauses.append(
_get_chunk_index_filter(min_chunk_index, max_chunk_index)
)
if document_id is not None:
# WARNING: If user_file_ids has elements and if none of them are
# document_id, no matches will be retrieved.
filter_clauses.append(
{"term": {DOCUMENT_ID_FIELD_NAME: {"value": document_id}}}
)
if max_chunk_size is not None:
filter_clauses.append(
{"term": {MAX_CHUNK_SIZE_FIELD_NAME: {"value": max_chunk_size}}}
)
if tenant_state.multitenant:
hybrid_search_filters.append(
filter_clauses.append(
{"term": {TENANT_ID_FIELD_NAME: {"value": tenant_state.tenant_id}}}
)
return hybrid_search_filters
return filter_clauses
@staticmethod
def _get_match_highlights_configuration() -> dict[str, Any]:
@@ -378,4 +664,5 @@ class DocumentQuery:
}
}
}
return match_highlights_configuration

View File

@@ -17,7 +17,7 @@ from onyx.connectors.cross_connector_utils.miscellaneous_utils import (
get_experts_stores_representations,
)
from onyx.document_index.chunk_content_enrichment import (
generate_enriched_content_for_chunk,
generate_enriched_content_for_chunk_text,
)
from onyx.document_index.document_index_utils import get_uuid_from_chunk
from onyx.document_index.document_index_utils import get_uuid_from_chunk_info_old
@@ -186,7 +186,7 @@ def _index_vespa_chunk(
# For the BM25 index, the keyword suffix is used, the vector is already generated with the more
# natural language representation of the metadata section
CONTENT: remove_invalid_unicode_chars(
generate_enriched_content_for_chunk(chunk)
generate_enriched_content_for_chunk_text(chunk)
),
# This duplication of `content` is needed for keyword highlighting
# Note that it's not exactly the same as the actual content

View File

@@ -7,6 +7,9 @@ from onyx.connectors.models import ConnectorFailure
from onyx.connectors.models import ConnectorStopSignal
from onyx.connectors.models import DocumentFailure
from onyx.db.models import SearchSettings
from onyx.document_index.chunk_content_enrichment import (
generate_enriched_content_for_chunk_embedding,
)
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.indexing.models import ChunkEmbedding
from onyx.indexing.models import DocAwareChunk
@@ -126,7 +129,7 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
if chunk.large_chunk_reference_ids:
large_chunks_present = True
chunk_text = (
f"{chunk.title_prefix}{chunk.doc_summary}{chunk.content}{chunk.chunk_context}{chunk.metadata_suffix_semantic}"
generate_enriched_content_for_chunk_embedding(chunk)
) or chunk.source_document.get_title_for_document_index()
if not chunk_text:

View File

@@ -37,6 +37,7 @@ from onyx.document_index.document_index_utils import (
get_multipass_config,
)
from onyx.document_index.interfaces import DocumentIndex
from onyx.document_index.interfaces import DocumentInsertionRecord
from onyx.document_index.interfaces import DocumentMetadata
from onyx.document_index.interfaces import IndexBatchParams
from onyx.file_processing.image_summarization import summarize_image_with_error_handling
@@ -163,7 +164,7 @@ def index_doc_batch_with_handler(
*,
chunker: Chunker,
embedder: IndexingEmbedder,
document_index: DocumentIndex,
document_indices: list[DocumentIndex],
document_batch: list[Document],
request_id: str | None,
tenant_id: str,
@@ -176,7 +177,7 @@ def index_doc_batch_with_handler(
index_pipeline_result = index_doc_batch(
chunker=chunker,
embedder=embedder,
document_index=document_index,
document_indices=document_indices,
document_batch=document_batch,
request_id=request_id,
tenant_id=tenant_id,
@@ -627,7 +628,7 @@ def index_doc_batch(
document_batch: list[Document],
chunker: Chunker,
embedder: IndexingEmbedder,
document_index: DocumentIndex,
document_indices: list[DocumentIndex],
request_id: str | None,
tenant_id: str,
adapter: IndexingBatchAdapter,
@@ -743,47 +744,57 @@ def index_doc_batch(
short_descriptor_log = str(short_descriptor_list)[:1024]
logger.debug(f"Indexing the following chunks: {short_descriptor_log}")
# A document will not be spread across different batches, so all the
# documents with chunks in this set, are fully represented by the chunks
# in this set
(
insertion_records,
vector_db_write_failures,
) = write_chunks_to_vector_db_with_backoff(
document_index=document_index,
chunks=result.chunks,
index_batch_params=IndexBatchParams(
doc_id_to_previous_chunk_cnt=result.doc_id_to_previous_chunk_cnt,
doc_id_to_new_chunk_cnt=result.doc_id_to_new_chunk_cnt,
tenant_id=tenant_id,
large_chunks_enabled=chunker.enable_large_chunks,
),
)
primary_doc_idx_insertion_records: list[DocumentInsertionRecord] | None = None
primary_doc_idx_vector_db_write_failures: list[ConnectorFailure] | None = None
for document_index in document_indices:
# A document will not be spread across different batches, so all the
# documents with chunks in this set, are fully represented by the chunks
# in this set
(
insertion_records,
vector_db_write_failures,
) = write_chunks_to_vector_db_with_backoff(
document_index=document_index,
chunks=result.chunks,
index_batch_params=IndexBatchParams(
doc_id_to_previous_chunk_cnt=result.doc_id_to_previous_chunk_cnt,
doc_id_to_new_chunk_cnt=result.doc_id_to_new_chunk_cnt,
tenant_id=tenant_id,
large_chunks_enabled=chunker.enable_large_chunks,
),
)
all_returned_doc_ids = (
{record.document_id for record in insertion_records}
.union(
{
record.failed_document.document_id
for record in vector_db_write_failures
if record.failed_document
}
)
.union(
{
record.failed_document.document_id
for record in embedding_failures
if record.failed_document
}
)
)
if all_returned_doc_ids != set(updatable_ids):
raise RuntimeError(
f"Some documents were not successfully indexed. "
f"Updatable IDs: {updatable_ids}, "
f"Returned IDs: {all_returned_doc_ids}. "
"This should never happen."
all_returned_doc_ids: set[str] = (
{record.document_id for record in insertion_records}
.union(
{
record.failed_document.document_id
for record in vector_db_write_failures
if record.failed_document
}
)
.union(
{
record.failed_document.document_id
for record in embedding_failures
if record.failed_document
}
)
)
if all_returned_doc_ids != set(updatable_ids):
raise RuntimeError(
f"Some documents were not successfully indexed. "
f"Updatable IDs: {updatable_ids}, "
f"Returned IDs: {all_returned_doc_ids}. "
"This should never happen."
f"This occured for document index {document_index.__class__.__name__}"
)
# We treat the first document index we got as the primary one used
# for reporting the state of indexing.
if primary_doc_idx_insertion_records is None:
primary_doc_idx_insertion_records = insertion_records
if primary_doc_idx_vector_db_write_failures is None:
primary_doc_idx_vector_db_write_failures = vector_db_write_failures
adapter.post_index(
context=context,
@@ -792,11 +803,15 @@ def index_doc_batch(
result=result,
)
assert primary_doc_idx_insertion_records is not None
assert primary_doc_idx_vector_db_write_failures is not None
return IndexingPipelineResult(
new_docs=len([r for r in insertion_records if not r.already_existed]),
new_docs=len(
[r for r in primary_doc_idx_insertion_records if not r.already_existed]
),
total_docs=len(filtered_documents),
total_chunks=len(chunks_with_embeddings),
failures=vector_db_write_failures + embedding_failures,
failures=primary_doc_idx_vector_db_write_failures + embedding_failures,
)
@@ -805,7 +820,7 @@ def run_indexing_pipeline(
document_batch: list[Document],
request_id: str | None,
embedder: IndexingEmbedder,
document_index: DocumentIndex,
document_indices: list[DocumentIndex],
db_session: Session,
tenant_id: str,
adapter: IndexingBatchAdapter,
@@ -846,7 +861,7 @@ def run_indexing_pipeline(
return index_doc_batch_with_handler(
chunker=chunker,
embedder=embedder,
document_index=document_index,
document_indices=document_indices,
document_batch=document_batch,
request_id=request_id,
tenant_id=tenant_id,

View File

@@ -41,6 +41,11 @@ alphanum_regex = re.compile(r"[^a-z0-9]+")
rem_email_regex = re.compile(r"(?<=\S)@([a-z0-9-]+)\.([a-z]{2,6})$")
def _ngrams(sequence: str, n: int) -> list[tuple[str, ...]]:
"""Generate n-grams from a sequence."""
return [tuple(sequence[i : i + n]) for i in range(len(sequence) - n + 1)]
def _clean_name(entity_name: str) -> str:
"""
Clean an entity string by removing non-alphanumeric characters and email addresses.
@@ -58,8 +63,6 @@ def _normalize_one_entity(
attributes: dict[str, str],
allowed_docs_temp_view_name: str | None = None,
) -> str | None:
from nltk import ngrams # type: ignore
"""
Matches a single entity to the best matching entity of the same type.
"""
@@ -150,16 +153,16 @@ def _normalize_one_entity(
# step 2: do a weighted ngram analysis and damerau levenshtein distance to rerank
n1, n2, n3 = (
set(ngrams(cleaned_entity, 1)),
set(ngrams(cleaned_entity, 2)),
set(ngrams(cleaned_entity, 3)),
set(_ngrams(cleaned_entity, 1)),
set(_ngrams(cleaned_entity, 2)),
set(_ngrams(cleaned_entity, 3)),
)
for i, (candidate_id_name, candidate_name, _) in enumerate(candidates):
cleaned_candidate = _clean_name(candidate_name)
h_n1, h_n2, h_n3 = (
set(ngrams(cleaned_candidate, 1)),
set(ngrams(cleaned_candidate, 2)),
set(ngrams(cleaned_candidate, 3)),
set(_ngrams(cleaned_candidate, 1)),
set(_ngrams(cleaned_candidate, 2)),
set(_ngrams(cleaned_candidate, 3)),
)
# compute ngram overlap, renormalize scores if the names are too short for larger ngrams

View File

@@ -369,6 +369,8 @@ def _patch_openai_responses_chunk_parser() -> None:
# New output item added
output_item = parsed_chunk.get("item", {})
if output_item.get("type") == "function_call":
# Track that we've received tool calls via streaming
self._has_streamed_tool_calls = True
return GenericStreamingChunk(
text="",
tool_use=ChatCompletionToolCallChunk(
@@ -394,6 +396,8 @@ def _patch_openai_responses_chunk_parser() -> None:
elif event_type == "response.function_call_arguments.delta":
content_part: Optional[str] = parsed_chunk.get("delta", None)
if content_part:
# Track that we've received tool calls via streaming
self._has_streamed_tool_calls = True
return GenericStreamingChunk(
text="",
tool_use=ChatCompletionToolCallChunk(
@@ -491,22 +495,72 @@ def _patch_openai_responses_chunk_parser() -> None:
elif event_type == "response.completed":
# Final event signaling all output items (including parallel tool calls) are done
# Check if we already received tool calls via streaming events
# There is an issue where OpenAI (not via Azure) will give back the tool calls streamed out as tokens
# But on Azure, it's only given out all at once. OpenAI also happens to give back the tool calls in the
# response.completed event so we need to throw it out here or there are duplicate tool calls.
has_streamed_tool_calls = getattr(self, "_has_streamed_tool_calls", False)
response_data = parsed_chunk.get("response", {})
# Determine finish reason based on response content
finish_reason = "stop"
if response_data.get("output"):
for item in response_data["output"]:
if isinstance(item, dict) and item.get("type") == "function_call":
finish_reason = "tool_calls"
break
return GenericStreamingChunk(
text="",
tool_use=None,
is_finished=True,
finish_reason=finish_reason,
usage=None,
output_items = response_data.get("output", [])
# Check if there are function_call items in the output
has_function_calls = any(
isinstance(item, dict) and item.get("type") == "function_call"
for item in output_items
)
if has_function_calls and not has_streamed_tool_calls:
# Azure's Responses API returns all tool calls in response.completed
# without streaming them incrementally. Extract them here.
from litellm.types.utils import (
Delta,
ModelResponseStream,
StreamingChoices,
)
tool_calls = []
for idx, item in enumerate(output_items):
if isinstance(item, dict) and item.get("type") == "function_call":
tool_calls.append(
ChatCompletionToolCallChunk(
id=item.get("call_id"),
index=idx,
type="function",
function=ChatCompletionToolCallFunctionChunk(
name=item.get("name"),
arguments=item.get("arguments", ""),
),
)
)
return ModelResponseStream(
choices=[
StreamingChoices(
index=0,
delta=Delta(tool_calls=tool_calls),
finish_reason="tool_calls",
)
]
)
elif has_function_calls:
# Tool calls were already streamed, just signal completion
return GenericStreamingChunk(
text="",
tool_use=None,
is_finished=True,
finish_reason="tool_calls",
usage=None,
)
else:
return GenericStreamingChunk(
text="",
tool_use=None,
is_finished=True,
finish_reason="stop",
usage=None,
)
else:
pass
@@ -631,6 +685,40 @@ def _patch_openai_responses_transform_response() -> None:
LiteLLMResponsesTransformationHandler.transform_response = _patched_transform_response # type: ignore[method-assign]
def _patch_azure_responses_should_fake_stream() -> None:
"""
Patches AzureOpenAIResponsesAPIConfig.should_fake_stream to always return False.
By default, LiteLLM uses "fake streaming" (MockResponsesAPIStreamingIterator) for models
not in its database. This causes Azure custom model deployments to buffer the entire
response before yielding, resulting in poor time-to-first-token.
Azure's Responses API supports native streaming, so we override this to always use
real streaming (SyncResponsesAPIStreamingIterator).
"""
from litellm.llms.azure.responses.transformation import (
AzureOpenAIResponsesAPIConfig,
)
if (
getattr(AzureOpenAIResponsesAPIConfig.should_fake_stream, "__name__", "")
== "_patched_should_fake_stream"
):
return
def _patched_should_fake_stream(
self: Any,
model: Optional[str],
stream: Optional[bool],
custom_llm_provider: Optional[str] = None,
) -> bool:
# Azure Responses API supports native streaming - never fake it
return False
_patched_should_fake_stream.__name__ = "_patched_should_fake_stream"
AzureOpenAIResponsesAPIConfig.should_fake_stream = _patched_should_fake_stream # type: ignore[method-assign]
def apply_monkey_patches() -> None:
"""
Apply all necessary monkey patches to LiteLLM for compatibility.
@@ -640,12 +728,13 @@ def apply_monkey_patches() -> None:
- Patching OllamaChatCompletionResponseIterator.chunk_parser for streaming content
- Patching OpenAiResponsesToChatCompletionStreamIterator.chunk_parser for OpenAI Responses API
- Patching LiteLLMResponsesTransformationHandler.transform_response for non-streaming responses
- Patching LiteLLMResponsesTransformationHandler._convert_content_str_to_input_text for tool content types
- Patching AzureOpenAIResponsesAPIConfig.should_fake_stream to enable native streaming
"""
_patch_ollama_transform_request()
_patch_ollama_chunk_parser()
_patch_openai_responses_chunk_parser()
_patch_openai_responses_transform_response()
_patch_azure_responses_should_fake_stream()
def _extract_reasoning_content(message: dict) -> Tuple[Optional[str], Optional[str]]:

View File

@@ -54,11 +54,6 @@
"model_vendor": "amazon",
"model_version": "v1:0"
},
"anthropic.claude-3-5-haiku-20241022-v1:0": {
"display_name": "Claude Haiku 3.5",
"model_vendor": "anthropic",
"model_version": "20241022-v1:0"
},
"anthropic.claude-3-5-sonnet-20240620-v1:0": {
"display_name": "Claude Sonnet 3.5",
"model_vendor": "anthropic",
@@ -1465,11 +1460,6 @@
"model_vendor": "mistral",
"model_version": "v0:1"
},
"bedrock/us.anthropic.claude-3-5-haiku-20241022-v1:0": {
"display_name": "Claude Haiku 3.5",
"model_vendor": "anthropic",
"model_version": "20241022-v1:0"
},
"chat-bison": {
"display_name": "Chat Bison",
"model_vendor": "google",
@@ -1500,16 +1490,6 @@
"model_vendor": "openai",
"model_version": "latest"
},
"claude-3-5-haiku-20241022": {
"display_name": "Claude Haiku 3.5",
"model_vendor": "anthropic",
"model_version": "20241022"
},
"claude-3-5-haiku-latest": {
"display_name": "Claude Haiku 3.5",
"model_vendor": "anthropic",
"model_version": "latest"
},
"claude-3-5-sonnet-20240620": {
"display_name": "Claude Sonnet 3.5",
"model_vendor": "anthropic",
@@ -1715,11 +1695,6 @@
"model_vendor": "amazon",
"model_version": "v1:0"
},
"eu.anthropic.claude-3-5-haiku-20241022-v1:0": {
"display_name": "Claude Haiku 3.5",
"model_vendor": "anthropic",
"model_version": "20241022-v1:0"
},
"eu.anthropic.claude-3-5-sonnet-20240620-v1:0": {
"display_name": "Claude Sonnet 3.5",
"model_vendor": "anthropic",
@@ -3251,15 +3226,6 @@
"model_vendor": "anthropic",
"model_version": "latest"
},
"openrouter/anthropic/claude-3-5-haiku": {
"display_name": "Claude Haiku 3.5",
"model_vendor": "anthropic"
},
"openrouter/anthropic/claude-3-5-haiku-20241022": {
"display_name": "Claude Haiku 3.5",
"model_vendor": "anthropic",
"model_version": "20241022"
},
"openrouter/anthropic/claude-3-haiku": {
"display_name": "Claude Haiku 3",
"model_vendor": "anthropic"
@@ -3774,11 +3740,6 @@
"model_vendor": "amazon",
"model_version": "1:0"
},
"us.anthropic.claude-3-5-haiku-20241022-v1:0": {
"display_name": "Claude Haiku 3.5",
"model_vendor": "anthropic",
"model_version": "20241022"
},
"us.anthropic.claude-3-5-sonnet-20240620-v1:0": {
"display_name": "Claude Sonnet 3.5",
"model_vendor": "anthropic",
@@ -3899,15 +3860,6 @@
"model_vendor": "twelvelabs",
"model_version": "v1:0"
},
"vertex_ai/claude-3-5-haiku": {
"display_name": "Claude Haiku 3.5",
"model_vendor": "anthropic"
},
"vertex_ai/claude-3-5-haiku@20241022": {
"display_name": "Claude Haiku 3.5",
"model_vendor": "anthropic",
"model_version": "20241022"
},
"vertex_ai/claude-3-5-sonnet": {
"display_name": "Claude Sonnet 3.5",
"model_vendor": "anthropic"

View File

@@ -301,6 +301,12 @@ class LitellmLLM(LLM):
)
is_ollama = self._model_provider == LlmProviderNames.OLLAMA_CHAT
is_mistral = self._model_provider == LlmProviderNames.MISTRAL
is_vertex_ai = self._model_provider == LlmProviderNames.VERTEX_AI
# Vertex Anthropic Opus 4.5 rejects output_config (LiteLLM maps reasoning_effort).
# Keep this guard until LiteLLM/Vertex accept the field for this model.
is_vertex_opus_4_5 = (
is_vertex_ai and "claude-opus-4-5" in self.config.model_name.lower()
)
#########################
# Build arguments
@@ -331,12 +337,16 @@ class LitellmLLM(LLM):
# Temperature
temperature = 1 if is_reasoning else self._temperature
if stream:
if stream and not is_vertex_opus_4_5:
optional_kwargs["stream_options"] = {"include_usage": True}
# Use configured default if not provided (if not set in env, low)
reasoning_effort = reasoning_effort or ReasoningEffort(DEFAULT_REASONING_EFFORT)
if is_reasoning and reasoning_effort != ReasoningEffort.OFF:
if (
is_reasoning
and reasoning_effort != ReasoningEffort.OFF
and not is_vertex_opus_4_5
):
if is_openai_model:
# OpenAI API does not accept reasoning params for GPT 5 chat models
# (neither reasoning nor reasoning_effort are accepted)

View File

@@ -96,6 +96,7 @@ from onyx.server.long_term_logs.long_term_logs_api import (
router as long_term_logs_router,
)
from onyx.server.manage.administrative import router as admin_router
from onyx.server.manage.discord_bot.api import router as discord_bot_router
from onyx.server.manage.embedding.api import admin_router as embedding_admin_router
from onyx.server.manage.embedding.api import basic_router as embedding_router
from onyx.server.manage.get_state import router as state_router
@@ -380,6 +381,7 @@ def get_application(lifespan_override: Lifespan | None = None) -> FastAPI:
include_router_with_global_prefix_prepended(
application, slack_bot_management_router
)
include_router_with_global_prefix_prepended(application, discord_bot_router)
include_router_with_global_prefix_prepended(application, persona_router)
include_router_with_global_prefix_prepended(application, admin_persona_router)
include_router_with_global_prefix_prepended(application, agents_router)

View File

@@ -0,0 +1,225 @@
import re
ENGLISH_STOPWORDS = [
"a",
"about",
"above",
"after",
"again",
"against",
"ain",
"all",
"am",
"an",
"and",
"any",
"are",
"aren",
"aren't",
"as",
"at",
"be",
"because",
"been",
"before",
"being",
"below",
"between",
"both",
"but",
"by",
"can",
"couldn",
"couldn't",
"d",
"did",
"didn",
"didn't",
"do",
"does",
"doesn",
"doesn't",
"doing",
"don",
"don't",
"down",
"during",
"each",
"few",
"for",
"from",
"further",
"had",
"hadn",
"hadn't",
"has",
"hasn",
"hasn't",
"have",
"haven",
"haven't",
"having",
"he",
"he'd",
"he'll",
"he's",
"her",
"here",
"hers",
"herself",
"him",
"himself",
"his",
"how",
"i",
"i'd",
"i'll",
"i'm",
"i've",
"if",
"in",
"into",
"is",
"isn",
"isn't",
"it",
"it'd",
"it'll",
"it's",
"its",
"itself",
"just",
"ll",
"m",
"ma",
"me",
"mightn",
"mightn't",
"more",
"most",
"mustn",
"mustn't",
"my",
"myself",
"needn",
"needn't",
"no",
"nor",
"not",
"now",
"o",
"of",
"off",
"on",
"once",
"only",
"or",
"other",
"our",
"ours",
"ourselves",
"out",
"over",
"own",
"re",
"s",
"same",
"shan",
"shan't",
"she",
"she'd",
"she'll",
"she's",
"should",
"should've",
"shouldn",
"shouldn't",
"so",
"some",
"such",
"t",
"than",
"that",
"that'll",
"the",
"their",
"theirs",
"them",
"themselves",
"then",
"there",
"these",
"they",
"they'd",
"they'll",
"they're",
"they've",
"this",
"those",
"through",
"to",
"too",
"under",
"until",
"up",
"ve",
"very",
"was",
"wasn",
"wasn't",
"we",
"we'd",
"we'll",
"we're",
"we've",
"were",
"weren",
"weren't",
"what",
"when",
"where",
"which",
"while",
"who",
"whom",
"why",
"will",
"with",
"won",
"won't",
"wouldn",
"wouldn't",
"y",
"you",
"you'd",
"you'll",
"you're",
"you've",
"your",
"yours",
"yourself",
"yourselves",
]
ENGLISH_STOPWORDS_SET = frozenset(ENGLISH_STOPWORDS)
def strip_stopwords(text: str) -> list[str]:
"""Remove English stopwords from text.
Matching is case-insensitive and ignores leading/trailing punctuation
on each word. Internal punctuation (like apostrophes in contractions)
is preserved for matching, so "you're" matches the stopword "you're"
but "youre" would not.
"""
words = text.split()
result = []
for word in words:
# Strip leading/trailing punctuation to get the core word for comparison
# This preserves internal punctuation like apostrophes
core = re.sub(r"^[^\w']+|[^\w']+$", "", word)
if core.lower() not in ENGLISH_STOPWORDS_SET:
result.append(word)
return result

View File

@@ -0,0 +1,287 @@
# Discord Bot Multitenant Architecture
This document analyzes how the Discord cache manager and API client coordinate to handle multitenant API keys from a single Discord client.
## Overview
The Discord bot uses a **single-client, multi-tenant** architecture where one `OnyxDiscordClient` instance serves multiple tenants (organizations) simultaneously. Tenant isolation is achieved through:
- **Cache Manager**: Maps Discord guilds to tenants and stores per-tenant API keys
- **API Client**: Stateless HTTP client that accepts dynamic API keys per request
```
┌─────────────────────────────────────────────────────────────────────┐
│ OnyxDiscordClient │
│ │
│ ┌─────────────────────────┐ ┌─────────────────────────────┐ │
│ │ DiscordCacheManager │ │ OnyxAPIClient │ │
│ │ │ │ │ │
│ │ guild_id → tenant_id │───▶│ send_chat_message( │ │
│ │ tenant_id → api_key │ │ message, │ │
│ │ │ │ api_key=<per-tenant>, │ │
│ └─────────────────────────┘ │ persona_id=... │ │
│ │ ) │ │
│ └─────────────────────────────┘ │
└─────────────────────────────────────────────────────────────────────┘
```
---
## Component Details
### 1. Cache Manager (`backend/onyx/onyxbot/discord/cache.py`)
The `DiscordCacheManager` maintains two critical in-memory mappings:
```python
class DiscordCacheManager:
_guild_tenants: dict[int, str] # guild_id → tenant_id
_api_keys: dict[str, str] # tenant_id → api_key
_lock: asyncio.Lock # Concurrency control
```
#### Key Responsibilities
| Function | Purpose |
|----------|---------|
| `get_tenant(guild_id)` | O(1) lookup: guild → tenant |
| `get_api_key(tenant_id)` | O(1) lookup: tenant → API key |
| `refresh_all()` | Full cache rebuild from database |
| `refresh_guild()` | Incremental update for single guild |
#### API Key Provisioning Strategy
API keys are **lazily provisioned** - only created when first needed:
```python
async def _load_tenant_data(self, tenant_id: str) -> tuple[list[int], str | None]:
needs_key = tenant_id not in self._api_keys
with get_session_with_tenant(tenant_id) as db:
# Load guild configs
configs = get_discord_bot_configs(db)
guild_ids = [c.guild_id for c in configs if c.enabled]
# Only provision API key if not already cached
api_key = None
if needs_key:
api_key = get_or_create_discord_service_api_key(db, tenant_id)
return guild_ids, api_key
```
This optimization avoids repeated database calls for API key generation.
#### Concurrency Control
All write operations acquire an async lock to prevent race conditions:
```python
async def refresh_all(self) -> None:
async with self._lock:
# Safe to modify _guild_tenants and _api_keys
for tenant_id in get_all_tenant_ids():
guild_ids, api_key = await self._load_tenant_data(tenant_id)
# Update mappings...
```
Read operations (`get_tenant`, `get_api_key`) are lock-free since Python dict lookups are atomic.
---
### 2. API Client (`backend/onyx/onyxbot/discord/api_client.py`)
The `OnyxAPIClient` is a **stateless async HTTP client** that communicates with Onyx API pods.
#### Key Design: Per-Request API Key Injection
```python
class OnyxAPIClient:
async def send_chat_message(
self,
message: str,
api_key: str, # Injected per-request
persona_id: int | None,
...
) -> ChatFullResponse:
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {api_key}", # Tenant-specific auth
}
# Make request...
```
The client accepts `api_key` as a parameter to each method, enabling **dynamic tenant selection at request time**. This design allows a single client instance to serve multiple tenants:
```python
# Same client, different tenants
await api_client.send_chat_message(msg, api_key=key_for_tenant_1, ...)
await api_client.send_chat_message(msg, api_key=key_for_tenant_2, ...)
```
---
## Coordination Flow
### Message Processing Pipeline
When a Discord message arrives, the client coordinates cache and API client:
```python
async def on_message(self, message: Message) -> None:
guild_id = message.guild.id
# Step 1: Cache lookup - guild → tenant
tenant_id = self.cache.get_tenant(guild_id)
if not tenant_id:
return # Guild not registered
# Step 2: Cache lookup - tenant → API key
api_key = self.cache.get_api_key(tenant_id)
if not api_key:
logger.warning(f"No API key for tenant {tenant_id}")
return
# Step 3: API call with tenant-specific credentials
await process_chat_message(
message=message,
api_key=api_key, # Tenant-specific
persona_id=persona_id, # Tenant-specific
api_client=self.api_client,
)
```
### Startup Sequence
```python
async def setup_hook(self) -> None:
# 1. Initialize API client (create aiohttp session)
await self.api_client.initialize()
# 2. Populate cache with all tenants
await self.cache.refresh_all()
# 3. Start background refresh task
self._cache_refresh_task = self.loop.create_task(
self._periodic_cache_refresh() # Every 60 seconds
)
```
### Shutdown Sequence
```python
async def close(self) -> None:
# 1. Cancel background refresh
if self._cache_refresh_task:
self._cache_refresh_task.cancel()
# 2. Close Discord connection
await super().close()
# 3. Close API client session
await self.api_client.close()
# 4. Clear cache
self.cache.clear()
```
---
## Tenant Isolation Mechanisms
### 1. Per-Tenant API Keys
Each tenant has a dedicated service API key:
```python
# backend/onyx/db/discord_bot.py
def get_or_create_discord_service_api_key(db_session: Session, tenant_id: str) -> str:
existing = get_discord_service_api_key(db_session)
if existing:
return regenerate_key(existing)
# Create LIMITED role key (chat-only permissions)
return insert_api_key(
db_session=db_session,
api_key_args=APIKeyArgs(
name=DISCORD_SERVICE_API_KEY_NAME,
role=UserRole.LIMITED, # Minimal permissions
),
user_id=None, # Service account (system-owned)
).api_key
```
### 2. Database Context Variables
The cache uses context variables for proper tenant-scoped DB sessions:
```python
context_token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
try:
with get_session_with_tenant(tenant_id) as db:
# All DB operations scoped to this tenant
...
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(context_token)
```
### 3. Enterprise Gating Support
Gated tenants are filtered during cache refresh:
```python
gated_tenants = fetch_ee_implementation_or_noop(
"onyx.server.tenants.product_gating",
"get_gated_tenants",
set(),
)()
for tenant_id in get_all_tenant_ids():
if tenant_id in gated_tenants:
continue # Skip gated tenants
```
---
## Cache Refresh Strategy
| Trigger | Method | Scope |
|---------|--------|-------|
| Startup | `refresh_all()` | All tenants |
| Periodic (60s) | `refresh_all()` | All tenants |
| Guild registration | `refresh_guild()` | Single tenant |
### Error Handling
- **Tenant-level errors**: Logged and skipped (doesn't stop other tenants)
- **Missing API key**: Bot silently ignores messages from that guild
- **Network errors**: Logged, cache continues with stale data until next refresh
---
## Key Design Insights
1. **Single Client, Multiple Tenants**: One `OnyxAPIClient` and one `DiscordCacheManager` instance serves all tenants via dynamic API key injection.
2. **Cache-First Architecture**: Guild lookups are O(1) in-memory; API keys are cached after first provisioning to avoid repeated DB calls.
3. **Graceful Degradation**: If an API key is missing or stale, the bot simply doesn't respond (no crash or error propagation).
4. **Thread Safety Without Blocking**: `asyncio.Lock` prevents race conditions while maintaining async concurrency for reads.
5. **Lazy Provisioning**: API keys are only created when first needed, then cached for performance.
6. **Stateless API Client**: The HTTP client holds no tenant state - all tenant context is injected per-request via the `api_key` parameter.
---
## File References
| Component | Path |
|-----------|------|
| Cache Manager | `backend/onyx/onyxbot/discord/cache.py` |
| API Client | `backend/onyx/onyxbot/discord/api_client.py` |
| Discord Client | `backend/onyx/onyxbot/discord/client.py` |
| API Key DB Operations | `backend/onyx/db/discord_bot.py` |
| Cache Manager Tests | `backend/tests/unit/onyx/onyxbot/discord/test_cache_manager.py` |
| API Client Tests | `backend/tests/unit/onyx/onyxbot/discord/test_api_client.py` |

View File

@@ -0,0 +1,215 @@
"""Async HTTP client for communicating with Onyx API pods."""
import aiohttp
from onyx.chat.models import ChatFullResponse
from onyx.onyxbot.discord.constants import API_REQUEST_TIMEOUT
from onyx.onyxbot.discord.exceptions import APIConnectionError
from onyx.onyxbot.discord.exceptions import APIResponseError
from onyx.onyxbot.discord.exceptions import APITimeoutError
from onyx.server.query_and_chat.models import ChatSessionCreationRequest
from onyx.server.query_and_chat.models import MessageOrigin
from onyx.server.query_and_chat.models import SendMessageRequest
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import build_api_server_url_for_http_requests
logger = setup_logger()
class OnyxAPIClient:
"""Async HTTP client for sending chat requests to Onyx API pods.
This client manages an aiohttp session for making non-blocking HTTP
requests to the Onyx API server. It handles authentication with per-tenant
API keys and multi-tenant routing.
Usage:
client = OnyxAPIClient()
await client.initialize()
try:
response = await client.send_chat_message(
message="What is our deployment process?",
tenant_id="tenant_123",
api_key="dn_xxx...",
persona_id=1,
)
print(response.answer)
finally:
await client.close()
"""
def __init__(
self,
timeout: int = API_REQUEST_TIMEOUT,
) -> None:
"""Initialize the API client.
Args:
timeout: Request timeout in seconds.
"""
# Helm chart uses API_SERVER_URL_OVERRIDE_FOR_HTTP_REQUESTS to set the base URL
# TODO: Ideally, this override is only used when someone is launching an Onyx service independently
self._base_url = build_api_server_url_for_http_requests(
respect_env_override_if_set=True
).rstrip("/")
self._timeout = timeout
self._session: aiohttp.ClientSession | None = None
async def initialize(self) -> None:
"""Create the aiohttp session.
Must be called before making any requests. The session is created
with a total timeout and connection timeout.
"""
if self._session is not None:
logger.warning("API client session already initialized")
return
timeout = aiohttp.ClientTimeout(
total=self._timeout,
connect=30, # 30 seconds to establish connection
)
self._session = aiohttp.ClientSession(timeout=timeout)
logger.info(f"API client initialized with base URL: {self._base_url}")
async def close(self) -> None:
"""Close the aiohttp session.
Should be called when shutting down the bot to properly release
resources.
"""
if self._session is not None:
await self._session.close()
self._session = None
logger.info("API client session closed")
@property
def is_initialized(self) -> bool:
"""Check if the session is initialized."""
return self._session is not None
async def send_chat_message(
self,
message: str,
api_key: str,
persona_id: int | None = None,
) -> ChatFullResponse:
"""Send a chat message to the Onyx API server and get a response.
This method sends a non-streaming chat request to the API server. The response
contains the complete answer with any citations and metadata.
Args:
message: The user's message to process.
api_key: The API key for authentication.
persona_id: Optional persona ID to use for the response.
Returns:
ChatFullResponse containing the answer, citations, and metadata.
Raises:
APIConnectionError: If unable to connect to the API.
APITimeoutError: If the request times out.
APIResponseError: If the API returns an error response.
"""
if self._session is None:
raise APIConnectionError(
"API client not initialized. Call initialize() first."
)
url = f"{self._base_url}/chat/send-chat-message"
# Build request payload
request = SendMessageRequest(
message=message,
stream=False,
origin=MessageOrigin.DISCORDBOT,
chat_session_info=ChatSessionCreationRequest(
persona_id=persona_id if persona_id is not None else 0,
),
)
# Build headers
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {api_key}",
}
try:
async with self._session.post(
url,
json=request.model_dump(mode="json"),
headers=headers,
) as response:
if response.status == 401:
raise APIResponseError(
"Authentication failed - invalid API key",
status_code=401,
)
elif response.status == 403:
raise APIResponseError(
"Access denied - insufficient permissions",
status_code=403,
)
elif response.status == 404:
raise APIResponseError(
"API endpoint not found",
status_code=404,
)
elif response.status >= 500:
error_text = await response.text()
raise APIResponseError(
f"Server error: {error_text}",
status_code=response.status,
)
elif response.status >= 400:
error_text = await response.text()
raise APIResponseError(
f"Request error: {error_text}",
status_code=response.status,
)
# Parse successful response
data = await response.json()
response_obj = ChatFullResponse.model_validate(data)
if response_obj.error_msg:
logger.warning(f"Chat API returned error: {response_obj.error_msg}")
return response_obj
except aiohttp.ClientConnectorError as e:
logger.error(f"Failed to connect to API: {e}")
raise APIConnectionError(
f"Failed to connect to API at {self._base_url}: {e}"
) from e
except TimeoutError as e:
logger.error(f"API request timed out after {self._timeout}s")
raise APITimeoutError(
f"Request timed out after {self._timeout} seconds"
) from e
except aiohttp.ClientError as e:
logger.error(f"HTTP client error: {e}")
raise APIConnectionError(f"HTTP client error: {e}") from e
async def health_check(self) -> bool:
"""Check if the API server is healthy.
Returns:
True if the API server is reachable and healthy, False otherwise.
"""
if self._session is None:
logger.warning("API client not initialized. Call initialize() first.")
return False
try:
url = f"{self._base_url}/health"
async with self._session.get(
url, timeout=aiohttp.ClientTimeout(total=10)
) as response:
return response.status == 200
except Exception as e:
logger.warning(f"API server health check failed: {e}")
return False

View File

@@ -0,0 +1,154 @@
"""Multi-tenant cache for Discord bot guild-tenant mappings and API keys."""
import asyncio
from onyx.db.discord_bot import get_guild_configs
from onyx.db.discord_bot import get_or_create_discord_service_api_key
from onyx.db.engine.sql_engine import get_session_with_tenant
from onyx.db.engine.tenant_utils import get_all_tenant_ids
from onyx.onyxbot.discord.exceptions import CacheError
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
logger = setup_logger()
class DiscordCacheManager:
"""Caches guild->tenant mappings and tenant->API key mappings.
Refreshed on startup, periodically (every 60s), and when guilds register.
"""
def __init__(self) -> None:
self._guild_tenants: dict[int, str] = {} # guild_id -> tenant_id
self._api_keys: dict[str, str] = {} # tenant_id -> api_key
self._lock = asyncio.Lock()
self._initialized = False
@property
def is_initialized(self) -> bool:
return self._initialized
async def refresh_all(self) -> None:
"""Full cache refresh from all tenants."""
async with self._lock:
logger.info("Starting Discord cache refresh")
new_guild_tenants: dict[int, str] = {}
new_api_keys: dict[str, str] = {}
try:
gated = fetch_ee_implementation_or_noop(
"onyx.server.tenants.product_gating",
"get_gated_tenants",
set(),
)()
tenant_ids = await asyncio.to_thread(get_all_tenant_ids)
for tenant_id in tenant_ids:
if tenant_id in gated:
continue
context_token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
try:
guild_ids, api_key = await self._load_tenant_data(tenant_id)
if not guild_ids:
logger.debug(f"No guilds found for tenant {tenant_id}")
continue
if not api_key:
logger.warning(
"Discord service API key missing for tenant that has registered guilds. "
f"{tenant_id} will not be handled in this refresh cycle."
)
continue
for guild_id in guild_ids:
new_guild_tenants[guild_id] = tenant_id
new_api_keys[tenant_id] = api_key
except Exception as e:
logger.warning(f"Failed to refresh tenant {tenant_id}: {e}")
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(context_token)
self._guild_tenants = new_guild_tenants
self._api_keys = new_api_keys
self._initialized = True
logger.info(
f"Cache refresh complete: {len(new_guild_tenants)} guilds, "
f"{len(new_api_keys)} tenants"
)
except Exception as e:
logger.error(f"Cache refresh failed: {e}")
raise CacheError(f"Failed to refresh cache: {e}") from e
async def refresh_guild(self, guild_id: int, tenant_id: str) -> None:
"""Add a single guild to cache after registration."""
async with self._lock:
logger.info(f"Refreshing cache for guild {guild_id} (tenant: {tenant_id})")
guild_ids, api_key = await self._load_tenant_data(tenant_id)
if guild_id in guild_ids:
self._guild_tenants[guild_id] = tenant_id
if api_key:
self._api_keys[tenant_id] = api_key
logger.info(f"Cache updated for guild {guild_id}")
else:
logger.warning(f"Guild {guild_id} not found or disabled")
async def _load_tenant_data(self, tenant_id: str) -> tuple[list[int], str | None]:
"""Load guild IDs and provision API key if needed.
Returns:
(active_guild_ids, api_key) - api_key is the cached key if available,
otherwise a newly created key. Returns None if no guilds found.
"""
cached_key = self._api_keys.get(tenant_id)
def _sync() -> tuple[list[int], str | None]:
with get_session_with_tenant(tenant_id=tenant_id) as db:
configs = get_guild_configs(db)
guild_ids = [
config.guild_id
for config in configs
if config.enabled and config.guild_id is not None
]
if not guild_ids:
return [], None
if not cached_key:
new_key = get_or_create_discord_service_api_key(db, tenant_id)
db.commit()
return guild_ids, new_key
return guild_ids, cached_key
return await asyncio.to_thread(_sync)
def get_tenant(self, guild_id: int) -> str | None:
"""Get tenant ID for a guild."""
return self._guild_tenants.get(guild_id)
def get_api_key(self, tenant_id: str) -> str | None:
"""Get API key for a tenant."""
return self._api_keys.get(tenant_id)
def remove_guild(self, guild_id: int) -> None:
"""Remove a guild from cache."""
self._guild_tenants.pop(guild_id, None)
def get_all_guild_ids(self) -> list[int]:
"""Get all cached guild IDs."""
return list(self._guild_tenants.keys())
def clear(self) -> None:
"""Clear all caches."""
self._guild_tenants.clear()
self._api_keys.clear()
self._initialized = False

View File

@@ -0,0 +1,232 @@
"""Discord bot client with integrated message handling."""
import asyncio
import time
import discord
from discord.ext import commands
from onyx.configs.app_configs import DISCORD_BOT_INVOKE_CHAR
from onyx.onyxbot.discord.api_client import OnyxAPIClient
from onyx.onyxbot.discord.cache import DiscordCacheManager
from onyx.onyxbot.discord.constants import CACHE_REFRESH_INTERVAL
from onyx.onyxbot.discord.handle_commands import handle_dm
from onyx.onyxbot.discord.handle_commands import handle_registration_command
from onyx.onyxbot.discord.handle_commands import handle_sync_channels_command
from onyx.onyxbot.discord.handle_message import process_chat_message
from onyx.onyxbot.discord.handle_message import should_respond
from onyx.onyxbot.discord.utils import get_bot_token
from onyx.utils.logger import setup_logger
logger = setup_logger()
class OnyxDiscordClient(commands.Bot):
"""Discord bot client with integrated cache, API client, and message handling.
This client handles:
- Guild registration via !register command
- Message processing with persona-based responses
- Thread context for conversation continuity
- Multi-tenant support via cached API keys
"""
def __init__(self, command_prefix: str = DISCORD_BOT_INVOKE_CHAR) -> None:
intents = discord.Intents.default()
intents.message_content = True
intents.members = True
super().__init__(command_prefix=command_prefix, intents=intents)
self.ready = False
self.cache = DiscordCacheManager()
self.api_client = OnyxAPIClient()
self._cache_refresh_task: asyncio.Task | None = None
# -------------------------------------------------------------------------
# Lifecycle Methods
# -------------------------------------------------------------------------
async def setup_hook(self) -> None:
"""Called before on_ready. Initialize components."""
logger.info("Initializing Discord bot components...")
# Initialize API client
await self.api_client.initialize()
# Initial cache load
await self.cache.refresh_all()
# Start periodic cache refresh
self._cache_refresh_task = self.loop.create_task(self._periodic_cache_refresh())
logger.info("Discord bot components initialized")
async def _periodic_cache_refresh(self) -> None:
"""Background task to refresh cache periodically."""
while not self.is_closed():
await asyncio.sleep(CACHE_REFRESH_INTERVAL)
try:
await self.cache.refresh_all()
except Exception as e:
logger.error(f"Cache refresh failed: {e}")
async def on_ready(self) -> None:
"""Bot connected and ready."""
if self.ready:
return
if not self.user:
raise RuntimeError("Critical error: Discord Bot user not found")
logger.info(f"Discord Bot connected as {self.user} (ID: {self.user.id})")
logger.info(f"Connected to {len(self.guilds)} guild(s)")
logger.info(f"Cached {len(self.cache.get_all_guild_ids())} registered guild(s)")
self.ready = True
async def close(self) -> None:
"""Graceful shutdown."""
logger.info("Shutting down Discord bot...")
# Cancel cache refresh task
if self._cache_refresh_task:
self._cache_refresh_task.cancel()
try:
await self._cache_refresh_task
except asyncio.CancelledError:
pass
# Close Discord connection first - stops new commands from triggering cache ops
if not self.is_closed():
await super().close()
# Close API client
await self.api_client.close()
# Clear cache (safe now - no concurrent operations possible)
self.cache.clear()
self.ready = False
logger.info("Discord bot shutdown complete")
# -------------------------------------------------------------------------
# Message Handling
# -------------------------------------------------------------------------
async def on_message(self, message: discord.Message) -> None:
"""Main message handler."""
# mypy
if not self.user:
raise RuntimeError("Critical error: Discord Bot user not found")
try:
# Ignore bot messages
if message.author.bot:
return
# Ignore thread starter messages (empty reference nodes that don't contain content)
if message.type == discord.MessageType.thread_starter_message:
return
# Handle DMs
if isinstance(message.channel, discord.DMChannel):
await handle_dm(message)
return
# Must have a guild
if not message.guild or not message.guild.id:
return
guild_id = message.guild.id
# Check for registration command first
if await handle_registration_command(message, self.cache):
return
# Look up guild in cache
tenant_id = self.cache.get_tenant(guild_id)
# Check for sync-channels command (requires registered guild)
if await handle_sync_channels_command(message, tenant_id, self):
return
if not tenant_id:
# Guild not registered, ignore
return
# Get API key
api_key = self.cache.get_api_key(tenant_id)
if not api_key:
logger.warning(f"No API key cached for tenant {tenant_id}")
return
# Check if bot should respond
should_respond_context = await should_respond(message, tenant_id, self.user)
if not should_respond_context.should_respond:
return
logger.debug(
f"Processing message: '{message.content[:50]}' in "
f"#{getattr(message.channel, 'name', 'unknown')} ({message.guild.name}), "
f"persona_id={should_respond_context.persona_id}"
)
# Process the message
await process_chat_message(
message=message,
api_key=api_key,
persona_id=should_respond_context.persona_id,
thread_only_mode=should_respond_context.thread_only_mode,
api_client=self.api_client,
bot_user=self.user,
)
except Exception as e:
logger.exception(f"Error processing message: {e}")
# -----------------------------------------------------------------------------
# Entry Point
# -----------------------------------------------------------------------------
def main() -> None:
"""Main entry point for Discord bot."""
from onyx.db.engine.sql_engine import SqlEngine
from onyx.utils.variable_functionality import set_is_ee_based_on_env_variable
logger.info("Starting Onyx Discord Bot...")
# Initialize the database engine (required before any DB operations)
SqlEngine.init_engine(pool_size=20, max_overflow=5)
# Initialize EE features based on environment
set_is_ee_based_on_env_variable()
counter = 0
while True:
token = get_bot_token()
if not token:
if counter % 180 == 0:
logger.info(
"Discord bot is dormant. Waiting for token configuration..."
)
counter += 1
time.sleep(5)
continue
counter = 0
bot = OnyxDiscordClient()
try:
# bot.run() handles SIGINT/SIGTERM and calls close() automatically
bot.run(token)
except Exception:
logger.exception("Fatal error in Discord bot")
raise
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,19 @@
"""Discord bot constants."""
# API settings
API_REQUEST_TIMEOUT: int = 3 * 60 # 3 minutes
# Cache settings
CACHE_REFRESH_INTERVAL: int = 60 # 1 minute
# Message settings
MAX_MESSAGE_LENGTH: int = 2000 # Discord's character limit
MAX_CONTEXT_MESSAGES: int = 10 # Max messages to include in conversation context
# Note: Discord.py's add_reaction() requires unicode emoji, not :name: format
THINKING_EMOJI: str = "🤔" # U+1F914 - Thinking Face
SUCCESS_EMOJI: str = "" # U+2705 - White Heavy Check Mark
ERROR_EMOJI: str = "" # U+274C - Cross Mark
# Command prefix
REGISTER_COMMAND: str = "register"
SYNC_CHANNELS_COMMAND: str = "sync-channels"

View File

@@ -0,0 +1,37 @@
"""Custom exception classes for Discord bot."""
class DiscordBotError(Exception):
"""Base exception for Discord bot errors."""
class RegistrationError(DiscordBotError):
"""Error during guild registration."""
class SyncChannelsError(DiscordBotError):
"""Error during channel sync."""
class APIError(DiscordBotError):
"""Base API error."""
class CacheError(DiscordBotError):
"""Error during cache operations."""
class APIConnectionError(APIError):
"""Failed to connect to API."""
class APITimeoutError(APIError):
"""Request timed out."""
class APIResponseError(APIError):
"""API returned an error response."""
def __init__(self, message: str, status_code: int | None = None):
super().__init__(message)
self.status_code = status_code

View File

@@ -0,0 +1,437 @@
"""Discord bot command handlers for registration and channel sync."""
import asyncio
from datetime import datetime
from datetime import timezone
import discord
from onyx.configs.app_configs import DISCORD_BOT_INVOKE_CHAR
from onyx.configs.constants import ONYX_DISCORD_URL
from onyx.db.discord_bot import bulk_create_channel_configs
from onyx.db.discord_bot import get_guild_config_by_discord_id
from onyx.db.discord_bot import get_guild_config_by_internal_id
from onyx.db.discord_bot import get_guild_config_by_registration_key
from onyx.db.discord_bot import sync_channel_configs
from onyx.db.engine.sql_engine import get_session_with_tenant
from onyx.db.utils import DiscordChannelView
from onyx.onyxbot.discord.cache import DiscordCacheManager
from onyx.onyxbot.discord.constants import REGISTER_COMMAND
from onyx.onyxbot.discord.constants import SYNC_CHANNELS_COMMAND
from onyx.onyxbot.discord.exceptions import RegistrationError
from onyx.onyxbot.discord.exceptions import SyncChannelsError
from onyx.server.manage.discord_bot.utils import parse_discord_registration_key
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
logger = setup_logger()
async def handle_dm(message: discord.Message) -> None:
"""Handle direct messages."""
dm_response = (
"**I can't respond to DMs** :sweat:\n\n"
f"Please chat with me in a server channel, or join the official "
f"[Onyx Discord]({ONYX_DISCORD_URL}) for help!"
)
await message.channel.send(dm_response)
# -------------------------------------------------------------------------
# Helper functions for error handling
# -------------------------------------------------------------------------
async def _try_dm_author(message: discord.Message, content: str) -> bool:
"""Attempt to DM the message author. Returns True if successful."""
logger.debug(f"Responding in Discord DM with {content}")
try:
await message.author.send(content)
return True
except (discord.Forbidden, discord.HTTPException) as e:
# User has DMs disabled or other error
logger.warning(f"Failed to DM author {message.author.id}: {e}")
except Exception as e:
logger.exception(f"Unexpected error DMing author {message.author.id}: {e}")
return False
async def _try_delete_message(message: discord.Message) -> bool:
"""Attempt to delete a message. Returns True if successful."""
logger.debug(f"Deleting potentially sensitive message {message.id}")
try:
await message.delete()
return True
except (discord.Forbidden, discord.HTTPException) as e:
# Bot lacks permission or other error
logger.warning(f"Failed to delete message {message.id}: {e}")
except Exception as e:
logger.exception(f"Unexpected error deleting message {message.id}: {e}")
return False
async def _try_react_x(message: discord.Message) -> bool:
"""Attempt to react to a message with ❌. Returns True if successful."""
try:
await message.add_reaction("")
return True
except (discord.Forbidden, discord.HTTPException) as e:
# Bot lacks permission or other error
logger.warning(f"Failed to react to message {message.id}: {e}")
except Exception as e:
logger.exception(f"Unexpected error reacting to message {message.id}: {e}")
return False
# -------------------------------------------------------------------------
# Registration
# -------------------------------------------------------------------------
async def handle_registration_command(
message: discord.Message,
cache: DiscordCacheManager,
) -> bool:
"""Handle !register command. Returns True if command was handled."""
content = message.content.strip()
# Check for !register command
if not content.startswith(f"{DISCORD_BOT_INVOKE_CHAR}{REGISTER_COMMAND}"):
return False
# Must be in a server
if not message.guild:
await _try_dm_author(
message, "This command can only be used in a server channel."
)
return True
guild_name = message.guild.name
logger.info(f"Registration command received: {guild_name}")
try:
# Parse the registration key
parts = content.split(maxsplit=1)
if len(parts) < 2:
raise RegistrationError(
"Invalid registration key format. Please check the key and try again."
)
registration_key = parts[1].strip()
if not message.author or not isinstance(message.author, discord.Member):
raise RegistrationError(
"You need to be a server administrator to register the bot."
)
# Check permissions - require admin or manage_guild
if not message.author.guild_permissions.administrator:
if not message.author.guild_permissions.manage_guild:
raise RegistrationError(
"You need **Administrator** or **Manage Server** permissions "
"to register this bot."
)
await _register_guild(message, registration_key, cache)
logger.info(f"Registration successful: {guild_name}")
await message.reply(
":white_check_mark: **Successfully registered!**\n\n"
"This server is now connected to Onyx. "
"I'll respond to messages based on your server and channel settings set in Onyx."
)
except RegistrationError as e:
logger.debug(f"Registration failed: {guild_name}, error={e}")
await _try_dm_author(message, f":x: **Registration failed.**\n\n{e}")
await _try_delete_message(message)
except Exception:
logger.exception(f"Registration failed unexpectedly: {guild_name}")
await _try_dm_author(
message,
":x: **Registration failed.**\n\n"
"An unexpected error occurred. Please try again later.",
)
await _try_delete_message(message)
return True
async def _register_guild(
message: discord.Message,
registration_key: str,
cache: DiscordCacheManager,
) -> None:
"""Register a guild with a registration key."""
if not message.guild:
# mypy, even though we already know that message.guild is not None
raise RegistrationError("This command can only be used in a server.")
logger.info(f"Guild '{message.guild.name}' attempting to register Discord bot")
registration_key = registration_key.strip()
# Parse tenant_id from registration key
parsed = parse_discord_registration_key(registration_key)
if parsed is None:
raise RegistrationError(
"Invalid registration key format. Please check the key and try again."
)
tenant_id = parsed
logger.info(f"Parsed tenant_id {tenant_id} from registration key")
# Check if this guild is already registered to any tenant
guild_id = message.guild.id
existing_tenant = cache.get_tenant(guild_id)
if existing_tenant is not None:
logger.warning(
f"Guild {guild_id} is already registered to tenant {existing_tenant}"
)
raise RegistrationError(
"This server is already registered.\n\n"
"OnyxBot can only connect one Discord server to one Onyx workspace."
)
context_token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
try:
guild = message.guild
guild_name = guild.name
# Collect all text channels from the guild
channels = get_text_channels(guild)
logger.info(f"Found {len(channels)} text channels in guild '{guild_name}'")
# Validate and update in database
def _sync_register() -> int:
with get_session_with_tenant(tenant_id=tenant_id) as db:
# Find the guild config by registration key
config = get_guild_config_by_registration_key(db, registration_key)
if not config:
raise RegistrationError(
"Registration key not found.\n\n"
"The key may have expired or been deleted. "
"Please generate a new one from the Onyx admin panel."
)
# Check if already used
if config.guild_id is not None:
raise RegistrationError(
"This registration key has already been used.\n\n"
"Each key can only be used once. "
"Please generate a new key from the Onyx admin panel."
)
# Update the guild config
config.guild_id = guild_id
config.guild_name = guild_name
config.registered_at = datetime.now(timezone.utc)
# Create channel configs for all text channels
bulk_create_channel_configs(db, config.id, channels)
db.commit()
return config.id
await asyncio.to_thread(_sync_register)
# Refresh cache for this guild
await cache.refresh_guild(guild_id, tenant_id)
logger.info(
f"Guild '{guild_name}' registered with {len(channels)} channel configs"
)
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(context_token)
def get_text_channels(guild: discord.Guild) -> list[DiscordChannelView]:
"""Get all text channels from a guild as DiscordChannelView objects."""
channels: list[DiscordChannelView] = []
for channel in guild.channels:
# Include text channels and forum channels (where threads can be created)
if isinstance(channel, (discord.TextChannel, discord.ForumChannel)):
# Check if channel is private (not visible to @everyone)
everyone_perms = channel.permissions_for(guild.default_role)
is_private = not everyone_perms.view_channel
logger.debug(
f"Found channel: #{channel.name}, "
f"type={channel.type.name}, is_private={is_private}"
)
channels.append(
DiscordChannelView(
channel_id=channel.id,
channel_name=channel.name,
channel_type=channel.type.name, # "text" or "forum"
is_private=is_private,
)
)
logger.debug(f"Retrieved {len(channels)} channels from guild '{guild.name}'")
return channels
# -------------------------------------------------------------------------
# Sync Channels
# -------------------------------------------------------------------------
async def handle_sync_channels_command(
message: discord.Message,
tenant_id: str | None,
bot: discord.Client,
) -> bool:
"""Handle !sync-channels command. Returns True if command was handled."""
content = message.content.strip()
# Check for !sync-channels command
if not content.startswith(f"{DISCORD_BOT_INVOKE_CHAR}{SYNC_CHANNELS_COMMAND}"):
return False
# Must be in a server
if not message.guild:
await _try_dm_author(
message, "This command can only be used in a server channel."
)
return True
guild_name = message.guild.name
logger.info(f"Sync-channels command received: {guild_name}")
try:
# Must be registered
if not tenant_id:
raise SyncChannelsError(
"This server is not registered. Please register it first."
)
# Check permissions - require admin or manage_guild
if not message.author or not isinstance(message.author, discord.Member):
raise SyncChannelsError(
"You need to be a server administrator to sync channels."
)
if not message.author.guild_permissions.administrator:
if not message.author.guild_permissions.manage_guild:
raise SyncChannelsError(
"You need **Administrator** or **Manage Server** permissions "
"to sync channels."
)
# Get guild config ID
def _get_guild_config_id() -> int | None:
with get_session_with_tenant(tenant_id=tenant_id) as db:
if not message.guild:
raise SyncChannelsError(
"Server not found. This shouldn't happen. Please contact Onyx support."
)
config = get_guild_config_by_discord_id(db, message.guild.id)
return config.id if config else None
guild_config_id = await asyncio.to_thread(_get_guild_config_id)
if not guild_config_id:
raise SyncChannelsError(
"Server config not found. This shouldn't happen. Please contact Onyx support."
)
# Perform the sync
added, removed, updated = await sync_guild_channels(
guild_config_id, tenant_id, bot
)
logger.info(
f"Sync-channels successful: {guild_name}, "
f"added={added}, removed={removed}, updated={updated}"
)
await message.reply(
f":white_check_mark: **Channel sync complete!**\n\n"
f"* **{added}** new channel(s) added\n"
f"* **{removed}** deleted channel(s) removed\n"
f"* **{updated}** channel name(s) updated\n\n"
"New channels are disabled by default. Enable them in the Onyx admin panel."
)
except SyncChannelsError as e:
logger.debug(f"Sync-channels failed: {guild_name}, error={e}")
await _try_dm_author(message, f":x: **Channel sync failed.**\n\n{e}")
await _try_react_x(message)
except Exception:
logger.exception(f"Sync-channels failed unexpectedly: {guild_name}")
await _try_dm_author(
message,
":x: **Channel sync failed.**\n\n"
"An unexpected error occurred. Please try again later.",
)
await _try_react_x(message)
return True
async def sync_guild_channels(
guild_config_id: int,
tenant_id: str,
bot: discord.Client,
) -> tuple[int, int, int]:
"""Sync channel configs with current Discord channels for a guild.
Fetches current channels from Discord and syncs with database:
- Creates configs for new channels (disabled by default)
- Removes configs for deleted channels
- Updates names for existing channels if changed
Args:
guild_config_id: Internal ID of the guild config
tenant_id: Tenant ID for database access
bot: Discord bot client
Returns:
(added_count, removed_count, updated_count)
Raises:
ValueError: If guild config not found or guild not registered
"""
context_token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
try:
# Get guild_id from config
def _get_guild_id() -> int | None:
with get_session_with_tenant(tenant_id=tenant_id) as db:
config = get_guild_config_by_internal_id(db, guild_config_id)
if not config:
return None
return config.guild_id
guild_id = await asyncio.to_thread(_get_guild_id)
if guild_id is None:
raise ValueError(
f"Guild config {guild_config_id} not found or not registered"
)
# Get the guild from Discord
guild = bot.get_guild(guild_id)
if not guild:
raise ValueError(f"Guild {guild_id} not found in Discord cache")
# Get current channels from Discord
channels = get_text_channels(guild)
logger.info(f"Syncing {len(channels)} channels for guild '{guild.name}'")
# Sync with database
def _sync() -> tuple[int, int, int]:
with get_session_with_tenant(tenant_id=tenant_id) as db:
added, removed, updated = sync_channel_configs(
db, guild_config_id, channels
)
db.commit()
return added, removed, updated
added, removed, updated = await asyncio.to_thread(_sync)
logger.info(
f"Channel sync complete for guild '{guild.name}': "
f"added={added}, removed={removed}, updated={updated}"
)
return added, removed, updated
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(context_token)

View File

@@ -0,0 +1,493 @@
"""Discord bot message handling and response logic."""
import asyncio
import discord
from pydantic import BaseModel
from onyx.chat.models import ChatFullResponse
from onyx.db.discord_bot import get_channel_config_by_discord_ids
from onyx.db.discord_bot import get_guild_config_by_discord_id
from onyx.db.engine.sql_engine import get_session_with_tenant
from onyx.db.models import DiscordChannelConfig
from onyx.db.models import DiscordGuildConfig
from onyx.onyxbot.discord.api_client import OnyxAPIClient
from onyx.onyxbot.discord.constants import MAX_CONTEXT_MESSAGES
from onyx.onyxbot.discord.constants import MAX_MESSAGE_LENGTH
from onyx.onyxbot.discord.constants import THINKING_EMOJI
from onyx.onyxbot.discord.exceptions import APIError
from onyx.utils.logger import setup_logger
logger = setup_logger()
# Message types with actual content (excludes system notifications like "user joined")
CONTENT_MESSAGE_TYPES = (
discord.MessageType.default,
discord.MessageType.reply,
discord.MessageType.thread_starter_message,
)
class ShouldRespondContext(BaseModel):
"""Context for whether the bot should respond to a message."""
should_respond: bool
persona_id: int | None
thread_only_mode: bool
# -------------------------------------------------------------------------
# Response Logic
# -------------------------------------------------------------------------
async def should_respond(
message: discord.Message,
tenant_id: str,
bot_user: discord.ClientUser,
) -> ShouldRespondContext:
"""Determine if bot should respond and which persona to use."""
if not message.guild:
logger.warning("Received a message that isn't in a server.")
return ShouldRespondContext(
should_respond=False, persona_id=None, thread_only_mode=False
)
guild_id = message.guild.id
channel_id = message.channel.id
bot_mentioned = bot_user in message.mentions
def _get_configs() -> tuple[DiscordGuildConfig | None, DiscordChannelConfig | None]:
with get_session_with_tenant(tenant_id=tenant_id) as db:
guild_config = get_guild_config_by_discord_id(db, guild_id)
if not guild_config or not guild_config.enabled:
return None, None
# For threads, use parent channel ID
actual_channel_id = channel_id
if isinstance(message.channel, discord.Thread) and message.channel.parent:
actual_channel_id = message.channel.parent.id
channel_config = get_channel_config_by_discord_ids(
db, guild_id, actual_channel_id
)
return guild_config, channel_config
guild_config, channel_config = await asyncio.to_thread(_get_configs)
if not guild_config or not channel_config or not channel_config.enabled:
return ShouldRespondContext(
should_respond=False, persona_id=None, thread_only_mode=False
)
# Determine persona (channel override or guild default)
persona_id = channel_config.persona_override_id or guild_config.default_persona_id
# Check mention requirement (with exceptions for implicit invocation)
if channel_config.require_bot_invocation and not bot_mentioned:
if not await check_implicit_invocation(message, bot_user):
return ShouldRespondContext(
should_respond=False, persona_id=None, thread_only_mode=False
)
return ShouldRespondContext(
should_respond=True,
persona_id=persona_id,
thread_only_mode=channel_config.thread_only_mode,
)
async def check_implicit_invocation(
message: discord.Message,
bot_user: discord.ClientUser,
) -> bool:
"""Check if the bot should respond without explicit mention.
Returns True if:
1. User is replying to a bot message
2. User is in a thread owned by the bot
3. User is in a thread created from a bot message
"""
# Check if replying to a bot message
if message.reference and message.reference.message_id:
try:
referenced_msg = await message.channel.fetch_message(
message.reference.message_id
)
if referenced_msg.author.id == bot_user.id:
logger.debug(
f"Implicit invocation via reply: '{message.content[:50]}...'"
)
return True
except (discord.NotFound, discord.HTTPException):
pass
# Check thread-related conditions
if isinstance(message.channel, discord.Thread):
thread = message.channel
# Bot owns the thread
if thread.owner_id == bot_user.id:
logger.debug(
f"Implicit invocation via bot-owned thread: "
f"'{message.content[:50]}...' in #{thread.name}"
)
return True
# Thread was created from a bot message
if thread.parent and not isinstance(thread.parent, discord.ForumChannel):
try:
starter = await thread.parent.fetch_message(thread.id)
if starter.author.id == bot_user.id:
logger.debug(
f"Implicit invocation via bot-started thread: "
f"'{message.content[:50]}...' in #{thread.name}"
)
return True
except (discord.NotFound, discord.HTTPException):
pass
return False
# -------------------------------------------------------------------------
# Message Processing
# -------------------------------------------------------------------------
async def process_chat_message(
message: discord.Message,
api_key: str,
persona_id: int | None,
thread_only_mode: bool,
api_client: OnyxAPIClient,
bot_user: discord.ClientUser,
) -> None:
"""Process a message and send response."""
try:
await message.add_reaction(THINKING_EMOJI)
except discord.DiscordException:
logger.warning(
f"Failed to add thinking reaction to message: '{message.content[:50]}...'"
)
try:
# Build conversation context
context = await _build_conversation_context(message, bot_user)
# Prepare full message content
parts = []
if context:
parts.append(context)
if isinstance(message.channel, discord.Thread):
if isinstance(message.channel.parent, discord.ForumChannel):
parts.append(f"Forum post title: {message.channel.name}")
parts.append(
f"Current message from @{message.author.display_name}: {format_message_content(message)}"
)
# Send to API
response = await api_client.send_chat_message(
message="\n\n".join(parts),
api_key=api_key,
persona_id=persona_id,
)
# Format response with citations
answer = response.answer or "I couldn't generate a response."
answer = _append_citations(answer, response)
await send_response(message, answer, thread_only_mode)
try:
await message.remove_reaction(THINKING_EMOJI, bot_user)
except discord.DiscordException:
pass
except APIError as e:
logger.error(f"API error processing message: {e}")
await send_error_response(message, bot_user)
except Exception as e:
logger.exception(f"Error processing chat message: {e}")
await send_error_response(message, bot_user)
async def _build_conversation_context(
message: discord.Message,
bot_user: discord.ClientUser,
) -> str | None:
"""Build conversation context from thread history or reply chain."""
if isinstance(message.channel, discord.Thread):
return await _build_thread_context(message, bot_user)
elif message.reference:
return await _build_reply_chain_context(message, bot_user)
return None
def _append_citations(answer: str, response: ChatFullResponse) -> str:
"""Append citation sources to the answer if present."""
if not response.citation_info or not response.top_documents:
return answer
cited_docs: list[tuple[int, str, str | None]] = []
for citation in response.citation_info:
doc = next(
(
d
for d in response.top_documents
if d.document_id == citation.document_id
),
None,
)
if doc:
cited_docs.append(
(
citation.citation_number,
doc.semantic_identifier or "Source",
doc.link,
)
)
if not cited_docs:
return answer
cited_docs.sort(key=lambda x: x[0])
citations = "\n\n**Sources:**\n"
for num, name, link in cited_docs[:5]:
if link:
citations += f"{num}. [{name}](<{link}>)\n"
else:
citations += f"{num}. {name}\n"
return answer + citations
# -------------------------------------------------------------------------
# Context Building
# -------------------------------------------------------------------------
async def _build_reply_chain_context(
message: discord.Message,
bot_user: discord.ClientUser,
) -> str | None:
"""Build context by following the reply chain backwards."""
if not message.reference or not message.reference.message_id:
return None
try:
messages: list[discord.Message] = []
current = message
# Follow reply chain backwards up to MAX_CONTEXT_MESSAGES
while (
current.reference
and current.reference.message_id
and len(messages) < MAX_CONTEXT_MESSAGES
):
try:
parent = await message.channel.fetch_message(
current.reference.message_id
)
messages.append(parent)
current = parent
except (discord.NotFound, discord.HTTPException):
break
if not messages:
return None
messages.reverse() # Chronological order
logger.debug(
f"Built reply chain context: {len(messages)} messages in #{getattr(message.channel, 'name', 'unknown')}"
)
return _format_messages_as_context(messages, bot_user)
except Exception as e:
logger.warning(f"Failed to build reply chain context: {e}")
return None
async def _build_thread_context(
message: discord.Message,
bot_user: discord.ClientUser,
) -> str | None:
"""Build context from thread message history."""
if not isinstance(message.channel, discord.Thread):
return None
try:
thread = message.channel
messages: list[discord.Message] = []
# Fetch recent messages (excluding current)
async for msg in thread.history(limit=MAX_CONTEXT_MESSAGES, oldest_first=False):
if msg.id != message.id:
messages.append(msg)
# Include thread starter message and its reply chain if not already present
if thread.parent and not isinstance(thread.parent, discord.ForumChannel):
try:
starter = await thread.parent.fetch_message(thread.id)
if starter.id != message.id and not any(
m.id == starter.id for m in messages
):
messages.append(starter)
# Trace back through the starter's reply chain for more context
current = starter
while (
current.reference
and current.reference.message_id
and len(messages) < MAX_CONTEXT_MESSAGES
):
try:
parent = await thread.parent.fetch_message(
current.reference.message_id
)
if not any(m.id == parent.id for m in messages):
messages.append(parent)
current = parent
except (discord.NotFound, discord.HTTPException):
break
except (discord.NotFound, discord.HTTPException):
pass
if not messages:
return None
messages.sort(key=lambda m: m.id) # Chronological order
logger.debug(
f"Built thread context: {len(messages)} messages in #{thread.name}"
)
return _format_messages_as_context(messages, bot_user)
except Exception as e:
logger.warning(f"Failed to build thread context: {e}")
return None
def _format_messages_as_context(
messages: list[discord.Message],
bot_user: discord.ClientUser,
) -> str | None:
"""Format a list of messages into a conversation context string."""
formatted = []
for msg in messages:
if msg.type not in CONTENT_MESSAGE_TYPES:
continue
sender = (
"OnyxBot" if msg.author.id == bot_user.id else f"@{msg.author.display_name}"
)
formatted.append(f"{sender}: {format_message_content(msg)}")
if not formatted:
return None
return (
"You are a Discord bot named OnyxBot.\n"
'Always assume that [user] is the same as the "Current message" author.'
"Conversation history:\n"
"---\n" + "\n".join(formatted) + "\n---"
)
# -------------------------------------------------------------------------
# Message Formatting
# -------------------------------------------------------------------------
def format_message_content(message: discord.Message) -> str:
"""Format message content with readable mentions."""
content = message.content
for user in message.mentions:
content = content.replace(f"<@{user.id}>", f"@{user.display_name}")
content = content.replace(f"<@!{user.id}>", f"@{user.display_name}")
for role in message.role_mentions:
content = content.replace(f"<@&{role.id}>", f"@{role.name}")
for channel in message.channel_mentions:
content = content.replace(f"<#{channel.id}>", f"#{channel.name}")
return content
# -------------------------------------------------------------------------
# Response Sending
# -------------------------------------------------------------------------
async def send_response(
message: discord.Message,
content: str,
thread_only_mode: bool,
) -> None:
"""Send response based on thread_only_mode setting."""
chunks = _split_message(content)
if isinstance(message.channel, discord.Thread):
for chunk in chunks:
await message.channel.send(chunk)
elif thread_only_mode:
thread_name = f"OnyxBot <> {message.author.display_name}"[:100]
thread = await message.create_thread(name=thread_name)
for chunk in chunks:
await thread.send(chunk)
else:
for i, chunk in enumerate(chunks):
if i == 0:
await message.reply(chunk)
else:
await message.channel.send(chunk)
def _split_message(content: str) -> list[str]:
"""Split content into chunks that fit Discord's message limit."""
chunks = []
while content:
if len(content) <= MAX_MESSAGE_LENGTH:
chunks.append(content)
break
# Find a good split point
split_at = MAX_MESSAGE_LENGTH
for sep in ["\n\n", "\n", ". ", " "]:
idx = content.rfind(sep, 0, MAX_MESSAGE_LENGTH)
if idx > MAX_MESSAGE_LENGTH // 2:
split_at = idx + len(sep)
break
chunks.append(content[:split_at])
content = content[split_at:]
return chunks
async def send_error_response(
message: discord.Message,
bot_user: discord.ClientUser,
) -> None:
"""Send error response and clean up reaction."""
try:
await message.remove_reaction(THINKING_EMOJI, bot_user)
except discord.DiscordException:
pass
error_msg = "Sorry, I encountered an error processing your message. You may want to contact Onyx for support :sweat_smile:"
try:
if isinstance(message.channel, discord.Thread):
await message.channel.send(error_msg)
else:
thread = await message.create_thread(
name=f"Response to {message.author.display_name}"[:100]
)
await thread.send(error_msg)
except discord.DiscordException:
pass

View File

@@ -0,0 +1,39 @@
from onyx.configs.app_configs import AUTH_TYPE
from onyx.configs.app_configs import DISCORD_BOT_TOKEN
from onyx.configs.constants import AuthType
from onyx.db.discord_bot import get_discord_bot_config
from onyx.db.engine.sql_engine import get_session_with_tenant
from onyx.utils.logger import setup_logger
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
logger = setup_logger()
def get_bot_token() -> str | None:
"""Get Discord bot token from env var or database.
Priority:
1. DISCORD_BOT_TOKEN env var (always takes precedence)
2. For self-hosted: DiscordBotConfig in database (default tenant)
3. For Cloud: should always have env var set
Returns:
Bot token string, or None if not configured.
"""
# Environment variable takes precedence
if DISCORD_BOT_TOKEN:
return DISCORD_BOT_TOKEN
# Cloud should always have env var; if not, return None
if AUTH_TYPE == AuthType.CLOUD:
logger.warning("Cloud deployment missing DISCORD_BOT_TOKEN env var")
return None
# Self-hosted: check database for bot config
try:
with get_session_with_tenant(tenant_id=POSTGRES_DEFAULT_SCHEMA) as db:
config = get_discord_bot_config(db)
except Exception as e:
logger.error(f"Failed to get bot token from database: {e}")
return None
return config.bot_token if config else None

View File

@@ -592,11 +592,8 @@ def build_slack_response_blocks(
)
citations_blocks = []
document_blocks = []
if answer.citation_info:
citations_blocks = _build_citations_blocks(answer)
else:
document_blocks = _priority_ordered_documents_blocks(answer)
citations_divider = [DividerBlock()] if citations_blocks else []
buttons_divider = [DividerBlock()] if web_follow_up_block or follow_up_block else []
@@ -608,7 +605,6 @@ def build_slack_response_blocks(
+ ai_feedback_block
+ citations_divider
+ citations_blocks
+ document_blocks
+ buttons_divider
+ web_follow_up_block
+ follow_up_block

View File

@@ -1,12 +1,149 @@
from mistune import Markdown # type: ignore[import-untyped]
from mistune import Renderer
import re
from collections.abc import Callable
from typing import Any
from mistune import create_markdown
from mistune import HTMLRenderer
# Tags that should be replaced with a newline (line-break and block-level elements)
_HTML_NEWLINE_TAG_PATTERN = re.compile(
r"<br\s*/?>|</(?:p|div|li|h[1-6]|tr|blockquote|section|article)>",
re.IGNORECASE,
)
# Strips HTML tags but excludes autolinks like <https://...> and <mailto:...>
_HTML_TAG_PATTERN = re.compile(
r"<(?!https?://|mailto:)/?[a-zA-Z][^>]*>",
)
# Matches fenced code blocks (``` ... ```) so we can skip sanitization inside them
_FENCED_CODE_BLOCK_PATTERN = re.compile(r"```[\s\S]*?```")
# Matches the start of any markdown link: [text]( or [[n]](
# The inner group handles nested brackets for citation links like [[1]](.
_MARKDOWN_LINK_PATTERN = re.compile(r"\[(?:[^\[\]]|\[[^\]]*\])*\]\(")
# Matches Slack-style links <url|text> that LLMs sometimes output directly.
# Mistune doesn't recognise this syntax, so text() would escape the angle
# brackets and Slack would render them as literal text instead of links.
_SLACK_LINK_PATTERN = re.compile(r"<(https?://[^|>]+)\|([^>]+)>")
def _sanitize_html(text: str) -> str:
"""Strip HTML tags from a text fragment.
Block-level closing tags and <br> are converted to newlines.
All other HTML tags are removed. Autolinks (<https://...>) are preserved.
"""
text = _HTML_NEWLINE_TAG_PATTERN.sub("\n", text)
text = _HTML_TAG_PATTERN.sub("", text)
return text
def _transform_outside_code_blocks(
message: str, transform: Callable[[str], str]
) -> str:
"""Apply *transform* only to text outside fenced code blocks."""
parts = _FENCED_CODE_BLOCK_PATTERN.split(message)
code_blocks = _FENCED_CODE_BLOCK_PATTERN.findall(message)
result: list[str] = []
for i, part in enumerate(parts):
result.append(transform(part))
if i < len(code_blocks):
result.append(code_blocks[i])
return "".join(result)
def _extract_link_destination(message: str, start_idx: int) -> tuple[str, int | None]:
"""Extract markdown link destination, allowing nested parentheses in the URL."""
depth = 0
i = start_idx
while i < len(message):
curr = message[i]
if curr == "\\":
i += 2
continue
if curr == "(":
depth += 1
elif curr == ")":
if depth == 0:
return message[start_idx:i], i
depth -= 1
i += 1
return message[start_idx:], None
def _normalize_link_destinations(message: str) -> str:
"""Wrap markdown link URLs in angle brackets so the parser handles special chars safely.
Markdown link syntax [text](url) breaks when the URL contains unescaped
parentheses, spaces, or other special characters. Wrapping the URL in angle
brackets — [text](<url>) — tells the parser to treat everything inside as
a literal URL. This applies to all links, not just citations.
"""
if "](" not in message:
return message
normalized_parts: list[str] = []
cursor = 0
while match := _MARKDOWN_LINK_PATTERN.search(message, cursor):
normalized_parts.append(message[cursor : match.end()])
destination_start = match.end()
destination, end_idx = _extract_link_destination(message, destination_start)
if end_idx is None:
normalized_parts.append(message[destination_start:])
return "".join(normalized_parts)
already_wrapped = destination.startswith("<") and destination.endswith(">")
if destination and not already_wrapped:
destination = f"<{destination}>"
normalized_parts.append(destination)
normalized_parts.append(")")
cursor = end_idx + 1
normalized_parts.append(message[cursor:])
return "".join(normalized_parts)
def _convert_slack_links_to_markdown(message: str) -> str:
"""Convert Slack-style <url|text> links to standard markdown [text](url).
LLMs sometimes emit Slack mrkdwn link syntax directly. Mistune doesn't
recognise it, so the angle brackets would be escaped by text() and Slack
would render the link as literal text instead of a clickable link.
"""
return _transform_outside_code_blocks(
message, lambda text: _SLACK_LINK_PATTERN.sub(r"[\2](\1)", text)
)
def format_slack_message(message: str | None) -> str:
return Markdown(renderer=SlackRenderer()).render(message)
if message is None:
return ""
message = _transform_outside_code_blocks(message, _sanitize_html)
message = _convert_slack_links_to_markdown(message)
normalized_message = _normalize_link_destinations(message)
md = create_markdown(renderer=SlackRenderer(), plugins=["strikethrough"])
result = md(normalized_message)
# With HTMLRenderer, result is always str (not AST list)
assert isinstance(result, str)
return result.rstrip("\n")
class SlackRenderer(Renderer):
class SlackRenderer(HTMLRenderer):
"""Renders markdown as Slack mrkdwn format instead of HTML.
Overrides all HTMLRenderer methods that produce HTML tags to ensure
no raw HTML ever appears in Slack messages.
"""
SPECIALS: dict[str, str] = {"&": "&amp;", "<": "&lt;", ">": "&gt;"}
def escape_special(self, text: str) -> str:
@@ -14,52 +151,72 @@ class SlackRenderer(Renderer):
text = text.replace(special, replacement)
return text
def header(self, text: str, level: int, raw: str | None = None) -> str:
return f"*{text}*\n"
def heading(self, text: str, level: int, **attrs: Any) -> str: # noqa: ARG002
return f"*{text}*\n\n"
def emphasis(self, text: str) -> str:
return f"_{text}_"
def double_emphasis(self, text: str) -> str:
def strong(self, text: str) -> str:
return f"*{text}*"
def strikethrough(self, text: str) -> str:
return f"~{text}~"
def list(self, body: str, ordered: bool = True) -> str:
lines = body.split("\n")
def list(self, text: str, ordered: bool, **attrs: Any) -> str:
lines = text.split("\n")
count = 0
for i, line in enumerate(lines):
if line.startswith("li: "):
count += 1
prefix = f"{count}. " if ordered else ""
lines[i] = f"{prefix}{line[4:]}"
return "\n".join(lines)
return "\n".join(lines) + "\n"
def list_item(self, text: str) -> str:
return f"li: {text}\n"
def link(self, link: str, title: str | None, content: str | None) -> str:
escaped_link = self.escape_special(link)
if content:
return f"<{escaped_link}|{content}>"
def link(self, text: str, url: str, title: str | None = None) -> str:
escaped_url = self.escape_special(url)
if text:
return f"<{escaped_url}|{text}>"
if title:
return f"<{escaped_link}|{title}>"
return f"<{escaped_link}>"
return f"<{escaped_url}|{title}>"
return f"<{escaped_url}>"
def image(self, src: str, title: str | None, text: str | None) -> str:
escaped_src = self.escape_special(src)
def image(self, text: str, url: str, title: str | None = None) -> str:
escaped_url = self.escape_special(url)
display_text = title or text
return f"<{escaped_src}|{display_text}>" if display_text else f"<{escaped_src}>"
return f"<{escaped_url}|{display_text}>" if display_text else f"<{escaped_url}>"
def codespan(self, text: str) -> str:
return f"`{text}`"
def block_code(self, text: str, lang: str | None) -> str:
return f"```\n{text}\n```\n"
def block_code(self, code: str, info: str | None = None) -> str: # noqa: ARG002
return f"```\n{code.rstrip(chr(10))}\n```\n\n"
def linebreak(self) -> str:
return "\n"
def thematic_break(self) -> str:
return "---\n\n"
def block_quote(self, text: str) -> str:
lines = text.strip().split("\n")
quoted = "\n".join(f">{line}" for line in lines)
return quoted + "\n\n"
def block_html(self, html: str) -> str:
return _sanitize_html(html) + "\n\n"
def block_error(self, text: str) -> str:
return f"```\n{text}\n```\n\n"
def text(self, text: str) -> str:
# Only escape the three entities Slack recognizes: & < >
# HTMLRenderer.text() also escapes " to &quot; which Slack renders
# as literal &quot; text since Slack doesn't recognize that entity.
return self.escape_special(text)
def paragraph(self, text: str) -> str:
return f"{text}\n"
def autolink(self, link: str, is_email: bool) -> str:
return link if is_email else self.link(link, None, None)
return f"{text}\n\n"

Some files were not shown because too many files have changed in this diff Show More