mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-03-14 04:02:41 +00:00
Compare commits
113 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c4a2ff2593 | ||
|
|
4b74a6dc76 | ||
|
|
eea5f5b380 | ||
|
|
ae428ba684 | ||
|
|
7b927e79c2 | ||
|
|
a6815d1221 | ||
|
|
f73d103b6b | ||
|
|
5ec424a3f3 | ||
|
|
0bd3e9a11c | ||
|
|
a336691882 | ||
|
|
bd4965b4d9 | ||
|
|
3c8a24eeba | ||
|
|
613be0de66 | ||
|
|
6f05dbd650 | ||
|
|
8dc7aae816 | ||
|
|
e4527cf117 | ||
|
|
868c9428e2 | ||
|
|
be61c54d45 | ||
|
|
aec0c28c59 | ||
|
|
ab9e3e5338 | ||
|
|
d17c748f75 | ||
|
|
196b6b0514 | ||
|
|
608491ac36 | ||
|
|
a4a664fa2c | ||
|
|
8a6e349741 | ||
|
|
11f8408558 | ||
|
|
24de76ad28 | ||
|
|
e264356eb5 | ||
|
|
c5c08c5da6 | ||
|
|
78a9b386c7 | ||
|
|
dbcbfc1629 | ||
|
|
fabbb00c49 | ||
|
|
809dab5746 | ||
|
|
1649bed548 | ||
|
|
dd07b3cf27 | ||
|
|
c57ea65d42 | ||
|
|
c1ce180b72 | ||
|
|
b5474dc127 | ||
|
|
e1df3f533a | ||
|
|
df5252db05 | ||
|
|
f01f210af8 | ||
|
|
781219cf18 | ||
|
|
ca39da7de9 | ||
|
|
abf76cd747 | ||
|
|
a78607f1b5 | ||
|
|
e213853f63 | ||
|
|
8dc379c6fd | ||
|
|
787f117e17 | ||
|
|
665640fac8 | ||
|
|
d2d44c1e68 | ||
|
|
ffe04ab91f | ||
|
|
6499b21235 | ||
|
|
c5bfd5a152 | ||
|
|
a0329161b0 | ||
|
|
334b7a6d2f | ||
|
|
36196373a8 | ||
|
|
533aa8eff8 | ||
|
|
ecbb267f80 | ||
|
|
66023dbb6d | ||
|
|
f97466e4de | ||
|
|
2cc8303e5f | ||
|
|
a92ff61f64 | ||
|
|
17551a907e | ||
|
|
9e42951fa4 | ||
|
|
dcb18c2411 | ||
|
|
2f628e39d3 | ||
|
|
fd200d46f8 | ||
|
|
ec7482619b | ||
|
|
9d1a357533 | ||
|
|
fbe823b551 | ||
|
|
1608e2f274 | ||
|
|
4dbb1fa606 | ||
|
|
19b33e4d93 | ||
|
|
e56fa57c21 | ||
|
|
5cdeb84164 | ||
|
|
5b5100a07a | ||
|
|
77f58fbad5 | ||
|
|
cf74afc65e | ||
|
|
a887bc616c | ||
|
|
fef1fd093e | ||
|
|
8d085a4ccf | ||
|
|
28310b9138 | ||
|
|
f71fab580c | ||
|
|
89593b353f | ||
|
|
91e24ae63a | ||
|
|
d2b37724d1 | ||
|
|
87f0849330 | ||
|
|
2ec7526772 | ||
|
|
bbd68e2795 | ||
|
|
e74c36001a | ||
|
|
fe593a15da | ||
|
|
27df690a8d | ||
|
|
edbe569edd | ||
|
|
5118193d16 | ||
|
|
63d3efd380 | ||
|
|
ec978d9a3f | ||
|
|
d4d98a6cd0 | ||
|
|
dc40e86dac | ||
|
|
e495f7a13e | ||
|
|
4761e4b132 | ||
|
|
6b5ab54b85 | ||
|
|
959cf444f8 | ||
|
|
2ebccea6d6 | ||
|
|
5fe7a474db | ||
|
|
9d7dc3da21 | ||
|
|
2899be4c5e | ||
|
|
64ee7fc23f | ||
|
|
e07764285d | ||
|
|
cc2e6ffa8a | ||
|
|
d3ee5c9b59 | ||
|
|
dfa0efc093 | ||
|
|
9aad4077f1 | ||
|
|
29d9ebf7b3 |
3
.github/CODEOWNERS
vendored
3
.github/CODEOWNERS
vendored
@@ -8,3 +8,6 @@
|
||||
# Agent context files
|
||||
/CLAUDE.md @Weves
|
||||
/AGENTS.md @Weves
|
||||
|
||||
# Beta cherry-pick workflow owners
|
||||
/.github/workflows/post-merge-beta-cherry-pick.yml @justin-tahara @jmelahman
|
||||
|
||||
31
.github/actions/slack-notify/action.yml
vendored
31
.github/actions/slack-notify/action.yml
vendored
@@ -1,11 +1,14 @@
|
||||
name: "Slack Notify on Failure"
|
||||
description: "Sends a Slack notification when a workflow fails"
|
||||
name: "Slack Notify"
|
||||
description: "Sends a Slack notification for workflow events"
|
||||
inputs:
|
||||
webhook-url:
|
||||
description: "Slack webhook URL (can also use SLACK_WEBHOOK_URL env var)"
|
||||
required: false
|
||||
details:
|
||||
description: "Additional message body content"
|
||||
required: false
|
||||
failed-jobs:
|
||||
description: "List of failed job names (newline-separated)"
|
||||
description: "Deprecated alias for details"
|
||||
required: false
|
||||
title:
|
||||
description: "Title for the notification"
|
||||
@@ -21,6 +24,7 @@ runs:
|
||||
shell: bash
|
||||
env:
|
||||
SLACK_WEBHOOK_URL: ${{ inputs.webhook-url }}
|
||||
DETAILS: ${{ inputs.details }}
|
||||
FAILED_JOBS: ${{ inputs.failed-jobs }}
|
||||
TITLE: ${{ inputs.title }}
|
||||
REF_NAME: ${{ inputs.ref-name }}
|
||||
@@ -44,6 +48,18 @@ runs:
|
||||
REF_NAME="$GITHUB_REF_NAME"
|
||||
fi
|
||||
|
||||
if [ -z "$DETAILS" ]; then
|
||||
DETAILS="$FAILED_JOBS"
|
||||
fi
|
||||
|
||||
normalize_multiline() {
|
||||
printf '%s' "$1" | awk 'BEGIN { ORS=""; first=1 } { if (!first) printf "\\n"; printf "%s", $0; first=0 }'
|
||||
}
|
||||
|
||||
DETAILS="$(normalize_multiline "$DETAILS")"
|
||||
REF_NAME="$(normalize_multiline "$REF_NAME")"
|
||||
TITLE="$(normalize_multiline "$TITLE")"
|
||||
|
||||
# Escape JSON special characters
|
||||
escape_json() {
|
||||
local input="$1"
|
||||
@@ -59,12 +75,12 @@ runs:
|
||||
}
|
||||
|
||||
REF_NAME_ESC=$(escape_json "$REF_NAME")
|
||||
FAILED_JOBS_ESC=$(escape_json "$FAILED_JOBS")
|
||||
DETAILS_ESC=$(escape_json "$DETAILS")
|
||||
WORKFLOW_URL_ESC=$(escape_json "$WORKFLOW_URL")
|
||||
TITLE_ESC=$(escape_json "$TITLE")
|
||||
|
||||
# Build JSON payload piece by piece
|
||||
# Note: FAILED_JOBS_ESC already contains \n sequences that should remain as \n in JSON
|
||||
# Note: DETAILS_ESC already contains \n sequences that should remain as \n in JSON
|
||||
PAYLOAD="{"
|
||||
PAYLOAD="${PAYLOAD}\"text\":\"${TITLE_ESC}\","
|
||||
PAYLOAD="${PAYLOAD}\"blocks\":[{"
|
||||
@@ -79,10 +95,10 @@ runs:
|
||||
PAYLOAD="${PAYLOAD}{\"type\":\"mrkdwn\",\"text\":\"*Run ID:*\\n#${RUN_NUMBER}\"}"
|
||||
PAYLOAD="${PAYLOAD}]"
|
||||
PAYLOAD="${PAYLOAD}}"
|
||||
if [ -n "$FAILED_JOBS" ]; then
|
||||
if [ -n "$DETAILS" ]; then
|
||||
PAYLOAD="${PAYLOAD},{"
|
||||
PAYLOAD="${PAYLOAD}\"type\":\"section\","
|
||||
PAYLOAD="${PAYLOAD}\"text\":{\"type\":\"mrkdwn\",\"text\":\"*Failed Jobs:*\\n${FAILED_JOBS_ESC}\"}"
|
||||
PAYLOAD="${PAYLOAD}\"text\":{\"type\":\"mrkdwn\",\"text\":\"${DETAILS_ESC}\"}"
|
||||
PAYLOAD="${PAYLOAD}}"
|
||||
fi
|
||||
PAYLOAD="${PAYLOAD},{"
|
||||
@@ -99,4 +115,3 @@ runs:
|
||||
curl -X POST -H 'Content-type: application/json' \
|
||||
--data "$PAYLOAD" \
|
||||
"$SLACK_WEBHOOK_URL"
|
||||
|
||||
|
||||
62
.github/workflows/deployment.yml
vendored
62
.github/workflows/deployment.yml
vendored
@@ -29,20 +29,32 @@ jobs:
|
||||
build-backend-craft: ${{ steps.check.outputs.build-backend-craft }}
|
||||
build-model-server: ${{ steps.check.outputs.build-model-server }}
|
||||
is-cloud-tag: ${{ steps.check.outputs.is-cloud-tag }}
|
||||
is-stable: ${{ steps.check.outputs.is-stable }}
|
||||
is-beta: ${{ steps.check.outputs.is-beta }}
|
||||
is-stable-standalone: ${{ steps.check.outputs.is-stable-standalone }}
|
||||
is-beta-standalone: ${{ steps.check.outputs.is-beta-standalone }}
|
||||
is-craft-latest: ${{ steps.check.outputs.is-craft-latest }}
|
||||
is-latest: ${{ steps.check.outputs.is-latest }}
|
||||
is-test-run: ${{ steps.check.outputs.is-test-run }}
|
||||
sanitized-tag: ${{ steps.check.outputs.sanitized-tag }}
|
||||
short-sha: ${{ steps.check.outputs.short-sha }}
|
||||
steps:
|
||||
- name: Checkout (for git tags)
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
fetch-depth: 0
|
||||
fetch-tags: true
|
||||
|
||||
- name: Setup uv
|
||||
uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # ratchet:astral-sh/setup-uv@v7
|
||||
with:
|
||||
version: "0.9.9"
|
||||
enable-cache: false
|
||||
|
||||
- name: Check which components to build and version info
|
||||
id: check
|
||||
env:
|
||||
EVENT_NAME: ${{ github.event_name }}
|
||||
run: |
|
||||
set -eo pipefail
|
||||
TAG="${GITHUB_REF_NAME}"
|
||||
# Sanitize tag name by replacing slashes with hyphens (for Docker tag compatibility)
|
||||
SANITIZED_TAG=$(echo "$TAG" | tr '/' '-')
|
||||
@@ -54,9 +66,8 @@ jobs:
|
||||
IS_VERSION_TAG=false
|
||||
IS_STABLE=false
|
||||
IS_BETA=false
|
||||
IS_STABLE_STANDALONE=false
|
||||
IS_BETA_STANDALONE=false
|
||||
IS_CRAFT_LATEST=false
|
||||
IS_LATEST=false
|
||||
IS_PROD_TAG=false
|
||||
IS_TEST_RUN=false
|
||||
BUILD_DESKTOP=false
|
||||
@@ -67,9 +78,6 @@ jobs:
|
||||
BUILD_MODEL_SERVER=true
|
||||
|
||||
# Determine tag type based on pattern matching (do regex checks once)
|
||||
if [[ "$TAG" == craft-* ]]; then
|
||||
IS_CRAFT_LATEST=true
|
||||
fi
|
||||
if [[ "$TAG" == *cloud* ]]; then
|
||||
IS_CLOUD=true
|
||||
fi
|
||||
@@ -97,20 +105,28 @@ jobs:
|
||||
fi
|
||||
fi
|
||||
|
||||
# Craft-latest builds backend with Craft enabled
|
||||
if [[ "$IS_CRAFT_LATEST" == "true" ]]; then
|
||||
BUILD_BACKEND_CRAFT=true
|
||||
BUILD_BACKEND=false
|
||||
fi
|
||||
|
||||
# Standalone version checks (for backend/model-server - version excluding cloud tags)
|
||||
if [[ "$IS_STABLE" == "true" ]] && [[ "$IS_CLOUD" != "true" ]]; then
|
||||
IS_STABLE_STANDALONE=true
|
||||
fi
|
||||
if [[ "$IS_BETA" == "true" ]] && [[ "$IS_CLOUD" != "true" ]]; then
|
||||
IS_BETA_STANDALONE=true
|
||||
fi
|
||||
|
||||
# Determine if this tag should get the "latest" Docker tag.
|
||||
# Only the highest semver stable tag (vX.Y.Z exactly) gets "latest".
|
||||
if [[ "$IS_STABLE" == "true" ]]; then
|
||||
HIGHEST_STABLE=$(uv run --no-sync --with onyx-devtools ods latest-stable-tag) || {
|
||||
echo "::error::Failed to determine highest stable tag via 'ods latest-stable-tag'"
|
||||
exit 1
|
||||
}
|
||||
if [[ "$TAG" == "$HIGHEST_STABLE" ]]; then
|
||||
IS_LATEST=true
|
||||
fi
|
||||
fi
|
||||
|
||||
# Build craft-latest backend alongside the regular latest.
|
||||
if [[ "$IS_LATEST" == "true" ]]; then
|
||||
BUILD_BACKEND_CRAFT=true
|
||||
fi
|
||||
|
||||
# Determine if this is a production tag
|
||||
# Production tags are: version tags (v1.2.3*) or nightly tags
|
||||
if [[ "$IS_VERSION_TAG" == "true" ]] || [[ "$IS_NIGHTLY" == "true" ]]; then
|
||||
@@ -129,11 +145,9 @@ jobs:
|
||||
echo "build-backend-craft=$BUILD_BACKEND_CRAFT"
|
||||
echo "build-model-server=$BUILD_MODEL_SERVER"
|
||||
echo "is-cloud-tag=$IS_CLOUD"
|
||||
echo "is-stable=$IS_STABLE"
|
||||
echo "is-beta=$IS_BETA"
|
||||
echo "is-stable-standalone=$IS_STABLE_STANDALONE"
|
||||
echo "is-beta-standalone=$IS_BETA_STANDALONE"
|
||||
echo "is-craft-latest=$IS_CRAFT_LATEST"
|
||||
echo "is-latest=$IS_LATEST"
|
||||
echo "is-test-run=$IS_TEST_RUN"
|
||||
echo "sanitized-tag=$SANITIZED_TAG"
|
||||
echo "short-sha=$SHORT_SHA"
|
||||
@@ -151,7 +165,7 @@ jobs:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Setup uv
|
||||
uses: astral-sh/setup-uv@61cb8a9741eeb8a550a1b8544337180c0fc8476b # ratchet:astral-sh/setup-uv@v7
|
||||
uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # ratchet:astral-sh/setup-uv@v7
|
||||
with:
|
||||
version: "0.9.9"
|
||||
# NOTE: This isn't caching much and zizmor suggests this could be poisoned, so disable.
|
||||
@@ -600,7 +614,7 @@ jobs:
|
||||
latest=false
|
||||
tags: |
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run == 'true' && format('web-{0}', needs.determine-builds.outputs.sanitized-tag) || github.ref_name }}
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && needs.determine-builds.outputs.is-stable == 'true' && 'latest' || '' }}
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && needs.determine-builds.outputs.is-latest == 'true' && 'latest' || '' }}
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && env.EDGE_TAG == 'true' && 'edge' || '' }}
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && needs.determine-builds.outputs.is-beta == 'true' && 'beta' || '' }}
|
||||
|
||||
@@ -1037,7 +1051,7 @@ jobs:
|
||||
latest=false
|
||||
tags: |
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run == 'true' && format('backend-{0}', needs.determine-builds.outputs.sanitized-tag) || github.ref_name }}
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && needs.determine-builds.outputs.is-stable-standalone == 'true' && 'latest' || '' }}
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && needs.determine-builds.outputs.is-latest == 'true' && 'latest' || '' }}
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && env.EDGE_TAG == 'true' && 'edge' || '' }}
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && needs.determine-builds.outputs.is-beta-standalone == 'true' && 'beta' || '' }}
|
||||
|
||||
@@ -1473,7 +1487,7 @@ jobs:
|
||||
latest=false
|
||||
tags: |
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run == 'true' && format('model-server-{0}', needs.determine-builds.outputs.sanitized-tag) || github.ref_name }}
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && needs.determine-builds.outputs.is-stable-standalone == 'true' && 'latest' || '' }}
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && needs.determine-builds.outputs.is-latest == 'true' && 'latest' || '' }}
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && env.EDGE_TAG == 'true' && 'edge' || '' }}
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && needs.determine-builds.outputs.is-beta-standalone == 'true' && 'beta' || '' }}
|
||||
|
||||
|
||||
235
.github/workflows/post-merge-beta-cherry-pick.yml
vendored
235
.github/workflows/post-merge-beta-cherry-pick.yml
vendored
@@ -1,67 +1,112 @@
|
||||
name: Post-Merge Beta Cherry-Pick
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
pull_request_target:
|
||||
types:
|
||||
- closed
|
||||
|
||||
# SECURITY NOTE:
|
||||
# This workflow intentionally uses pull_request_target so post-merge automation can
|
||||
# use base-repo credentials. Do not checkout PR head refs in this workflow
|
||||
# (e.g. github.event.pull_request.head.sha). Only trusted base refs are allowed.
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
cherry-pick-to-latest-release:
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
resolve-cherry-pick-request:
|
||||
if: >-
|
||||
github.event.pull_request.merged == true
|
||||
&& github.event.pull_request.base.ref == 'main'
|
||||
&& github.event.pull_request.head.repo.full_name == github.repository
|
||||
outputs:
|
||||
should_cherrypick: ${{ steps.gate.outputs.should_cherrypick }}
|
||||
pr_number: ${{ steps.gate.outputs.pr_number }}
|
||||
cherry_pick_reason: ${{ steps.run_cherry_pick.outputs.reason }}
|
||||
cherry_pick_details: ${{ steps.run_cherry_pick.outputs.details }}
|
||||
merge_commit_sha: ${{ steps.gate.outputs.merge_commit_sha }}
|
||||
merged_by: ${{ steps.gate.outputs.merged_by }}
|
||||
gate_error: ${{ steps.gate.outputs.gate_error }}
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 45
|
||||
timeout-minutes: 10
|
||||
steps:
|
||||
- name: Resolve merged PR and checkbox state
|
||||
id: gate
|
||||
env:
|
||||
GH_TOKEN: ${{ github.token }}
|
||||
PR_NUMBER: ${{ github.event.pull_request.number }}
|
||||
# SECURITY: keep PR body in env/plain-text handling; avoid directly
|
||||
# inlining github.event.pull_request.body into shell commands.
|
||||
PR_BODY: ${{ github.event.pull_request.body }}
|
||||
MERGE_COMMIT_SHA: ${{ github.event.pull_request.merge_commit_sha }}
|
||||
MERGED_BY: ${{ github.event.pull_request.merged_by.login }}
|
||||
# Explicit merger allowlist used because pull_request_target runs with
|
||||
# the default GITHUB_TOKEN, which cannot reliably read org/team
|
||||
# membership for this repository context.
|
||||
ALLOWED_MERGERS: |
|
||||
acaprau
|
||||
bo-onyx
|
||||
danelegend
|
||||
duo-onyx
|
||||
evan-onyx
|
||||
jessicasingh7
|
||||
jmelahman
|
||||
joachim-danswer
|
||||
justin-tahara
|
||||
nmgarza5
|
||||
raunakab
|
||||
rohoswagger
|
||||
subash-mohan
|
||||
trial2onyx
|
||||
wenxi-onyx
|
||||
weves
|
||||
yuhongsun96
|
||||
run: |
|
||||
# For the commit that triggered this workflow (HEAD on main), fetch all
|
||||
# associated PRs and keep only the PR that was actually merged into main
|
||||
# with this exact merge commit SHA.
|
||||
pr_numbers="$(gh api "repos/${GITHUB_REPOSITORY}/commits/${GITHUB_SHA}/pulls" | jq -r --arg sha "${GITHUB_SHA}" '.[] | select(.merged_at != null and .base.ref == "main" and .merge_commit_sha == $sha) | .number')"
|
||||
match_count="$(printf '%s\n' "$pr_numbers" | sed '/^[[:space:]]*$/d' | wc -l | tr -d ' ')"
|
||||
pr_number="$(printf '%s\n' "$pr_numbers" | sed '/^[[:space:]]*$/d' | head -n 1)"
|
||||
echo "pr_number=${PR_NUMBER}" >> "$GITHUB_OUTPUT"
|
||||
echo "merged_by=${MERGED_BY}" >> "$GITHUB_OUTPUT"
|
||||
|
||||
if [ "${match_count}" -gt 1 ]; then
|
||||
echo "::warning::Multiple merged PRs matched commit ${GITHUB_SHA}. Using PR #${pr_number}."
|
||||
fi
|
||||
|
||||
if [ -z "$pr_number" ]; then
|
||||
echo "No merged PR associated with commit ${GITHUB_SHA}; skipping."
|
||||
if ! echo "${PR_BODY}" | grep -qiE "\\[x\\][[:space:]]*(\\[[^]]+\\][[:space:]]*)?Please cherry-pick this PR to the latest release version"; then
|
||||
echo "should_cherrypick=false" >> "$GITHUB_OUTPUT"
|
||||
echo "Cherry-pick checkbox not checked for PR #${PR_NUMBER}. Skipping."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# Read the PR once so we can gate behavior and infer preferred actor.
|
||||
pr_json="$(gh api "repos/${GITHUB_REPOSITORY}/pulls/${pr_number}")"
|
||||
pr_body="$(printf '%s' "$pr_json" | jq -r '.body // ""')"
|
||||
merged_by="$(printf '%s' "$pr_json" | jq -r '.merged_by.login // ""')"
|
||||
# Keep should_cherrypick output before any possible exit 1 below so
|
||||
# notify-slack can still gate on this output even if this job fails.
|
||||
echo "should_cherrypick=true" >> "$GITHUB_OUTPUT"
|
||||
echo "Cherry-pick checkbox checked for PR #${PR_NUMBER}."
|
||||
|
||||
echo "pr_number=$pr_number" >> "$GITHUB_OUTPUT"
|
||||
echo "merged_by=$merged_by" >> "$GITHUB_OUTPUT"
|
||||
|
||||
if echo "$pr_body" | grep -qiE "\\[x\\][[:space:]]*(\\[[^]]+\\][[:space:]]*)?Please cherry-pick this PR to the latest release version"; then
|
||||
echo "should_cherrypick=true" >> "$GITHUB_OUTPUT"
|
||||
echo "Cherry-pick checkbox checked for PR #${pr_number}."
|
||||
exit 0
|
||||
if [ -z "${MERGE_COMMIT_SHA}" ] || [ "${MERGE_COMMIT_SHA}" = "null" ]; then
|
||||
echo "gate_error=missing-merge-commit-sha" >> "$GITHUB_OUTPUT"
|
||||
echo "::error::PR #${PR_NUMBER} requested cherry-pick, but merge_commit_sha is missing."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "should_cherrypick=false" >> "$GITHUB_OUTPUT"
|
||||
echo "Cherry-pick checkbox not checked for PR #${pr_number}. Skipping."
|
||||
echo "merge_commit_sha=${MERGE_COMMIT_SHA}" >> "$GITHUB_OUTPUT"
|
||||
|
||||
normalized_merged_by="$(printf '%s' "${MERGED_BY}" | tr '[:upper:]' '[:lower:]')"
|
||||
normalized_allowed_mergers="$(printf '%s\n' "${ALLOWED_MERGERS}" | tr '[:upper:]' '[:lower:]')"
|
||||
if ! printf '%s\n' "${normalized_allowed_mergers}" | grep -Fxq "${normalized_merged_by}"; then
|
||||
echo "gate_error=not-allowed-merger" >> "$GITHUB_OUTPUT"
|
||||
echo "::error::${MERGED_BY} is not in the explicit cherry-pick merger allowlist. Failing cherry-pick gate."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
exit 0
|
||||
|
||||
cherry-pick-to-latest-release:
|
||||
needs:
|
||||
- resolve-cherry-pick-request
|
||||
if: needs.resolve-cherry-pick-request.outputs.should_cherrypick == 'true' && needs.resolve-cherry-pick-request.result == 'success'
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
outputs:
|
||||
cherry_pick_pr_url: ${{ steps.run_cherry_pick.outputs.pr_url }}
|
||||
cherry_pick_reason: ${{ steps.run_cherry_pick.outputs.reason }}
|
||||
cherry_pick_details: ${{ steps.run_cherry_pick.outputs.details }}
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
if: steps.gate.outputs.should_cherrypick == 'true'
|
||||
# SECURITY: keep checkout pinned to trusted base branch; do not switch to PR head refs.
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
@@ -69,34 +114,44 @@ jobs:
|
||||
ref: main
|
||||
|
||||
- name: Install the latest version of uv
|
||||
if: steps.gate.outputs.should_cherrypick == 'true'
|
||||
uses: astral-sh/setup-uv@61cb8a9741eeb8a550a1b8544337180c0fc8476b # ratchet:astral-sh/setup-uv@v7
|
||||
uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # ratchet:astral-sh/setup-uv@v7
|
||||
with:
|
||||
enable-cache: false
|
||||
version: "0.9.9"
|
||||
|
||||
- name: Configure git identity
|
||||
if: steps.gate.outputs.should_cherrypick == 'true'
|
||||
run: |
|
||||
git config user.name "github-actions[bot]"
|
||||
git config user.email "github-actions[bot]@users.noreply.github.com"
|
||||
|
||||
- name: Create cherry-pick PR to latest release
|
||||
id: run_cherry_pick
|
||||
if: steps.gate.outputs.should_cherrypick == 'true'
|
||||
continue-on-error: true
|
||||
env:
|
||||
GH_TOKEN: ${{ github.token }}
|
||||
GITHUB_TOKEN: ${{ github.token }}
|
||||
CHERRY_PICK_ASSIGNEE: ${{ steps.gate.outputs.merged_by }}
|
||||
CHERRY_PICK_ASSIGNEE: ${{ needs.resolve-cherry-pick-request.outputs.merged_by }}
|
||||
MERGE_COMMIT_SHA: ${{ needs.resolve-cherry-pick-request.outputs.merge_commit_sha }}
|
||||
run: |
|
||||
set -o pipefail
|
||||
output_file="$(mktemp)"
|
||||
uv run --no-sync --with onyx-devtools ods cherry-pick "${GITHUB_SHA}" --yes --no-verify 2>&1 | tee "$output_file"
|
||||
exit_code="${PIPESTATUS[0]}"
|
||||
set +e
|
||||
uv run --no-sync --with onyx-devtools ods cherry-pick "${MERGE_COMMIT_SHA}" --yes --no-verify 2>&1 | tee "$output_file"
|
||||
pipe_statuses=("${PIPESTATUS[@]}")
|
||||
exit_code="${pipe_statuses[0]}"
|
||||
tee_exit="${pipe_statuses[1]:-0}"
|
||||
set -e
|
||||
if [ "${tee_exit}" -ne 0 ]; then
|
||||
echo "status=failure" >> "$GITHUB_OUTPUT"
|
||||
echo "reason=output-capture-failed" >> "$GITHUB_OUTPUT"
|
||||
echo "::error::tee failed to capture cherry-pick output (exit ${tee_exit}); cannot classify result."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ "${exit_code}" -eq 0 ]; then
|
||||
pr_url="$(sed -n 's/^.*PR created successfully: \(https:\/\/github\.com\/[^[:space:]]\+\/pull\/[0-9]\+\).*$/\1/p' "$output_file" | tail -n 1)"
|
||||
echo "status=success" >> "$GITHUB_OUTPUT"
|
||||
if [ -n "${pr_url}" ]; then
|
||||
echo "pr_url=${pr_url}" >> "$GITHUB_OUTPUT"
|
||||
fi
|
||||
exit 0
|
||||
fi
|
||||
|
||||
@@ -115,17 +170,18 @@ jobs:
|
||||
} >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Mark workflow as failed if cherry-pick failed
|
||||
if: steps.gate.outputs.should_cherrypick == 'true' && steps.run_cherry_pick.outputs.status == 'failure'
|
||||
if: steps.run_cherry_pick.outputs.status == 'failure'
|
||||
env:
|
||||
CHERRY_PICK_REASON: ${{ steps.run_cherry_pick.outputs.reason }}
|
||||
run: |
|
||||
echo "::error::Automated cherry-pick failed (${CHERRY_PICK_REASON})."
|
||||
exit 1
|
||||
|
||||
notify-slack-on-cherry-pick-failure:
|
||||
notify-slack-on-cherry-pick-success:
|
||||
needs:
|
||||
- resolve-cherry-pick-request
|
||||
- cherry-pick-to-latest-release
|
||||
if: always() && needs.cherry-pick-to-latest-release.outputs.should_cherrypick == 'true' && needs.cherry-pick-to-latest-release.result != 'success'
|
||||
if: needs.resolve-cherry-pick-request.outputs.should_cherrypick == 'true' && needs.resolve-cherry-pick-request.result == 'success' && needs.cherry-pick-to-latest-release.result == 'success'
|
||||
runs-on: ubuntu-slim
|
||||
timeout-minutes: 10
|
||||
steps:
|
||||
@@ -134,22 +190,95 @@ jobs:
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Fail if Slack webhook secret is missing
|
||||
env:
|
||||
CHERRY_PICK_PRS_WEBHOOK: ${{ secrets.CHERRY_PICK_PRS_WEBHOOK }}
|
||||
run: |
|
||||
if [ -z "${CHERRY_PICK_PRS_WEBHOOK}" ]; then
|
||||
echo "::error::CHERRY_PICK_PRS_WEBHOOK is not configured."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
- name: Build cherry-pick success summary
|
||||
id: success-summary
|
||||
env:
|
||||
SOURCE_PR_NUMBER: ${{ needs.resolve-cherry-pick-request.outputs.pr_number }}
|
||||
MERGE_COMMIT_SHA: ${{ needs.resolve-cherry-pick-request.outputs.merge_commit_sha }}
|
||||
CHERRY_PICK_PR_URL: ${{ needs.cherry-pick-to-latest-release.outputs.cherry_pick_pr_url }}
|
||||
run: |
|
||||
source_pr_url="https://github.com/${GITHUB_REPOSITORY}/pull/${SOURCE_PR_NUMBER}"
|
||||
details="*Cherry-pick PR opened successfully.*\\n• source PR: ${source_pr_url}"
|
||||
if [ -n "${CHERRY_PICK_PR_URL}" ]; then
|
||||
details="${details}\\n• cherry-pick PR: ${CHERRY_PICK_PR_URL}"
|
||||
fi
|
||||
if [ -n "${MERGE_COMMIT_SHA}" ]; then
|
||||
details="${details}\\n• merge SHA: ${MERGE_COMMIT_SHA}"
|
||||
fi
|
||||
|
||||
echo "details=${details}" >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Notify #cherry-pick-prs about cherry-pick success
|
||||
uses: ./.github/actions/slack-notify
|
||||
with:
|
||||
webhook-url: ${{ secrets.CHERRY_PICK_PRS_WEBHOOK }}
|
||||
details: ${{ steps.success-summary.outputs.details }}
|
||||
title: "✅ Automated Cherry-Pick PR Opened"
|
||||
ref-name: ${{ github.event.pull_request.base.ref }}
|
||||
|
||||
notify-slack-on-cherry-pick-failure:
|
||||
needs:
|
||||
- resolve-cherry-pick-request
|
||||
- cherry-pick-to-latest-release
|
||||
if: always() && needs.resolve-cherry-pick-request.outputs.should_cherrypick == 'true' && (needs.resolve-cherry-pick-request.result == 'failure' || needs.cherry-pick-to-latest-release.result == 'failure')
|
||||
runs-on: ubuntu-slim
|
||||
timeout-minutes: 10
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Fail if Slack webhook secret is missing
|
||||
env:
|
||||
CHERRY_PICK_PRS_WEBHOOK: ${{ secrets.CHERRY_PICK_PRS_WEBHOOK }}
|
||||
run: |
|
||||
if [ -z "${CHERRY_PICK_PRS_WEBHOOK}" ]; then
|
||||
echo "::error::CHERRY_PICK_PRS_WEBHOOK is not configured."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
- name: Build cherry-pick failure summary
|
||||
id: failure-summary
|
||||
env:
|
||||
SOURCE_PR_NUMBER: ${{ needs.cherry-pick-to-latest-release.outputs.pr_number }}
|
||||
SOURCE_PR_NUMBER: ${{ needs.resolve-cherry-pick-request.outputs.pr_number }}
|
||||
MERGE_COMMIT_SHA: ${{ needs.resolve-cherry-pick-request.outputs.merge_commit_sha }}
|
||||
GATE_ERROR: ${{ needs.resolve-cherry-pick-request.outputs.gate_error }}
|
||||
CHERRY_PICK_REASON: ${{ needs.cherry-pick-to-latest-release.outputs.cherry_pick_reason }}
|
||||
CHERRY_PICK_DETAILS: ${{ needs.cherry-pick-to-latest-release.outputs.cherry_pick_details }}
|
||||
run: |
|
||||
source_pr_url="https://github.com/${GITHUB_REPOSITORY}/pull/${SOURCE_PR_NUMBER}"
|
||||
|
||||
reason_text="cherry-pick command failed"
|
||||
if [ "${CHERRY_PICK_REASON}" = "merge-conflict" ]; then
|
||||
if [ "${GATE_ERROR}" = "missing-merge-commit-sha" ]; then
|
||||
reason_text="requested cherry-pick but merge commit SHA was missing"
|
||||
elif [ "${GATE_ERROR}" = "not-allowed-merger" ]; then
|
||||
reason_text="merger is not in the explicit cherry-pick allowlist"
|
||||
elif [ "${CHERRY_PICK_REASON}" = "output-capture-failed" ]; then
|
||||
reason_text="failed to capture cherry-pick output for classification"
|
||||
elif [ "${CHERRY_PICK_REASON}" = "merge-conflict" ]; then
|
||||
reason_text="merge conflict during cherry-pick"
|
||||
fi
|
||||
|
||||
details_excerpt="$(printf '%s' "${CHERRY_PICK_DETAILS}" | tail -n 8 | tr '\n' ' ' | sed "s/[[:space:]]\\+/ /g" | sed "s/\"/'/g" | cut -c1-350)"
|
||||
failed_jobs="• cherry-pick-to-latest-release\\n• source PR: ${source_pr_url}\\n• reason: ${reason_text}"
|
||||
if [ -n "${GATE_ERROR}" ]; then
|
||||
failed_job_label="resolve-cherry-pick-request"
|
||||
else
|
||||
failed_job_label="cherry-pick-to-latest-release"
|
||||
fi
|
||||
failed_jobs="• ${failed_job_label}\\n• source PR: ${source_pr_url}\\n• reason: ${reason_text}"
|
||||
if [ -n "${MERGE_COMMIT_SHA}" ]; then
|
||||
failed_jobs="${failed_jobs}\\n• merge SHA: ${MERGE_COMMIT_SHA}"
|
||||
fi
|
||||
if [ -n "${details_excerpt}" ]; then
|
||||
failed_jobs="${failed_jobs}\\n• excerpt: ${details_excerpt}"
|
||||
fi
|
||||
@@ -160,6 +289,6 @@ jobs:
|
||||
uses: ./.github/actions/slack-notify
|
||||
with:
|
||||
webhook-url: ${{ secrets.CHERRY_PICK_PRS_WEBHOOK }}
|
||||
failed-jobs: ${{ steps.failure-summary.outputs.jobs }}
|
||||
details: ${{ steps.failure-summary.outputs.jobs }}
|
||||
title: "🚨 Automated Cherry-Pick Failed"
|
||||
ref-name: ${{ github.ref_name }}
|
||||
ref-name: ${{ github.event.pull_request.base.ref }}
|
||||
|
||||
11
.github/workflows/pr-helm-chart-testing.yml
vendored
11
.github/workflows/pr-helm-chart-testing.yml
vendored
@@ -133,7 +133,7 @@ jobs:
|
||||
echo "=== Validating chart dependencies ==="
|
||||
cd deployment/helm/charts/onyx
|
||||
helm dependency update
|
||||
helm lint .
|
||||
helm lint . --set auth.userauth.values.user_auth_secret=placeholder
|
||||
|
||||
- name: Run chart-testing (install) with enhanced monitoring
|
||||
timeout-minutes: 25
|
||||
@@ -194,6 +194,7 @@ jobs:
|
||||
--set=vespa.enabled=false \
|
||||
--set=opensearch.enabled=true \
|
||||
--set=auth.opensearch.enabled=true \
|
||||
--set=auth.userauth.values.user_auth_secret=test-secret \
|
||||
--set=slackbot.enabled=false \
|
||||
--set=postgresql.enabled=true \
|
||||
--set=postgresql.cluster.storage.storageClass=standard \
|
||||
@@ -230,6 +231,10 @@ jobs:
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
run: |
|
||||
echo "=== Post-install verification ==="
|
||||
if ! kubectl cluster-info >/dev/null 2>&1; then
|
||||
echo "ERROR: Kubernetes cluster is not reachable after install"
|
||||
exit 1
|
||||
fi
|
||||
kubectl get pods --all-namespaces
|
||||
kubectl get services --all-namespaces
|
||||
# Only show issues if they exist
|
||||
@@ -239,6 +244,10 @@ jobs:
|
||||
if: failure() && steps.list-changed.outputs.changed == 'true'
|
||||
run: |
|
||||
echo "=== Cleanup on failure ==="
|
||||
if ! kubectl cluster-info >/dev/null 2>&1; then
|
||||
echo "Skipping failure cleanup: Kubernetes cluster is not reachable"
|
||||
exit 0
|
||||
fi
|
||||
echo "=== Final cluster state ==="
|
||||
kubectl get pods --all-namespaces
|
||||
kubectl get events --all-namespaces --sort-by=.lastTimestamp | tail -10
|
||||
|
||||
4
.github/workflows/pr-integration-tests.yml
vendored
4
.github/workflows/pr-integration-tests.yml
vendored
@@ -316,6 +316,7 @@ jobs:
|
||||
# Base config shared by both editions
|
||||
cat <<EOF > deployment/docker_compose/.env
|
||||
COMPOSE_PROFILES=s3-filestore
|
||||
OPENSEARCH_FOR_ONYX_ENABLED=false
|
||||
AUTH_TYPE=basic
|
||||
POSTGRES_POOL_PRE_PING=true
|
||||
POSTGRES_USE_NULL_POOL=true
|
||||
@@ -418,6 +419,7 @@ jobs:
|
||||
-e POSTGRES_POOL_PRE_PING=true \
|
||||
-e POSTGRES_USE_NULL_POOL=true \
|
||||
-e VESPA_HOST=index \
|
||||
-e ENABLE_OPENSEARCH_INDEXING_FOR_ONYX=false \
|
||||
-e REDIS_HOST=cache \
|
||||
-e API_SERVER_HOST=api_server \
|
||||
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
|
||||
@@ -637,6 +639,7 @@ jobs:
|
||||
ONYX_BACKEND_IMAGE=${ECR_CACHE}:integration-test-backend-test-${RUN_ID} \
|
||||
ONYX_MODEL_SERVER_IMAGE=${ECR_CACHE}:integration-test-model-server-test-${RUN_ID} \
|
||||
DEV_MODE=true \
|
||||
OPENSEARCH_FOR_ONYX_ENABLED=false \
|
||||
docker compose -f docker-compose.multitenant-dev.yml up \
|
||||
relational_db \
|
||||
index \
|
||||
@@ -691,6 +694,7 @@ jobs:
|
||||
-e POSTGRES_DB=postgres \
|
||||
-e POSTGRES_USE_NULL_POOL=true \
|
||||
-e VESPA_HOST=index \
|
||||
-e ENABLE_OPENSEARCH_INDEXING_FOR_ONYX=false \
|
||||
-e REDIS_HOST=cache \
|
||||
-e API_SERVER_HOST=api_server \
|
||||
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
|
||||
|
||||
7
.github/workflows/pr-playwright-tests.yml
vendored
7
.github/workflows/pr-playwright-tests.yml
vendored
@@ -12,6 +12,9 @@ on:
|
||||
push:
|
||||
tags:
|
||||
- "v*.*.*"
|
||||
# TODO: Remove this if we enable merge-queues for release branches.
|
||||
branches:
|
||||
- "release/**"
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
@@ -468,7 +471,7 @@ jobs:
|
||||
|
||||
- name: Install the latest version of uv
|
||||
if: always()
|
||||
uses: astral-sh/setup-uv@61cb8a9741eeb8a550a1b8544337180c0fc8476b # ratchet:astral-sh/setup-uv@v7
|
||||
uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # ratchet:astral-sh/setup-uv@v7
|
||||
with:
|
||||
enable-cache: false
|
||||
version: "0.9.9"
|
||||
@@ -707,7 +710,7 @@ jobs:
|
||||
pull-requests: write
|
||||
steps:
|
||||
- name: Download visual diff summaries
|
||||
uses: actions/download-artifact@37930b1c2abaa49bbe596cd826c3c89aef350131
|
||||
uses: actions/download-artifact@70fc10c6e5e1ce46ad2ea6f2b72d43f7d47b13c3
|
||||
with:
|
||||
pattern: screenshot-diff-summary-*
|
||||
path: summaries/
|
||||
|
||||
2
.github/workflows/pr-quality-checks.yml
vendored
2
.github/workflows/pr-quality-checks.yml
vendored
@@ -28,7 +28,7 @@ jobs:
|
||||
with:
|
||||
python-version: "3.11"
|
||||
- name: Setup Terraform
|
||||
uses: hashicorp/setup-terraform@b9cd54a3c349d3f38e8881555d616ced269862dd # ratchet:hashicorp/setup-terraform@v3
|
||||
uses: hashicorp/setup-terraform@5e8dbf3c6d9deaf4193ca7a8fb23f2ac83bb6c85 # ratchet:hashicorp/setup-terraform@v4.0.0
|
||||
- name: Setup node
|
||||
uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238 # ratchet:actions/setup-node@v6
|
||||
with: # zizmor: ignore[cache-poisoning]
|
||||
|
||||
2
.github/workflows/release-devtools.yml
vendored
2
.github/workflows/release-devtools.yml
vendored
@@ -26,7 +26,7 @@ jobs:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: astral-sh/setup-uv@61cb8a9741eeb8a550a1b8544337180c0fc8476b # ratchet:astral-sh/setup-uv@v7
|
||||
- uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # ratchet:astral-sh/setup-uv@v7
|
||||
with:
|
||||
enable-cache: false
|
||||
version: "0.9.9"
|
||||
|
||||
69
.github/workflows/storybook-deploy.yml
vendored
Normal file
69
.github/workflows/storybook-deploy.yml
vendored
Normal file
@@ -0,0 +1,69 @@
|
||||
name: Storybook Deploy
|
||||
env:
|
||||
VERCEL_ORG_ID: ${{ secrets.VERCEL_ORG_ID }}
|
||||
VERCEL_PROJECT_ID: prj_sG49mVsA25UsxIPhN2pmBJlikJZM
|
||||
VERCEL_CLI: vercel@50.14.1
|
||||
VERCEL_TOKEN: ${{ secrets.VERCEL_TOKEN }}
|
||||
|
||||
concurrency:
|
||||
group: storybook-deploy-production
|
||||
cancel-in-progress: true
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- "web/lib/opal/**"
|
||||
- "web/src/refresh-components/**"
|
||||
- "web/.storybook/**"
|
||||
- "web/package.json"
|
||||
- "web/package-lock.json"
|
||||
permissions:
|
||||
contents: read
|
||||
jobs:
|
||||
Deploy-Storybook:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 30
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup node
|
||||
uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238 # ratchet:actions/setup-node@v4
|
||||
with:
|
||||
node-version: 22
|
||||
cache: "npm"
|
||||
cache-dependency-path: ./web/package-lock.json
|
||||
|
||||
- name: Install dependencies
|
||||
working-directory: web
|
||||
run: npm ci
|
||||
|
||||
- name: Build Storybook
|
||||
working-directory: web
|
||||
run: npm run storybook:build
|
||||
|
||||
- name: Deploy to Vercel (Production)
|
||||
working-directory: web
|
||||
run: npx --yes "$VERCEL_CLI" deploy storybook-static/ --prod --yes --token="$VERCEL_TOKEN"
|
||||
|
||||
notify-slack-on-failure:
|
||||
needs: Deploy-Storybook
|
||||
if: always() && needs.Deploy-Storybook.result == 'failure'
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 10
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
sparse-checkout: .github/actions/slack-notify
|
||||
|
||||
- name: Send Slack notification
|
||||
uses: ./.github/actions/slack-notify
|
||||
with:
|
||||
webhook-url: ${{ secrets.MONITOR_DEPLOYMENTS_WEBHOOK }}
|
||||
failed-jobs: "• Deploy-Storybook"
|
||||
title: "🚨 Storybook Deploy Failed"
|
||||
2
.github/workflows/zizmor.yml
vendored
2
.github/workflows/zizmor.yml
vendored
@@ -24,7 +24,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install the latest version of uv
|
||||
uses: astral-sh/setup-uv@61cb8a9741eeb8a550a1b8544337180c0fc8476b # ratchet:astral-sh/setup-uv@v7
|
||||
uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # ratchet:astral-sh/setup-uv@v7
|
||||
with:
|
||||
enable-cache: false
|
||||
version: "0.9.9"
|
||||
|
||||
3
.vscode/env_template.txt
vendored
3
.vscode/env_template.txt
vendored
@@ -7,6 +7,9 @@
|
||||
|
||||
|
||||
AUTH_TYPE=basic
|
||||
# Recommended for basic auth - used for signing password reset and verification tokens
|
||||
# Generate a secure value with: openssl rand -hex 32
|
||||
USER_AUTH_SECRET=""
|
||||
DEV_MODE=true
|
||||
|
||||
|
||||
|
||||
@@ -598,7 +598,7 @@ Before writing your plan, make sure to do research. Explore the relevant section
|
||||
Never hardcode status codes or use `starlette.status` / `fastapi.status` constants directly.**
|
||||
|
||||
A global FastAPI exception handler converts `OnyxError` into a JSON response with the standard
|
||||
`{"error_code": "...", "message": "..."}` shape. This eliminates boilerplate and keeps error
|
||||
`{"error_code": "...", "detail": "..."}` shape. This eliminates boilerplate and keeps error
|
||||
handling consistent across the entire backend.
|
||||
|
||||
```python
|
||||
|
||||
@@ -46,7 +46,9 @@ RUN apt-get update && \
|
||||
pkg-config \
|
||||
gcc \
|
||||
nano \
|
||||
vim && \
|
||||
vim \
|
||||
libjemalloc2 \
|
||||
&& \
|
||||
rm -rf /var/lib/apt/lists/* && \
|
||||
apt-get clean
|
||||
|
||||
@@ -141,6 +143,7 @@ COPY --chown=onyx:onyx ./scripts/debugging /app/scripts/debugging
|
||||
COPY --chown=onyx:onyx ./scripts/force_delete_connector_by_id.py /app/scripts/force_delete_connector_by_id.py
|
||||
COPY --chown=onyx:onyx ./scripts/supervisord_entrypoint.sh /app/scripts/supervisord_entrypoint.sh
|
||||
COPY --chown=onyx:onyx ./scripts/setup_craft_templates.sh /app/scripts/setup_craft_templates.sh
|
||||
COPY --chown=onyx:onyx ./scripts/reencrypt_secrets.py /app/scripts/reencrypt_secrets.py
|
||||
RUN chmod +x /app/scripts/supervisord_entrypoint.sh /app/scripts/setup_craft_templates.sh
|
||||
|
||||
# Run Craft template setup at build time when ENABLE_CRAFT=true
|
||||
@@ -164,6 +167,13 @@ ENV PYTHONPATH=/app
|
||||
ARG ONYX_VERSION=0.0.0-dev
|
||||
ENV ONYX_VERSION=${ONYX_VERSION}
|
||||
|
||||
# Use jemalloc instead of glibc malloc to reduce memory fragmentation
|
||||
# in long-running Python processes (API server, Celery workers).
|
||||
# The soname is architecture-independent; the dynamic linker resolves
|
||||
# the correct path from standard library directories.
|
||||
# Placed after all RUN steps so build-time processes are unaffected.
|
||||
ENV LD_PRELOAD=libjemalloc.so.2
|
||||
|
||||
# Default command which does nothing
|
||||
# This container is used by api server and background which specify their own CMD
|
||||
CMD ["tail", "-f", "/dev/null"]
|
||||
|
||||
@@ -244,7 +244,10 @@ def do_run_migrations(
|
||||
|
||||
|
||||
def provide_iam_token_for_alembic(
|
||||
dialect: Any, conn_rec: Any, cargs: Any, cparams: Any # noqa: ARG001
|
||||
dialect: Any, # noqa: ARG001
|
||||
conn_rec: Any, # noqa: ARG001
|
||||
cargs: Any, # noqa: ARG001
|
||||
cparams: Any,
|
||||
) -> None:
|
||||
if USE_IAM_AUTH:
|
||||
# Database connection settings
|
||||
@@ -360,8 +363,7 @@ async def run_async_migrations() -> None:
|
||||
# upgrade_all_tenants=true or schemas in multi-tenant mode
|
||||
# and for non-multi-tenant mode, we should use schemas with the default schema
|
||||
raise ValueError(
|
||||
"No migration target specified. Use either upgrade_all_tenants=true for all tenants "
|
||||
"or schemas for specific schemas."
|
||||
"No migration target specified. Use either upgrade_all_tenants=true for all tenants or schemas for specific schemas."
|
||||
)
|
||||
|
||||
await engine.dispose()
|
||||
@@ -457,8 +459,7 @@ def run_migrations_offline() -> None:
|
||||
else:
|
||||
# This should not happen in the new design
|
||||
raise ValueError(
|
||||
"No migration target specified. Use either upgrade_all_tenants=true for all tenants "
|
||||
"or schemas for specific schemas."
|
||||
"No migration target specified. Use either upgrade_all_tenants=true for all tenants or schemas for specific schemas."
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ Usage examples::
|
||||
# custom settings
|
||||
python alembic/run_multitenant_migrations.py -j 8 -b 100
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
@@ -117,8 +118,7 @@ def run_migrations_parallel(
|
||||
batches = [schemas[i : i + batch_size] for i in range(0, len(schemas), batch_size)]
|
||||
total_batches = len(batches)
|
||||
print(
|
||||
f"{len(schemas)} schemas in {total_batches} batch(es) "
|
||||
f"with {max_workers} workers (batch size: {batch_size})...",
|
||||
f"{len(schemas)} schemas in {total_batches} batch(es) with {max_workers} workers (batch size: {batch_size})...",
|
||||
flush=True,
|
||||
)
|
||||
all_success = True
|
||||
@@ -166,8 +166,7 @@ def run_migrations_parallel(
|
||||
with lock:
|
||||
in_flight[batch_idx] = batch
|
||||
print(
|
||||
f"Batch {batch_idx + 1}/{total_batches} started "
|
||||
f"({len(batch)} schemas): {', '.join(batch)}",
|
||||
f"Batch {batch_idx + 1}/{total_batches} started ({len(batch)} schemas): {', '.join(batch)}",
|
||||
flush=True,
|
||||
)
|
||||
result = run_alembic_for_batch(batch)
|
||||
@@ -201,7 +200,7 @@ def run_migrations_parallel(
|
||||
|
||||
except Exception as e:
|
||||
print(
|
||||
f"Batch {batch_idx + 1}/{total_batches} " f"✗ exception: {e}",
|
||||
f"Batch {batch_idx + 1}/{total_batches} ✗ exception: {e}",
|
||||
flush=True,
|
||||
)
|
||||
all_success = False
|
||||
@@ -268,14 +267,12 @@ def main() -> int:
|
||||
|
||||
if not schemas_to_migrate:
|
||||
print(
|
||||
f"All {len(tenant_schemas)} tenants are already at head "
|
||||
f"revision ({head_rev})."
|
||||
f"All {len(tenant_schemas)} tenants are already at head revision ({head_rev})."
|
||||
)
|
||||
return 0
|
||||
|
||||
print(
|
||||
f"{len(schemas_to_migrate)}/{len(tenant_schemas)} tenants need "
|
||||
f"migration (head: {head_rev})."
|
||||
f"{len(schemas_to_migrate)}/{len(tenant_schemas)} tenants need migration (head: {head_rev})."
|
||||
)
|
||||
|
||||
success = run_migrations_parallel(
|
||||
|
||||
@@ -0,0 +1,43 @@
|
||||
"""add timestamps to user table
|
||||
|
||||
Revision ID: 27fb147a843f
|
||||
Revises: b5c4d7e8f9a1
|
||||
Create Date: 2026-03-08 17:18:40.828644
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "27fb147a843f"
|
||||
down_revision = "b5c4d7e8f9a1"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.func.now(),
|
||||
nullable=False,
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.func.now(),
|
||||
nullable=False,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("user", "updated_at")
|
||||
op.drop_column("user", "created_at")
|
||||
@@ -50,8 +50,7 @@ def upgrade() -> None:
|
||||
|
||||
if orphaned_count > 0:
|
||||
logger.warning(
|
||||
f"WARNING: {orphaned_count} chat_session records still have "
|
||||
f"folder_id without project_id. Proceeding anyway."
|
||||
f"WARNING: {orphaned_count} chat_session records still have folder_id without project_id. Proceeding anyway."
|
||||
)
|
||||
|
||||
# === Step 2: Drop chat_session.folder_id ===
|
||||
|
||||
@@ -75,8 +75,7 @@ def batch_delete(
|
||||
|
||||
if failed_batches:
|
||||
logger.warning(
|
||||
f"Failed to delete {len(failed_batches)} batches from {table_name}. "
|
||||
f"Total deleted: {total_deleted}/{total_count}"
|
||||
f"Failed to delete {len(failed_batches)} batches from {table_name}. Total deleted: {total_deleted}/{total_count}"
|
||||
)
|
||||
# Fail the migration to avoid silently succeeding on partial cleanup
|
||||
raise RuntimeError(
|
||||
|
||||
@@ -18,8 +18,7 @@ depends_on = None
|
||||
def upgrade() -> None:
|
||||
# Set all existing records to not migrated
|
||||
op.execute(
|
||||
"UPDATE user_file SET document_id_migrated = FALSE "
|
||||
"WHERE document_id_migrated IS DISTINCT FROM FALSE;"
|
||||
"UPDATE user_file SET document_id_migrated = FALSE WHERE document_id_migrated IS DISTINCT FROM FALSE;"
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -35,7 +35,6 @@ def upgrade() -> None:
|
||||
# environment variables MUST be set. Otherwise, an exception will be raised.
|
||||
|
||||
if not MULTI_TENANT:
|
||||
|
||||
# Enable pg_trgm extension if not already enabled
|
||||
op.execute("CREATE EXTENSION IF NOT EXISTS pg_trgm")
|
||||
|
||||
@@ -481,8 +480,7 @@ def upgrade() -> None:
|
||||
f"ON kg_entity USING GIN (name {POSTGRES_DEFAULT_SCHEMA}.gin_trgm_ops)"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_kg_entity_normalization_trigrams "
|
||||
"ON kg_entity USING GIN (name_trigrams)"
|
||||
"CREATE INDEX IF NOT EXISTS idx_kg_entity_normalization_trigrams ON kg_entity USING GIN (name_trigrams)"
|
||||
)
|
||||
|
||||
# Create kg_entity trigger to update kg_entity.name and its trigrams
|
||||
|
||||
@@ -51,10 +51,7 @@ def upgrade() -> None:
|
||||
next_email = f"{username.lower()}_{attempt}@{domain.lower()}"
|
||||
# Email conflict occurred, append `_1`, `_2`, etc., to the username
|
||||
logger.warning(
|
||||
f"Conflict while lowercasing email: "
|
||||
f"old_email={email} "
|
||||
f"conflicting_email={new_email} "
|
||||
f"next_email={next_email}"
|
||||
f"Conflict while lowercasing email: old_email={email} conflicting_email={new_email} next_email={next_email}"
|
||||
)
|
||||
new_email = next_email
|
||||
attempt += 1
|
||||
|
||||
@@ -24,12 +24,10 @@ depends_on = None
|
||||
def upgrade() -> None:
|
||||
# Convert existing lowercase values to uppercase to match enum member names
|
||||
op.execute(
|
||||
"UPDATE connector_credential_pair SET processing_mode = 'REGULAR' "
|
||||
"WHERE processing_mode = 'regular'"
|
||||
"UPDATE connector_credential_pair SET processing_mode = 'REGULAR' WHERE processing_mode = 'regular'"
|
||||
)
|
||||
op.execute(
|
||||
"UPDATE connector_credential_pair SET processing_mode = 'FILE_SYSTEM' "
|
||||
"WHERE processing_mode = 'file_system'"
|
||||
"UPDATE connector_credential_pair SET processing_mode = 'FILE_SYSTEM' WHERE processing_mode = 'file_system'"
|
||||
)
|
||||
|
||||
# Update the server default to use uppercase
|
||||
|
||||
@@ -289,8 +289,7 @@ def upgrade() -> None:
|
||||
attributes_str = json.dumps(attributes).replace("'", "''")
|
||||
op.execute(
|
||||
sa.text(
|
||||
f"UPDATE kg_entity_type SET attributes = '{attributes_str}'"
|
||||
f"WHERE id_name = '{entity_type}'"
|
||||
f"UPDATE kg_entity_type SET attributes = '{attributes_str}'WHERE id_name = '{entity_type}'"
|
||||
),
|
||||
)
|
||||
|
||||
@@ -312,7 +311,6 @@ def downgrade() -> None:
|
||||
attributes_str = json.dumps(attributes).replace("'", "''")
|
||||
op.execute(
|
||||
sa.text(
|
||||
f"UPDATE kg_entity_type SET attributes = '{attributes_str}'"
|
||||
f"WHERE id_name = '{entity_type}'"
|
||||
f"UPDATE kg_entity_type SET attributes = '{attributes_str}'WHERE id_name = '{entity_type}'"
|
||||
),
|
||||
)
|
||||
|
||||
@@ -160,7 +160,7 @@ def remove_old_tags() -> None:
|
||||
f"""
|
||||
DELETE FROM document__tag
|
||||
WHERE document_id = '{document_id}'
|
||||
AND tag_id IN ({','.join(to_delete)})
|
||||
AND tag_id IN ({",".join(to_delete)})
|
||||
"""
|
||||
)
|
||||
)
|
||||
@@ -239,7 +239,7 @@ def _get_batch_documents_with_multiple_tags(
|
||||
).fetchall()
|
||||
if not batch:
|
||||
break
|
||||
doc_ids = [document_id for document_id, in batch]
|
||||
doc_ids = [document_id for (document_id,) in batch]
|
||||
yield doc_ids
|
||||
offset_clause = f"AND document__tag.document_id > '{doc_ids[-1]}'"
|
||||
|
||||
|
||||
@@ -0,0 +1,117 @@
|
||||
"""add_voice_provider_and_user_voice_prefs
|
||||
|
||||
Revision ID: 93a2e195e25c
|
||||
Revises: 27fb147a843f
|
||||
Create Date: 2026-02-23 15:16:39.507304
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import column
|
||||
from sqlalchemy import true
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "93a2e195e25c"
|
||||
down_revision = "27fb147a843f"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Create voice_provider table
|
||||
op.create_table(
|
||||
"voice_provider",
|
||||
sa.Column("id", sa.Integer(), primary_key=True),
|
||||
sa.Column("name", sa.String(), unique=True, nullable=False),
|
||||
sa.Column("provider_type", sa.String(), nullable=False),
|
||||
sa.Column("api_key", sa.LargeBinary(), nullable=True),
|
||||
sa.Column("api_base", sa.String(), nullable=True),
|
||||
sa.Column("custom_config", postgresql.JSONB(), nullable=True),
|
||||
sa.Column("stt_model", sa.String(), nullable=True),
|
||||
sa.Column("tts_model", sa.String(), nullable=True),
|
||||
sa.Column("default_voice", sa.String(), nullable=True),
|
||||
sa.Column(
|
||||
"is_default_stt", sa.Boolean(), nullable=False, server_default="false"
|
||||
),
|
||||
sa.Column(
|
||||
"is_default_tts", sa.Boolean(), nullable=False, server_default="false"
|
||||
),
|
||||
sa.Column("deleted", sa.Boolean(), nullable=False, server_default="false"),
|
||||
sa.Column(
|
||||
"time_created",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.func.now(),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"time_updated",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.func.now(),
|
||||
onupdate=sa.func.now(),
|
||||
nullable=False,
|
||||
),
|
||||
)
|
||||
|
||||
# Add partial unique indexes to enforce only one default STT/TTS provider
|
||||
op.create_index(
|
||||
"ix_voice_provider_one_default_stt",
|
||||
"voice_provider",
|
||||
["is_default_stt"],
|
||||
unique=True,
|
||||
postgresql_where=column("is_default_stt") == true(),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_voice_provider_one_default_tts",
|
||||
"voice_provider",
|
||||
["is_default_tts"],
|
||||
unique=True,
|
||||
postgresql_where=column("is_default_tts") == true(),
|
||||
)
|
||||
|
||||
# Add voice preference columns to user table
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column(
|
||||
"voice_auto_send",
|
||||
sa.Boolean(),
|
||||
default=False,
|
||||
nullable=False,
|
||||
server_default="false",
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column(
|
||||
"voice_auto_playback",
|
||||
sa.Boolean(),
|
||||
default=False,
|
||||
nullable=False,
|
||||
server_default="false",
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column(
|
||||
"voice_playback_speed",
|
||||
sa.Float(),
|
||||
default=1.0,
|
||||
nullable=False,
|
||||
server_default="1.0",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Remove user voice preference columns
|
||||
op.drop_column("user", "voice_playback_speed")
|
||||
op.drop_column("user", "voice_auto_playback")
|
||||
op.drop_column("user", "voice_auto_send")
|
||||
|
||||
op.drop_index("ix_voice_provider_one_default_tts", table_name="voice_provider")
|
||||
op.drop_index("ix_voice_provider_one_default_stt", table_name="voice_provider")
|
||||
|
||||
# Drop voice_provider table
|
||||
op.drop_table("voice_provider")
|
||||
@@ -24,8 +24,7 @@ TOOL_DESCRIPTIONS = {
|
||||
"The action will be used when the user asks the agent to generate an image."
|
||||
),
|
||||
"WebSearchTool": (
|
||||
"The Web Search Action allows the agent "
|
||||
"to perform internet searches for up-to-date information."
|
||||
"The Web Search Action allows the agent to perform internet searches for up-to-date information."
|
||||
),
|
||||
"KnowledgeGraphTool": (
|
||||
"The Knowledge Graph Search Action allows the agent to search the "
|
||||
|
||||
@@ -0,0 +1,51 @@
|
||||
"""add hierarchy_node_by_connector_credential_pair table
|
||||
|
||||
Revision ID: b5c4d7e8f9a1
|
||||
Revises: a3b8d9e2f1c4
|
||||
Create Date: 2026-03-04
|
||||
|
||||
"""
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
revision = "b5c4d7e8f9a1"
|
||||
down_revision = "a3b8d9e2f1c4"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"hierarchy_node_by_connector_credential_pair",
|
||||
sa.Column("hierarchy_node_id", sa.Integer(), nullable=False),
|
||||
sa.Column("connector_id", sa.Integer(), nullable=False),
|
||||
sa.Column("credential_id", sa.Integer(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["hierarchy_node_id"],
|
||||
["hierarchy_node.id"],
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["connector_id", "credential_id"],
|
||||
[
|
||||
"connector_credential_pair.connector_id",
|
||||
"connector_credential_pair.credential_id",
|
||||
],
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
sa.PrimaryKeyConstraint("hierarchy_node_id", "connector_id", "credential_id"),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_hierarchy_node_cc_pair_connector_credential",
|
||||
"hierarchy_node_by_connector_credential_pair",
|
||||
["connector_id", "credential_id"],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index(
|
||||
"ix_hierarchy_node_cc_pair_connector_credential",
|
||||
table_name="hierarchy_node_by_connector_credential_pair",
|
||||
)
|
||||
op.drop_table("hierarchy_node_by_connector_credential_pair")
|
||||
@@ -140,8 +140,7 @@ def _migrate_files_to_postgres() -> None:
|
||||
# Fetch rows that have external storage pointers (bucket/object_key not NULL)
|
||||
result = session.execute(
|
||||
text(
|
||||
"SELECT file_id, bucket_name, object_key FROM file_record "
|
||||
"WHERE bucket_name IS NOT NULL AND object_key IS NOT NULL"
|
||||
"SELECT file_id, bucket_name, object_key FROM file_record WHERE bucket_name IS NOT NULL AND object_key IS NOT NULL"
|
||||
)
|
||||
)
|
||||
|
||||
@@ -182,8 +181,7 @@ def _migrate_files_to_postgres() -> None:
|
||||
# Update DB row: set lobj_oid, clear bucket/object_key
|
||||
session.execute(
|
||||
text(
|
||||
"UPDATE file_record SET lobj_oid = :lobj_oid, bucket_name = NULL, "
|
||||
"object_key = NULL WHERE file_id = :file_id"
|
||||
"UPDATE file_record SET lobj_oid = :lobj_oid, bucket_name = NULL, object_key = NULL WHERE file_id = :file_id"
|
||||
),
|
||||
{"lobj_oid": lobj_oid, "file_id": file_id},
|
||||
)
|
||||
@@ -224,8 +222,7 @@ def _migrate_files_to_external_storage() -> None:
|
||||
# Find all files currently stored in PostgreSQL (lobj_oid is not null)
|
||||
result = session.execute(
|
||||
text(
|
||||
"SELECT file_id FROM file_record WHERE lobj_oid IS NOT NULL "
|
||||
"AND bucket_name IS NULL AND object_key IS NULL"
|
||||
"SELECT file_id FROM file_record WHERE lobj_oid IS NOT NULL AND bucket_name IS NULL AND object_key IS NULL"
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -39,8 +39,7 @@ BUILT_IN_TOOLS = [
|
||||
"name": "WebSearchTool",
|
||||
"display_name": "Web Search",
|
||||
"description": (
|
||||
"The Web Search Action allows the assistant "
|
||||
"to perform internet searches for up-to-date information."
|
||||
"The Web Search Action allows the assistant to perform internet searches for up-to-date information."
|
||||
),
|
||||
"in_code_tool_id": "WebSearchTool",
|
||||
},
|
||||
|
||||
@@ -11,7 +11,6 @@ from sqlalchemy import text
|
||||
from alembic import op
|
||||
from onyx.configs.app_configs import DB_READONLY_PASSWORD
|
||||
from onyx.configs.app_configs import DB_READONLY_USER
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
@@ -22,59 +21,52 @@ depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
if MULTI_TENANT:
|
||||
# Enable pg_trgm extension if not already enabled
|
||||
op.execute("CREATE EXTENSION IF NOT EXISTS pg_trgm")
|
||||
|
||||
# Enable pg_trgm extension if not already enabled
|
||||
op.execute("CREATE EXTENSION IF NOT EXISTS pg_trgm")
|
||||
# Create the read-only db user if it does not already exist.
|
||||
if not (DB_READONLY_USER and DB_READONLY_PASSWORD):
|
||||
raise Exception("DB_READONLY_USER or DB_READONLY_PASSWORD is not set")
|
||||
|
||||
# Create read-only db user here only in multi-tenant mode. For single-tenant mode,
|
||||
# the user is created in the standard migration.
|
||||
if not (DB_READONLY_USER and DB_READONLY_PASSWORD):
|
||||
raise Exception("DB_READONLY_USER or DB_READONLY_PASSWORD is not set")
|
||||
|
||||
op.execute(
|
||||
text(
|
||||
f"""
|
||||
DO $$
|
||||
BEGIN
|
||||
-- Check if the read-only user already exists
|
||||
IF NOT EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname = '{DB_READONLY_USER}') THEN
|
||||
-- Create the read-only user with the specified password
|
||||
EXECUTE format('CREATE USER %I WITH PASSWORD %L', '{DB_READONLY_USER}', '{DB_READONLY_PASSWORD}');
|
||||
-- First revoke all privileges to ensure a clean slate
|
||||
EXECUTE format('REVOKE ALL ON DATABASE %I FROM %I', current_database(), '{DB_READONLY_USER}');
|
||||
-- Grant only the CONNECT privilege to allow the user to connect to the database
|
||||
-- but not perform any operations without additional specific grants
|
||||
EXECUTE format('GRANT CONNECT ON DATABASE %I TO %I', current_database(), '{DB_READONLY_USER}');
|
||||
END IF;
|
||||
END
|
||||
$$;
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
if MULTI_TENANT:
|
||||
# Drop read-only db user here only in single tenant mode. For multi-tenant mode,
|
||||
# the user is dropped in the alembic_tenants migration.
|
||||
|
||||
op.execute(
|
||||
text(
|
||||
f"""
|
||||
op.execute(
|
||||
text(
|
||||
f"""
|
||||
DO $$
|
||||
BEGIN
|
||||
IF EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname = '{DB_READONLY_USER}') THEN
|
||||
-- First revoke all privileges from the database
|
||||
-- Check if the read-only user already exists
|
||||
IF NOT EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname = '{DB_READONLY_USER}') THEN
|
||||
-- Create the read-only user with the specified password
|
||||
EXECUTE format('CREATE USER %I WITH PASSWORD %L', '{DB_READONLY_USER}', '{DB_READONLY_PASSWORD}');
|
||||
-- First revoke all privileges to ensure a clean slate
|
||||
EXECUTE format('REVOKE ALL ON DATABASE %I FROM %I', current_database(), '{DB_READONLY_USER}');
|
||||
-- Then revoke all privileges from the public schema
|
||||
EXECUTE format('REVOKE ALL ON SCHEMA public FROM %I', '{DB_READONLY_USER}');
|
||||
-- Then drop the user
|
||||
EXECUTE format('DROP USER %I', '{DB_READONLY_USER}');
|
||||
-- Grant only the CONNECT privilege to allow the user to connect to the database
|
||||
-- but not perform any operations without additional specific grants
|
||||
EXECUTE format('GRANT CONNECT ON DATABASE %I TO %I', current_database(), '{DB_READONLY_USER}');
|
||||
END IF;
|
||||
END
|
||||
$$;
|
||||
"""
|
||||
)
|
||||
"""
|
||||
)
|
||||
op.execute(text("DROP EXTENSION IF EXISTS pg_trgm"))
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute(
|
||||
text(
|
||||
f"""
|
||||
DO $$
|
||||
BEGIN
|
||||
IF EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname = '{DB_READONLY_USER}') THEN
|
||||
-- First revoke all privileges from the database
|
||||
EXECUTE format('REVOKE ALL ON DATABASE %I FROM %I', current_database(), '{DB_READONLY_USER}');
|
||||
-- Then revoke all privileges from the public schema
|
||||
EXECUTE format('REVOKE ALL ON SCHEMA public FROM %I', '{DB_READONLY_USER}');
|
||||
-- Then drop the user
|
||||
EXECUTE format('DROP USER %I', '{DB_READONLY_USER}');
|
||||
END IF;
|
||||
END
|
||||
$$;
|
||||
"""
|
||||
)
|
||||
)
|
||||
op.execute(text("DROP EXTENSION IF EXISTS pg_trgm"))
|
||||
|
||||
@@ -9,12 +9,15 @@ from onyx.access.access import (
|
||||
_get_access_for_documents as get_access_for_documents_without_groups,
|
||||
)
|
||||
from onyx.access.access import _get_acl_for_user as get_acl_for_user_without_groups
|
||||
from onyx.access.access import collect_user_file_access
|
||||
from onyx.access.models import DocumentAccess
|
||||
from onyx.access.utils import prefix_external_group
|
||||
from onyx.access.utils import prefix_user_group
|
||||
from onyx.db.document import get_document_sources
|
||||
from onyx.db.document import get_documents_by_ids
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.db.user_file import fetch_user_files_with_access_relationships
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
@@ -116,6 +119,68 @@ def _get_access_for_documents(
|
||||
return access_map
|
||||
|
||||
|
||||
def _collect_user_file_group_names(user_file: UserFile) -> set[str]:
|
||||
"""Extract user-group names from the already-loaded Persona.groups
|
||||
relationships on a UserFile (skipping deleted personas)."""
|
||||
groups: set[str] = set()
|
||||
for persona in user_file.assistants:
|
||||
if persona.deleted:
|
||||
continue
|
||||
for group in persona.groups:
|
||||
groups.add(group.name)
|
||||
return groups
|
||||
|
||||
|
||||
def get_access_for_user_files_impl(
|
||||
user_file_ids: list[str],
|
||||
db_session: Session,
|
||||
) -> dict[str, DocumentAccess]:
|
||||
"""EE version: extends the MIT user file ACL with user group names
|
||||
from personas shared via user groups.
|
||||
|
||||
Uses a single DB query (via fetch_user_files_with_access_relationships)
|
||||
that eagerly loads both the MIT-needed and EE-needed relationships.
|
||||
|
||||
NOTE: is imported in onyx.access.access by `fetch_versioned_implementation`
|
||||
DO NOT REMOVE."""
|
||||
user_files = fetch_user_files_with_access_relationships(
|
||||
user_file_ids, db_session, eager_load_groups=True
|
||||
)
|
||||
return build_access_for_user_files_impl(user_files)
|
||||
|
||||
|
||||
def build_access_for_user_files_impl(
|
||||
user_files: list[UserFile],
|
||||
) -> dict[str, DocumentAccess]:
|
||||
"""EE version: works on pre-loaded UserFile objects.
|
||||
Expects Persona.groups to be eagerly loaded.
|
||||
|
||||
NOTE: is imported in onyx.access.access by `fetch_versioned_implementation`
|
||||
DO NOT REMOVE."""
|
||||
result: dict[str, DocumentAccess] = {}
|
||||
for user_file in user_files:
|
||||
if user_file.user is None:
|
||||
result[str(user_file.id)] = DocumentAccess.build(
|
||||
user_emails=[],
|
||||
user_groups=[],
|
||||
is_public=True,
|
||||
external_user_emails=[],
|
||||
external_user_group_ids=[],
|
||||
)
|
||||
continue
|
||||
|
||||
emails, is_public = collect_user_file_access(user_file)
|
||||
group_names = _collect_user_file_group_names(user_file)
|
||||
result[str(user_file.id)] = DocumentAccess.build(
|
||||
user_emails=list(emails),
|
||||
user_groups=list(group_names),
|
||||
is_public=is_public,
|
||||
external_user_emails=[],
|
||||
external_user_group_ids=[],
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def _get_acl_for_user(user: User, db_session: Session) -> set[str]:
|
||||
"""Returns a list of ACL entries that the user has access to. This is meant to be
|
||||
used downstream to filter out documents that the user does not have access to. The
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
import jwt
|
||||
@@ -20,7 +21,12 @@ logger = setup_logger()
|
||||
|
||||
|
||||
def verify_auth_setting() -> None:
|
||||
# All the Auth flows are valid for EE version
|
||||
# All the Auth flows are valid for EE version, but warn about deprecated 'disabled'
|
||||
raw_auth_type = (os.environ.get("AUTH_TYPE") or "").lower()
|
||||
if raw_auth_type == "disabled":
|
||||
logger.warning(
|
||||
"AUTH_TYPE='disabled' is no longer supported. Using 'basic' instead. Please update your configuration."
|
||||
)
|
||||
logger.notice(f"Using Auth Type: {AUTH_TYPE.value}")
|
||||
|
||||
|
||||
|
||||
@@ -59,7 +59,6 @@ def cloud_beat_task_generator(
|
||||
# gated_tenants = get_gated_tenants()
|
||||
|
||||
for tenant_id in tenant_ids:
|
||||
|
||||
# Same comment here as the above NOTE
|
||||
# if tenant_id in gated_tenants:
|
||||
# continue
|
||||
|
||||
@@ -424,10 +424,7 @@ def connector_permission_sync_generator_task(
|
||||
raise ValueError(error_msg)
|
||||
|
||||
if not redis_connector.permissions.fenced: # The fence must exist
|
||||
error_msg = (
|
||||
f"connector_permission_sync_generator_task - fence not found: "
|
||||
f"fence={redis_connector.permissions.fence_key}"
|
||||
)
|
||||
error_msg = f"connector_permission_sync_generator_task - fence not found: fence={redis_connector.permissions.fence_key}"
|
||||
_fail_doc_permission_sync_attempt(attempt_id, error_msg)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
@@ -441,8 +438,7 @@ def connector_permission_sync_generator_task(
|
||||
|
||||
if payload.celery_task_id is None:
|
||||
logger.info(
|
||||
f"connector_permission_sync_generator_task - Waiting for fence: "
|
||||
f"fence={redis_connector.permissions.fence_key}"
|
||||
f"connector_permission_sync_generator_task - Waiting for fence: fence={redis_connector.permissions.fence_key}"
|
||||
)
|
||||
sleep(1)
|
||||
continue
|
||||
@@ -608,8 +604,7 @@ def connector_permission_sync_generator_task(
|
||||
docs_with_permission_errors=docs_with_errors,
|
||||
)
|
||||
task_logger.info(
|
||||
f"Completed doc permission sync attempt {attempt_id}: "
|
||||
f"{tasks_generated} docs, {docs_with_errors} errors"
|
||||
f"Completed doc permission sync attempt {attempt_id}: {tasks_generated} docs, {docs_with_errors} errors"
|
||||
)
|
||||
|
||||
redis_connector.permissions.generator_complete = tasks_generated
|
||||
@@ -716,9 +711,7 @@ def element_update_permissions(
|
||||
|
||||
elapsed = time.monotonic() - start
|
||||
task_logger.info(
|
||||
f"{element_type}={element_id} "
|
||||
f"action=update_permissions "
|
||||
f"elapsed={elapsed:.2f}"
|
||||
f"{element_type}={element_id} action=update_permissions elapsed={elapsed:.2f}"
|
||||
)
|
||||
except Exception as e:
|
||||
task_logger.exception(
|
||||
@@ -900,8 +893,7 @@ def validate_permission_sync_fence(
|
||||
tasks_not_in_celery += 1
|
||||
|
||||
task_logger.info(
|
||||
"validate_permission_sync_fence task check: "
|
||||
f"tasks_scanned={tasks_scanned} tasks_not_in_celery={tasks_not_in_celery}"
|
||||
f"validate_permission_sync_fence task check: tasks_scanned={tasks_scanned} tasks_not_in_celery={tasks_not_in_celery}"
|
||||
)
|
||||
|
||||
# we're active if there are still tasks to run and those tasks all exist in celery
|
||||
@@ -1007,7 +999,10 @@ class PermissionSyncCallback(IndexingHeartbeatInterface):
|
||||
|
||||
|
||||
def monitor_ccpair_permissions_taskset(
|
||||
tenant_id: str, key_bytes: bytes, r: Redis, db_session: Session # noqa: ARG001
|
||||
tenant_id: str,
|
||||
key_bytes: bytes,
|
||||
r: Redis, # noqa: ARG001
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
|
||||
@@ -1031,8 +1026,7 @@ def monitor_ccpair_permissions_taskset(
|
||||
payload = redis_connector.permissions.payload
|
||||
except ValidationError:
|
||||
task_logger.exception(
|
||||
"Permissions sync payload failed to validate. "
|
||||
"Schema may have been updated."
|
||||
"Permissions sync payload failed to validate. Schema may have been updated."
|
||||
)
|
||||
return
|
||||
|
||||
@@ -1041,11 +1035,7 @@ def monitor_ccpair_permissions_taskset(
|
||||
|
||||
remaining = redis_connector.permissions.get_remaining()
|
||||
task_logger.info(
|
||||
f"Permissions sync progress: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"id={payload.id} "
|
||||
f"remaining={remaining} "
|
||||
f"initial={initial}"
|
||||
f"Permissions sync progress: cc_pair={cc_pair_id} id={payload.id} remaining={remaining} initial={initial}"
|
||||
)
|
||||
|
||||
# Add telemetry for permission syncing progress
|
||||
@@ -1064,10 +1054,7 @@ def monitor_ccpair_permissions_taskset(
|
||||
|
||||
mark_cc_pair_as_permissions_synced(db_session, int(cc_pair_id), payload.started)
|
||||
task_logger.info(
|
||||
f"Permissions sync finished: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"id={payload.id} "
|
||||
f"num_synced={initial}"
|
||||
f"Permissions sync finished: cc_pair={cc_pair_id} id={payload.id} num_synced={initial}"
|
||||
)
|
||||
|
||||
# Add telemetry for permission syncing complete
|
||||
|
||||
@@ -111,23 +111,20 @@ def _is_external_group_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
|
||||
|
||||
if cc_pair.access_type != AccessType.SYNC:
|
||||
task_logger.error(
|
||||
f"Received non-sync CC Pair {cc_pair.id} for external "
|
||||
f"group sync. Actual access type: {cc_pair.access_type}"
|
||||
f"Received non-sync CC Pair {cc_pair.id} for external group sync. Actual access type: {cc_pair.access_type}"
|
||||
)
|
||||
return False
|
||||
|
||||
if cc_pair.status == ConnectorCredentialPairStatus.DELETING:
|
||||
task_logger.debug(
|
||||
f"Skipping group sync for CC Pair {cc_pair.id} - "
|
||||
f"CC Pair is being deleted"
|
||||
f"Skipping group sync for CC Pair {cc_pair.id} - CC Pair is being deleted"
|
||||
)
|
||||
return False
|
||||
|
||||
sync_config = get_source_perm_sync_config(cc_pair.connector.source)
|
||||
if sync_config is None:
|
||||
task_logger.debug(
|
||||
f"Skipping group sync for CC Pair {cc_pair.id} - "
|
||||
f"no sync config found for {cc_pair.connector.source}"
|
||||
f"Skipping group sync for CC Pair {cc_pair.id} - no sync config found for {cc_pair.connector.source}"
|
||||
)
|
||||
return False
|
||||
|
||||
@@ -135,8 +132,7 @@ def _is_external_group_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
|
||||
# This is fine because all sources dont necessarily have a concept of groups
|
||||
if sync_config.group_sync_config is None:
|
||||
task_logger.debug(
|
||||
f"Skipping group sync for CC Pair {cc_pair.id} - "
|
||||
f"no group sync config found for {cc_pair.connector.source}"
|
||||
f"Skipping group sync for CC Pair {cc_pair.id} - no group sync config found for {cc_pair.connector.source}"
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
@@ -74,8 +74,7 @@ def perform_ttl_management_task(
|
||||
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"delete_chat_session exceptioned. "
|
||||
f"user_id={user_id} session_id={session_id}"
|
||||
f"delete_chat_session exceptioned. user_id={user_id} session_id={session_id}"
|
||||
)
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
mark_task_as_finished_with_id(
|
||||
|
||||
@@ -7,7 +7,8 @@ QUERY_HISTORY_TASK_NAME_PREFIX = OnyxCeleryTask.EXPORT_QUERY_HISTORY_TASK
|
||||
|
||||
|
||||
def name_chat_ttl_task(
|
||||
retention_limit_days: float, tenant_id: str | None = None # noqa: ARG001
|
||||
retention_limit_days: float,
|
||||
tenant_id: str | None = None, # noqa: ARG001
|
||||
) -> str:
|
||||
return f"chat_ttl_{retention_limit_days}_days"
|
||||
|
||||
|
||||
@@ -31,7 +31,8 @@ def fetch_query_analytics(
|
||||
func.sum(case((ChatMessageFeedback.is_positive, 1), else_=0)),
|
||||
func.sum(
|
||||
case(
|
||||
(ChatMessageFeedback.is_positive == False, 1), else_=0 # noqa: E712
|
||||
(ChatMessageFeedback.is_positive == False, 1), # noqa: E712
|
||||
else_=0, # noqa: E712
|
||||
)
|
||||
),
|
||||
cast(ChatMessage.time_sent, Date),
|
||||
@@ -66,7 +67,8 @@ def fetch_per_user_query_analytics(
|
||||
func.sum(case((ChatMessageFeedback.is_positive, 1), else_=0)),
|
||||
func.sum(
|
||||
case(
|
||||
(ChatMessageFeedback.is_positive == False, 1), else_=0 # noqa: E712
|
||||
(ChatMessageFeedback.is_positive == False, 1), # noqa: E712
|
||||
else_=0, # noqa: E712
|
||||
)
|
||||
),
|
||||
cast(ChatMessage.time_sent, Date),
|
||||
|
||||
@@ -23,8 +23,7 @@ def _delete_connector_credential_pair_user_groups_relationship__no_commit(
|
||||
)
|
||||
if cc_pair is None:
|
||||
raise ValueError(
|
||||
f"ConnectorCredentialPair with connector_id: {connector_id} "
|
||||
f"and credential_id: {credential_id} not found"
|
||||
f"ConnectorCredentialPair with connector_id: {connector_id} and credential_id: {credential_id} not found"
|
||||
)
|
||||
|
||||
stmt = delete(UserGroup__ConnectorCredentialPair).where(
|
||||
|
||||
@@ -123,8 +123,7 @@ def upsert_external_groups(
|
||||
user_id = email_id_map.get(user_email.lower())
|
||||
if user_id is None:
|
||||
logger.warning(
|
||||
f"User in group {external_group.id}"
|
||||
f" with email {user_email} not found"
|
||||
f"User in group {external_group.id} with email {user_email} not found"
|
||||
)
|
||||
continue
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ from onyx.db.models import HierarchyNode
|
||||
|
||||
|
||||
def _build_hierarchy_access_filter(
|
||||
user_email: str | None,
|
||||
user_email: str,
|
||||
external_group_ids: list[str],
|
||||
) -> ColumnElement[bool]:
|
||||
"""Build SQLAlchemy filter for hierarchy node access.
|
||||
@@ -43,7 +43,7 @@ def _build_hierarchy_access_filter(
|
||||
def _get_accessible_hierarchy_nodes_for_source(
|
||||
db_session: Session,
|
||||
source: DocumentSource,
|
||||
user_email: str | None,
|
||||
user_email: str,
|
||||
external_group_ids: list[str],
|
||||
) -> list[HierarchyNode]:
|
||||
"""
|
||||
|
||||
@@ -7,6 +7,7 @@ from onyx.db.models import Persona
|
||||
from onyx.db.models import Persona__User
|
||||
from onyx.db.models import Persona__UserGroup
|
||||
from onyx.db.notification import create_notification
|
||||
from onyx.db.persona import mark_persona_user_files_for_sync
|
||||
from onyx.server.features.persona.models import PersonaSharedNotificationData
|
||||
|
||||
|
||||
@@ -26,7 +27,9 @@ def update_persona_access(
|
||||
|
||||
NOTE: Callers are responsible for committing."""
|
||||
|
||||
needs_sync = False
|
||||
if is_public is not None:
|
||||
needs_sync = True
|
||||
persona = db_session.query(Persona).filter(Persona.id == persona_id).first()
|
||||
if persona:
|
||||
persona.is_public = is_public
|
||||
@@ -35,6 +38,7 @@ def update_persona_access(
|
||||
# and a non-empty list means "replace with these shares".
|
||||
|
||||
if user_ids is not None:
|
||||
needs_sync = True
|
||||
db_session.query(Persona__User).filter(
|
||||
Persona__User.persona_id == persona_id
|
||||
).delete(synchronize_session="fetch")
|
||||
@@ -54,6 +58,7 @@ def update_persona_access(
|
||||
)
|
||||
|
||||
if group_ids is not None:
|
||||
needs_sync = True
|
||||
db_session.query(Persona__UserGroup).filter(
|
||||
Persona__UserGroup.persona_id == persona_id
|
||||
).delete(synchronize_session="fetch")
|
||||
@@ -63,3 +68,7 @@ def update_persona_access(
|
||||
db_session.add(
|
||||
Persona__UserGroup(persona_id=persona_id, user_group_id=group_id)
|
||||
)
|
||||
|
||||
# When sharing changes, user file ACLs need to be updated in the vector DB
|
||||
if needs_sync:
|
||||
mark_persona_user_files_for_sync(persona_id, db_session)
|
||||
|
||||
@@ -191,8 +191,7 @@ def create_initial_default_standard_answer_category(db_session: Session) -> None
|
||||
if default_category is not None:
|
||||
if default_category.name != default_category_name:
|
||||
raise ValueError(
|
||||
"DB is not in a valid initial state. "
|
||||
"Default standard answer category does not have expected name."
|
||||
"DB is not in a valid initial state. Default standard answer category does not have expected name."
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
@@ -424,8 +424,7 @@ def fetch_user_groups_for_documents(
|
||||
def _check_user_group_is_modifiable(user_group: UserGroup) -> None:
|
||||
if not user_group.is_up_to_date:
|
||||
raise ValueError(
|
||||
"Specified user group is currently syncing. Wait until the current "
|
||||
"sync has finished before editing."
|
||||
"Specified user group is currently syncing. Wait until the current sync has finished before editing."
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -56,8 +56,7 @@ def _run_with_retry(
|
||||
if retry_count < MAX_RETRY_COUNT:
|
||||
sleep_after_rate_limit_exception(github_client)
|
||||
logger.warning(
|
||||
f"Rate limit exceeded while {description}. Retrying... "
|
||||
f"(attempt {retry_count + 1}/{MAX_RETRY_COUNT})"
|
||||
f"Rate limit exceeded while {description}. Retrying... (attempt {retry_count + 1}/{MAX_RETRY_COUNT})"
|
||||
)
|
||||
return _run_with_retry(
|
||||
operation, description, github_client, retry_count + 1
|
||||
@@ -91,7 +90,9 @@ class TeamInfo(BaseModel):
|
||||
|
||||
|
||||
def _fetch_organization_members(
|
||||
github_client: Github, org_name: str, retry_count: int = 0 # noqa: ARG001
|
||||
github_client: Github,
|
||||
org_name: str,
|
||||
retry_count: int = 0, # noqa: ARG001
|
||||
) -> List[UserInfo]:
|
||||
"""Fetch all organization members including owners and regular members."""
|
||||
org_members: List[UserInfo] = []
|
||||
@@ -124,7 +125,9 @@ def _fetch_organization_members(
|
||||
|
||||
|
||||
def _fetch_repository_teams_detailed(
|
||||
repo: Repository, github_client: Github, retry_count: int = 0 # noqa: ARG001
|
||||
repo: Repository,
|
||||
github_client: Github,
|
||||
retry_count: int = 0, # noqa: ARG001
|
||||
) -> List[TeamInfo]:
|
||||
"""Fetch teams with access to the repository and their members."""
|
||||
teams_data: List[TeamInfo] = []
|
||||
@@ -167,7 +170,9 @@ def _fetch_repository_teams_detailed(
|
||||
|
||||
|
||||
def fetch_repository_team_slugs(
|
||||
repo: Repository, github_client: Github, retry_count: int = 0 # noqa: ARG001
|
||||
repo: Repository,
|
||||
github_client: Github,
|
||||
retry_count: int = 0, # noqa: ARG001
|
||||
) -> List[str]:
|
||||
"""Fetch team slugs with access to the repository."""
|
||||
logger.info(f"Fetching team slugs for repository {repo.full_name}")
|
||||
|
||||
@@ -68,6 +68,7 @@ def get_external_access_for_raw_gdrive_file(
|
||||
company_domain: str,
|
||||
retriever_drive_service: GoogleDriveService | None,
|
||||
admin_drive_service: GoogleDriveService,
|
||||
fallback_user_email: str,
|
||||
add_prefix: bool = False,
|
||||
) -> ExternalAccess:
|
||||
"""
|
||||
@@ -79,6 +80,11 @@ def get_external_access_for_raw_gdrive_file(
|
||||
set add_prefix to True so group IDs are prefixed with the source type.
|
||||
When invoked from doc_sync (permission sync), use the default (False)
|
||||
since upsert_document_external_perms handles prefixing.
|
||||
fallback_user_email: When we cannot retrieve any permission info for a file
|
||||
(e.g. externally-owned files where the API returns no permissions
|
||||
and permissions.list returns 403), fall back to granting access
|
||||
to this user. This is typically the impersonated org user whose
|
||||
drive contained the file.
|
||||
"""
|
||||
doc_id = file.get("id")
|
||||
if not doc_id:
|
||||
@@ -109,14 +115,33 @@ def get_external_access_for_raw_gdrive_file(
|
||||
)
|
||||
if len(permissions_list) != len(permission_ids) and retriever_drive_service:
|
||||
logger.warning(
|
||||
f"Failed to get all permissions for file {doc_id} with retriever service, "
|
||||
"trying admin service"
|
||||
f"Failed to get all permissions for file {doc_id} with retriever service, trying admin service"
|
||||
)
|
||||
backup_permissions_list = _get_permissions(admin_drive_service)
|
||||
permissions_list = _merge_permissions_lists(
|
||||
[permissions_list, backup_permissions_list]
|
||||
)
|
||||
|
||||
# For externally-owned files, the Drive API may return no permissions
|
||||
# and permissions.list may return 403. In this case, fall back to
|
||||
# granting access to the user who found the file in their drive.
|
||||
# Note, even if other users also have access to this file,
|
||||
# they will not be granted access in Onyx.
|
||||
# We check permissions_list (the final result after all fetch attempts)
|
||||
# rather than the raw fields, because permission_ids may be present
|
||||
# but the actual fetch can still return empty due to a 403.
|
||||
if not permissions_list:
|
||||
logger.info(
|
||||
f"No permission info available for file {doc_id} "
|
||||
f"(likely owned by a user outside of your organization). "
|
||||
f"Falling back to granting access to retriever user: {fallback_user_email}"
|
||||
)
|
||||
return ExternalAccess(
|
||||
external_user_emails={fallback_user_email},
|
||||
external_user_group_ids=set(),
|
||||
is_public=False,
|
||||
)
|
||||
|
||||
folder_ids_to_inherit_permissions_from: set[str] = set()
|
||||
user_emails: set[str] = set()
|
||||
group_emails: set[str] = set()
|
||||
@@ -140,9 +165,7 @@ def get_external_access_for_raw_gdrive_file(
|
||||
user_emails.add(permission.email_address)
|
||||
else:
|
||||
logger.error(
|
||||
"Permission is type `user` but no email address is "
|
||||
f"provided for document {doc_id}"
|
||||
f"\n {permission}"
|
||||
f"Permission is type `user` but no email address is provided for document {doc_id}\n {permission}"
|
||||
)
|
||||
elif permission.type == PermissionType.GROUP:
|
||||
# groups are represented as email addresses within Drive
|
||||
@@ -150,17 +173,14 @@ def get_external_access_for_raw_gdrive_file(
|
||||
group_emails.add(permission.email_address)
|
||||
else:
|
||||
logger.error(
|
||||
"Permission is type `group` but no email address is "
|
||||
f"provided for document {doc_id}"
|
||||
f"\n {permission}"
|
||||
f"Permission is type `group` but no email address is provided for document {doc_id}\n {permission}"
|
||||
)
|
||||
elif permission.type == PermissionType.DOMAIN and company_domain:
|
||||
if permission.domain == company_domain:
|
||||
public = True
|
||||
else:
|
||||
logger.warning(
|
||||
"Permission is type domain but does not match company domain:"
|
||||
f"\n {permission}"
|
||||
f"Permission is type domain but does not match company domain:\n {permission}"
|
||||
)
|
||||
elif permission.type == PermissionType.ANYONE:
|
||||
public = True
|
||||
|
||||
@@ -18,10 +18,7 @@ logger = setup_logger()
|
||||
# Only include fields we need - folder ID and permissions
|
||||
# IMPORTANT: must fetch permissionIds, since sometimes the drive API
|
||||
# seems to miss permissions when requesting them directly
|
||||
FOLDER_PERMISSION_FIELDS = (
|
||||
"nextPageToken, files(id, name, permissionIds, "
|
||||
"permissions(id, emailAddress, type, domain, permissionDetails))"
|
||||
)
|
||||
FOLDER_PERMISSION_FIELDS = "nextPageToken, files(id, name, permissionIds, permissions(id, emailAddress, type, domain, permissionDetails))"
|
||||
|
||||
|
||||
def get_folder_permissions_by_ids(
|
||||
|
||||
@@ -142,8 +142,7 @@ def _drive_folder_to_onyx_group(
|
||||
elif permission.type == PermissionType.GROUP:
|
||||
if permission.email_address not in group_email_to_member_emails_map:
|
||||
logger.warning(
|
||||
f"Group email {permission.email_address} for folder {folder.id} "
|
||||
"not found in group_email_to_member_emails_map"
|
||||
f"Group email {permission.email_address} for folder {folder.id} not found in group_email_to_member_emails_map"
|
||||
)
|
||||
continue
|
||||
folder_member_emails.update(
|
||||
@@ -238,8 +237,7 @@ def _drive_member_map_to_onyx_groups(
|
||||
for group_email in group_emails:
|
||||
if group_email not in group_email_to_member_emails_map:
|
||||
logger.warning(
|
||||
f"Group email {group_email} for drive {drive_id} not found in "
|
||||
"group_email_to_member_emails_map"
|
||||
f"Group email {group_email} for drive {drive_id} not found in group_email_to_member_emails_map"
|
||||
)
|
||||
continue
|
||||
drive_member_emails.update(group_email_to_member_emails_map[group_email])
|
||||
@@ -326,8 +324,7 @@ def _build_onyx_groups(
|
||||
for group_email in group_emails:
|
||||
if group_email not in group_email_to_member_emails_map:
|
||||
logger.warning(
|
||||
f"Group email {group_email} for drive {drive_id} not found in "
|
||||
"group_email_to_member_emails_map"
|
||||
f"Group email {group_email} for drive {drive_id} not found in group_email_to_member_emails_map"
|
||||
)
|
||||
continue
|
||||
drive_member_emails.update(group_email_to_member_emails_map[group_email])
|
||||
|
||||
@@ -55,8 +55,7 @@ def get_permissions_by_ids(
|
||||
if len(filtered_permissions) < len(permission_ids):
|
||||
missing_ids = permission_id_set - {p.id for p in filtered_permissions if p.id}
|
||||
logger.warning(
|
||||
f"Could not find all requested permission IDs for document {doc_id}. "
|
||||
f"Missing IDs: {missing_ids}"
|
||||
f"Could not find all requested permission IDs for document {doc_id}. Missing IDs: {missing_ids}"
|
||||
)
|
||||
|
||||
return filtered_permissions
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
from jira import JIRA
|
||||
from jira.exceptions import JIRAError
|
||||
|
||||
from ee.onyx.db.external_perm import ExternalUserGroup
|
||||
from onyx.connectors.jira.utils import build_jira_client
|
||||
@@ -9,107 +11,101 @@ from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_ATLASSIAN_ACCOUNT_TYPE = "atlassian"
|
||||
_GROUP_MEMBER_PAGE_SIZE = 50
|
||||
|
||||
def _get_jira_group_members_email(
|
||||
# The GET /group/member endpoint was introduced in Jira 6.0.
|
||||
# Jira versions older than 6.0 do not have group management REST APIs at all.
|
||||
_MIN_JIRA_VERSION_FOR_GROUP_MEMBER = "6.0"
|
||||
|
||||
|
||||
def _fetch_group_member_page(
|
||||
jira_client: JIRA,
|
||||
group_name: str,
|
||||
) -> list[str]:
|
||||
"""Get all member emails for a Jira group.
|
||||
start_at: int,
|
||||
) -> dict[str, Any]:
|
||||
"""Fetch a single page from the non-deprecated GET /group/member endpoint.
|
||||
|
||||
Filters out app accounts (bots, integrations) and only returns real user emails.
|
||||
The old GET /group endpoint (used by jira_client.group_members()) is deprecated
|
||||
and decommissioned in Jira Server 10.3+. This uses the replacement endpoint
|
||||
directly via the library's internal _get_json helper, following the same pattern
|
||||
as enhanced_search_ids / bulk_fetch_issues in connector.py.
|
||||
|
||||
There is an open PR to the library to switch to this endpoint since last year:
|
||||
https://github.com/pycontribs/jira/pull/2356
|
||||
so once it is merged and released, we can switch to using the library function.
|
||||
"""
|
||||
emails: list[str] = []
|
||||
|
||||
try:
|
||||
# group_members returns an OrderedDict of account_id -> member_info
|
||||
members = jira_client.group_members(group=group_name)
|
||||
|
||||
if not members:
|
||||
logger.warning(f"No members found for group {group_name}")
|
||||
return emails
|
||||
|
||||
for account_id, member_info in members.items():
|
||||
# member_info is a dict with keys like 'fullname', 'email', 'active'
|
||||
email = member_info.get("email")
|
||||
|
||||
# Skip "hidden" emails - these are typically app accounts
|
||||
if email and email != "hidden":
|
||||
emails.append(email)
|
||||
else:
|
||||
# For cloud, we might need to fetch user details separately
|
||||
try:
|
||||
user = jira_client.user(id=account_id)
|
||||
|
||||
# Skip app accounts (bots, integrations, etc.)
|
||||
if hasattr(user, "accountType") and user.accountType == "app":
|
||||
logger.info(
|
||||
f"Skipping app account {account_id} for group {group_name}"
|
||||
)
|
||||
continue
|
||||
|
||||
if hasattr(user, "emailAddress") and user.emailAddress:
|
||||
emails.append(user.emailAddress)
|
||||
else:
|
||||
logger.warning(f"User {account_id} has no email address")
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Could not fetch email for user {account_id} in group {group_name}: {e}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching members for group {group_name}: {e}")
|
||||
|
||||
return emails
|
||||
return jira_client._get_json(
|
||||
"group/member",
|
||||
params={
|
||||
"groupname": group_name,
|
||||
"includeInactiveUsers": "false",
|
||||
"startAt": start_at,
|
||||
"maxResults": _GROUP_MEMBER_PAGE_SIZE,
|
||||
},
|
||||
)
|
||||
except JIRAError as e:
|
||||
if e.status_code == 404:
|
||||
raise RuntimeError(
|
||||
f"GET /group/member returned 404 for group '{group_name}'. "
|
||||
f"This endpoint requires Jira {_MIN_JIRA_VERSION_FOR_GROUP_MEMBER}+. "
|
||||
f"If you are running a self-hosted Jira instance, please upgrade "
|
||||
f"to at least Jira {_MIN_JIRA_VERSION_FOR_GROUP_MEMBER}."
|
||||
) from e
|
||||
raise
|
||||
|
||||
|
||||
def _build_group_member_email_map(
|
||||
def _get_group_member_emails(
|
||||
jira_client: JIRA,
|
||||
) -> dict[str, set[str]]:
|
||||
"""Build a map of group names to member emails."""
|
||||
group_member_emails: dict[str, set[str]] = {}
|
||||
group_name: str,
|
||||
) -> set[str]:
|
||||
"""Get all member emails for a single Jira group.
|
||||
|
||||
try:
|
||||
# Get all groups from Jira - returns a list of group name strings
|
||||
group_names = jira_client.groups()
|
||||
Uses the non-deprecated GET /group/member endpoint which returns full user
|
||||
objects including accountType, so we can filter out app/customer accounts
|
||||
without making separate user() calls.
|
||||
"""
|
||||
emails: set[str] = set()
|
||||
start_at = 0
|
||||
|
||||
if not group_names:
|
||||
logger.warning("No groups found in Jira")
|
||||
return group_member_emails
|
||||
while True:
|
||||
try:
|
||||
page = _fetch_group_member_page(jira_client, group_name, start_at)
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching members for group {group_name}: {e}")
|
||||
raise
|
||||
|
||||
logger.info(f"Found {len(group_names)} groups in Jira")
|
||||
|
||||
for group_name in group_names:
|
||||
if not group_name:
|
||||
members: list[dict[str, Any]] = page.get("values", [])
|
||||
for member in members:
|
||||
account_type = member.get("accountType")
|
||||
# On Jira DC < 9.0, accountType is absent; include those users.
|
||||
# On Cloud / DC 9.0+, filter to real user accounts only.
|
||||
if account_type is not None and account_type != _ATLASSIAN_ACCOUNT_TYPE:
|
||||
continue
|
||||
|
||||
member_emails = _get_jira_group_members_email(
|
||||
jira_client=jira_client,
|
||||
group_name=group_name,
|
||||
)
|
||||
|
||||
if member_emails:
|
||||
group_member_emails[group_name] = set(member_emails)
|
||||
logger.debug(
|
||||
f"Found {len(member_emails)} members for group {group_name}"
|
||||
)
|
||||
email = member.get("emailAddress")
|
||||
if email:
|
||||
emails.add(email)
|
||||
else:
|
||||
logger.debug(f"No members found for group {group_name}")
|
||||
logger.warning(
|
||||
f"Atlassian user {member.get('accountId', 'unknown')} in group {group_name} has no visible email address"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error building group member email map: {e}")
|
||||
if page.get("isLast", True) or not members:
|
||||
break
|
||||
start_at += len(members)
|
||||
|
||||
return group_member_emails
|
||||
return emails
|
||||
|
||||
|
||||
def jira_group_sync(
|
||||
tenant_id: str, # noqa: ARG001
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
) -> Generator[ExternalUserGroup, None, None]:
|
||||
"""
|
||||
Sync Jira groups and their members.
|
||||
"""Sync Jira groups and their members, yielding one group at a time.
|
||||
|
||||
This function fetches all groups from Jira and yields ExternalUserGroup
|
||||
objects containing the group ID and member emails.
|
||||
Streams group-by-group rather than accumulating all groups in memory.
|
||||
"""
|
||||
jira_base_url = cc_pair.connector.connector_specific_config.get("jira_base_url", "")
|
||||
scoped_token = cc_pair.connector.connector_specific_config.get(
|
||||
@@ -130,12 +126,26 @@ def jira_group_sync(
|
||||
scoped_token=scoped_token,
|
||||
)
|
||||
|
||||
group_member_email_map = _build_group_member_email_map(jira_client=jira_client)
|
||||
if not group_member_email_map:
|
||||
raise ValueError(f"No groups with members found for cc_pair_id={cc_pair.id}")
|
||||
group_names = jira_client.groups()
|
||||
if not group_names:
|
||||
raise ValueError(f"No groups found for cc_pair_id={cc_pair.id}")
|
||||
|
||||
for group_id, group_member_emails in group_member_email_map.items():
|
||||
yield ExternalUserGroup(
|
||||
id=group_id,
|
||||
user_emails=list(group_member_emails),
|
||||
logger.info(f"Found {len(group_names)} groups in Jira")
|
||||
|
||||
for group_name in group_names:
|
||||
if not group_name:
|
||||
continue
|
||||
|
||||
member_emails = _get_group_member_emails(
|
||||
jira_client=jira_client,
|
||||
group_name=group_name,
|
||||
)
|
||||
if not member_emails:
|
||||
logger.debug(f"No members found for group {group_name}")
|
||||
continue
|
||||
|
||||
logger.debug(f"Found {len(member_emails)} members for group {group_name}")
|
||||
yield ExternalUserGroup(
|
||||
id=group_name,
|
||||
user_emails=list(member_emails),
|
||||
)
|
||||
|
||||
@@ -69,8 +69,7 @@ def _post_query_chunk_censoring(
|
||||
censored_chunks = censor_chunks_for_source(chunks_for_source, user.email)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Failed to censor chunks for source {source} so throwing out all"
|
||||
f" chunks for this source and continuing: {e}"
|
||||
f"Failed to censor chunks for source {source} so throwing out all chunks for this source and continuing: {e}"
|
||||
)
|
||||
continue
|
||||
|
||||
|
||||
@@ -23,7 +23,9 @@ ContentRange = tuple[int, int | None] # (start_index, end_index) None means to
|
||||
|
||||
# NOTE: Used for testing timing
|
||||
def _get_dummy_object_access_map(
|
||||
object_ids: set[str], user_email: str, chunks: list[InferenceChunk] # noqa: ARG001
|
||||
object_ids: set[str],
|
||||
user_email: str, # noqa: ARG001
|
||||
chunks: list[InferenceChunk], # noqa: ARG001
|
||||
) -> dict[str, bool]:
|
||||
time.sleep(0.15)
|
||||
# return {object_id: True for object_id in object_ids}
|
||||
|
||||
@@ -61,8 +61,7 @@ def _graph_api_get(
|
||||
):
|
||||
wait = min(int(resp.headers.get("Retry-After", str(2**attempt))), 60)
|
||||
logger.warning(
|
||||
f"Graph API {resp.status_code} on attempt {attempt + 1}, "
|
||||
f"retrying in {wait}s: {url}"
|
||||
f"Graph API {resp.status_code} on attempt {attempt + 1}, retrying in {wait}s: {url}"
|
||||
)
|
||||
time.sleep(wait)
|
||||
continue
|
||||
@@ -72,8 +71,7 @@ def _graph_api_get(
|
||||
if attempt < GRAPH_API_MAX_RETRIES:
|
||||
wait = min(2**attempt, 60)
|
||||
logger.warning(
|
||||
f"Graph API connection error on attempt {attempt + 1}, "
|
||||
f"retrying in {wait}s: {url}"
|
||||
f"Graph API connection error on attempt {attempt + 1}, retrying in {wait}s: {url}"
|
||||
)
|
||||
time.sleep(wait)
|
||||
continue
|
||||
@@ -767,8 +765,7 @@ def get_sharepoint_external_groups(
|
||||
|
||||
if not enumerate_all_ad_groups or get_access_token is None:
|
||||
logger.info(
|
||||
"Skipping exhaustive Azure AD group enumeration. "
|
||||
"Only groups found in site role assignments are included."
|
||||
"Skipping exhaustive Azure AD group enumeration. Only groups found in site role assignments are included."
|
||||
)
|
||||
return external_user_groups
|
||||
|
||||
|
||||
@@ -166,8 +166,7 @@ def slack_doc_sync(
|
||||
user_id_to_email_map = fetch_user_id_to_email_map(slack_client)
|
||||
if not user_id_to_email_map:
|
||||
raise ValueError(
|
||||
"No user id to email map found. Please check to make sure that "
|
||||
"your Slack bot token has the `users:read.email` scope"
|
||||
"No user id to email map found. Please check to make sure that your Slack bot token has the `users:read.email` scope"
|
||||
)
|
||||
|
||||
workspace_permissions = _fetch_workspace_permissions(
|
||||
|
||||
@@ -152,10 +152,7 @@ def create_new_usage_report(
|
||||
zip_buffer.seek(0)
|
||||
|
||||
# store zip blob to file_store
|
||||
report_name = (
|
||||
f"{datetime.now(tz=timezone.utc).strftime('%Y-%m-%d')}"
|
||||
f"_{report_id}_usage_report.zip"
|
||||
)
|
||||
report_name = f"{datetime.now(tz=timezone.utc).strftime('%Y-%m-%d')}_{report_id}_usage_report.zip"
|
||||
file_store.save_file(
|
||||
content=zip_buffer,
|
||||
display_name=report_name,
|
||||
|
||||
@@ -449,8 +449,7 @@ def _apply_group_remove(
|
||||
match = _MEMBER_FILTER_RE.match(op.path)
|
||||
if not match:
|
||||
raise ScimPatchError(
|
||||
f"Unsupported remove path '{op.path}'. "
|
||||
'Expected: members[value eq "user-id"]'
|
||||
f"Unsupported remove path '{op.path}'. Expected: members[value eq \"user-id\"]"
|
||||
)
|
||||
|
||||
target_id = match.group(1)
|
||||
|
||||
@@ -26,6 +26,7 @@ from onyx.db.models import Tool
|
||||
from onyx.db.persona import upsert_persona
|
||||
from onyx.server.features.persona.models import PersonaUpsertRequest
|
||||
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from onyx.server.manage.llm.models import LLMProviderView
|
||||
from onyx.server.settings.models import Settings
|
||||
from onyx.server.settings.store import store_settings as store_base_settings
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -125,10 +126,16 @@ def _seed_llms(
|
||||
existing = fetch_existing_llm_provider(name=request.name, db_session=db_session)
|
||||
if existing:
|
||||
request.id = existing.id
|
||||
seeded_providers = [
|
||||
upsert_llm_provider(llm_upsert_request, db_session)
|
||||
for llm_upsert_request in llm_upsert_requests
|
||||
]
|
||||
seeded_providers: list[LLMProviderView] = []
|
||||
for llm_upsert_request in llm_upsert_requests:
|
||||
try:
|
||||
seeded_providers.append(upsert_llm_provider(llm_upsert_request, db_session))
|
||||
except ValueError as e:
|
||||
logger.warning(
|
||||
"Failed to upsert LLM provider '%s' during seeding: %s",
|
||||
llm_upsert_request.name,
|
||||
e,
|
||||
)
|
||||
|
||||
default_provider = next(
|
||||
(p for p in seeded_providers if p.model_configurations), None
|
||||
|
||||
@@ -123,7 +123,8 @@ async def get_or_provision_tenant(
|
||||
|
||||
|
||||
async def create_tenant(
|
||||
email: str, referral_source: str | None = None # noqa: ARG001
|
||||
email: str,
|
||||
referral_source: str | None = None, # noqa: ARG001
|
||||
) -> str:
|
||||
"""
|
||||
Create a new tenant on-demand when no pre-provisioned tenants are available.
|
||||
@@ -679,7 +680,9 @@ async def setup_tenant(tenant_id: str) -> None:
|
||||
|
||||
|
||||
async def assign_tenant_to_user(
|
||||
tenant_id: str, email: str, referral_source: str | None = None # noqa: ARG001
|
||||
tenant_id: str,
|
||||
email: str,
|
||||
referral_source: str | None = None, # noqa: ARG001
|
||||
) -> None:
|
||||
"""
|
||||
Assign a tenant to a user and perform necessary operations.
|
||||
|
||||
@@ -14,67 +14,90 @@ from onyx.utils.variable_functionality import fetch_versioned_implementation
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
@lru_cache(maxsize=2)
|
||||
def _get_trimmed_key(key: str) -> bytes:
|
||||
encoded_key = key.encode()
|
||||
key_length = len(encoded_key)
|
||||
if key_length < 16:
|
||||
raise RuntimeError("Invalid ENCRYPTION_KEY_SECRET - too short")
|
||||
elif key_length > 32:
|
||||
key = key[:32]
|
||||
elif key_length not in (16, 24, 32):
|
||||
valid_lengths = [16, 24, 32]
|
||||
key = key[: min(valid_lengths, key=lambda x: abs(x - key_length))]
|
||||
|
||||
return encoded_key
|
||||
# Trim to the largest valid AES key size that fits
|
||||
valid_lengths = [32, 24, 16]
|
||||
for size in valid_lengths:
|
||||
if key_length >= size:
|
||||
return encoded_key[:size]
|
||||
|
||||
raise AssertionError("unreachable")
|
||||
|
||||
|
||||
def _encrypt_string(input_str: str) -> bytes:
|
||||
if not ENCRYPTION_KEY_SECRET:
|
||||
def _encrypt_string(input_str: str, key: str | None = None) -> bytes:
|
||||
effective_key = key if key is not None else ENCRYPTION_KEY_SECRET
|
||||
if not effective_key:
|
||||
return input_str.encode()
|
||||
|
||||
key = _get_trimmed_key(ENCRYPTION_KEY_SECRET)
|
||||
trimmed = _get_trimmed_key(effective_key)
|
||||
iv = urandom(16)
|
||||
padder = padding.PKCS7(algorithms.AES.block_size).padder()
|
||||
padded_data = padder.update(input_str.encode()) + padder.finalize()
|
||||
|
||||
cipher = Cipher(algorithms.AES(key), modes.CBC(iv), backend=default_backend())
|
||||
cipher = Cipher(algorithms.AES(trimmed), modes.CBC(iv), backend=default_backend())
|
||||
encryptor = cipher.encryptor()
|
||||
encrypted_data = encryptor.update(padded_data) + encryptor.finalize()
|
||||
|
||||
return iv + encrypted_data
|
||||
|
||||
|
||||
def _decrypt_bytes(input_bytes: bytes) -> str:
|
||||
if not ENCRYPTION_KEY_SECRET:
|
||||
def _decrypt_bytes(input_bytes: bytes, key: str | None = None) -> str:
|
||||
effective_key = key if key is not None else ENCRYPTION_KEY_SECRET
|
||||
if not effective_key:
|
||||
return input_bytes.decode()
|
||||
|
||||
key = _get_trimmed_key(ENCRYPTION_KEY_SECRET)
|
||||
iv = input_bytes[:16]
|
||||
encrypted_data = input_bytes[16:]
|
||||
trimmed = _get_trimmed_key(effective_key)
|
||||
try:
|
||||
iv = input_bytes[:16]
|
||||
encrypted_data = input_bytes[16:]
|
||||
|
||||
cipher = Cipher(algorithms.AES(key), modes.CBC(iv), backend=default_backend())
|
||||
decryptor = cipher.decryptor()
|
||||
decrypted_padded_data = decryptor.update(encrypted_data) + decryptor.finalize()
|
||||
cipher = Cipher(
|
||||
algorithms.AES(trimmed), modes.CBC(iv), backend=default_backend()
|
||||
)
|
||||
decryptor = cipher.decryptor()
|
||||
decrypted_padded_data = decryptor.update(encrypted_data) + decryptor.finalize()
|
||||
|
||||
unpadder = padding.PKCS7(algorithms.AES.block_size).unpadder()
|
||||
decrypted_data = unpadder.update(decrypted_padded_data) + unpadder.finalize()
|
||||
unpadder = padding.PKCS7(algorithms.AES.block_size).unpadder()
|
||||
decrypted_data = unpadder.update(decrypted_padded_data) + unpadder.finalize()
|
||||
|
||||
return decrypted_data.decode()
|
||||
return decrypted_data.decode()
|
||||
except (ValueError, UnicodeDecodeError):
|
||||
if key is not None:
|
||||
# Explicit key was provided — don't fall back silently
|
||||
raise
|
||||
# Read path: attempt raw UTF-8 decode as a fallback for legacy data.
|
||||
# Does NOT handle data encrypted with a different key — that
|
||||
# ciphertext is not valid UTF-8 and will raise below.
|
||||
logger.warning(
|
||||
"AES decryption failed — falling back to raw decode. Run the re-encrypt secrets script to rotate to the current key."
|
||||
)
|
||||
try:
|
||||
return input_bytes.decode()
|
||||
except UnicodeDecodeError:
|
||||
raise ValueError(
|
||||
"Data is not valid UTF-8 — likely encrypted with a different key. "
|
||||
"Run the re-encrypt secrets script to rotate to the current key."
|
||||
) from None
|
||||
|
||||
|
||||
def encrypt_string_to_bytes(input_str: str) -> bytes:
|
||||
def encrypt_string_to_bytes(input_str: str, key: str | None = None) -> bytes:
|
||||
versioned_encryption_fn = fetch_versioned_implementation(
|
||||
"onyx.utils.encryption", "_encrypt_string"
|
||||
)
|
||||
return versioned_encryption_fn(input_str)
|
||||
return versioned_encryption_fn(input_str, key=key)
|
||||
|
||||
|
||||
def decrypt_bytes_to_string(input_bytes: bytes) -> str:
|
||||
def decrypt_bytes_to_string(input_bytes: bytes, key: str | None = None) -> str:
|
||||
versioned_decryption_fn = fetch_versioned_implementation(
|
||||
"onyx.utils.encryption", "_decrypt_bytes"
|
||||
)
|
||||
return versioned_decryption_fn(input_bytes)
|
||||
return versioned_decryption_fn(input_bytes, key=key)
|
||||
|
||||
|
||||
def test_encryption() -> None:
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from collections.abc import Callable
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.access.models import DocumentAccess
|
||||
@@ -12,6 +11,7 @@ from onyx.db.document import get_access_info_for_document
|
||||
from onyx.db.document import get_access_info_for_documents
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.db.user_file import fetch_user_files_with_access_relationships
|
||||
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
from onyx.utils.variable_functionality import fetch_versioned_implementation
|
||||
|
||||
@@ -96,7 +96,9 @@ def get_access_for_documents(
|
||||
return versioned_get_access_for_documents_fn(document_ids, db_session)
|
||||
|
||||
|
||||
def _get_acl_for_user(user: User, db_session: Session) -> set[str]: # noqa: ARG001
|
||||
def _get_acl_for_user(
|
||||
user: User, db_session: Session # noqa: ARG001
|
||||
) -> set[str]: # noqa: ARG001
|
||||
"""Returns a list of ACL entries that the user has access to. This is meant to be
|
||||
used downstream to filter out documents that the user does not have access to. The
|
||||
user should have access to a document if at least one entry in the document's ACL
|
||||
@@ -132,19 +134,61 @@ def get_access_for_user_files(
|
||||
user_file_ids: list[str],
|
||||
db_session: Session,
|
||||
) -> dict[str, DocumentAccess]:
|
||||
user_files = (
|
||||
db_session.query(UserFile)
|
||||
.options(joinedload(UserFile.user)) # Eager load the user relationship
|
||||
.filter(UserFile.id.in_(user_file_ids))
|
||||
.all()
|
||||
versioned_fn = fetch_versioned_implementation(
|
||||
"onyx.access.access", "get_access_for_user_files_impl"
|
||||
)
|
||||
return {
|
||||
str(user_file.id): DocumentAccess.build(
|
||||
user_emails=[user_file.user.email] if user_file.user else [],
|
||||
return versioned_fn(user_file_ids, db_session)
|
||||
|
||||
|
||||
def get_access_for_user_files_impl(
|
||||
user_file_ids: list[str],
|
||||
db_session: Session,
|
||||
) -> dict[str, DocumentAccess]:
|
||||
user_files = fetch_user_files_with_access_relationships(user_file_ids, db_session)
|
||||
return build_access_for_user_files_impl(user_files)
|
||||
|
||||
|
||||
def build_access_for_user_files(
|
||||
user_files: list[UserFile],
|
||||
) -> dict[str, DocumentAccess]:
|
||||
"""Compute access from pre-loaded UserFile objects (with relationships).
|
||||
Callers must ensure UserFile.user, Persona.users, and Persona.user are
|
||||
eagerly loaded (and Persona.groups for the EE path)."""
|
||||
versioned_fn = fetch_versioned_implementation(
|
||||
"onyx.access.access", "build_access_for_user_files_impl"
|
||||
)
|
||||
return versioned_fn(user_files)
|
||||
|
||||
|
||||
def build_access_for_user_files_impl(
|
||||
user_files: list[UserFile],
|
||||
) -> dict[str, DocumentAccess]:
|
||||
result: dict[str, DocumentAccess] = {}
|
||||
for user_file in user_files:
|
||||
emails, is_public = collect_user_file_access(user_file)
|
||||
result[str(user_file.id)] = DocumentAccess.build(
|
||||
user_emails=list(emails),
|
||||
user_groups=[],
|
||||
is_public=True if user_file.user is None else False,
|
||||
is_public=is_public,
|
||||
external_user_emails=[],
|
||||
external_user_group_ids=[],
|
||||
)
|
||||
for user_file in user_files
|
||||
}
|
||||
return result
|
||||
|
||||
|
||||
def collect_user_file_access(user_file: UserFile) -> tuple[set[str], bool]:
|
||||
"""Collect all user emails that should have access to this user file.
|
||||
Includes the owner plus any users who have access via shared personas.
|
||||
Returns (emails, is_public)."""
|
||||
emails: set[str] = {user_file.user.email}
|
||||
is_public = False
|
||||
for persona in user_file.assistants:
|
||||
if persona.deleted:
|
||||
continue
|
||||
if persona.is_public:
|
||||
is_public = True
|
||||
if persona.user_id is not None and persona.user:
|
||||
emails.add(persona.user.email)
|
||||
for shared_user in persona.users:
|
||||
emails.add(shared_user.email)
|
||||
return emails, is_public
|
||||
|
||||
@@ -5,7 +5,8 @@ from onyx.utils.variable_functionality import fetch_versioned_implementation
|
||||
|
||||
|
||||
def _get_user_external_group_ids(
|
||||
db_session: Session, user: User # noqa: ARG001
|
||||
db_session: Session, # noqa: ARG001
|
||||
user: User, # noqa: ARG001
|
||||
) -> list[str]:
|
||||
return []
|
||||
|
||||
|
||||
@@ -8,7 +8,6 @@ from onyx.configs.constants import PUBLIC_DOC_PAT
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ExternalAccess:
|
||||
|
||||
# arbitrary limit to prevent excessively large permissions sets
|
||||
# not internally enforced ... the caller can check this before using the instance
|
||||
MAX_NUM_ENTRIES = 5000
|
||||
|
||||
@@ -96,8 +96,7 @@ async def verify_captcha_token(
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Captcha verification passed: score={result.score}, "
|
||||
f"action={result.action}"
|
||||
f"Captcha verification passed: score={result.score}, action={result.action}"
|
||||
)
|
||||
|
||||
except httpx.HTTPError as e:
|
||||
|
||||
@@ -353,20 +353,11 @@ def build_user_email_invite(
|
||||
"or login with Google and complete your registration.</p>"
|
||||
)
|
||||
elif auth_type == AuthType.BASIC:
|
||||
message += (
|
||||
"<p>To join the organization, please click the button below to set a password "
|
||||
"and complete your registration.</p>"
|
||||
)
|
||||
message += "<p>To join the organization, please click the button below to set a password and complete your registration.</p>"
|
||||
elif auth_type == AuthType.GOOGLE_OAUTH:
|
||||
message += (
|
||||
"<p>To join the organization, please click the button below to login with Google "
|
||||
"and complete your registration.</p>"
|
||||
)
|
||||
message += "<p>To join the organization, please click the button below to login with Google and complete your registration.</p>"
|
||||
elif auth_type == AuthType.OIDC or auth_type == AuthType.SAML:
|
||||
message += (
|
||||
"<p>To join the organization, please click the button below to"
|
||||
" complete your registration.</p>"
|
||||
)
|
||||
message += "<p>To join the organization, please click the button below to complete your registration.</p>"
|
||||
else:
|
||||
raise ValueError(f"Invalid auth type: {auth_type}")
|
||||
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
import base64
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import secrets
|
||||
import string
|
||||
@@ -28,6 +31,8 @@ from fastapi import Query
|
||||
from fastapi import Request
|
||||
from fastapi import Response
|
||||
from fastapi import status
|
||||
from fastapi import WebSocket
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.responses import RedirectResponse
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from fastapi_users import BaseUserManager
|
||||
@@ -54,6 +59,7 @@ from fastapi_users.router.common import ErrorModel
|
||||
from fastapi_users_db_sqlalchemy import SQLAlchemyUserDatabase
|
||||
from httpx_oauth.integrations.fastapi import OAuth2AuthorizeCallback
|
||||
from httpx_oauth.oauth2 import BaseOAuth2
|
||||
from httpx_oauth.oauth2 import GetAccessTokenError
|
||||
from httpx_oauth.oauth2 import OAuth2Token
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import nulls_last
|
||||
@@ -119,7 +125,12 @@ from onyx.db.models import Persona
|
||||
from onyx.db.models import User
|
||||
from onyx.db.pat import fetch_user_for_pat
|
||||
from onyx.db.users import get_user_by_email
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import log_onyx_error
|
||||
from onyx.error_handling.exceptions import onyx_error_to_json_response
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.redis.redis_pool import get_async_redis_connection
|
||||
from onyx.redis.redis_pool import retrieve_ws_token_data
|
||||
from onyx.server.settings.store import load_settings
|
||||
from onyx.server.utils import BasicAuthenticationError
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -145,10 +156,21 @@ def is_user_admin(user: User) -> bool:
|
||||
|
||||
|
||||
def verify_auth_setting() -> None:
|
||||
if AUTH_TYPE == AuthType.CLOUD:
|
||||
"""Log warnings for AUTH_TYPE issues.
|
||||
|
||||
This only runs on app startup not during migrations/scripts.
|
||||
"""
|
||||
raw_auth_type = (os.environ.get("AUTH_TYPE") or "").lower()
|
||||
|
||||
if raw_auth_type == "cloud":
|
||||
raise ValueError(
|
||||
f"{AUTH_TYPE.value} is not a valid auth type for self-hosted deployments."
|
||||
"'cloud' is not a valid auth type for self-hosted deployments."
|
||||
)
|
||||
if raw_auth_type == "disabled":
|
||||
logger.warning(
|
||||
"AUTH_TYPE='disabled' is no longer supported. Using 'basic' instead. Please update your configuration."
|
||||
)
|
||||
|
||||
logger.notice(f"Using Auth Type: {AUTH_TYPE.value}")
|
||||
|
||||
|
||||
@@ -589,8 +611,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
char in PASSWORD_SPECIAL_CHARS for char in password
|
||||
):
|
||||
raise exceptions.InvalidPasswordException(
|
||||
reason="Password must contain at least one special character from the following set: "
|
||||
f"{PASSWORD_SPECIAL_CHARS}."
|
||||
reason=f"Password must contain at least one special character from the following set: {PASSWORD_SPECIAL_CHARS}."
|
||||
)
|
||||
return
|
||||
|
||||
@@ -857,7 +878,10 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
)
|
||||
|
||||
async def on_after_forgot_password(
|
||||
self, user: User, token: str, request: Optional[Request] = None # noqa: ARG002
|
||||
self,
|
||||
user: User,
|
||||
token: str,
|
||||
request: Optional[Request] = None, # noqa: ARG002
|
||||
) -> None:
|
||||
if not EMAIL_CONFIGURED:
|
||||
logger.error(
|
||||
@@ -876,7 +900,10 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
send_forgot_password_email(user.email, tenant_id=tenant_id, token=token)
|
||||
|
||||
async def on_after_request_verify(
|
||||
self, user: User, token: str, request: Optional[Request] = None # noqa: ARG002
|
||||
self,
|
||||
user: User,
|
||||
token: str,
|
||||
request: Optional[Request] = None, # noqa: ARG002
|
||||
) -> None:
|
||||
verify_email_domain(user.email)
|
||||
|
||||
@@ -1172,7 +1199,9 @@ class SingleTenantJWTStrategy(JWTStrategy[User, uuid.UUID]):
|
||||
return
|
||||
|
||||
async def refresh_token(
|
||||
self, token: Optional[str], user: User # noqa: ARG002
|
||||
self,
|
||||
token: Optional[str], # noqa: ARG002
|
||||
user: User, # noqa: ARG002
|
||||
) -> str:
|
||||
"""Issue a fresh JWT with a new expiry."""
|
||||
return await self.write_token(user)
|
||||
@@ -1200,8 +1229,7 @@ def get_jwt_strategy() -> SingleTenantJWTStrategy:
|
||||
if AUTH_BACKEND == AuthBackend.JWT:
|
||||
if MULTI_TENANT or AUTH_TYPE == AuthType.CLOUD:
|
||||
raise ValueError(
|
||||
"JWT auth backend is only supported for single-tenant, self-hosted deployments. "
|
||||
"Use 'redis' or 'postgres' instead."
|
||||
"JWT auth backend is only supported for single-tenant, self-hosted deployments. Use 'redis' or 'postgres' instead."
|
||||
)
|
||||
if not USER_AUTH_SECRET:
|
||||
raise ValueError("USER_AUTH_SECRET is required for JWT auth backend.")
|
||||
@@ -1599,6 +1627,102 @@ async def current_admin_user(user: User = Depends(current_user)) -> User:
|
||||
return user
|
||||
|
||||
|
||||
async def _get_user_from_token_data(token_data: dict) -> User | None:
|
||||
"""Shared logic: token data dict → User object.
|
||||
|
||||
Args:
|
||||
token_data: Decoded token data containing 'sub' (user ID).
|
||||
|
||||
Returns:
|
||||
User object if found and active, None otherwise.
|
||||
"""
|
||||
user_id = token_data.get("sub")
|
||||
if not user_id:
|
||||
return None
|
||||
|
||||
try:
|
||||
user_uuid = uuid.UUID(user_id)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
async with get_async_session_context_manager() as async_db_session:
|
||||
user = await async_db_session.get(User, user_uuid)
|
||||
if user is None or not user.is_active:
|
||||
return None
|
||||
return user
|
||||
|
||||
|
||||
async def current_user_from_websocket(
|
||||
websocket: WebSocket,
|
||||
token: str = Query(..., description="WebSocket authentication token"),
|
||||
) -> User:
|
||||
"""
|
||||
WebSocket authentication dependency using query parameter.
|
||||
|
||||
Validates the WS token from query param and returns the User.
|
||||
Raises BasicAuthenticationError if authentication fails.
|
||||
|
||||
The token must be obtained from POST /voice/ws-token before connecting.
|
||||
Tokens are single-use and expire after 60 seconds.
|
||||
|
||||
Usage:
|
||||
1. POST /voice/ws-token -> {"token": "xxx"}
|
||||
2. Connect to ws://host/path?token=xxx
|
||||
|
||||
This applies the same auth checks as current_user() for HTTP endpoints.
|
||||
"""
|
||||
# Check Origin header to prevent Cross-Site WebSocket Hijacking (CSWSH)
|
||||
# Browsers always send Origin on WebSocket connections
|
||||
origin = websocket.headers.get("origin")
|
||||
expected_origin = WEB_DOMAIN.rstrip("/")
|
||||
if not origin:
|
||||
logger.warning("WS auth: missing Origin header")
|
||||
raise BasicAuthenticationError(detail="Access denied. Missing origin.")
|
||||
|
||||
actual_origin = origin.rstrip("/")
|
||||
if actual_origin != expected_origin:
|
||||
logger.warning(
|
||||
f"WS auth: origin mismatch. Expected {expected_origin}, got {actual_origin}"
|
||||
)
|
||||
raise BasicAuthenticationError(detail="Access denied. Invalid origin.")
|
||||
|
||||
# Validate WS token in Redis (single-use, deleted after retrieval)
|
||||
try:
|
||||
token_data = await retrieve_ws_token_data(token)
|
||||
if token_data is None:
|
||||
raise BasicAuthenticationError(
|
||||
detail="Access denied. Invalid or expired authentication token."
|
||||
)
|
||||
except BasicAuthenticationError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"WS auth: error during token validation: {e}")
|
||||
raise BasicAuthenticationError(
|
||||
detail="Authentication verification failed."
|
||||
) from e
|
||||
|
||||
# Get user from token data
|
||||
user = await _get_user_from_token_data(token_data)
|
||||
if user is None:
|
||||
logger.warning(f"WS auth: user not found for id={token_data.get('sub')}")
|
||||
raise BasicAuthenticationError(
|
||||
detail="Access denied. User not found or inactive."
|
||||
)
|
||||
|
||||
# Apply same checks as HTTP auth (verification, OIDC expiry, role)
|
||||
user = await double_check_user(user)
|
||||
|
||||
# Block LIMITED users (same as current_user)
|
||||
if user.role == UserRole.LIMITED:
|
||||
logger.warning(f"WS auth: user {user.email} has LIMITED role")
|
||||
raise BasicAuthenticationError(
|
||||
detail="Access denied. User role is LIMITED. BASIC or higher permissions are required.",
|
||||
)
|
||||
|
||||
logger.debug(f"WS auth: authenticated {user.email}")
|
||||
return user
|
||||
|
||||
|
||||
def get_default_admin_user_emails_() -> list[str]:
|
||||
# No default seeding available for Onyx MIT
|
||||
return []
|
||||
@@ -1608,6 +1732,7 @@ STATE_TOKEN_AUDIENCE = "fastapi-users:oauth-state"
|
||||
STATE_TOKEN_LIFETIME_SECONDS = 3600
|
||||
CSRF_TOKEN_KEY = "csrftoken"
|
||||
CSRF_TOKEN_COOKIE_NAME = "fastapiusersoauthcsrf"
|
||||
PKCE_COOKIE_NAME_PREFIX = "fastapiusersoauthpkce"
|
||||
|
||||
|
||||
class OAuth2AuthorizeResponse(BaseModel):
|
||||
@@ -1628,6 +1753,21 @@ def generate_csrf_token() -> str:
|
||||
return secrets.token_urlsafe(32)
|
||||
|
||||
|
||||
def _base64url_encode(data: bytes) -> str:
|
||||
return base64.urlsafe_b64encode(data).rstrip(b"=").decode("ascii")
|
||||
|
||||
|
||||
def generate_pkce_pair() -> tuple[str, str]:
|
||||
verifier = secrets.token_urlsafe(64)
|
||||
challenge = _base64url_encode(hashlib.sha256(verifier.encode("ascii")).digest())
|
||||
return verifier, challenge
|
||||
|
||||
|
||||
def get_pkce_cookie_name(state: str) -> str:
|
||||
state_hash = hashlib.sha256(state.encode("utf-8")).hexdigest()
|
||||
return f"{PKCE_COOKIE_NAME_PREFIX}_{state_hash}"
|
||||
|
||||
|
||||
# refer to https://github.com/fastapi-users/fastapi-users/blob/42ddc241b965475390e2bce887b084152ae1a2cd/fastapi_users/fastapi_users.py#L91
|
||||
def create_onyx_oauth_router(
|
||||
oauth_client: BaseOAuth2,
|
||||
@@ -1636,6 +1776,7 @@ def create_onyx_oauth_router(
|
||||
redirect_url: Optional[str] = None,
|
||||
associate_by_email: bool = False,
|
||||
is_verified_by_default: bool = False,
|
||||
enable_pkce: bool = False,
|
||||
) -> APIRouter:
|
||||
return get_oauth_router(
|
||||
oauth_client,
|
||||
@@ -1645,6 +1786,7 @@ def create_onyx_oauth_router(
|
||||
redirect_url,
|
||||
associate_by_email,
|
||||
is_verified_by_default,
|
||||
enable_pkce=enable_pkce,
|
||||
)
|
||||
|
||||
|
||||
@@ -1663,6 +1805,7 @@ def get_oauth_router(
|
||||
csrf_token_cookie_secure: Optional[bool] = None,
|
||||
csrf_token_cookie_httponly: bool = True,
|
||||
csrf_token_cookie_samesite: Optional[Literal["lax", "strict", "none"]] = "lax",
|
||||
enable_pkce: bool = False,
|
||||
) -> APIRouter:
|
||||
"""Generate a router with the OAuth routes."""
|
||||
router = APIRouter()
|
||||
@@ -1679,6 +1822,13 @@ def get_oauth_router(
|
||||
route_name=callback_route_name,
|
||||
)
|
||||
|
||||
async def null_access_token_state() -> tuple[OAuth2Token, Optional[str]] | None:
|
||||
return None
|
||||
|
||||
access_token_state_dependency = (
|
||||
oauth2_authorize_callback if not enable_pkce else null_access_token_state
|
||||
)
|
||||
|
||||
if csrf_token_cookie_secure is None:
|
||||
csrf_token_cookie_secure = WEB_DOMAIN.startswith("https")
|
||||
|
||||
@@ -1712,13 +1862,26 @@ def get_oauth_router(
|
||||
CSRF_TOKEN_KEY: csrf_token,
|
||||
}
|
||||
state = generate_state_token(state_data, state_secret)
|
||||
pkce_cookie: tuple[str, str] | None = None
|
||||
|
||||
# Get the basic authorization URL
|
||||
authorization_url = await oauth_client.get_authorization_url(
|
||||
authorize_redirect_url,
|
||||
state,
|
||||
scopes,
|
||||
)
|
||||
if enable_pkce:
|
||||
code_verifier, code_challenge = generate_pkce_pair()
|
||||
pkce_cookie_name = get_pkce_cookie_name(state)
|
||||
pkce_cookie = (pkce_cookie_name, code_verifier)
|
||||
authorization_url = await oauth_client.get_authorization_url(
|
||||
authorize_redirect_url,
|
||||
state,
|
||||
scopes,
|
||||
code_challenge=code_challenge,
|
||||
code_challenge_method="S256",
|
||||
)
|
||||
else:
|
||||
# Get the basic authorization URL
|
||||
authorization_url = await oauth_client.get_authorization_url(
|
||||
authorize_redirect_url,
|
||||
state,
|
||||
scopes,
|
||||
)
|
||||
|
||||
# For Google OAuth, add parameters to request refresh tokens
|
||||
if oauth_client.name == "google":
|
||||
@@ -1726,11 +1889,15 @@ def get_oauth_router(
|
||||
authorization_url, {"access_type": "offline", "prompt": "consent"}
|
||||
)
|
||||
|
||||
if redirect:
|
||||
redirect_response = RedirectResponse(authorization_url, status_code=302)
|
||||
redirect_response.set_cookie(
|
||||
key=csrf_token_cookie_name,
|
||||
value=csrf_token,
|
||||
def set_oauth_cookie(
|
||||
target_response: Response,
|
||||
*,
|
||||
key: str,
|
||||
value: str,
|
||||
) -> None:
|
||||
target_response.set_cookie(
|
||||
key=key,
|
||||
value=value,
|
||||
max_age=STATE_TOKEN_LIFETIME_SECONDS,
|
||||
path=csrf_token_cookie_path,
|
||||
domain=csrf_token_cookie_domain,
|
||||
@@ -1738,18 +1905,28 @@ def get_oauth_router(
|
||||
httponly=csrf_token_cookie_httponly,
|
||||
samesite=csrf_token_cookie_samesite,
|
||||
)
|
||||
return redirect_response
|
||||
|
||||
response.set_cookie(
|
||||
response_with_cookies: Response
|
||||
if redirect:
|
||||
response_with_cookies = RedirectResponse(authorization_url, status_code=302)
|
||||
else:
|
||||
response_with_cookies = response
|
||||
|
||||
set_oauth_cookie(
|
||||
response_with_cookies,
|
||||
key=csrf_token_cookie_name,
|
||||
value=csrf_token,
|
||||
max_age=STATE_TOKEN_LIFETIME_SECONDS,
|
||||
path=csrf_token_cookie_path,
|
||||
domain=csrf_token_cookie_domain,
|
||||
secure=csrf_token_cookie_secure,
|
||||
httponly=csrf_token_cookie_httponly,
|
||||
samesite=csrf_token_cookie_samesite,
|
||||
)
|
||||
if pkce_cookie is not None:
|
||||
pkce_cookie_name, code_verifier = pkce_cookie
|
||||
set_oauth_cookie(
|
||||
response_with_cookies,
|
||||
key=pkce_cookie_name,
|
||||
value=code_verifier,
|
||||
)
|
||||
|
||||
if redirect:
|
||||
return response_with_cookies
|
||||
|
||||
return OAuth2AuthorizeResponse(authorization_url=authorization_url)
|
||||
|
||||
@@ -1780,119 +1957,242 @@ def get_oauth_router(
|
||||
)
|
||||
async def callback(
|
||||
request: Request,
|
||||
access_token_state: Tuple[OAuth2Token, str] = Depends(
|
||||
oauth2_authorize_callback
|
||||
access_token_state: Tuple[OAuth2Token, Optional[str]] | None = Depends(
|
||||
access_token_state_dependency
|
||||
),
|
||||
code: Optional[str] = None,
|
||||
state: Optional[str] = None,
|
||||
error: Optional[str] = None,
|
||||
user_manager: BaseUserManager[models.UP, models.ID] = Depends(get_user_manager),
|
||||
strategy: Strategy[models.UP, models.ID] = Depends(backend.get_strategy),
|
||||
) -> RedirectResponse:
|
||||
token, state = access_token_state
|
||||
account_id, account_email = await oauth_client.get_id_email(
|
||||
token["access_token"]
|
||||
)
|
||||
) -> Response:
|
||||
pkce_cookie_name: str | None = None
|
||||
|
||||
if account_email is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ErrorCode.OAUTH_NOT_AVAILABLE_EMAIL,
|
||||
)
|
||||
def delete_pkce_cookie(response: Response) -> None:
|
||||
if enable_pkce and pkce_cookie_name:
|
||||
response.delete_cookie(
|
||||
key=pkce_cookie_name,
|
||||
path=csrf_token_cookie_path,
|
||||
domain=csrf_token_cookie_domain,
|
||||
secure=csrf_token_cookie_secure,
|
||||
httponly=csrf_token_cookie_httponly,
|
||||
samesite=csrf_token_cookie_samesite,
|
||||
)
|
||||
|
||||
try:
|
||||
state_data = decode_jwt(state, state_secret, [STATE_TOKEN_AUDIENCE])
|
||||
except jwt.DecodeError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=getattr(
|
||||
ErrorCode, "ACCESS_TOKEN_DECODE_ERROR", "ACCESS_TOKEN_DECODE_ERROR"
|
||||
),
|
||||
)
|
||||
except jwt.ExpiredSignatureError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=getattr(
|
||||
ErrorCode,
|
||||
"ACCESS_TOKEN_ALREADY_EXPIRED",
|
||||
"ACCESS_TOKEN_ALREADY_EXPIRED",
|
||||
),
|
||||
)
|
||||
def build_error_response(exc: OnyxError) -> JSONResponse:
|
||||
log_onyx_error(exc)
|
||||
error_response = onyx_error_to_json_response(exc)
|
||||
delete_pkce_cookie(error_response)
|
||||
return error_response
|
||||
|
||||
cookie_csrf_token = request.cookies.get(csrf_token_cookie_name)
|
||||
state_csrf_token = state_data.get(CSRF_TOKEN_KEY)
|
||||
if (
|
||||
not cookie_csrf_token
|
||||
or not state_csrf_token
|
||||
or not secrets.compare_digest(cookie_csrf_token, state_csrf_token)
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=getattr(ErrorCode, "OAUTH_INVALID_STATE", "OAUTH_INVALID_STATE"),
|
||||
)
|
||||
def decode_and_validate_state(state_value: str) -> Dict[str, str]:
|
||||
try:
|
||||
state_data = decode_jwt(
|
||||
state_value, state_secret, [STATE_TOKEN_AUDIENCE]
|
||||
)
|
||||
except jwt.DecodeError:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
getattr(
|
||||
ErrorCode,
|
||||
"ACCESS_TOKEN_DECODE_ERROR",
|
||||
"ACCESS_TOKEN_DECODE_ERROR",
|
||||
),
|
||||
)
|
||||
except jwt.ExpiredSignatureError:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
getattr(
|
||||
ErrorCode,
|
||||
"ACCESS_TOKEN_ALREADY_EXPIRED",
|
||||
"ACCESS_TOKEN_ALREADY_EXPIRED",
|
||||
),
|
||||
)
|
||||
except jwt.PyJWTError:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
getattr(
|
||||
ErrorCode,
|
||||
"ACCESS_TOKEN_DECODE_ERROR",
|
||||
"ACCESS_TOKEN_DECODE_ERROR",
|
||||
),
|
||||
)
|
||||
|
||||
next_url = state_data.get("next_url", "/")
|
||||
referral_source = state_data.get("referral_source", None)
|
||||
try:
|
||||
tenant_id = fetch_ee_implementation_or_noop(
|
||||
"onyx.server.tenants.user_mapping", "get_tenant_id_for_email", None
|
||||
)(account_email)
|
||||
except exceptions.UserNotExists:
|
||||
tenant_id = None
|
||||
cookie_csrf_token = request.cookies.get(csrf_token_cookie_name)
|
||||
state_csrf_token = state_data.get(CSRF_TOKEN_KEY)
|
||||
if (
|
||||
not cookie_csrf_token
|
||||
or not state_csrf_token
|
||||
or not secrets.compare_digest(cookie_csrf_token, state_csrf_token)
|
||||
):
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
getattr(ErrorCode, "OAUTH_INVALID_STATE", "OAUTH_INVALID_STATE"),
|
||||
)
|
||||
|
||||
request.state.referral_source = referral_source
|
||||
return state_data
|
||||
|
||||
# Proceed to authenticate or create the user
|
||||
try:
|
||||
user = await user_manager.oauth_callback(
|
||||
oauth_client.name,
|
||||
token["access_token"],
|
||||
account_id,
|
||||
account_email,
|
||||
token.get("expires_at"),
|
||||
token.get("refresh_token"),
|
||||
request,
|
||||
associate_by_email=associate_by_email,
|
||||
is_verified_by_default=is_verified_by_default,
|
||||
)
|
||||
except UserAlreadyExists:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ErrorCode.OAUTH_USER_ALREADY_EXISTS,
|
||||
)
|
||||
token: OAuth2Token
|
||||
state_data: Dict[str, str]
|
||||
|
||||
if not user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ErrorCode.LOGIN_BAD_CREDENTIALS,
|
||||
)
|
||||
# `code`, `state`, and `error` are read directly only in the PKCE path.
|
||||
# In the non-PKCE path, `oauth2_authorize_callback` consumes them.
|
||||
if enable_pkce:
|
||||
if state is not None:
|
||||
pkce_cookie_name = get_pkce_cookie_name(state)
|
||||
|
||||
# Login user
|
||||
response = await backend.login(strategy, user)
|
||||
await user_manager.on_after_login(user, request, response)
|
||||
if error is not None:
|
||||
return build_error_response(
|
||||
OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"Authorization request failed or was denied",
|
||||
)
|
||||
)
|
||||
if code is None:
|
||||
return build_error_response(
|
||||
OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"Missing authorization code in OAuth callback",
|
||||
)
|
||||
)
|
||||
if state is None:
|
||||
return build_error_response(
|
||||
OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"Missing state parameter in OAuth callback",
|
||||
)
|
||||
)
|
||||
|
||||
# Prepare redirect response
|
||||
if tenant_id is None:
|
||||
# Use URL utility to add parameters
|
||||
redirect_url = add_url_params(next_url, {"new_team": "true"})
|
||||
redirect_response = RedirectResponse(redirect_url, status_code=302)
|
||||
else:
|
||||
# No parameters to add
|
||||
redirect_response = RedirectResponse(next_url, status_code=302)
|
||||
state_value = state
|
||||
|
||||
# Copy headers from auth response to redirect response, with special handling for Set-Cookie
|
||||
for header_name, header_value in response.headers.items():
|
||||
# FastAPI can have multiple Set-Cookie headers as a list
|
||||
if header_name.lower() == "set-cookie" and isinstance(header_value, list):
|
||||
for cookie_value in header_value:
|
||||
redirect_response.headers.append(header_name, cookie_value)
|
||||
if redirect_url is not None:
|
||||
callback_redirect_url = redirect_url
|
||||
else:
|
||||
callback_path = request.app.url_path_for(callback_route_name)
|
||||
callback_redirect_url = f"{WEB_DOMAIN}{callback_path}"
|
||||
|
||||
code_verifier = request.cookies.get(cast(str, pkce_cookie_name))
|
||||
if not code_verifier:
|
||||
return build_error_response(
|
||||
OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"Missing PKCE verifier cookie in OAuth callback",
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
state_data = decode_and_validate_state(state_value)
|
||||
except OnyxError as e:
|
||||
return build_error_response(e)
|
||||
|
||||
try:
|
||||
token = await oauth_client.get_access_token(
|
||||
code, callback_redirect_url, code_verifier
|
||||
)
|
||||
except GetAccessTokenError:
|
||||
return build_error_response(
|
||||
OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"Authorization code exchange failed",
|
||||
)
|
||||
)
|
||||
else:
|
||||
if access_token_state is None:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INTERNAL_ERROR, "Missing OAuth callback state"
|
||||
)
|
||||
token, callback_state = access_token_state
|
||||
if callback_state is None:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"Missing state parameter in OAuth callback",
|
||||
)
|
||||
state_data = decode_and_validate_state(callback_state)
|
||||
|
||||
async def complete_login_flow(
|
||||
token: OAuth2Token, state_data: Dict[str, str]
|
||||
) -> RedirectResponse:
|
||||
account_id, account_email = await oauth_client.get_id_email(
|
||||
token["access_token"]
|
||||
)
|
||||
|
||||
if account_email is None:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
ErrorCode.OAUTH_NOT_AVAILABLE_EMAIL,
|
||||
)
|
||||
|
||||
next_url = state_data.get("next_url", "/")
|
||||
referral_source = state_data.get("referral_source", None)
|
||||
try:
|
||||
tenant_id = fetch_ee_implementation_or_noop(
|
||||
"onyx.server.tenants.user_mapping", "get_tenant_id_for_email", None
|
||||
)(account_email)
|
||||
except exceptions.UserNotExists:
|
||||
tenant_id = None
|
||||
|
||||
request.state.referral_source = referral_source
|
||||
|
||||
# Proceed to authenticate or create the user
|
||||
try:
|
||||
user = await user_manager.oauth_callback(
|
||||
oauth_client.name,
|
||||
token["access_token"],
|
||||
account_id,
|
||||
account_email,
|
||||
token.get("expires_at"),
|
||||
token.get("refresh_token"),
|
||||
request,
|
||||
associate_by_email=associate_by_email,
|
||||
is_verified_by_default=is_verified_by_default,
|
||||
)
|
||||
except UserAlreadyExists:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
ErrorCode.OAUTH_USER_ALREADY_EXISTS,
|
||||
)
|
||||
|
||||
if not user.is_active:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
ErrorCode.LOGIN_BAD_CREDENTIALS,
|
||||
)
|
||||
|
||||
# Login user
|
||||
response = await backend.login(strategy, user)
|
||||
await user_manager.on_after_login(user, request, response)
|
||||
|
||||
# Prepare redirect response
|
||||
if tenant_id is None:
|
||||
# Use URL utility to add parameters
|
||||
redirect_destination = add_url_params(next_url, {"new_team": "true"})
|
||||
redirect_response = RedirectResponse(
|
||||
redirect_destination, status_code=302
|
||||
)
|
||||
else:
|
||||
# No parameters to add
|
||||
redirect_response = RedirectResponse(next_url, status_code=302)
|
||||
|
||||
# Copy headers from auth response to redirect response, with special handling for Set-Cookie
|
||||
for header_name, header_value in response.headers.items():
|
||||
header_name_lower = header_name.lower()
|
||||
if header_name_lower == "set-cookie":
|
||||
redirect_response.headers.append(header_name, header_value)
|
||||
continue
|
||||
if header_name_lower in {"location", "content-length"}:
|
||||
continue
|
||||
redirect_response.headers[header_name] = header_value
|
||||
|
||||
if hasattr(response, "body"):
|
||||
redirect_response.body = response.body
|
||||
if hasattr(response, "status_code"):
|
||||
redirect_response.status_code = response.status_code
|
||||
if hasattr(response, "media_type"):
|
||||
redirect_response.media_type = response.media_type
|
||||
return redirect_response
|
||||
|
||||
return redirect_response
|
||||
if enable_pkce:
|
||||
try:
|
||||
redirect_response = await complete_login_flow(token, state_data)
|
||||
except OnyxError as e:
|
||||
return build_error_response(e)
|
||||
delete_pkce_cookie(redirect_response)
|
||||
return redirect_response
|
||||
|
||||
return await complete_login_flow(token, state_data)
|
||||
|
||||
return router
|
||||
|
||||
@@ -154,8 +154,7 @@ def on_task_postrun(
|
||||
tenant_id = cast(str, kwargs.get("tenant_id", POSTGRES_DEFAULT_SCHEMA))
|
||||
|
||||
task_logger.debug(
|
||||
f"Task {task.name} (ID: {task_id}) completed with state: {state} "
|
||||
f"{f'for tenant_id={tenant_id}' if tenant_id else ''}"
|
||||
f"Task {task.name} (ID: {task_id}) completed with state: {state} {f'for tenant_id={tenant_id}' if tenant_id else ''}"
|
||||
)
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
@@ -211,7 +210,9 @@ def on_task_postrun(
|
||||
|
||||
|
||||
def on_celeryd_init(
|
||||
sender: str, conf: Any = None, **kwargs: Any # noqa: ARG001
|
||||
sender: str, # noqa: ARG001
|
||||
conf: Any = None, # noqa: ARG001
|
||||
**kwargs: Any, # noqa: ARG001
|
||||
) -> None:
|
||||
"""The first signal sent on celery worker startup"""
|
||||
|
||||
@@ -277,10 +278,7 @@ def wait_for_redis(sender: Any, **kwargs: Any) -> None: # noqa: ARG001
|
||||
time.sleep(WAIT_INTERVAL)
|
||||
|
||||
if not ready:
|
||||
msg = (
|
||||
f"Redis: Readiness probe did not succeed within the timeout "
|
||||
f"({WAIT_LIMIT} seconds). Exiting..."
|
||||
)
|
||||
msg = f"Redis: Readiness probe did not succeed within the timeout ({WAIT_LIMIT} seconds). Exiting..."
|
||||
logger.error(msg)
|
||||
raise WorkerShutdown(msg)
|
||||
|
||||
@@ -319,10 +317,7 @@ def wait_for_db(sender: Any, **kwargs: Any) -> None: # noqa: ARG001
|
||||
time.sleep(WAIT_INTERVAL)
|
||||
|
||||
if not ready:
|
||||
msg = (
|
||||
f"Database: Readiness probe did not succeed within the timeout "
|
||||
f"({WAIT_LIMIT} seconds). Exiting..."
|
||||
)
|
||||
msg = f"Database: Readiness probe did not succeed within the timeout ({WAIT_LIMIT} seconds). Exiting..."
|
||||
logger.error(msg)
|
||||
raise WorkerShutdown(msg)
|
||||
|
||||
@@ -349,10 +344,7 @@ def on_secondary_worker_init(sender: Any, **kwargs: Any) -> None: # noqa: ARG00
|
||||
f"Primary worker is not ready yet. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}"
|
||||
)
|
||||
if time_elapsed > WAIT_LIMIT:
|
||||
msg = (
|
||||
f"Primary worker was not ready within the timeout. "
|
||||
f"({WAIT_LIMIT} seconds). Exiting..."
|
||||
)
|
||||
msg = f"Primary worker was not ready within the timeout. ({WAIT_LIMIT} seconds). Exiting..."
|
||||
logger.error(msg)
|
||||
raise WorkerShutdown(msg)
|
||||
|
||||
@@ -522,7 +514,9 @@ def reset_tenant_id(
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(POSTGRES_DEFAULT_SCHEMA)
|
||||
|
||||
|
||||
def wait_for_vespa_or_shutdown(sender: Any, **kwargs: Any) -> None: # noqa: ARG001
|
||||
def wait_for_vespa_or_shutdown(
|
||||
sender: Any, **kwargs: Any # noqa: ARG001
|
||||
) -> None: # noqa: ARG001
|
||||
"""Waits for Vespa to become ready subject to a timeout.
|
||||
Raises WorkerShutdown if the timeout is reached."""
|
||||
|
||||
|
||||
@@ -181,9 +181,7 @@ class DynamicTenantScheduler(PersistentScheduler):
|
||||
if not do_update:
|
||||
# exit early if nothing changed
|
||||
task_logger.info(
|
||||
f"_try_updating_schedule - Schedule unchanged: "
|
||||
f"tasks={len(new_schedule)} "
|
||||
f"beat_multiplier={beat_multiplier}"
|
||||
f"_try_updating_schedule - Schedule unchanged: tasks={len(new_schedule)} beat_multiplier={beat_multiplier}"
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
@@ -186,7 +186,6 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
|
||||
# Check if the Celery task actually exists
|
||||
try:
|
||||
|
||||
result: AsyncResult = AsyncResult(attempt.celery_task_id)
|
||||
|
||||
# If the task is not in PENDING state, it exists in Celery
|
||||
@@ -207,8 +206,7 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
except Exception:
|
||||
# If we can't check the task status, be conservative and continue
|
||||
logger.warning(
|
||||
f"Could not verify Celery task status on startup for attempt {attempt.id}, "
|
||||
f"task_id={attempt.celery_task_id}"
|
||||
f"Could not verify Celery task status on startup for attempt {attempt.id}, task_id={attempt.celery_task_id}"
|
||||
)
|
||||
|
||||
|
||||
@@ -278,8 +276,7 @@ class HubPeriodicTask(bootsteps.StartStopStep):
|
||||
lock.reacquire()
|
||||
else:
|
||||
task_logger.warning(
|
||||
"Full acquisition of primary worker lock. "
|
||||
"Reasons could be worker restart or lock expiration."
|
||||
"Full acquisition of primary worker lock. Reasons could be worker restart or lock expiration."
|
||||
)
|
||||
lock = r.lock(
|
||||
OnyxRedisLocks.PRIMARY_WORKER,
|
||||
|
||||
@@ -115,18 +115,15 @@ def _extract_from_batch(
|
||||
for item in doc_list:
|
||||
if isinstance(item, HierarchyNode):
|
||||
hierarchy_nodes.append(item)
|
||||
if item.raw_node_id not in ids:
|
||||
ids[item.raw_node_id] = None
|
||||
elif isinstance(item, ConnectorFailure):
|
||||
failed_id = _get_failure_id(item)
|
||||
if failed_id:
|
||||
ids[failed_id] = None
|
||||
logger.warning(
|
||||
f"Failed to retrieve document {failed_id}: " f"{item.failure_message}"
|
||||
f"Failed to retrieve document {failed_id}: {item.failure_message}"
|
||||
)
|
||||
else:
|
||||
parent_raw = getattr(item, "parent_hierarchy_raw_node_id", None)
|
||||
ids[item.id] = parent_raw
|
||||
ids[item.id] = item.parent_hierarchy_raw_node_id
|
||||
return BatchResult(raw_id_to_parent=ids, hierarchy_nodes=hierarchy_nodes)
|
||||
|
||||
|
||||
@@ -192,9 +189,7 @@ def extract_ids_from_runnable_connector(
|
||||
batch_ids = batch_result.raw_id_to_parent
|
||||
batch_nodes = batch_result.hierarchy_nodes
|
||||
doc_batch_processing_func(batch_ids)
|
||||
for k, v in batch_ids.items():
|
||||
if v is not None or k not in all_raw_id_to_parent:
|
||||
all_raw_id_to_parent[k] = v
|
||||
all_raw_id_to_parent.update(batch_ids)
|
||||
all_hierarchy_nodes.extend(batch_nodes)
|
||||
|
||||
if callback:
|
||||
|
||||
@@ -307,14 +307,12 @@ def try_generate_document_cc_pair_cleanup_tasks(
|
||||
|
||||
if redis_connector.prune.fenced:
|
||||
raise TaskDependencyError(
|
||||
"Connector deletion - Delayed (pruning in progress): "
|
||||
f"cc_pair={cc_pair_id}"
|
||||
f"Connector deletion - Delayed (pruning in progress): cc_pair={cc_pair_id}"
|
||||
)
|
||||
|
||||
if redis_connector.permissions.fenced:
|
||||
raise TaskDependencyError(
|
||||
f"Connector deletion - Delayed (permissions in progress): "
|
||||
f"cc_pair={cc_pair_id}"
|
||||
f"Connector deletion - Delayed (permissions in progress): cc_pair={cc_pair_id}"
|
||||
)
|
||||
|
||||
# add tasks to celery and build up the task set to monitor in redis
|
||||
@@ -354,8 +352,7 @@ def try_generate_document_cc_pair_cleanup_tasks(
|
||||
# return 0
|
||||
|
||||
task_logger.info(
|
||||
"RedisConnectorDeletion.generate_tasks finished. "
|
||||
f"cc_pair={cc_pair_id} tasks_generated={tasks_generated}"
|
||||
f"RedisConnectorDeletion.generate_tasks finished. cc_pair={cc_pair_id} tasks_generated={tasks_generated}"
|
||||
)
|
||||
|
||||
# set this only after all tasks have been added
|
||||
@@ -366,7 +363,9 @@ def try_generate_document_cc_pair_cleanup_tasks(
|
||||
|
||||
|
||||
def monitor_connector_deletion_taskset(
|
||||
tenant_id: str, key_bytes: bytes, r: Redis # noqa: ARG001
|
||||
tenant_id: str,
|
||||
key_bytes: bytes,
|
||||
r: Redis, # noqa: ARG001
|
||||
) -> None:
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
|
||||
@@ -690,8 +689,7 @@ def validate_connector_deletion_fence(
|
||||
tasks_not_in_celery += 1
|
||||
|
||||
task_logger.info(
|
||||
"validate_connector_deletion_fence task check: "
|
||||
f"tasks_scanned={tasks_scanned} tasks_not_in_celery={tasks_not_in_celery}"
|
||||
f"validate_connector_deletion_fence task check: tasks_scanned={tasks_scanned} tasks_not_in_celery={tasks_not_in_celery}"
|
||||
)
|
||||
|
||||
# we're active if there are still tasks to run and those tasks all exist in celery
|
||||
|
||||
@@ -109,9 +109,7 @@ def try_creating_docfetching_task(
|
||||
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
f"try_creating_indexing_task - Unexpected exception: "
|
||||
f"cc_pair={cc_pair.id} "
|
||||
f"search_settings={search_settings.id}"
|
||||
f"try_creating_indexing_task - Unexpected exception: cc_pair={cc_pair.id} search_settings={search_settings.id}"
|
||||
)
|
||||
|
||||
# Clean up on failure
|
||||
|
||||
@@ -60,15 +60,13 @@ def _verify_indexing_attempt(
|
||||
|
||||
if attempt.connector_credential_pair_id != cc_pair_id:
|
||||
raise SimpleJobException(
|
||||
f"docfetching_task - CC pair mismatch: "
|
||||
f"expected={cc_pair_id} actual={attempt.connector_credential_pair_id}",
|
||||
f"docfetching_task - CC pair mismatch: expected={cc_pair_id} actual={attempt.connector_credential_pair_id}",
|
||||
code=IndexingWatchdogTerminalStatus.FENCE_MISMATCH.code,
|
||||
)
|
||||
|
||||
if attempt.search_settings_id != search_settings_id:
|
||||
raise SimpleJobException(
|
||||
f"docfetching_task - Search settings mismatch: "
|
||||
f"expected={search_settings_id} actual={attempt.search_settings_id}",
|
||||
f"docfetching_task - Search settings mismatch: expected={search_settings_id} actual={attempt.search_settings_id}",
|
||||
code=IndexingWatchdogTerminalStatus.FENCE_MISMATCH.code,
|
||||
)
|
||||
|
||||
@@ -77,8 +75,7 @@ def _verify_indexing_attempt(
|
||||
IndexingStatus.IN_PROGRESS,
|
||||
]:
|
||||
raise SimpleJobException(
|
||||
f"docfetching_task - Invalid attempt status: "
|
||||
f"attempt_id={index_attempt_id} status={attempt.status}",
|
||||
f"docfetching_task - Invalid attempt status: attempt_id={index_attempt_id} status={attempt.status}",
|
||||
code=IndexingWatchdogTerminalStatus.FENCE_MISMATCH.code,
|
||||
)
|
||||
|
||||
@@ -248,9 +245,7 @@ def _docfetching_task(
|
||||
raise e
|
||||
|
||||
logger.info(
|
||||
f"Indexing spawned task finished: attempt={index_attempt_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
f"Indexing spawned task finished: attempt={index_attempt_id} cc_pair={cc_pair_id} search_settings={search_settings_id}"
|
||||
)
|
||||
os._exit(0) # ensure process exits cleanly
|
||||
|
||||
@@ -286,8 +281,7 @@ def process_job_result(
|
||||
result.status = IndexingWatchdogTerminalStatus.SUCCEEDED
|
||||
task_logger.warning(
|
||||
log_builder.build(
|
||||
"Indexing watchdog - spawned task has non-zero exit code "
|
||||
"but completion signal is OK. Continuing...",
|
||||
"Indexing watchdog - spawned task has non-zero exit code but completion signal is OK. Continuing...",
|
||||
exit_code=str(result.exit_code),
|
||||
)
|
||||
)
|
||||
@@ -296,10 +290,7 @@ def process_job_result(
|
||||
result.status = IndexingWatchdogTerminalStatus.from_code(result.exit_code)
|
||||
|
||||
job_level_exception = job.exception()
|
||||
result.exception_str = (
|
||||
f"Docfetching returned exit code {result.exit_code} "
|
||||
f"with exception: {job_level_exception}"
|
||||
)
|
||||
result.exception_str = f"Docfetching returned exit code {result.exit_code} with exception: {job_level_exception}"
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@@ -158,7 +158,6 @@ def validate_active_indexing_attempts(
|
||||
logger.info("Validating active indexing attempts")
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
|
||||
# Find all active indexing attempts
|
||||
active_attempts = (
|
||||
db_session.execute(
|
||||
@@ -190,8 +189,7 @@ def validate_active_indexing_attempts(
|
||||
db_session.commit()
|
||||
|
||||
task_logger.info(
|
||||
f"Initialized heartbeat tracking for attempt {fresh_attempt.id}: "
|
||||
f"counter={fresh_attempt.heartbeat_counter}"
|
||||
f"Initialized heartbeat tracking for attempt {fresh_attempt.id}: counter={fresh_attempt.heartbeat_counter}"
|
||||
)
|
||||
continue
|
||||
|
||||
@@ -214,8 +212,7 @@ def validate_active_indexing_attempts(
|
||||
db_session.commit()
|
||||
|
||||
task_logger.debug(
|
||||
f"Heartbeat advanced for attempt {fresh_attempt.id}: "
|
||||
f"new_counter={current_counter}"
|
||||
f"Heartbeat advanced for attempt {fresh_attempt.id}: new_counter={current_counter}"
|
||||
)
|
||||
continue
|
||||
|
||||
@@ -350,9 +347,7 @@ def monitor_indexing_attempt_progress(
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Failed to monitor document processing completion: "
|
||||
f"attempt={attempt.id} "
|
||||
f"error={str(e)}"
|
||||
f"Failed to monitor document processing completion: attempt={attempt.id} error={str(e)}"
|
||||
)
|
||||
|
||||
# Mark the attempt as failed if monitoring fails
|
||||
@@ -401,9 +396,7 @@ def check_indexing_completion(
|
||||
) -> None:
|
||||
|
||||
logger.info(
|
||||
f"Checking for indexing completion: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id}"
|
||||
f"Checking for indexing completion: attempt={index_attempt_id} tenant={tenant_id}"
|
||||
)
|
||||
|
||||
# Check if indexing is complete and all batches are processed
|
||||
@@ -445,7 +438,7 @@ def check_indexing_completion(
|
||||
if attempt.status == IndexingStatus.IN_PROGRESS:
|
||||
logger.error(
|
||||
f"Indexing attempt {index_attempt_id} has been indexing for "
|
||||
f"{stalled_timeout_hours//2}-{stalled_timeout_hours} hours without progress. "
|
||||
f"{stalled_timeout_hours // 2}-{stalled_timeout_hours} hours without progress. "
|
||||
f"Marking it as failed."
|
||||
)
|
||||
mark_attempt_failed(
|
||||
@@ -695,17 +688,12 @@ def _kickoff_indexing_tasks(
|
||||
|
||||
if attempt_id is not None:
|
||||
task_logger.info(
|
||||
f"Connector indexing queued: "
|
||||
f"index_attempt={attempt_id} "
|
||||
f"cc_pair={cc_pair.id} "
|
||||
f"search_settings={search_settings.id}"
|
||||
f"Connector indexing queued: index_attempt={attempt_id} cc_pair={cc_pair.id} search_settings={search_settings.id}"
|
||||
)
|
||||
tasks_created += 1
|
||||
else:
|
||||
task_logger.error(
|
||||
f"Failed to create indexing task: "
|
||||
f"cc_pair={cc_pair.id} "
|
||||
f"search_settings={search_settings.id}"
|
||||
f"Failed to create indexing task: cc_pair={cc_pair.id} search_settings={search_settings.id}"
|
||||
)
|
||||
|
||||
return tasks_created
|
||||
@@ -901,9 +889,7 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
|
||||
and secondary_search_settings.switchover_type == SwitchoverType.INSTANT
|
||||
):
|
||||
task_logger.info(
|
||||
f"Skipping secondary indexing: "
|
||||
f"switchover_type=INSTANT "
|
||||
f"for search_settings={secondary_search_settings.id}"
|
||||
f"Skipping secondary indexing: switchover_type=INSTANT for search_settings={secondary_search_settings.id}"
|
||||
)
|
||||
|
||||
# 2/3: VALIDATE
|
||||
@@ -1005,8 +991,7 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
|
||||
lock_beat.release()
|
||||
else:
|
||||
task_logger.error(
|
||||
"check_for_indexing - Lock not owned on completion: "
|
||||
f"tenant={tenant_id}"
|
||||
f"check_for_indexing - Lock not owned on completion: tenant={tenant_id}"
|
||||
)
|
||||
redis_lock_dump(lock_beat, redis_client)
|
||||
|
||||
@@ -1060,8 +1045,7 @@ def check_for_checkpoint_cleanup(self: Task, *, tenant_id: str) -> None:
|
||||
lock.release()
|
||||
else:
|
||||
task_logger.error(
|
||||
"check_for_checkpoint_cleanup - Lock not owned on completion: "
|
||||
f"tenant={tenant_id}"
|
||||
f"check_for_checkpoint_cleanup - Lock not owned on completion: tenant={tenant_id}"
|
||||
)
|
||||
|
||||
|
||||
@@ -1071,7 +1055,10 @@ def check_for_checkpoint_cleanup(self: Task, *, tenant_id: str) -> None:
|
||||
bind=True,
|
||||
)
|
||||
def cleanup_checkpoint_task(
|
||||
self: Task, *, index_attempt_id: int, tenant_id: str | None # noqa: ARG001
|
||||
self: Task, # noqa: ARG001
|
||||
*,
|
||||
index_attempt_id: int,
|
||||
tenant_id: str | None,
|
||||
) -> None:
|
||||
"""Clean up a checkpoint for a given index attempt"""
|
||||
|
||||
@@ -1084,9 +1071,7 @@ def cleanup_checkpoint_task(
|
||||
elapsed = time.monotonic() - start
|
||||
|
||||
task_logger.info(
|
||||
f"cleanup_checkpoint_task completed: tenant_id={tenant_id} "
|
||||
f"index_attempt_id={index_attempt_id} "
|
||||
f"elapsed={elapsed:.2f}"
|
||||
f"cleanup_checkpoint_task completed: tenant_id={tenant_id} index_attempt_id={index_attempt_id} elapsed={elapsed:.2f}"
|
||||
)
|
||||
|
||||
|
||||
@@ -1149,8 +1134,7 @@ def check_for_index_attempt_cleanup(self: Task, *, tenant_id: str) -> None:
|
||||
lock.release()
|
||||
else:
|
||||
task_logger.error(
|
||||
"check_for_index_attempt_cleanup - Lock not owned on completion: "
|
||||
f"tenant={tenant_id}"
|
||||
f"check_for_index_attempt_cleanup - Lock not owned on completion: tenant={tenant_id}"
|
||||
)
|
||||
|
||||
|
||||
@@ -1160,7 +1144,10 @@ def check_for_index_attempt_cleanup(self: Task, *, tenant_id: str) -> None:
|
||||
bind=True,
|
||||
)
|
||||
def cleanup_index_attempt_task(
|
||||
self: Task, *, index_attempt_ids: list[int], tenant_id: str # noqa: ARG001
|
||||
self: Task, # noqa: ARG001
|
||||
*,
|
||||
index_attempt_ids: list[int],
|
||||
tenant_id: str,
|
||||
) -> None:
|
||||
"""Clean up an index attempt"""
|
||||
start = time.monotonic()
|
||||
@@ -1207,15 +1194,13 @@ def _check_failure_threshold(
|
||||
FAILURE_RATIO_THRESHOLD = 0.1
|
||||
if total_failures > FAILURE_THRESHOLD and failure_ratio > FAILURE_RATIO_THRESHOLD:
|
||||
logger.error(
|
||||
f"Connector run failed with '{total_failures}' errors "
|
||||
f"after '{batch_num}' batches."
|
||||
f"Connector run failed with '{total_failures}' errors after '{batch_num}' batches."
|
||||
)
|
||||
if last_failure and last_failure.exception:
|
||||
raise last_failure.exception from last_failure.exception
|
||||
|
||||
raise RuntimeError(
|
||||
f"Connector run encountered too many errors, aborting. "
|
||||
f"Last error: {last_failure}"
|
||||
f"Connector run encountered too many errors, aborting. Last error: {last_failure}"
|
||||
)
|
||||
|
||||
|
||||
@@ -1339,9 +1324,7 @@ def _docprocessing_task(
|
||||
raise
|
||||
|
||||
task_logger.info(
|
||||
f"Processing document batch: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"batch_num={batch_num} "
|
||||
f"Processing document batch: attempt={index_attempt_id} batch_num={batch_num} "
|
||||
)
|
||||
|
||||
# Get the document batch storage
|
||||
@@ -1599,9 +1582,7 @@ def _docprocessing_task(
|
||||
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
f"Document batch processing failed: "
|
||||
f"batch_num={batch_num} "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"Document batch processing failed: batch_num={batch_num} attempt={index_attempt_id} "
|
||||
)
|
||||
|
||||
raise
|
||||
|
||||
@@ -84,8 +84,7 @@ def scheduled_eval_task(self: Task, **kwargs: Any) -> None: # noqa: ARG001
|
||||
run_timestamp = datetime.now(timezone.utc).strftime("%Y-%m-%d")
|
||||
|
||||
logger.info(
|
||||
f"Starting scheduled eval pipeline for project '{project_name}' "
|
||||
f"with {len(dataset_names)} dataset(s): {dataset_names}"
|
||||
f"Starting scheduled eval pipeline for project '{project_name}' with {len(dataset_names)} dataset(s): {dataset_names}"
|
||||
)
|
||||
|
||||
pipeline_start = datetime.now(timezone.utc)
|
||||
@@ -101,8 +100,7 @@ def scheduled_eval_task(self: Task, **kwargs: Any) -> None: # noqa: ARG001
|
||||
|
||||
try:
|
||||
logger.info(
|
||||
f"Running scheduled eval for dataset: {dataset_name} "
|
||||
f"(project: {project_name})"
|
||||
f"Running scheduled eval for dataset: {dataset_name} (project: {project_name})"
|
||||
)
|
||||
|
||||
configuration = EvalConfigurationOptions(
|
||||
@@ -142,6 +140,5 @@ def scheduled_eval_task(self: Task, **kwargs: Any) -> None: # noqa: ARG001
|
||||
|
||||
passed_count = sum(1 for r in results if r["success"])
|
||||
logger.info(
|
||||
f"Scheduled eval pipeline completed: {passed_count}/{len(results)} passed "
|
||||
f"in {total_duration:.1f}s"
|
||||
f"Scheduled eval pipeline completed: {passed_count}/{len(results)} passed in {total_duration:.1f}s"
|
||||
)
|
||||
|
||||
@@ -40,6 +40,7 @@ from onyx.db.connector_credential_pair import get_connector_credential_pair_from
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.hierarchy import upsert_hierarchy_node_cc_pair_entries
|
||||
from onyx.db.hierarchy import upsert_hierarchy_nodes_batch
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.redis.redis_hierarchy import cache_hierarchy_nodes_batch
|
||||
@@ -126,9 +127,7 @@ def _try_creating_hierarchy_fetching_task(
|
||||
raise RuntimeError("send_task for hierarchy_fetching_task failed.")
|
||||
|
||||
task_logger.info(
|
||||
f"Created hierarchy fetching task: "
|
||||
f"cc_pair={cc_pair.id} "
|
||||
f"celery_task_id={custom_task_id}"
|
||||
f"Created hierarchy fetching task: cc_pair={cc_pair.id} celery_task_id={custom_task_id}"
|
||||
)
|
||||
|
||||
return custom_task_id
|
||||
@@ -214,8 +213,7 @@ def check_for_hierarchy_fetching(self: Task, *, tenant_id: str) -> int | None:
|
||||
|
||||
time_elapsed = time.monotonic() - time_start
|
||||
task_logger.info(
|
||||
f"check_for_hierarchy_fetching finished: "
|
||||
f"tasks_created={tasks_created} elapsed={time_elapsed:.2f}s"
|
||||
f"check_for_hierarchy_fetching finished: tasks_created={tasks_created} elapsed={time_elapsed:.2f}s"
|
||||
)
|
||||
return tasks_created
|
||||
|
||||
@@ -289,6 +287,14 @@ def _run_hierarchy_extraction(
|
||||
is_connector_public=is_connector_public,
|
||||
)
|
||||
|
||||
upsert_hierarchy_node_cc_pair_entries(
|
||||
db_session=db_session,
|
||||
hierarchy_node_ids=[n.id for n in upserted_nodes],
|
||||
connector_id=cc_pair.connector_id,
|
||||
credential_id=cc_pair.credential_id,
|
||||
commit=True,
|
||||
)
|
||||
|
||||
# Cache in Redis for fast ancestor resolution
|
||||
cache_entries = [
|
||||
HierarchyNodeCacheEntry.from_db_model(node) for node in upserted_nodes
|
||||
@@ -333,8 +339,7 @@ def connector_hierarchy_fetching_task(
|
||||
from the connector source and stores it in the database.
|
||||
"""
|
||||
task_logger.info(
|
||||
f"connector_hierarchy_fetching_task starting: "
|
||||
f"cc_pair={cc_pair_id} tenant={tenant_id}"
|
||||
f"connector_hierarchy_fetching_task starting: cc_pair={cc_pair_id} tenant={tenant_id}"
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -352,8 +357,7 @@ def connector_hierarchy_fetching_task(
|
||||
|
||||
if cc_pair.status == ConnectorCredentialPairStatus.DELETING:
|
||||
task_logger.info(
|
||||
f"Skipping hierarchy fetching for deleting connector: "
|
||||
f"cc_pair={cc_pair_id}"
|
||||
f"Skipping hierarchy fetching for deleting connector: cc_pair={cc_pair_id}"
|
||||
)
|
||||
return
|
||||
|
||||
@@ -366,8 +370,7 @@ def connector_hierarchy_fetching_task(
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"connector_hierarchy_fetching_task: "
|
||||
f"Extracted {total_nodes} hierarchy nodes for cc_pair={cc_pair_id}"
|
||||
f"connector_hierarchy_fetching_task: Extracted {total_nodes} hierarchy nodes for cc_pair={cc_pair_id}"
|
||||
)
|
||||
|
||||
# Update the last fetch time to prevent re-running until next interval
|
||||
|
||||
@@ -18,7 +18,9 @@ from onyx.llm.well_known_providers.auto_update_service import (
|
||||
bind=True,
|
||||
)
|
||||
def check_for_auto_llm_updates(
|
||||
self: Task, *, tenant_id: str # noqa: ARG001
|
||||
self: Task, # noqa: ARG001
|
||||
*,
|
||||
tenant_id: str, # noqa: ARG001
|
||||
) -> bool | None:
|
||||
"""Periodic task to fetch LLM model updates from GitHub
|
||||
and sync them to providers in Auto mode.
|
||||
|
||||
@@ -116,8 +116,7 @@ class Metric(BaseModel):
|
||||
string_value = self.value
|
||||
else:
|
||||
task_logger.error(
|
||||
f"Invalid metric value type: {type(self.value)} "
|
||||
f"({self.value}) for metric {self.name}."
|
||||
f"Invalid metric value type: {type(self.value)} ({self.value}) for metric {self.name}."
|
||||
)
|
||||
return
|
||||
|
||||
@@ -260,8 +259,7 @@ def _build_connector_final_metrics(
|
||||
)
|
||||
if _has_metric_been_emitted(redis_std, metric_key):
|
||||
task_logger.info(
|
||||
f"Skipping final metrics for connector {cc_pair.connector.id} "
|
||||
f"index attempt {attempt.id}, already emitted."
|
||||
f"Skipping final metrics for connector {cc_pair.connector.id} index attempt {attempt.id}, already emitted."
|
||||
)
|
||||
continue
|
||||
|
||||
@@ -1036,8 +1034,7 @@ def monitor_process_memory(self: Task, *, tenant_id: str) -> None: # noqa: ARG0
|
||||
if process_name in cmdline:
|
||||
if process_type in supervisor_processes.values():
|
||||
task_logger.error(
|
||||
f"Duplicate process type for type {process_type} "
|
||||
f"with cmd {cmdline} with pid={proc.pid}."
|
||||
f"Duplicate process type for type {process_type} with cmd {cmdline} with pid={proc.pid}."
|
||||
)
|
||||
continue
|
||||
|
||||
@@ -1046,8 +1043,7 @@ def monitor_process_memory(self: Task, *, tenant_id: str) -> None: # noqa: ARG0
|
||||
|
||||
if len(supervisor_processes) != len(process_type_mapping):
|
||||
task_logger.error(
|
||||
"Missing processes: "
|
||||
f"{set(process_type_mapping.keys()).symmetric_difference(supervisor_processes.values())}"
|
||||
f"Missing processes: {set(process_type_mapping.keys()).symmetric_difference(supervisor_processes.values())}"
|
||||
)
|
||||
|
||||
# Log memory usage for each process
|
||||
@@ -1101,9 +1097,7 @@ def cloud_monitor_celery_pidbox(
|
||||
|
||||
r_celery.delete(key)
|
||||
task_logger.info(
|
||||
f"Deleted idle pidbox: pidbox={key_str} "
|
||||
f"idletime={idletime} "
|
||||
f"max_idletime={MAX_PIDBOX_IDLE}"
|
||||
f"Deleted idle pidbox: pidbox={key_str} idletime={idletime} max_idletime={MAX_PIDBOX_IDLE}"
|
||||
)
|
||||
num_deleted += 1
|
||||
|
||||
|
||||
@@ -11,6 +11,9 @@
|
||||
# lock after its cleanup which happens at most after its soft timeout.
|
||||
|
||||
# Constants corresponding to migrate_documents_from_vespa_to_opensearch_task.
|
||||
from onyx.configs.app_configs import OPENSEARCH_MIGRATION_GET_VESPA_CHUNKS_PAGE_SIZE
|
||||
|
||||
|
||||
MIGRATION_TASK_SOFT_TIME_LIMIT_S = 60 * 5 # 5 minutes.
|
||||
MIGRATION_TASK_TIME_LIMIT_S = 60 * 6 # 6 minutes.
|
||||
# The maximum time the lock can be held for. Will automatically be released
|
||||
@@ -44,7 +47,7 @@ TOTAL_ALLOWABLE_DOC_MIGRATION_ATTEMPTS_BEFORE_PERMANENT_FAILURE = 15
|
||||
|
||||
# WARNING: Do not change these values without knowing what changes also need to
|
||||
# be made to OpenSearchTenantMigrationRecord.
|
||||
GET_VESPA_CHUNKS_PAGE_SIZE = 500
|
||||
GET_VESPA_CHUNKS_PAGE_SIZE = OPENSEARCH_MIGRATION_GET_VESPA_CHUNKS_PAGE_SIZE
|
||||
GET_VESPA_CHUNKS_SLICE_COUNT = 4
|
||||
|
||||
# String used to indicate in the vespa_visit_continuation_token mapping that the
|
||||
|
||||
@@ -205,8 +205,7 @@ def migrate_chunks_from_vespa_to_opensearch_task(
|
||||
) = get_vespa_visit_state(db_session)
|
||||
if is_continuation_token_done_for_all_slices(continuation_token_map):
|
||||
task_logger.info(
|
||||
f"OpenSearch migration COMPLETED for tenant {tenant_id}. "
|
||||
f"Total chunks migrated: {total_chunks_migrated}."
|
||||
f"OpenSearch migration COMPLETED for tenant {tenant_id}. Total chunks migrated: {total_chunks_migrated}."
|
||||
)
|
||||
mark_migration_completed_time_if_not_set_with_commit(db_session)
|
||||
break
|
||||
|
||||
@@ -48,10 +48,15 @@ from onyx.db.enums import AccessType
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.enums import SyncStatus
|
||||
from onyx.db.enums import SyncType
|
||||
from onyx.db.hierarchy import delete_orphaned_hierarchy_nodes
|
||||
from onyx.db.hierarchy import link_hierarchy_nodes_to_documents
|
||||
from onyx.db.hierarchy import remove_stale_hierarchy_node_cc_pair_entries
|
||||
from onyx.db.hierarchy import reparent_orphaned_hierarchy_nodes
|
||||
from onyx.db.hierarchy import update_document_parent_hierarchy_nodes
|
||||
from onyx.db.hierarchy import upsert_hierarchy_node_cc_pair_entries
|
||||
from onyx.db.hierarchy import upsert_hierarchy_nodes_batch
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.models import HierarchyNode as DBHierarchyNode
|
||||
from onyx.db.sync_record import insert_sync_record
|
||||
from onyx.db.sync_record import update_sync_record_status
|
||||
from onyx.db.tag import delete_orphan_tags__no_commit
|
||||
@@ -60,6 +65,7 @@ from onyx.redis.redis_connector_prune import RedisConnectorPrune
|
||||
from onyx.redis.redis_connector_prune import RedisConnectorPrunePayload
|
||||
from onyx.redis.redis_hierarchy import cache_hierarchy_nodes_batch
|
||||
from onyx.redis.redis_hierarchy import ensure_source_node_exists
|
||||
from onyx.redis.redis_hierarchy import evict_hierarchy_nodes_from_cache
|
||||
from onyx.redis.redis_hierarchy import get_node_id_from_raw_id
|
||||
from onyx.redis.redis_hierarchy import get_source_node_id_from_cache
|
||||
from onyx.redis.redis_hierarchy import HierarchyNodeCacheEntry
|
||||
@@ -145,8 +151,7 @@ def _resolve_and_update_document_parents(
|
||||
commit=True,
|
||||
)
|
||||
task_logger.info(
|
||||
f"Pruning: resolved and updated parent hierarchy for "
|
||||
f"{len(resolved)} documents (source={source.value})"
|
||||
f"Pruning: resolved and updated parent hierarchy for {len(resolved)} documents (source={source.value})"
|
||||
)
|
||||
|
||||
|
||||
@@ -214,7 +219,6 @@ def check_for_pruning(self: Task, *, tenant_id: str) -> bool | None:
|
||||
|
||||
# but pruning only kicks off once per hour
|
||||
if not r.exists(OnyxRedisSignals.BLOCK_PRUNING):
|
||||
|
||||
task_logger.info("Checking for pruning due")
|
||||
|
||||
cc_pair_ids: list[int] = []
|
||||
@@ -478,8 +482,7 @@ def connector_pruning_generator_task(
|
||||
|
||||
if not redis_connector.prune.fenced: # The fence must exist
|
||||
raise ValueError(
|
||||
f"connector_prune_generator_task - fence not found: "
|
||||
f"fence={redis_connector.prune.fence_key}"
|
||||
f"connector_prune_generator_task - fence not found: fence={redis_connector.prune.fence_key}"
|
||||
)
|
||||
|
||||
payload = redis_connector.prune.payload # The payload must exist
|
||||
@@ -490,8 +493,7 @@ def connector_pruning_generator_task(
|
||||
|
||||
if payload.celery_task_id is None:
|
||||
logger.info(
|
||||
f"connector_prune_generator_task - Waiting for fence: "
|
||||
f"fence={redis_connector.prune.fence_key}"
|
||||
f"connector_prune_generator_task - Waiting for fence: fence={redis_connector.prune.fence_key}"
|
||||
)
|
||||
time.sleep(1)
|
||||
continue
|
||||
@@ -547,9 +549,7 @@ def connector_pruning_generator_task(
|
||||
redis_connector.prune.set_fence(new_payload)
|
||||
|
||||
task_logger.info(
|
||||
f"Pruning generator running connector: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"connector_source={cc_pair.connector.source}"
|
||||
f"Pruning generator running connector: cc_pair={cc_pair_id} connector_source={cc_pair.connector.source}"
|
||||
)
|
||||
|
||||
runnable_connector = instantiate_connector(
|
||||
@@ -579,11 +579,12 @@ def connector_pruning_generator_task(
|
||||
source = cc_pair.connector.source
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
ensure_source_node_exists(redis_client, db_session, source)
|
||||
|
||||
upserted_nodes: list[DBHierarchyNode] = []
|
||||
if extraction_result.hierarchy_nodes:
|
||||
is_connector_public = cc_pair.access_type == AccessType.PUBLIC
|
||||
|
||||
ensure_source_node_exists(redis_client, db_session, source)
|
||||
|
||||
upserted_nodes = upsert_hierarchy_nodes_batch(
|
||||
db_session=db_session,
|
||||
nodes=extraction_result.hierarchy_nodes,
|
||||
@@ -592,6 +593,14 @@ def connector_pruning_generator_task(
|
||||
is_connector_public=is_connector_public,
|
||||
)
|
||||
|
||||
upsert_hierarchy_node_cc_pair_entries(
|
||||
db_session=db_session,
|
||||
hierarchy_node_ids=[n.id for n in upserted_nodes],
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
commit=True,
|
||||
)
|
||||
|
||||
cache_entries = [
|
||||
HierarchyNodeCacheEntry.from_db_model(node)
|
||||
for node in upserted_nodes
|
||||
@@ -607,7 +616,6 @@ def connector_pruning_generator_task(
|
||||
f"hierarchy nodes for cc_pair={cc_pair_id}"
|
||||
)
|
||||
|
||||
ensure_source_node_exists(redis_client, db_session, source)
|
||||
# Resolve parent_hierarchy_raw_node_id → parent_hierarchy_node_id
|
||||
# and bulk-update documents, mirroring the docfetching resolution
|
||||
_resolve_and_update_document_parents(
|
||||
@@ -659,16 +667,50 @@ def connector_pruning_generator_task(
|
||||
return None
|
||||
|
||||
task_logger.info(
|
||||
"RedisConnector.prune.generate_tasks finished. "
|
||||
f"cc_pair={cc_pair_id} tasks_generated={tasks_generated}"
|
||||
f"RedisConnector.prune.generate_tasks finished. cc_pair={cc_pair_id} tasks_generated={tasks_generated}"
|
||||
)
|
||||
|
||||
redis_connector.prune.generator_complete = tasks_generated
|
||||
|
||||
# --- Hierarchy node pruning ---
|
||||
live_node_ids = {n.id for n in upserted_nodes}
|
||||
stale_removed = remove_stale_hierarchy_node_cc_pair_entries(
|
||||
db_session=db_session,
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
live_hierarchy_node_ids=live_node_ids,
|
||||
commit=True,
|
||||
)
|
||||
deleted_raw_ids = delete_orphaned_hierarchy_nodes(
|
||||
db_session=db_session,
|
||||
source=source,
|
||||
commit=True,
|
||||
)
|
||||
reparented_nodes = reparent_orphaned_hierarchy_nodes(
|
||||
db_session=db_session,
|
||||
source=source,
|
||||
commit=True,
|
||||
)
|
||||
if deleted_raw_ids:
|
||||
evict_hierarchy_nodes_from_cache(redis_client, source, deleted_raw_ids)
|
||||
if reparented_nodes:
|
||||
reparented_cache_entries = [
|
||||
HierarchyNodeCacheEntry.from_db_model(node)
|
||||
for node in reparented_nodes
|
||||
]
|
||||
cache_hierarchy_nodes_batch(
|
||||
redis_client, source, reparented_cache_entries
|
||||
)
|
||||
if stale_removed or deleted_raw_ids or reparented_nodes:
|
||||
task_logger.info(
|
||||
f"Hierarchy node pruning: cc_pair={cc_pair_id} "
|
||||
f"stale_entries_removed={stale_removed} "
|
||||
f"nodes_deleted={len(deleted_raw_ids)} "
|
||||
f"nodes_reparented={len(reparented_nodes)}"
|
||||
)
|
||||
except Exception as e:
|
||||
task_logger.exception(
|
||||
f"Pruning exceptioned: cc_pair={cc_pair_id} "
|
||||
f"connector={connector_id} "
|
||||
f"payload_id={payload_id}"
|
||||
f"Pruning exceptioned: cc_pair={cc_pair_id} connector={connector_id} payload_id={payload_id}"
|
||||
)
|
||||
|
||||
redis_connector.prune.reset()
|
||||
@@ -686,7 +728,10 @@ def connector_pruning_generator_task(
|
||||
|
||||
|
||||
def monitor_ccpair_pruning_taskset(
|
||||
tenant_id: str, key_bytes: bytes, r: Redis, db_session: Session # noqa: ARG001
|
||||
tenant_id: str,
|
||||
key_bytes: bytes,
|
||||
r: Redis, # noqa: ARG001
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
|
||||
@@ -880,8 +925,7 @@ def validate_pruning_fence(
|
||||
tasks_not_in_celery += 1
|
||||
|
||||
task_logger.info(
|
||||
"validate_pruning_fence task check: "
|
||||
f"tasks_scanned={tasks_scanned} tasks_not_in_celery={tasks_not_in_celery}"
|
||||
f"validate_pruning_fence task check: tasks_scanned={tasks_scanned} tasks_not_in_celery={tasks_not_in_celery}"
|
||||
)
|
||||
|
||||
# we're active if there are still tasks to run and those tasks all exist in celery
|
||||
|
||||
@@ -192,10 +192,7 @@ def document_by_cc_pair_cleanup_task(
|
||||
|
||||
elapsed = time.monotonic() - start
|
||||
task_logger.info(
|
||||
f"doc={document_id} "
|
||||
f"action={action} "
|
||||
f"refcount={count} "
|
||||
f"elapsed={elapsed:.2f}"
|
||||
f"doc={document_id} action={action} refcount={count} elapsed={elapsed:.2f}"
|
||||
)
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(f"SoftTimeLimitExceeded exception. doc={document_id}")
|
||||
@@ -218,9 +215,7 @@ def document_by_cc_pair_cleanup_task(
|
||||
if isinstance(e, httpx.HTTPStatusError):
|
||||
if e.response.status_code == HTTPStatus.BAD_REQUEST:
|
||||
task_logger.exception(
|
||||
f"Non-retryable HTTPStatusError: "
|
||||
f"doc={document_id} "
|
||||
f"status={e.response.status_code}"
|
||||
f"Non-retryable HTTPStatusError: doc={document_id} status={e.response.status_code}"
|
||||
)
|
||||
completion_status = (
|
||||
OnyxCeleryTaskCompletionStatus.NON_RETRYABLE_EXCEPTION
|
||||
@@ -239,8 +234,7 @@ def document_by_cc_pair_cleanup_task(
|
||||
# This is the last attempt! mark the document as dirty in the db so that it
|
||||
# eventually gets fixed out of band via stale document reconciliation
|
||||
task_logger.warning(
|
||||
f"Max celery task retries reached. Marking doc as dirty for reconciliation: "
|
||||
f"doc={document_id}"
|
||||
f"Max celery task retries reached. Marking doc as dirty for reconciliation: doc={document_id}"
|
||||
)
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
# delete the cc pair relationship now and let reconciliation clean it up
|
||||
@@ -285,4 +279,4 @@ def celery_beat_heartbeat(self: Task, *, tenant_id: str) -> None: # noqa: ARG00
|
||||
r: Redis = get_redis_client()
|
||||
r.set(ONYX_CELERY_BEAT_HEARTBEAT_KEY, 1, ex=600)
|
||||
time_elapsed = time.monotonic() - time_start
|
||||
task_logger.info(f"celery_beat_heartbeat finished: " f"elapsed={time_elapsed:.2f}")
|
||||
task_logger.info(f"celery_beat_heartbeat finished: elapsed={time_elapsed:.2f}")
|
||||
|
||||
@@ -12,9 +12,9 @@ from redis import Redis
|
||||
from redis.lock import Lock as RedisLock
|
||||
from retry import retry
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.access.access import build_access_for_user_files
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_get_queue_length
|
||||
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
|
||||
@@ -43,7 +43,9 @@ from onyx.db.enums import UserFileStatus
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.db.search_settings import get_active_search_settings
|
||||
from onyx.db.search_settings import get_active_search_settings_list
|
||||
from onyx.db.user_file import fetch_user_files_with_access_relationships
|
||||
from onyx.document_index.factory import get_all_document_indices
|
||||
from onyx.document_index.interfaces import VespaDocumentFields
|
||||
from onyx.document_index.interfaces import VespaDocumentUserFields
|
||||
from onyx.document_index.vespa_constants import DOCUMENT_ID_ENDPOINT
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
@@ -54,6 +56,7 @@ from onyx.indexing.adapters.user_file_indexing_adapter import UserFileIndexingAd
|
||||
from onyx.indexing.embedder import DefaultIndexingEmbedder
|
||||
from onyx.indexing.indexing_pipeline import run_indexing_pipeline
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
|
||||
|
||||
def _as_uuid(value: str | UUID) -> UUID:
|
||||
@@ -282,8 +285,7 @@ def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
|
||||
lock.release()
|
||||
|
||||
task_logger.info(
|
||||
f"check_user_file_processing - Enqueued {enqueued} skipped_guard={skipped_guard} "
|
||||
f"tasks for tenant={tenant_id}"
|
||||
f"check_user_file_processing - Enqueued {enqueued} skipped_guard={skipped_guard} tasks for tenant={tenant_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
@@ -314,8 +316,7 @@ def _process_user_file_without_vector_db(
|
||||
token_count: int | None = len(encode(combined_text))
|
||||
except Exception:
|
||||
task_logger.warning(
|
||||
f"_process_user_file_without_vector_db - "
|
||||
f"Failed to compute token count for {uf.id}, falling back to None"
|
||||
f"_process_user_file_without_vector_db - Failed to compute token count for {uf.id}, falling back to None"
|
||||
)
|
||||
token_count = None
|
||||
|
||||
@@ -335,8 +336,7 @@ def _process_user_file_without_vector_db(
|
||||
db_session.commit()
|
||||
|
||||
task_logger.info(
|
||||
f"_process_user_file_without_vector_db - "
|
||||
f"Completed id={uf.id} tokens={token_count}"
|
||||
f"_process_user_file_without_vector_db - Completed id={uf.id} tokens={token_count}"
|
||||
)
|
||||
|
||||
|
||||
@@ -363,8 +363,7 @@ def _process_user_file_with_indexing(
|
||||
)
|
||||
if current_search_settings is None:
|
||||
raise RuntimeError(
|
||||
f"_process_user_file_with_indexing - "
|
||||
f"No current search settings found for tenant={tenant_id}"
|
||||
f"_process_user_file_with_indexing - No current search settings found for tenant={tenant_id}"
|
||||
)
|
||||
|
||||
adapter = UserFileIndexingAdapter(
|
||||
@@ -394,8 +393,7 @@ def _process_user_file_with_indexing(
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"_process_user_file_with_indexing - "
|
||||
f"Indexing pipeline completed ={index_pipeline_result}"
|
||||
f"_process_user_file_with_indexing - Indexing pipeline completed ={index_pipeline_result}"
|
||||
)
|
||||
|
||||
if (
|
||||
@@ -404,8 +402,7 @@ def _process_user_file_with_indexing(
|
||||
or index_pipeline_result.total_chunks == 0
|
||||
):
|
||||
task_logger.error(
|
||||
f"_process_user_file_with_indexing - "
|
||||
f"Indexing pipeline failed id={user_file_id}"
|
||||
f"_process_user_file_with_indexing - Indexing pipeline failed id={user_file_id}"
|
||||
)
|
||||
if uf.status != UserFileStatus.DELETING:
|
||||
uf.status = UserFileStatus.FAILED
|
||||
@@ -532,7 +529,10 @@ def process_user_file_impl(
|
||||
ignore_result=True,
|
||||
)
|
||||
def process_single_user_file(
|
||||
self: Task, *, user_file_id: str, tenant_id: str # noqa: ARG001
|
||||
self: Task, # noqa: ARG001
|
||||
*,
|
||||
user_file_id: str,
|
||||
tenant_id: str,
|
||||
) -> None:
|
||||
process_user_file_impl(
|
||||
user_file_id=user_file_id, tenant_id=tenant_id, redis_locking=True
|
||||
@@ -688,7 +688,10 @@ def delete_user_file_impl(
|
||||
ignore_result=True,
|
||||
)
|
||||
def process_single_user_file_delete(
|
||||
self: Task, *, user_file_id: str, tenant_id: str # noqa: ARG001
|
||||
self: Task, # noqa: ARG001
|
||||
*,
|
||||
user_file_id: str,
|
||||
tenant_id: str,
|
||||
) -> None:
|
||||
delete_user_file_impl(
|
||||
user_file_id=user_file_id, tenant_id=tenant_id, redis_locking=True
|
||||
@@ -758,8 +761,7 @@ def check_for_user_file_project_sync(self: Task, *, tenant_id: str) -> None:
|
||||
lock.release()
|
||||
|
||||
task_logger.info(
|
||||
f"Enqueued {enqueued} "
|
||||
f"Skipped guard {skipped_guard} tasks for tenant={tenant_id}"
|
||||
f"Enqueued {enqueued} Skipped guard {skipped_guard} tasks for tenant={tenant_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
@@ -791,11 +793,12 @@ def project_sync_user_file_impl(
|
||||
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
user_file = db_session.execute(
|
||||
select(UserFile)
|
||||
.where(UserFile.id == _as_uuid(user_file_id))
|
||||
.options(selectinload(UserFile.assistants))
|
||||
).scalar_one_or_none()
|
||||
user_files = fetch_user_files_with_access_relationships(
|
||||
[user_file_id],
|
||||
db_session,
|
||||
eager_load_groups=global_version.is_ee_version(),
|
||||
)
|
||||
user_file = user_files[0] if user_files else None
|
||||
if not user_file:
|
||||
task_logger.info(
|
||||
f"project_sync_user_file_impl - User file not found id={user_file_id}"
|
||||
@@ -823,12 +826,21 @@ def project_sync_user_file_impl(
|
||||
|
||||
project_ids = [project.id for project in user_file.projects]
|
||||
persona_ids = [p.id for p in user_file.assistants if not p.deleted]
|
||||
|
||||
file_id_str = str(user_file.id)
|
||||
access_map = build_access_for_user_files([user_file])
|
||||
access = access_map.get(file_id_str)
|
||||
|
||||
for retry_document_index in retry_document_indices:
|
||||
retry_document_index.update_single(
|
||||
doc_id=str(user_file.id),
|
||||
doc_id=file_id_str,
|
||||
tenant_id=tenant_id,
|
||||
chunk_count=user_file.chunk_count,
|
||||
fields=None,
|
||||
fields=(
|
||||
VespaDocumentFields(access=access)
|
||||
if access is not None
|
||||
else None
|
||||
),
|
||||
user_fields=VespaDocumentUserFields(
|
||||
user_projects=project_ids,
|
||||
personas=persona_ids,
|
||||
@@ -863,7 +875,10 @@ def project_sync_user_file_impl(
|
||||
ignore_result=True,
|
||||
)
|
||||
def process_single_user_file_project_sync(
|
||||
self: Task, *, user_file_id: str, tenant_id: str # noqa: ARG001
|
||||
self: Task, # noqa: ARG001
|
||||
*,
|
||||
user_file_id: str,
|
||||
tenant_id: str,
|
||||
) -> None:
|
||||
project_sync_user_file_impl(
|
||||
user_file_id=user_file_id, tenant_id=tenant_id, redis_locking=True
|
||||
|
||||
@@ -199,8 +199,7 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str) -> bool | None:
|
||||
lock_beat.release()
|
||||
else:
|
||||
task_logger.error(
|
||||
"check_for_vespa_sync_task - Lock not owned on completion: "
|
||||
f"tenant={tenant_id}"
|
||||
f"check_for_vespa_sync_task - Lock not owned on completion: tenant={tenant_id}"
|
||||
)
|
||||
redis_lock_dump(lock_beat, r)
|
||||
|
||||
@@ -266,8 +265,7 @@ def try_generate_document_set_sync_tasks(
|
||||
# return 0
|
||||
|
||||
task_logger.info(
|
||||
f"RedisDocumentSet.generate_tasks finished. "
|
||||
f"document_set={document_set.id} tasks_generated={tasks_generated}"
|
||||
f"RedisDocumentSet.generate_tasks finished. document_set={document_set.id} tasks_generated={tasks_generated}"
|
||||
)
|
||||
|
||||
# create before setting fence to avoid race condition where the monitoring
|
||||
@@ -342,8 +340,7 @@ def try_generate_user_group_sync_tasks(
|
||||
# return 0
|
||||
|
||||
task_logger.info(
|
||||
f"RedisUserGroup.generate_tasks finished. "
|
||||
f"usergroup={usergroup.id} tasks_generated={tasks_generated}"
|
||||
f"RedisUserGroup.generate_tasks finished. usergroup={usergroup.id} tasks_generated={tasks_generated}"
|
||||
)
|
||||
|
||||
# create before setting fence to avoid race condition where the monitoring
|
||||
@@ -398,8 +395,7 @@ def monitor_document_set_taskset(
|
||||
|
||||
count = cast(int, r.scard(rds.taskset_key))
|
||||
task_logger.info(
|
||||
f"Document set sync progress: document_set={document_set_id} "
|
||||
f"remaining={count} initial={initial_count}"
|
||||
f"Document set sync progress: document_set={document_set_id} remaining={count} initial={initial_count}"
|
||||
)
|
||||
if count > 0:
|
||||
update_sync_record_status(
|
||||
@@ -444,9 +440,7 @@ def monitor_document_set_taskset(
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
"update_sync_record_status exceptioned. "
|
||||
f"document_set_id={document_set_id} "
|
||||
"Resetting document set regardless."
|
||||
f"update_sync_record_status exceptioned. document_set_id={document_set_id} Resetting document set regardless."
|
||||
)
|
||||
|
||||
rds.reset()
|
||||
@@ -483,9 +477,7 @@ def vespa_metadata_sync_task(self: Task, document_id: str, *, tenant_id: str) ->
|
||||
if not doc:
|
||||
elapsed = time.monotonic() - start
|
||||
task_logger.info(
|
||||
f"doc={document_id} "
|
||||
f"action=no_operation "
|
||||
f"elapsed={elapsed:.2f}"
|
||||
f"doc={document_id} action=no_operation elapsed={elapsed:.2f}"
|
||||
)
|
||||
completion_status = OnyxCeleryTaskCompletionStatus.SKIPPED
|
||||
else:
|
||||
@@ -524,9 +516,7 @@ def vespa_metadata_sync_task(self: Task, document_id: str, *, tenant_id: str) ->
|
||||
mark_document_as_synced(document_id, db_session)
|
||||
|
||||
elapsed = time.monotonic() - start
|
||||
task_logger.info(
|
||||
f"doc={document_id} " f"action=sync " f"elapsed={elapsed:.2f}"
|
||||
)
|
||||
task_logger.info(f"doc={document_id} action=sync elapsed={elapsed:.2f}")
|
||||
completion_status = OnyxCeleryTaskCompletionStatus.SUCCEEDED
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(f"SoftTimeLimitExceeded exception. doc={document_id}")
|
||||
@@ -549,9 +539,7 @@ def vespa_metadata_sync_task(self: Task, document_id: str, *, tenant_id: str) ->
|
||||
if isinstance(e, httpx.HTTPStatusError):
|
||||
if e.response.status_code == HTTPStatus.BAD_REQUEST:
|
||||
task_logger.exception(
|
||||
f"Non-retryable HTTPStatusError: "
|
||||
f"doc={document_id} "
|
||||
f"status={e.response.status_code}"
|
||||
f"Non-retryable HTTPStatusError: doc={document_id} status={e.response.status_code}"
|
||||
)
|
||||
completion_status = (
|
||||
OnyxCeleryTaskCompletionStatus.NON_RETRYABLE_EXCEPTION
|
||||
|
||||
@@ -175,14 +175,16 @@ class SimpleJobClient:
|
||||
del self.jobs[job.id]
|
||||
|
||||
def submit(
|
||||
self, func: Callable, *args: Any, pure: bool = True # noqa: ARG002
|
||||
self,
|
||||
func: Callable,
|
||||
*args: Any,
|
||||
pure: bool = True, # noqa: ARG002
|
||||
) -> SimpleJob | None:
|
||||
"""NOTE: `pure` arg is needed so this can be a drop in replacement for Dask"""
|
||||
self._cleanup_completed_jobs()
|
||||
if len(self.jobs) >= self.n_workers:
|
||||
logger.debug(
|
||||
f"No available workers to run job. "
|
||||
f"Currently running '{len(self.jobs)}' jobs, with a limit of '{self.n_workers}'."
|
||||
f"No available workers to run job. Currently running '{len(self.jobs)}' jobs, with a limit of '{self.n_workers}'."
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
@@ -45,6 +45,7 @@ from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.enums import IndexingStatus
|
||||
from onyx.db.enums import IndexModelStatus
|
||||
from onyx.db.enums import ProcessingMode
|
||||
from onyx.db.hierarchy import upsert_hierarchy_node_cc_pair_entries
|
||||
from onyx.db.hierarchy import upsert_hierarchy_nodes_batch
|
||||
from onyx.db.index_attempt import create_index_attempt_error
|
||||
from onyx.db.index_attempt import get_index_attempt
|
||||
@@ -225,15 +226,13 @@ def _check_failure_threshold(
|
||||
FAILURE_RATIO_THRESHOLD = 0.1
|
||||
if total_failures > FAILURE_THRESHOLD and failure_ratio > FAILURE_RATIO_THRESHOLD:
|
||||
logger.error(
|
||||
f"Connector run failed with '{total_failures}' errors "
|
||||
f"after '{batch_num}' batches."
|
||||
f"Connector run failed with '{total_failures}' errors after '{batch_num}' batches."
|
||||
)
|
||||
if last_failure and last_failure.exception:
|
||||
raise last_failure.exception from last_failure.exception
|
||||
|
||||
raise RuntimeError(
|
||||
f"Connector run encountered too many errors, aborting. "
|
||||
f"Last error: {last_failure}"
|
||||
f"Connector run encountered too many errors, aborting. Last error: {last_failure}"
|
||||
)
|
||||
|
||||
|
||||
@@ -587,6 +586,14 @@ def connector_document_extraction(
|
||||
is_connector_public=is_connector_public,
|
||||
)
|
||||
|
||||
upsert_hierarchy_node_cc_pair_entries(
|
||||
db_session=db_session,
|
||||
hierarchy_node_ids=[n.id for n in upserted_nodes],
|
||||
connector_id=db_connector.id,
|
||||
credential_id=db_credential.id,
|
||||
commit=True,
|
||||
)
|
||||
|
||||
# Cache in Redis for fast ancestor resolution during doc processing
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
cache_entries = [
|
||||
@@ -600,8 +607,7 @@ def connector_document_extraction(
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Persisted and cached {len(hierarchy_node_batch_cleaned)} hierarchy nodes "
|
||||
f"for attempt={index_attempt_id}"
|
||||
f"Persisted and cached {len(hierarchy_node_batch_cleaned)} hierarchy nodes for attempt={index_attempt_id}"
|
||||
)
|
||||
|
||||
# below is all document processing task, so if no batch we can just continue
|
||||
@@ -803,15 +809,12 @@ def connector_document_extraction(
|
||||
queue=OnyxCeleryQueues.SANDBOX,
|
||||
)
|
||||
logger.info(
|
||||
f"Triggered sandbox file sync for user {creator_id} "
|
||||
f"source={source_value} after indexing complete"
|
||||
f"Triggered sandbox file sync for user {creator_id} source={source_value} after indexing complete"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Document extraction failed: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"error={str(e)}"
|
||||
f"Document extraction failed: attempt={index_attempt_id} error={str(e)}"
|
||||
)
|
||||
|
||||
# Do NOT clean up batches on failure; future runs will use those batches
|
||||
@@ -947,7 +950,6 @@ def reissue_old_batches(
|
||||
# is still in the filestore waiting for processing or not.
|
||||
last_batch_num = len(old_batches) + recent_batches
|
||||
logger.info(
|
||||
f"Starting from batch {last_batch_num} due to "
|
||||
f"re-issued batches: {old_batches}, completed batches: {recent_batches}"
|
||||
f"Starting from batch {last_batch_num} due to re-issued batches: {old_batches}, completed batches: {recent_batches}"
|
||||
)
|
||||
return len(old_batches), recent_batches
|
||||
|
||||
@@ -259,8 +259,7 @@ def _poller_loop(tenant_id: str) -> None:
|
||||
|
||||
periodic_tasks = _build_periodic_tasks()
|
||||
logger.info(
|
||||
f"Periodic poller started with {len(periodic_tasks)} periodic task(s): "
|
||||
f"{[t.name for t in periodic_tasks]}"
|
||||
f"Periodic poller started with {len(periodic_tasks)} periodic task(s): {[t.name for t in periodic_tasks]}"
|
||||
)
|
||||
|
||||
while not _shutdown_event.is_set():
|
||||
|
||||
3
backend/onyx/cache/factory.py
vendored
3
backend/onyx/cache/factory.py
vendored
@@ -38,8 +38,7 @@ def get_cache_backend(*, tenant_id: str | None = None) -> CacheBackend:
|
||||
builder = _BACKEND_BUILDERS.get(CACHE_BACKEND)
|
||||
if builder is None:
|
||||
raise ValueError(
|
||||
f"Unsupported CACHE_BACKEND={CACHE_BACKEND!r}. "
|
||||
f"Supported values: {[t.value for t in CacheBackendType]}"
|
||||
f"Unsupported CACHE_BACKEND={CACHE_BACKEND!r}. Supported values: {[t.value for t in CacheBackendType]}"
|
||||
)
|
||||
return builder(tenant_id)
|
||||
|
||||
|
||||
@@ -270,7 +270,10 @@ def extract_headers(
|
||||
|
||||
|
||||
def process_kg_commands(
|
||||
message: str, persona_name: str, tenant_id: str, db_session: Session # noqa: ARG001
|
||||
message: str,
|
||||
persona_name: str,
|
||||
tenant_id: str, # noqa: ARG001
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
# Temporarily, until we have a draft UI for the KG Operations/Management
|
||||
# TODO: move to api endpoint once we get frontend
|
||||
|
||||
@@ -472,8 +472,7 @@ class DynamicCitationProcessor:
|
||||
# Check if we have a mapping for this citation number
|
||||
if num not in self.citation_to_doc:
|
||||
logger.warning(
|
||||
f"Citation number {num} not found in mapping. "
|
||||
f"Available: {list(self.citation_to_doc.keys())}"
|
||||
f"Citation number {num} not found in mapping. Available: {list(self.citation_to_doc.keys())}"
|
||||
)
|
||||
continue
|
||||
|
||||
|
||||
@@ -50,6 +50,7 @@ from onyx.tools.built_in_tools import CITEABLE_TOOLS_NAMES
|
||||
from onyx.tools.built_in_tools import STOPPING_TOOLS_NAMES
|
||||
from onyx.tools.interface import Tool
|
||||
from onyx.tools.models import ChatFile
|
||||
from onyx.tools.models import CustomToolCallSummary
|
||||
from onyx.tools.models import MemoryToolResponseSnapshot
|
||||
from onyx.tools.models import PythonToolRichResponse
|
||||
from onyx.tools.models import ToolCallInfo
|
||||
@@ -156,8 +157,7 @@ def _try_fallback_tool_extraction(
|
||||
)
|
||||
if extracted_tool_calls:
|
||||
logger.info(
|
||||
f"Extracted {len(extracted_tool_calls)} tool call(s) from response text "
|
||||
"as fallback"
|
||||
f"Extracted {len(extracted_tool_calls)} tool call(s) from response text as fallback"
|
||||
)
|
||||
return (
|
||||
LlmStepResult(
|
||||
@@ -396,8 +396,7 @@ def construct_message_history(
|
||||
]
|
||||
if forgotten_meta:
|
||||
logger.debug(
|
||||
f"FileReader: building forgotten-files message for "
|
||||
f"{[(m.file_id, m.filename) for m in forgotten_meta]}"
|
||||
f"FileReader: building forgotten-files message for {[(m.file_id, m.filename) for m in forgotten_meta]}"
|
||||
)
|
||||
forgotten_files_message = _create_file_tool_metadata_message(
|
||||
forgotten_meta, token_counter
|
||||
@@ -487,8 +486,7 @@ def _drop_orphaned_tool_call_responses(
|
||||
sanitized.append(msg)
|
||||
else:
|
||||
logger.debug(
|
||||
"Dropping orphaned tool response with tool_call_id=%s while "
|
||||
"constructing message history",
|
||||
"Dropping orphaned tool response with tool_call_id=%s while constructing message history",
|
||||
msg.tool_call_id,
|
||||
)
|
||||
continue
|
||||
@@ -514,8 +512,7 @@ def _create_file_tool_metadata_message(
|
||||
]
|
||||
for meta in file_metadata:
|
||||
lines.append(
|
||||
f'- file_id="{meta.file_id}" filename="{meta.filename}" '
|
||||
f"(~{meta.approx_char_count:,} chars)"
|
||||
f'- file_id="{meta.file_id}" filename="{meta.filename}" (~{meta.approx_char_count:,} chars)'
|
||||
)
|
||||
|
||||
message_content = "\n".join(lines)
|
||||
@@ -980,6 +977,10 @@ def run_llm_loop(
|
||||
|
||||
if memory_snapshot:
|
||||
saved_response = json.dumps(memory_snapshot.model_dump())
|
||||
elif isinstance(tool_response.rich_response, CustomToolCallSummary):
|
||||
saved_response = json.dumps(
|
||||
tool_response.rich_response.model_dump()
|
||||
)
|
||||
elif isinstance(tool_response.rich_response, str):
|
||||
saved_response = tool_response.rich_response
|
||||
else:
|
||||
|
||||
@@ -15,6 +15,7 @@ from onyx.chat.citation_processor import DynamicCitationProcessor
|
||||
from onyx.chat.emitter import Emitter
|
||||
from onyx.chat.models import ChatMessageSimple
|
||||
from onyx.chat.models import LlmStepResult
|
||||
from onyx.chat.tool_call_args_streaming import maybe_emit_argument_delta
|
||||
from onyx.configs.app_configs import LOG_ONYX_MODEL_INTERACTIONS
|
||||
from onyx.configs.app_configs import PROMPT_CACHE_CHAT_HISTORY
|
||||
from onyx.configs.constants import MessageType
|
||||
@@ -54,6 +55,7 @@ from onyx.server.query_and_chat.streaming_models import ReasoningStart
|
||||
from onyx.tools.models import ToolCallKickoff
|
||||
from onyx.tracing.framework.create import generation_span
|
||||
from onyx.utils.b64 import get_image_type_from_bytes
|
||||
from onyx.utils.jsonriver import Parser
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.postgres_sanitization import sanitize_string
|
||||
from onyx.utils.text_processing import find_all_json_objects
|
||||
@@ -693,8 +695,7 @@ def _build_structured_assistant_message(msg: ChatMessageSimple) -> AssistantMess
|
||||
def _build_structured_tool_response_message(msg: ChatMessageSimple) -> ToolMessage:
|
||||
if not msg.tool_call_id:
|
||||
raise ValueError(
|
||||
"Tool call response message encountered but tool_call_id is not available. "
|
||||
f"Message: {msg}"
|
||||
f"Tool call response message encountered but tool_call_id is not available. Message: {msg}"
|
||||
)
|
||||
|
||||
return ToolMessage(
|
||||
@@ -729,8 +730,7 @@ class _OllamaHistoryMessageFormatter(_HistoryMessageFormatter):
|
||||
|
||||
tool_call_lines = [
|
||||
(
|
||||
f"[Tool Call] name={tc.tool_name} "
|
||||
f"id={tc.tool_call_id} args={json.dumps(tc.tool_arguments)}"
|
||||
f"[Tool Call] name={tc.tool_name} id={tc.tool_call_id} args={json.dumps(tc.tool_arguments)}"
|
||||
)
|
||||
for tc in msg.tool_calls
|
||||
]
|
||||
@@ -748,8 +748,7 @@ class _OllamaHistoryMessageFormatter(_HistoryMessageFormatter):
|
||||
def format_tool_response_message(self, msg: ChatMessageSimple) -> UserMessage:
|
||||
if not msg.tool_call_id:
|
||||
raise ValueError(
|
||||
"Tool call response message encountered but tool_call_id is not available. "
|
||||
f"Message: {msg}"
|
||||
f"Tool call response message encountered but tool_call_id is not available. Message: {msg}"
|
||||
)
|
||||
|
||||
return UserMessage(
|
||||
@@ -837,8 +836,7 @@ def translate_history_to_llm_format(
|
||||
content_parts.append(image_part)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to process image file {img_file.file_id}: {e}. "
|
||||
"Skipping image."
|
||||
f"Failed to process image file {img_file.file_id}: {e}. Skipping image."
|
||||
)
|
||||
user_msg = UserMessage(
|
||||
role="user",
|
||||
@@ -1009,6 +1007,7 @@ def run_llm_step_pkt_generator(
|
||||
)
|
||||
|
||||
id_to_tool_call_map: dict[int, dict[str, Any]] = {}
|
||||
arg_parsers: dict[int, Parser] = {}
|
||||
reasoning_start = False
|
||||
answer_start = False
|
||||
accumulated_reasoning = ""
|
||||
@@ -1215,7 +1214,14 @@ def run_llm_step_pkt_generator(
|
||||
yield from _close_reasoning_if_active()
|
||||
|
||||
for tool_call_delta in delta.tool_calls:
|
||||
# maybe_emit depends and update being called first and attaching the delta
|
||||
_update_tool_call_with_delta(id_to_tool_call_map, tool_call_delta)
|
||||
yield from maybe_emit_argument_delta(
|
||||
tool_calls_in_progress=id_to_tool_call_map,
|
||||
tool_call_delta=tool_call_delta,
|
||||
placement=_current_placement(),
|
||||
parsers=arg_parsers,
|
||||
)
|
||||
|
||||
# Flush any tail text buffered while checking for split "<function_calls" markers.
|
||||
filtered_content_tail = xml_tool_call_content_filter.flush()
|
||||
|
||||
@@ -796,8 +796,7 @@ def handle_stream_message_objects(
|
||||
|
||||
if all_injected_file_metadata:
|
||||
logger.debug(
|
||||
"FileReader: file metadata for LLM: "
|
||||
f"{[(fid, m.filename) for fid, m in all_injected_file_metadata.items()]}"
|
||||
f"FileReader: file metadata for LLM: {[(fid, m.filename) for fid, m in all_injected_file_metadata.items()]}"
|
||||
)
|
||||
|
||||
# Prepend summary message if compression exists
|
||||
|
||||
@@ -87,8 +87,7 @@ def _create_and_link_tool_calls(
|
||||
tool_call_tokens = len(default_tokenizer.encode(arguments_json_str))
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to tokenize tool call arguments for {tool_call_info.tool_call_id}: {e}. "
|
||||
f"Using length as (over) estimate."
|
||||
f"Failed to tokenize tool call arguments for {tool_call_info.tool_call_id}: {e}. Using length as (over) estimate."
|
||||
)
|
||||
arguments_json_str = json.dumps(tool_call_info.tool_call_arguments)
|
||||
tool_call_tokens = len(arguments_json_str)
|
||||
|
||||
77
backend/onyx/chat/tool_call_args_streaming.py
Normal file
77
backend/onyx/chat/tool_call_args_streaming.py
Normal file
@@ -0,0 +1,77 @@
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
from typing import Type
|
||||
|
||||
from onyx.llm.model_response import ChatCompletionDeltaToolCall
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.server.query_and_chat.streaming_models import ToolCallArgumentDelta
|
||||
from onyx.tools.built_in_tools import TOOL_NAME_TO_CLASS
|
||||
from onyx.tools.interface import Tool
|
||||
from onyx.utils.jsonriver import Parser
|
||||
|
||||
|
||||
def _get_tool_class(
|
||||
tool_calls_in_progress: Mapping[int, Mapping[str, Any]],
|
||||
tool_call_delta: ChatCompletionDeltaToolCall,
|
||||
) -> Type[Tool] | None:
|
||||
"""Look up the Tool subclass for a streaming tool call delta."""
|
||||
tool_name = tool_calls_in_progress.get(tool_call_delta.index, {}).get("name")
|
||||
if not tool_name:
|
||||
return None
|
||||
return TOOL_NAME_TO_CLASS.get(tool_name)
|
||||
|
||||
|
||||
def maybe_emit_argument_delta(
|
||||
tool_calls_in_progress: Mapping[int, Mapping[str, Any]],
|
||||
tool_call_delta: ChatCompletionDeltaToolCall,
|
||||
placement: Placement,
|
||||
parsers: dict[int, Parser],
|
||||
) -> Generator[Packet, None, None]:
|
||||
"""Emit decoded tool-call argument deltas to the frontend.
|
||||
|
||||
Uses a ``jsonriver.Parser`` per tool-call index to incrementally parse
|
||||
the JSON argument string and extract only the newly-appended content
|
||||
for each string-valued argument.
|
||||
|
||||
NOTE: Non-string arguments (numbers, booleans, null, arrays, objects)
|
||||
are skipped — they are available in the final tool-call kickoff packet.
|
||||
|
||||
``parsers`` is a mutable dict keyed by tool-call index. A new
|
||||
``Parser`` is created automatically for each new index.
|
||||
"""
|
||||
tool_cls = _get_tool_class(tool_calls_in_progress, tool_call_delta)
|
||||
if not tool_cls or not tool_cls.should_emit_argument_deltas():
|
||||
return
|
||||
|
||||
fn = tool_call_delta.function
|
||||
delta_fragment = fn.arguments if fn else None
|
||||
if not delta_fragment:
|
||||
return
|
||||
|
||||
idx = tool_call_delta.index
|
||||
if idx not in parsers:
|
||||
parsers[idx] = Parser()
|
||||
parser = parsers[idx]
|
||||
|
||||
deltas = parser.feed(delta_fragment)
|
||||
|
||||
argument_deltas: dict[str, str] = {}
|
||||
for delta in deltas:
|
||||
if isinstance(delta, dict):
|
||||
for key, value in delta.items():
|
||||
if isinstance(value, str):
|
||||
argument_deltas[key] = argument_deltas.get(key, "") + value
|
||||
|
||||
if not argument_deltas:
|
||||
return
|
||||
|
||||
tc_data = tool_calls_in_progress[tool_call_delta.index]
|
||||
yield Packet(
|
||||
placement=placement,
|
||||
obj=ToolCallArgumentDelta(
|
||||
tool_type=tc_data.get("name", ""),
|
||||
argument_deltas=argument_deltas,
|
||||
),
|
||||
)
|
||||
@@ -68,6 +68,10 @@ FILE_TOKEN_COUNT_THRESHOLD = int(
|
||||
os.environ.get("FILE_TOKEN_COUNT_THRESHOLD", str(_DEFAULT_FILE_TOKEN_LIMIT))
|
||||
)
|
||||
|
||||
# Maximum upload size for a single user file (chat/projects) in MB.
|
||||
USER_FILE_MAX_UPLOAD_SIZE_MB = int(os.environ.get("USER_FILE_MAX_UPLOAD_SIZE_MB") or 50)
|
||||
USER_FILE_MAX_UPLOAD_SIZE_BYTES = USER_FILE_MAX_UPLOAD_SIZE_MB * 1024 * 1024
|
||||
|
||||
# If set to true, will show extra/uncommon connectors in the "Other" category
|
||||
SHOW_EXTRA_CONNECTORS = os.environ.get("SHOW_EXTRA_CONNECTORS", "").lower() == "true"
|
||||
|
||||
@@ -92,19 +96,12 @@ WEB_DOMAIN = os.environ.get("WEB_DOMAIN") or "http://localhost:3000"
|
||||
#####
|
||||
# Auth Configs
|
||||
#####
|
||||
# Upgrades users from disabled auth to basic auth and shows warning.
|
||||
_auth_type_str = (os.environ.get("AUTH_TYPE") or "basic").lower()
|
||||
if _auth_type_str == "disabled":
|
||||
logger.warning(
|
||||
"AUTH_TYPE='disabled' is no longer supported. "
|
||||
"Defaulting to 'basic'. Please update your configuration. "
|
||||
"Your existing data will be migrated automatically."
|
||||
)
|
||||
_auth_type_str = AuthType.BASIC.value
|
||||
try:
|
||||
# Silently default to basic - warnings/errors logged in verify_auth_setting()
|
||||
# which only runs on app startup, not during migrations/scripts
|
||||
_auth_type_str = (os.environ.get("AUTH_TYPE") or "").lower()
|
||||
if _auth_type_str in [auth_type.value for auth_type in AuthType]:
|
||||
AUTH_TYPE = AuthType(_auth_type_str)
|
||||
except ValueError:
|
||||
logger.error(f"Invalid AUTH_TYPE: {_auth_type_str}. Defaulting to 'basic'.")
|
||||
else:
|
||||
AUTH_TYPE = AuthType.BASIC
|
||||
|
||||
PASSWORD_MIN_LENGTH = int(os.getenv("PASSWORD_MIN_LENGTH", 8))
|
||||
@@ -199,6 +196,10 @@ if _OIDC_SCOPE_OVERRIDE:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Enables PKCE for OIDC login flow. Disabled by default to preserve
|
||||
# backwards compatibility for existing OIDC deployments.
|
||||
OIDC_PKCE_ENABLED = os.environ.get("OIDC_PKCE_ENABLED", "").lower() == "true"
|
||||
|
||||
# Applicable for SAML Auth
|
||||
SAML_CONF_DIR = os.environ.get("SAML_CONF_DIR") or "/app/onyx/configs/saml_config"
|
||||
|
||||
@@ -207,6 +208,12 @@ JWT_PUBLIC_KEY_URL: str | None = os.getenv("JWT_PUBLIC_KEY_URL", None)
|
||||
|
||||
USER_AUTH_SECRET = os.environ.get("USER_AUTH_SECRET", "")
|
||||
|
||||
if AUTH_TYPE == AuthType.BASIC and not USER_AUTH_SECRET:
|
||||
logger.warning(
|
||||
"USER_AUTH_SECRET is not set. This is required for secure password reset "
|
||||
"and email verification tokens. Please set USER_AUTH_SECRET in production."
|
||||
)
|
||||
|
||||
# Duration (in seconds) for which the FastAPI Users JWT token remains valid in the user's browser.
|
||||
# By default, this is set to match the Redis expiry time for consistency.
|
||||
AUTH_COOKIE_EXPIRE_TIME_SECONDS = int(
|
||||
@@ -288,8 +295,9 @@ OPENSEARCH_TEXT_ANALYZER = os.environ.get("OPENSEARCH_TEXT_ANALYZER") or "englis
|
||||
# environments we always want to be dual indexing into both OpenSearch and Vespa
|
||||
# to stress test the new codepaths. Only enable this if there is some instance
|
||||
# of OpenSearch running for the relevant Onyx instance.
|
||||
# NOTE: Now enabled on by default, unless the env indicates otherwise.
|
||||
ENABLE_OPENSEARCH_INDEXING_FOR_ONYX = (
|
||||
os.environ.get("ENABLE_OPENSEARCH_INDEXING_FOR_ONYX", "").lower() == "true"
|
||||
os.environ.get("ENABLE_OPENSEARCH_INDEXING_FOR_ONYX", "true").lower() == "true"
|
||||
)
|
||||
# NOTE: This effectively does nothing anymore, admins can now toggle whether
|
||||
# retrieval is through OpenSearch. This value is only used as a final fallback
|
||||
@@ -307,6 +315,12 @@ VERIFY_CREATE_OPENSEARCH_INDEX_ON_INIT_MT = (
|
||||
os.environ.get("VERIFY_CREATE_OPENSEARCH_INDEX_ON_INIT_MT", "true").lower()
|
||||
== "true"
|
||||
)
|
||||
OPENSEARCH_MIGRATION_GET_VESPA_CHUNKS_PAGE_SIZE = int(
|
||||
os.environ.get("OPENSEARCH_MIGRATION_GET_VESPA_CHUNKS_PAGE_SIZE") or 500
|
||||
)
|
||||
OPENSEARCH_OVERRIDE_DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES = int(
|
||||
os.environ.get("OPENSEARCH_DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES") or 0
|
||||
)
|
||||
|
||||
VESPA_HOST = os.environ.get("VESPA_HOST") or "localhost"
|
||||
# NOTE: this is used if and only if the vespa config server is accessible via a
|
||||
|
||||
@@ -90,8 +90,7 @@ def parse_airtable_url(
|
||||
match = _AIRTABLE_URL_PATTERN.search(url.strip())
|
||||
if not match:
|
||||
raise ValueError(
|
||||
f"Could not parse Airtable URL: '{url}'. "
|
||||
"Expected format: https://airtable.com/appXXX/tblYYY[/viwZZZ]"
|
||||
f"Could not parse Airtable URL: '{url}'. Expected format: https://airtable.com/appXXX/tblYYY[/viwZZZ]"
|
||||
)
|
||||
return match.group(1), match.group(2), match.group(3)
|
||||
|
||||
@@ -170,16 +169,14 @@ class AirtableConnector(LoadConnector):
|
||||
else:
|
||||
if not self.base_id or not self.table_name_or_id:
|
||||
raise ConnectorValidationError(
|
||||
"A valid Airtable URL or base_id and table_name_or_id are required "
|
||||
"when not using index_all mode."
|
||||
"A valid Airtable URL or base_id and table_name_or_id are required when not using index_all mode."
|
||||
)
|
||||
try:
|
||||
table = self.airtable_client.table(self.base_id, self.table_name_or_id)
|
||||
table.schema()
|
||||
except Exception as e:
|
||||
raise ConnectorValidationError(
|
||||
f"Failed to access table '{self.table_name_or_id}' "
|
||||
f"in base '{self.base_id}': {e}"
|
||||
f"Failed to access table '{self.table_name_or_id}' in base '{self.base_id}': {e}"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -391,10 +388,7 @@ class AirtableConnector(LoadConnector):
|
||||
TextSection(
|
||||
link=link,
|
||||
text=(
|
||||
f"{field_name}:\n"
|
||||
"------------------------\n"
|
||||
f"{text}\n"
|
||||
"------------------------"
|
||||
f"{field_name}:\n------------------------\n{text}\n------------------------"
|
||||
),
|
||||
)
|
||||
for text, link in field_value_and_links
|
||||
@@ -440,8 +434,7 @@ class AirtableConnector(LoadConnector):
|
||||
field_type = field_schema.type
|
||||
|
||||
logger.debug(
|
||||
f"Processing field '{field_name}' of type '{field_type}' "
|
||||
f"for record '{record_id}'."
|
||||
f"Processing field '{field_name}' of type '{field_type}' for record '{record_id}'."
|
||||
)
|
||||
|
||||
field_sections, field_metadata = self._process_field(
|
||||
@@ -534,8 +527,7 @@ class AirtableConnector(LoadConnector):
|
||||
break
|
||||
|
||||
logger.info(
|
||||
f"Processing {len(records)} records from table "
|
||||
f"'{table_schema.name}' in base '{base_name or base_id}'."
|
||||
f"Processing {len(records)} records from table '{table_schema.name}' in base '{base_name or base_id}'."
|
||||
)
|
||||
|
||||
if not records:
|
||||
@@ -629,7 +621,6 @@ class AirtableConnector(LoadConnector):
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"Failed to index table '{table.name}' ({table.id}) "
|
||||
f"in base '{base_name}' ({base_id}), skipping."
|
||||
f"Failed to index table '{table.name}' ({table.id}) in base '{base_name}' ({base_id}), skipping."
|
||||
)
|
||||
continue
|
||||
|
||||
@@ -68,7 +68,7 @@ class ClickupConnector(LoadConnector, PollConnector):
|
||||
response = self._make_request(url_endpoint)
|
||||
comments = [
|
||||
TextSection(
|
||||
link=f'https://app.clickup.com/t/{task_id}?comment={comment_dict["id"]}',
|
||||
link=f"https://app.clickup.com/t/{task_id}?comment={comment_dict['id']}",
|
||||
text=comment_dict["comment_text"],
|
||||
)
|
||||
for comment_dict in response["comments"]
|
||||
|
||||
@@ -698,8 +698,7 @@ class CodaConnector(LoadConnector, PollConnector):
|
||||
)
|
||||
elif e.status_code == 429:
|
||||
raise ConnectorValidationError(
|
||||
"Validation failed due to Coda rate-limits being exceeded (HTTP 429). "
|
||||
"Please try again later."
|
||||
"Validation failed due to Coda rate-limits being exceeded (HTTP 429). Please try again later."
|
||||
)
|
||||
else:
|
||||
raise UnexpectedValidationError(
|
||||
|
||||
@@ -95,7 +95,6 @@ def _get_page_id(page: dict[str, Any], allow_missing: bool = False) -> str:
|
||||
|
||||
|
||||
class ConfluenceCheckpoint(ConnectorCheckpoint):
|
||||
|
||||
next_page_url: str | None
|
||||
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user