Compare commits

..

1 Commits

Author SHA1 Message Date
Jamison Lahman
63da8a5b89 chore(devtools): ods web installs node_modules on init 2026-04-09 23:47:55 +00:00
402 changed files with 6118 additions and 18047 deletions

View File

@@ -1,8 +1,8 @@
FROM ubuntu:26.04@sha256:cc925e589b7543b910fea57a240468940003fbfc0515245a495dd0ad8fe7cef1
RUN apt-get update && apt-get install -y --no-install-recommends \
acl \
curl \
default-jre \
fd-find \
fzf \
git \
@@ -25,11 +25,13 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
&& curl -fsSL https://deb.nodesource.com/setup_20.x | bash - \
&& apt-get install -y nodejs \
&& install -m 0755 -d /etc/apt/keyrings \
&& curl -fsSL https://download.docker.com/linux/ubuntu/gpg -o /etc/apt/keyrings/docker.asc \
&& echo "deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/docker.asc] https://download.docker.com/linux/ubuntu $(. /etc/os-release && echo "$VERSION_CODENAME") stable" > /etc/apt/sources.list.d/docker.list \
&& curl -fsSL https://cli.github.com/packages/githubcli-archive-keyring.gpg -o /etc/apt/keyrings/githubcli-archive-keyring.gpg \
&& chmod go+r /etc/apt/keyrings/githubcli-archive-keyring.gpg \
&& echo "deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/githubcli-archive-keyring.gpg] https://cli.github.com/packages stable main" > /etc/apt/sources.list.d/github-cli.list \
&& apt-get update \
&& apt-get install -y --no-install-recommends gh \
&& apt-get install -y --no-install-recommends docker-ce-cli docker-compose-plugin gh \
&& apt-get clean && rm -rf /var/lib/apt/lists/*
# fd-find installs as fdfind on Debian/Ubuntu — symlink to fd

View File

@@ -6,7 +6,7 @@ A containerized development environment for working on Onyx.
- Ubuntu 26.04 base image
- Node.js 20, uv, Claude Code
- GitHub CLI (`gh`)
- Docker CLI, GitHub CLI (`gh`)
- Neovim, ripgrep, fd, fzf, jq, make, wget, unzip
- Zsh as default shell (sources host `~/.zshrc` if available)
- Python venv auto-activation
@@ -14,6 +14,12 @@ A containerized development environment for working on Onyx.
## Usage
### VS Code
1. Install the [Dev Containers extension](https://marketplace.visualstudio.com/items?itemName=ms-vscode-remote.remote-containers)
2. Open this repo in VS Code
3. "Reopen in Container" when prompted
### CLI (`ods dev`)
The [`ods` devtools CLI](../tools/ods/README.md) provides workspace-aware wrappers
@@ -33,8 +39,25 @@ ods dev exec npm test
ods dev stop
```
If you don't have `ods` installed, use the `devcontainer` CLI directly:
```bash
npm install -g @devcontainers/cli
devcontainer up --workspace-folder .
devcontainer exec --workspace-folder . zsh
```
## Restarting the container
### VS Code
Open the Command Palette (`Ctrl+Shift+P` / `Cmd+Shift+P`) and run:
- **Dev Containers: Reopen in Container** — restarts the container without rebuilding
### CLI
```bash
# Restart the container
ods dev restart
@@ -43,6 +66,12 @@ ods dev restart
ods dev rebuild
```
Or without `ods`:
```bash
devcontainer up --workspace-folder . --remove-existing-container
```
## Image
The devcontainer uses a prebuilt image published to `onyxdotapp/onyx-devcontainer`.
@@ -59,19 +88,30 @@ The `devcontainer` target is defined in `docker-bake.hcl` at the repo root.
## User & permissions
The container runs as the `dev` user by default (`remoteUser` in devcontainer.json).
An init script (`init-dev-user.sh`) runs at container start to ensure the active
user has read/write access to the bind-mounted workspace:
An init script (`init-dev-user.sh`) runs at container start to ensure `dev` has
read/write access to the bind-mounted workspace:
- **Standard Docker** — `dev`'s UID/GID is remapped to match the workspace owner,
so file permissions work seamlessly.
- **Rootless Docker** — The workspace appears as root-owned (UID 0) inside the
container due to user-namespace mapping. `ods dev up` auto-detects rootless Docker
and sets `DEVCONTAINER_REMOTE_USER=root` so the container runs as root — which
maps back to your host user via the user namespace. New files are owned by your
host UID and no ACL workarounds are needed.
container due to user-namespace mapping. The init script grants `dev` access via
POSIX ACLs (`setfacl`), which adds a few seconds to the first container start on
large repos.
To override the auto-detection, set `DEVCONTAINER_REMOTE_USER` before running
`ods dev up`.
## Docker socket
The container mounts the host's Docker socket so you can run `docker` commands
from inside. `ods dev` auto-detects the socket path and sets `DOCKER_SOCK`:
| Environment | Socket path |
| ----------------------- | ------------------------------ |
| Linux (rootless Docker) | `$XDG_RUNTIME_DIR/docker.sock` |
| macOS (Docker Desktop) | `~/.docker/run/docker.sock` |
| Linux (standard Docker) | `/var/run/docker.sock` |
To override, set `DOCKER_SOCK` before running `ods dev up`. When using the
VS Code extension or `devcontainer` CLI directly (without `ods`), you must set
`DOCKER_SOCK` yourself.
## Firewall

View File

@@ -1,24 +1,20 @@
{
"name": "Onyx Dev Sandbox",
"image": "onyxdotapp/onyx-devcontainer@sha256:0f02d9299928849c7b15f3b348dcfdcdcb64411ff7a4580cbc026a6ee7aa1554",
"runArgs": ["--cap-add=NET_ADMIN", "--cap-add=NET_RAW", "--network=onyx_default"],
"image": "onyxdotapp/onyx-devcontainer@sha256:12184169c5bcc9cca0388286d5ffe504b569bc9c37bfa631b76ee8eee2064055",
"runArgs": ["--cap-add=NET_ADMIN", "--cap-add=NET_RAW"],
"mounts": [
"source=${localEnv:DOCKER_SOCK},target=/var/run/docker.sock,type=bind",
"source=${localEnv:HOME}/.claude,target=/home/dev/.claude,type=bind",
"source=${localEnv:HOME}/.claude.json,target=/home/dev/.claude.json,type=bind",
"source=${localEnv:HOME}/.zshrc,target=/home/dev/.zshrc.host,type=bind,readonly",
"source=${localEnv:HOME}/.gitconfig,target=/home/dev/.gitconfig,type=bind,readonly",
"source=${localEnv:HOME}/.config/nvim,target=/home/dev/.config/nvim,type=bind,readonly",
"source=${localEnv:HOME}/.gitconfig,target=/home/dev/.gitconfig.host,type=bind,readonly",
"source=${localEnv:HOME}/.ssh,target=/home/dev/.ssh.host,type=bind,readonly",
"source=${localEnv:HOME}/.config/nvim,target=/home/dev/.config/nvim.host,type=bind,readonly",
"source=onyx-devcontainer-cache,target=/home/dev/.cache,type=volume",
"source=onyx-devcontainer-local,target=/home/dev/.local,type=volume"
],
"containerEnv": {
"SSH_AUTH_SOCK": "/tmp/ssh-agent.sock",
"POSTGRES_HOST": "relational_db",
"REDIS_HOST": "cache"
},
"remoteUser": "${localEnv:DEVCONTAINER_REMOTE_USER:dev}",
"remoteUser": "dev",
"updateRemoteUserUID": false,
"initializeCommand": "docker network create onyx_default 2>/dev/null || true",
"workspaceMount": "source=${localWorkspaceFolder},target=/workspace,type=bind,consistency=delegated",
"workspaceFolder": "/workspace",
"postStartCommand": "sudo bash /workspace/.devcontainer/init-dev-user.sh && sudo bash /workspace/.devcontainer/init-firewall.sh",

View File

@@ -8,68 +8,38 @@ set -euo pipefail
# We remap dev to that UID -- fast and seamless.
#
# Rootless Docker: Workspace appears as root-owned (UID 0) inside the
# container due to user-namespace mapping. Requires
# DEVCONTAINER_REMOTE_USER=root (set automatically by
# ods dev up). Container root IS the host user, so
# bind-mounts and named volumes are symlinked into /root.
# container due to user-namespace mapping. We can't remap
# dev to UID 0 (that's root), so we grant access with
# POSIX ACLs instead.
WORKSPACE=/workspace
TARGET_USER=dev
REMOTE_USER="${SUDO_USER:-$TARGET_USER}"
WS_UID=$(stat -c '%u' "$WORKSPACE")
WS_GID=$(stat -c '%g' "$WORKSPACE")
DEV_UID=$(id -u "$TARGET_USER")
DEV_GID=$(id -g "$TARGET_USER")
# devcontainer.json bind-mounts and named volumes target /home/dev regardless
# of remoteUser. When running as root ($HOME=/root), Phase 1 bridges the gap
# with symlinks from ACTIVE_HOME → MOUNT_HOME.
MOUNT_HOME=/home/"$TARGET_USER"
DEV_HOME=/home/"$TARGET_USER"
if [ "$REMOTE_USER" = "root" ]; then
ACTIVE_HOME="/root"
else
ACTIVE_HOME="$MOUNT_HOME"
# Ensure directories that tools expect exist under ~dev.
# ~/.local and ~/.cache are named Docker volumes -- ensure they are owned by dev.
mkdir -p "$DEV_HOME"/.local/state "$DEV_HOME"/.local/share
chown -R "$TARGET_USER":"$TARGET_USER" "$DEV_HOME"/.local
chown -R "$TARGET_USER":"$TARGET_USER" "$DEV_HOME"/.cache
# Copy host configs mounted as *.host into their real locations.
# This gives the dev user owned copies without touching host originals.
if [ -d "$DEV_HOME/.ssh.host" ]; then
cp -a "$DEV_HOME/.ssh.host" "$DEV_HOME/.ssh"
chmod 700 "$DEV_HOME/.ssh"
chmod 600 "$DEV_HOME"/.ssh/id_* 2>/dev/null || true
chown -R "$TARGET_USER":"$TARGET_USER" "$DEV_HOME/.ssh"
fi
# ── Phase 1: home directory setup ───────────────────────────────────
# ~/.local and ~/.cache are named Docker volumes mounted under MOUNT_HOME.
mkdir -p "$MOUNT_HOME"/.local/state "$MOUNT_HOME"/.local/share
# When running as root, symlink bind-mounts and named volumes into /root
# so that $HOME-relative tools (Claude Code, git, etc.) find them.
if [ "$ACTIVE_HOME" != "$MOUNT_HOME" ]; then
for item in .claude .cache .local; do
[ -d "$MOUNT_HOME/$item" ] || continue
if [ -e "$ACTIVE_HOME/$item" ] && [ ! -L "$ACTIVE_HOME/$item" ]; then
echo "warning: replacing $ACTIVE_HOME/$item with symlink to $MOUNT_HOME/$item" >&2
rm -rf "$ACTIVE_HOME/$item"
fi
ln -sfn "$MOUNT_HOME/$item" "$ACTIVE_HOME/$item"
done
# Symlink files (not directories).
for file in .claude.json .gitconfig .zshrc.host; do
[ -f "$MOUNT_HOME/$file" ] && ln -sf "$MOUNT_HOME/$file" "$ACTIVE_HOME/$file"
done
# Nested mount: .config/nvim
if [ -d "$MOUNT_HOME/.config/nvim" ]; then
mkdir -p "$ACTIVE_HOME/.config"
if [ -e "$ACTIVE_HOME/.config/nvim" ] && [ ! -L "$ACTIVE_HOME/.config/nvim" ]; then
echo "warning: replacing $ACTIVE_HOME/.config/nvim with symlink" >&2
rm -rf "$ACTIVE_HOME/.config/nvim"
fi
ln -sfn "$MOUNT_HOME/.config/nvim" "$ACTIVE_HOME/.config/nvim"
fi
fi
# ── Phase 2: workspace access ───────────────────────────────────────
# Root always has workspace access; Phase 1 handled home setup.
if [ "$REMOTE_USER" = "root" ]; then
exit 0
if [ -d "$DEV_HOME/.config/nvim.host" ]; then
mkdir -p "$DEV_HOME/.config"
cp -a "$DEV_HOME/.config/nvim.host" "$DEV_HOME/.config/nvim"
chown -R "$TARGET_USER":"$TARGET_USER" "$DEV_HOME/.config/nvim"
fi
# Already matching -- nothing to do.
@@ -91,17 +61,45 @@ if [ "$WS_UID" != "0" ]; then
echo "warning: failed to remap $TARGET_USER UID to $WS_UID" >&2
fi
fi
if ! chown -R "$TARGET_USER":"$TARGET_USER" "$MOUNT_HOME" 2>&1; then
echo "warning: failed to chown $MOUNT_HOME" >&2
if ! chown -R "$TARGET_USER":"$TARGET_USER" /home/"$TARGET_USER" 2>&1; then
echo "warning: failed to chown /home/$TARGET_USER" >&2
fi
else
# ── Rootless Docker ──────────────────────────────────────────────
# Workspace is root-owned (UID 0) due to user-namespace mapping.
# The supported path is remoteUser=root (set DEVCONTAINER_REMOTE_USER=root),
# which is handled above. If we reach here, the user is running as dev
# under rootless Docker without the override.
echo "error: rootless Docker detected but remoteUser is not root." >&2
echo " Set DEVCONTAINER_REMOTE_USER=root before starting the container," >&2
echo " or use 'ods dev up' which sets it automatically." >&2
exit 1
# Workspace is root-owned inside the container. Grant dev access
# via POSIX ACLs (preserves ownership, works across the namespace
# boundary).
if command -v setfacl &>/dev/null; then
setfacl -Rm "u:${TARGET_USER}:rwX" "$WORKSPACE"
setfacl -Rdm "u:${TARGET_USER}:rwX" "$WORKSPACE" # default ACL for new files
# Git refuses to operate in repos owned by a different UID.
# Host gitconfig is mounted readonly as ~/.gitconfig.host.
# Create a real ~/.gitconfig that includes it plus container overrides.
printf '[include]\n\tpath = %s/.gitconfig.host\n[safe]\n\tdirectory = %s\n' \
"$DEV_HOME" "$WORKSPACE" > "$DEV_HOME/.gitconfig"
chown "$TARGET_USER":"$TARGET_USER" "$DEV_HOME/.gitconfig"
# If this is a worktree, the main .git dir is bind-mounted at its
# host absolute path. Grant dev access so git operations work.
GIT_COMMON_DIR=$(git -C "$WORKSPACE" rev-parse --git-common-dir 2>/dev/null || true)
if [ -n "$GIT_COMMON_DIR" ] && [ "$GIT_COMMON_DIR" != "$WORKSPACE/.git" ]; then
[ ! -d "$GIT_COMMON_DIR" ] && GIT_COMMON_DIR="$WORKSPACE/$GIT_COMMON_DIR"
if [ -d "$GIT_COMMON_DIR" ]; then
setfacl -Rm "u:${TARGET_USER}:rwX" "$GIT_COMMON_DIR"
setfacl -Rdm "u:${TARGET_USER}:rwX" "$GIT_COMMON_DIR"
git config -f "$DEV_HOME/.gitconfig" --add safe.directory "$(dirname "$GIT_COMMON_DIR")"
fi
fi
# Also fix bind-mounted dirs under ~dev that appear root-owned.
for dir in /home/"$TARGET_USER"/.claude; do
[ -d "$dir" ] && setfacl -Rm "u:${TARGET_USER}:rwX" "$dir" && setfacl -Rdm "u:${TARGET_USER}:rwX" "$dir"
done
[ -f /home/"$TARGET_USER"/.claude.json ] && \
setfacl -m "u:${TARGET_USER}:rw" /home/"$TARGET_USER"/.claude.json
else
echo "warning: setfacl not found; dev user may not have write access to workspace" >&2
echo " install the 'acl' package or set remoteUser to root" >&2
fi
fi

View File

@@ -4,12 +4,22 @@ set -euo pipefail
echo "Setting up firewall..."
# Only flush the filter table. The nat and mangle tables are managed by Docker
# (DNS DNAT to 127.0.0.11, container networking, etc.) and must not be touched —
# flushing them breaks Docker's embedded DNS resolver.
# Preserve docker dns resolution
DOCKER_DNS_RULES=$(iptables-save | grep -E "^-A.*-d 127.0.0.11/32" || true)
# Flush all rules
iptables -t nat -F
iptables -t nat -X
iptables -t mangle -F
iptables -t mangle -X
iptables -F
iptables -X
# Restore docker dns rules
if [ -n "$DOCKER_DNS_RULES" ]; then
echo "$DOCKER_DNS_RULES" | iptables-restore -n
fi
# Create ipset for allowed destinations
ipset create allowed-domains hash:net || true
ipset flush allowed-domains
@@ -24,7 +34,6 @@ done
# Resolve allowed domains
ALLOWED_DOMAINS=(
"github.com"
"registry.npmjs.org"
"api.anthropic.com"
"api-staging.anthropic.com"
@@ -47,23 +56,14 @@ for domain in "${ALLOWED_DOMAINS[@]}"; do
done
done
# Allow traffic to the Docker gateway so the container can reach host services
# (e.g. the Onyx stack at localhost:3000, localhost:8080, etc.)
DOCKER_GATEWAY=$(ip -4 route show default | awk '{print $3}')
if [ -n "$DOCKER_GATEWAY" ]; then
# Detect host network
if [[ "${DOCKER_HOST:-}" == "unix://"* ]]; then
DOCKER_GATEWAY=$(ip -4 route show | grep "^default" | awk '{print $3}')
if ! ipset add allowed-domains "$DOCKER_GATEWAY/32" -exist 2>&1; then
echo "warning: failed to add Docker gateway $DOCKER_GATEWAY to allowlist" >&2
fi
fi
# Allow traffic to all attached Docker network subnets so the container can
# reach sibling services (e.g. relational_db, cache) on shared compose networks.
for subnet in $(ip -4 -o addr show scope global | awk '{print $4}'); do
if ! ipset add allowed-domains "$subnet" -exist 2>&1; then
echo "warning: failed to add Docker subnet $subnet to allowlist" >&2
fi
done
# Set default policies to DROP
iptables -P FORWARD DROP
iptables -P INPUT DROP

View File

@@ -44,7 +44,7 @@ jobs:
fetch-tags: true
- name: Setup uv
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # ratchet:astral-sh/setup-uv@v8.0.0
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # ratchet:astral-sh/setup-uv@v7
with:
version: "0.9.9"
enable-cache: false
@@ -165,7 +165,7 @@ jobs:
fetch-depth: 0
- name: Setup uv
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # ratchet:astral-sh/setup-uv@v8.0.0
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # ratchet:astral-sh/setup-uv@v7
with:
version: "0.9.9"
# NOTE: This isn't caching much and zizmor suggests this could be poisoned, so disable.

View File

@@ -114,7 +114,7 @@ jobs:
ref: main
- name: Install the latest version of uv
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # ratchet:astral-sh/setup-uv@v8.0.0
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # ratchet:astral-sh/setup-uv@v7
with:
enable-cache: false
version: "0.9.9"

View File

@@ -471,7 +471,7 @@ jobs:
- name: Install the latest version of uv
if: always()
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # ratchet:astral-sh/setup-uv@v8.0.0
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # ratchet:astral-sh/setup-uv@v7
with:
enable-cache: false
version: "0.9.9"
@@ -710,7 +710,7 @@ jobs:
pull-requests: write
steps:
- name: Download visual diff summaries
uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c
uses: actions/download-artifact@70fc10c6e5e1ce46ad2ea6f2b72d43f7d47b13c3
with:
pattern: screenshot-diff-summary-*
path: summaries/

View File

@@ -38,7 +38,7 @@ jobs:
- name: Install node dependencies
working-directory: ./web
run: npm ci
- uses: j178/prek-action@cbc2f23eb5539cf20d82d1aabd0d0ecbcc56f4e3
- uses: j178/prek-action@0bb87d7f00b0c99306c8bcb8b8beba1eb581c037 # ratchet:j178/prek-action@v1
with:
prek-version: '0.3.4'
extra-args: ${{ github.event_name == 'pull_request' && format('--from-ref {0} --to-ref {1}', github.event.pull_request.base.sha, github.event.pull_request.head.sha) || github.event_name == 'merge_group' && format('--from-ref {0} --to-ref {1}', github.event.merge_group.base_sha, github.event.merge_group.head_sha) || github.ref_name == 'main' && '--all-files' || '' }}

View File

@@ -17,7 +17,7 @@ jobs:
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
with:
persist-credentials: false
- uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # ratchet:astral-sh/setup-uv@v8.0.0
- uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # ratchet:astral-sh/setup-uv@v7
with:
enable-cache: false
version: "0.9.9"

View File

@@ -26,7 +26,7 @@ jobs:
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
with:
persist-credentials: false
- uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # ratchet:astral-sh/setup-uv@v8.0.0
- uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # ratchet:astral-sh/setup-uv@v7
with:
enable-cache: false
version: "0.9.9"

View File

@@ -24,7 +24,7 @@ jobs:
persist-credentials: false
- name: Install the latest version of uv
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # ratchet:astral-sh/setup-uv@v8.0.0
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # ratchet:astral-sh/setup-uv@v7
with:
enable-cache: false
version: "0.9.9"

View File

@@ -1,57 +1,64 @@
{
"labels": [],
"comment": "",
"fixWithAI": true,
"hideFooter": false,
"strictness": 3,
"statusCheck": true,
"commentTypes": ["logic", "syntax", "style"],
"instructions": "",
"disabledLabels": [],
"excludeAuthors": ["dependabot[bot]", "renovate[bot]"],
"ignoreKeywords": "",
"ignorePatterns": "",
"includeAuthors": [],
"summarySection": {
"included": true,
"collapsible": false,
"defaultOpen": false
},
"excludeBranches": [],
"fileChangeLimit": 300,
"includeBranches": [],
"includeKeywords": "",
"triggerOnUpdates": false,
"updateExistingSummaryComment": true,
"updateSummaryOnly": false,
"issuesTableSection": {
"included": true,
"collapsible": false,
"defaultOpen": false
},
"statusCommentsEnabled": true,
"confidenceScoreSection": {
"included": true,
"collapsible": false
},
"sequenceDiagramSection": {
"included": true,
"collapsible": false,
"defaultOpen": false
},
"shouldUpdateDescription": false,
"rules": [
{
"scope": ["web/**"],
"rule": "In Onyx's Next.js app, the `app/ee/admin/` directory is a filesystem convention for Enterprise Edition route overrides — it does NOT add an `/ee/` prefix to the URL. Both `app/admin/groups/page.tsx` and `app/ee/admin/groups/page.tsx` serve the same URL `/admin/groups`. Hardcoded `/admin/...` paths in router.push() calls are correct and do NOT break EE deployments. Do not flag hardcoded admin paths as bugs."
"labels": [],
"comment": "",
"fixWithAI": true,
"hideFooter": false,
"strictness": 3,
"statusCheck": true,
"commentTypes": [
"logic",
"syntax",
"style"
],
"instructions": "",
"disabledLabels": [],
"excludeAuthors": [
"dependabot[bot]",
"renovate[bot]"
],
"ignoreKeywords": "",
"ignorePatterns": "",
"includeAuthors": [],
"summarySection": {
"included": true,
"collapsible": false,
"defaultOpen": false
},
{
"scope": ["web/**"],
"rule": "In Onyx, each API key creates a unique user row in the database with a unique `user_id` (UUID). There is a 1:1 mapping between API keys and their backing user records. Multiple API keys do NOT share the same `user_id`. Do not flag potential duplicate row IDs when using `user_id` from API key descriptors."
"excludeBranches": [],
"fileChangeLimit": 300,
"includeBranches": [],
"includeKeywords": "",
"triggerOnUpdates": true,
"updateExistingSummaryComment": true,
"updateSummaryOnly": false,
"issuesTableSection": {
"included": true,
"collapsible": false,
"defaultOpen": false
},
{
"scope": ["backend/**/*.py"],
"rule": "Never raise HTTPException directly in business code. Use `raise OnyxError(OnyxErrorCode.XXX, \"message\")` from `onyx.error_handling.exceptions`. A global FastAPI exception handler converts OnyxError into structured JSON responses with {\"error_code\": \"...\", \"detail\": \"...\"}. Error codes are defined in `onyx.error_handling.error_codes.OnyxErrorCode`. For upstream errors with dynamic HTTP status codes, use `status_code_override`: `raise OnyxError(OnyxErrorCode.BAD_GATEWAY, detail, status_code_override=upstream_status)`."
}
]
"statusCommentsEnabled": true,
"confidenceScoreSection": {
"included": true,
"collapsible": false
},
"sequenceDiagramSection": {
"included": true,
"collapsible": false,
"defaultOpen": false
},
"shouldUpdateDescription": false,
"rules": [
{
"scope": ["web/**"],
"rule": "In Onyx's Next.js app, the `app/ee/admin/` directory is a filesystem convention for Enterprise Edition route overrides — it does NOT add an `/ee/` prefix to the URL. Both `app/admin/groups/page.tsx` and `app/ee/admin/groups/page.tsx` serve the same URL `/admin/groups`. Hardcoded `/admin/...` paths in router.push() calls are correct and do NOT break EE deployments. Do not flag hardcoded admin paths as bugs."
},
{
"scope": ["web/**"],
"rule": "In Onyx, each API key creates a unique user row in the database with a unique `user_id` (UUID). There is a 1:1 mapping between API keys and their backing user records. Multiple API keys do NOT share the same `user_id`. Do not flag potential duplicate row IDs when using `user_id` from API key descriptors."
},
{
"scope": ["backend/**/*.py"],
"rule": "Never raise HTTPException directly in business code. Use `raise OnyxError(OnyxErrorCode.XXX, \"message\")` from `onyx.error_handling.exceptions`. A global FastAPI exception handler converts OnyxError into structured JSON responses with {\"error_code\": \"...\", \"detail\": \"...\"}. Error codes are defined in `onyx.error_handling.error_codes.OnyxErrorCode`. For upstream errors with dynamic HTTP status codes, use `status_code_override`: `raise OnyxError(OnyxErrorCode.BAD_GATEWAY, detail, status_code_override=upstream_status)`."
}
]
}

12
.vscode/launch.json vendored
View File

@@ -475,18 +475,6 @@
"order": 0
}
},
{
"name": "Start Monitoring Stack (Prometheus + Grafana)",
"type": "node",
"request": "launch",
"runtimeExecutable": "docker",
"runtimeArgs": ["compose", "up", "-d"],
"cwd": "${workspaceFolder}/profiling",
"console": "integratedTerminal",
"presentation": {
"group": "3"
}
},
{
"name": "Clear and Restart External Volumes and Containers",
"type": "node",

View File

@@ -49,12 +49,12 @@ Onyx uses Celery for asynchronous task processing with multiple specialized work
4. **Light Worker** (`light`)
- Handles lightweight, fast operations
- Tasks: vespa metadata sync, connector deletion, doc permissions upsert, checkpoint cleanup, index attempt cleanup
- Tasks: vespa operations, document permissions sync, external group sync
- Higher concurrency for quick tasks
5. **Heavy Worker** (`heavy`)
- Handles resource-intensive operations
- Tasks: connector pruning, document permissions sync, external group sync, CSV generation
- Primary task: document pruning operations
- Runs with 4 threads concurrency
6. **KG Processing Worker** (`kg_processing`)

View File

@@ -1,4 +1,4 @@
FROM python:3.11-slim-bookworm@sha256:9c6f90801e6b68e772b7c0ca74260cbf7af9f320acec894e26fccdaccfbe3b47
FROM python:3.11.7-slim-bookworm
LABEL com.danswer.maintainer="founders@onyx.app"
LABEL com.danswer.description="This image is the web/frontend container of Onyx which \

View File

@@ -1,5 +1,5 @@
# Base stage with dependencies
FROM python:3.11-slim-bookworm@sha256:9c6f90801e6b68e772b7c0ca74260cbf7af9f320acec894e26fccdaccfbe3b47 AS base
FROM python:3.11.7-slim-bookworm AS base
ENV DANSWER_RUNNING_IN_DOCKER="true" \
HF_HOME=/app/.cache/huggingface

View File

@@ -208,7 +208,7 @@ def do_run_migrations(
context.configure(
connection=connection,
target_metadata=target_metadata,
target_metadata=target_metadata, # type: ignore
version_table_schema=schema_name,
include_schemas=True,
compare_type=True,
@@ -380,7 +380,7 @@ def run_migrations_offline() -> None:
logger.info(f"Migrating schema: {schema}")
context.configure(
url=url,
target_metadata=target_metadata,
target_metadata=target_metadata, # type: ignore
literal_binds=True,
version_table_schema=schema,
include_schemas=True,
@@ -421,7 +421,7 @@ def run_migrations_offline() -> None:
logger.info(f"Migrating schema: {schema}")
context.configure(
url=url,
target_metadata=target_metadata,
target_metadata=target_metadata, # type: ignore
literal_binds=True,
version_table_schema=schema,
include_schemas=True,
@@ -464,7 +464,7 @@ def run_migrations_online() -> None:
context.configure(
connection=connection,
target_metadata=target_metadata,
target_metadata=target_metadata, # type: ignore
version_table_schema=schema_name,
include_schemas=True,
compare_type=True,

View File

@@ -25,7 +25,7 @@ def upgrade() -> None:
# Use batch mode to modify the enum type
with op.batch_alter_table("user", schema=None) as batch_op:
batch_op.alter_column(
batch_op.alter_column( # type: ignore[attr-defined]
"role",
type_=sa.Enum(
"BASIC",
@@ -71,7 +71,7 @@ def downgrade() -> None:
op.drop_column("user__user_group", "is_curator")
with op.batch_alter_table("user", schema=None) as batch_op:
batch_op.alter_column(
batch_op.alter_column( # type: ignore[attr-defined]
"role",
type_=sa.Enum(
"BASIC", "ADMIN", name="userrole", native_enum=False, length=20

View File

@@ -63,7 +63,7 @@ def upgrade() -> None:
"time_created",
existing_type=postgresql.TIMESTAMP(timezone=True),
nullable=False,
existing_server_default=sa.text("now()"),
existing_server_default=sa.text("now()"), # type: ignore
)
op.alter_column(
"index_attempt",
@@ -85,7 +85,7 @@ def downgrade() -> None:
"time_created",
existing_type=postgresql.TIMESTAMP(timezone=True),
nullable=True,
existing_server_default=sa.text("now()"),
existing_server_default=sa.text("now()"), # type: ignore
)
op.drop_index(op.f("ix_accesstoken_created_at"), table_name="accesstoken")
op.drop_table("accesstoken")

View File

@@ -19,7 +19,7 @@ depends_on: None = None
def upgrade() -> None:
sequence = Sequence("connector_credential_pair_id_seq")
op.execute(CreateSequence(sequence))
op.execute(CreateSequence(sequence)) # type: ignore
op.add_column(
"connector_credential_pair",
sa.Column(

View File

@@ -1,28 +0,0 @@
"""add_error_tracking_fields_to_index_attempt_errors
Revision ID: d129f37b3d87
Revises: 503883791c39
Create Date: 2026-04-06 19:11:18.261800
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "d129f37b3d87"
down_revision = "503883791c39"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"index_attempt_errors",
sa.Column("error_type", sa.String(), nullable=True),
)
def downgrade() -> None:
op.drop_column("index_attempt_errors", "error_type")

View File

@@ -49,7 +49,7 @@ def run_migrations_offline() -> None:
url = build_connection_string()
context.configure(
url=url,
target_metadata=target_metadata,
target_metadata=target_metadata, # type: ignore
literal_binds=True,
dialect_opts={"paramstyle": "named"},
)
@@ -61,7 +61,7 @@ def run_migrations_offline() -> None:
def do_run_migrations(connection: Connection) -> None:
context.configure(
connection=connection,
target_metadata=target_metadata,
target_metadata=target_metadata, # type: ignore[arg-type]
)
with context.begin_transaction():

View File

@@ -13,7 +13,6 @@ from ee.onyx.server.license.models import LicenseSource
from onyx.auth.schemas import UserRole
from onyx.cache.factory import get_cache_backend
from onyx.configs.constants import ANONYMOUS_USER_EMAIL
from onyx.db.enums import AccountType
from onyx.db.models import License
from onyx.db.models import User
from onyx.utils.logger import setup_logger
@@ -108,13 +107,12 @@ def get_used_seats(tenant_id: str | None = None) -> int:
Get current seat usage directly from database.
For multi-tenant: counts users in UserTenantMapping for this tenant.
For self-hosted: counts all active users.
For self-hosted: counts all active users (excludes EXT_PERM_USER role
and the anonymous system user).
Only human accounts count toward seat limits.
SERVICE_ACCOUNT (API key dummy users), EXT_PERM_USER, and the
anonymous system user are excluded. BOT (Slack users) ARE counted
because they represent real humans and get upgraded to STANDARD
when they log in via web.
TODO: Exclude API key dummy users from seat counting. API keys create
users with emails like `__DANSWER_API_KEY_*` that should not count toward
seat limits. See: https://linear.app/onyx-app/issue/ENG-3518
"""
if MULTI_TENANT:
from ee.onyx.server.tenants.user_mapping import get_tenant_count
@@ -131,7 +129,6 @@ def get_used_seats(tenant_id: str | None = None) -> int:
User.is_active == True, # type: ignore # noqa: E712
User.role != UserRole.EXT_PERM_USER,
User.email != ANONYMOUS_USER_EMAIL, # type: ignore
User.account_type != AccountType.SERVICE_ACCOUNT,
)
)
return result.scalar() or 0

View File

@@ -11,8 +11,6 @@ require a valid SCIM bearer token.
from __future__ import annotations
import hashlib
import struct
from uuid import UUID
from fastapi import APIRouter
@@ -24,7 +22,6 @@ from fastapi import Response
from fastapi.responses import JSONResponse
from fastapi_users.password import PasswordHelper
from sqlalchemy import func
from sqlalchemy import text
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
@@ -68,25 +65,12 @@ from onyx.db.permissions import recompute_user_permissions__no_commit
from onyx.db.users import assign_user_to_default_groups__no_commit
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
# Group names reserved for system default groups (seeded by migration).
_RESERVED_GROUP_NAMES = frozenset({"Admin", "Basic"})
# Namespace prefix for the seat-allocation advisory lock. Hashed together
# with the tenant ID so the lock is scoped per-tenant (unrelated tenants
# never block each other) and cannot collide with unrelated advisory locks.
_SEAT_LOCK_NAMESPACE = "onyx_scim_seat_lock"
def _seat_lock_id_for_tenant(tenant_id: str) -> int:
"""Derive a stable 64-bit signed int lock id for this tenant's seat lock."""
digest = hashlib.sha256(f"{_SEAT_LOCK_NAMESPACE}:{tenant_id}".encode()).digest()
# pg_advisory_xact_lock takes a signed 8-byte int; unpack as such.
return struct.unpack("q", digest[:8])[0]
class ScimJSONResponse(JSONResponse):
"""JSONResponse with Content-Type: application/scim+json (RFC 7644 §3.1)."""
@@ -225,37 +209,12 @@ def _apply_exclusions(
def _check_seat_availability(dal: ScimDAL) -> str | None:
"""Return an error message if seat limit is reached, else None.
Acquires a transaction-scoped advisory lock so that concurrent
SCIM requests are serialized. IdPs like Okta send provisioning
requests in parallel batches — without serialization the check is
vulnerable to a TOCTOU race where N concurrent requests each see
"seats available", all insert, and the tenant ends up over its
seat limit.
The lock is held until the caller's next COMMIT or ROLLBACK, which
means the seat count cannot change between the check here and the
subsequent INSERT/UPDATE. Each call site in this module follows
the pattern: _check_seat_availability → write → dal.commit()
(which releases the lock for the next waiting request).
"""
"""Return an error message if seat limit is reached, else None."""
check_fn = fetch_ee_implementation_or_noop(
"onyx.db.license", "check_seat_availability", None
)
if check_fn is None:
return None
# Transaction-scoped advisory lock — released on dal.commit() / dal.rollback().
# The lock id is derived from the tenant so unrelated tenants never block
# each other, and from a namespace string so it cannot collide with
# unrelated advisory locks elsewhere in the codebase.
lock_id = _seat_lock_id_for_tenant(get_current_tenant_id())
dal.session.execute(
text("SELECT pg_advisory_xact_lock(:lock_id)"),
{"lock_id": lock_id},
)
result = check_fn(dal.session, seats_needed=1)
if not result.available:
return result.error_message or "Seat limit reached"

View File

@@ -96,14 +96,11 @@ def get_model_app() -> FastAPI:
title="Onyx Model Server", version=__version__, lifespan=lifespan
)
if SENTRY_DSN:
from onyx.configs.sentry import _add_instance_tags
sentry_sdk.init(
dsn=SENTRY_DSN,
integrations=[StarletteIntegration(), FastApiIntegration()],
traces_sample_rate=0.1,
release=__version__,
before_send=_add_instance_tags,
)
logger.info("Sentry initialized")
else:

View File

@@ -10,7 +10,6 @@ from celery import bootsteps # type: ignore
from celery import Task
from celery.app import trace
from celery.exceptions import WorkerShutdown
from celery.signals import before_task_publish
from celery.signals import task_postrun
from celery.signals import task_prerun
from celery.states import READY_STATES
@@ -63,14 +62,11 @@ logger = setup_logger()
task_logger = get_task_logger(__name__)
if SENTRY_DSN:
from onyx.configs.sentry import _add_instance_tags
sentry_sdk.init(
dsn=SENTRY_DSN,
integrations=[CeleryIntegration()],
traces_sample_rate=0.1,
release=__version__,
before_send=_add_instance_tags,
)
logger.info("Sentry initialized")
else:
@@ -98,17 +94,6 @@ class TenantAwareTask(Task):
CURRENT_TENANT_ID_CONTEXTVAR.set(None)
@before_task_publish.connect
def on_before_task_publish(
headers: dict[str, Any] | None = None,
**kwargs: Any, # noqa: ARG001
) -> None:
"""Stamp the current wall-clock time into the task message headers so that
workers can compute queue wait time (time between publish and execution)."""
if headers is not None:
headers["enqueued_at"] = time.time()
@task_prerun.connect
def on_task_prerun(
sender: Any | None = None, # noqa: ARG001

View File

@@ -16,12 +16,6 @@ from onyx.configs.app_configs import VESPA_CLOUD_CERT_PATH
from onyx.configs.app_configs import VESPA_CLOUD_KEY_PATH
from onyx.configs.constants import POSTGRES_CELERY_WORKER_LIGHT_APP_NAME
from onyx.db.engine.sql_engine import SqlEngine
from onyx.server.metrics.celery_task_metrics import on_celery_task_postrun
from onyx.server.metrics.celery_task_metrics import on_celery_task_prerun
from onyx.server.metrics.celery_task_metrics import on_celery_task_rejected
from onyx.server.metrics.celery_task_metrics import on_celery_task_retry
from onyx.server.metrics.celery_task_metrics import on_celery_task_revoked
from onyx.server.metrics.metrics_server import start_metrics_server
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
@@ -42,7 +36,6 @@ def on_task_prerun(
**kwds: Any,
) -> None:
app_base.on_task_prerun(sender, task_id, task, args, kwargs, **kwds)
on_celery_task_prerun(task_id, task)
@signals.task_postrun.connect
@@ -57,31 +50,6 @@ def on_task_postrun(
**kwds: Any,
) -> None:
app_base.on_task_postrun(sender, task_id, task, args, kwargs, retval, state, **kwds)
on_celery_task_postrun(task_id, task, state)
@signals.task_retry.connect
def on_task_retry(sender: Any | None = None, **kwargs: Any) -> None: # noqa: ARG001
task_id = getattr(getattr(sender, "request", None), "id", None)
on_celery_task_retry(task_id, sender)
@signals.task_revoked.connect
def on_task_revoked(sender: Any | None = None, **kwargs: Any) -> None:
task_name = getattr(sender, "name", None) or str(sender)
on_celery_task_revoked(kwargs.get("task_id"), task_name)
@signals.task_rejected.connect
def on_task_rejected(sender: Any | None = None, **kwargs: Any) -> None: # noqa: ARG001
message = kwargs.get("message")
task_name: str | None = None
if message is not None:
headers = getattr(message, "headers", None) or {}
task_name = headers.get("task")
if task_name is None:
task_name = "unknown"
on_celery_task_rejected(None, task_name)
@celeryd_init.connect
@@ -122,7 +90,6 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
@worker_ready.connect
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
start_metrics_server("light")
app_base.on_worker_ready(sender, **kwargs)

View File

@@ -59,11 +59,6 @@ from onyx.redis.redis_connector_delete import RedisConnectorDelete
from onyx.redis.redis_connector_delete import RedisConnectorDeletePayload
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_pool import get_redis_replica_client
from onyx.server.metrics.deletion_metrics import inc_deletion_blocked
from onyx.server.metrics.deletion_metrics import inc_deletion_completed
from onyx.server.metrics.deletion_metrics import inc_deletion_fence_reset
from onyx.server.metrics.deletion_metrics import inc_deletion_started
from onyx.server.metrics.deletion_metrics import observe_deletion_taskset_duration
from onyx.utils.variable_functionality import (
fetch_versioned_implementation_with_fallback,
)
@@ -107,7 +102,7 @@ def revoke_tasks_blocking_deletion(
f"Revoked permissions sync task {permissions_sync_payload.celery_task_id}."
)
except Exception:
task_logger.exception("Exception while revoking permissions sync task")
task_logger.exception("Exception while revoking pruning task")
try:
prune_payload = redis_connector.prune.payload
@@ -115,7 +110,7 @@ def revoke_tasks_blocking_deletion(
app.control.revoke(prune_payload.celery_task_id)
task_logger.info(f"Revoked pruning task {prune_payload.celery_task_id}.")
except Exception:
task_logger.exception("Exception while revoking pruning task")
task_logger.exception("Exception while revoking permissions sync task")
try:
external_group_sync_payload = redis_connector.external_group_sync.payload
@@ -305,7 +300,6 @@ def try_generate_document_cc_pair_cleanup_tasks(
recent_index_attempts
and recent_index_attempts[0].status == IndexingStatus.IN_PROGRESS
):
inc_deletion_blocked(tenant_id, "indexing")
raise TaskDependencyError(
"Connector deletion - Delayed (indexing in progress): "
f"cc_pair={cc_pair_id} "
@@ -313,13 +307,11 @@ def try_generate_document_cc_pair_cleanup_tasks(
)
if redis_connector.prune.fenced:
inc_deletion_blocked(tenant_id, "pruning")
raise TaskDependencyError(
f"Connector deletion - Delayed (pruning in progress): cc_pair={cc_pair_id}"
)
if redis_connector.permissions.fenced:
inc_deletion_blocked(tenant_id, "permissions")
raise TaskDependencyError(
f"Connector deletion - Delayed (permissions in progress): cc_pair={cc_pair_id}"
)
@@ -367,7 +359,6 @@ def try_generate_document_cc_pair_cleanup_tasks(
# set this only after all tasks have been added
fence_payload.num_tasks = tasks_generated
redis_connector.delete.set_fence(fence_payload)
inc_deletion_started(tenant_id)
return tasks_generated
@@ -517,11 +508,7 @@ def monitor_connector_deletion_taskset(
db_session=db_session,
connector_id=connector_id_to_delete,
)
if not connector:
task_logger.info(
"Connector deletion - Connector already deleted, skipping connector cleanup"
)
elif not len(connector.credentials):
if not connector or not len(connector.credentials):
task_logger.info(
"Connector deletion - Found no credentials left for connector, deleting connector"
)
@@ -536,12 +523,6 @@ def monitor_connector_deletion_taskset(
num_docs_synced=fence_data.num_tasks,
)
duration = (
datetime.now(timezone.utc) - fence_data.submitted
).total_seconds()
observe_deletion_taskset_duration(tenant_id, "success", duration)
inc_deletion_completed(tenant_id, "success")
except Exception as e:
db_session.rollback()
stack_trace = traceback.format_exc()
@@ -560,11 +541,6 @@ def monitor_connector_deletion_taskset(
f"Connector deletion exceptioned: "
f"cc_pair={cc_pair_id} connector={connector_id_to_delete} credential={credential_id_to_delete}"
)
duration = (
datetime.now(timezone.utc) - fence_data.submitted
).total_seconds()
observe_deletion_taskset_duration(tenant_id, "failure", duration)
inc_deletion_completed(tenant_id, "failure")
raise e
task_logger.info(
@@ -741,6 +717,5 @@ def validate_connector_deletion_fence(
f"fence={fence_key}"
)
inc_deletion_fence_reset(tenant_id)
redis_connector.delete.reset()
return

View File

@@ -135,13 +135,10 @@ def _docfetching_task(
# Since connector_indexing_proxy_task spawns a new process using this function as
# the entrypoint, we init Sentry here.
if SENTRY_DSN:
from onyx.configs.sentry import _add_instance_tags
sentry_sdk.init(
dsn=SENTRY_DSN,
traces_sample_rate=0.1,
release=__version__,
before_send=_add_instance_tags,
)
logger.info("Sentry initialized")
else:

View File

@@ -3,7 +3,6 @@ import os
import time
import traceback
from collections import defaultdict
from dataclasses import dataclass
from datetime import datetime
from datetime import timedelta
from datetime import timezone
@@ -51,7 +50,6 @@ from onyx.configs.constants import AuthType
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_INDEXING_LOCK_TIMEOUT
from onyx.configs.constants import MilestoneRecordType
from onyx.configs.constants import NotificationType
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
@@ -87,8 +85,6 @@ from onyx.db.indexing_coordination import INDEXING_PROGRESS_TIMEOUT_HOURS
from onyx.db.indexing_coordination import IndexingCoordination
from onyx.db.models import IndexAttempt
from onyx.db.models import SearchSettings
from onyx.db.notification import create_notification
from onyx.db.notification import get_notifications
from onyx.db.search_settings import get_current_search_settings
from onyx.db.search_settings import get_secondary_search_settings
from onyx.db.swap_index import check_and_perform_index_swap
@@ -109,9 +105,6 @@ from onyx.redis.redis_pool import get_redis_replica_client
from onyx.redis.redis_pool import redis_lock_dump
from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT
from onyx.redis.redis_utils import is_fence
from onyx.server.metrics.connector_health_metrics import on_connector_error_state_change
from onyx.server.metrics.connector_health_metrics import on_connector_indexing_success
from onyx.server.metrics.connector_health_metrics import on_index_attempt_status_change
from onyx.server.runtime.onyx_runtime import OnyxRuntime
from onyx.utils.logger import setup_logger
from onyx.utils.middleware import make_randomized_onyx_request_id
@@ -407,6 +400,7 @@ def check_indexing_completion(
tenant_id: str,
task: Task,
) -> None:
logger.info(
f"Checking for indexing completion: attempt={index_attempt_id} tenant={tenant_id}"
)
@@ -527,23 +521,13 @@ def check_indexing_completion(
# Update CC pair status if successful
cc_pair = get_connector_credential_pair_from_id(
db_session,
attempt.connector_credential_pair_id,
eager_load_connector=True,
db_session, attempt.connector_credential_pair_id
)
if cc_pair is None:
raise RuntimeError(
f"CC pair {attempt.connector_credential_pair_id} not found in database"
)
source = cc_pair.connector.source.value
on_index_attempt_status_change(
tenant_id=tenant_id,
source=source,
cc_pair_id=cc_pair.id,
status=attempt.status.value,
)
if attempt.status.is_successful():
# NOTE: we define the last successful index time as the time the last successful
# attempt finished. This is distinct from the poll_range_end of the last successful
@@ -564,39 +548,10 @@ def check_indexing_completion(
event=MilestoneRecordType.CONNECTOR_SUCCEEDED,
)
on_connector_indexing_success(
tenant_id=tenant_id,
source=source,
cc_pair_id=cc_pair.id,
docs_indexed=attempt.new_docs_indexed or 0,
success_timestamp=attempt.time_updated.timestamp(),
)
# Clear repeated error state on success
if cc_pair.in_repeated_error_state:
cc_pair.in_repeated_error_state = False
# Delete any existing error notification for this CC pair so a
# fresh one is created if the connector fails again later.
for notif in get_notifications(
user=None,
db_session=db_session,
notif_type=NotificationType.CONNECTOR_REPEATED_ERRORS,
include_dismissed=True,
):
if (
notif.additional_data
and notif.additional_data.get("cc_pair_id") == cc_pair.id
):
db_session.delete(notif)
db_session.commit()
on_connector_error_state_change(
tenant_id=tenant_id,
source=source,
cc_pair_id=cc_pair.id,
in_error=False,
)
if attempt.status == IndexingStatus.SUCCESS:
logger.info(
@@ -653,27 +608,6 @@ def active_indexing_attempt(
return bool(active_indexing_attempt)
@dataclass
class _KickoffResult:
"""Tracks diagnostic counts from a _kickoff_indexing_tasks run."""
created: int = 0
skipped_active: int = 0
skipped_not_found: int = 0
skipped_not_indexable: int = 0
failed_to_create: int = 0
@property
def evaluated(self) -> int:
return (
self.created
+ self.skipped_active
+ self.skipped_not_found
+ self.skipped_not_indexable
+ self.failed_to_create
)
def _kickoff_indexing_tasks(
celery_app: Celery,
db_session: Session,
@@ -683,12 +617,12 @@ def _kickoff_indexing_tasks(
redis_client: Redis,
lock_beat: RedisLock,
tenant_id: str,
) -> _KickoffResult:
) -> int:
"""Kick off indexing tasks for the given cc_pair_ids and search_settings.
Returns a _KickoffResult with diagnostic counts.
Returns the number of tasks successfully created.
"""
result = _KickoffResult()
tasks_created = 0
for cc_pair_id in cc_pair_ids:
lock_beat.reacquire()
@@ -699,7 +633,6 @@ def _kickoff_indexing_tasks(
search_settings_id=search_settings.id,
db_session=db_session,
):
result.skipped_active += 1
continue
cc_pair = get_connector_credential_pair_from_id(
@@ -710,7 +643,6 @@ def _kickoff_indexing_tasks(
task_logger.warning(
f"_kickoff_indexing_tasks - CC pair not found: cc_pair={cc_pair_id}"
)
result.skipped_not_found += 1
continue
# Heavyweight check after fetching cc pair
@@ -725,7 +657,6 @@ def _kickoff_indexing_tasks(
f"search_settings={search_settings.id}, "
f"secondary_index_building={secondary_index_building}"
)
result.skipped_not_indexable += 1
continue
task_logger.debug(
@@ -765,14 +696,13 @@ def _kickoff_indexing_tasks(
task_logger.info(
f"Connector indexing queued: index_attempt={attempt_id} cc_pair={cc_pair.id} search_settings={search_settings.id}"
)
result.created += 1
tasks_created += 1
else:
task_logger.error(
f"Failed to create indexing task: cc_pair={cc_pair.id} search_settings={search_settings.id}"
)
result.failed_to_create += 1
return result
return tasks_created
@shared_task(
@@ -798,8 +728,6 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
task_logger.warning("check_for_indexing - Starting")
tasks_created = 0
primary_result = _KickoffResult()
secondary_result: _KickoffResult | None = None
locked = False
redis_client = get_redis_client()
redis_client_replica = get_redis_replica_client()
@@ -920,39 +848,6 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
cc_pair_id=cc_pair_id,
in_repeated_error_state=True,
)
on_connector_error_state_change(
tenant_id=tenant_id,
source=cc_pair.connector.source.value,
cc_pair_id=cc_pair_id,
in_error=True,
)
connector_name = (
cc_pair.name
or cc_pair.connector.name
or f"CC pair {cc_pair.id}"
)
source = cc_pair.connector.source.value
connector_url = f"/admin/connector/{cc_pair.id}"
create_notification(
user_id=None,
notif_type=NotificationType.CONNECTOR_REPEATED_ERRORS,
db_session=db_session,
title=f"Connector '{connector_name}' has entered repeated error state",
description=(
f"The {source} connector has failed repeatedly and "
f"has been flagged. View indexing history in the "
f"Advanced section: {connector_url}"
),
additional_data={"cc_pair_id": cc_pair.id},
)
task_logger.error(
f"Connector entered repeated error state: "
f"cc_pair={cc_pair.id} "
f"connector={cc_pair.connector.name} "
f"source={source}"
)
# When entering repeated error state, also pause the connector
# to prevent continued indexing retry attempts burning through embedding credits.
# NOTE: only for Cloud, since most self-hosted users use self-hosted embedding
@@ -968,7 +863,7 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
# Heavy check, should_index(), is called in _kickoff_indexing_tasks
with get_session_with_current_tenant() as db_session:
# Primary first
primary_result = _kickoff_indexing_tasks(
tasks_created += _kickoff_indexing_tasks(
celery_app=self.app,
db_session=db_session,
search_settings=current_search_settings,
@@ -978,7 +873,6 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
lock_beat=lock_beat,
tenant_id=tenant_id,
)
tasks_created += primary_result.created
# Secondary indexing (only if secondary search settings exist and switchover_type is not INSTANT)
if (
@@ -986,7 +880,7 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
and secondary_search_settings.switchover_type != SwitchoverType.INSTANT
and secondary_cc_pair_ids
):
secondary_result = _kickoff_indexing_tasks(
tasks_created += _kickoff_indexing_tasks(
celery_app=self.app,
db_session=db_session,
search_settings=secondary_search_settings,
@@ -996,7 +890,6 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
lock_beat=lock_beat,
tenant_id=tenant_id,
)
tasks_created += secondary_result.created
elif (
secondary_search_settings
and secondary_search_settings.switchover_type == SwitchoverType.INSTANT
@@ -1109,26 +1002,7 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
redis_lock_dump(lock_beat, redis_client)
time_elapsed = time.monotonic() - time_start
task_logger.info(
f"check_for_indexing finished: "
f"elapsed={time_elapsed:.2f}s "
f"primary=[evaluated={primary_result.evaluated} "
f"created={primary_result.created} "
f"skipped_active={primary_result.skipped_active} "
f"skipped_not_found={primary_result.skipped_not_found} "
f"skipped_not_indexable={primary_result.skipped_not_indexable} "
f"failed={primary_result.failed_to_create}]"
+ (
f" secondary=[evaluated={secondary_result.evaluated} "
f"created={secondary_result.created} "
f"skipped_active={secondary_result.skipped_active} "
f"skipped_not_found={secondary_result.skipped_not_found} "
f"skipped_not_indexable={secondary_result.skipped_not_indexable} "
f"failed={secondary_result.failed_to_create}]"
if secondary_result
else ""
)
)
task_logger.info(f"check_for_indexing finished: elapsed={time_elapsed:.2f}")
return tasks_created

View File

@@ -172,10 +172,6 @@ def migrate_chunks_from_vespa_to_opensearch_task(
search_settings = get_current_search_settings(db_session)
indexing_setting = IndexingSetting.from_db_model(search_settings)
task_logger.debug(
"Verified tenant info, migration record, and search settings."
)
# 2.e. Build sanitized to original doc ID mapping to check for
# conflicts in the event we sanitize a doc ID to an
# already-existing doc ID.
@@ -329,7 +325,6 @@ def migrate_chunks_from_vespa_to_opensearch_task(
finally:
if lock.owned():
lock.release()
task_logger.debug("Released the OpenSearch migration lock.")
else:
task_logger.warning(
"The OpenSearch migration lock was not owned on completion of the migration task."

View File

@@ -38,7 +38,6 @@ from onyx.configs.constants import OnyxRedisConstants
from onyx.configs.constants import OnyxRedisLocks
from onyx.configs.constants import OnyxRedisSignals
from onyx.connectors.factory import instantiate_connector
from onyx.connectors.interfaces import BaseConnector
from onyx.connectors.models import InputType
from onyx.db.connector import mark_ccpair_as_pruned
from onyx.db.connector_credential_pair import get_connector_credential_pair
@@ -526,14 +525,6 @@ def connector_pruning_generator_task(
return None
try:
# Session 1: pre-enumeration — load cc_pair and instantiate the connector.
# The session is closed before enumeration so the DB connection is not held
# open during the 1030+ minute connector crawl.
connector_source: DocumentSource | None = None
connector_type: str = ""
is_connector_public: bool = False
runnable_connector: BaseConnector | None = None
with get_session_with_current_tenant() as db_session:
cc_pair = get_connector_credential_pair(
db_session=db_session,
@@ -559,51 +550,49 @@ def connector_pruning_generator_task(
)
redis_connector.prune.set_fence(new_payload)
connector_source = cc_pair.connector.source
connector_type = connector_source.value
is_connector_public = cc_pair.access_type == AccessType.PUBLIC
task_logger.info(
f"Pruning generator running connector: cc_pair={cc_pair_id} connector_source={connector_source}"
f"Pruning generator running connector: cc_pair={cc_pair_id} connector_source={cc_pair.connector.source}"
)
runnable_connector = instantiate_connector(
db_session,
connector_source,
cc_pair.connector.source,
InputType.SLIM_RETRIEVAL,
cc_pair.connector.connector_specific_config,
cc_pair.credential,
)
# Session 1 closed here — connection released before enumeration.
callback = PruneCallback(
0,
redis_connector,
lock,
r,
timeout_seconds=JOB_TIMEOUT,
)
callback = PruneCallback(
0,
redis_connector,
lock,
r,
timeout_seconds=JOB_TIMEOUT,
)
# Extract docs and hierarchy nodes from the source (no DB session held).
extraction_result = extract_ids_from_runnable_connector(
runnable_connector, callback, connector_type=connector_type
)
all_connector_doc_ids = extraction_result.raw_id_to_parent
# Extract docs and hierarchy nodes from the source
connector_type = cc_pair.connector.source.value
extraction_result = extract_ids_from_runnable_connector(
runnable_connector, callback, connector_type=connector_type
)
all_connector_doc_ids = extraction_result.raw_id_to_parent
# Session 2: post-enumeration — hierarchy upserts, diff computation, task dispatch.
with get_session_with_current_tenant() as db_session:
source = connector_source
# Process hierarchy nodes (same as docfetching):
# upsert to Postgres and cache in Redis
source = cc_pair.connector.source
redis_client = get_redis_client(tenant_id=tenant_id)
ensure_source_node_exists(redis_client, db_session, source)
upserted_nodes: list[DBHierarchyNode] = []
if extraction_result.hierarchy_nodes:
is_connector_public = cc_pair.access_type == AccessType.PUBLIC
upserted_nodes = upsert_hierarchy_nodes_batch(
db_session=db_session,
nodes=extraction_result.hierarchy_nodes,
source=source,
commit=False,
commit=True,
is_connector_public=is_connector_public,
)
@@ -612,13 +601,9 @@ def connector_pruning_generator_task(
hierarchy_node_ids=[n.id for n in upserted_nodes],
connector_id=connector_id,
credential_id=credential_id,
commit=False,
commit=True,
)
# Single commit so the FK reference in the join table can never
# outrun the parent hierarchy_node insert.
db_session.commit()
cache_entries = [
HierarchyNodeCacheEntry.from_db_model(node)
for node in upserted_nodes
@@ -673,7 +658,7 @@ def connector_pruning_generator_task(
task_logger.info(
"Pruning set collected: "
f"cc_pair={cc_pair_id} "
f"connector_source={connector_source} "
f"connector_source={cc_pair.connector.source} "
f"docs_to_remove={len(doc_ids_to_remove)}"
)

View File

@@ -23,8 +23,6 @@ class IndexAttemptErrorPydantic(BaseModel):
index_attempt_id: int
error_type: str | None = None
@classmethod
def from_model(cls, model: IndexAttemptError) -> "IndexAttemptErrorPydantic":
return cls(
@@ -39,5 +37,4 @@ class IndexAttemptErrorPydantic(BaseModel):
is_resolved=model.is_resolved,
time_created=model.time_created,
index_attempt_id=model.index_attempt_id,
error_type=model.error_type,
)

View File

@@ -5,7 +5,6 @@ from datetime import datetime
from datetime import timedelta
from datetime import timezone
import sentry_sdk
from celery import Celery
from sqlalchemy.orm import Session
@@ -69,7 +68,6 @@ from onyx.redis.redis_pool import get_redis_client
from onyx.server.features.build.indexing.persistent_document_writer import (
get_persistent_document_writer,
)
from onyx.server.metrics.connector_health_metrics import on_index_attempt_status_change
from onyx.utils.logger import setup_logger
from onyx.utils.middleware import make_randomized_onyx_request_id
from onyx.utils.postgres_sanitization import sanitize_document_for_postgres
@@ -269,13 +267,6 @@ def run_docfetching_entrypoint(
)
credential_id = attempt.connector_credential_pair.credential_id
on_index_attempt_status_change(
tenant_id=tenant_id,
source=attempt.connector_credential_pair.connector.source.value,
cc_pair_id=connector_credential_pair_id,
status="in_progress",
)
logger.info(
f"Docfetching starting{tenant_str}: "
f"connector='{connector_name}' "
@@ -565,27 +556,6 @@ def connector_document_extraction(
# save record of any failures at the connector level
if failure is not None:
if failure.exception is not None:
with sentry_sdk.new_scope() as scope:
scope.set_tag("stage", "connector_fetch")
scope.set_tag("connector_source", db_connector.source.value)
scope.set_tag("cc_pair_id", str(cc_pair_id))
scope.set_tag("index_attempt_id", str(index_attempt_id))
scope.set_tag("tenant_id", tenant_id)
if failure.failed_document:
scope.set_tag(
"doc_id", failure.failed_document.document_id
)
if failure.failed_entity:
scope.set_tag(
"entity_id", failure.failed_entity.entity_id
)
scope.fingerprint = [
"connector-fetch-failure",
db_connector.source.value,
type(failure.exception).__name__,
]
sentry_sdk.capture_exception(failure.exception)
total_failures += 1
with get_session_with_current_tenant() as db_session:
create_index_attempt_error(

View File

@@ -364,7 +364,7 @@ def _get_or_extract_plaintext(
plaintext_io = file_store.read_file(plaintext_key, mode="b")
return plaintext_io.read().decode("utf-8")
except Exception:
logger.info(f"Cache miss for file with id={file_id}")
logger.exception(f"Error when reading file, id={file_id}")
# Cache miss — extract and store.
content_text = extract_fn()

View File

@@ -4,6 +4,8 @@ from collections.abc import Callable
from typing import Any
from typing import Literal
from sqlalchemy.orm import Session
from onyx.chat.chat_state import ChatStateContainer
from onyx.chat.chat_utils import create_tool_call_failure_messages
from onyx.chat.citation_processor import CitationMapping
@@ -633,6 +635,7 @@ def run_llm_loop(
user_memory_context: UserMemoryContext | None,
llm: LLM,
token_counter: Callable[[str], int],
db_session: Session,
forced_tool_id: int | None = None,
user_identity: LLMUserIdentity | None = None,
chat_session_id: str | None = None,
@@ -1017,16 +1020,20 @@ def run_llm_loop(
persisted_memory_id: int | None = None
if user_memory_context and user_memory_context.user_id:
if tool_response.rich_response.index_to_replace is not None:
persisted_memory_id = update_memory_at_index(
memory = update_memory_at_index(
user_id=user_memory_context.user_id,
index=tool_response.rich_response.index_to_replace,
new_text=tool_response.rich_response.memory_text,
db_session=db_session,
)
persisted_memory_id = memory.id if memory else None
else:
persisted_memory_id = add_memory(
memory = add_memory(
user_id=user_memory_context.user_id,
memory_text=tool_response.rich_response.memory_text,
db_session=db_session,
)
persisted_memory_id = memory.id
operation: Literal["add", "update"] = (
"update"
if tool_response.rich_response.index_to_replace is not None

View File

@@ -67,6 +67,7 @@ from onyx.db.chat import get_chat_session_by_id
from onyx.db.chat import get_or_create_root_message
from onyx.db.chat import reserve_message_id
from onyx.db.chat import reserve_multi_model_message_ids
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.enums import HookPoint
from onyx.db.memory import get_memories
from onyx.db.models import ChatMessage
@@ -1005,86 +1006,93 @@ def _run_models(
model_llm = setup.llms[model_idx]
try:
# Each function opens short-lived DB sessions on demand.
# Do NOT pass a long-lived session here — it would hold a
# connection for the entire LLM loop (minutes), and cloud
# infrastructure may drop idle connections.
thread_tool_dict = construct_tools(
persona=setup.persona,
emitter=model_emitter,
user=user,
llm=model_llm,
search_tool_config=SearchToolConfig(
user_selected_filters=setup.new_msg_req.internal_search_filters,
project_id_filter=setup.search_params.project_id_filter,
persona_id_filter=setup.search_params.persona_id_filter,
bypass_acl=setup.bypass_acl,
slack_context=setup.slack_context,
enable_slack_search=_should_enable_slack_search(
setup.persona, setup.new_msg_req.internal_search_filters
),
),
custom_tool_config=CustomToolConfig(
chat_session_id=setup.chat_session.id,
message_id=setup.user_message.id,
additional_headers=setup.custom_tool_additional_headers,
mcp_headers=setup.mcp_headers,
),
file_reader_tool_config=FileReaderToolConfig(
user_file_ids=setup.available_files.user_file_ids,
chat_file_ids=setup.available_files.chat_file_ids,
),
allowed_tool_ids=setup.new_msg_req.allowed_tool_ids,
search_usage_forcing_setting=setup.search_params.search_usage,
)
model_tools = [
tool for tool_list in thread_tool_dict.values() for tool in tool_list
]
if setup.forced_tool_id and setup.forced_tool_id not in {
tool.id for tool in model_tools
}:
raise ValueError(
f"Forced tool {setup.forced_tool_id} not found in tools"
)
# Per-thread copy: run_llm_loop mutates simple_chat_history in-place.
if n_models == 1 and setup.new_msg_req.deep_research:
if setup.chat_session.project_id:
raise RuntimeError("Deep research is not supported for projects")
run_deep_research_llm_loop(
emitter=model_emitter,
state_container=sc,
simple_chat_history=list(setup.simple_chat_history),
tools=model_tools,
custom_agent_prompt=setup.custom_agent_prompt,
llm=model_llm,
token_counter=get_llm_token_counter(model_llm),
skip_clarification=setup.skip_clarification,
user_identity=setup.user_identity,
chat_session_id=str(setup.chat_session.id),
all_injected_file_metadata=setup.all_injected_file_metadata,
)
else:
run_llm_loop(
emitter=model_emitter,
state_container=sc,
simple_chat_history=list(setup.simple_chat_history),
tools=model_tools,
custom_agent_prompt=setup.custom_agent_prompt,
context_files=setup.extracted_context_files,
# Each worker opens its own session — SQLAlchemy sessions are not thread-safe.
# Do NOT write to the outer db_session (or any shared DB state) from here;
# all DB writes in this thread must go through thread_db_session.
with get_session_with_current_tenant() as thread_db_session:
thread_tool_dict = construct_tools(
persona=setup.persona,
user_memory_context=setup.user_memory_context,
db_session=thread_db_session,
emitter=model_emitter,
user=user,
llm=model_llm,
token_counter=get_llm_token_counter(model_llm),
forced_tool_id=setup.forced_tool_id,
user_identity=setup.user_identity,
chat_session_id=str(setup.chat_session.id),
chat_files=setup.chat_files_for_tools,
include_citations=setup.new_msg_req.include_citations,
all_injected_file_metadata=setup.all_injected_file_metadata,
inject_memories_in_prompt=user.use_memories,
search_tool_config=SearchToolConfig(
user_selected_filters=setup.new_msg_req.internal_search_filters,
project_id_filter=setup.search_params.project_id_filter,
persona_id_filter=setup.search_params.persona_id_filter,
bypass_acl=setup.bypass_acl,
slack_context=setup.slack_context,
enable_slack_search=_should_enable_slack_search(
setup.persona, setup.new_msg_req.internal_search_filters
),
),
custom_tool_config=CustomToolConfig(
chat_session_id=setup.chat_session.id,
message_id=setup.user_message.id,
additional_headers=setup.custom_tool_additional_headers,
mcp_headers=setup.mcp_headers,
),
file_reader_tool_config=FileReaderToolConfig(
user_file_ids=setup.available_files.user_file_ids,
chat_file_ids=setup.available_files.chat_file_ids,
),
allowed_tool_ids=setup.new_msg_req.allowed_tool_ids,
search_usage_forcing_setting=setup.search_params.search_usage,
)
model_tools = [
tool
for tool_list in thread_tool_dict.values()
for tool in tool_list
]
if setup.forced_tool_id and setup.forced_tool_id not in {
tool.id for tool in model_tools
}:
raise ValueError(
f"Forced tool {setup.forced_tool_id} not found in tools"
)
# Per-thread copy: run_llm_loop mutates simple_chat_history in-place.
if n_models == 1 and setup.new_msg_req.deep_research:
if setup.chat_session.project_id:
raise RuntimeError(
"Deep research is not supported for projects"
)
run_deep_research_llm_loop(
emitter=model_emitter,
state_container=sc,
simple_chat_history=list(setup.simple_chat_history),
tools=model_tools,
custom_agent_prompt=setup.custom_agent_prompt,
llm=model_llm,
token_counter=get_llm_token_counter(model_llm),
db_session=thread_db_session,
skip_clarification=setup.skip_clarification,
user_identity=setup.user_identity,
chat_session_id=str(setup.chat_session.id),
all_injected_file_metadata=setup.all_injected_file_metadata,
)
else:
run_llm_loop(
emitter=model_emitter,
state_container=sc,
simple_chat_history=list(setup.simple_chat_history),
tools=model_tools,
custom_agent_prompt=setup.custom_agent_prompt,
context_files=setup.extracted_context_files,
persona=setup.persona,
user_memory_context=setup.user_memory_context,
llm=model_llm,
token_counter=get_llm_token_counter(model_llm),
db_session=thread_db_session,
forced_tool_id=setup.forced_tool_id,
user_identity=setup.user_identity,
chat_session_id=str(setup.chat_session.id),
chat_files=setup.chat_files_for_tools,
include_citations=setup.new_msg_req.include_citations,
all_injected_file_metadata=setup.all_injected_file_metadata,
inject_memories_in_prompt=user.use_memories,
)
model_succeeded[model_idx] = True

View File

@@ -283,7 +283,6 @@ class NotificationType(str, Enum):
RELEASE_NOTES = "release_notes"
ASSISTANT_FILES_READY = "assistant_files_ready"
FEATURE_ANNOUNCEMENT = "feature_announcement"
CONNECTOR_REPEATED_ERRORS = "connector_repeated_errors"
class BlobType(str, Enum):

View File

@@ -1,48 +0,0 @@
from typing import Any
from sentry_sdk.types import Event
from onyx.utils.logger import setup_logger
logger = setup_logger()
_instance_id_resolved = False
def _add_instance_tags(
event: Event,
hint: dict[str, Any], # noqa: ARG001
) -> Event | None:
"""Sentry before_send hook that lazily attaches instance identification tags.
On the first event, resolves the instance UUID from the KV store (requires DB)
and sets it as a global Sentry tag. Subsequent events pick it up automatically.
"""
global _instance_id_resolved
if _instance_id_resolved:
return event
try:
import sentry_sdk
from shared_configs.configs import MULTI_TENANT
if MULTI_TENANT:
instance_id = "multi-tenant-cloud"
else:
from onyx.utils.telemetry import get_or_generate_uuid
instance_id = get_or_generate_uuid()
sentry_sdk.set_tag("instance_id", instance_id)
# Also set on this event since set_tag won't retroactively apply
event.setdefault("tags", {})["instance_id"] = instance_id
# Only mark resolved after success — if DB wasn't ready, retry next event
_instance_id_resolved = True
except Exception:
logger.debug("Failed to resolve instance_id for Sentry tagging")
return event

View File

@@ -26,10 +26,6 @@ from onyx.configs.constants import FileOrigin
from onyx.connectors.cross_connector_utils.miscellaneous_utils import (
process_onyx_metadata,
)
from onyx.connectors.cross_connector_utils.tabular_section_utils import is_tabular_file
from onyx.connectors.cross_connector_utils.tabular_section_utils import (
tabular_file_to_sections,
)
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import CredentialExpiredError
from onyx.connectors.exceptions import InsufficientPermissionsError
@@ -42,7 +38,6 @@ from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import HierarchyNode
from onyx.connectors.models import ImageSection
from onyx.connectors.models import TabularSection
from onyx.connectors.models import TextSection
from onyx.file_processing.extract_file_text import extract_text_and_images
from onyx.file_processing.extract_file_text import get_file_ext
@@ -456,40 +451,6 @@ class BlobStorageConnector(LoadConnector, PollConnector):
logger.exception(f"Error processing image {key}")
continue
# Handle tabular files (xlsx, csv, tsv) — produce one
# TabularSection per sheet (or per file for csv/tsv)
# instead of a flat TextSection.
if is_tabular_file(file_name):
try:
downloaded_file = self._download_object(key)
if downloaded_file is None:
continue
tabular_sections = tabular_file_to_sections(
BytesIO(downloaded_file),
file_name=file_name,
link=link,
)
batch.append(
Document(
id=f"{self.bucket_type}:{self.bucket_name}:{key}",
sections=(
tabular_sections
if tabular_sections
else [TabularSection(link=link, text="")]
),
source=DocumentSource(self.bucket_type.value),
semantic_identifier=file_name,
doc_updated_at=last_modified,
metadata={},
)
)
if len(batch) == self.batch_size:
yield batch
batch = []
except Exception:
logger.exception(f"Error processing tabular file {key}")
continue
# Handle text and document files
try:
downloaded_file = self._download_object(key)

View File

@@ -27,19 +27,16 @@ _STATUS_TO_ERROR_CODE: dict[int, OnyxErrorCode] = {
401: OnyxErrorCode.CREDENTIAL_EXPIRED,
403: OnyxErrorCode.INSUFFICIENT_PERMISSIONS,
404: OnyxErrorCode.BAD_GATEWAY,
429: OnyxErrorCode.RATE_LIMITED,
}
def _error_code_for_status(status_code: int) -> OnyxErrorCode:
"""Map an HTTP status code to the appropriate OnyxErrorCode.
Expects a >= 400 status code. Known codes (401, 403, 404) are
Expects a >= 400 status code. Known codes (401, 403, 404, 429) are
mapped to specific error codes; all other codes (unrecognised 4xx
and 5xx) map to BAD_GATEWAY as unexpected upstream errors.
Note: 429 is intentionally omitted — the rl_requests wrapper
handles rate limits transparently at the HTTP layer, so 429
responses never reach this function.
"""
if status_code in _STATUS_TO_ERROR_CODE:
return _STATUS_TO_ERROR_CODE[status_code]

View File

@@ -1,9 +1,10 @@
from datetime import datetime
from datetime import timezone
from enum import StrEnum
from typing import Any
from typing import cast
from typing import Literal
from typing import NoReturn
from typing import TypeAlias
from pydantic import BaseModel
from retry import retry
@@ -24,11 +25,8 @@ from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import SlimConnectorWithPermSync
from onyx.connectors.models import ConnectorCheckpoint
from onyx.connectors.models import ConnectorFailure
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import DocumentFailure
from onyx.connectors.models import EntityFailure
from onyx.connectors.models import ImageSection
from onyx.connectors.models import TextSection
from onyx.error_handling.exceptions import OnyxError
@@ -49,6 +47,10 @@ def _handle_canvas_api_error(e: OnyxError) -> NoReturn:
raise InsufficientPermissionsError(
"Canvas API token does not have sufficient permissions (HTTP 403)."
)
elif e.status_code == 429:
raise ConnectorValidationError(
"Canvas rate-limit exceeded (HTTP 429). Please try again later."
)
elif e.status_code >= 500:
raise UnexpectedValidationError(
f"Unexpected Canvas HTTP error (status={e.status_code}): {e}"
@@ -59,60 +61,6 @@ def _handle_canvas_api_error(e: OnyxError) -> NoReturn:
)
class CanvasStage(StrEnum):
PAGES = "pages"
ASSIGNMENTS = "assignments"
ANNOUNCEMENTS = "announcements"
_STAGE_CONFIG: dict[CanvasStage, dict[str, Any]] = {
CanvasStage.PAGES: {
"endpoint": "courses/{course_id}/pages",
"params": {
"per_page": "100",
"include[]": "body",
"published": "true",
"sort": "updated_at",
"order": "desc",
},
},
CanvasStage.ASSIGNMENTS: {
"endpoint": "courses/{course_id}/assignments",
"params": {"per_page": "100", "published": "true"},
},
CanvasStage.ANNOUNCEMENTS: {
"endpoint": "announcements",
"params": {
"per_page": "100",
"context_codes[]": "course_{course_id}",
"active_only": "true",
},
},
}
def _parse_canvas_dt(timestamp_str: str) -> datetime:
"""Parse a Canvas ISO-8601 timestamp (e.g. '2025-06-15T12:00:00Z')
into a timezone-aware UTC datetime.
Canvas returns timestamps with a trailing 'Z' instead of '+00:00',
so we normalise before parsing.
"""
return datetime.fromisoformat(timestamp_str.replace("Z", "+00:00")).astimezone(
timezone.utc
)
def _unix_to_canvas_time(epoch: float) -> str:
"""Convert a Unix timestamp to Canvas ISO-8601 format (e.g. '2025-06-15T12:00:00Z')."""
return datetime.fromtimestamp(epoch, tz=timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")
def _in_time_window(timestamp_str: str, start: float, end: float) -> bool:
"""Check whether a Canvas ISO-8601 timestamp falls within (start, end]."""
return start < _parse_canvas_dt(timestamp_str).timestamp() <= end
class CanvasCourse(BaseModel):
id: int
name: str | None = None
@@ -197,6 +145,9 @@ class CanvasAnnouncement(BaseModel):
)
CanvasStage: TypeAlias = Literal["pages", "assignments", "announcements"]
class CanvasConnectorCheckpoint(ConnectorCheckpoint):
"""Checkpoint state for resumable Canvas indexing.
@@ -214,30 +165,15 @@ class CanvasConnectorCheckpoint(ConnectorCheckpoint):
course_ids: list[int] = []
current_course_index: int = 0
stage: CanvasStage = CanvasStage.PAGES
stage: CanvasStage = "pages"
next_url: str | None = None
def advance_course(self) -> None:
"""Move to the next course and reset within-course state."""
self.current_course_index += 1
self.stage = CanvasStage.PAGES
self.stage = "pages"
self.next_url = None
def advance_stage(self) -> None:
"""Advance past the current stage.
Moves to the next stage within the same course, or to the next
course if the current stage is the last one. Resets next_url so
the next call starts fresh on the new stage.
"""
self.next_url = None
stages: list[CanvasStage] = list(CanvasStage)
next_idx = stages.index(self.stage) + 1
if next_idx < len(stages):
self.stage = stages[next_idx]
else:
self.advance_course()
class CanvasConnector(
CheckpointedConnectorWithPermSync[CanvasConnectorCheckpoint],
@@ -359,7 +295,13 @@ class CanvasConnector(
if body_text:
text_parts.append(body_text)
doc_updated_at = _parse_canvas_dt(page.updated_at) if page.updated_at else None
doc_updated_at = (
datetime.fromisoformat(page.updated_at.replace("Z", "+00:00")).astimezone(
timezone.utc
)
if page.updated_at
else None
)
document = self._build_document(
doc_id=f"canvas-page-{page.course_id}-{page.page_id}",
@@ -383,11 +325,17 @@ class CanvasConnector(
if desc_text:
text_parts.append(desc_text)
if assignment.due_at:
due_dt = _parse_canvas_dt(assignment.due_at)
due_dt = datetime.fromisoformat(
assignment.due_at.replace("Z", "+00:00")
).astimezone(timezone.utc)
text_parts.append(f"Due: {due_dt.strftime('%B %d, %Y %H:%M UTC')}")
doc_updated_at = (
_parse_canvas_dt(assignment.updated_at) if assignment.updated_at else None
datetime.fromisoformat(
assignment.updated_at.replace("Z", "+00:00")
).astimezone(timezone.utc)
if assignment.updated_at
else None
)
document = self._build_document(
@@ -413,7 +361,11 @@ class CanvasConnector(
text_parts.append(msg_text)
doc_updated_at = (
_parse_canvas_dt(announcement.posted_at) if announcement.posted_at else None
datetime.fromisoformat(
announcement.posted_at.replace("Z", "+00:00")
).astimezone(timezone.utc)
if announcement.posted_at
else None
)
document = self._build_document(
@@ -448,314 +400,6 @@ class CanvasConnector(
self._canvas_client = client
return None
def _fetch_stage_page(
self,
next_url: str | None,
endpoint: str,
params: dict[str, Any],
) -> tuple[list[Any], str | None]:
"""Fetch one page of API results for the current stage.
Returns (items, next_url). All error handling is done by the
caller (_load_from_checkpoint).
"""
if next_url:
# Resuming mid-pagination: the next_url from Canvas's
# Link header already contains endpoint + query params.
response, result_next_url = self.canvas_client.get(full_url=next_url)
else:
# First request for this stage: build from endpoint + params.
response, result_next_url = self.canvas_client.get(
endpoint=endpoint, params=params
)
return response or [], result_next_url
def _process_items(
self,
response: list[Any],
stage: CanvasStage,
course_id: int,
start: float,
end: float,
include_permissions: bool,
) -> tuple[list[Document | ConnectorFailure], bool]:
"""Process a page of API results into documents.
Returns (docs, early_exit). early_exit is True when pages
(sorted desc by updated_at) hit an item older than start,
signaling that pagination should stop.
"""
results: list[Document | ConnectorFailure] = []
early_exit = False
for item in response:
try:
if stage == CanvasStage.PAGES:
page = CanvasPage.from_api(item, course_id=course_id)
if not page.updated_at:
continue
# Pages are sorted by updated_at desc — once we see
# an item at or before `start`, all remaining items
# on this and subsequent pages are older too.
if not _in_time_window(page.updated_at, start, end):
if _parse_canvas_dt(page.updated_at).timestamp() <= start:
early_exit = True
break
# ts > end: page is newer than our window, skip it
continue
doc = self._convert_page_to_document(page)
results.append(
self._maybe_attach_permissions(
doc, course_id, include_permissions
)
)
elif stage == CanvasStage.ASSIGNMENTS:
assignment = CanvasAssignment.from_api(item, course_id=course_id)
if not assignment.updated_at or not _in_time_window(
assignment.updated_at, start, end
):
continue
doc = self._convert_assignment_to_document(assignment)
results.append(
self._maybe_attach_permissions(
doc, course_id, include_permissions
)
)
elif stage == CanvasStage.ANNOUNCEMENTS:
announcement = CanvasAnnouncement.from_api(
item, course_id=course_id
)
if not announcement.posted_at:
logger.debug(
f"Skipping announcement {announcement.id} in "
f"course {course_id}: no posted_at"
)
continue
if not _in_time_window(announcement.posted_at, start, end):
continue
doc = self._convert_announcement_to_document(announcement)
results.append(
self._maybe_attach_permissions(
doc, course_id, include_permissions
)
)
except Exception as e:
item_id = item.get("id") or item.get("page_id", "unknown")
if stage == CanvasStage.PAGES:
doc_link = (
f"{self.canvas_base_url}/courses/{course_id}"
f"/pages/{item.get('url', '')}"
)
else:
doc_link = item.get("html_url", "")
results.append(
ConnectorFailure(
failed_document=DocumentFailure(
document_id=f"canvas-{stage.removesuffix('s')}-{course_id}-{item_id}",
document_link=doc_link,
),
failure_message=f"Failed to process {stage.removesuffix('s')}: {e}",
exception=e,
)
)
return results, early_exit
def _maybe_attach_permissions(
self,
document: Document,
course_id: int,
include_permissions: bool,
) -> Document:
if include_permissions:
document.external_access = self._get_course_permissions(course_id)
return document
def _load_from_checkpoint(
self,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
checkpoint: CanvasConnectorCheckpoint,
include_permissions: bool = False,
) -> CheckpointOutput[CanvasConnectorCheckpoint]:
"""Shared implementation for load_from_checkpoint and load_from_checkpoint_with_perm_sync."""
new_checkpoint = checkpoint.model_copy(deep=True)
# First call: materialize the list of course IDs.
# On failure, let the exception propagate so the framework fails the
# attempt cleanly. Swallowing errors here would leave the checkpoint
# state unchanged and cause an infinite retry loop.
if not new_checkpoint.course_ids:
try:
courses = self._list_courses()
except OnyxError as e:
if e.status_code in (401, 403):
_handle_canvas_api_error(e) # NoReturn — always raises
raise
new_checkpoint.course_ids = [c.id for c in courses]
logger.info(f"Found {len(courses)} Canvas courses to process")
new_checkpoint.has_more = len(new_checkpoint.course_ids) > 0
return new_checkpoint
# All courses done.
if new_checkpoint.current_course_index >= len(new_checkpoint.course_ids):
new_checkpoint.has_more = False
return new_checkpoint
course_id = new_checkpoint.course_ids[new_checkpoint.current_course_index]
try:
stage = CanvasStage(new_checkpoint.stage)
except ValueError as e:
raise ValueError(
f"Invalid checkpoint stage: {new_checkpoint.stage!r}. "
f"Valid stages: {[s.value for s in CanvasStage]}"
) from e
# Build endpoint + params from the static template.
config = _STAGE_CONFIG[stage]
endpoint = config["endpoint"].format(course_id=course_id)
params = {k: v.format(course_id=course_id) for k, v in config["params"].items()}
# Only the announcements API supports server-side date filtering
# (start_date/end_date). Pages support server-side sorting
# (sort=updated_at desc) enabling early exit, but not date
# filtering. Assignments support neither. Both are filtered
# client-side via _in_time_window after fetching.
if stage == CanvasStage.ANNOUNCEMENTS:
params["start_date"] = _unix_to_canvas_time(start)
params["end_date"] = _unix_to_canvas_time(end)
try:
response, result_next_url = self._fetch_stage_page(
next_url=new_checkpoint.next_url,
endpoint=endpoint,
params=params,
)
except OnyxError as oe:
# Security errors from _parse_next_link (host/scheme
# mismatch on pagination URLs) have no status code override
# and must not be silenced.
is_api_error = oe._status_code_override is not None
if not is_api_error:
raise
if oe.status_code in (401, 403):
_handle_canvas_api_error(oe) # NoReturn — always raises
# 404 means the course itself is gone or inaccessible. The
# other stages on this course will hit the same 404, so skip
# the whole course rather than burning API calls on each stage.
if oe.status_code == 404:
logger.warning(
f"Canvas course {course_id} not found while fetching "
f"{stage} (HTTP 404). Skipping course."
)
yield ConnectorFailure(
failed_entity=EntityFailure(
entity_id=f"canvas-course-{course_id}",
),
failure_message=(f"Canvas course {course_id} not found: {oe}"),
exception=oe,
)
new_checkpoint.advance_course()
else:
logger.warning(
f"Failed to fetch {stage} for course {course_id}: {oe}. "
f"Skipping remainder of this stage."
)
yield ConnectorFailure(
failed_entity=EntityFailure(
entity_id=f"canvas-{stage}-{course_id}",
),
failure_message=(
f"Failed to fetch {stage} for course {course_id}: {oe}"
),
exception=oe,
)
new_checkpoint.advance_stage()
new_checkpoint.has_more = new_checkpoint.current_course_index < len(
new_checkpoint.course_ids
)
return new_checkpoint
except Exception as e:
# Unknown error — skip the stage and try to continue.
logger.warning(
f"Failed to fetch {stage} for course {course_id}: {e}. "
f"Skipping remainder of this stage."
)
yield ConnectorFailure(
failed_entity=EntityFailure(
entity_id=f"canvas-{stage}-{course_id}",
),
failure_message=(
f"Failed to fetch {stage} for course {course_id}: {e}"
),
exception=e,
)
new_checkpoint.advance_stage()
new_checkpoint.has_more = new_checkpoint.current_course_index < len(
new_checkpoint.course_ids
)
return new_checkpoint
# Process fetched items
results, early_exit = self._process_items(
response, stage, course_id, start, end, include_permissions
)
for result in results:
yield result
# If we hit an item older than our window (pages sorted desc),
# skip remaining pagination and advance to the next stage.
if early_exit:
result_next_url = None
# If there are more pages, save the cursor and return
if result_next_url:
new_checkpoint.next_url = result_next_url
else:
# Stage complete — advance to next stage (or next course if last).
new_checkpoint.advance_stage()
new_checkpoint.has_more = new_checkpoint.current_course_index < len(
new_checkpoint.course_ids
)
return new_checkpoint
@override
def load_from_checkpoint(
self,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
checkpoint: CanvasConnectorCheckpoint,
) -> CheckpointOutput[CanvasConnectorCheckpoint]:
return self._load_from_checkpoint(
start, end, checkpoint, include_permissions=False
)
@override
def load_from_checkpoint_with_perm_sync(
self,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
checkpoint: CanvasConnectorCheckpoint,
) -> CheckpointOutput[CanvasConnectorCheckpoint]:
"""Load documents from checkpoint with permission information included."""
return self._load_from_checkpoint(
start, end, checkpoint, include_permissions=True
)
@override
def build_dummy_checkpoint(self) -> CanvasConnectorCheckpoint:
return CanvasConnectorCheckpoint(has_more=True)
@override
def validate_checkpoint_json(
self, checkpoint_json: str
) -> CanvasConnectorCheckpoint:
return CanvasConnectorCheckpoint.model_validate_json(checkpoint_json)
@override
def validate_connector_settings(self) -> None:
"""Validate Canvas connector settings by testing API access."""
@@ -771,6 +415,38 @@ class CanvasConnector(
f"Unexpected error during Canvas settings validation: {exc}"
)
@override
def load_from_checkpoint(
self,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
checkpoint: CanvasConnectorCheckpoint,
) -> CheckpointOutput[CanvasConnectorCheckpoint]:
# TODO(benwu408): implemented in PR3 (checkpoint)
raise NotImplementedError
@override
def load_from_checkpoint_with_perm_sync(
self,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
checkpoint: CanvasConnectorCheckpoint,
) -> CheckpointOutput[CanvasConnectorCheckpoint]:
# TODO(benwu408): implemented in PR3 (checkpoint)
raise NotImplementedError
@override
def build_dummy_checkpoint(self) -> CanvasConnectorCheckpoint:
# TODO(benwu408): implemented in PR3 (checkpoint)
raise NotImplementedError
@override
def validate_checkpoint_json(
self, checkpoint_json: str
) -> CanvasConnectorCheckpoint:
# TODO(benwu408): implemented in PR3 (checkpoint)
raise NotImplementedError
@override
def retrieve_all_slim_docs_perm_sync(
self,

View File

@@ -171,10 +171,7 @@ class ClickupConnector(LoadConnector, PollConnector):
document.metadata[extra_field] = task[extra_field]
if self.retrieve_task_comments:
document.sections = [
*document.sections,
*self._get_task_comments(task["id"]),
]
document.sections.extend(self._get_task_comments(task["id"]))
doc_batch.append(document)

View File

@@ -61,9 +61,6 @@ _USER_NOT_FOUND = "Unknown Confluence User"
_USER_ID_TO_DISPLAY_NAME_CACHE: dict[str, str | None] = {}
_USER_EMAIL_CACHE: dict[str, str | None] = {}
_DEFAULT_PAGINATION_LIMIT = 1000
_MINIMUM_PAGINATION_LIMIT = 5
_SERVER_ERROR_CODES = {500, 502, 503, 504}
_CONFLUENCE_SPACES_API_V1 = "rest/api/space"
_CONFLUENCE_SPACES_API_V2 = "wiki/api/v2/spaces"
@@ -572,8 +569,7 @@ class OnyxConfluence:
if not limit:
limit = _DEFAULT_PAGINATION_LIMIT
current_limit = limit
url_suffix = update_param_in_path(url_suffix, "limit", str(current_limit))
url_suffix = update_param_in_path(url_suffix, "limit", str(limit))
while url_suffix:
logger.debug(f"Making confluence call to {url_suffix}")
@@ -613,61 +609,40 @@ class OnyxConfluence:
)
continue
if raw_response.status_code in _SERVER_ERROR_CODES:
# Try reducing the page size -- Confluence often times out
# on large result sets (especially Cloud 504s).
if current_limit > _MINIMUM_PAGINATION_LIMIT:
old_limit = current_limit
current_limit = max(
current_limit // 2, _MINIMUM_PAGINATION_LIMIT
)
logger.warning(
f"Confluence returned {raw_response.status_code}. "
f"Reducing limit from {old_limit} to {current_limit} "
f"and retrying."
)
url_suffix = update_param_in_path(
url_suffix, "limit", str(current_limit)
)
continue
# If we fail due to a 500, try one by one.
# NOTE: this iterative approach only works for server, since cloud uses cursor-based
# pagination
if raw_response.status_code == 500 and not self._is_cloud:
initial_start = get_start_param_from_url(url_suffix)
if initial_start is None:
# can't handle this if we don't have offset-based pagination
raise
# Limit reduction exhausted -- for Server, fall back to
# one-by-one offset pagination as a last resort.
if not self._is_cloud:
initial_start = get_start_param_from_url(url_suffix)
# this will just yield the successful items from the batch
new_url_suffix = (
yield from self._try_one_by_one_for_paginated_url(
url_suffix,
initial_start=initial_start,
limit=current_limit,
)
)
# this means we ran into an empty page
if new_url_suffix is None:
if next_page_callback:
next_page_callback("")
break
# this will just yield the successful items from the batch
new_url_suffix = yield from self._try_one_by_one_for_paginated_url(
url_suffix,
initial_start=initial_start,
limit=limit,
)
url_suffix = new_url_suffix
continue
# this means we ran into an empty page
if new_url_suffix is None:
if next_page_callback:
next_page_callback("")
break
url_suffix = new_url_suffix
continue
else:
logger.exception(
f"Error in confluence call to {url_suffix} "
f"after reducing limit to {current_limit}.\n"
f"Raw Response Text: {raw_response.text}\n"
f"Error: {e}\n"
f"Error in confluence call to {url_suffix} \n"
f"Raw Response Text: {raw_response.text} \n"
f"Full Response: {raw_response.__dict__} \n"
f"Error: {e} \n"
)
raise
logger.exception(
f"Error in confluence call to {url_suffix} \n"
f"Raw Response Text: {raw_response.text} \n"
f"Full Response: {raw_response.__dict__} \n"
f"Error: {e} \n"
)
raise
try:
next_response = raw_response.json()
except Exception as e:
@@ -705,10 +680,6 @@ class OnyxConfluence:
old_url_suffix = url_suffix
updated_start = get_start_param_from_url(old_url_suffix)
url_suffix = cast(str, next_response.get("_links", {}).get("next", ""))
if url_suffix and current_limit != limit:
url_suffix = update_param_in_path(
url_suffix, "limit", str(current_limit)
)
for i, result in enumerate(results):
updated_start += 1
if url_suffix and next_page_callback and i == len(results) - 1:

View File

@@ -1,69 +0,0 @@
import csv
import io
from typing import IO
from onyx.connectors.models import TabularSection
from onyx.file_processing.extract_file_text import file_io_to_text
from onyx.file_processing.extract_file_text import xlsx_sheet_extraction
from onyx.file_processing.file_types import OnyxFileExtensions
from onyx.utils.logger import setup_logger
logger = setup_logger()
def is_tabular_file(file_name: str) -> bool:
lowered = file_name.lower()
return any(lowered.endswith(ext) for ext in OnyxFileExtensions.TABULAR_EXTENSIONS)
def _tsv_to_csv(tsv_text: str) -> str:
"""Re-serialize tab-separated text as CSV so downstream parsers that
assume the default Excel dialect read the columns correctly."""
out = io.StringIO()
csv.writer(out, lineterminator="\n").writerows(
csv.reader(io.StringIO(tsv_text), dialect="excel-tab")
)
return out.getvalue().rstrip("\n")
def tabular_file_to_sections(
file: IO[bytes],
file_name: str,
link: str = "",
) -> list[TabularSection]:
"""Convert a tabular file into one or more TabularSections.
- .xlsx → one TabularSection per non-empty sheet.
- .csv / .tsv → a single TabularSection containing the full decoded
file.
Returns an empty list when the file yields no extractable content.
"""
lowered = file_name.lower()
if lowered.endswith(".xlsx"):
return [
TabularSection(
link=link or file_name,
text=csv_text,
heading=f"{file_name} :: {sheet_title}",
)
for csv_text, sheet_title in xlsx_sheet_extraction(
file, file_name=file_name
)
]
if not lowered.endswith((".csv", ".tsv")):
raise ValueError(f"{file_name!r} is not a tabular file")
try:
text = file_io_to_text(file).strip()
except Exception:
logger.exception(f"Failure decoding {file_name}")
raise
if not text:
return []
if lowered.endswith(".tsv"):
text = _tsv_to_csv(text)
return [TabularSection(link=link or file_name, text=text)]

View File

@@ -15,10 +15,6 @@ from onyx.connectors.cross_connector_utils.miscellaneous_utils import (
)
from onyx.connectors.cross_connector_utils.rate_limit_wrapper import rate_limit_builder
from onyx.connectors.cross_connector_utils.rate_limit_wrapper import rl_requests
from onyx.connectors.cross_connector_utils.tabular_section_utils import is_tabular_file
from onyx.connectors.cross_connector_utils.tabular_section_utils import (
tabular_file_to_sections,
)
from onyx.connectors.drupal_wiki.models import DrupalWikiCheckpoint
from onyx.connectors.drupal_wiki.models import DrupalWikiPage
from onyx.connectors.drupal_wiki.models import DrupalWikiPageResponse
@@ -37,7 +33,6 @@ from onyx.connectors.models import DocumentFailure
from onyx.connectors.models import HierarchyNode
from onyx.connectors.models import ImageSection
from onyx.connectors.models import SlimDocument
from onyx.connectors.models import TabularSection
from onyx.connectors.models import TextSection
from onyx.file_processing.extract_file_text import extract_text_and_images
from onyx.file_processing.extract_file_text import get_file_ext
@@ -218,7 +213,7 @@ class DrupalWikiConnector(
attachment: dict[str, Any],
page_id: int,
download_url: str,
) -> tuple[list[TextSection | ImageSection | TabularSection], str | None]:
) -> tuple[list[TextSection | ImageSection], str | None]:
"""
Process a single attachment and return generated sections.
@@ -231,7 +226,7 @@ class DrupalWikiConnector(
Tuple of (sections, error_message). If error_message is not None, the
sections list should be treated as invalid.
"""
sections: list[TextSection | ImageSection | TabularSection] = []
sections: list[TextSection | ImageSection] = []
try:
if not self._validate_attachment_filetype(attachment):
@@ -278,25 +273,6 @@ class DrupalWikiConnector(
return sections, None
# Tabular attachments (xlsx, csv, tsv) — produce
# TabularSections instead of a flat TextSection.
if is_tabular_file(file_name):
try:
sections.extend(
tabular_file_to_sections(
BytesIO(raw_bytes),
file_name=file_name,
link=download_url,
)
)
except Exception:
logger.exception(
f"Failed to extract tabular sections from {file_name}"
)
if not sections:
return [], f"No content extracted from tabular file {file_name}"
return sections, None
image_counter = 0
def _store_embedded_image(image_data: bytes, image_name: str) -> None:
@@ -521,7 +497,7 @@ class DrupalWikiConnector(
page_url = build_drupal_wiki_document_id(self.base_url, page.id)
# Create sections with just the page content
sections: list[TextSection | ImageSection | TabularSection] = [
sections: list[TextSection | ImageSection] = [
TextSection(text=text_content, link=page_url)
]

View File

@@ -2,7 +2,6 @@ import json
import os
from datetime import datetime
from datetime import timezone
from io import BytesIO
from pathlib import Path
from typing import Any
from typing import IO
@@ -13,16 +12,11 @@ from onyx.configs.constants import FileOrigin
from onyx.connectors.cross_connector_utils.miscellaneous_utils import (
process_onyx_metadata,
)
from onyx.connectors.cross_connector_utils.tabular_section_utils import is_tabular_file
from onyx.connectors.cross_connector_utils.tabular_section_utils import (
tabular_file_to_sections,
)
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.models import Document
from onyx.connectors.models import HierarchyNode
from onyx.connectors.models import ImageSection
from onyx.connectors.models import TabularSection
from onyx.connectors.models import TextSection
from onyx.file_processing.extract_file_text import extract_text_and_images
from onyx.file_processing.extract_file_text import get_file_ext
@@ -185,32 +179,8 @@ def _process_file(
link = onyx_metadata.link or link
# Build sections: first the text as a single Section
sections: list[TextSection | ImageSection | TabularSection] = []
if is_tabular_file(file_name):
# Produce TabularSections
lowered_name = file_name.lower()
if lowered_name.endswith(".xlsx"):
file.seek(0)
tabular_source: IO[bytes] = file
else:
tabular_source = BytesIO(
extraction_result.text_content.encode("utf-8", errors="replace")
)
try:
sections.extend(
tabular_file_to_sections(
file=tabular_source,
file_name=file_name,
link=link or "",
)
)
except Exception as e:
logger.error(f"Failed to process tabular file {file_name}: {e}")
return []
if not sections:
logger.warning(f"No content extracted from tabular file {file_name}")
return []
elif extraction_result.text_content.strip():
sections: list[TextSection | ImageSection] = []
if extraction_result.text_content.strip():
logger.debug(f"Creating TextSection for {file_name} with link: {link}")
sections.append(
TextSection(link=link, text=extraction_result.text_content.strip())

View File

@@ -42,9 +42,6 @@ from onyx.connectors.google_drive.file_retrieval import (
get_all_files_in_my_drive_and_shared,
)
from onyx.connectors.google_drive.file_retrieval import get_external_access_for_folder
from onyx.connectors.google_drive.file_retrieval import (
get_files_by_web_view_links_batch,
)
from onyx.connectors.google_drive.file_retrieval import get_files_in_shared_drive
from onyx.connectors.google_drive.file_retrieval import get_folder_metadata
from onyx.connectors.google_drive.file_retrieval import get_root_folder_id
@@ -73,14 +70,11 @@ from onyx.connectors.interfaces import CheckpointedConnectorWithPermSync
from onyx.connectors.interfaces import CheckpointOutput
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.connectors.interfaces import NormalizationResult
from onyx.connectors.interfaces import Resolver
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import SlimConnector
from onyx.connectors.interfaces import SlimConnectorWithPermSync
from onyx.connectors.models import ConnectorFailure
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import DocumentFailure
from onyx.connectors.models import EntityFailure
from onyx.connectors.models import HierarchyNode
from onyx.connectors.models import SlimDocument
@@ -208,10 +202,7 @@ class DriveIdStatus(Enum):
class GoogleDriveConnector(
SlimConnector,
SlimConnectorWithPermSync,
CheckpointedConnectorWithPermSync[GoogleDriveCheckpoint],
Resolver,
SlimConnectorWithPermSync, CheckpointedConnectorWithPermSync[GoogleDriveCheckpoint]
):
def __init__(
self,
@@ -1674,89 +1665,12 @@ class GoogleDriveConnector(
start, end, checkpoint, include_permissions=True
)
@override
def resolve_errors(
self,
errors: list[ConnectorFailure],
include_permissions: bool = False,
) -> Generator[Document | ConnectorFailure | HierarchyNode, None, None]:
if self._creds is None or self._primary_admin_email is None:
raise RuntimeError(
"Credentials missing, should not call this method before calling load_credentials"
)
logger.info(f"Resolving {len(errors)} errors")
doc_ids = [
failure.failed_document.document_id
for failure in errors
if failure.failed_document
]
service = get_drive_service(self.creds, self.primary_admin_email)
field_type = (
DriveFileFieldType.WITH_PERMISSIONS
if include_permissions or self.exclude_domain_link_only
else DriveFileFieldType.STANDARD
)
batch_result = get_files_by_web_view_links_batch(service, doc_ids, field_type)
for doc_id, error in batch_result.errors.items():
yield ConnectorFailure(
failed_document=DocumentFailure(
document_id=doc_id,
document_link=doc_id,
),
failure_message=f"Failed to retrieve file during error resolution: {error}",
exception=error,
)
permission_sync_context = (
PermissionSyncContext(
primary_admin_email=self.primary_admin_email,
google_domain=self.google_domain,
)
if include_permissions
else None
)
retrieved_files = [
RetrievedDriveFile(
drive_file=file,
user_email=self.primary_admin_email,
completion_stage=DriveRetrievalStage.DONE,
)
for file in batch_result.files.values()
]
yield from self._get_new_ancestors_for_files(
files=retrieved_files,
seen_hierarchy_node_raw_ids=ThreadSafeSet(),
fully_walked_hierarchy_node_raw_ids=ThreadSafeSet(),
permission_sync_context=permission_sync_context,
add_prefix=True,
)
func_with_args = [
(
self._convert_retrieved_file_to_document,
(rf, permission_sync_context),
)
for rf in retrieved_files
]
results = cast(
list[Document | ConnectorFailure | None],
run_functions_tuples_in_parallel(func_with_args, max_workers=8),
)
for result in results:
if result is not None:
yield result
def _extract_slim_docs_from_google_drive(
self,
checkpoint: GoogleDriveCheckpoint,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
callback: IndexingHeartbeatInterface | None = None,
include_permissions: bool = True,
) -> GenerateSlimDocumentOutput:
files_batch: list[RetrievedDriveFile] = []
slim_batch: list[SlimDocument | HierarchyNode] = []
@@ -1766,13 +1680,9 @@ class GoogleDriveConnector(
nonlocal files_batch, slim_batch
# Get new ancestor hierarchy nodes first
permission_sync_context = (
PermissionSyncContext(
primary_admin_email=self.primary_admin_email,
google_domain=self.google_domain,
)
if include_permissions
else None
permission_sync_context = PermissionSyncContext(
primary_admin_email=self.primary_admin_email,
google_domain=self.google_domain,
)
new_ancestors = self._get_new_ancestors_for_files(
files=files_batch,
@@ -1786,7 +1696,10 @@ class GoogleDriveConnector(
if doc := build_slim_document(
self.creds,
file.drive_file,
permission_sync_context,
PermissionSyncContext(
primary_admin_email=self.primary_admin_email,
google_domain=self.google_domain,
),
retriever_email=file.user_email,
):
slim_batch.append(doc)
@@ -1826,12 +1739,11 @@ class GoogleDriveConnector(
if files_batch:
yield _yield_slim_batch()
def _retrieve_all_slim_docs_impl(
def retrieve_all_slim_docs_perm_sync(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
callback: IndexingHeartbeatInterface | None = None,
include_permissions: bool = True,
) -> GenerateSlimDocumentOutput:
try:
checkpoint = self.build_dummy_checkpoint()
@@ -1841,34 +1753,13 @@ class GoogleDriveConnector(
start=start,
end=end,
callback=callback,
include_permissions=include_permissions,
)
logger.info("Drive slim doc retrieval complete")
logger.info("Drive perm sync: Slim doc retrieval complete")
except Exception as e:
if MISSING_SCOPES_ERROR_STR in str(e):
raise PermissionError(ONYX_SCOPE_INSTRUCTIONS) from e
raise
@override
def retrieve_all_slim_docs(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
callback: IndexingHeartbeatInterface | None = None,
) -> GenerateSlimDocumentOutput:
return self._retrieve_all_slim_docs_impl(
start=start, end=end, callback=callback, include_permissions=False
)
def retrieve_all_slim_docs_perm_sync(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
callback: IndexingHeartbeatInterface | None = None,
) -> GenerateSlimDocumentOutput:
return self._retrieve_all_slim_docs_impl(
start=start, end=end, callback=callback, include_permissions=True
)
raise e
def validate_connector_settings(self) -> None:
if self._creds is None:

View File

@@ -13,10 +13,6 @@ from pydantic import BaseModel
from onyx.access.models import ExternalAccess
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import FileOrigin
from onyx.connectors.cross_connector_utils.tabular_section_utils import is_tabular_file
from onyx.connectors.cross_connector_utils.tabular_section_utils import (
tabular_file_to_sections,
)
from onyx.connectors.google_drive.constants import DRIVE_FOLDER_TYPE
from onyx.connectors.google_drive.constants import DRIVE_SHORTCUT_TYPE
from onyx.connectors.google_drive.models import GDriveMimeType
@@ -32,16 +28,15 @@ from onyx.connectors.models import Document
from onyx.connectors.models import DocumentFailure
from onyx.connectors.models import ImageSection
from onyx.connectors.models import SlimDocument
from onyx.connectors.models import TabularSection
from onyx.connectors.models import TextSection
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_processing.extract_file_text import get_file_ext
from onyx.file_processing.extract_file_text import pptx_to_text
from onyx.file_processing.extract_file_text import read_docx_file
from onyx.file_processing.extract_file_text import read_pdf_file
from onyx.file_processing.extract_file_text import xlsx_to_text
from onyx.file_processing.file_types import OnyxFileExtensions
from onyx.file_processing.file_types import OnyxMimeTypes
from onyx.file_processing.file_types import SPREADSHEET_MIME_TYPE
from onyx.file_processing.image_utils import store_image_and_create_section
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import (
@@ -294,7 +289,7 @@ def _download_and_extract_sections_basic(
service: GoogleDriveService,
allow_images: bool,
size_threshold: int,
) -> list[TextSection | ImageSection | TabularSection]:
) -> list[TextSection | ImageSection]:
"""Extract text and images from a Google Drive file."""
file_id = file["id"]
file_name = file["name"]
@@ -313,7 +308,7 @@ def _download_and_extract_sections_basic(
return []
# Store images for later processing
sections: list[TextSection | ImageSection | TabularSection] = []
sections: list[TextSection | ImageSection] = []
try:
section, embedded_id = store_image_and_create_section(
image_data=response_call(),
@@ -328,9 +323,10 @@ def _download_and_extract_sections_basic(
logger.error(f"Failed to process image {file_name}: {e}")
return sections
# For Google Docs, Sheets, and Slides, export via the Drive API
# For Google Docs, Sheets, and Slides, export as plain text
if mime_type in GOOGLE_MIME_TYPES_TO_EXPORT:
export_mime_type = GOOGLE_MIME_TYPES_TO_EXPORT[mime_type]
# Use the correct API call for exporting files
request = service.files().export_media(
fileId=file_id, mimeType=export_mime_type
)
@@ -339,17 +335,6 @@ def _download_and_extract_sections_basic(
logger.warning(f"Failed to export {file_name} as {export_mime_type}")
return []
if export_mime_type in OnyxMimeTypes.TABULAR_MIME_TYPES:
# Synthesize an extension on the filename
ext = ".xlsx" if export_mime_type == SPREADSHEET_MIME_TYPE else ".csv"
return list(
tabular_file_to_sections(
io.BytesIO(response),
file_name=f"{file_name}{ext}",
link=link,
)
)
text = response.decode("utf-8")
return [TextSection(link=link, text=text)]
@@ -371,15 +356,9 @@ def _download_and_extract_sections_basic(
elif (
mime_type == "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
or is_tabular_file(file_name)
):
return list(
tabular_file_to_sections(
io.BytesIO(response_call()),
file_name=file_name,
link=link,
)
)
text = xlsx_to_text(io.BytesIO(response_call()), file_name=file_name)
return [TextSection(link=link, text=text)] if text else []
elif (
mime_type
@@ -390,7 +369,7 @@ def _download_and_extract_sections_basic(
elif mime_type == "application/pdf":
text, _pdf_meta, images = read_pdf_file(io.BytesIO(response_call()))
pdf_sections: list[TextSection | ImageSection | TabularSection] = [
pdf_sections: list[TextSection | ImageSection] = [
TextSection(link=link, text=text)
]
@@ -431,9 +410,8 @@ def _find_nth(haystack: str, needle: str, n: int, start: int = 0) -> int:
def align_basic_advanced(
basic_sections: list[TextSection | ImageSection | TabularSection],
adv_sections: list[TextSection],
) -> list[TextSection | ImageSection | TabularSection]:
basic_sections: list[TextSection | ImageSection], adv_sections: list[TextSection]
) -> list[TextSection | ImageSection]:
"""Align the basic sections with the advanced sections.
In particular, the basic sections contain all content of the file,
including smart chips like dates and doc links. The advanced sections
@@ -450,7 +428,7 @@ def align_basic_advanced(
basic_full_text = "".join(
[section.text for section in basic_sections if isinstance(section, TextSection)]
)
new_sections: list[TextSection | ImageSection | TabularSection] = []
new_sections: list[TextSection | ImageSection] = []
heading_start = 0
for adv_ind in range(1, len(adv_sections)):
heading = adv_sections[adv_ind].text.split(HEADING_DELIMITER)[0]
@@ -621,7 +599,7 @@ def _convert_drive_item_to_document(
"""
Main entry point for converting a Google Drive file => Document object.
"""
sections: list[TextSection | ImageSection | TabularSection] = []
sections: list[TextSection | ImageSection] = []
# Only construct these services when needed
def _get_drive_service() -> GoogleDriveService:
@@ -661,9 +639,7 @@ def _convert_drive_item_to_document(
doc_id=file.get("id", ""),
)
if doc_sections:
sections = cast(
list[TextSection | ImageSection | TabularSection], doc_sections
)
sections = cast(list[TextSection | ImageSection], doc_sections)
if any(SMART_CHIP_CHAR in section.text for section in doc_sections):
logger.debug(
f"found smart chips in {file.get('name')}, aligning with basic sections"

View File

@@ -9,7 +9,6 @@ from urllib.parse import urlparse
from googleapiclient.discovery import Resource # type: ignore
from googleapiclient.errors import HttpError # type: ignore
from googleapiclient.http import BatchHttpRequest # type: ignore
from onyx.access.models import ExternalAccess
from onyx.connectors.google_drive.constants import DRIVE_FOLDER_TYPE
@@ -61,8 +60,6 @@ SLIM_FILE_FIELDS = (
)
FOLDER_FIELDS = "nextPageToken, files(id, name, permissions, modifiedTime, webViewLink, shortcutDetails)"
MAX_BATCH_SIZE = 100
HIERARCHY_FIELDS = "id, name, parents, webViewLink, mimeType, driveId"
HIERARCHY_FIELDS_WITH_PERMISSIONS = (
@@ -219,7 +216,7 @@ def get_external_access_for_folder(
def _get_fields_for_file_type(field_type: DriveFileFieldType) -> str:
"""Get the appropriate fields string for files().list() based on the field type enum."""
"""Get the appropriate fields string based on the field type enum"""
if field_type == DriveFileFieldType.SLIM:
return SLIM_FILE_FIELDS
elif field_type == DriveFileFieldType.WITH_PERMISSIONS:
@@ -228,25 +225,6 @@ def _get_fields_for_file_type(field_type: DriveFileFieldType) -> str:
return FILE_FIELDS
def _extract_single_file_fields(list_fields: str) -> str:
"""Convert a files().list() fields string to one suitable for files().get().
List fields look like "nextPageToken, files(field1, field2, ...)"
Single-file fields should be just "field1, field2, ..."
"""
start = list_fields.find("files(")
if start == -1:
return list_fields
inner_start = start + len("files(")
inner_end = list_fields.rfind(")")
return list_fields[inner_start:inner_end]
def _get_single_file_fields(field_type: DriveFileFieldType) -> str:
"""Get the appropriate fields string for files().get() based on the field type enum."""
return _extract_single_file_fields(_get_fields_for_file_type(field_type))
def _get_files_in_parent(
service: Resource,
parent_id: str,
@@ -558,74 +536,3 @@ def get_file_by_web_view_link(
)
.execute()
)
class BatchRetrievalResult:
"""Result of a batch file retrieval, separating successes from errors."""
def __init__(self) -> None:
self.files: dict[str, GoogleDriveFileType] = {}
self.errors: dict[str, Exception] = {}
def get_files_by_web_view_links_batch(
service: GoogleDriveService,
web_view_links: list[str],
field_type: DriveFileFieldType,
) -> BatchRetrievalResult:
"""Retrieve multiple Google Drive files by webViewLink using the batch API.
Returns a BatchRetrievalResult containing successful file retrievals
and errors for any files that could not be fetched.
Automatically splits into chunks of MAX_BATCH_SIZE.
"""
fields = _get_single_file_fields(field_type)
if len(web_view_links) <= MAX_BATCH_SIZE:
return _get_files_by_web_view_links_batch(service, web_view_links, fields)
combined = BatchRetrievalResult()
for i in range(0, len(web_view_links), MAX_BATCH_SIZE):
chunk = web_view_links[i : i + MAX_BATCH_SIZE]
chunk_result = _get_files_by_web_view_links_batch(service, chunk, fields)
combined.files.update(chunk_result.files)
combined.errors.update(chunk_result.errors)
return combined
def _get_files_by_web_view_links_batch(
service: GoogleDriveService,
web_view_links: list[str],
fields: str,
) -> BatchRetrievalResult:
"""Single-batch implementation."""
result = BatchRetrievalResult()
def callback(
request_id: str,
response: GoogleDriveFileType,
exception: Exception | None,
) -> None:
if exception:
logger.warning(f"Error retrieving file {request_id}: {exception}")
result.errors[request_id] = exception
else:
result.files[request_id] = response
batch = cast(BatchHttpRequest, service.new_batch_http_request(callback=callback))
for web_view_link in web_view_links:
try:
file_id = _extract_file_id_from_web_view_link(web_view_link)
request = service.files().get(
fileId=file_id,
supportsAllDrives=True,
fields=fields,
)
batch.add(request, request_id=web_view_link)
except ValueError as e:
logger.warning(f"Failed to extract file ID from {web_view_link}: {e}")
result.errors[web_view_link] = e
batch.execute()
return result

View File

@@ -1,5 +1,4 @@
import json
from typing import Any
from typing import cast
from urllib.parse import parse_qs
from urllib.parse import ParseResult
@@ -54,21 +53,6 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
def _load_google_json(raw: object) -> dict[str, Any]:
"""Accept both the current (dict) and legacy (JSON string) KV payload shapes.
Payloads written before the fix for serializing Google credentials into
``EncryptedJson`` columns are stored as JSON strings; new writes store dicts.
Once every install has re-uploaded their Google credentials the legacy
``str`` branch can be removed.
"""
if isinstance(raw, dict):
return raw
if isinstance(raw, str):
return json.loads(raw)
raise ValueError(f"Unexpected Google credential payload type: {type(raw)!r}")
def _build_frontend_google_drive_redirect(source: DocumentSource) -> str:
if source == DocumentSource.GOOGLE_DRIVE:
return f"{WEB_DOMAIN}/admin/connectors/google-drive/auth/callback"
@@ -178,13 +162,12 @@ def build_service_account_creds(
def get_auth_url(credential_id: int, source: DocumentSource) -> str:
if source == DocumentSource.GOOGLE_DRIVE:
credential_json = _load_google_json(
get_kv_store().load(KV_GOOGLE_DRIVE_CRED_KEY)
)
creds_str = str(get_kv_store().load(KV_GOOGLE_DRIVE_CRED_KEY))
elif source == DocumentSource.GMAIL:
credential_json = _load_google_json(get_kv_store().load(KV_GMAIL_CRED_KEY))
creds_str = str(get_kv_store().load(KV_GMAIL_CRED_KEY))
else:
raise ValueError(f"Unsupported source: {source}")
credential_json = json.loads(creds_str)
flow = InstalledAppFlow.from_client_config(
credential_json,
scopes=GOOGLE_SCOPES[source],
@@ -205,12 +188,12 @@ def get_auth_url(credential_id: int, source: DocumentSource) -> str:
def get_google_app_cred(source: DocumentSource) -> GoogleAppCredentials:
if source == DocumentSource.GOOGLE_DRIVE:
creds = _load_google_json(get_kv_store().load(KV_GOOGLE_DRIVE_CRED_KEY))
creds_str = str(get_kv_store().load(KV_GOOGLE_DRIVE_CRED_KEY))
elif source == DocumentSource.GMAIL:
creds = _load_google_json(get_kv_store().load(KV_GMAIL_CRED_KEY))
creds_str = str(get_kv_store().load(KV_GMAIL_CRED_KEY))
else:
raise ValueError(f"Unsupported source: {source}")
return GoogleAppCredentials(**creds)
return GoogleAppCredentials(**json.loads(creds_str))
def upsert_google_app_cred(
@@ -218,14 +201,10 @@ def upsert_google_app_cred(
) -> None:
if source == DocumentSource.GOOGLE_DRIVE:
get_kv_store().store(
KV_GOOGLE_DRIVE_CRED_KEY,
app_credentials.model_dump(mode="json"),
encrypt=True,
KV_GOOGLE_DRIVE_CRED_KEY, app_credentials.json(), encrypt=True
)
elif source == DocumentSource.GMAIL:
get_kv_store().store(
KV_GMAIL_CRED_KEY, app_credentials.model_dump(mode="json"), encrypt=True
)
get_kv_store().store(KV_GMAIL_CRED_KEY, app_credentials.json(), encrypt=True)
else:
raise ValueError(f"Unsupported source: {source}")
@@ -241,14 +220,12 @@ def delete_google_app_cred(source: DocumentSource) -> None:
def get_service_account_key(source: DocumentSource) -> GoogleServiceAccountKey:
if source == DocumentSource.GOOGLE_DRIVE:
creds = _load_google_json(
get_kv_store().load(KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY)
)
creds_str = str(get_kv_store().load(KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY))
elif source == DocumentSource.GMAIL:
creds = _load_google_json(get_kv_store().load(KV_GMAIL_SERVICE_ACCOUNT_KEY))
creds_str = str(get_kv_store().load(KV_GMAIL_SERVICE_ACCOUNT_KEY))
else:
raise ValueError(f"Unsupported source: {source}")
return GoogleServiceAccountKey(**creds)
return GoogleServiceAccountKey(**json.loads(creds_str))
def upsert_service_account_key(
@@ -257,14 +234,12 @@ def upsert_service_account_key(
if source == DocumentSource.GOOGLE_DRIVE:
get_kv_store().store(
KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY,
service_account_key.model_dump(mode="json"),
service_account_key.json(),
encrypt=True,
)
elif source == DocumentSource.GMAIL:
get_kv_store().store(
KV_GMAIL_SERVICE_ACCOUNT_KEY,
service_account_key.model_dump(mode="json"),
encrypt=True,
KV_GMAIL_SERVICE_ACCOUNT_KEY, service_account_key.json(), encrypt=True
)
else:
raise ValueError(f"Unsupported source: {source}")

View File

@@ -123,9 +123,6 @@ class SlimConnector(BaseConnector):
@abc.abstractmethod
def retrieve_all_slim_docs(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
callback: IndexingHeartbeatInterface | None = None,
) -> GenerateSlimDocumentOutput:
raise NotImplementedError
@@ -301,22 +298,6 @@ class CheckpointedConnectorWithPermSync(CheckpointedConnector[CT]):
raise NotImplementedError
class Resolver(BaseConnector):
@abc.abstractmethod
def resolve_errors(
self,
errors: list[ConnectorFailure],
include_permissions: bool = False,
) -> Generator[Document | ConnectorFailure | HierarchyNode, None, None]:
"""Attempts to yield back ALL the documents described by the errors, no checkpointing.
Caller's responsibility is to delete the old ConnectorFailures and replace with the new ones.
If include_permissions is True, the documents will have permissions synced.
May also yield HierarchyNode objects for ancestor folders of resolved documents.
"""
raise NotImplementedError
class HierarchyConnector(BaseConnector):
@abc.abstractmethod
def load_hierarchy(

View File

@@ -60,10 +60,8 @@ logger = setup_logger()
ONE_HOUR = 3600
_MAX_RESULTS_FETCH_IDS = 5000
_MAX_RESULTS_FETCH_IDS = 5000 # 5000
_JIRA_FULL_PAGE_SIZE = 50
# https://developer.atlassian.com/cloud/jira/platform/rest/v3/api-group-issues/
_JIRA_BULK_FETCH_LIMIT = 100
# Constants for Jira field names
_FIELD_REPORTER = "reporter"
@@ -257,13 +255,15 @@ def _bulk_fetch_request(
return resp.json()["issues"]
def _bulk_fetch_batch(
jira_client: JIRA, issue_ids: list[str], fields: str | None
) -> list[dict[str, Any]]:
"""Fetch a single batch (must be <= _JIRA_BULK_FETCH_LIMIT).
On JSONDecodeError, recursively bisects until it succeeds or reaches size 1."""
def bulk_fetch_issues(
jira_client: JIRA, issue_ids: list[str], fields: str | None = None
) -> list[Issue]:
# TODO(evan): move away from this jira library if they continue to not support
# the endpoints we need. Using private fields is not ideal, but
# is likely fine for now since we pin the library version
try:
return _bulk_fetch_request(jira_client, issue_ids, fields)
raw_issues = _bulk_fetch_request(jira_client, issue_ids, fields)
except requests.exceptions.JSONDecodeError:
if len(issue_ids) <= 1:
logger.exception(
@@ -277,25 +277,12 @@ def _bulk_fetch_batch(
f"Jira bulk-fetch JSON decode failed for batch of {len(issue_ids)} issues. "
f"Splitting into sub-batches of {mid} and {len(issue_ids) - mid}."
)
left = _bulk_fetch_batch(jira_client, issue_ids[:mid], fields)
right = _bulk_fetch_batch(jira_client, issue_ids[mid:], fields)
left = bulk_fetch_issues(jira_client, issue_ids[:mid], fields)
right = bulk_fetch_issues(jira_client, issue_ids[mid:], fields)
return left + right
def bulk_fetch_issues(
jira_client: JIRA, issue_ids: list[str], fields: str | None = None
) -> list[Issue]:
# TODO(evan): move away from this jira library if they continue to not support
# the endpoints we need. Using private fields is not ideal, but
# is likely fine for now since we pin the library version
raw_issues: list[dict[str, Any]] = []
for batch in chunked(issue_ids, _JIRA_BULK_FETCH_LIMIT):
try:
raw_issues.extend(_bulk_fetch_batch(jira_client, list(batch), fields))
except Exception as e:
logger.error(f"Error fetching issues: {e}")
raise
except Exception as e:
logger.error(f"Error fetching issues: {e}")
raise
return [
Issue(jira_client._options, jira_client._session, raw=issue)

View File

@@ -1,10 +1,8 @@
import sys
from collections.abc import Sequence
from datetime import datetime
from enum import Enum
from typing import Any
from typing import cast
from typing import Literal
from pydantic import BaseModel
from pydantic import Field
@@ -35,28 +33,17 @@ class ConnectorMissingCredentialError(PermissionError):
)
class SectionType(str, Enum):
"""Discriminator for Section subclasses."""
TEXT = "text"
IMAGE = "image"
TABULAR = "tabular"
class Section(BaseModel):
"""Base section class with common attributes"""
type: SectionType
link: str | None = None
text: str | None = None
image_file_id: str | None = None
heading: str | None = None
class TextSection(Section):
"""Section containing text content"""
type: Literal[SectionType.TEXT] = SectionType.TEXT
text: str
def __sizeof__(self) -> int:
@@ -66,25 +53,12 @@ class TextSection(Section):
class ImageSection(Section):
"""Section containing an image reference"""
type: Literal[SectionType.IMAGE] = SectionType.IMAGE
image_file_id: str
def __sizeof__(self) -> int:
return sys.getsizeof(self.image_file_id) + sys.getsizeof(self.link)
class TabularSection(Section):
"""Section containing tabular data (csv/tsv content, or one sheet of
an xlsx workbook rendered as CSV)."""
type: Literal[SectionType.TABULAR] = SectionType.TABULAR
text: str # CSV representation in a string
link: str
def __sizeof__(self) -> int:
return sys.getsizeof(self.text) + sys.getsizeof(self.link)
class BasicExpertInfo(BaseModel):
"""Basic Information for the owner of a document, any of the fields can be left as None
Display fallback goes as follows:
@@ -160,6 +134,7 @@ class BasicExpertInfo(BaseModel):
@classmethod
def from_dict(cls, model_dict: dict[str, Any]) -> "BasicExpertInfo":
first_name = cast(str, model_dict.get("FirstName"))
last_name = cast(str, model_dict.get("LastName"))
email = cast(str, model_dict.get("Email"))
@@ -186,7 +161,7 @@ class DocumentBase(BaseModel):
"""Used for Onyx ingestion api, the ID is inferred before use if not provided"""
id: str | None = None
sections: Sequence[TextSection | ImageSection | TabularSection]
sections: list[TextSection | ImageSection]
source: DocumentSource | None = None
semantic_identifier: str # displayed in the UI as the main identifier for the doc
# TODO(andrei): Ideally we could improve this to where each value is just a
@@ -396,9 +371,12 @@ class IndexingDocument(Document):
)
else:
section_len = sum(
len(section.text) if section.text is not None else 0
(
len(section.text)
if isinstance(section, TextSection) and section.text is not None
else 0
)
for section in self.sections
if isinstance(section, (TextSection, TabularSection))
)
return title_len + section_len

View File

@@ -41,10 +41,6 @@ from onyx.configs.app_configs import REQUEST_TIMEOUT_SECONDS
from onyx.configs.app_configs import SHAREPOINT_CONNECTOR_SIZE_THRESHOLD
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import FileOrigin
from onyx.connectors.cross_connector_utils.tabular_section_utils import is_tabular_file
from onyx.connectors.cross_connector_utils.tabular_section_utils import (
tabular_file_to_sections,
)
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.interfaces import CheckpointedConnectorWithPermSync
from onyx.connectors.interfaces import CheckpointOutput
@@ -64,7 +60,6 @@ from onyx.connectors.models import ExternalAccess
from onyx.connectors.models import HierarchyNode
from onyx.connectors.models import ImageSection
from onyx.connectors.models import SlimDocument
from onyx.connectors.models import TabularSection
from onyx.connectors.models import TextSection
from onyx.connectors.sharepoint.connector_utils import get_sharepoint_external_access
from onyx.db.enums import HierarchyNodeType
@@ -591,7 +586,7 @@ def _convert_driveitem_to_document_with_permissions(
driveitem, f"Failed to download via graph api: {e}", e
)
sections: list[TextSection | ImageSection | TabularSection] = []
sections: list[TextSection | ImageSection] = []
file_ext = get_file_ext(driveitem.name)
if not content_bytes:
@@ -607,19 +602,6 @@ def _convert_driveitem_to_document_with_permissions(
)
image_section.link = driveitem.web_url
sections.append(image_section)
elif is_tabular_file(driveitem.name):
try:
sections.extend(
tabular_file_to_sections(
file=io.BytesIO(content_bytes),
file_name=driveitem.name,
link=driveitem.web_url or "",
)
)
except Exception as e:
logger.warning(
f"Failed to extract tabular sections for '{driveitem.name}': {e}"
)
else:
def _store_embedded_image(img_data: bytes, img_name: str) -> None:

View File

@@ -1,4 +1,3 @@
from dataclasses import dataclass
from datetime import datetime
from typing import TypedDict
@@ -7,14 +6,6 @@ from pydantic import BaseModel
from onyx.onyxbot.slack.models import ChannelType
@dataclass(frozen=True)
class DirectThreadFetch:
"""Request to fetch a Slack thread directly by channel and timestamp."""
channel_id: str
thread_ts: str
class ChannelMetadata(TypedDict):
"""Type definition for cached channel metadata."""

View File

@@ -19,7 +19,6 @@ from onyx.configs.chat_configs import DOC_TIME_DECAY
from onyx.connectors.models import IndexingDocument
from onyx.connectors.models import TextSection
from onyx.context.search.federated.models import ChannelMetadata
from onyx.context.search.federated.models import DirectThreadFetch
from onyx.context.search.federated.models import SlackMessage
from onyx.context.search.federated.slack_search_utils import ALL_CHANNEL_TYPES
from onyx.context.search.federated.slack_search_utils import build_channel_query_filter
@@ -50,6 +49,7 @@ from onyx.server.federated.models import FederatedConnectorDetail
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
from onyx.utils.timing import log_function_time
from shared_configs.configs import DOC_EMBEDDING_CONTEXT_SIZE
logger = setup_logger()
@@ -58,6 +58,7 @@ HIGHLIGHT_END_CHAR = "\ue001"
CHANNEL_METADATA_CACHE_TTL = 60 * 60 * 24 # 24 hours
USER_PROFILE_CACHE_TTL = 60 * 60 * 24 # 24 hours
SLACK_THREAD_CONTEXT_WINDOW = 3 # Number of messages before matched message to include
CHANNEL_METADATA_MAX_RETRIES = 3 # Maximum retry attempts for channel metadata fetching
CHANNEL_METADATA_RETRY_DELAY = 1 # Initial retry delay in seconds (exponential backoff)
@@ -420,94 +421,6 @@ class SlackQueryResult(BaseModel):
filtered_channels: list[str] # Channels filtered out during this query
def _fetch_thread_from_url(
thread_fetch: DirectThreadFetch,
access_token: str,
channel_metadata_dict: dict[str, ChannelMetadata] | None = None,
) -> SlackQueryResult:
"""Fetch a thread directly from a Slack URL via conversations.replies."""
channel_id = thread_fetch.channel_id
thread_ts = thread_fetch.thread_ts
slack_client = WebClient(token=access_token)
try:
response = slack_client.conversations_replies(
channel=channel_id,
ts=thread_ts,
)
response.validate()
messages: list[dict[str, Any]] = response.get("messages", [])
except SlackApiError as e:
logger.warning(
f"Failed to fetch thread from URL (channel={channel_id}, ts={thread_ts}): {e}"
)
return SlackQueryResult(messages=[], filtered_channels=[])
if not messages:
logger.warning(
f"No messages found for URL override (channel={channel_id}, ts={thread_ts})"
)
return SlackQueryResult(messages=[], filtered_channels=[])
# Build thread text from all messages
thread_text = _build_thread_text(messages, access_token, None, slack_client)
# Get channel name from metadata cache or API
channel_name = "unknown"
if channel_metadata_dict and channel_id in channel_metadata_dict:
channel_name = channel_metadata_dict[channel_id].get("name", "unknown")
else:
try:
ch_response = slack_client.conversations_info(channel=channel_id)
ch_response.validate()
channel_info: dict[str, Any] = ch_response.get("channel", {})
channel_name = channel_info.get("name", "unknown")
except SlackApiError:
pass
# Build the SlackMessage
parent_msg = messages[0]
message_ts = parent_msg.get("ts", thread_ts)
username = parent_msg.get("user", "unknown_user")
parent_text = parent_msg.get("text", "")
snippet = (
parent_text[:50].rstrip() + "..." if len(parent_text) > 50 else parent_text
).replace("\n", " ")
doc_time = datetime.fromtimestamp(float(message_ts))
decay_factor = DOC_TIME_DECAY
doc_age_years = (datetime.now() - doc_time).total_seconds() / (365 * 24 * 60 * 60)
recency_bias = max(1 / (1 + decay_factor * doc_age_years), 0.75)
permalink = (
f"https://slack.com/archives/{channel_id}/p{message_ts.replace('.', '')}"
)
slack_message = SlackMessage(
document_id=f"{channel_id}_{message_ts}",
channel_id=channel_id,
message_id=message_ts,
thread_id=None, # Prevent double-enrichment in thread context fetch
link=permalink,
metadata={
"channel": channel_name,
"time": doc_time.isoformat(),
},
timestamp=doc_time,
recency_bias=recency_bias,
semantic_identifier=f"{username} in #{channel_name}: {snippet}",
text=thread_text,
highlighted_texts=set(),
slack_score=100000.0, # High priority — user explicitly asked for this thread
)
logger.info(
f"URL override: fetched thread from channel={channel_id}, ts={thread_ts}, {len(messages)} messages"
)
return SlackQueryResult(messages=[slack_message], filtered_channels=[])
def query_slack(
query_string: str,
access_token: str,
@@ -519,6 +432,7 @@ def query_slack(
available_channels: list[str] | None = None,
channel_metadata_dict: dict[str, ChannelMetadata] | None = None,
) -> SlackQueryResult:
# Check if query has channel override (user specified channels in query)
has_channel_override = query_string.startswith("__CHANNEL_OVERRIDE__")
@@ -748,6 +662,7 @@ def _fetch_thread_context(
"""
channel_id = message.channel_id
thread_id = message.thread_id
message_id = message.message_id
# If not a thread, return original text as success
if thread_id is None:
@@ -780,37 +695,62 @@ def _fetch_thread_context(
if len(messages) <= 1:
return ThreadContextResult.success(message.text)
# Build thread text from thread starter + all replies
thread_text = _build_thread_text(messages, access_token, team_id, slack_client)
# Build thread text from thread starter + context window around matched message
thread_text = _build_thread_text(
messages, message_id, thread_id, access_token, team_id, slack_client
)
return ThreadContextResult.success(thread_text)
def _build_thread_text(
messages: list[dict[str, Any]],
message_id: str,
thread_id: str,
access_token: str,
team_id: str | None,
slack_client: WebClient,
) -> str:
"""Build thread text including all replies.
Includes the thread parent message followed by all replies in order.
"""
"""Build the thread text from messages."""
msg_text = messages[0].get("text", "")
msg_sender = messages[0].get("user", "")
thread_text = f"<@{msg_sender}>: {msg_text}"
# All messages after index 0 are replies
replies = messages[1:]
if not replies:
return thread_text
logger.debug(f"Thread {messages[0].get('ts')}: {len(replies)} replies included")
thread_text += "\n\nReplies:"
if thread_id == message_id:
message_id_idx = 0
else:
message_id_idx = next(
(i for i, msg in enumerate(messages) if msg.get("ts") == message_id), 0
)
if not message_id_idx:
return thread_text
for msg in replies:
start_idx = max(1, message_id_idx - SLACK_THREAD_CONTEXT_WINDOW)
if start_idx > 1:
thread_text += "\n..."
for i in range(start_idx, message_id_idx):
msg_text = messages[i].get("text", "")
msg_sender = messages[i].get("user", "")
thread_text += f"\n\n<@{msg_sender}>: {msg_text}"
msg_text = messages[message_id_idx].get("text", "")
msg_sender = messages[message_id_idx].get("user", "")
thread_text += f"\n\n<@{msg_sender}>: {msg_text}"
# Add following replies
len_replies = 0
for msg in messages[message_id_idx + 1 :]:
msg_text = msg.get("text", "")
msg_sender = msg.get("user", "")
thread_text += f"\n\n<@{msg_sender}>: {msg_text}"
reply = f"\n\n<@{msg_sender}>: {msg_text}"
thread_text += reply
len_replies += len(reply)
if len_replies >= DOC_EMBEDDING_CONTEXT_SIZE * 4:
thread_text += "\n..."
break
# Replace user IDs with names using cached lookups
userids: set[str] = set(re.findall(r"<@([A-Z0-9]+)>", thread_text))
@@ -1036,16 +976,7 @@ def slack_retrieval(
# Query slack with entity filtering
llm = get_default_llm()
query_items = build_slack_queries(query, llm, entities, available_channels)
# Partition into direct thread fetches and search query strings
direct_fetches: list[DirectThreadFetch] = []
query_strings: list[str] = []
for item in query_items:
if isinstance(item, DirectThreadFetch):
direct_fetches.append(item)
else:
query_strings.append(item)
query_strings = build_slack_queries(query, llm, entities, available_channels)
# Determine filtering based on entities OR context (bot)
include_dm = False
@@ -1062,16 +993,8 @@ def slack_retrieval(
f"Private channel context: will only allow messages from {allowed_private_channel} + public channels"
)
# Build search tasks — direct thread fetches + keyword searches
search_tasks: list[tuple] = [
(
_fetch_thread_from_url,
(fetch, access_token, channel_metadata_dict),
)
for fetch in direct_fetches
]
search_tasks.extend(
# Build search tasks
search_tasks = [
(
query_slack,
(
@@ -1087,7 +1010,7 @@ def slack_retrieval(
),
)
for query_string in query_strings
)
]
# If include_dm is True AND we're not already searching all channels,
# add additional searches without channel filters.

View File

@@ -10,7 +10,6 @@ from pydantic import ValidationError
from onyx.configs.app_configs import MAX_SLACK_QUERY_EXPANSIONS
from onyx.context.search.federated.models import ChannelMetadata
from onyx.context.search.federated.models import DirectThreadFetch
from onyx.context.search.models import ChunkIndexRequest
from onyx.federated_connectors.slack.models import SlackEntities
from onyx.llm.interfaces import LLM
@@ -639,38 +638,12 @@ def expand_query_with_llm(query_text: str, llm: LLM) -> list[str]:
return [query_text]
SLACK_URL_PATTERN = re.compile(
r"https?://[a-z0-9-]+\.slack\.com/archives/([A-Z0-9]+)/p(\d{16})"
)
def extract_slack_message_urls(
query_text: str,
) -> list[tuple[str, str]]:
"""Extract Slack message URLs from query text.
Parses URLs like:
https://onyx-company.slack.com/archives/C097NBWMY8Y/p1775491616524769
Returns list of (channel_id, thread_ts) tuples.
The 16-digit timestamp is converted to Slack ts format (with dot).
"""
results = []
for match in SLACK_URL_PATTERN.finditer(query_text):
channel_id = match.group(1)
raw_ts = match.group(2)
# Convert p1775491616524769 -> 1775491616.524769
thread_ts = f"{raw_ts[:10]}.{raw_ts[10:]}"
results.append((channel_id, thread_ts))
return results
def build_slack_queries(
query: ChunkIndexRequest,
llm: LLM,
entities: dict[str, Any] | None = None,
available_channels: list[str] | None = None,
) -> list[str | DirectThreadFetch]:
) -> list[str]:
"""Build Slack query strings with date filtering and query expansion."""
default_search_days = 30
if entities:
@@ -695,15 +668,6 @@ def build_slack_queries(
cutoff_date = datetime.now(timezone.utc) - timedelta(days=days_back)
time_filter = f" after:{cutoff_date.strftime('%Y-%m-%d')}"
# Check for Slack message URLs — if found, add direct fetch requests
url_fetches: list[DirectThreadFetch] = []
slack_urls = extract_slack_message_urls(query.query)
for channel_id, thread_ts in slack_urls:
url_fetches.append(
DirectThreadFetch(channel_id=channel_id, thread_ts=thread_ts)
)
logger.info(f"Detected Slack URL: channel={channel_id}, ts={thread_ts}")
# ALWAYS extract channel references from the query (not just for recency queries)
channel_references = extract_channel_references_from_query(query.query)
@@ -720,9 +684,7 @@ def build_slack_queries(
# If valid channels detected, use ONLY those channels with NO keywords
# Return query with ONLY time filter + channel filter (no keywords)
return url_fetches + [
build_channel_override_query(channel_references, time_filter)
]
return [build_channel_override_query(channel_references, time_filter)]
except ValueError as e:
# If validation fails, log the error and continue with normal flow
logger.warning(f"Channel reference validation failed: {e}")
@@ -740,8 +702,7 @@ def build_slack_queries(
rephrased_queries = expand_query_with_llm(query.query, llm)
# Build final query strings with time filters
search_queries = [
return [
rephrased_query.strip() + time_filter
for rephrased_query in rephrased_queries[:MAX_SLACK_QUERY_EXPANSIONS]
]
return url_fetches + search_queries

View File

@@ -750,3 +750,31 @@ def resync_cc_pair(
)
db_session.commit()
# ── Metrics query helpers ──────────────────────────────────────────────
def get_connector_health_for_metrics(
db_session: Session,
) -> list: # Returns list of Row tuples
"""Return connector health data for Prometheus metrics.
Each row is (cc_pair_id, status, in_repeated_error_state,
last_successful_index_time, name, source).
"""
return (
db_session.query(
ConnectorCredentialPair.id,
ConnectorCredentialPair.status,
ConnectorCredentialPair.in_repeated_error_state,
ConnectorCredentialPair.last_successful_index_time,
ConnectorCredentialPair.name,
Connector.source,
)
.join(
Connector,
ConnectorCredentialPair.connector_id == Connector.id,
)
.all()
)

View File

@@ -335,7 +335,6 @@ def update_document_set(
"Cannot update document set while it is syncing. Please wait for it to finish syncing, and then try again."
)
document_set_row.name = document_set_update_request.name
document_set_row.description = document_set_update_request.description
if not DISABLE_VECTOR_DB:
document_set_row.is_up_to_date = False

View File

@@ -11,7 +11,6 @@ from sqlalchemy import event
from sqlalchemy import pool
from sqlalchemy.engine import create_engine
from sqlalchemy.engine import Engine
from sqlalchemy.exc import DBAPIError
from sqlalchemy.orm import Session
from onyx.configs.app_configs import DB_READONLY_PASSWORD
@@ -347,25 +346,6 @@ def get_session_with_shared_schema() -> Generator[Session, None, None]:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
def _safe_close_session(session: Session) -> None:
"""Close a session, catching connection-closed errors during cleanup.
Long-running operations (e.g. multi-model LLM loops) can hold a session
open for minutes. If the underlying connection is dropped by cloud
infrastructure (load-balancer timeouts, PgBouncer, idle-in-transaction
timeouts, etc.), the implicit rollback in Session.close() raises
OperationalError or InterfaceError. Since the work is already complete,
we log and move on — SQLAlchemy internally invalidates the connection
for pool recycling.
"""
try:
session.close()
except DBAPIError:
logger.warning(
"DB connection lost during session cleanup — the connection will be invalidated and recycled by the pool."
)
@contextmanager
def get_session_with_tenant(*, tenant_id: str) -> Generator[Session, None, None]:
"""
@@ -378,11 +358,8 @@ def get_session_with_tenant(*, tenant_id: str) -> Generator[Session, None, None]
# no need to use the schema translation map for self-hosted + default schema
if not MULTI_TENANT and tenant_id == POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE:
session = Session(bind=engine, expire_on_commit=False)
try:
with Session(bind=engine, expire_on_commit=False) as session:
yield session
finally:
_safe_close_session(session)
return
# Create connection with schema translation to handle querying the right schema
@@ -390,11 +367,8 @@ def get_session_with_tenant(*, tenant_id: str) -> Generator[Session, None, None]
with engine.connect().execution_options(
schema_translate_map=schema_translate_map
) as connection:
session = Session(bind=connection, expire_on_commit=False)
try:
with Session(bind=connection, expire_on_commit=False) as session:
yield session
finally:
_safe_close_session(session)
def get_session() -> Generator[Session, None, None]:

View File

@@ -2,6 +2,8 @@ from collections.abc import Sequence
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from typing import NamedTuple
from typing import TYPE_CHECKING
from typing import TypeVarTuple
from sqlalchemy import and_
@@ -28,6 +30,17 @@ from onyx.utils.logger import setup_logger
from onyx.utils.telemetry import optional_telemetry
from onyx.utils.telemetry import RecordType
if TYPE_CHECKING:
from onyx.configs.constants import DocumentSource
# from sqlalchemy.sql.selectable import Select
# Comment out unused imports that cause mypy errors
# from onyx.auth.models import UserRole
# from onyx.configs.constants import MAX_LAST_VALID_CHECKPOINT_AGE_SECONDS
# from onyx.db.connector_credential_pair import ConnectorCredentialPairIdentifier
# from onyx.db.engine import async_query_for_dms
logger = setup_logger()
@@ -886,7 +899,6 @@ def create_index_attempt_error(
failure: ConnectorFailure,
db_session: Session,
) -> int:
exc = failure.exception
new_error = IndexAttemptError(
index_attempt_id=index_attempt_id,
connector_credential_pair_id=connector_credential_pair_id,
@@ -909,7 +921,6 @@ def create_index_attempt_error(
),
failure_message=failure.failure_message,
is_resolved=False,
error_type=type(exc).__name__ if exc else None,
)
db_session.add(new_error)
db_session.commit()
@@ -968,48 +979,104 @@ def get_index_attempt_errors_for_cc_pair(
return list(db_session.scalars(stmt).all())
def get_index_attempt_errors_across_connectors(
# ── Metrics query helpers ──────────────────────────────────────────────
class ActiveIndexAttemptMetric(NamedTuple):
"""Row returned by get_active_index_attempts_for_metrics."""
status: IndexingStatus
source: "DocumentSource"
cc_pair_id: int
cc_pair_name: str | None
attempt_count: int
def get_active_index_attempts_for_metrics(
db_session: Session,
cc_pair_id: int | None = None,
error_type: str | None = None,
start_time: datetime | None = None,
end_time: datetime | None = None,
unresolved_only: bool = True,
page: int = 0,
page_size: int = 25,
) -> tuple[list[IndexAttemptError], int]:
"""Query index attempt errors across all connectors with optional filters.
) -> list[ActiveIndexAttemptMetric]:
"""Return non-terminal index attempts grouped by status, source, and connector.
Returns (errors, total_count) for pagination.
Each row is (status, source, cc_pair_id, cc_pair_name, attempt_count).
"""
stmt = select(IndexAttemptError)
count_stmt = select(func.count()).select_from(IndexAttemptError)
from onyx.db.models import Connector
if cc_pair_id is not None:
stmt = stmt.where(IndexAttemptError.connector_credential_pair_id == cc_pair_id)
count_stmt = count_stmt.where(
IndexAttemptError.connector_credential_pair_id == cc_pair_id
terminal_statuses = [s for s in IndexingStatus if s.is_terminal()]
rows = (
db_session.query(
IndexAttempt.status,
Connector.source,
ConnectorCredentialPair.id,
ConnectorCredentialPair.name,
func.count(),
)
.join(
ConnectorCredentialPair,
IndexAttempt.connector_credential_pair_id == ConnectorCredentialPair.id,
)
.join(
Connector,
ConnectorCredentialPair.connector_id == Connector.id,
)
.filter(IndexAttempt.status.notin_(terminal_statuses))
.group_by(
IndexAttempt.status,
Connector.source,
ConnectorCredentialPair.id,
ConnectorCredentialPair.name,
)
.all()
)
return [ActiveIndexAttemptMetric(*row) for row in rows]
if error_type is not None:
stmt = stmt.where(IndexAttemptError.error_type == error_type)
count_stmt = count_stmt.where(IndexAttemptError.error_type == error_type)
if unresolved_only:
stmt = stmt.where(IndexAttemptError.is_resolved.is_(False))
count_stmt = count_stmt.where(IndexAttemptError.is_resolved.is_(False))
def get_failed_attempt_counts_by_cc_pair(
db_session: Session,
since: datetime | None = None,
) -> dict[int, int]:
"""Return {cc_pair_id: failed_attempt_count} for all connectors.
if start_time is not None:
stmt = stmt.where(IndexAttemptError.time_created >= start_time)
count_stmt = count_stmt.where(IndexAttemptError.time_created >= start_time)
When ``since`` is provided, only attempts created after that timestamp
are counted. Defaults to the last 90 days to avoid unbounded historical
aggregation.
"""
if since is None:
since = datetime.now(timezone.utc) - timedelta(days=90)
if end_time is not None:
stmt = stmt.where(IndexAttemptError.time_created <= end_time)
count_stmt = count_stmt.where(IndexAttemptError.time_created <= end_time)
rows = (
db_session.query(
IndexAttempt.connector_credential_pair_id,
func.count(),
)
.filter(IndexAttempt.status == IndexingStatus.FAILED)
.filter(IndexAttempt.time_created >= since)
.group_by(IndexAttempt.connector_credential_pair_id)
.all()
)
return {cc_id: count for cc_id, count in rows}
stmt = stmt.order_by(desc(IndexAttemptError.time_created))
stmt = stmt.offset(page * page_size).limit(page_size)
total = db_session.scalar(count_stmt) or 0
errors = list(db_session.scalars(stmt).all())
return errors, total
def get_docs_indexed_by_cc_pair(
db_session: Session,
since: datetime | None = None,
) -> dict[int, int]:
"""Return {cc_pair_id: total_new_docs_indexed} across successful attempts.
Only counts attempts with status SUCCESS to avoid inflating counts with
partial results from failed attempts. When ``since`` is provided, only
attempts created after that timestamp are included.
"""
if since is None:
since = datetime.now(timezone.utc) - timedelta(days=90)
query = (
db_session.query(
IndexAttempt.connector_credential_pair_id,
func.sum(func.coalesce(IndexAttempt.new_docs_indexed, 0)),
)
.filter(IndexAttempt.status == IndexingStatus.SUCCESS)
.filter(IndexAttempt.time_created >= since)
.group_by(IndexAttempt.connector_credential_pair_id)
)
rows = query.all()
return {cc_id: int(total or 0) for cc_id, total in rows}

View File

@@ -5,7 +5,6 @@ from pydantic import ConfigDict
from sqlalchemy import select
from sqlalchemy.orm import Session
from onyx.db.engine.sql_engine import get_session_with_current_tenant_if_none
from onyx.db.models import Memory
from onyx.db.models import User
@@ -84,51 +83,47 @@ def get_memories(user: User, db_session: Session) -> UserMemoryContext:
def add_memory(
user_id: UUID,
memory_text: str,
db_session: Session | None = None,
) -> int:
db_session: Session,
) -> Memory:
"""Insert a new Memory row for the given user.
If the user already has MAX_MEMORIES_PER_USER memories, the oldest
one (lowest id) is deleted before inserting the new one.
Returns the id of the newly created Memory row.
"""
with get_session_with_current_tenant_if_none(db_session) as db_session:
existing = db_session.scalars(
select(Memory).where(Memory.user_id == user_id).order_by(Memory.id.asc())
).all()
existing = db_session.scalars(
select(Memory).where(Memory.user_id == user_id).order_by(Memory.id.asc())
).all()
if len(existing) >= MAX_MEMORIES_PER_USER:
db_session.delete(existing[0])
if len(existing) >= MAX_MEMORIES_PER_USER:
db_session.delete(existing[0])
memory = Memory(
user_id=user_id,
memory_text=memory_text,
)
db_session.add(memory)
db_session.commit()
return memory.id
memory = Memory(
user_id=user_id,
memory_text=memory_text,
)
db_session.add(memory)
db_session.commit()
return memory
def update_memory_at_index(
user_id: UUID,
index: int,
new_text: str,
db_session: Session | None = None,
) -> int | None:
db_session: Session,
) -> Memory | None:
"""Update the memory at the given 0-based index (ordered by id ASC, matching get_memories()).
Returns the id of the updated Memory row, or None if the index is out of range.
Returns the updated Memory row, or None if the index is out of range.
"""
with get_session_with_current_tenant_if_none(db_session) as db_session:
memory_rows = db_session.scalars(
select(Memory).where(Memory.user_id == user_id).order_by(Memory.id.asc())
).all()
memory_rows = db_session.scalars(
select(Memory).where(Memory.user_id == user_id).order_by(Memory.id.asc())
).all()
if index < 0 or index >= len(memory_rows):
return None
if index < 0 or index >= len(memory_rows):
return None
memory = memory_rows[index]
memory.memory_text = new_text
db_session.commit()
return memory.id
memory = memory_rows[index]
memory.memory_text = new_text
db_session.commit()
return memory

View File

@@ -2422,8 +2422,6 @@ class IndexAttemptError(Base):
failure_message: Mapped[str] = mapped_column(Text)
is_resolved: Mapped[bool] = mapped_column(Boolean, default=False)
error_type: Mapped[str | None] = mapped_column(String, nullable=True)
time_created: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),

View File

@@ -7,6 +7,8 @@ import time
from collections.abc import Callable
from typing import cast
from sqlalchemy.orm import Session
from onyx.chat.chat_state import ChatStateContainer
from onyx.chat.citation_processor import CitationMapping
from onyx.chat.citation_processor import DynamicCitationProcessor
@@ -20,7 +22,6 @@ from onyx.chat.models import LlmStepResult
from onyx.chat.models import ToolCallSimple
from onyx.configs.chat_configs import SKIP_DEEP_RESEARCH_CLARIFICATION
from onyx.configs.constants import MessageType
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.tools import get_tool_by_name
from onyx.deep_research.dr_mock_tools import get_clarification_tool_definitions
from onyx.deep_research.dr_mock_tools import get_orchestrator_tools
@@ -183,14 +184,6 @@ def generate_final_report(
return has_reasoned
def _get_research_agent_tool_id() -> int:
with get_session_with_current_tenant() as db_session:
return get_tool_by_name(
tool_name=RESEARCH_AGENT_TOOL_NAME,
db_session=db_session,
).id
@log_function_time(print_only=True)
def run_deep_research_llm_loop(
emitter: Emitter,
@@ -200,6 +193,7 @@ def run_deep_research_llm_loop(
custom_agent_prompt: str | None, # noqa: ARG001
llm: LLM,
token_counter: Callable[[str], int],
db_session: Session,
skip_clarification: bool = False,
user_identity: LLMUserIdentity | None = None,
chat_session_id: str | None = None,
@@ -723,7 +717,6 @@ def run_deep_research_llm_loop(
simple_chat_history.append(assistant_with_tools)
# Now add TOOL_CALL_RESPONSE messages and tool call info for each result
research_agent_tool_id = _get_research_agent_tool_id()
for tab_index, report in enumerate(
research_results.intermediate_reports
):
@@ -744,7 +737,10 @@ def run_deep_research_llm_loop(
tab_index=tab_index,
tool_name=current_tool_call.tool_name,
tool_call_id=current_tool_call.tool_call_id,
tool_id=research_agent_tool_id,
tool_id=get_tool_by_name(
tool_name=RESEARCH_AGENT_TOOL_NAME,
db_session=db_session,
).id,
reasoning_tokens=llm_step_result.reasoning
or most_recent_reasoning,
tool_call_arguments=current_tool_call.tool_args,

View File

@@ -379,25 +379,13 @@ def _worksheet_to_matrix(
worksheet: Worksheet,
) -> list[list[str]]:
"""
Converts a singular worksheet to a matrix of values.
Rows are padded to a uniform width. In openpyxl's read_only mode,
iter_rows can yield rows of differing lengths (trailing empty cells
are sometimes omitted), and downstream column cleanup assumes a
rectangular matrix.
Converts a singular worksheet to a matrix of values
"""
rows: list[list[str]] = []
max_len = 0
for worksheet_row in worksheet.iter_rows(min_row=1, values_only=True):
row = ["" if cell is None else str(cell) for cell in worksheet_row]
if len(row) > max_len:
max_len = len(row)
rows.append(row)
for row in rows:
if len(row) < max_len:
row.extend([""] * (max_len - len(row)))
return rows
@@ -475,13 +463,29 @@ def _remove_empty_runs(
return result
def xlsx_sheet_extraction(file: IO[Any], file_name: str = "") -> list[tuple[str, str]]:
"""
Converts each sheet in the excel file to a csv condensed string.
Returns a string and the worksheet title for each worksheet
def xlsx_to_text(file: IO[Any], file_name: str = "") -> str:
# TODO: switch back to this approach in a few months when markitdown
# fixes their handling of excel files
Returns a list of (csv_text, sheet)
"""
# md = get_markitdown_converter()
# stream_info = StreamInfo(
# mimetype=SPREADSHEET_MIME_TYPE, filename=file_name or None, extension=".xlsx"
# )
# try:
# workbook = md.convert(to_bytesio(file), stream_info=stream_info)
# except (
# BadZipFile,
# ValueError,
# FileConversionException,
# UnsupportedFormatException,
# ) as e:
# error_str = f"Failed to extract text from {file_name or 'xlsx file'}: {e}"
# if file_name.startswith("~"):
# logger.debug(error_str + " (this is expected for files with ~)")
# else:
# logger.warning(error_str)
# return ""
# return workbook.markdown
try:
workbook = openpyxl.load_workbook(file, read_only=True)
except BadZipFile as e:
@@ -490,30 +494,23 @@ def xlsx_sheet_extraction(file: IO[Any], file_name: str = "") -> list[tuple[str,
logger.debug(error_str + " (this is expected for files with ~)")
else:
logger.warning(error_str)
return []
return ""
except Exception as e:
if any(s in str(e) for s in KNOWN_OPENPYXL_BUGS):
logger.error(
f"Failed to extract text from {file_name or 'xlsx file'}. This happens due to a bug in openpyxl. {e}"
)
return []
return ""
raise
sheets: list[tuple[str, str]] = []
text_content = []
for sheet in workbook.worksheets:
sheet_matrix = _clean_worksheet_matrix(_worksheet_to_matrix(sheet))
buf = io.StringIO()
writer = csv.writer(buf, lineterminator="\n")
writer.writerows(sheet_matrix)
csv_text = buf.getvalue().rstrip("\n")
if csv_text.strip():
sheets.append((csv_text, sheet.title))
return sheets
def xlsx_to_text(file: IO[Any], file_name: str = "") -> str:
sheets = xlsx_sheet_extraction(file, file_name)
return TEXT_SECTION_SEPARATOR.join(csv_text for csv_text, _title in sheets)
text_content.append(buf.getvalue().rstrip("\n"))
return TEXT_SECTION_SEPARATOR.join(text_content)
def eml_to_text(file: IO[Any]) -> str:

View File

@@ -1,3 +1,5 @@
from typing import cast
from chonkie import SentenceChunker
from onyx.configs.app_configs import AVERAGE_SUMMARY_EMBEDDINGS
@@ -14,14 +16,16 @@ from onyx.connectors.cross_connector_utils.miscellaneous_utils import (
get_metadata_keys_to_ignore,
)
from onyx.connectors.models import IndexingDocument
from onyx.indexing.chunking import DocumentChunker
from onyx.indexing.chunking import extract_blurb
from onyx.connectors.models import Section
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.indexing.models import DocAwareChunk
from onyx.llm.utils import MAX_CONTEXT_TOKENS
from onyx.natural_language_processing.utils import BaseTokenizer
from onyx.utils.logger import setup_logger
from onyx.utils.text_processing import clean_text
from onyx.utils.text_processing import shared_precompare_cleanup
from shared_configs.configs import DOC_EMBEDDING_CONTEXT_SIZE
from shared_configs.configs import STRICT_CHUNK_TOKEN_LIMIT
# Not supporting overlaps, we need a clean combination of chunks and it is unclear if overlaps
# actually help quality at all
@@ -150,6 +154,9 @@ class Chunker:
self.tokenizer = tokenizer
self.callback = callback
self.max_context = 0
self.prompt_tokens = 0
# Create a token counter function that returns the count instead of the tokens
def token_counter(text: str) -> int:
return len(tokenizer.encode(text))
@@ -179,12 +186,234 @@ class Chunker:
else None
)
self._document_chunker = DocumentChunker(
tokenizer=tokenizer,
blurb_splitter=self.blurb_splitter,
chunk_splitter=self.chunk_splitter,
mini_chunk_splitter=self.mini_chunk_splitter,
def _split_oversized_chunk(self, text: str, content_token_limit: int) -> list[str]:
"""
Splits the text into smaller chunks based on token count to ensure
no chunk exceeds the content_token_limit.
"""
tokens = self.tokenizer.tokenize(text)
chunks = []
start = 0
total_tokens = len(tokens)
while start < total_tokens:
end = min(start + content_token_limit, total_tokens)
token_chunk = tokens[start:end]
chunk_text = " ".join(token_chunk)
chunks.append(chunk_text)
start = end
return chunks
def _extract_blurb(self, text: str) -> str:
"""
Extract a short blurb from the text (first chunk of size `blurb_size`).
"""
# chunker is in `text` mode
texts = cast(list[str], self.blurb_splitter.chunk(text))
if not texts:
return ""
return texts[0]
def _get_mini_chunk_texts(self, chunk_text: str) -> list[str] | None:
"""
For "multipass" mode: additional sub-chunks (mini-chunks) for use in certain embeddings.
"""
if self.mini_chunk_splitter and chunk_text.strip():
# chunker is in `text` mode
return cast(list[str], self.mini_chunk_splitter.chunk(chunk_text))
return None
# ADDED: extra param image_url to store in the chunk
def _create_chunk(
self,
document: IndexingDocument,
chunks_list: list[DocAwareChunk],
text: str,
links: dict[int, str],
is_continuation: bool = False,
title_prefix: str = "",
metadata_suffix_semantic: str = "",
metadata_suffix_keyword: str = "",
image_file_id: str | None = None,
) -> None:
"""
Helper to create a new DocAwareChunk, append it to chunks_list.
"""
new_chunk = DocAwareChunk(
source_document=document,
chunk_id=len(chunks_list),
blurb=self._extract_blurb(text),
content=text,
source_links=links or {0: ""},
image_file_id=image_file_id,
section_continuation=is_continuation,
title_prefix=title_prefix,
metadata_suffix_semantic=metadata_suffix_semantic,
metadata_suffix_keyword=metadata_suffix_keyword,
mini_chunk_texts=self._get_mini_chunk_texts(text),
large_chunk_id=None,
doc_summary="",
chunk_context="",
contextual_rag_reserved_tokens=0, # set per-document in _handle_single_document
)
chunks_list.append(new_chunk)
def _chunk_document_with_sections(
self,
document: IndexingDocument,
sections: list[Section],
title_prefix: str,
metadata_suffix_semantic: str,
metadata_suffix_keyword: str,
content_token_limit: int,
) -> list[DocAwareChunk]:
"""
Loops through sections of the document, converting them into one or more chunks.
Works with processed sections that are base Section objects.
"""
chunks: list[DocAwareChunk] = []
link_offsets: dict[int, str] = {}
chunk_text = ""
for section_idx, section in enumerate(sections):
# Get section text and other attributes
section_text = clean_text(str(section.text or ""))
section_link_text = section.link or ""
image_url = section.image_file_id
# If there is no useful content, skip
if not section_text and (not document.title or section_idx > 0):
logger.warning(
f"Skipping empty or irrelevant section in doc {document.semantic_identifier}, link={section_link_text}"
)
continue
# CASE 1: If this section has an image, force a separate chunk
if image_url:
# First, if we have any partially built text chunk, finalize it
if chunk_text.strip():
self._create_chunk(
document,
chunks,
chunk_text,
link_offsets,
is_continuation=False,
title_prefix=title_prefix,
metadata_suffix_semantic=metadata_suffix_semantic,
metadata_suffix_keyword=metadata_suffix_keyword,
)
chunk_text = ""
link_offsets = {}
# Create a chunk specifically for this image section
# (Using the text summary that was generated during processing)
self._create_chunk(
document,
chunks,
section_text,
links={0: section_link_text} if section_link_text else {},
image_file_id=image_url,
title_prefix=title_prefix,
metadata_suffix_semantic=metadata_suffix_semantic,
metadata_suffix_keyword=metadata_suffix_keyword,
)
# Continue to next section
continue
# CASE 2: Normal text section
section_token_count = len(self.tokenizer.encode(section_text))
# If the section is large on its own, split it separately
if section_token_count > content_token_limit:
if chunk_text.strip():
self._create_chunk(
document,
chunks,
chunk_text,
link_offsets,
False,
title_prefix,
metadata_suffix_semantic,
metadata_suffix_keyword,
)
chunk_text = ""
link_offsets = {}
# chunker is in `text` mode
split_texts = cast(list[str], self.chunk_splitter.chunk(section_text))
for i, split_text in enumerate(split_texts):
# If even the split_text is bigger than strict limit, further split
if (
STRICT_CHUNK_TOKEN_LIMIT
and len(self.tokenizer.encode(split_text)) > content_token_limit
):
smaller_chunks = self._split_oversized_chunk(
split_text, content_token_limit
)
for j, small_chunk in enumerate(smaller_chunks):
self._create_chunk(
document,
chunks,
small_chunk,
{0: section_link_text},
is_continuation=(j != 0),
title_prefix=title_prefix,
metadata_suffix_semantic=metadata_suffix_semantic,
metadata_suffix_keyword=metadata_suffix_keyword,
)
else:
self._create_chunk(
document,
chunks,
split_text,
{0: section_link_text},
is_continuation=(i != 0),
title_prefix=title_prefix,
metadata_suffix_semantic=metadata_suffix_semantic,
metadata_suffix_keyword=metadata_suffix_keyword,
)
continue
# If we can still fit this section into the current chunk, do so
current_token_count = len(self.tokenizer.encode(chunk_text))
current_offset = len(shared_precompare_cleanup(chunk_text))
next_section_tokens = (
len(self.tokenizer.encode(SECTION_SEPARATOR)) + section_token_count
)
if next_section_tokens + current_token_count <= content_token_limit:
if chunk_text:
chunk_text += SECTION_SEPARATOR
chunk_text += section_text
link_offsets[current_offset] = section_link_text
else:
# finalize the existing chunk
self._create_chunk(
document,
chunks,
chunk_text,
link_offsets,
False,
title_prefix,
metadata_suffix_semantic,
metadata_suffix_keyword,
)
# start a new chunk
link_offsets = {0: section_link_text}
chunk_text = section_text
# finalize any leftover text chunk
if chunk_text.strip() or not chunks:
self._create_chunk(
document,
chunks,
chunk_text,
link_offsets or {0: ""}, # safe default
False,
title_prefix,
metadata_suffix_semantic,
metadata_suffix_keyword,
)
return chunks
def _handle_single_document(
self, document: IndexingDocument
@@ -194,10 +423,7 @@ class Chunker:
logger.debug(f"Chunking {document.semantic_identifier}")
# Title prep
title = extract_blurb(
document.get_title_for_document_index() or "",
self.blurb_splitter,
)
title = self._extract_blurb(document.get_title_for_document_index() or "")
title_prefix = title + RETURN_SEPARATOR if title else ""
title_tokens = len(self.tokenizer.encode(title_prefix))
@@ -265,7 +491,7 @@ class Chunker:
# Use processed_sections if available (IndexingDocument), otherwise use original sections
sections_to_chunk = document.processed_sections
normal_chunks = self._document_chunker.chunk(
normal_chunks = self._chunk_document_with_sections(
document,
sections_to_chunk,
title_prefix,

View File

@@ -1,7 +0,0 @@
from onyx.indexing.chunking.document_chunker import DocumentChunker
from onyx.indexing.chunking.section_chunker import extract_blurb
__all__ = [
"DocumentChunker",
"extract_blurb",
]

View File

@@ -1,113 +0,0 @@
from chonkie import SentenceChunker
from onyx.connectors.models import IndexingDocument
from onyx.connectors.models import Section
from onyx.connectors.models import SectionType
from onyx.indexing.chunking.image_section_chunker import ImageChunker
from onyx.indexing.chunking.section_chunker import AccumulatorState
from onyx.indexing.chunking.section_chunker import ChunkPayload
from onyx.indexing.chunking.section_chunker import SectionChunker
from onyx.indexing.chunking.tabular_section_chunker import TabularChunker
from onyx.indexing.chunking.text_section_chunker import TextChunker
from onyx.indexing.models import DocAwareChunk
from onyx.natural_language_processing.utils import BaseTokenizer
from onyx.utils.logger import setup_logger
from onyx.utils.text_processing import clean_text
logger = setup_logger()
class DocumentChunker:
"""Converts a document's processed sections into DocAwareChunks.
Drop-in replacement for `Chunker._chunk_document_with_sections`.
"""
def __init__(
self,
tokenizer: BaseTokenizer,
blurb_splitter: SentenceChunker,
chunk_splitter: SentenceChunker,
mini_chunk_splitter: SentenceChunker | None = None,
) -> None:
self.blurb_splitter = blurb_splitter
self.mini_chunk_splitter = mini_chunk_splitter
self._dispatch: dict[SectionType, SectionChunker] = {
SectionType.TEXT: TextChunker(
tokenizer=tokenizer,
chunk_splitter=chunk_splitter,
),
SectionType.IMAGE: ImageChunker(),
SectionType.TABULAR: TabularChunker(tokenizer=tokenizer),
}
def chunk(
self,
document: IndexingDocument,
sections: list[Section],
title_prefix: str,
metadata_suffix_semantic: str,
metadata_suffix_keyword: str,
content_token_limit: int,
) -> list[DocAwareChunk]:
payloads = self._collect_section_payloads(
document=document,
sections=sections,
content_token_limit=content_token_limit,
)
if not payloads:
payloads.append(ChunkPayload(text="", links={0: ""}))
return [
payload.to_doc_aware_chunk(
document=document,
chunk_id=idx,
blurb_splitter=self.blurb_splitter,
mini_chunk_splitter=self.mini_chunk_splitter,
title_prefix=title_prefix,
metadata_suffix_semantic=metadata_suffix_semantic,
metadata_suffix_keyword=metadata_suffix_keyword,
)
for idx, payload in enumerate(payloads)
]
def _collect_section_payloads(
self,
document: IndexingDocument,
sections: list[Section],
content_token_limit: int,
) -> list[ChunkPayload]:
accumulator = AccumulatorState()
payloads: list[ChunkPayload] = []
for section_idx, section in enumerate(sections):
section_text = clean_text(str(section.text or ""))
if not section_text and (not document.title or section_idx > 0):
logger.warning(
f"Skipping empty or irrelevant section in doc "
f"{document.semantic_identifier}, link={section.link}"
)
continue
chunker = self._select_chunker(section)
result = chunker.chunk_section(
section=section,
accumulator=accumulator,
content_token_limit=content_token_limit,
)
payloads.extend(result.payloads)
accumulator = result.accumulator
# Final flush — any leftover buffered text becomes one last payload.
payloads.extend(accumulator.flush_to_list())
return payloads
def _select_chunker(self, section: Section) -> SectionChunker:
try:
return self._dispatch[section.type]
except KeyError:
raise ValueError(f"No SectionChunker registered for type={section.type}")

View File

@@ -1,35 +0,0 @@
from onyx.connectors.models import Section
from onyx.indexing.chunking.section_chunker import AccumulatorState
from onyx.indexing.chunking.section_chunker import ChunkPayload
from onyx.indexing.chunking.section_chunker import SectionChunker
from onyx.indexing.chunking.section_chunker import SectionChunkerOutput
from onyx.utils.text_processing import clean_text
class ImageChunker(SectionChunker):
def chunk_section(
self,
section: Section,
accumulator: AccumulatorState,
content_token_limit: int, # noqa: ARG002
) -> SectionChunkerOutput:
assert section.image_file_id is not None
section_text = clean_text(str(section.text or ""))
section_link = section.link or ""
# Flush any partially built text chunks
payloads = accumulator.flush_to_list()
payloads.append(
ChunkPayload(
text=section_text,
links={0: section_link} if section_link else {},
image_file_id=section.image_file_id,
is_continuation=False,
)
)
return SectionChunkerOutput(
payloads=payloads,
accumulator=AccumulatorState(),
)

View File

@@ -1,100 +0,0 @@
from abc import ABC
from abc import abstractmethod
from collections.abc import Sequence
from typing import cast
from chonkie import SentenceChunker
from pydantic import BaseModel
from pydantic import Field
from onyx.connectors.models import IndexingDocument
from onyx.connectors.models import Section
from onyx.indexing.models import DocAwareChunk
def extract_blurb(text: str, blurb_splitter: SentenceChunker) -> str:
texts = cast(list[str], blurb_splitter.chunk(text))
if not texts:
return ""
return texts[0]
def get_mini_chunk_texts(
chunk_text: str,
mini_chunk_splitter: SentenceChunker | None,
) -> list[str] | None:
if mini_chunk_splitter and chunk_text.strip():
return list(cast(Sequence[str], mini_chunk_splitter.chunk(chunk_text)))
return None
class ChunkPayload(BaseModel):
"""Section-local chunk content without document-scoped fields.
The orchestrator upgrades these to DocAwareChunks via
`to_doc_aware_chunk` after assigning chunk_ids and attaching
title/metadata.
"""
text: str
links: dict[int, str]
is_continuation: bool = False
image_file_id: str | None = None
def to_doc_aware_chunk(
self,
document: IndexingDocument,
chunk_id: int,
blurb_splitter: SentenceChunker,
title_prefix: str = "",
metadata_suffix_semantic: str = "",
metadata_suffix_keyword: str = "",
mini_chunk_splitter: SentenceChunker | None = None,
) -> DocAwareChunk:
return DocAwareChunk(
source_document=document,
chunk_id=chunk_id,
blurb=extract_blurb(self.text, blurb_splitter),
content=self.text,
source_links=self.links or {0: ""},
image_file_id=self.image_file_id,
section_continuation=self.is_continuation,
title_prefix=title_prefix,
metadata_suffix_semantic=metadata_suffix_semantic,
metadata_suffix_keyword=metadata_suffix_keyword,
mini_chunk_texts=get_mini_chunk_texts(self.text, mini_chunk_splitter),
large_chunk_id=None,
doc_summary="",
chunk_context="",
contextual_rag_reserved_tokens=0,
)
class AccumulatorState(BaseModel):
"""Cross-section text buffer threaded through SectionChunkers."""
text: str = ""
link_offsets: dict[int, str] = Field(default_factory=dict)
def is_empty(self) -> bool:
return not self.text.strip()
def flush_to_list(self) -> list[ChunkPayload]:
if self.is_empty():
return []
return [ChunkPayload(text=self.text, links=self.link_offsets)]
class SectionChunkerOutput(BaseModel):
payloads: list[ChunkPayload]
accumulator: AccumulatorState
class SectionChunker(ABC):
@abstractmethod
def chunk_section(
self,
section: Section,
accumulator: AccumulatorState,
content_token_limit: int,
) -> SectionChunkerOutput: ...

View File

@@ -1,5 +0,0 @@
from onyx.indexing.chunking.tabular_section_chunker.tabular_section_chunker import (
TabularChunker,
)
__all__ = ["TabularChunker"]

View File

@@ -1,229 +0,0 @@
"""Per-section sheet descriptor chunk builder."""
from dataclasses import dataclass
from dataclasses import field
from datetime import date
from itertools import zip_longest
from dateutil.parser import parse as parse_dt
from onyx.connectors.models import Section
from onyx.natural_language_processing.utils import BaseTokenizer
from onyx.natural_language_processing.utils import count_tokens
from onyx.utils.csv_utils import parse_csv_string
from onyx.utils.csv_utils import ParsedRow
from onyx.utils.csv_utils import read_csv_header
MAX_NUMERIC_COLS = 12
MAX_CATEGORICAL_COLS = 6
MAX_CATEGORICAL_WITH_SAMPLES = 4
MAX_DISTINCT_SAMPLES = 8
CATEGORICAL_DISTINCT_THRESHOLD = 20
ID_NAME_TOKENS = {"id", "uuid", "uid", "guid", "key"}
@dataclass
class SheetAnalysis:
row_count: int
num_cols: int
numeric_cols: list[int] = field(default_factory=list)
categorical_cols: list[int] = field(default_factory=list)
categorical_values: dict[int, list[str]] = field(default_factory=dict)
id_col: int | None = None
date_min: date | None = None
date_max: date | None = None
def build_sheet_descriptor_chunks(
section: Section,
tokenizer: BaseTokenizer,
max_tokens: int,
) -> list[str]:
"""Build sheet descriptor chunk(s) from a parsed CSV section.
Output (lines joined by "\\n"; lines that overflow ``max_tokens`` on
their own are skipped; ``section.heading`` is prepended to every
emitted chunk so retrieval keeps sheet context after a split):
{section.heading} # optional
Sheet overview.
This sheet has {N} rows and {M} columns.
Columns: {col1}, {col2}, ...
Time range: {start} to {end}. # optional
Numeric columns (aggregatable by sum, average, min, max): ... # optional
Categorical columns (groupable, can be counted by value): ... # optional
Identifier column: {col}. # optional
Values seen in {col}: {v1}, {v2}, ... # optional, repeated
"""
text = section.text or ""
parsed_rows = list(parse_csv_string(text))
headers = parsed_rows[0].header if parsed_rows else read_csv_header(text)
if not headers:
return []
a = _analyze(headers, parsed_rows)
lines = [
_overview_line(a),
_columns_line(headers),
_time_range_line(a),
_numeric_cols_line(headers, a),
_categorical_cols_line(headers, a),
_id_col_line(headers, a),
_values_seen_line(headers, a),
]
return _pack_lines(
[line for line in lines if line],
prefix=section.heading or "",
tokenizer=tokenizer,
max_tokens=max_tokens,
)
def _overview_line(a: SheetAnalysis) -> str:
return (
"Sheet overview.\n"
f"This sheet has {a.row_count} rows and {a.num_cols} columns."
)
def _columns_line(headers: list[str]) -> str:
return "Columns: " + ", ".join(_label(h) for h in headers)
def _time_range_line(a: SheetAnalysis) -> str:
if not (a.date_min and a.date_max):
return ""
return f"Time range: {a.date_min} to {a.date_max}."
def _numeric_cols_line(headers: list[str], a: SheetAnalysis) -> str:
if not a.numeric_cols:
return ""
names = ", ".join(_label(headers[i]) for i in a.numeric_cols[:MAX_NUMERIC_COLS])
return f"Numeric columns (aggregatable by sum, average, min, max): {names}"
def _categorical_cols_line(headers: list[str], a: SheetAnalysis) -> str:
if not a.categorical_cols:
return ""
names = ", ".join(
_label(headers[i]) for i in a.categorical_cols[:MAX_CATEGORICAL_COLS]
)
return f"Categorical columns (groupable, can be counted by value): {names}"
def _id_col_line(headers: list[str], a: SheetAnalysis) -> str:
if a.id_col is None:
return ""
return f"Identifier column: {_label(headers[a.id_col])}."
def _values_seen_line(headers: list[str], a: SheetAnalysis) -> str:
rows: list[str] = []
for ci in a.categorical_cols[:MAX_CATEGORICAL_WITH_SAMPLES]:
sample = sorted(a.categorical_values.get(ci, []))[:MAX_DISTINCT_SAMPLES]
if sample:
rows.append(f"Values seen in {_label(headers[ci])}: " + ", ".join(sample))
return "\n".join(rows)
def _label(name: str) -> str:
return f"{name} ({name.replace('_', ' ')})" if "_" in name else name
def _is_numeric(value: str) -> bool:
try:
float(value.replace(",", ""))
return True
except ValueError:
return False
def _try_date(value: str) -> date | None:
if len(value) < 4 or not any(c in value for c in "-/T"):
return None
try:
return parse_dt(value).date()
except (ValueError, OverflowError, TypeError):
return None
def _is_id_name(name: str) -> bool:
lowered = name.lower().strip().replace("-", "_")
return lowered in ID_NAME_TOKENS or any(
lowered.endswith(f"_{t}") for t in ID_NAME_TOKENS
)
def _analyze(headers: list[str], parsed_rows: list[ParsedRow]) -> SheetAnalysis:
a = SheetAnalysis(row_count=len(parsed_rows), num_cols=len(headers))
columns = zip_longest(*(pr.row for pr in parsed_rows), fillvalue="")
for idx, (header, raw_values) in enumerate(zip(headers, columns)):
# Pull the column's non-empty values; skip if the column is blank.
values = [v.strip() for v in raw_values if v.strip()]
if not values:
continue
# Identifier: id-named column whose values are all unique. Detected
# before classification so a numeric `id` column still gets flagged.
distinct = set(values)
if a.id_col is None and len(distinct) == len(values) and _is_id_name(header):
a.id_col = idx
# Numeric: every value parses as a number.
if all(_is_numeric(v) for v in values):
a.numeric_cols.append(idx)
continue
# Date: every value parses as a date — fold into the sheet-wide range.
dates = [_try_date(v) for v in values]
if all(d is not None for d in dates):
dmin = min(filter(None, dates))
dmax = max(filter(None, dates))
a.date_min = dmin if a.date_min is None else min(a.date_min, dmin)
a.date_max = dmax if a.date_max is None else max(a.date_max, dmax)
continue
# Categorical: low-cardinality column — keep distinct values for samples.
if len(distinct) <= max(CATEGORICAL_DISTINCT_THRESHOLD, len(values) // 2):
a.categorical_cols.append(idx)
a.categorical_values[idx] = list(distinct)
return a
def _pack_lines(
lines: list[str],
prefix: str,
tokenizer: BaseTokenizer,
max_tokens: int,
) -> list[str]:
"""Greedily pack lines into chunks ≤ max_tokens. Lines that on
their own exceed max_tokens (after accounting for the prefix) are
skipped. ``prefix`` is prepended to every emitted chunk."""
prefix_tokens = count_tokens(prefix, tokenizer) + 1 if prefix else 0
budget = max_tokens - prefix_tokens
chunks: list[str] = []
current: list[str] = []
current_tokens = 0
for line in lines:
line_tokens = count_tokens(line, tokenizer)
if line_tokens > budget:
continue
sep = 1 if current else 0
if current_tokens + sep + line_tokens > budget:
chunks.append(_join_with_prefix(current, prefix))
current = [line]
current_tokens = line_tokens
else:
current.append(line)
current_tokens += sep + line_tokens
if current:
chunks.append(_join_with_prefix(current, prefix))
return chunks
def _join_with_prefix(lines: list[str], prefix: str) -> str:
body = "\n".join(lines)
return f"{prefix}\n{body}" if prefix else body

View File

@@ -1,272 +0,0 @@
from collections.abc import Iterable
from pydantic import BaseModel
from onyx.connectors.models import Section
from onyx.indexing.chunking.section_chunker import AccumulatorState
from onyx.indexing.chunking.section_chunker import ChunkPayload
from onyx.indexing.chunking.section_chunker import SectionChunker
from onyx.indexing.chunking.section_chunker import SectionChunkerOutput
from onyx.indexing.chunking.tabular_section_chunker.sheet_descriptor import (
build_sheet_descriptor_chunks,
)
from onyx.natural_language_processing.utils import BaseTokenizer
from onyx.natural_language_processing.utils import count_tokens
from onyx.natural_language_processing.utils import split_text_by_tokens
from onyx.utils.csv_utils import parse_csv_string
from onyx.utils.csv_utils import ParsedRow
from onyx.utils.logger import setup_logger
logger = setup_logger()
COLUMNS_MARKER = "Columns:"
FIELD_VALUE_SEPARATOR = ", "
ROW_JOIN = "\n"
NEWLINE_TOKENS = 1
class _TokenizedText(BaseModel):
text: str
token_count: int
def format_row(header: list[str], row: list[str]) -> str:
"""
A header-row combination is formatted like this:
field1=value1, field2=value2, field3=value3
"""
pairs = _row_to_pairs(header, row)
formatted = FIELD_VALUE_SEPARATOR.join(f"{h}={v}" for h, v in pairs)
return formatted
def format_columns_header(headers: list[str]) -> str:
"""
Format the column header line. Underscored headers get a
space-substituted friendly alias in parens.
Example:
headers = ["id", "MTTR_hours"]
=> "Columns: id, MTTR_hours (MTTR hours)"
"""
parts: list[str] = []
for header in headers:
friendly = header
if "_" in header:
friendly = f'{header} ({header.replace("_", " ")})'
parts.append(friendly)
return f"{COLUMNS_MARKER} " + FIELD_VALUE_SEPARATOR.join(parts)
def _row_to_pairs(headers: list[str], row: list[str]) -> list[tuple[str, str]]:
return [(h, v) for h, v in zip(headers, row) if v.strip()]
def pack_chunk(chunk: str, new_row: str) -> str:
return chunk + "\n" + new_row
def _split_row_by_pairs(
pairs: list[tuple[str, str]],
tokenizer: BaseTokenizer,
max_tokens: int,
) -> list[_TokenizedText]:
"""Greedily pack pairs into max-sized pieces. Any single pair that
itself exceeds ``max_tokens`` is token-split at id boundaries.
No headers."""
separator_tokens = count_tokens(FIELD_VALUE_SEPARATOR, tokenizer)
pieces: list[_TokenizedText] = []
current_parts: list[str] = []
current_tokens = 0
for pair in pairs:
pair_str = f"{pair[0]}={pair[1]}"
pair_tokens = count_tokens(pair_str, tokenizer)
increment = pair_tokens if not current_parts else separator_tokens + pair_tokens
if current_tokens + increment <= max_tokens:
current_parts.append(pair_str)
current_tokens += increment
continue
if current_parts:
pieces.append(
_TokenizedText(
text=FIELD_VALUE_SEPARATOR.join(current_parts),
token_count=current_tokens,
)
)
current_parts = []
current_tokens = 0
if pair_tokens > max_tokens:
for split_text in split_text_by_tokens(pair_str, tokenizer, max_tokens):
pieces.append(
_TokenizedText(
text=split_text,
token_count=count_tokens(split_text, tokenizer),
)
)
else:
current_parts = [pair_str]
current_tokens = pair_tokens
if current_parts:
pieces.append(
_TokenizedText(
text=FIELD_VALUE_SEPARATOR.join(current_parts),
token_count=current_tokens,
)
)
return pieces
def _build_chunk_from_scratch(
pairs: list[tuple[str, str]],
formatted_row: str,
row_tokens: int,
column_header: str,
column_header_tokens: int,
sheet_header: str,
sheet_header_tokens: int,
tokenizer: BaseTokenizer,
max_tokens: int,
) -> list[_TokenizedText]:
# 1. Row alone is too large — split by pairs, no headers.
if row_tokens > max_tokens:
return _split_row_by_pairs(pairs, tokenizer, max_tokens)
chunk = formatted_row
chunk_tokens = row_tokens
# 2. Attempt to add column header
candidate_tokens = column_header_tokens + NEWLINE_TOKENS + chunk_tokens
if candidate_tokens <= max_tokens:
chunk = column_header + ROW_JOIN + chunk
chunk_tokens = candidate_tokens
# 3. Attempt to add sheet header
if sheet_header:
candidate_tokens = sheet_header_tokens + NEWLINE_TOKENS + chunk_tokens
if candidate_tokens <= max_tokens:
chunk = sheet_header + ROW_JOIN + chunk
chunk_tokens = candidate_tokens
return [_TokenizedText(text=chunk, token_count=chunk_tokens)]
def parse_to_chunks(
rows: Iterable[ParsedRow],
sheet_header: str,
tokenizer: BaseTokenizer,
max_tokens: int,
) -> list[str]:
rows_list = list(rows)
if not rows_list:
return []
column_header = format_columns_header(rows_list[0].header)
column_header_tokens = count_tokens(column_header, tokenizer)
sheet_header_tokens = count_tokens(sheet_header, tokenizer) if sheet_header else 0
chunks: list[str] = []
current_chunk = ""
current_chunk_tokens = 0
for row in rows_list:
pairs: list[tuple[str, str]] = _row_to_pairs(row.header, row.row)
formatted = format_row(row.header, row.row)
row_tokens = count_tokens(formatted, tokenizer)
if current_chunk:
# Attempt to pack it in (additive approximation)
if current_chunk_tokens + NEWLINE_TOKENS + row_tokens <= max_tokens:
current_chunk = pack_chunk(current_chunk, formatted)
current_chunk_tokens += NEWLINE_TOKENS + row_tokens
continue
# Doesn't fit — flush and start new
chunks.append(current_chunk)
current_chunk = ""
current_chunk_tokens = 0
# Build chunk from scratch
for piece in _build_chunk_from_scratch(
pairs=pairs,
formatted_row=formatted,
row_tokens=row_tokens,
column_header=column_header,
column_header_tokens=column_header_tokens,
sheet_header=sheet_header,
sheet_header_tokens=sheet_header_tokens,
tokenizer=tokenizer,
max_tokens=max_tokens,
):
if current_chunk:
chunks.append(current_chunk)
current_chunk = piece.text
current_chunk_tokens = piece.token_count
# Flush remaining
if current_chunk:
chunks.append(current_chunk)
return chunks
class TabularChunker(SectionChunker):
def __init__(
self,
tokenizer: BaseTokenizer,
ignore_metadata_chunks: bool = False,
) -> None:
self.tokenizer = tokenizer
self.ignore_metadata_chunks = ignore_metadata_chunks
def chunk_section(
self,
section: Section,
accumulator: AccumulatorState,
content_token_limit: int,
) -> SectionChunkerOutput:
payloads = accumulator.flush_to_list()
parsed_rows = list(parse_csv_string(section.text or ""))
sheet_header = section.heading or ""
chunk_texts: list[str] = []
if parsed_rows:
chunk_texts.extend(
parse_to_chunks(
rows=parsed_rows,
sheet_header=sheet_header,
tokenizer=self.tokenizer,
max_tokens=content_token_limit,
)
)
if not self.ignore_metadata_chunks:
chunk_texts.extend(
build_sheet_descriptor_chunks(
section=section,
tokenizer=self.tokenizer,
max_tokens=content_token_limit,
)
)
if not chunk_texts:
logger.warning(
f"TabularChunker: skipping unparseable section (link={section.link})"
)
return SectionChunkerOutput(
payloads=payloads, accumulator=AccumulatorState()
)
for i, text in enumerate(chunk_texts):
payloads.append(
ChunkPayload(
text=text,
links={0: section.link or ""},
is_continuation=(i > 0),
)
)
return SectionChunkerOutput(payloads=payloads, accumulator=AccumulatorState())

View File

@@ -1,117 +0,0 @@
from typing import cast
from chonkie import SentenceChunker
from onyx.configs.constants import SECTION_SEPARATOR
from onyx.connectors.models import Section
from onyx.indexing.chunking.section_chunker import AccumulatorState
from onyx.indexing.chunking.section_chunker import ChunkPayload
from onyx.indexing.chunking.section_chunker import SectionChunker
from onyx.indexing.chunking.section_chunker import SectionChunkerOutput
from onyx.natural_language_processing.utils import BaseTokenizer
from onyx.natural_language_processing.utils import count_tokens
from onyx.natural_language_processing.utils import split_text_by_tokens
from onyx.utils.text_processing import clean_text
from onyx.utils.text_processing import shared_precompare_cleanup
from shared_configs.configs import STRICT_CHUNK_TOKEN_LIMIT
class TextChunker(SectionChunker):
def __init__(
self,
tokenizer: BaseTokenizer,
chunk_splitter: SentenceChunker,
) -> None:
self.tokenizer = tokenizer
self.chunk_splitter = chunk_splitter
self.section_separator_token_count = count_tokens(
SECTION_SEPARATOR,
self.tokenizer,
)
def chunk_section(
self,
section: Section,
accumulator: AccumulatorState,
content_token_limit: int,
) -> SectionChunkerOutput:
section_text = clean_text(str(section.text or ""))
section_link = section.link or ""
section_token_count = len(self.tokenizer.encode(section_text))
# Oversized — flush buffer and split the section
if section_token_count > content_token_limit:
return self._handle_oversized_section(
section_text=section_text,
section_link=section_link,
accumulator=accumulator,
content_token_limit=content_token_limit,
)
current_token_count = count_tokens(accumulator.text, self.tokenizer)
next_section_tokens = self.section_separator_token_count + section_token_count
# Fits — extend the accumulator
if next_section_tokens + current_token_count <= content_token_limit:
offset = len(shared_precompare_cleanup(accumulator.text))
new_text = accumulator.text
if new_text:
new_text += SECTION_SEPARATOR
new_text += section_text
return SectionChunkerOutput(
payloads=[],
accumulator=AccumulatorState(
text=new_text,
link_offsets={**accumulator.link_offsets, offset: section_link},
),
)
# Doesn't fit — flush buffer and restart with this section
return SectionChunkerOutput(
payloads=accumulator.flush_to_list(),
accumulator=AccumulatorState(
text=section_text,
link_offsets={0: section_link},
),
)
def _handle_oversized_section(
self,
section_text: str,
section_link: str,
accumulator: AccumulatorState,
content_token_limit: int,
) -> SectionChunkerOutput:
payloads = accumulator.flush_to_list()
split_texts = cast(list[str], self.chunk_splitter.chunk(section_text))
for i, split_text in enumerate(split_texts):
if (
STRICT_CHUNK_TOKEN_LIMIT
and count_tokens(split_text, self.tokenizer) > content_token_limit
):
smaller_chunks = split_text_by_tokens(
split_text, self.tokenizer, content_token_limit
)
for j, small_chunk in enumerate(smaller_chunks):
payloads.append(
ChunkPayload(
text=small_chunk,
links={0: section_link},
is_continuation=(j != 0),
)
)
else:
payloads.append(
ChunkPayload(
text=split_text,
links={0: section_link},
is_continuation=(i != 0),
)
)
return SectionChunkerOutput(
payloads=payloads,
accumulator=AccumulatorState(),
)

View File

@@ -3,8 +3,6 @@ from abc import ABC
from abc import abstractmethod
from collections import defaultdict
import sentry_sdk
from onyx.connectors.models import ConnectorFailure
from onyx.connectors.models import ConnectorStopSignal
from onyx.connectors.models import DocumentFailure
@@ -293,13 +291,6 @@ def embed_chunks_with_failure_handling(
)
embedded_chunks.extend(doc_embedded_chunks)
except Exception as e:
with sentry_sdk.new_scope() as scope:
scope.set_tag("stage", "embedding")
scope.set_tag("doc_id", doc_id)
if tenant_id:
scope.set_tag("tenant_id", tenant_id)
scope.fingerprint = ["embedding-failure", type(e).__name__]
sentry_sdk.capture_exception(e)
logger.exception(f"Failed to embed chunks for document '{doc_id}'")
failures.append(
ConnectorFailure(

View File

@@ -5,7 +5,6 @@ from collections.abc import Iterator
from contextlib import contextmanager
from typing import Protocol
import sentry_sdk
from pydantic import BaseModel
from pydantic import ConfigDict
from sqlalchemy.orm import Session
@@ -333,13 +332,6 @@ def index_doc_batch_with_handler(
except Exception as e:
# don't log the batch directly, it's too much text
document_ids = [doc.id for doc in document_batch]
with sentry_sdk.new_scope() as scope:
scope.set_tag("stage", "indexing_pipeline")
scope.set_tag("tenant_id", tenant_id)
scope.set_tag("batch_size", str(len(document_batch)))
scope.set_extra("document_ids", document_ids)
scope.fingerprint = ["indexing-pipeline-failure", type(e).__name__]
sentry_sdk.capture_exception(e)
logger.exception(f"Failed to index document batch: {document_ids}")
index_pipeline_result = IndexingPipelineResult(
@@ -550,8 +542,7 @@ def process_image_sections(documents: list[Document]) -> list[IndexingDocument]:
**document.model_dump(),
processed_sections=[
Section(
type=section.type,
text="" if isinstance(section, ImageSection) else section.text,
text=section.text if isinstance(section, TextSection) else "",
link=section.link,
image_file_id=(
section.image_file_id
@@ -575,7 +566,6 @@ def process_image_sections(documents: list[Document]) -> list[IndexingDocument]:
if isinstance(section, ImageSection):
# Default section with image path preserved - ensure text is always a string
processed_section = Section(
type=section.type,
link=section.link,
image_file_id=section.image_file_id,
text="", # Initialize with empty string
@@ -617,9 +607,8 @@ def process_image_sections(documents: list[Document]) -> list[IndexingDocument]:
processed_sections.append(processed_section)
# For TextSection, create a base Section with text and link
else:
elif isinstance(section, TextSection):
processed_section = Section(
type=section.type,
text=section.text or "", # Ensure text is always a string, not None
link=section.link,
image_file_id=None,

View File

@@ -6,7 +6,6 @@ from itertools import chain
from itertools import groupby
import httpx
import sentry_sdk
from onyx.connectors.models import ConnectorFailure
from onyx.connectors.models import DocumentFailure
@@ -89,12 +88,6 @@ def write_chunks_to_vector_db_with_backoff(
)
)
except Exception as e:
with sentry_sdk.new_scope() as scope:
scope.set_tag("stage", "vector_db_write")
scope.set_tag("doc_id", doc_id)
scope.set_tag("tenant_id", index_batch_params.tenant_id)
scope.fingerprint = ["vector-db-write-failure", type(e).__name__]
sentry_sdk.capture_exception(e)
logger.exception(
f"Failed to write document chunks for '{doc_id}' to vector db"
)

View File

@@ -434,14 +434,11 @@ def get_application(lifespan_override: Lifespan | None = None) -> FastAPI:
lifespan=lifespan_override or lifespan,
)
if SENTRY_DSN:
from onyx.configs.sentry import _add_instance_tags
sentry_sdk.init(
dsn=SENTRY_DSN,
integrations=[StarletteIntegration(), FastApiIntegration()],
traces_sample_rate=0.1,
release=__version__,
before_send=_add_instance_tags,
)
logger.info("Sentry initialized")
else:

View File

@@ -201,33 +201,6 @@ def count_tokens(
return total
def split_text_by_tokens(
text: str,
tokenizer: BaseTokenizer,
max_tokens: int,
) -> list[str]:
"""Split ``text`` into pieces of ≤ ``max_tokens`` tokens each, via
encode/decode at token-id boundaries.
Note: the returned pieces are not strictly guaranteed to re-tokenize to
≤ max_tokens. BPE merges at window boundaries may drift by a few tokens,
and cuts landing mid-multi-byte-UTF-8-character produce replacement
characters on decode. Good enough for "best-effort" splitting of
oversized content, not for hard limit enforcement.
"""
if not text:
return []
token_ids: list[int] = []
for start in range(0, len(text), _ENCODE_CHUNK_SIZE):
token_ids.extend(tokenizer.encode(text[start : start + _ENCODE_CHUNK_SIZE]))
return [
tokenizer.decode(token_ids[start : start + max_tokens])
for start in range(0, len(token_ids), max_tokens)
]
def tokenizer_trim_content(
content: str, desired_length: int, tokenizer: BaseTokenizer
) -> str:

View File

@@ -15,7 +15,7 @@
"date-fns": "^4.1.0",
"embla-carousel-react": "^8.6.0",
"lucide-react": "^0.562.0",
"next": "16.2.3",
"next": "16.1.7",
"next-themes": "^0.4.6",
"radix-ui": "^1.4.3",
"react": "19.2.3",
@@ -961,9 +961,9 @@
"license": "MIT"
},
"node_modules/@hono/node-server": {
"version": "1.19.13",
"resolved": "https://registry.npmjs.org/@hono/node-server/-/node-server-1.19.13.tgz",
"integrity": "sha512-TsQLe4i2gvoTtrHje625ngThGBySOgSK3Xo2XRYOdqGN1teR8+I7vchQC46uLJi8OF62YTYA3AhSpumtkhsaKQ==",
"version": "1.19.10",
"resolved": "https://registry.npmjs.org/@hono/node-server/-/node-server-1.19.10.tgz",
"integrity": "sha512-hZ7nOssGqRgyV3FVVQdfi+U4q02uB23bpnYpdvNXkYTRRyWx84b7yf1ans+dnJ/7h41sGL3CeQTfO+ZGxuO+Iw==",
"license": "MIT",
"engines": {
"node": ">=18.14.1"
@@ -1711,9 +1711,9 @@
}
},
"node_modules/@next/env": {
"version": "16.2.3",
"resolved": "https://registry.npmjs.org/@next/env/-/env-16.2.3.tgz",
"integrity": "sha512-ZWXyj4uNu4GCWQw9cjRxWlbD+33mcDszIo9iQxFnBX3Wmgq9ulaSJcl6VhuWx5pCWqqD+9W6Wfz7N0lM5lYPMA==",
"version": "16.1.7",
"resolved": "https://registry.npmjs.org/@next/env/-/env-16.1.7.tgz",
"integrity": "sha512-rJJbIdJB/RQr2F1nylZr/PJzamvNNhfr3brdKP6s/GW850jbtR70QlSfFselvIBbcPUOlQwBakexjFzqLzF6pg==",
"license": "MIT"
},
"node_modules/@next/eslint-plugin-next": {
@@ -1727,9 +1727,9 @@
}
},
"node_modules/@next/swc-darwin-arm64": {
"version": "16.2.3",
"resolved": "https://registry.npmjs.org/@next/swc-darwin-arm64/-/swc-darwin-arm64-16.2.3.tgz",
"integrity": "sha512-u37KDKTKQ+OQLvY+z7SNXixwo4Q2/IAJFDzU1fYe66IbCE51aDSAzkNDkWmLN0yjTUh4BKBd+hb69jYn6qqqSg==",
"version": "16.1.7",
"resolved": "https://registry.npmjs.org/@next/swc-darwin-arm64/-/swc-darwin-arm64-16.1.7.tgz",
"integrity": "sha512-b2wWIE8sABdyafc4IM8r5Y/dS6kD80JRtOGrUiKTsACFQfWWgUQ2NwoUX1yjFMXVsAwcQeNpnucF2ZrujsBBPg==",
"cpu": [
"arm64"
],
@@ -1743,9 +1743,9 @@
}
},
"node_modules/@next/swc-darwin-x64": {
"version": "16.2.3",
"resolved": "https://registry.npmjs.org/@next/swc-darwin-x64/-/swc-darwin-x64-16.2.3.tgz",
"integrity": "sha512-gHjL/qy6Q6CG3176FWbAKyKh9IfntKZTB3RY/YOJdDFpHGsUDXVH38U4mMNpHVGXmeYW4wj22dMp1lTfmu/bTQ==",
"version": "16.1.7",
"resolved": "https://registry.npmjs.org/@next/swc-darwin-x64/-/swc-darwin-x64-16.1.7.tgz",
"integrity": "sha512-zcnVaaZulS1WL0Ss38R5Q6D2gz7MtBu8GZLPfK+73D/hp4GFMrC2sudLky1QibfV7h6RJBJs/gOFvYP0X7UVlQ==",
"cpu": [
"x64"
],
@@ -1759,9 +1759,9 @@
}
},
"node_modules/@next/swc-linux-arm64-gnu": {
"version": "16.2.3",
"resolved": "https://registry.npmjs.org/@next/swc-linux-arm64-gnu/-/swc-linux-arm64-gnu-16.2.3.tgz",
"integrity": "sha512-U6vtblPtU/P14Y/b/n9ZY0GOxbbIhTFuaFR7F4/uMBidCi2nSdaOFhA0Go81L61Zd6527+yvuX44T4ksnf8T+Q==",
"version": "16.1.7",
"resolved": "https://registry.npmjs.org/@next/swc-linux-arm64-gnu/-/swc-linux-arm64-gnu-16.1.7.tgz",
"integrity": "sha512-2ant89Lux/Q3VyC8vNVg7uBaFVP9SwoK2jJOOR0L8TQnX8CAYnh4uctAScy2Hwj2dgjVHqHLORQZJ2wH6VxhSQ==",
"cpu": [
"arm64"
],
@@ -1775,9 +1775,9 @@
}
},
"node_modules/@next/swc-linux-arm64-musl": {
"version": "16.2.3",
"resolved": "https://registry.npmjs.org/@next/swc-linux-arm64-musl/-/swc-linux-arm64-musl-16.2.3.tgz",
"integrity": "sha512-/YV0LgjHUmfhQpn9bVoGc4x4nan64pkhWR5wyEV8yCOfwwrH630KpvRg86olQHTwHIn1z59uh6JwKvHq1h4QEw==",
"version": "16.1.7",
"resolved": "https://registry.npmjs.org/@next/swc-linux-arm64-musl/-/swc-linux-arm64-musl-16.1.7.tgz",
"integrity": "sha512-uufcze7LYv0FQg9GnNeZ3/whYfo+1Q3HnQpm16o6Uyi0OVzLlk2ZWoY7j07KADZFY8qwDbsmFnMQP3p3+Ftprw==",
"cpu": [
"arm64"
],
@@ -1791,9 +1791,9 @@
}
},
"node_modules/@next/swc-linux-x64-gnu": {
"version": "16.2.3",
"resolved": "https://registry.npmjs.org/@next/swc-linux-x64-gnu/-/swc-linux-x64-gnu-16.2.3.tgz",
"integrity": "sha512-/HiWEcp+WMZ7VajuiMEFGZ6cg0+aYZPqCJD3YJEfpVWQsKYSjXQG06vJP6F1rdA03COD9Fef4aODs3YxKx+RDQ==",
"version": "16.1.7",
"resolved": "https://registry.npmjs.org/@next/swc-linux-x64-gnu/-/swc-linux-x64-gnu-16.1.7.tgz",
"integrity": "sha512-KWVf2gxYvHtvuT+c4MBOGxuse5TD7DsMFYSxVxRBnOzok/xryNeQSjXgxSv9QpIVlaGzEn/pIuI6Koosx8CGWA==",
"cpu": [
"x64"
],
@@ -1807,9 +1807,9 @@
}
},
"node_modules/@next/swc-linux-x64-musl": {
"version": "16.2.3",
"resolved": "https://registry.npmjs.org/@next/swc-linux-x64-musl/-/swc-linux-x64-musl-16.2.3.tgz",
"integrity": "sha512-Kt44hGJfZSefebhk/7nIdivoDr3Ugp5+oNz9VvF3GUtfxutucUIHfIO0ZYO8QlOPDQloUVQn4NVC/9JvHRk9hw==",
"version": "16.1.7",
"resolved": "https://registry.npmjs.org/@next/swc-linux-x64-musl/-/swc-linux-x64-musl-16.1.7.tgz",
"integrity": "sha512-HguhaGwsGr1YAGs68uRKc4aGWxLET+NevJskOcCAwXbwj0fYX0RgZW2gsOCzr9S11CSQPIkxmoSbuVaBp4Z3dA==",
"cpu": [
"x64"
],
@@ -1823,9 +1823,9 @@
}
},
"node_modules/@next/swc-win32-arm64-msvc": {
"version": "16.2.3",
"resolved": "https://registry.npmjs.org/@next/swc-win32-arm64-msvc/-/swc-win32-arm64-msvc-16.2.3.tgz",
"integrity": "sha512-O2NZ9ie3Tq6xj5Z5CSwBT3+aWAMW2PIZ4egUi9MaWLkwaehgtB7YZjPm+UpcNpKOme0IQuqDcor7BsW6QBiQBw==",
"version": "16.1.7",
"resolved": "https://registry.npmjs.org/@next/swc-win32-arm64-msvc/-/swc-win32-arm64-msvc-16.1.7.tgz",
"integrity": "sha512-S0n3KrDJokKTeFyM/vGGGR8+pCmXYrjNTk2ZozOL1C/JFdfUIL9O1ATaJOl5r2POe56iRChbsszrjMAdWSv7kQ==",
"cpu": [
"arm64"
],
@@ -1839,9 +1839,9 @@
}
},
"node_modules/@next/swc-win32-x64-msvc": {
"version": "16.2.3",
"resolved": "https://registry.npmjs.org/@next/swc-win32-x64-msvc/-/swc-win32-x64-msvc-16.2.3.tgz",
"integrity": "sha512-Ibm29/GgB/ab5n7XKqlStkm54qqZE8v2FnijUPBgrd67FWrac45o/RsNlaOWjme/B5UqeWt/8KM4aWBwA1D2Kw==",
"version": "16.1.7",
"resolved": "https://registry.npmjs.org/@next/swc-win32-x64-msvc/-/swc-win32-x64-msvc-16.1.7.tgz",
"integrity": "sha512-mwgtg8CNZGYm06LeEd+bNnOUfwOyNem/rOiP14Lsz+AnUY92Zq/LXwtebtUiaeVkhbroRCQ0c8GlR4UT1U+0yg==",
"cpu": [
"x64"
],
@@ -7427,9 +7427,9 @@
}
},
"node_modules/hono": {
"version": "4.12.12",
"resolved": "https://registry.npmjs.org/hono/-/hono-4.12.12.tgz",
"integrity": "sha512-p1JfQMKaceuCbpJKAPKVqyqviZdS0eUxH9v82oWo1kb9xjQ5wA6iP3FNVAPDFlz5/p7d45lO+BpSk1tuSZMF4Q==",
"version": "4.12.7",
"resolved": "https://registry.npmjs.org/hono/-/hono-4.12.7.tgz",
"integrity": "sha512-jq9l1DM0zVIvsm3lv9Nw9nlJnMNPOcAtsbsgiUhWcFzPE99Gvo6yRTlszSLLYacMeQ6quHD6hMfId8crVHvexw==",
"license": "MIT",
"engines": {
"node": ">=16.9.0"
@@ -8637,9 +8637,9 @@
}
},
"node_modules/lodash": {
"version": "4.18.1",
"resolved": "https://registry.npmjs.org/lodash/-/lodash-4.18.1.tgz",
"integrity": "sha512-dMInicTPVE8d1e5otfwmmjlxkZoUpiVLwyeTdUsi/Caj/gfzzblBcCE5sRHV/AsjuCmxWrte2TNGSYuCeCq+0Q==",
"version": "4.17.23",
"resolved": "https://registry.npmjs.org/lodash/-/lodash-4.17.23.tgz",
"integrity": "sha512-LgVTMpQtIopCi79SJeDiP0TfWi5CNEc/L/aRdTh3yIvmZXTnheWpKjSZhnvMl8iXbC1tFg9gdHHDMLoV7CnG+w==",
"license": "MIT"
},
"node_modules/lodash.merge": {
@@ -8978,12 +8978,12 @@
}
},
"node_modules/next": {
"version": "16.2.3",
"resolved": "https://registry.npmjs.org/next/-/next-16.2.3.tgz",
"integrity": "sha512-9V3zV4oZFza3PVev5/poB9g0dEafVcgNyQ8eTRop8GvxZjV2G15FC5ARuG1eFD42QgeYkzJBJzHghNP8Ad9xtA==",
"version": "16.1.7",
"resolved": "https://registry.npmjs.org/next/-/next-16.1.7.tgz",
"integrity": "sha512-WM0L7WrSvKwoLegLYr6V+mz+RIofqQgVAfHhMp9a88ms0cFX8iX9ew+snpWlSBwpkURJOUdvCEt3uLl3NNzvWg==",
"license": "MIT",
"dependencies": {
"@next/env": "16.2.3",
"@next/env": "16.1.7",
"@swc/helpers": "0.5.15",
"baseline-browser-mapping": "^2.9.19",
"caniuse-lite": "^1.0.30001579",
@@ -8997,15 +8997,15 @@
"node": ">=20.9.0"
},
"optionalDependencies": {
"@next/swc-darwin-arm64": "16.2.3",
"@next/swc-darwin-x64": "16.2.3",
"@next/swc-linux-arm64-gnu": "16.2.3",
"@next/swc-linux-arm64-musl": "16.2.3",
"@next/swc-linux-x64-gnu": "16.2.3",
"@next/swc-linux-x64-musl": "16.2.3",
"@next/swc-win32-arm64-msvc": "16.2.3",
"@next/swc-win32-x64-msvc": "16.2.3",
"sharp": "^0.34.5"
"@next/swc-darwin-arm64": "16.1.7",
"@next/swc-darwin-x64": "16.1.7",
"@next/swc-linux-arm64-gnu": "16.1.7",
"@next/swc-linux-arm64-musl": "16.1.7",
"@next/swc-linux-x64-gnu": "16.1.7",
"@next/swc-linux-x64-musl": "16.1.7",
"@next/swc-win32-arm64-msvc": "16.1.7",
"@next/swc-win32-x64-msvc": "16.1.7",
"sharp": "^0.34.4"
},
"peerDependencies": {
"@opentelemetry/api": "^1.1.0",

View File

@@ -16,7 +16,7 @@
"date-fns": "^4.1.0",
"embla-carousel-react": "^8.6.0",
"lucide-react": "^0.562.0",
"next": "16.2.3",
"next": "16.1.7",
"next-themes": "^0.4.6",
"radix-ui": "^1.4.3",
"react": "19.2.3",

View File

@@ -618,7 +618,6 @@ done
"app.kubernetes.io/managed-by": "onyx",
"onyx.app/sandbox-id": sandbox_id,
"onyx.app/tenant-id": tenant_id,
"admission.datadoghq.com/enabled": "false",
},
),
spec=pod_spec,

View File

@@ -63,7 +63,6 @@ class DocumentSetCreationRequest(BaseModel):
class DocumentSetUpdateRequest(BaseModel):
id: int
name: str
description: str
cc_pair_ids: list[int]
is_public: bool

View File

@@ -96,32 +96,6 @@ def _truncate_description(description: str | None, max_length: int = 500) -> str
return description[: max_length - 3] + "..."
# TODO: Replace mask-comparison approach with an explicit Unset sentinel from the
# frontend indicating whether each credential field was actually modified. The current
# approach is brittle (e.g. short credentials produce a fixed-length mask that could
# collide) and mutates request values, which is surprising. The frontend should signal
# "unchanged" vs "new value" directly rather than relying on masked-string equality.
def _restore_masked_oauth_credentials(
request_client_id: str | None,
request_client_secret: str | None,
existing_client: OAuthClientInformationFull,
) -> tuple[str | None, str | None]:
"""If the frontend sent back masked credentials, restore the real stored values."""
if (
request_client_id
and existing_client.client_id
and request_client_id == mask_string(existing_client.client_id)
):
request_client_id = existing_client.client_id
if (
request_client_secret
and existing_client.client_secret
and request_client_secret == mask_string(existing_client.client_secret)
):
request_client_secret = existing_client.client_secret
return request_client_id, request_client_secret
router = APIRouter(prefix="/mcp")
admin_router = APIRouter(prefix="/admin/mcp")
STATE_TTL_SECONDS = 60 * 5 # 5 minutes
@@ -418,26 +392,6 @@ async def _connect_oauth(
detail=f"Server was configured with authentication type {auth_type_str}",
)
# If the frontend sent back masked credentials (unchanged by the user),
# restore the real stored values so we don't overwrite them with masks.
if mcp_server.admin_connection_config:
existing_data = extract_connection_data(
mcp_server.admin_connection_config, apply_mask=False
)
existing_client_raw = existing_data.get(MCPOAuthKeys.CLIENT_INFO.value)
if existing_client_raw:
existing_client = OAuthClientInformationFull.model_validate(
existing_client_raw
)
(
request.oauth_client_id,
request.oauth_client_secret,
) = _restore_masked_oauth_credentials(
request.oauth_client_id,
request.oauth_client_secret,
existing_client,
)
# Create admin config with client info if provided
config_data = MCPConnectionData(headers={})
if request.oauth_client_id and request.oauth_client_secret:
@@ -1402,19 +1356,6 @@ def _upsert_mcp_server(
if client_info_raw:
client_info = OAuthClientInformationFull.model_validate(client_info_raw)
# If the frontend sent back masked credentials (unchanged by the user),
# restore the real stored values so the comparison below sees no change
# and the credentials aren't overwritten with masked strings.
if client_info and request.auth_type == MCPAuthenticationType.OAUTH:
(
request.oauth_client_id,
request.oauth_client_secret,
) = _restore_masked_oauth_credentials(
request.oauth_client_id,
request.oauth_client_secret,
client_info,
)
changing_connection_config = (
not mcp_server.admin_connection_config
or (

View File

@@ -11,9 +11,6 @@ from onyx.db.notification import dismiss_notification
from onyx.db.notification import get_notification_by_id
from onyx.db.notification import get_notifications
from onyx.server.features.build.utils import ensure_build_mode_intro_notification
from onyx.server.features.notifications.utils import (
ensure_permissions_migration_notification,
)
from onyx.server.features.release_notes.utils import (
ensure_release_notes_fresh_and_notify,
)
@@ -52,13 +49,6 @@ def get_notifications_api(
except Exception:
logger.exception("Failed to check for release notes in notifications endpoint")
try:
ensure_permissions_migration_notification(user, db_session)
except Exception:
logger.exception(
"Failed to create permissions_migration_v1 announcement in notifications endpoint"
)
notifications = [
NotificationModel.from_model(notif)
for notif in get_notifications(user, db_session, include_dismissed=True)

View File

@@ -1,21 +0,0 @@
from sqlalchemy.orm import Session
from onyx.configs.constants import NotificationType
from onyx.db.models import User
from onyx.db.notification import create_notification
def ensure_permissions_migration_notification(user: User, db_session: Session) -> None:
# Feature id "permissions_migration_v1" must not change after shipping —
# it is the dedup key on (user_id, notif_type, additional_data).
create_notification(
user_id=user.id,
notif_type=NotificationType.FEATURE_ANNOUNCEMENT,
db_session=db_session,
title="Permissions are changing in Onyx",
description="Roles are moving to group-based permissions. Click for details.",
additional_data={
"feature": "permissions_migration_v1",
"link": "https://docs.onyx.app/admins/permissions/whats_changing",
},
)

View File

@@ -185,10 +185,6 @@ class MinimalPersonaSnapshot(BaseModel):
for doc_set in persona.document_sets:
for cc_pair in doc_set.connector_credential_pairs:
sources.add(cc_pair.connector.source)
for fed_ds in doc_set.federated_connectors:
non_fed = fed_ds.federated_connector.source.to_non_federated_source()
if non_fed is not None:
sources.add(non_fed)
# Sources from hierarchy nodes
for node in persona.hierarchy_nodes:
@@ -199,9 +195,6 @@ class MinimalPersonaSnapshot(BaseModel):
if doc.parent_hierarchy_node:
sources.add(doc.parent_hierarchy_node.source)
if persona.user_files:
sources.add(DocumentSource.USER_FILE)
return MinimalPersonaSnapshot(
# Core fields actually used by ChatPage
id=persona.id,

View File

@@ -11,7 +11,6 @@ from sqlalchemy.orm import Session
from onyx.auth.permissions import require_permission
from onyx.auth.users import current_curator_or_admin_user
from onyx.background.celery.versioned_apps.client import app as client_app
from onyx.background.indexing.models import IndexAttemptErrorPydantic
from onyx.configs.app_configs import GENERATIVE_MODEL_ACCESS_CHECK_FREQ
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import KV_GEN_AI_KEY_CHECK_TIME
@@ -29,7 +28,6 @@ from onyx.db.feedback import fetch_docs_ranked_by_boost_for_user
from onyx.db.feedback import update_document_boost_for_user
from onyx.db.feedback import update_document_hidden_for_user
from onyx.db.index_attempt import cancel_indexing_attempts_for_ccpair
from onyx.db.index_attempt import get_index_attempt_errors_across_connectors
from onyx.db.models import User
from onyx.file_store.file_store import get_default_file_store
from onyx.key_value_store.factory import get_kv_store
@@ -37,7 +35,6 @@ from onyx.key_value_store.interface import KvKeyNotFoundError
from onyx.llm.factory import get_default_llm
from onyx.llm.utils import test_llm
from onyx.server.documents.models import ConnectorCredentialPairIdentifier
from onyx.server.documents.models import PaginatedReturn
from onyx.server.manage.models import BoostDoc
from onyx.server.manage.models import BoostUpdateRequest
from onyx.server.manage.models import HiddenUpdateRequest
@@ -209,40 +206,3 @@ def create_deletion_attempt_for_connector_id(
file_store = get_default_file_store()
for file_id in connector.connector_specific_config.get("file_locations", []):
file_store.delete_file(file_id)
@router.get("/admin/indexing/failed-documents")
def get_failed_documents(
cc_pair_id: int | None = None,
error_type: str | None = None,
start_time: datetime | None = None,
end_time: datetime | None = None,
include_resolved: bool = False,
page_num: int = 0,
page_size: int = 25,
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
db_session: Session = Depends(get_session),
) -> PaginatedReturn[IndexAttemptErrorPydantic]:
"""Get indexing errors across all connectors with optional filters.
Provides a cross-connector view of document indexing failures.
Defaults to last 30 days if no start_time is provided to avoid
unbounded count queries.
"""
if start_time is None:
start_time = datetime.now(tz=timezone.utc) - timedelta(days=30)
errors, total = get_index_attempt_errors_across_connectors(
db_session=db_session,
cc_pair_id=cc_pair_id,
error_type=error_type,
start_time=start_time,
end_time=end_time,
unresolved_only=not include_resolved,
page=page_num,
page_size=page_size,
)
return PaginatedReturn(
items=[IndexAttemptErrorPydantic.from_model(e) for e in errors],
total_items=total,
)

View File

@@ -111,43 +111,6 @@ def _mask_string(value: str) -> str:
return value[:4] + "****" + value[-4:]
def _resolve_api_key(
api_key: str | None,
provider_name: str | None,
api_base: str | None,
db_session: Session,
) -> str | None:
"""Return the real API key for model-fetch endpoints.
When editing an existing provider the form value is masked (e.g.
``sk-a****b1c2``). If *provider_name* is supplied we can look up
the unmasked key from the database so the external request succeeds.
The stored key is only returned when the request's *api_base*
matches the value stored in the database.
"""
if not provider_name:
return api_key
existing_provider = fetch_existing_llm_provider(
name=provider_name, db_session=db_session
)
if existing_provider and existing_provider.api_key:
# Normalise both URLs before comparing so trailing-slash
# differences don't cause a false mismatch.
stored_base = (existing_provider.api_base or "").strip().rstrip("/")
request_base = (api_base or "").strip().rstrip("/")
if stored_base != request_base:
return api_key
stored_key = existing_provider.api_key.get_value(apply_mask=False)
# Only resolve when the incoming value is the masked form of the
# stored key — i.e. the user hasn't typed a new key.
if api_key and api_key == _mask_string(stored_key):
return stored_key
return api_key
def _sync_fetched_models(
db_session: Session,
provider_name: str,
@@ -1211,17 +1174,16 @@ def get_ollama_available_models(
return sorted_results
def _get_openrouter_models_response(api_base: str, api_key: str | None) -> dict:
def _get_openrouter_models_response(api_base: str, api_key: str) -> dict:
"""Perform GET to OpenRouter /models and return parsed JSON."""
cleaned_api_base = api_base.strip().rstrip("/")
url = f"{cleaned_api_base}/models"
headers: dict[str, str] = {
headers = {
"Authorization": f"Bearer {api_key}",
# Optional headers recommended by OpenRouter for attribution
"HTTP-Referer": "https://onyx.app",
"X-Title": "Onyx",
}
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
try:
response = httpx.get(url, headers=headers, timeout=10.0)
response.raise_for_status()
@@ -1244,12 +1206,8 @@ def get_openrouter_available_models(
Parses id, name (display), context_length, and architecture.input_modalities.
"""
api_key = _resolve_api_key(
request.api_key, request.provider_name, request.api_base, db_session
)
response_json = _get_openrouter_models_response(
api_base=request.api_base, api_key=api_key
api_base=request.api_base, api_key=request.api_key
)
data = response_json.get("data", [])
@@ -1342,18 +1300,13 @@ def get_lm_studio_available_models(
# If provider_name is given and the api_key hasn't been changed by the user,
# fall back to the stored API key from the database (the form value is masked).
# Only do so when the api_base matches what is stored.
api_key = request.api_key
if request.provider_name and not request.api_key_changed:
existing_provider = fetch_existing_llm_provider(
name=request.provider_name, db_session=db_session
)
if existing_provider and existing_provider.custom_config:
stored_base = (existing_provider.api_base or "").strip().rstrip("/")
if stored_base == cleaned_api_base:
api_key = existing_provider.custom_config.get(
LM_STUDIO_API_KEY_CONFIG_KEY
)
api_key = existing_provider.custom_config.get(LM_STUDIO_API_KEY_CONFIG_KEY)
url = f"{cleaned_api_base}/api/v1/models"
headers: dict[str, str] = {}
@@ -1437,12 +1390,8 @@ def get_litellm_available_models(
db_session: Session = Depends(get_session),
) -> list[LitellmFinalModelResponse]:
"""Fetch available models from Litellm proxy /v1/models endpoint."""
api_key = _resolve_api_key(
request.api_key, request.provider_name, request.api_base, db_session
)
response_json = _get_litellm_models_response(
api_key=api_key, api_base=request.api_base
api_key=request.api_key, api_base=request.api_base
)
models = response_json.get("data", [])
@@ -1499,7 +1448,7 @@ def get_litellm_available_models(
return sorted_results
def _get_litellm_models_response(api_key: str | None, api_base: str) -> dict:
def _get_litellm_models_response(api_key: str, api_base: str) -> dict:
"""Perform GET to Litellm proxy /api/v1/models and return parsed JSON."""
cleaned_api_base = api_base.strip().rstrip("/")
url = f"{cleaned_api_base}/v1/models"
@@ -1574,12 +1523,8 @@ def get_bifrost_available_models(
db_session: Session = Depends(get_session),
) -> list[BifrostFinalModelResponse]:
"""Fetch available models from Bifrost gateway /v1/models endpoint."""
api_key = _resolve_api_key(
request.api_key, request.provider_name, request.api_base, db_session
)
response_json = _get_bifrost_models_response(
api_base=request.api_base, api_key=api_key
api_base=request.api_base, api_key=request.api_key
)
models = response_json.get("data", [])
@@ -1668,12 +1613,8 @@ def get_openai_compatible_server_available_models(
db_session: Session = Depends(get_session),
) -> list[OpenAICompatibleFinalModelResponse]:
"""Fetch available models from a generic OpenAI-compatible /v1/models endpoint."""
api_key = _resolve_api_key(
request.api_key, request.provider_name, request.api_base, db_session
)
response_json = _get_openai_compatible_server_response(
api_base=request.api_base, api_key=api_key
api_base=request.api_base, api_key=request.api_key
)
models = response_json.get("data", [])

View File

@@ -183,9 +183,6 @@ def generate_ollama_display_name(model_name: str) -> str:
"qwen2.5:7b""Qwen 2.5 7B"
"mistral:latest""Mistral"
"deepseek-r1:14b""DeepSeek R1 14B"
"gemma4:e4b""Gemma 4 E4B"
"deepseek-v3.1:671b-cloud""DeepSeek V3.1 671B Cloud"
"qwen3-vl:235b-instruct-cloud""Qwen 3-vl 235B Instruct Cloud"
"""
# Split into base name and tag
if ":" in model_name:
@@ -212,24 +209,13 @@ def generate_ollama_display_name(model_name: str) -> str:
# Default: Title case with dashes converted to spaces
display_name = base.replace("-", " ").title()
# Process tag (skip "latest")
# Process tag to extract size info (skip "latest")
if tag and tag.lower() != "latest":
# Check for size prefix like "7b", "70b", optionally followed by modifiers
size_match = re.match(r"^(\d+(?:\.\d+)?[bBmM])(-.+)?$", tag)
# Extract size like "7b", "70b", "14b"
size_match = re.match(r"^(\d+(?:\.\d+)?[bBmM])", tag)
if size_match:
size = size_match.group(1).upper()
remainder = size_match.group(2)
if remainder:
# Format modifiers like "-cloud", "-instruct-cloud"
modifiers = " ".join(
p.title() for p in remainder.strip("-").split("-") if p
)
display_name = f"{display_name} {size} {modifiers}"
else:
display_name = f"{display_name} {size}"
else:
# Non-size tags like "e4b", "q4_0", "fp16", "cloud"
display_name = f"{display_name} {tag.upper()}"
display_name = f"{display_name} {size}"
return display_name

View File

@@ -1,14 +1,13 @@
import json
import secrets
from collections.abc import AsyncIterator
from fastapi import APIRouter
from fastapi import Depends
from fastapi import File
from fastapi import Query
from fastapi import UploadFile
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from pydantic import Field
from sqlalchemy.orm import Session
from onyx.auth.permissions import require_permission
@@ -114,47 +113,28 @@ async def transcribe_audio(
) from exc
def _extract_provider_error(exc: Exception) -> str:
"""Extract a human-readable message from a provider exception.
Provider errors often embed JSON from upstream APIs (e.g. ElevenLabs).
This tries to parse a readable ``message`` field out of common JSON
error shapes; falls back to ``str(exc)`` if nothing better is found.
"""
raw = str(exc)
try:
# Many providers embed JSON after a prefix like "ElevenLabs TTS failed: {...}"
json_start = raw.find("{")
if json_start == -1:
return raw
parsed = json.loads(raw[json_start:])
# Shape: {"detail": {"message": "..."}} (ElevenLabs)
detail = parsed.get("detail", parsed)
if isinstance(detail, dict):
return detail.get("message") or detail.get("error") or raw
if isinstance(detail, str):
return detail
except (json.JSONDecodeError, AttributeError, TypeError):
pass
return raw
class SynthesizeRequest(BaseModel):
text: str = Field(..., min_length=1)
voice: str | None = None
speed: float | None = Field(default=None, ge=0.5, le=2.0)
@router.post("/synthesize")
async def synthesize_speech(
body: SynthesizeRequest,
text: str | None = Query(
default=None, description="Text to synthesize", max_length=4096
),
voice: str | None = Query(default=None, description="Voice ID to use"),
speed: float | None = Query(
default=None, description="Playback speed (0.5-2.0)", ge=0.5, le=2.0
),
user: User = Depends(require_permission(Permission.BASIC_ACCESS)),
) -> StreamingResponse:
"""Synthesize text to speech using the default TTS provider."""
text = body.text
voice = body.voice
speed = body.speed
logger.info(f"TTS request: text length={len(text)}, voice={voice}, speed={speed}")
"""
Synthesize text to speech using the default TTS provider.
Accepts parameters via query string for streaming compatibility.
"""
logger.info(
f"TTS request: text length={len(text) if text else 0}, voice={voice}, speed={speed}"
)
if not text:
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, "Text is required")
# Use short-lived session to fetch provider config, then release connection
# before starting the long-running streaming response
@@ -197,36 +177,31 @@ async def synthesize_speech(
logger.error(f"Failed to get voice provider: {exc}")
raise OnyxError(OnyxErrorCode.INTERNAL_ERROR, str(exc)) from exc
# Pull the first chunk before returning the StreamingResponse. If the
# provider rejects the request (e.g. text too long), the error surfaces
# as a proper HTTP error instead of a broken audio stream.
stream_iter = provider.synthesize_stream(
text=text, voice=final_voice, speed=final_speed
)
try:
first_chunk = await stream_iter.__anext__()
except StopAsyncIteration:
raise OnyxError(OnyxErrorCode.INTERNAL_ERROR, "TTS provider returned no audio")
except Exception as exc:
raise OnyxError(
OnyxErrorCode.BAD_GATEWAY, _extract_provider_error(exc)
) from exc
# Session is now closed - streaming response won't hold DB connection
async def audio_stream() -> AsyncIterator[bytes]:
yield first_chunk
chunk_count = 1
async for chunk in stream_iter:
chunk_count += 1
yield chunk
logger.info(f"TTS streaming complete: {chunk_count} chunks sent")
try:
chunk_count = 0
async for chunk in provider.synthesize_stream(
text=text, voice=final_voice, speed=final_speed
):
chunk_count += 1
yield chunk
logger.info(f"TTS streaming complete: {chunk_count} chunks sent")
except NotImplementedError as exc:
logger.error(f"TTS not implemented: {exc}")
raise
except Exception as exc:
logger.error(f"Synthesis failed: {exc}")
raise
return StreamingResponse(
audio_stream(),
media_type="audio/mpeg",
headers={
"Content-Disposition": "inline; filename=speech.mp3",
# Allow streaming by not setting content-length
"Cache-Control": "no-cache",
"X-Accel-Buffering": "no",
"X-Accel-Buffering": "no", # Disable nginx buffering
},
)

View File

@@ -1,8 +1,7 @@
"""Generic Celery task lifecycle Prometheus metrics.
Provides signal handlers that track task started/completed/failed counts,
active task gauge, task duration histograms, queue wait time histograms,
and retry/reject/revoke counts.
active task gauge, task duration histograms, and retry/reject/revoke counts.
These fire for ALL tasks on the worker — no per-connector enrichment
(see indexing_task_metrics.py for that).
@@ -72,32 +71,6 @@ TASK_REJECTED = Counter(
["task_name"],
)
TASK_QUEUE_WAIT = Histogram(
"onyx_celery_task_queue_wait_seconds",
"Time a Celery task spent waiting in the queue before execution started",
["task_name", "queue"],
buckets=[
0.1,
0.5,
1,
5,
30,
60,
300,
600,
1800,
3600,
7200,
14400,
28800,
43200,
86400,
172800,
432000,
864000,
],
)
# task_id → (monotonic start time, metric labels)
_task_start_times: dict[str, tuple[float, dict[str, str]]] = {}
@@ -160,13 +133,6 @@ def on_celery_task_prerun(
with _task_start_times_lock:
_evict_stale_start_times()
_task_start_times[task_id] = (time.monotonic(), labels)
headers = getattr(task.request, "headers", None) or {}
enqueued_at = headers.get("enqueued_at")
if isinstance(enqueued_at, (int, float)):
TASK_QUEUE_WAIT.labels(**labels).observe(
max(0.0, time.time() - enqueued_at)
)
except Exception:
logger.debug("Failed to record celery task prerun metrics", exc_info=True)

View File

@@ -1,110 +0,0 @@
"""Prometheus metrics for connector health and index attempts.
Emitted by docfetching and docprocessing workers when connector or
index attempt state changes. All functions silently catch exceptions
to avoid disrupting the caller's business logic.
Gauge metrics (error state, last success timestamp) are per-process.
With multiple worker pods, use max() aggregation in PromQL to get the
correct value across instances, e.g.:
max by (cc_pair_id) (onyx_connector_in_error_state)
"""
from prometheus_client import Counter
from prometheus_client import Gauge
from onyx.utils.logger import setup_logger
logger = setup_logger()
# --- Index attempt lifecycle ---
INDEX_ATTEMPT_STATUS = Counter(
"onyx_index_attempt_transitions_total",
"Index attempt status transitions",
["tenant_id", "source", "cc_pair_id", "status"],
)
# --- Connector health ---
CONNECTOR_IN_ERROR_STATE = Gauge(
"onyx_connector_in_error_state",
"Whether the connector is in a repeated error state (1=yes, 0=no)",
["tenant_id", "source", "cc_pair_id"],
)
CONNECTOR_LAST_SUCCESS_TIMESTAMP = Gauge(
"onyx_connector_last_success_timestamp_seconds",
"Unix timestamp of last successful indexing for this connector",
["tenant_id", "source", "cc_pair_id"],
)
CONNECTOR_DOCS_INDEXED = Counter(
"onyx_connector_docs_indexed_total",
"Total documents indexed per connector (monotonic)",
["tenant_id", "source", "cc_pair_id"],
)
CONNECTOR_INDEXING_ERRORS = Counter(
"onyx_connector_indexing_errors_total",
"Total failed index attempts per connector (monotonic)",
["tenant_id", "source", "cc_pair_id"],
)
def on_index_attempt_status_change(
tenant_id: str,
source: str,
cc_pair_id: int,
status: str,
) -> None:
"""Called on any index attempt status transition."""
try:
labels = {
"tenant_id": tenant_id,
"source": source,
"cc_pair_id": str(cc_pair_id),
}
INDEX_ATTEMPT_STATUS.labels(**labels, status=status).inc()
if status == "failed":
CONNECTOR_INDEXING_ERRORS.labels(**labels).inc()
except Exception:
logger.debug("Failed to record index attempt status metric", exc_info=True)
def on_connector_error_state_change(
tenant_id: str,
source: str,
cc_pair_id: int,
in_error: bool,
) -> None:
"""Called when a connector's in_repeated_error_state changes."""
try:
CONNECTOR_IN_ERROR_STATE.labels(
tenant_id=tenant_id,
source=source,
cc_pair_id=str(cc_pair_id),
).set(1.0 if in_error else 0.0)
except Exception:
logger.debug("Failed to record connector error state metric", exc_info=True)
def on_connector_indexing_success(
tenant_id: str,
source: str,
cc_pair_id: int,
docs_indexed: int,
success_timestamp: float,
) -> None:
"""Called when an indexing run completes successfully."""
try:
labels = {
"tenant_id": tenant_id,
"source": source,
"cc_pair_id": str(cc_pair_id),
}
CONNECTOR_LAST_SUCCESS_TIMESTAMP.labels(**labels).set(success_timestamp)
if docs_indexed > 0:
CONNECTOR_DOCS_INDEXED.labels(**labels).inc(docs_indexed)
except Exception:
logger.debug("Failed to record connector success metric", exc_info=True)

View File

@@ -1,104 +0,0 @@
"""Connector-deletion-specific Prometheus metrics.
Tracks the deletion lifecycle:
1. Deletions started (taskset generated)
2. Deletions completed (success or failure)
3. Taskset duration (from taskset generation to completion or failure).
Note: this measures the most recent taskset execution, NOT wall-clock
time since the user triggered the deletion. When deletion is blocked by
indexing/pruning/permissions, the fence is cleared and a fresh taskset
is generated on each retry, resetting this timer.
4. Deletion blocked by dependencies (indexing, pruning, permissions, etc.)
5. Fence resets (stuck deletion recovery)
All metrics are labeled by tenant_id. cc_pair_id is intentionally excluded
to avoid unbounded cardinality.
Usage:
from onyx.server.metrics.deletion_metrics import (
inc_deletion_started,
inc_deletion_completed,
observe_deletion_taskset_duration,
inc_deletion_blocked,
inc_deletion_fence_reset,
)
"""
from prometheus_client import Counter
from prometheus_client import Histogram
from onyx.utils.logger import setup_logger
logger = setup_logger()
DELETION_STARTED = Counter(
"onyx_deletion_started_total",
"Connector deletions initiated (taskset generated)",
["tenant_id"],
)
DELETION_COMPLETED = Counter(
"onyx_deletion_completed_total",
"Connector deletions completed",
["tenant_id", "outcome"],
)
DELETION_TASKSET_DURATION = Histogram(
"onyx_deletion_taskset_duration_seconds",
"Duration of a connector deletion taskset, from taskset generation "
"to completion or failure. Does not include time spent blocked on "
"indexing/pruning/permissions before the taskset was generated.",
["tenant_id", "outcome"],
buckets=[10, 30, 60, 120, 300, 600, 1800, 3600, 7200, 21600],
)
DELETION_BLOCKED = Counter(
"onyx_deletion_blocked_total",
"Times deletion was blocked by a dependency",
["tenant_id", "blocker"],
)
DELETION_FENCE_RESET = Counter(
"onyx_deletion_fence_reset_total",
"Deletion fences reset due to missing celery tasks",
["tenant_id"],
)
def inc_deletion_started(tenant_id: str) -> None:
try:
DELETION_STARTED.labels(tenant_id=tenant_id).inc()
except Exception:
logger.debug("Failed to record deletion started", exc_info=True)
def inc_deletion_completed(tenant_id: str, outcome: str) -> None:
try:
DELETION_COMPLETED.labels(tenant_id=tenant_id, outcome=outcome).inc()
except Exception:
logger.debug("Failed to record deletion completed", exc_info=True)
def observe_deletion_taskset_duration(
tenant_id: str, outcome: str, duration_seconds: float
) -> None:
try:
DELETION_TASKSET_DURATION.labels(tenant_id=tenant_id, outcome=outcome).observe(
duration_seconds
)
except Exception:
logger.debug("Failed to record deletion taskset duration", exc_info=True)
def inc_deletion_blocked(tenant_id: str, blocker: str) -> None:
try:
DELETION_BLOCKED.labels(tenant_id=tenant_id, blocker=blocker).inc()
except Exception:
logger.debug("Failed to record deletion blocked", exc_info=True)
def inc_deletion_fence_reset(tenant_id: str) -> None:
try:
DELETION_FENCE_RESET.labels(tenant_id=tenant_id).inc()
except Exception:
logger.debug("Failed to record deletion fence reset", exc_info=True)

View File

@@ -1,30 +1,25 @@
"""Prometheus collectors for Celery queue depths and infrastructure health.
"""Prometheus collectors for Celery queue depths and indexing pipeline state.
These collectors query Redis at scrape time (the Collector pattern),
These collectors query Redis and Postgres at scrape time (the Collector pattern),
so metrics are always fresh when Prometheus scrapes /metrics. They run inside the
monitoring celery worker which already has Redis access.
monitoring celery worker which already has Redis and DB access.
To avoid hammering Redis on every 15s scrape, results are cached with
To avoid hammering Redis/Postgres on every 15s scrape, results are cached with
a configurable TTL (default 30s). This means metrics may be up to TTL seconds
stale, which is fine for monitoring dashboards.
Note: connector health and index attempt metrics are push-based (emitted by
workers at state-change time) and live in connector_health_metrics.py.
"""
from __future__ import annotations
import concurrent.futures
import json
import threading
import time
from datetime import datetime
from datetime import timezone
from typing import Any
from prometheus_client.core import GaugeMetricFamily
from prometheus_client.registry import Collector
from redis import Redis
from onyx.background.celery.celery_redis import celery_get_broker_client
from onyx.background.celery.celery_redis import celery_get_queue_length
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
from onyx.configs.constants import OnyxCeleryQueues
@@ -36,11 +31,6 @@ logger = setup_logger()
# the previous result without re-querying Redis/Postgres.
_DEFAULT_CACHE_TTL = 30.0
# Maximum time (seconds) a single _collect_fresh() call may take before
# the collector gives up and returns stale/empty results. Prevents the
# /metrics endpoint from hanging indefinitely when a DB or Redis query stalls.
_DEFAULT_COLLECT_TIMEOUT = 120.0
_QUEUE_LABEL_MAP: dict[str, str] = {
OnyxCeleryQueues.PRIMARY: "primary",
OnyxCeleryQueues.DOCPROCESSING: "docprocessing",
@@ -72,32 +62,18 @@ _UNACKED_QUEUES: list[str] = [
class _CachedCollector(Collector):
"""Base collector with TTL-based caching and timeout protection.
"""Base collector with TTL-based caching.
Subclasses implement ``_collect_fresh()`` to query the actual data source.
The base ``collect()`` returns cached results if the TTL hasn't expired,
avoiding repeated queries when Prometheus scrapes frequently.
A per-collection timeout prevents a slow DB or Redis query from blocking
the /metrics endpoint indefinitely. If _collect_fresh() exceeds the
timeout, stale cached results are returned instead.
"""
def __init__(
self,
cache_ttl: float = _DEFAULT_CACHE_TTL,
collect_timeout: float = _DEFAULT_COLLECT_TIMEOUT,
) -> None:
def __init__(self, cache_ttl: float = _DEFAULT_CACHE_TTL) -> None:
self._cache_ttl = cache_ttl
self._collect_timeout = collect_timeout
self._cached_result: list[GaugeMetricFamily] | None = None
self._last_collect_time: float = 0.0
self._lock = threading.Lock()
self._executor = concurrent.futures.ThreadPoolExecutor(
max_workers=1,
thread_name_prefix=type(self).__name__,
)
self._inflight: concurrent.futures.Future | None = None
def collect(self) -> list[GaugeMetricFamily]:
with self._lock:
@@ -108,28 +84,12 @@ class _CachedCollector(Collector):
):
return self._cached_result
# If a previous _collect_fresh() is still running, wait on it
# rather than queuing another. This prevents unbounded task
# accumulation in the executor during extended DB outages.
if self._inflight is not None and not self._inflight.done():
future = self._inflight
else:
future = self._executor.submit(self._collect_fresh)
self._inflight = future
try:
result = future.result(timeout=self._collect_timeout)
self._inflight = None
result = self._collect_fresh()
self._cached_result = result
self._last_collect_time = now
return result
except concurrent.futures.TimeoutError:
logger.warning(
f"{type(self).__name__}._collect_fresh() timed out after {self._collect_timeout}s, returning stale cache"
)
return self._cached_result if self._cached_result is not None else []
except Exception:
self._inflight = None
logger.exception(f"Error in {type(self).__name__}.collect()")
# Return stale cache on error rather than nothing — avoids
# metrics disappearing during transient failures.
@@ -157,6 +117,8 @@ class QueueDepthCollector(_CachedCollector):
if self._celery_app is None:
return []
from onyx.background.celery.celery_redis import celery_get_broker_client
redis_client = celery_get_broker_client(self._celery_app)
depth = GaugeMetricFamily(
@@ -232,6 +194,208 @@ class QueueDepthCollector(_CachedCollector):
return None
class IndexAttemptCollector(_CachedCollector):
"""Queries Postgres for index attempt state on each scrape."""
def __init__(self, cache_ttl: float = _DEFAULT_CACHE_TTL) -> None:
super().__init__(cache_ttl)
self._configured: bool = False
self._terminal_statuses: list = []
def configure(self) -> None:
"""Call once DB engine is initialized."""
from onyx.db.enums import IndexingStatus
self._terminal_statuses = [s for s in IndexingStatus if s.is_terminal()]
self._configured = True
def _collect_fresh(self) -> list[GaugeMetricFamily]:
if not self._configured:
return []
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.engine.tenant_utils import get_all_tenant_ids
from onyx.db.index_attempt import get_active_index_attempts_for_metrics
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
attempts_gauge = GaugeMetricFamily(
"onyx_index_attempts_active",
"Number of non-terminal index attempts",
labels=[
"status",
"source",
"tenant_id",
"connector_name",
"cc_pair_id",
],
)
tenant_ids = get_all_tenant_ids()
for tid in tenant_ids:
# Defensive guard — get_all_tenant_ids() should never yield None,
# but we guard here for API stability in case the contract changes.
if tid is None:
continue
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tid)
try:
with get_session_with_current_tenant() as session:
rows = get_active_index_attempts_for_metrics(session)
for status, source, cc_id, cc_name, count in rows:
name_val = cc_name or f"cc_pair_{cc_id}"
attempts_gauge.add_metric(
[
status.value,
source.value,
tid,
name_val,
str(cc_id),
],
count,
)
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
return [attempts_gauge]
class ConnectorHealthCollector(_CachedCollector):
"""Queries Postgres for connector health state on each scrape."""
def __init__(self, cache_ttl: float = _DEFAULT_CACHE_TTL) -> None:
super().__init__(cache_ttl)
self._configured: bool = False
def configure(self) -> None:
"""Call once DB engine is initialized."""
self._configured = True
def _collect_fresh(self) -> list[GaugeMetricFamily]:
if not self._configured:
return []
from onyx.db.connector_credential_pair import (
get_connector_health_for_metrics,
)
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.engine.tenant_utils import get_all_tenant_ids
from onyx.db.index_attempt import get_docs_indexed_by_cc_pair
from onyx.db.index_attempt import get_failed_attempt_counts_by_cc_pair
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
staleness_gauge = GaugeMetricFamily(
"onyx_connector_last_success_age_seconds",
"Seconds since last successful index for this connector",
labels=["tenant_id", "source", "cc_pair_id", "connector_name"],
)
error_state_gauge = GaugeMetricFamily(
"onyx_connector_in_error_state",
"Whether the connector is in a repeated error state (1=yes, 0=no)",
labels=["tenant_id", "source", "cc_pair_id", "connector_name"],
)
by_status_gauge = GaugeMetricFamily(
"onyx_connectors_by_status",
"Number of connectors grouped by status",
labels=["tenant_id", "status"],
)
error_total_gauge = GaugeMetricFamily(
"onyx_connectors_in_error_total",
"Total number of connectors in repeated error state",
labels=["tenant_id"],
)
per_connector_labels = [
"tenant_id",
"source",
"cc_pair_id",
"connector_name",
]
docs_success_gauge = GaugeMetricFamily(
"onyx_connector_docs_indexed",
"Total new documents indexed (90-day rolling sum) per connector",
labels=per_connector_labels,
)
docs_error_gauge = GaugeMetricFamily(
"onyx_connector_error_count",
"Total number of failed index attempts per connector",
labels=per_connector_labels,
)
now = datetime.now(tz=timezone.utc)
tenant_ids = get_all_tenant_ids()
for tid in tenant_ids:
# Defensive guard — get_all_tenant_ids() should never yield None,
# but we guard here for API stability in case the contract changes.
if tid is None:
continue
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tid)
try:
with get_session_with_current_tenant() as session:
pairs = get_connector_health_for_metrics(session)
error_counts_by_cc = get_failed_attempt_counts_by_cc_pair(session)
docs_by_cc = get_docs_indexed_by_cc_pair(session)
status_counts: dict[str, int] = {}
error_count = 0
for (
cc_id,
status,
in_error,
last_success,
cc_name,
source,
) in pairs:
cc_id_str = str(cc_id)
source_val = source.value
name_val = cc_name or f"cc_pair_{cc_id}"
label_vals = [tid, source_val, cc_id_str, name_val]
if last_success is not None:
# Both `now` and `last_success` are timezone-aware
# (the DB column uses DateTime(timezone=True)),
# so subtraction is safe.
age = (now - last_success).total_seconds()
staleness_gauge.add_metric(label_vals, age)
error_state_gauge.add_metric(
label_vals,
1.0 if in_error else 0.0,
)
if in_error:
error_count += 1
docs_success_gauge.add_metric(
label_vals,
docs_by_cc.get(cc_id, 0),
)
docs_error_gauge.add_metric(
label_vals,
error_counts_by_cc.get(cc_id, 0),
)
status_val = status.value
status_counts[status_val] = status_counts.get(status_val, 0) + 1
for status_val, count in status_counts.items():
by_status_gauge.add_metric([tid, status_val], count)
error_total_gauge.add_metric([tid], error_count)
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
return [
staleness_gauge,
error_state_gauge,
by_status_gauge,
error_total_gauge,
docs_success_gauge,
docs_error_gauge,
]
class RedisHealthCollector(_CachedCollector):
"""Collects Redis server health metrics (memory, clients, etc.)."""
@@ -247,6 +411,8 @@ class RedisHealthCollector(_CachedCollector):
if self._celery_app is None:
return []
from onyx.background.celery.celery_redis import celery_get_broker_client
redis_client = celery_get_broker_client(self._celery_app)
memory_used = GaugeMetricFamily(
@@ -329,9 +495,7 @@ class WorkerHeartbeatMonitor:
},
)
recv.capture(
limit=None,
timeout=self._HEARTBEAT_TIMEOUT_SECONDS,
wakeup=True,
limit=None, timeout=self._HEARTBEAT_TIMEOUT_SECONDS, wakeup=True
)
except Exception:
if self._running:

View File

@@ -6,6 +6,8 @@ Called once by the monitoring celery worker after Redis and DB are ready.
from celery import Celery
from prometheus_client.registry import REGISTRY
from onyx.server.metrics.indexing_pipeline import ConnectorHealthCollector
from onyx.server.metrics.indexing_pipeline import IndexAttemptCollector
from onyx.server.metrics.indexing_pipeline import QueueDepthCollector
from onyx.server.metrics.indexing_pipeline import RedisHealthCollector
from onyx.server.metrics.indexing_pipeline import WorkerHealthCollector
@@ -19,6 +21,8 @@ logger = setup_logger()
# module level ensures they survive the lifetime of the worker process and are
# only registered with the Prometheus registry once.
_queue_collector = QueueDepthCollector()
_attempt_collector = IndexAttemptCollector()
_connector_collector = ConnectorHealthCollector()
_redis_health_collector = RedisHealthCollector()
_worker_health_collector = WorkerHealthCollector()
_heartbeat_monitor: WorkerHeartbeatMonitor | None = None
@@ -30,9 +34,6 @@ def setup_indexing_pipeline_metrics(celery_app: Celery) -> None:
Args:
celery_app: The Celery application instance. Used to obtain a
broker Redis client on each scrape for queue depth metrics.
Note: connector health and index attempt metrics are push-based
(see connector_health_metrics.py) and do not use collectors.
"""
_queue_collector.set_celery_app(celery_app)
_redis_health_collector.set_celery_app(celery_app)
@@ -46,8 +47,13 @@ def setup_indexing_pipeline_metrics(celery_app: Celery) -> None:
_heartbeat_monitor.start()
_worker_health_collector.set_monitor(_heartbeat_monitor)
_attempt_collector.configure()
_connector_collector.configure()
for collector in (
_queue_collector,
_attempt_collector,
_connector_collector,
_redis_health_collector,
_worker_health_collector,
):

View File

@@ -27,7 +27,6 @@ _DEFAULT_PORTS: dict[str, int] = {
"docfetching": 9092,
"docprocessing": 9093,
"heavy": 9094,
"light": 9095,
}
_server_started = False

View File

@@ -28,14 +28,14 @@ PRUNING_ENUMERATION_DURATION = Histogram(
"onyx_pruning_enumeration_duration_seconds",
"Duration of document ID enumeration from the source connector during pruning",
["connector_type"],
buckets=[5, 60, 600, 1800, 3600, 10800, 21600],
buckets=[1, 5, 15, 30, 60, 120, 300, 600, 1800, 3600],
)
PRUNING_DIFF_DURATION = Histogram(
"onyx_pruning_diff_duration_seconds",
"Duration of diff computation and subtask dispatch during pruning",
["connector_type"],
buckets=[0.1, 0.25, 0.5, 1, 2, 5, 15, 30, 60],
buckets=[1, 5, 15, 30, 60, 120, 300, 600, 1800, 3600],
)
PRUNING_RATE_LIMIT_ERRORS = Counter(

Some files were not shown because too many files have changed in this diff Show More