mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-04-16 23:16:46 +00:00
Compare commits
141 Commits
ods/v0.7.3
...
jamison/ti
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7e9d73b485 | ||
|
|
73f9a47364 | ||
|
|
a808445d96 | ||
|
|
c31215197a | ||
|
|
9ebd9ebd73 | ||
|
|
f0bb0a6bb0 | ||
|
|
01bec19d19 | ||
|
|
7b40c2cde7 | ||
|
|
e2c38d2899 | ||
|
|
24768f9e4f | ||
|
|
aec1c169b6 | ||
|
|
5a16ad3473 | ||
|
|
7e28e59f23 | ||
|
|
879ae6c02d | ||
|
|
f84f367eb4 | ||
|
|
d81efe3877 | ||
|
|
d4619f93c4 | ||
|
|
70fcfb1d73 | ||
|
|
32ba393b32 | ||
|
|
f9d2bf78ed | ||
|
|
5567a078fe | ||
|
|
fc0e8560bc | ||
|
|
60b2701eed | ||
|
|
3682d9844b | ||
|
|
a420f9a37c | ||
|
|
20c5107ba6 | ||
|
|
357bc91aee | ||
|
|
09653872a2 | ||
|
|
ff01a53f83 | ||
|
|
03ddd5ca9b | ||
|
|
8c49e4573c | ||
|
|
f1696ffa16 | ||
|
|
a427cb5b0c | ||
|
|
f7e4be18dd | ||
|
|
0f31c490fa | ||
|
|
c9a4a6e42b | ||
|
|
558c9df3c7 | ||
|
|
30003036d3 | ||
|
|
4b2f18c239 | ||
|
|
4290b097f5 | ||
|
|
b0f621a08b | ||
|
|
112edf41c5 | ||
|
|
74eb1d7212 | ||
|
|
e62d592b11 | ||
|
|
57a0d25321 | ||
|
|
887f79d7a5 | ||
|
|
65fd1c3ec8 | ||
|
|
6e3ee287b9 | ||
|
|
dee0b7867e | ||
|
|
77beb8044e | ||
|
|
750d3ac4ed | ||
|
|
6c02087ba4 | ||
|
|
0425283ed0 | ||
|
|
da97a57c58 | ||
|
|
8087ddb97c | ||
|
|
d9d5943dc4 | ||
|
|
97a7fa6f7f | ||
|
|
8027e62446 | ||
|
|
571e860d4f | ||
|
|
89b91ac384 | ||
|
|
069b1f3efb | ||
|
|
ef2fffcd6e | ||
|
|
925be18424 | ||
|
|
38fffc8ad8 | ||
|
|
3e9e2f08d5 | ||
|
|
243d93ecd8 | ||
|
|
4effe77225 | ||
|
|
ef2df458a3 | ||
|
|
d3000da3d0 | ||
|
|
a5c703f9ca | ||
|
|
d10c901c43 | ||
|
|
f1ac555c57 | ||
|
|
ed52384c21 | ||
|
|
cb10376a0d | ||
|
|
5a25b70b9c | ||
|
|
8cbc37f281 | ||
|
|
9d78f71f23 | ||
|
|
fbf3179d84 | ||
|
|
779470b553 | ||
|
|
151e189898 | ||
|
|
72e08f81a4 | ||
|
|
65792a8ad8 | ||
|
|
497b700b3d | ||
|
|
c3ed2135f1 | ||
|
|
a969d56818 | ||
|
|
a31d862f48 | ||
|
|
a4e6d4cf43 | ||
|
|
1e6f94e00d | ||
|
|
a769b87a9d | ||
|
|
278fc7e9b1 | ||
|
|
eb34df470f | ||
|
|
9d1785273f | ||
|
|
ef69b17d26 | ||
|
|
787c961802 | ||
|
|
62bc4fa2a3 | ||
|
|
bb1c44daff | ||
|
|
f26ecafb51 | ||
|
|
9fdb425c0d | ||
|
|
47e20e89c5 | ||
|
|
8b28c127f2 | ||
|
|
9a861a71ad | ||
|
|
b4bc12f6dc | ||
|
|
9af9148ca7 | ||
|
|
8a517c4f10 | ||
|
|
6959d851ea | ||
|
|
6a2550fc2d | ||
|
|
b1cc0c2bf9 | ||
|
|
c28b17064b | ||
|
|
4dab92ab52 | ||
|
|
7eb68d61b0 | ||
|
|
8c7810d688 | ||
|
|
712e6fdf5e | ||
|
|
f1a9a3b41e | ||
|
|
c3405fb6bf | ||
|
|
3e962935f4 | ||
|
|
0aa1aa7ea0 | ||
|
|
771d2cf101 | ||
|
|
7ec50280ed | ||
|
|
5b2ba5caeb | ||
|
|
4a96ef13d7 | ||
|
|
822b0c99be | ||
|
|
bcf2851a85 | ||
|
|
a5a59bd8f0 | ||
|
|
32d2e7985a | ||
|
|
c4f8d5370b | ||
|
|
9e434f6a5a | ||
|
|
67dc819319 | ||
|
|
2d12274050 | ||
|
|
c727ba13ee | ||
|
|
6193dd5326 | ||
|
|
387a7d1cea | ||
|
|
869578eeed | ||
|
|
e68648ab74 | ||
|
|
da01002099 | ||
|
|
f5d66f389c | ||
|
|
82d89f78c6 | ||
|
|
6f49c5e32c | ||
|
|
41f2bd2f19 | ||
|
|
bfa2f672f9 | ||
|
|
a823c3ead1 | ||
|
|
bd7d378a9a |
63
.devcontainer/Dockerfile
Normal file
63
.devcontainer/Dockerfile
Normal file
@@ -0,0 +1,63 @@
|
||||
FROM ubuntu:26.04@sha256:cc925e589b7543b910fea57a240468940003fbfc0515245a495dd0ad8fe7cef1
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
curl \
|
||||
default-jre \
|
||||
fd-find \
|
||||
fzf \
|
||||
git \
|
||||
jq \
|
||||
less \
|
||||
make \
|
||||
neovim \
|
||||
openssh-client \
|
||||
python3-venv \
|
||||
ripgrep \
|
||||
sudo \
|
||||
ca-certificates \
|
||||
iptables \
|
||||
ipset \
|
||||
iproute2 \
|
||||
dnsutils \
|
||||
unzip \
|
||||
wget \
|
||||
zsh \
|
||||
&& curl -fsSL https://deb.nodesource.com/setup_20.x | bash - \
|
||||
&& apt-get install -y nodejs \
|
||||
&& install -m 0755 -d /etc/apt/keyrings \
|
||||
&& curl -fsSL https://cli.github.com/packages/githubcli-archive-keyring.gpg -o /etc/apt/keyrings/githubcli-archive-keyring.gpg \
|
||||
&& chmod go+r /etc/apt/keyrings/githubcli-archive-keyring.gpg \
|
||||
&& echo "deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/githubcli-archive-keyring.gpg] https://cli.github.com/packages stable main" > /etc/apt/sources.list.d/github-cli.list \
|
||||
&& apt-get update \
|
||||
&& apt-get install -y --no-install-recommends gh \
|
||||
&& apt-get clean && rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# fd-find installs as fdfind on Debian/Ubuntu — symlink to fd
|
||||
RUN ln -sf "$(which fdfind)" /usr/local/bin/fd
|
||||
|
||||
# Install uv (Python package manager)
|
||||
COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /usr/local/bin/
|
||||
|
||||
# Create non-root dev user with passwordless sudo
|
||||
RUN useradd -m -s /bin/zsh dev && \
|
||||
echo "dev ALL=(ALL) NOPASSWD:ALL" > /etc/sudoers.d/dev && \
|
||||
chmod 0440 /etc/sudoers.d/dev
|
||||
|
||||
ENV DEVCONTAINER=true
|
||||
|
||||
RUN mkdir -p /workspace && \
|
||||
chown -R dev:dev /workspace
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
# Install Claude Code
|
||||
ARG CLAUDE_CODE_VERSION=latest
|
||||
RUN npm install -g @anthropic-ai/claude-code@${CLAUDE_CODE_VERSION}
|
||||
|
||||
# Configure zsh — source the repo-local zshrc so shell customization
|
||||
# doesn't require an image rebuild.
|
||||
RUN chsh -s /bin/zsh root && \
|
||||
for rc in /root/.zshrc /home/dev/.zshrc; do \
|
||||
echo '[ -f /workspace/.devcontainer/zshrc ] && . /workspace/.devcontainer/zshrc' >> "$rc"; \
|
||||
done && \
|
||||
chown dev:dev /home/dev/.zshrc
|
||||
86
.devcontainer/README.md
Normal file
86
.devcontainer/README.md
Normal file
@@ -0,0 +1,86 @@
|
||||
# Onyx Dev Container
|
||||
|
||||
A containerized development environment for working on Onyx.
|
||||
|
||||
## What's included
|
||||
|
||||
- Ubuntu 26.04 base image
|
||||
- Node.js 20, uv, Claude Code
|
||||
- GitHub CLI (`gh`)
|
||||
- Neovim, ripgrep, fd, fzf, jq, make, wget, unzip
|
||||
- Zsh as default shell (sources host `~/.zshrc` if available)
|
||||
- Python venv auto-activation
|
||||
- Network firewall (default-deny, whitelists npm, GitHub, Anthropic APIs, Sentry, and VS Code update servers)
|
||||
|
||||
## Usage
|
||||
|
||||
### CLI (`ods dev`)
|
||||
|
||||
The [`ods` devtools CLI](../tools/ods/README.md) provides workspace-aware wrappers
|
||||
for all devcontainer operations (also available as `ods dc`):
|
||||
|
||||
```bash
|
||||
# Start the container
|
||||
ods dev up
|
||||
|
||||
# Open a shell
|
||||
ods dev into
|
||||
|
||||
# Run a command
|
||||
ods dev exec npm test
|
||||
|
||||
# Stop the container
|
||||
ods dev stop
|
||||
```
|
||||
|
||||
## Restarting the container
|
||||
|
||||
```bash
|
||||
# Restart the container
|
||||
ods dev restart
|
||||
|
||||
# Pull the latest published image and recreate
|
||||
ods dev rebuild
|
||||
```
|
||||
|
||||
## Image
|
||||
|
||||
The devcontainer uses a prebuilt image published to `onyxdotapp/onyx-devcontainer`.
|
||||
The tag is pinned in `devcontainer.json` — no local build is required.
|
||||
|
||||
To build the image locally (e.g. while iterating on the Dockerfile):
|
||||
|
||||
```bash
|
||||
docker buildx bake devcontainer
|
||||
```
|
||||
|
||||
The `devcontainer` target is defined in `docker-bake.hcl` at the repo root.
|
||||
|
||||
## User & permissions
|
||||
|
||||
The container runs as the `dev` user by default (`remoteUser` in devcontainer.json).
|
||||
An init script (`init-dev-user.sh`) runs at container start to ensure the active
|
||||
user has read/write access to the bind-mounted workspace:
|
||||
|
||||
- **Standard Docker** — `dev`'s UID/GID is remapped to match the workspace owner,
|
||||
so file permissions work seamlessly.
|
||||
- **Rootless Docker** — The workspace appears as root-owned (UID 0) inside the
|
||||
container due to user-namespace mapping. `ods dev up` auto-detects rootless Docker
|
||||
and sets `DEVCONTAINER_REMOTE_USER=root` so the container runs as root — which
|
||||
maps back to your host user via the user namespace. New files are owned by your
|
||||
host UID and no ACL workarounds are needed.
|
||||
|
||||
To override the auto-detection, set `DEVCONTAINER_REMOTE_USER` before running
|
||||
`ods dev up`.
|
||||
|
||||
## Firewall
|
||||
|
||||
The container starts with a default-deny firewall (`init-firewall.sh`) that only allows outbound traffic to:
|
||||
|
||||
- npm registry
|
||||
- GitHub
|
||||
- Anthropic API
|
||||
- Sentry
|
||||
- VS Code update servers
|
||||
|
||||
This requires the `NET_ADMIN` and `NET_RAW` capabilities, which are added via `runArgs` in `devcontainer.json`.
|
||||
26
.devcontainer/devcontainer.json
Normal file
26
.devcontainer/devcontainer.json
Normal file
@@ -0,0 +1,26 @@
|
||||
{
|
||||
"name": "Onyx Dev Sandbox",
|
||||
"image": "onyxdotapp/onyx-devcontainer@sha256:0f02d9299928849c7b15f3b348dcfdcdcb64411ff7a4580cbc026a6ee7aa1554",
|
||||
"runArgs": ["--cap-add=NET_ADMIN", "--cap-add=NET_RAW", "--network=onyx_default"],
|
||||
"mounts": [
|
||||
"source=${localEnv:HOME}/.claude,target=/home/dev/.claude,type=bind",
|
||||
"source=${localEnv:HOME}/.claude.json,target=/home/dev/.claude.json,type=bind",
|
||||
"source=${localEnv:HOME}/.zshrc,target=/home/dev/.zshrc.host,type=bind,readonly",
|
||||
"source=${localEnv:HOME}/.gitconfig,target=/home/dev/.gitconfig,type=bind,readonly",
|
||||
"source=${localEnv:HOME}/.config/nvim,target=/home/dev/.config/nvim,type=bind,readonly",
|
||||
"source=onyx-devcontainer-cache,target=/home/dev/.cache,type=volume",
|
||||
"source=onyx-devcontainer-local,target=/home/dev/.local,type=volume"
|
||||
],
|
||||
"containerEnv": {
|
||||
"SSH_AUTH_SOCK": "/tmp/ssh-agent.sock",
|
||||
"POSTGRES_HOST": "relational_db",
|
||||
"REDIS_HOST": "cache"
|
||||
},
|
||||
"remoteUser": "${localEnv:DEVCONTAINER_REMOTE_USER:dev}",
|
||||
"updateRemoteUserUID": false,
|
||||
"initializeCommand": "docker network create onyx_default 2>/dev/null || true",
|
||||
"workspaceMount": "source=${localWorkspaceFolder},target=/workspace,type=bind,consistency=delegated",
|
||||
"workspaceFolder": "/workspace",
|
||||
"postStartCommand": "sudo bash /workspace/.devcontainer/init-dev-user.sh && sudo bash /workspace/.devcontainer/init-firewall.sh",
|
||||
"waitFor": "postStartCommand"
|
||||
}
|
||||
107
.devcontainer/init-dev-user.sh
Normal file
107
.devcontainer/init-dev-user.sh
Normal file
@@ -0,0 +1,107 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
# Remap the dev user's UID/GID to match the workspace owner so that
|
||||
# bind-mounted files are accessible without running as root.
|
||||
#
|
||||
# Standard Docker: Workspace is owned by the host user's UID (e.g. 1000).
|
||||
# We remap dev to that UID -- fast and seamless.
|
||||
#
|
||||
# Rootless Docker: Workspace appears as root-owned (UID 0) inside the
|
||||
# container due to user-namespace mapping. Requires
|
||||
# DEVCONTAINER_REMOTE_USER=root (set automatically by
|
||||
# ods dev up). Container root IS the host user, so
|
||||
# bind-mounts and named volumes are symlinked into /root.
|
||||
|
||||
WORKSPACE=/workspace
|
||||
TARGET_USER=dev
|
||||
REMOTE_USER="${SUDO_USER:-$TARGET_USER}"
|
||||
|
||||
WS_UID=$(stat -c '%u' "$WORKSPACE")
|
||||
WS_GID=$(stat -c '%g' "$WORKSPACE")
|
||||
DEV_UID=$(id -u "$TARGET_USER")
|
||||
DEV_GID=$(id -g "$TARGET_USER")
|
||||
|
||||
# devcontainer.json bind-mounts and named volumes target /home/dev regardless
|
||||
# of remoteUser. When running as root ($HOME=/root), Phase 1 bridges the gap
|
||||
# with symlinks from ACTIVE_HOME → MOUNT_HOME.
|
||||
MOUNT_HOME=/home/"$TARGET_USER"
|
||||
|
||||
if [ "$REMOTE_USER" = "root" ]; then
|
||||
ACTIVE_HOME="/root"
|
||||
else
|
||||
ACTIVE_HOME="$MOUNT_HOME"
|
||||
fi
|
||||
|
||||
# ── Phase 1: home directory setup ───────────────────────────────────
|
||||
|
||||
# ~/.local and ~/.cache are named Docker volumes mounted under MOUNT_HOME.
|
||||
mkdir -p "$MOUNT_HOME"/.local/state "$MOUNT_HOME"/.local/share
|
||||
|
||||
# When running as root, symlink bind-mounts and named volumes into /root
|
||||
# so that $HOME-relative tools (Claude Code, git, etc.) find them.
|
||||
if [ "$ACTIVE_HOME" != "$MOUNT_HOME" ]; then
|
||||
for item in .claude .cache .local; do
|
||||
[ -d "$MOUNT_HOME/$item" ] || continue
|
||||
if [ -e "$ACTIVE_HOME/$item" ] && [ ! -L "$ACTIVE_HOME/$item" ]; then
|
||||
echo "warning: replacing $ACTIVE_HOME/$item with symlink to $MOUNT_HOME/$item" >&2
|
||||
rm -rf "$ACTIVE_HOME/$item"
|
||||
fi
|
||||
ln -sfn "$MOUNT_HOME/$item" "$ACTIVE_HOME/$item"
|
||||
done
|
||||
# Symlink files (not directories).
|
||||
for file in .claude.json .gitconfig .zshrc.host; do
|
||||
[ -f "$MOUNT_HOME/$file" ] && ln -sf "$MOUNT_HOME/$file" "$ACTIVE_HOME/$file"
|
||||
done
|
||||
|
||||
# Nested mount: .config/nvim
|
||||
if [ -d "$MOUNT_HOME/.config/nvim" ]; then
|
||||
mkdir -p "$ACTIVE_HOME/.config"
|
||||
if [ -e "$ACTIVE_HOME/.config/nvim" ] && [ ! -L "$ACTIVE_HOME/.config/nvim" ]; then
|
||||
echo "warning: replacing $ACTIVE_HOME/.config/nvim with symlink" >&2
|
||||
rm -rf "$ACTIVE_HOME/.config/nvim"
|
||||
fi
|
||||
ln -sfn "$MOUNT_HOME/.config/nvim" "$ACTIVE_HOME/.config/nvim"
|
||||
fi
|
||||
fi
|
||||
|
||||
# ── Phase 2: workspace access ───────────────────────────────────────
|
||||
|
||||
# Root always has workspace access; Phase 1 handled home setup.
|
||||
if [ "$REMOTE_USER" = "root" ]; then
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# Already matching -- nothing to do.
|
||||
if [ "$WS_UID" = "$DEV_UID" ] && [ "$WS_GID" = "$DEV_GID" ]; then
|
||||
exit 0
|
||||
fi
|
||||
|
||||
if [ "$WS_UID" != "0" ]; then
|
||||
# ── Standard Docker ──────────────────────────────────────────────
|
||||
# Workspace is owned by a non-root UID (the host user).
|
||||
# Remap dev's UID/GID to match.
|
||||
if [ "$DEV_GID" != "$WS_GID" ]; then
|
||||
if ! groupmod -g "$WS_GID" "$TARGET_USER" 2>&1; then
|
||||
echo "warning: failed to remap $TARGET_USER GID to $WS_GID" >&2
|
||||
fi
|
||||
fi
|
||||
if [ "$DEV_UID" != "$WS_UID" ]; then
|
||||
if ! usermod -u "$WS_UID" -g "$WS_GID" "$TARGET_USER" 2>&1; then
|
||||
echo "warning: failed to remap $TARGET_USER UID to $WS_UID" >&2
|
||||
fi
|
||||
fi
|
||||
if ! chown -R "$TARGET_USER":"$TARGET_USER" "$MOUNT_HOME" 2>&1; then
|
||||
echo "warning: failed to chown $MOUNT_HOME" >&2
|
||||
fi
|
||||
else
|
||||
# ── Rootless Docker ──────────────────────────────────────────────
|
||||
# Workspace is root-owned (UID 0) due to user-namespace mapping.
|
||||
# The supported path is remoteUser=root (set DEVCONTAINER_REMOTE_USER=root),
|
||||
# which is handled above. If we reach here, the user is running as dev
|
||||
# under rootless Docker without the override.
|
||||
echo "error: rootless Docker detected but remoteUser is not root." >&2
|
||||
echo " Set DEVCONTAINER_REMOTE_USER=root before starting the container," >&2
|
||||
echo " or use 'ods dev up' which sets it automatically." >&2
|
||||
exit 1
|
||||
fi
|
||||
104
.devcontainer/init-firewall.sh
Executable file
104
.devcontainer/init-firewall.sh
Executable file
@@ -0,0 +1,104 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
echo "Setting up firewall..."
|
||||
|
||||
# Only flush the filter table. The nat and mangle tables are managed by Docker
|
||||
# (DNS DNAT to 127.0.0.11, container networking, etc.) and must not be touched —
|
||||
# flushing them breaks Docker's embedded DNS resolver.
|
||||
iptables -F
|
||||
iptables -X
|
||||
|
||||
# Create ipset for allowed destinations
|
||||
ipset create allowed-domains hash:net || true
|
||||
ipset flush allowed-domains
|
||||
|
||||
# Fetch GitHub IP ranges (IPv4 only -- ipset hash:net and iptables are IPv4)
|
||||
GITHUB_IPS=$(curl -s https://api.github.com/meta | jq -r '.api[]' 2>/dev/null | grep -v ':' || echo "")
|
||||
for ip in $GITHUB_IPS; do
|
||||
if ! ipset add allowed-domains "$ip" -exist 2>&1; then
|
||||
echo "warning: failed to add GitHub IP $ip to allowlist" >&2
|
||||
fi
|
||||
done
|
||||
|
||||
# Resolve allowed domains
|
||||
ALLOWED_DOMAINS=(
|
||||
"github.com"
|
||||
"registry.npmjs.org"
|
||||
"api.anthropic.com"
|
||||
"api-staging.anthropic.com"
|
||||
"files.anthropic.com"
|
||||
"sentry.io"
|
||||
"update.code.visualstudio.com"
|
||||
"pypi.org"
|
||||
"files.pythonhosted.org"
|
||||
"go.dev"
|
||||
"storage.googleapis.com"
|
||||
"static.rust-lang.org"
|
||||
)
|
||||
|
||||
for domain in "${ALLOWED_DOMAINS[@]}"; do
|
||||
IPS=$(getent ahosts "$domain" 2>/dev/null | awk '{print $1}' | grep -v ':' | sort -u || echo "")
|
||||
for ip in $IPS; do
|
||||
if ! ipset add allowed-domains "$ip/32" -exist 2>&1; then
|
||||
echo "warning: failed to add $domain ($ip) to allowlist" >&2
|
||||
fi
|
||||
done
|
||||
done
|
||||
|
||||
# Allow traffic to the Docker gateway so the container can reach host services
|
||||
# (e.g. the Onyx stack at localhost:3000, localhost:8080, etc.)
|
||||
DOCKER_GATEWAY=$(ip -4 route show default | awk '{print $3}')
|
||||
if [ -n "$DOCKER_GATEWAY" ]; then
|
||||
if ! ipset add allowed-domains "$DOCKER_GATEWAY/32" -exist 2>&1; then
|
||||
echo "warning: failed to add Docker gateway $DOCKER_GATEWAY to allowlist" >&2
|
||||
fi
|
||||
fi
|
||||
|
||||
# Allow traffic to all attached Docker network subnets so the container can
|
||||
# reach sibling services (e.g. relational_db, cache) on shared compose networks.
|
||||
for subnet in $(ip -4 -o addr show scope global | awk '{print $4}'); do
|
||||
if ! ipset add allowed-domains "$subnet" -exist 2>&1; then
|
||||
echo "warning: failed to add Docker subnet $subnet to allowlist" >&2
|
||||
fi
|
||||
done
|
||||
|
||||
# Set default policies to DROP
|
||||
iptables -P FORWARD DROP
|
||||
iptables -P INPUT DROP
|
||||
iptables -P OUTPUT DROP
|
||||
|
||||
# Allow established connections
|
||||
iptables -A INPUT -m conntrack --ctstate ESTABLISHED,RELATED -j ACCEPT
|
||||
iptables -A OUTPUT -m conntrack --ctstate ESTABLISHED,RELATED -j ACCEPT
|
||||
|
||||
# Allow loopback
|
||||
iptables -A INPUT -i lo -j ACCEPT
|
||||
iptables -A OUTPUT -o lo -j ACCEPT
|
||||
|
||||
# Allow DNS
|
||||
iptables -A OUTPUT -p udp --dport 53 -j ACCEPT
|
||||
iptables -A OUTPUT -p tcp --dport 53 -j ACCEPT
|
||||
|
||||
# Allow outbound to allowed destinations
|
||||
iptables -A OUTPUT -m set --match-set allowed-domains dst -j ACCEPT
|
||||
|
||||
# Reject unauthorized outbound
|
||||
iptables -A OUTPUT -j REJECT --reject-with icmp-host-unreachable
|
||||
|
||||
# Validate firewall configuration
|
||||
echo "Validating firewall configuration..."
|
||||
|
||||
BLOCKED_SITES=("example.com" "google.com" "facebook.com")
|
||||
for site in "${BLOCKED_SITES[@]}"; do
|
||||
if timeout 2 ping -c 1 "$site" &>/dev/null; then
|
||||
echo "Warning: $site is still reachable"
|
||||
fi
|
||||
done
|
||||
|
||||
if ! timeout 5 curl -s https://api.github.com/meta > /dev/null; then
|
||||
echo "Warning: GitHub API is not accessible"
|
||||
fi
|
||||
|
||||
echo "Firewall setup complete"
|
||||
10
.devcontainer/zshrc
Normal file
10
.devcontainer/zshrc
Normal file
@@ -0,0 +1,10 @@
|
||||
# Devcontainer zshrc — sourced automatically for both root and dev users.
|
||||
# Edit this file to customize the shell without rebuilding the image.
|
||||
|
||||
# Auto-activate Python venv
|
||||
if [ -f /workspace/.venv/bin/activate ]; then
|
||||
. /workspace/.venv/bin/activate
|
||||
fi
|
||||
|
||||
# Source host zshrc if bind-mounted
|
||||
[ -f ~/.zshrc.host ] && . ~/.zshrc.host
|
||||
6
.github/workflows/deployment.yml
vendored
6
.github/workflows/deployment.yml
vendored
@@ -44,7 +44,7 @@ jobs:
|
||||
fetch-tags: true
|
||||
|
||||
- name: Setup uv
|
||||
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # ratchet:astral-sh/setup-uv@v7
|
||||
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # ratchet:astral-sh/setup-uv@v8.0.0
|
||||
with:
|
||||
version: "0.9.9"
|
||||
enable-cache: false
|
||||
@@ -156,7 +156,7 @@ jobs:
|
||||
check-version-tag:
|
||||
runs-on: ubuntu-slim
|
||||
timeout-minutes: 10
|
||||
if: ${{ !startsWith(github.ref_name, 'nightly-latest') && github.event_name != 'workflow_dispatch' }}
|
||||
if: ${{ !startsWith(github.ref_name, 'nightly-latest') && github.ref_name != 'edge' && github.event_name != 'workflow_dispatch' }}
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
@@ -165,7 +165,7 @@ jobs:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Setup uv
|
||||
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # ratchet:astral-sh/setup-uv@v7
|
||||
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # ratchet:astral-sh/setup-uv@v8.0.0
|
||||
with:
|
||||
version: "0.9.9"
|
||||
# NOTE: This isn't caching much and zizmor suggests this could be poisoned, so disable.
|
||||
|
||||
@@ -114,7 +114,7 @@ jobs:
|
||||
ref: main
|
||||
|
||||
- name: Install the latest version of uv
|
||||
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # ratchet:astral-sh/setup-uv@v7
|
||||
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # ratchet:astral-sh/setup-uv@v8.0.0
|
||||
with:
|
||||
enable-cache: false
|
||||
version: "0.9.9"
|
||||
|
||||
4
.github/workflows/pr-playwright-tests.yml
vendored
4
.github/workflows/pr-playwright-tests.yml
vendored
@@ -471,7 +471,7 @@ jobs:
|
||||
|
||||
- name: Install the latest version of uv
|
||||
if: always()
|
||||
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # ratchet:astral-sh/setup-uv@v7
|
||||
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # ratchet:astral-sh/setup-uv@v8.0.0
|
||||
with:
|
||||
enable-cache: false
|
||||
version: "0.9.9"
|
||||
@@ -710,7 +710,7 @@ jobs:
|
||||
pull-requests: write
|
||||
steps:
|
||||
- name: Download visual diff summaries
|
||||
uses: actions/download-artifact@70fc10c6e5e1ce46ad2ea6f2b72d43f7d47b13c3
|
||||
uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c
|
||||
with:
|
||||
pattern: screenshot-diff-summary-*
|
||||
path: summaries/
|
||||
|
||||
2
.github/workflows/pr-quality-checks.yml
vendored
2
.github/workflows/pr-quality-checks.yml
vendored
@@ -38,7 +38,7 @@ jobs:
|
||||
- name: Install node dependencies
|
||||
working-directory: ./web
|
||||
run: npm ci
|
||||
- uses: j178/prek-action@0bb87d7f00b0c99306c8bcb8b8beba1eb581c037 # ratchet:j178/prek-action@v1
|
||||
- uses: j178/prek-action@cbc2f23eb5539cf20d82d1aabd0d0ecbcc56f4e3
|
||||
with:
|
||||
prek-version: '0.3.4'
|
||||
extra-args: ${{ github.event_name == 'pull_request' && format('--from-ref {0} --to-ref {1}', github.event.pull_request.base.sha, github.event.pull_request.head.sha) || github.event_name == 'merge_group' && format('--from-ref {0} --to-ref {1}', github.event.merge_group.base_sha, github.event.merge_group.head_sha) || github.ref_name == 'main' && '--all-files' || '' }}
|
||||
|
||||
2
.github/workflows/release-cli.yml
vendored
2
.github/workflows/release-cli.yml
vendored
@@ -17,7 +17,7 @@ jobs:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # ratchet:astral-sh/setup-uv@v7
|
||||
- uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # ratchet:astral-sh/setup-uv@v8.0.0
|
||||
with:
|
||||
enable-cache: false
|
||||
version: "0.9.9"
|
||||
|
||||
2
.github/workflows/release-devtools.yml
vendored
2
.github/workflows/release-devtools.yml
vendored
@@ -26,7 +26,7 @@ jobs:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # ratchet:astral-sh/setup-uv@v7
|
||||
- uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # ratchet:astral-sh/setup-uv@v8.0.0
|
||||
with:
|
||||
enable-cache: false
|
||||
version: "0.9.9"
|
||||
|
||||
2
.github/workflows/zizmor.yml
vendored
2
.github/workflows/zizmor.yml
vendored
@@ -24,7 +24,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install the latest version of uv
|
||||
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # ratchet:astral-sh/setup-uv@v7
|
||||
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # ratchet:astral-sh/setup-uv@v8.0.0
|
||||
with:
|
||||
enable-cache: false
|
||||
version: "0.9.9"
|
||||
|
||||
@@ -1,64 +1,57 @@
|
||||
{
|
||||
"labels": [],
|
||||
"comment": "",
|
||||
"fixWithAI": true,
|
||||
"hideFooter": false,
|
||||
"strictness": 3,
|
||||
"statusCheck": true,
|
||||
"commentTypes": [
|
||||
"logic",
|
||||
"syntax",
|
||||
"style"
|
||||
],
|
||||
"instructions": "",
|
||||
"disabledLabels": [],
|
||||
"excludeAuthors": [
|
||||
"dependabot[bot]",
|
||||
"renovate[bot]"
|
||||
],
|
||||
"ignoreKeywords": "",
|
||||
"ignorePatterns": "",
|
||||
"includeAuthors": [],
|
||||
"summarySection": {
|
||||
"included": true,
|
||||
"collapsible": false,
|
||||
"defaultOpen": false
|
||||
"labels": [],
|
||||
"comment": "",
|
||||
"fixWithAI": true,
|
||||
"hideFooter": false,
|
||||
"strictness": 3,
|
||||
"statusCheck": true,
|
||||
"commentTypes": ["logic", "syntax", "style"],
|
||||
"instructions": "",
|
||||
"disabledLabels": [],
|
||||
"excludeAuthors": ["dependabot[bot]", "renovate[bot]"],
|
||||
"ignoreKeywords": "",
|
||||
"ignorePatterns": "",
|
||||
"includeAuthors": [],
|
||||
"summarySection": {
|
||||
"included": true,
|
||||
"collapsible": false,
|
||||
"defaultOpen": false
|
||||
},
|
||||
"excludeBranches": [],
|
||||
"fileChangeLimit": 300,
|
||||
"includeBranches": [],
|
||||
"includeKeywords": "",
|
||||
"triggerOnUpdates": false,
|
||||
"updateExistingSummaryComment": true,
|
||||
"updateSummaryOnly": false,
|
||||
"issuesTableSection": {
|
||||
"included": true,
|
||||
"collapsible": false,
|
||||
"defaultOpen": false
|
||||
},
|
||||
"statusCommentsEnabled": true,
|
||||
"confidenceScoreSection": {
|
||||
"included": true,
|
||||
"collapsible": false
|
||||
},
|
||||
"sequenceDiagramSection": {
|
||||
"included": true,
|
||||
"collapsible": false,
|
||||
"defaultOpen": false
|
||||
},
|
||||
"shouldUpdateDescription": false,
|
||||
"rules": [
|
||||
{
|
||||
"scope": ["web/**"],
|
||||
"rule": "In Onyx's Next.js app, the `app/ee/admin/` directory is a filesystem convention for Enterprise Edition route overrides — it does NOT add an `/ee/` prefix to the URL. Both `app/admin/groups/page.tsx` and `app/ee/admin/groups/page.tsx` serve the same URL `/admin/groups`. Hardcoded `/admin/...` paths in router.push() calls are correct and do NOT break EE deployments. Do not flag hardcoded admin paths as bugs."
|
||||
},
|
||||
"excludeBranches": [],
|
||||
"fileChangeLimit": 300,
|
||||
"includeBranches": [],
|
||||
"includeKeywords": "",
|
||||
"triggerOnUpdates": true,
|
||||
"updateExistingSummaryComment": true,
|
||||
"updateSummaryOnly": false,
|
||||
"issuesTableSection": {
|
||||
"included": true,
|
||||
"collapsible": false,
|
||||
"defaultOpen": false
|
||||
{
|
||||
"scope": ["web/**"],
|
||||
"rule": "In Onyx, each API key creates a unique user row in the database with a unique `user_id` (UUID). There is a 1:1 mapping between API keys and their backing user records. Multiple API keys do NOT share the same `user_id`. Do not flag potential duplicate row IDs when using `user_id` from API key descriptors."
|
||||
},
|
||||
"statusCommentsEnabled": true,
|
||||
"confidenceScoreSection": {
|
||||
"included": true,
|
||||
"collapsible": false
|
||||
},
|
||||
"sequenceDiagramSection": {
|
||||
"included": true,
|
||||
"collapsible": false,
|
||||
"defaultOpen": false
|
||||
},
|
||||
"shouldUpdateDescription": false,
|
||||
"rules": [
|
||||
{
|
||||
"scope": ["web/**"],
|
||||
"rule": "In Onyx's Next.js app, the `app/ee/admin/` directory is a filesystem convention for Enterprise Edition route overrides — it does NOT add an `/ee/` prefix to the URL. Both `app/admin/groups/page.tsx` and `app/ee/admin/groups/page.tsx` serve the same URL `/admin/groups`. Hardcoded `/admin/...` paths in router.push() calls are correct and do NOT break EE deployments. Do not flag hardcoded admin paths as bugs."
|
||||
},
|
||||
{
|
||||
"scope": ["web/**"],
|
||||
"rule": "In Onyx, each API key creates a unique user row in the database with a unique `user_id` (UUID). There is a 1:1 mapping between API keys and their backing user records. Multiple API keys do NOT share the same `user_id`. Do not flag potential duplicate row IDs when using `user_id` from API key descriptors."
|
||||
},
|
||||
{
|
||||
"scope": ["backend/**/*.py"],
|
||||
"rule": "Never raise HTTPException directly in business code. Use `raise OnyxError(OnyxErrorCode.XXX, \"message\")` from `onyx.error_handling.exceptions`. A global FastAPI exception handler converts OnyxError into structured JSON responses with {\"error_code\": \"...\", \"detail\": \"...\"}. Error codes are defined in `onyx.error_handling.error_codes.OnyxErrorCode`. For upstream errors with dynamic HTTP status codes, use `status_code_override`: `raise OnyxError(OnyxErrorCode.BAD_GATEWAY, detail, status_code_override=upstream_status)`."
|
||||
}
|
||||
]
|
||||
{
|
||||
"scope": ["backend/**/*.py"],
|
||||
"rule": "Never raise HTTPException directly in business code. Use `raise OnyxError(OnyxErrorCode.XXX, \"message\")` from `onyx.error_handling.exceptions`. A global FastAPI exception handler converts OnyxError into structured JSON responses with {\"error_code\": \"...\", \"detail\": \"...\"}. Error codes are defined in `onyx.error_handling.error_codes.OnyxErrorCode`. For upstream errors with dynamic HTTP status codes, use `status_code_override`: `raise OnyxError(OnyxErrorCode.BAD_GATEWAY, detail, status_code_override=upstream_status)`."
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
@@ -9,7 +9,6 @@ repos:
|
||||
rev: d30b4298e4fb63ce8609e29acdbcf4c9018a483c
|
||||
hooks:
|
||||
- id: uv-sync
|
||||
args: ["--locked", "--all-extras"]
|
||||
- id: uv-lock
|
||||
- id: uv-export
|
||||
name: uv-export default.txt
|
||||
@@ -18,7 +17,7 @@ repos:
|
||||
"--no-emit-project",
|
||||
"--no-default-groups",
|
||||
"--no-hashes",
|
||||
"--extra",
|
||||
"--group",
|
||||
"backend",
|
||||
"-o",
|
||||
"backend/requirements/default.txt",
|
||||
@@ -31,7 +30,7 @@ repos:
|
||||
"--no-emit-project",
|
||||
"--no-default-groups",
|
||||
"--no-hashes",
|
||||
"--extra",
|
||||
"--group",
|
||||
"dev",
|
||||
"-o",
|
||||
"backend/requirements/dev.txt",
|
||||
@@ -44,7 +43,7 @@ repos:
|
||||
"--no-emit-project",
|
||||
"--no-default-groups",
|
||||
"--no-hashes",
|
||||
"--extra",
|
||||
"--group",
|
||||
"ee",
|
||||
"-o",
|
||||
"backend/requirements/ee.txt",
|
||||
@@ -57,7 +56,7 @@ repos:
|
||||
"--no-emit-project",
|
||||
"--no-default-groups",
|
||||
"--no-hashes",
|
||||
"--extra",
|
||||
"--group",
|
||||
"model_server",
|
||||
"-o",
|
||||
"backend/requirements/model_server.txt",
|
||||
|
||||
15
.vscode/launch.json
vendored
15
.vscode/launch.json
vendored
@@ -475,6 +475,18 @@
|
||||
"order": 0
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Start Monitoring Stack (Prometheus + Grafana)",
|
||||
"type": "node",
|
||||
"request": "launch",
|
||||
"runtimeExecutable": "docker",
|
||||
"runtimeArgs": ["compose", "up", "-d"],
|
||||
"cwd": "${workspaceFolder}/profiling",
|
||||
"console": "integratedTerminal",
|
||||
"presentation": {
|
||||
"group": "3"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Clear and Restart External Volumes and Containers",
|
||||
"type": "node",
|
||||
@@ -531,8 +543,7 @@
|
||||
"request": "launch",
|
||||
"runtimeExecutable": "uv",
|
||||
"runtimeArgs": [
|
||||
"sync",
|
||||
"--all-extras"
|
||||
"sync"
|
||||
],
|
||||
"cwd": "${workspaceFolder}",
|
||||
"console": "integratedTerminal",
|
||||
|
||||
@@ -49,12 +49,12 @@ Onyx uses Celery for asynchronous task processing with multiple specialized work
|
||||
|
||||
4. **Light Worker** (`light`)
|
||||
- Handles lightweight, fast operations
|
||||
- Tasks: vespa operations, document permissions sync, external group sync
|
||||
- Tasks: vespa metadata sync, connector deletion, doc permissions upsert, checkpoint cleanup, index attempt cleanup
|
||||
- Higher concurrency for quick tasks
|
||||
|
||||
5. **Heavy Worker** (`heavy`)
|
||||
- Handles resource-intensive operations
|
||||
- Primary task: document pruning operations
|
||||
- Tasks: connector pruning, document permissions sync, external group sync, CSV generation
|
||||
- Runs with 4 threads concurrency
|
||||
|
||||
6. **KG Processing Worker** (`kg_processing`)
|
||||
|
||||
@@ -117,7 +117,7 @@ If using PowerShell, the command slightly differs:
|
||||
Install the required Python dependencies:
|
||||
|
||||
```bash
|
||||
uv sync --all-extras
|
||||
uv sync
|
||||
```
|
||||
|
||||
Install Playwright for Python (headless browser required by the Web Connector):
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
FROM python:3.11.7-slim-bookworm
|
||||
FROM python:3.11-slim-bookworm@sha256:9c6f90801e6b68e772b7c0ca74260cbf7af9f320acec894e26fccdaccfbe3b47
|
||||
|
||||
LABEL com.danswer.maintainer="founders@onyx.app"
|
||||
LABEL com.danswer.description="This image is the web/frontend container of Onyx which \
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# Base stage with dependencies
|
||||
FROM python:3.11.7-slim-bookworm AS base
|
||||
FROM python:3.11-slim-bookworm@sha256:9c6f90801e6b68e772b7c0ca74260cbf7af9f320acec894e26fccdaccfbe3b47 AS base
|
||||
|
||||
ENV DANSWER_RUNNING_IN_DOCKER="true" \
|
||||
HF_HOME=/app/.cache/huggingface
|
||||
|
||||
@@ -208,7 +208,7 @@ def do_run_migrations(
|
||||
|
||||
context.configure(
|
||||
connection=connection,
|
||||
target_metadata=target_metadata, # type: ignore
|
||||
target_metadata=target_metadata,
|
||||
version_table_schema=schema_name,
|
||||
include_schemas=True,
|
||||
compare_type=True,
|
||||
@@ -380,7 +380,7 @@ def run_migrations_offline() -> None:
|
||||
logger.info(f"Migrating schema: {schema}")
|
||||
context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata, # type: ignore
|
||||
target_metadata=target_metadata,
|
||||
literal_binds=True,
|
||||
version_table_schema=schema,
|
||||
include_schemas=True,
|
||||
@@ -421,7 +421,7 @@ def run_migrations_offline() -> None:
|
||||
logger.info(f"Migrating schema: {schema}")
|
||||
context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata, # type: ignore
|
||||
target_metadata=target_metadata,
|
||||
literal_binds=True,
|
||||
version_table_schema=schema,
|
||||
include_schemas=True,
|
||||
@@ -464,7 +464,7 @@ def run_migrations_online() -> None:
|
||||
|
||||
context.configure(
|
||||
connection=connection,
|
||||
target_metadata=target_metadata, # type: ignore
|
||||
target_metadata=target_metadata,
|
||||
version_table_schema=schema_name,
|
||||
include_schemas=True,
|
||||
compare_type=True,
|
||||
|
||||
@@ -25,7 +25,7 @@ def upgrade() -> None:
|
||||
|
||||
# Use batch mode to modify the enum type
|
||||
with op.batch_alter_table("user", schema=None) as batch_op:
|
||||
batch_op.alter_column( # type: ignore[attr-defined]
|
||||
batch_op.alter_column(
|
||||
"role",
|
||||
type_=sa.Enum(
|
||||
"BASIC",
|
||||
@@ -71,7 +71,7 @@ def downgrade() -> None:
|
||||
op.drop_column("user__user_group", "is_curator")
|
||||
|
||||
with op.batch_alter_table("user", schema=None) as batch_op:
|
||||
batch_op.alter_column( # type: ignore[attr-defined]
|
||||
batch_op.alter_column(
|
||||
"role",
|
||||
type_=sa.Enum(
|
||||
"BASIC", "ADMIN", name="userrole", native_enum=False, length=20
|
||||
|
||||
@@ -63,7 +63,7 @@ def upgrade() -> None:
|
||||
"time_created",
|
||||
existing_type=postgresql.TIMESTAMP(timezone=True),
|
||||
nullable=False,
|
||||
existing_server_default=sa.text("now()"), # type: ignore
|
||||
existing_server_default=sa.text("now()"),
|
||||
)
|
||||
op.alter_column(
|
||||
"index_attempt",
|
||||
@@ -85,7 +85,7 @@ def downgrade() -> None:
|
||||
"time_created",
|
||||
existing_type=postgresql.TIMESTAMP(timezone=True),
|
||||
nullable=True,
|
||||
existing_server_default=sa.text("now()"), # type: ignore
|
||||
existing_server_default=sa.text("now()"),
|
||||
)
|
||||
op.drop_index(op.f("ix_accesstoken_created_at"), table_name="accesstoken")
|
||||
op.drop_table("accesstoken")
|
||||
|
||||
@@ -19,7 +19,7 @@ depends_on: None = None
|
||||
|
||||
def upgrade() -> None:
|
||||
sequence = Sequence("connector_credential_pair_id_seq")
|
||||
op.execute(CreateSequence(sequence)) # type: ignore
|
||||
op.execute(CreateSequence(sequence))
|
||||
op.add_column(
|
||||
"connector_credential_pair",
|
||||
sa.Column(
|
||||
|
||||
@@ -0,0 +1,28 @@
|
||||
"""add_error_tracking_fields_to_index_attempt_errors
|
||||
|
||||
Revision ID: d129f37b3d87
|
||||
Revises: 503883791c39
|
||||
Create Date: 2026-04-06 19:11:18.261800
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "d129f37b3d87"
|
||||
down_revision = "503883791c39"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"index_attempt_errors",
|
||||
sa.Column("error_type", sa.String(), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("index_attempt_errors", "error_type")
|
||||
@@ -49,7 +49,7 @@ def run_migrations_offline() -> None:
|
||||
url = build_connection_string()
|
||||
context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata, # type: ignore
|
||||
target_metadata=target_metadata,
|
||||
literal_binds=True,
|
||||
dialect_opts={"paramstyle": "named"},
|
||||
)
|
||||
@@ -61,7 +61,7 @@ def run_migrations_offline() -> None:
|
||||
def do_run_migrations(connection: Connection) -> None:
|
||||
context.configure(
|
||||
connection=connection,
|
||||
target_metadata=target_metadata, # type: ignore[arg-type]
|
||||
target_metadata=target_metadata,
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
|
||||
@@ -13,6 +13,7 @@ from ee.onyx.server.license.models import LicenseSource
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.cache.factory import get_cache_backend
|
||||
from onyx.configs.constants import ANONYMOUS_USER_EMAIL
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.models import License
|
||||
from onyx.db.models import User
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -107,12 +108,13 @@ def get_used_seats(tenant_id: str | None = None) -> int:
|
||||
Get current seat usage directly from database.
|
||||
|
||||
For multi-tenant: counts users in UserTenantMapping for this tenant.
|
||||
For self-hosted: counts all active users (excludes EXT_PERM_USER role
|
||||
and the anonymous system user).
|
||||
For self-hosted: counts all active users.
|
||||
|
||||
TODO: Exclude API key dummy users from seat counting. API keys create
|
||||
users with emails like `__DANSWER_API_KEY_*` that should not count toward
|
||||
seat limits. See: https://linear.app/onyx-app/issue/ENG-3518
|
||||
Only human accounts count toward seat limits.
|
||||
SERVICE_ACCOUNT (API key dummy users), EXT_PERM_USER, and the
|
||||
anonymous system user are excluded. BOT (Slack users) ARE counted
|
||||
because they represent real humans and get upgraded to STANDARD
|
||||
when they log in via web.
|
||||
"""
|
||||
if MULTI_TENANT:
|
||||
from ee.onyx.server.tenants.user_mapping import get_tenant_count
|
||||
@@ -129,6 +131,7 @@ def get_used_seats(tenant_id: str | None = None) -> int:
|
||||
User.is_active == True, # type: ignore # noqa: E712
|
||||
User.role != UserRole.EXT_PERM_USER,
|
||||
User.email != ANONYMOUS_USER_EMAIL, # type: ignore
|
||||
User.account_type != AccountType.SERVICE_ACCOUNT,
|
||||
)
|
||||
)
|
||||
return result.scalar() or 0
|
||||
|
||||
@@ -11,6 +11,8 @@ require a valid SCIM bearer token.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import struct
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter
|
||||
@@ -22,6 +24,7 @@ from fastapi import Response
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi_users.password import PasswordHelper
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -65,12 +68,25 @@ from onyx.db.permissions import recompute_user_permissions__no_commit
|
||||
from onyx.db.users import assign_user_to_default_groups__no_commit
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# Group names reserved for system default groups (seeded by migration).
|
||||
_RESERVED_GROUP_NAMES = frozenset({"Admin", "Basic"})
|
||||
|
||||
# Namespace prefix for the seat-allocation advisory lock. Hashed together
|
||||
# with the tenant ID so the lock is scoped per-tenant (unrelated tenants
|
||||
# never block each other) and cannot collide with unrelated advisory locks.
|
||||
_SEAT_LOCK_NAMESPACE = "onyx_scim_seat_lock"
|
||||
|
||||
|
||||
def _seat_lock_id_for_tenant(tenant_id: str) -> int:
|
||||
"""Derive a stable 64-bit signed int lock id for this tenant's seat lock."""
|
||||
digest = hashlib.sha256(f"{_SEAT_LOCK_NAMESPACE}:{tenant_id}".encode()).digest()
|
||||
# pg_advisory_xact_lock takes a signed 8-byte int; unpack as such.
|
||||
return struct.unpack("q", digest[:8])[0]
|
||||
|
||||
|
||||
class ScimJSONResponse(JSONResponse):
|
||||
"""JSONResponse with Content-Type: application/scim+json (RFC 7644 §3.1)."""
|
||||
@@ -209,12 +225,37 @@ def _apply_exclusions(
|
||||
|
||||
|
||||
def _check_seat_availability(dal: ScimDAL) -> str | None:
|
||||
"""Return an error message if seat limit is reached, else None."""
|
||||
"""Return an error message if seat limit is reached, else None.
|
||||
|
||||
Acquires a transaction-scoped advisory lock so that concurrent
|
||||
SCIM requests are serialized. IdPs like Okta send provisioning
|
||||
requests in parallel batches — without serialization the check is
|
||||
vulnerable to a TOCTOU race where N concurrent requests each see
|
||||
"seats available", all insert, and the tenant ends up over its
|
||||
seat limit.
|
||||
|
||||
The lock is held until the caller's next COMMIT or ROLLBACK, which
|
||||
means the seat count cannot change between the check here and the
|
||||
subsequent INSERT/UPDATE. Each call site in this module follows
|
||||
the pattern: _check_seat_availability → write → dal.commit()
|
||||
(which releases the lock for the next waiting request).
|
||||
"""
|
||||
check_fn = fetch_ee_implementation_or_noop(
|
||||
"onyx.db.license", "check_seat_availability", None
|
||||
)
|
||||
if check_fn is None:
|
||||
return None
|
||||
|
||||
# Transaction-scoped advisory lock — released on dal.commit() / dal.rollback().
|
||||
# The lock id is derived from the tenant so unrelated tenants never block
|
||||
# each other, and from a namespace string so it cannot collide with
|
||||
# unrelated advisory locks elsewhere in the codebase.
|
||||
lock_id = _seat_lock_id_for_tenant(get_current_tenant_id())
|
||||
dal.session.execute(
|
||||
text("SELECT pg_advisory_xact_lock(:lock_id)"),
|
||||
{"lock_id": lock_id},
|
||||
)
|
||||
|
||||
result = check_fn(dal.session, seats_needed=1)
|
||||
if not result.available:
|
||||
return result.error_message or "Seat limit reached"
|
||||
|
||||
@@ -96,11 +96,14 @@ def get_model_app() -> FastAPI:
|
||||
title="Onyx Model Server", version=__version__, lifespan=lifespan
|
||||
)
|
||||
if SENTRY_DSN:
|
||||
from onyx.configs.sentry import _add_instance_tags
|
||||
|
||||
sentry_sdk.init(
|
||||
dsn=SENTRY_DSN,
|
||||
integrations=[StarletteIntegration(), FastApiIntegration()],
|
||||
traces_sample_rate=0.1,
|
||||
release=__version__,
|
||||
before_send=_add_instance_tags,
|
||||
)
|
||||
logger.info("Sentry initialized")
|
||||
else:
|
||||
|
||||
@@ -10,6 +10,7 @@ from celery import bootsteps # type: ignore
|
||||
from celery import Task
|
||||
from celery.app import trace
|
||||
from celery.exceptions import WorkerShutdown
|
||||
from celery.signals import before_task_publish
|
||||
from celery.signals import task_postrun
|
||||
from celery.signals import task_prerun
|
||||
from celery.states import READY_STATES
|
||||
@@ -62,11 +63,14 @@ logger = setup_logger()
|
||||
task_logger = get_task_logger(__name__)
|
||||
|
||||
if SENTRY_DSN:
|
||||
from onyx.configs.sentry import _add_instance_tags
|
||||
|
||||
sentry_sdk.init(
|
||||
dsn=SENTRY_DSN,
|
||||
integrations=[CeleryIntegration()],
|
||||
traces_sample_rate=0.1,
|
||||
release=__version__,
|
||||
before_send=_add_instance_tags,
|
||||
)
|
||||
logger.info("Sentry initialized")
|
||||
else:
|
||||
@@ -94,6 +98,17 @@ class TenantAwareTask(Task):
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(None)
|
||||
|
||||
|
||||
@before_task_publish.connect
|
||||
def on_before_task_publish(
|
||||
headers: dict[str, Any] | None = None,
|
||||
**kwargs: Any, # noqa: ARG001
|
||||
) -> None:
|
||||
"""Stamp the current wall-clock time into the task message headers so that
|
||||
workers can compute queue wait time (time between publish and execution)."""
|
||||
if headers is not None:
|
||||
headers["enqueued_at"] = time.time()
|
||||
|
||||
|
||||
@task_prerun.connect
|
||||
def on_task_prerun(
|
||||
sender: Any | None = None, # noqa: ARG001
|
||||
|
||||
@@ -16,6 +16,12 @@ from onyx.configs.app_configs import VESPA_CLOUD_CERT_PATH
|
||||
from onyx.configs.app_configs import VESPA_CLOUD_KEY_PATH
|
||||
from onyx.configs.constants import POSTGRES_CELERY_WORKER_LIGHT_APP_NAME
|
||||
from onyx.db.engine.sql_engine import SqlEngine
|
||||
from onyx.server.metrics.celery_task_metrics import on_celery_task_postrun
|
||||
from onyx.server.metrics.celery_task_metrics import on_celery_task_prerun
|
||||
from onyx.server.metrics.celery_task_metrics import on_celery_task_rejected
|
||||
from onyx.server.metrics.celery_task_metrics import on_celery_task_retry
|
||||
from onyx.server.metrics.celery_task_metrics import on_celery_task_revoked
|
||||
from onyx.server.metrics.metrics_server import start_metrics_server
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
@@ -36,6 +42,7 @@ def on_task_prerun(
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
app_base.on_task_prerun(sender, task_id, task, args, kwargs, **kwds)
|
||||
on_celery_task_prerun(task_id, task)
|
||||
|
||||
|
||||
@signals.task_postrun.connect
|
||||
@@ -50,6 +57,31 @@ def on_task_postrun(
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
app_base.on_task_postrun(sender, task_id, task, args, kwargs, retval, state, **kwds)
|
||||
on_celery_task_postrun(task_id, task, state)
|
||||
|
||||
|
||||
@signals.task_retry.connect
|
||||
def on_task_retry(sender: Any | None = None, **kwargs: Any) -> None: # noqa: ARG001
|
||||
task_id = getattr(getattr(sender, "request", None), "id", None)
|
||||
on_celery_task_retry(task_id, sender)
|
||||
|
||||
|
||||
@signals.task_revoked.connect
|
||||
def on_task_revoked(sender: Any | None = None, **kwargs: Any) -> None:
|
||||
task_name = getattr(sender, "name", None) or str(sender)
|
||||
on_celery_task_revoked(kwargs.get("task_id"), task_name)
|
||||
|
||||
|
||||
@signals.task_rejected.connect
|
||||
def on_task_rejected(sender: Any | None = None, **kwargs: Any) -> None: # noqa: ARG001
|
||||
message = kwargs.get("message")
|
||||
task_name: str | None = None
|
||||
if message is not None:
|
||||
headers = getattr(message, "headers", None) or {}
|
||||
task_name = headers.get("task")
|
||||
if task_name is None:
|
||||
task_name = "unknown"
|
||||
on_celery_task_rejected(None, task_name)
|
||||
|
||||
|
||||
@celeryd_init.connect
|
||||
@@ -90,6 +122,7 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
|
||||
@worker_ready.connect
|
||||
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
|
||||
start_metrics_server("light")
|
||||
app_base.on_worker_ready(sender, **kwargs)
|
||||
|
||||
|
||||
|
||||
@@ -59,6 +59,11 @@ from onyx.redis.redis_connector_delete import RedisConnectorDelete
|
||||
from onyx.redis.redis_connector_delete import RedisConnectorDeletePayload
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_pool import get_redis_replica_client
|
||||
from onyx.server.metrics.deletion_metrics import inc_deletion_blocked
|
||||
from onyx.server.metrics.deletion_metrics import inc_deletion_completed
|
||||
from onyx.server.metrics.deletion_metrics import inc_deletion_fence_reset
|
||||
from onyx.server.metrics.deletion_metrics import inc_deletion_started
|
||||
from onyx.server.metrics.deletion_metrics import observe_deletion_taskset_duration
|
||||
from onyx.utils.variable_functionality import (
|
||||
fetch_versioned_implementation_with_fallback,
|
||||
)
|
||||
@@ -102,7 +107,7 @@ def revoke_tasks_blocking_deletion(
|
||||
f"Revoked permissions sync task {permissions_sync_payload.celery_task_id}."
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception("Exception while revoking pruning task")
|
||||
task_logger.exception("Exception while revoking permissions sync task")
|
||||
|
||||
try:
|
||||
prune_payload = redis_connector.prune.payload
|
||||
@@ -110,7 +115,7 @@ def revoke_tasks_blocking_deletion(
|
||||
app.control.revoke(prune_payload.celery_task_id)
|
||||
task_logger.info(f"Revoked pruning task {prune_payload.celery_task_id}.")
|
||||
except Exception:
|
||||
task_logger.exception("Exception while revoking permissions sync task")
|
||||
task_logger.exception("Exception while revoking pruning task")
|
||||
|
||||
try:
|
||||
external_group_sync_payload = redis_connector.external_group_sync.payload
|
||||
@@ -300,6 +305,7 @@ def try_generate_document_cc_pair_cleanup_tasks(
|
||||
recent_index_attempts
|
||||
and recent_index_attempts[0].status == IndexingStatus.IN_PROGRESS
|
||||
):
|
||||
inc_deletion_blocked(tenant_id, "indexing")
|
||||
raise TaskDependencyError(
|
||||
"Connector deletion - Delayed (indexing in progress): "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
@@ -307,11 +313,13 @@ def try_generate_document_cc_pair_cleanup_tasks(
|
||||
)
|
||||
|
||||
if redis_connector.prune.fenced:
|
||||
inc_deletion_blocked(tenant_id, "pruning")
|
||||
raise TaskDependencyError(
|
||||
f"Connector deletion - Delayed (pruning in progress): cc_pair={cc_pair_id}"
|
||||
)
|
||||
|
||||
if redis_connector.permissions.fenced:
|
||||
inc_deletion_blocked(tenant_id, "permissions")
|
||||
raise TaskDependencyError(
|
||||
f"Connector deletion - Delayed (permissions in progress): cc_pair={cc_pair_id}"
|
||||
)
|
||||
@@ -359,6 +367,7 @@ def try_generate_document_cc_pair_cleanup_tasks(
|
||||
# set this only after all tasks have been added
|
||||
fence_payload.num_tasks = tasks_generated
|
||||
redis_connector.delete.set_fence(fence_payload)
|
||||
inc_deletion_started(tenant_id)
|
||||
|
||||
return tasks_generated
|
||||
|
||||
@@ -508,7 +517,11 @@ def monitor_connector_deletion_taskset(
|
||||
db_session=db_session,
|
||||
connector_id=connector_id_to_delete,
|
||||
)
|
||||
if not connector or not len(connector.credentials):
|
||||
if not connector:
|
||||
task_logger.info(
|
||||
"Connector deletion - Connector already deleted, skipping connector cleanup"
|
||||
)
|
||||
elif not len(connector.credentials):
|
||||
task_logger.info(
|
||||
"Connector deletion - Found no credentials left for connector, deleting connector"
|
||||
)
|
||||
@@ -523,6 +536,12 @@ def monitor_connector_deletion_taskset(
|
||||
num_docs_synced=fence_data.num_tasks,
|
||||
)
|
||||
|
||||
duration = (
|
||||
datetime.now(timezone.utc) - fence_data.submitted
|
||||
).total_seconds()
|
||||
observe_deletion_taskset_duration(tenant_id, "success", duration)
|
||||
inc_deletion_completed(tenant_id, "success")
|
||||
|
||||
except Exception as e:
|
||||
db_session.rollback()
|
||||
stack_trace = traceback.format_exc()
|
||||
@@ -541,6 +560,11 @@ def monitor_connector_deletion_taskset(
|
||||
f"Connector deletion exceptioned: "
|
||||
f"cc_pair={cc_pair_id} connector={connector_id_to_delete} credential={credential_id_to_delete}"
|
||||
)
|
||||
duration = (
|
||||
datetime.now(timezone.utc) - fence_data.submitted
|
||||
).total_seconds()
|
||||
observe_deletion_taskset_duration(tenant_id, "failure", duration)
|
||||
inc_deletion_completed(tenant_id, "failure")
|
||||
raise e
|
||||
|
||||
task_logger.info(
|
||||
@@ -717,5 +741,6 @@ def validate_connector_deletion_fence(
|
||||
f"fence={fence_key}"
|
||||
)
|
||||
|
||||
inc_deletion_fence_reset(tenant_id)
|
||||
redis_connector.delete.reset()
|
||||
return
|
||||
|
||||
@@ -135,10 +135,13 @@ def _docfetching_task(
|
||||
# Since connector_indexing_proxy_task spawns a new process using this function as
|
||||
# the entrypoint, we init Sentry here.
|
||||
if SENTRY_DSN:
|
||||
from onyx.configs.sentry import _add_instance_tags
|
||||
|
||||
sentry_sdk.init(
|
||||
dsn=SENTRY_DSN,
|
||||
traces_sample_rate=0.1,
|
||||
release=__version__,
|
||||
before_send=_add_instance_tags,
|
||||
)
|
||||
logger.info("Sentry initialized")
|
||||
else:
|
||||
|
||||
@@ -3,6 +3,7 @@ import os
|
||||
import time
|
||||
import traceback
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
@@ -50,6 +51,7 @@ from onyx.configs.constants import AuthType
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_INDEXING_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import MilestoneRecordType
|
||||
from onyx.configs.constants import NotificationType
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
@@ -85,6 +87,8 @@ from onyx.db.indexing_coordination import INDEXING_PROGRESS_TIMEOUT_HOURS
|
||||
from onyx.db.indexing_coordination import IndexingCoordination
|
||||
from onyx.db.models import IndexAttempt
|
||||
from onyx.db.models import SearchSettings
|
||||
from onyx.db.notification import create_notification
|
||||
from onyx.db.notification import get_notifications
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.db.search_settings import get_secondary_search_settings
|
||||
from onyx.db.swap_index import check_and_perform_index_swap
|
||||
@@ -105,6 +109,9 @@ from onyx.redis.redis_pool import get_redis_replica_client
|
||||
from onyx.redis.redis_pool import redis_lock_dump
|
||||
from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT
|
||||
from onyx.redis.redis_utils import is_fence
|
||||
from onyx.server.metrics.connector_health_metrics import on_connector_error_state_change
|
||||
from onyx.server.metrics.connector_health_metrics import on_connector_indexing_success
|
||||
from onyx.server.metrics.connector_health_metrics import on_index_attempt_status_change
|
||||
from onyx.server.runtime.onyx_runtime import OnyxRuntime
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.middleware import make_randomized_onyx_request_id
|
||||
@@ -400,7 +407,6 @@ def check_indexing_completion(
|
||||
tenant_id: str,
|
||||
task: Task,
|
||||
) -> None:
|
||||
|
||||
logger.info(
|
||||
f"Checking for indexing completion: attempt={index_attempt_id} tenant={tenant_id}"
|
||||
)
|
||||
@@ -521,13 +527,23 @@ def check_indexing_completion(
|
||||
|
||||
# Update CC pair status if successful
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
db_session, attempt.connector_credential_pair_id
|
||||
db_session,
|
||||
attempt.connector_credential_pair_id,
|
||||
eager_load_connector=True,
|
||||
)
|
||||
if cc_pair is None:
|
||||
raise RuntimeError(
|
||||
f"CC pair {attempt.connector_credential_pair_id} not found in database"
|
||||
)
|
||||
|
||||
source = cc_pair.connector.source.value
|
||||
on_index_attempt_status_change(
|
||||
tenant_id=tenant_id,
|
||||
source=source,
|
||||
cc_pair_id=cc_pair.id,
|
||||
status=attempt.status.value,
|
||||
)
|
||||
|
||||
if attempt.status.is_successful():
|
||||
# NOTE: we define the last successful index time as the time the last successful
|
||||
# attempt finished. This is distinct from the poll_range_end of the last successful
|
||||
@@ -548,10 +564,39 @@ def check_indexing_completion(
|
||||
event=MilestoneRecordType.CONNECTOR_SUCCEEDED,
|
||||
)
|
||||
|
||||
on_connector_indexing_success(
|
||||
tenant_id=tenant_id,
|
||||
source=source,
|
||||
cc_pair_id=cc_pair.id,
|
||||
docs_indexed=attempt.new_docs_indexed or 0,
|
||||
success_timestamp=attempt.time_updated.timestamp(),
|
||||
)
|
||||
|
||||
# Clear repeated error state on success
|
||||
if cc_pair.in_repeated_error_state:
|
||||
cc_pair.in_repeated_error_state = False
|
||||
|
||||
# Delete any existing error notification for this CC pair so a
|
||||
# fresh one is created if the connector fails again later.
|
||||
for notif in get_notifications(
|
||||
user=None,
|
||||
db_session=db_session,
|
||||
notif_type=NotificationType.CONNECTOR_REPEATED_ERRORS,
|
||||
include_dismissed=True,
|
||||
):
|
||||
if (
|
||||
notif.additional_data
|
||||
and notif.additional_data.get("cc_pair_id") == cc_pair.id
|
||||
):
|
||||
db_session.delete(notif)
|
||||
|
||||
db_session.commit()
|
||||
on_connector_error_state_change(
|
||||
tenant_id=tenant_id,
|
||||
source=source,
|
||||
cc_pair_id=cc_pair.id,
|
||||
in_error=False,
|
||||
)
|
||||
|
||||
if attempt.status == IndexingStatus.SUCCESS:
|
||||
logger.info(
|
||||
@@ -608,6 +653,27 @@ def active_indexing_attempt(
|
||||
return bool(active_indexing_attempt)
|
||||
|
||||
|
||||
@dataclass
|
||||
class _KickoffResult:
|
||||
"""Tracks diagnostic counts from a _kickoff_indexing_tasks run."""
|
||||
|
||||
created: int = 0
|
||||
skipped_active: int = 0
|
||||
skipped_not_found: int = 0
|
||||
skipped_not_indexable: int = 0
|
||||
failed_to_create: int = 0
|
||||
|
||||
@property
|
||||
def evaluated(self) -> int:
|
||||
return (
|
||||
self.created
|
||||
+ self.skipped_active
|
||||
+ self.skipped_not_found
|
||||
+ self.skipped_not_indexable
|
||||
+ self.failed_to_create
|
||||
)
|
||||
|
||||
|
||||
def _kickoff_indexing_tasks(
|
||||
celery_app: Celery,
|
||||
db_session: Session,
|
||||
@@ -617,12 +683,12 @@ def _kickoff_indexing_tasks(
|
||||
redis_client: Redis,
|
||||
lock_beat: RedisLock,
|
||||
tenant_id: str,
|
||||
) -> int:
|
||||
) -> _KickoffResult:
|
||||
"""Kick off indexing tasks for the given cc_pair_ids and search_settings.
|
||||
|
||||
Returns the number of tasks successfully created.
|
||||
Returns a _KickoffResult with diagnostic counts.
|
||||
"""
|
||||
tasks_created = 0
|
||||
result = _KickoffResult()
|
||||
|
||||
for cc_pair_id in cc_pair_ids:
|
||||
lock_beat.reacquire()
|
||||
@@ -633,6 +699,7 @@ def _kickoff_indexing_tasks(
|
||||
search_settings_id=search_settings.id,
|
||||
db_session=db_session,
|
||||
):
|
||||
result.skipped_active += 1
|
||||
continue
|
||||
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
@@ -643,6 +710,7 @@ def _kickoff_indexing_tasks(
|
||||
task_logger.warning(
|
||||
f"_kickoff_indexing_tasks - CC pair not found: cc_pair={cc_pair_id}"
|
||||
)
|
||||
result.skipped_not_found += 1
|
||||
continue
|
||||
|
||||
# Heavyweight check after fetching cc pair
|
||||
@@ -657,6 +725,7 @@ def _kickoff_indexing_tasks(
|
||||
f"search_settings={search_settings.id}, "
|
||||
f"secondary_index_building={secondary_index_building}"
|
||||
)
|
||||
result.skipped_not_indexable += 1
|
||||
continue
|
||||
|
||||
task_logger.debug(
|
||||
@@ -696,13 +765,14 @@ def _kickoff_indexing_tasks(
|
||||
task_logger.info(
|
||||
f"Connector indexing queued: index_attempt={attempt_id} cc_pair={cc_pair.id} search_settings={search_settings.id}"
|
||||
)
|
||||
tasks_created += 1
|
||||
result.created += 1
|
||||
else:
|
||||
task_logger.error(
|
||||
f"Failed to create indexing task: cc_pair={cc_pair.id} search_settings={search_settings.id}"
|
||||
)
|
||||
result.failed_to_create += 1
|
||||
|
||||
return tasks_created
|
||||
return result
|
||||
|
||||
|
||||
@shared_task(
|
||||
@@ -728,6 +798,8 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
|
||||
task_logger.warning("check_for_indexing - Starting")
|
||||
|
||||
tasks_created = 0
|
||||
primary_result = _KickoffResult()
|
||||
secondary_result: _KickoffResult | None = None
|
||||
locked = False
|
||||
redis_client = get_redis_client()
|
||||
redis_client_replica = get_redis_replica_client()
|
||||
@@ -848,6 +920,39 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
|
||||
cc_pair_id=cc_pair_id,
|
||||
in_repeated_error_state=True,
|
||||
)
|
||||
on_connector_error_state_change(
|
||||
tenant_id=tenant_id,
|
||||
source=cc_pair.connector.source.value,
|
||||
cc_pair_id=cc_pair_id,
|
||||
in_error=True,
|
||||
)
|
||||
|
||||
connector_name = (
|
||||
cc_pair.name
|
||||
or cc_pair.connector.name
|
||||
or f"CC pair {cc_pair.id}"
|
||||
)
|
||||
source = cc_pair.connector.source.value
|
||||
connector_url = f"/admin/connector/{cc_pair.id}"
|
||||
create_notification(
|
||||
user_id=None,
|
||||
notif_type=NotificationType.CONNECTOR_REPEATED_ERRORS,
|
||||
db_session=db_session,
|
||||
title=f"Connector '{connector_name}' has entered repeated error state",
|
||||
description=(
|
||||
f"The {source} connector has failed repeatedly and "
|
||||
f"has been flagged. View indexing history in the "
|
||||
f"Advanced section: {connector_url}"
|
||||
),
|
||||
additional_data={"cc_pair_id": cc_pair.id},
|
||||
)
|
||||
|
||||
task_logger.error(
|
||||
f"Connector entered repeated error state: "
|
||||
f"cc_pair={cc_pair.id} "
|
||||
f"connector={cc_pair.connector.name} "
|
||||
f"source={source}"
|
||||
)
|
||||
# When entering repeated error state, also pause the connector
|
||||
# to prevent continued indexing retry attempts burning through embedding credits.
|
||||
# NOTE: only for Cloud, since most self-hosted users use self-hosted embedding
|
||||
@@ -863,7 +968,7 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
|
||||
# Heavy check, should_index(), is called in _kickoff_indexing_tasks
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
# Primary first
|
||||
tasks_created += _kickoff_indexing_tasks(
|
||||
primary_result = _kickoff_indexing_tasks(
|
||||
celery_app=self.app,
|
||||
db_session=db_session,
|
||||
search_settings=current_search_settings,
|
||||
@@ -873,6 +978,7 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
|
||||
lock_beat=lock_beat,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
tasks_created += primary_result.created
|
||||
|
||||
# Secondary indexing (only if secondary search settings exist and switchover_type is not INSTANT)
|
||||
if (
|
||||
@@ -880,7 +986,7 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
|
||||
and secondary_search_settings.switchover_type != SwitchoverType.INSTANT
|
||||
and secondary_cc_pair_ids
|
||||
):
|
||||
tasks_created += _kickoff_indexing_tasks(
|
||||
secondary_result = _kickoff_indexing_tasks(
|
||||
celery_app=self.app,
|
||||
db_session=db_session,
|
||||
search_settings=secondary_search_settings,
|
||||
@@ -890,6 +996,7 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
|
||||
lock_beat=lock_beat,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
tasks_created += secondary_result.created
|
||||
elif (
|
||||
secondary_search_settings
|
||||
and secondary_search_settings.switchover_type == SwitchoverType.INSTANT
|
||||
@@ -1002,7 +1109,26 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
|
||||
redis_lock_dump(lock_beat, redis_client)
|
||||
|
||||
time_elapsed = time.monotonic() - time_start
|
||||
task_logger.info(f"check_for_indexing finished: elapsed={time_elapsed:.2f}")
|
||||
task_logger.info(
|
||||
f"check_for_indexing finished: "
|
||||
f"elapsed={time_elapsed:.2f}s "
|
||||
f"primary=[evaluated={primary_result.evaluated} "
|
||||
f"created={primary_result.created} "
|
||||
f"skipped_active={primary_result.skipped_active} "
|
||||
f"skipped_not_found={primary_result.skipped_not_found} "
|
||||
f"skipped_not_indexable={primary_result.skipped_not_indexable} "
|
||||
f"failed={primary_result.failed_to_create}]"
|
||||
+ (
|
||||
f" secondary=[evaluated={secondary_result.evaluated} "
|
||||
f"created={secondary_result.created} "
|
||||
f"skipped_active={secondary_result.skipped_active} "
|
||||
f"skipped_not_found={secondary_result.skipped_not_found} "
|
||||
f"skipped_not_indexable={secondary_result.skipped_not_indexable} "
|
||||
f"failed={secondary_result.failed_to_create}]"
|
||||
if secondary_result
|
||||
else ""
|
||||
)
|
||||
)
|
||||
return tasks_created
|
||||
|
||||
|
||||
|
||||
@@ -172,6 +172,10 @@ def migrate_chunks_from_vespa_to_opensearch_task(
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
indexing_setting = IndexingSetting.from_db_model(search_settings)
|
||||
|
||||
task_logger.debug(
|
||||
"Verified tenant info, migration record, and search settings."
|
||||
)
|
||||
|
||||
# 2.e. Build sanitized to original doc ID mapping to check for
|
||||
# conflicts in the event we sanitize a doc ID to an
|
||||
# already-existing doc ID.
|
||||
@@ -325,6 +329,7 @@ def migrate_chunks_from_vespa_to_opensearch_task(
|
||||
finally:
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
task_logger.debug("Released the OpenSearch migration lock.")
|
||||
else:
|
||||
task_logger.warning(
|
||||
"The OpenSearch migration lock was not owned on completion of the migration task."
|
||||
|
||||
@@ -38,6 +38,7 @@ from onyx.configs.constants import OnyxRedisConstants
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.configs.constants import OnyxRedisSignals
|
||||
from onyx.connectors.factory import instantiate_connector
|
||||
from onyx.connectors.interfaces import BaseConnector
|
||||
from onyx.connectors.models import InputType
|
||||
from onyx.db.connector import mark_ccpair_as_pruned
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair
|
||||
@@ -525,6 +526,14 @@ def connector_pruning_generator_task(
|
||||
return None
|
||||
|
||||
try:
|
||||
# Session 1: pre-enumeration — load cc_pair and instantiate the connector.
|
||||
# The session is closed before enumeration so the DB connection is not held
|
||||
# open during the 10–30+ minute connector crawl.
|
||||
connector_source: DocumentSource | None = None
|
||||
connector_type: str = ""
|
||||
is_connector_public: bool = False
|
||||
runnable_connector: BaseConnector | None = None
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
cc_pair = get_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
@@ -550,49 +559,51 @@ def connector_pruning_generator_task(
|
||||
)
|
||||
redis_connector.prune.set_fence(new_payload)
|
||||
|
||||
connector_source = cc_pair.connector.source
|
||||
connector_type = connector_source.value
|
||||
is_connector_public = cc_pair.access_type == AccessType.PUBLIC
|
||||
|
||||
task_logger.info(
|
||||
f"Pruning generator running connector: cc_pair={cc_pair_id} connector_source={cc_pair.connector.source}"
|
||||
f"Pruning generator running connector: cc_pair={cc_pair_id} connector_source={connector_source}"
|
||||
)
|
||||
|
||||
runnable_connector = instantiate_connector(
|
||||
db_session,
|
||||
cc_pair.connector.source,
|
||||
connector_source,
|
||||
InputType.SLIM_RETRIEVAL,
|
||||
cc_pair.connector.connector_specific_config,
|
||||
cc_pair.credential,
|
||||
)
|
||||
# Session 1 closed here — connection released before enumeration.
|
||||
|
||||
callback = PruneCallback(
|
||||
0,
|
||||
redis_connector,
|
||||
lock,
|
||||
r,
|
||||
timeout_seconds=JOB_TIMEOUT,
|
||||
)
|
||||
callback = PruneCallback(
|
||||
0,
|
||||
redis_connector,
|
||||
lock,
|
||||
r,
|
||||
timeout_seconds=JOB_TIMEOUT,
|
||||
)
|
||||
|
||||
# Extract docs and hierarchy nodes from the source
|
||||
connector_type = cc_pair.connector.source.value
|
||||
extraction_result = extract_ids_from_runnable_connector(
|
||||
runnable_connector, callback, connector_type=connector_type
|
||||
)
|
||||
all_connector_doc_ids = extraction_result.raw_id_to_parent
|
||||
# Extract docs and hierarchy nodes from the source (no DB session held).
|
||||
extraction_result = extract_ids_from_runnable_connector(
|
||||
runnable_connector, callback, connector_type=connector_type
|
||||
)
|
||||
all_connector_doc_ids = extraction_result.raw_id_to_parent
|
||||
|
||||
# Process hierarchy nodes (same as docfetching):
|
||||
# upsert to Postgres and cache in Redis
|
||||
source = cc_pair.connector.source
|
||||
# Session 2: post-enumeration — hierarchy upserts, diff computation, task dispatch.
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
source = connector_source
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
ensure_source_node_exists(redis_client, db_session, source)
|
||||
|
||||
upserted_nodes: list[DBHierarchyNode] = []
|
||||
if extraction_result.hierarchy_nodes:
|
||||
is_connector_public = cc_pair.access_type == AccessType.PUBLIC
|
||||
|
||||
upserted_nodes = upsert_hierarchy_nodes_batch(
|
||||
db_session=db_session,
|
||||
nodes=extraction_result.hierarchy_nodes,
|
||||
source=source,
|
||||
commit=True,
|
||||
commit=False,
|
||||
is_connector_public=is_connector_public,
|
||||
)
|
||||
|
||||
@@ -601,9 +612,13 @@ def connector_pruning_generator_task(
|
||||
hierarchy_node_ids=[n.id for n in upserted_nodes],
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
commit=True,
|
||||
commit=False,
|
||||
)
|
||||
|
||||
# Single commit so the FK reference in the join table can never
|
||||
# outrun the parent hierarchy_node insert.
|
||||
db_session.commit()
|
||||
|
||||
cache_entries = [
|
||||
HierarchyNodeCacheEntry.from_db_model(node)
|
||||
for node in upserted_nodes
|
||||
@@ -658,7 +673,7 @@ def connector_pruning_generator_task(
|
||||
task_logger.info(
|
||||
"Pruning set collected: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"connector_source={cc_pair.connector.source} "
|
||||
f"connector_source={connector_source} "
|
||||
f"docs_to_remove={len(doc_ids_to_remove)}"
|
||||
)
|
||||
|
||||
|
||||
@@ -23,6 +23,8 @@ class IndexAttemptErrorPydantic(BaseModel):
|
||||
|
||||
index_attempt_id: int
|
||||
|
||||
error_type: str | None = None
|
||||
|
||||
@classmethod
|
||||
def from_model(cls, model: IndexAttemptError) -> "IndexAttemptErrorPydantic":
|
||||
return cls(
|
||||
@@ -37,4 +39,5 @@ class IndexAttemptErrorPydantic(BaseModel):
|
||||
is_resolved=model.is_resolved,
|
||||
time_created=model.time_created,
|
||||
index_attempt_id=model.index_attempt_id,
|
||||
error_type=model.error_type,
|
||||
)
|
||||
|
||||
@@ -5,6 +5,7 @@ from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
|
||||
import sentry_sdk
|
||||
from celery import Celery
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -68,6 +69,7 @@ from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.server.features.build.indexing.persistent_document_writer import (
|
||||
get_persistent_document_writer,
|
||||
)
|
||||
from onyx.server.metrics.connector_health_metrics import on_index_attempt_status_change
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.middleware import make_randomized_onyx_request_id
|
||||
from onyx.utils.postgres_sanitization import sanitize_document_for_postgres
|
||||
@@ -267,6 +269,13 @@ def run_docfetching_entrypoint(
|
||||
)
|
||||
credential_id = attempt.connector_credential_pair.credential_id
|
||||
|
||||
on_index_attempt_status_change(
|
||||
tenant_id=tenant_id,
|
||||
source=attempt.connector_credential_pair.connector.source.value,
|
||||
cc_pair_id=connector_credential_pair_id,
|
||||
status="in_progress",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Docfetching starting{tenant_str}: "
|
||||
f"connector='{connector_name}' "
|
||||
@@ -556,6 +565,27 @@ def connector_document_extraction(
|
||||
|
||||
# save record of any failures at the connector level
|
||||
if failure is not None:
|
||||
if failure.exception is not None:
|
||||
with sentry_sdk.new_scope() as scope:
|
||||
scope.set_tag("stage", "connector_fetch")
|
||||
scope.set_tag("connector_source", db_connector.source.value)
|
||||
scope.set_tag("cc_pair_id", str(cc_pair_id))
|
||||
scope.set_tag("index_attempt_id", str(index_attempt_id))
|
||||
scope.set_tag("tenant_id", tenant_id)
|
||||
if failure.failed_document:
|
||||
scope.set_tag(
|
||||
"doc_id", failure.failed_document.document_id
|
||||
)
|
||||
if failure.failed_entity:
|
||||
scope.set_tag(
|
||||
"entity_id", failure.failed_entity.entity_id
|
||||
)
|
||||
scope.fingerprint = [
|
||||
"connector-fetch-failure",
|
||||
db_connector.source.value,
|
||||
type(failure.exception).__name__,
|
||||
]
|
||||
sentry_sdk.capture_exception(failure.exception)
|
||||
total_failures += 1
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
create_index_attempt_error(
|
||||
|
||||
@@ -364,7 +364,7 @@ def _get_or_extract_plaintext(
|
||||
plaintext_io = file_store.read_file(plaintext_key, mode="b")
|
||||
return plaintext_io.read().decode("utf-8")
|
||||
except Exception:
|
||||
logger.exception(f"Error when reading file, id={file_id}")
|
||||
logger.info(f"Cache miss for file with id={file_id}")
|
||||
|
||||
# Cache miss — extract and store.
|
||||
content_text = extract_fn()
|
||||
|
||||
@@ -4,8 +4,6 @@ from collections.abc import Callable
|
||||
from typing import Any
|
||||
from typing import Literal
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.chat.chat_state import ChatStateContainer
|
||||
from onyx.chat.chat_utils import create_tool_call_failure_messages
|
||||
from onyx.chat.citation_processor import CitationMapping
|
||||
@@ -635,7 +633,6 @@ def run_llm_loop(
|
||||
user_memory_context: UserMemoryContext | None,
|
||||
llm: LLM,
|
||||
token_counter: Callable[[str], int],
|
||||
db_session: Session,
|
||||
forced_tool_id: int | None = None,
|
||||
user_identity: LLMUserIdentity | None = None,
|
||||
chat_session_id: str | None = None,
|
||||
@@ -1020,20 +1017,16 @@ def run_llm_loop(
|
||||
persisted_memory_id: int | None = None
|
||||
if user_memory_context and user_memory_context.user_id:
|
||||
if tool_response.rich_response.index_to_replace is not None:
|
||||
memory = update_memory_at_index(
|
||||
persisted_memory_id = update_memory_at_index(
|
||||
user_id=user_memory_context.user_id,
|
||||
index=tool_response.rich_response.index_to_replace,
|
||||
new_text=tool_response.rich_response.memory_text,
|
||||
db_session=db_session,
|
||||
)
|
||||
persisted_memory_id = memory.id if memory else None
|
||||
else:
|
||||
memory = add_memory(
|
||||
persisted_memory_id = add_memory(
|
||||
user_id=user_memory_context.user_id,
|
||||
memory_text=tool_response.rich_response.memory_text,
|
||||
db_session=db_session,
|
||||
)
|
||||
persisted_memory_id = memory.id
|
||||
operation: Literal["add", "update"] = (
|
||||
"update"
|
||||
if tool_response.rich_response.index_to_replace is not None
|
||||
|
||||
@@ -67,7 +67,6 @@ from onyx.db.chat import get_chat_session_by_id
|
||||
from onyx.db.chat import get_or_create_root_message
|
||||
from onyx.db.chat import reserve_message_id
|
||||
from onyx.db.chat import reserve_multi_model_message_ids
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import HookPoint
|
||||
from onyx.db.memory import get_memories
|
||||
from onyx.db.models import ChatMessage
|
||||
@@ -94,6 +93,7 @@ from onyx.llm.factory import get_llm_for_persona
|
||||
from onyx.llm.factory import get_llm_token_counter
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.interfaces import LLMUserIdentity
|
||||
from onyx.llm.multi_llm import LLMTimeoutError
|
||||
from onyx.llm.override_models import LLMOverride
|
||||
from onyx.llm.request_context import reset_llm_mock_response
|
||||
from onyx.llm.request_context import set_llm_mock_response
|
||||
@@ -1006,93 +1006,86 @@ def _run_models(
|
||||
model_llm = setup.llms[model_idx]
|
||||
|
||||
try:
|
||||
# Each worker opens its own session — SQLAlchemy sessions are not thread-safe.
|
||||
# Do NOT write to the outer db_session (or any shared DB state) from here;
|
||||
# all DB writes in this thread must go through thread_db_session.
|
||||
with get_session_with_current_tenant() as thread_db_session:
|
||||
thread_tool_dict = construct_tools(
|
||||
persona=setup.persona,
|
||||
db_session=thread_db_session,
|
||||
emitter=model_emitter,
|
||||
user=user,
|
||||
llm=model_llm,
|
||||
search_tool_config=SearchToolConfig(
|
||||
user_selected_filters=setup.new_msg_req.internal_search_filters,
|
||||
project_id_filter=setup.search_params.project_id_filter,
|
||||
persona_id_filter=setup.search_params.persona_id_filter,
|
||||
bypass_acl=setup.bypass_acl,
|
||||
slack_context=setup.slack_context,
|
||||
enable_slack_search=_should_enable_slack_search(
|
||||
setup.persona, setup.new_msg_req.internal_search_filters
|
||||
),
|
||||
# Each function opens short-lived DB sessions on demand.
|
||||
# Do NOT pass a long-lived session here — it would hold a
|
||||
# connection for the entire LLM loop (minutes), and cloud
|
||||
# infrastructure may drop idle connections.
|
||||
thread_tool_dict = construct_tools(
|
||||
persona=setup.persona,
|
||||
emitter=model_emitter,
|
||||
user=user,
|
||||
llm=model_llm,
|
||||
search_tool_config=SearchToolConfig(
|
||||
user_selected_filters=setup.new_msg_req.internal_search_filters,
|
||||
project_id_filter=setup.search_params.project_id_filter,
|
||||
persona_id_filter=setup.search_params.persona_id_filter,
|
||||
bypass_acl=setup.bypass_acl,
|
||||
slack_context=setup.slack_context,
|
||||
enable_slack_search=_should_enable_slack_search(
|
||||
setup.persona, setup.new_msg_req.internal_search_filters
|
||||
),
|
||||
custom_tool_config=CustomToolConfig(
|
||||
chat_session_id=setup.chat_session.id,
|
||||
message_id=setup.user_message.id,
|
||||
additional_headers=setup.custom_tool_additional_headers,
|
||||
mcp_headers=setup.mcp_headers,
|
||||
),
|
||||
file_reader_tool_config=FileReaderToolConfig(
|
||||
user_file_ids=setup.available_files.user_file_ids,
|
||||
chat_file_ids=setup.available_files.chat_file_ids,
|
||||
),
|
||||
allowed_tool_ids=setup.new_msg_req.allowed_tool_ids,
|
||||
search_usage_forcing_setting=setup.search_params.search_usage,
|
||||
),
|
||||
custom_tool_config=CustomToolConfig(
|
||||
chat_session_id=setup.chat_session.id,
|
||||
message_id=setup.user_message.id,
|
||||
additional_headers=setup.custom_tool_additional_headers,
|
||||
mcp_headers=setup.mcp_headers,
|
||||
),
|
||||
file_reader_tool_config=FileReaderToolConfig(
|
||||
user_file_ids=setup.available_files.user_file_ids,
|
||||
chat_file_ids=setup.available_files.chat_file_ids,
|
||||
),
|
||||
allowed_tool_ids=setup.new_msg_req.allowed_tool_ids,
|
||||
search_usage_forcing_setting=setup.search_params.search_usage,
|
||||
)
|
||||
model_tools = [
|
||||
tool for tool_list in thread_tool_dict.values() for tool in tool_list
|
||||
]
|
||||
|
||||
if setup.forced_tool_id and setup.forced_tool_id not in {
|
||||
tool.id for tool in model_tools
|
||||
}:
|
||||
raise ValueError(
|
||||
f"Forced tool {setup.forced_tool_id} not found in tools"
|
||||
)
|
||||
model_tools = [
|
||||
tool
|
||||
for tool_list in thread_tool_dict.values()
|
||||
for tool in tool_list
|
||||
]
|
||||
|
||||
if setup.forced_tool_id and setup.forced_tool_id not in {
|
||||
tool.id for tool in model_tools
|
||||
}:
|
||||
raise ValueError(
|
||||
f"Forced tool {setup.forced_tool_id} not found in tools"
|
||||
)
|
||||
|
||||
# Per-thread copy: run_llm_loop mutates simple_chat_history in-place.
|
||||
if n_models == 1 and setup.new_msg_req.deep_research:
|
||||
if setup.chat_session.project_id:
|
||||
raise RuntimeError(
|
||||
"Deep research is not supported for projects"
|
||||
)
|
||||
run_deep_research_llm_loop(
|
||||
emitter=model_emitter,
|
||||
state_container=sc,
|
||||
simple_chat_history=list(setup.simple_chat_history),
|
||||
tools=model_tools,
|
||||
custom_agent_prompt=setup.custom_agent_prompt,
|
||||
llm=model_llm,
|
||||
token_counter=get_llm_token_counter(model_llm),
|
||||
db_session=thread_db_session,
|
||||
skip_clarification=setup.skip_clarification,
|
||||
user_identity=setup.user_identity,
|
||||
chat_session_id=str(setup.chat_session.id),
|
||||
all_injected_file_metadata=setup.all_injected_file_metadata,
|
||||
)
|
||||
else:
|
||||
run_llm_loop(
|
||||
emitter=model_emitter,
|
||||
state_container=sc,
|
||||
simple_chat_history=list(setup.simple_chat_history),
|
||||
tools=model_tools,
|
||||
custom_agent_prompt=setup.custom_agent_prompt,
|
||||
context_files=setup.extracted_context_files,
|
||||
persona=setup.persona,
|
||||
user_memory_context=setup.user_memory_context,
|
||||
llm=model_llm,
|
||||
token_counter=get_llm_token_counter(model_llm),
|
||||
db_session=thread_db_session,
|
||||
forced_tool_id=setup.forced_tool_id,
|
||||
user_identity=setup.user_identity,
|
||||
chat_session_id=str(setup.chat_session.id),
|
||||
chat_files=setup.chat_files_for_tools,
|
||||
include_citations=setup.new_msg_req.include_citations,
|
||||
all_injected_file_metadata=setup.all_injected_file_metadata,
|
||||
inject_memories_in_prompt=user.use_memories,
|
||||
)
|
||||
# Per-thread copy: run_llm_loop mutates simple_chat_history in-place.
|
||||
if n_models == 1 and setup.new_msg_req.deep_research:
|
||||
if setup.chat_session.project_id:
|
||||
raise RuntimeError("Deep research is not supported for projects")
|
||||
run_deep_research_llm_loop(
|
||||
emitter=model_emitter,
|
||||
state_container=sc,
|
||||
simple_chat_history=list(setup.simple_chat_history),
|
||||
tools=model_tools,
|
||||
custom_agent_prompt=setup.custom_agent_prompt,
|
||||
llm=model_llm,
|
||||
token_counter=get_llm_token_counter(model_llm),
|
||||
skip_clarification=setup.skip_clarification,
|
||||
user_identity=setup.user_identity,
|
||||
chat_session_id=str(setup.chat_session.id),
|
||||
all_injected_file_metadata=setup.all_injected_file_metadata,
|
||||
)
|
||||
else:
|
||||
run_llm_loop(
|
||||
emitter=model_emitter,
|
||||
state_container=sc,
|
||||
simple_chat_history=list(setup.simple_chat_history),
|
||||
tools=model_tools,
|
||||
custom_agent_prompt=setup.custom_agent_prompt,
|
||||
context_files=setup.extracted_context_files,
|
||||
persona=setup.persona,
|
||||
user_memory_context=setup.user_memory_context,
|
||||
llm=model_llm,
|
||||
token_counter=get_llm_token_counter(model_llm),
|
||||
forced_tool_id=setup.forced_tool_id,
|
||||
user_identity=setup.user_identity,
|
||||
chat_session_id=str(setup.chat_session.id),
|
||||
chat_files=setup.chat_files_for_tools,
|
||||
include_citations=setup.new_msg_req.include_citations,
|
||||
all_injected_file_metadata=setup.all_injected_file_metadata,
|
||||
inject_memories_in_prompt=user.use_memories,
|
||||
)
|
||||
|
||||
model_succeeded[model_idx] = True
|
||||
|
||||
@@ -1174,6 +1167,32 @@ def _run_models(
|
||||
else:
|
||||
if item is _MODEL_DONE:
|
||||
models_remaining -= 1
|
||||
elif isinstance(item, LLMTimeoutError):
|
||||
model_llm = setup.llms[model_idx]
|
||||
error_msg = (
|
||||
"The LLM took too long to respond. "
|
||||
"If you're running a local model, try increasing the "
|
||||
"LLM_SOCKET_READ_TIMEOUT environment variable "
|
||||
"(current default: 120 seconds)."
|
||||
)
|
||||
stack_trace = "".join(
|
||||
traceback.format_exception(type(item), item, item.__traceback__)
|
||||
)
|
||||
if model_llm.config.api_key and len(model_llm.config.api_key) > 2:
|
||||
stack_trace = stack_trace.replace(
|
||||
model_llm.config.api_key, "[REDACTED_API_KEY]"
|
||||
)
|
||||
yield StreamingError(
|
||||
error=error_msg,
|
||||
stack_trace=stack_trace,
|
||||
error_code="CONNECTION_ERROR",
|
||||
is_retryable=True,
|
||||
details={
|
||||
"model": model_llm.config.model_name,
|
||||
"provider": model_llm.config.model_provider,
|
||||
"model_index": model_idx,
|
||||
},
|
||||
)
|
||||
elif isinstance(item, Exception):
|
||||
# Yield a tagged error for this model but keep the other models running.
|
||||
# Do NOT decrement models_remaining — _run_model's finally always posts
|
||||
|
||||
@@ -283,6 +283,7 @@ class NotificationType(str, Enum):
|
||||
RELEASE_NOTES = "release_notes"
|
||||
ASSISTANT_FILES_READY = "assistant_files_ready"
|
||||
FEATURE_ANNOUNCEMENT = "feature_announcement"
|
||||
CONNECTOR_REPEATED_ERRORS = "connector_repeated_errors"
|
||||
|
||||
|
||||
class BlobType(str, Enum):
|
||||
|
||||
48
backend/onyx/configs/sentry.py
Normal file
48
backend/onyx/configs/sentry.py
Normal file
@@ -0,0 +1,48 @@
|
||||
from typing import Any
|
||||
|
||||
from sentry_sdk.types import Event
|
||||
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_instance_id_resolved = False
|
||||
|
||||
|
||||
def _add_instance_tags(
|
||||
event: Event,
|
||||
hint: dict[str, Any], # noqa: ARG001
|
||||
) -> Event | None:
|
||||
"""Sentry before_send hook that lazily attaches instance identification tags.
|
||||
|
||||
On the first event, resolves the instance UUID from the KV store (requires DB)
|
||||
and sets it as a global Sentry tag. Subsequent events pick it up automatically.
|
||||
"""
|
||||
global _instance_id_resolved
|
||||
|
||||
if _instance_id_resolved:
|
||||
return event
|
||||
|
||||
try:
|
||||
import sentry_sdk
|
||||
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
if MULTI_TENANT:
|
||||
instance_id = "multi-tenant-cloud"
|
||||
else:
|
||||
from onyx.utils.telemetry import get_or_generate_uuid
|
||||
|
||||
instance_id = get_or_generate_uuid()
|
||||
|
||||
sentry_sdk.set_tag("instance_id", instance_id)
|
||||
|
||||
# Also set on this event since set_tag won't retroactively apply
|
||||
event.setdefault("tags", {})["instance_id"] = instance_id
|
||||
|
||||
# Only mark resolved after success — if DB wasn't ready, retry next event
|
||||
_instance_id_resolved = True
|
||||
except Exception:
|
||||
logger.debug("Failed to resolve instance_id for Sentry tagging")
|
||||
|
||||
return event
|
||||
@@ -27,16 +27,19 @@ _STATUS_TO_ERROR_CODE: dict[int, OnyxErrorCode] = {
|
||||
401: OnyxErrorCode.CREDENTIAL_EXPIRED,
|
||||
403: OnyxErrorCode.INSUFFICIENT_PERMISSIONS,
|
||||
404: OnyxErrorCode.BAD_GATEWAY,
|
||||
429: OnyxErrorCode.RATE_LIMITED,
|
||||
}
|
||||
|
||||
|
||||
def _error_code_for_status(status_code: int) -> OnyxErrorCode:
|
||||
"""Map an HTTP status code to the appropriate OnyxErrorCode.
|
||||
|
||||
Expects a >= 400 status code. Known codes (401, 403, 404, 429) are
|
||||
Expects a >= 400 status code. Known codes (401, 403, 404) are
|
||||
mapped to specific error codes; all other codes (unrecognised 4xx
|
||||
and 5xx) map to BAD_GATEWAY as unexpected upstream errors.
|
||||
|
||||
Note: 429 is intentionally omitted — the rl_requests wrapper
|
||||
handles rate limits transparently at the HTTP layer, so 429
|
||||
responses never reach this function.
|
||||
"""
|
||||
if status_code in _STATUS_TO_ERROR_CODE:
|
||||
return _STATUS_TO_ERROR_CODE[status_code]
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import Literal
|
||||
from typing import NoReturn
|
||||
from typing import TypeAlias
|
||||
|
||||
from pydantic import BaseModel
|
||||
from retry import retry
|
||||
@@ -25,8 +24,11 @@ from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.interfaces import SlimConnectorWithPermSync
|
||||
from onyx.connectors.models import ConnectorCheckpoint
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import DocumentFailure
|
||||
from onyx.connectors.models import EntityFailure
|
||||
from onyx.connectors.models import ImageSection
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
@@ -47,10 +49,6 @@ def _handle_canvas_api_error(e: OnyxError) -> NoReturn:
|
||||
raise InsufficientPermissionsError(
|
||||
"Canvas API token does not have sufficient permissions (HTTP 403)."
|
||||
)
|
||||
elif e.status_code == 429:
|
||||
raise ConnectorValidationError(
|
||||
"Canvas rate-limit exceeded (HTTP 429). Please try again later."
|
||||
)
|
||||
elif e.status_code >= 500:
|
||||
raise UnexpectedValidationError(
|
||||
f"Unexpected Canvas HTTP error (status={e.status_code}): {e}"
|
||||
@@ -61,6 +59,60 @@ def _handle_canvas_api_error(e: OnyxError) -> NoReturn:
|
||||
)
|
||||
|
||||
|
||||
class CanvasStage(StrEnum):
|
||||
PAGES = "pages"
|
||||
ASSIGNMENTS = "assignments"
|
||||
ANNOUNCEMENTS = "announcements"
|
||||
|
||||
|
||||
_STAGE_CONFIG: dict[CanvasStage, dict[str, Any]] = {
|
||||
CanvasStage.PAGES: {
|
||||
"endpoint": "courses/{course_id}/pages",
|
||||
"params": {
|
||||
"per_page": "100",
|
||||
"include[]": "body",
|
||||
"published": "true",
|
||||
"sort": "updated_at",
|
||||
"order": "desc",
|
||||
},
|
||||
},
|
||||
CanvasStage.ASSIGNMENTS: {
|
||||
"endpoint": "courses/{course_id}/assignments",
|
||||
"params": {"per_page": "100", "published": "true"},
|
||||
},
|
||||
CanvasStage.ANNOUNCEMENTS: {
|
||||
"endpoint": "announcements",
|
||||
"params": {
|
||||
"per_page": "100",
|
||||
"context_codes[]": "course_{course_id}",
|
||||
"active_only": "true",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _parse_canvas_dt(timestamp_str: str) -> datetime:
|
||||
"""Parse a Canvas ISO-8601 timestamp (e.g. '2025-06-15T12:00:00Z')
|
||||
into a timezone-aware UTC datetime.
|
||||
|
||||
Canvas returns timestamps with a trailing 'Z' instead of '+00:00',
|
||||
so we normalise before parsing.
|
||||
"""
|
||||
return datetime.fromisoformat(timestamp_str.replace("Z", "+00:00")).astimezone(
|
||||
timezone.utc
|
||||
)
|
||||
|
||||
|
||||
def _unix_to_canvas_time(epoch: float) -> str:
|
||||
"""Convert a Unix timestamp to Canvas ISO-8601 format (e.g. '2025-06-15T12:00:00Z')."""
|
||||
return datetime.fromtimestamp(epoch, tz=timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")
|
||||
|
||||
|
||||
def _in_time_window(timestamp_str: str, start: float, end: float) -> bool:
|
||||
"""Check whether a Canvas ISO-8601 timestamp falls within (start, end]."""
|
||||
return start < _parse_canvas_dt(timestamp_str).timestamp() <= end
|
||||
|
||||
|
||||
class CanvasCourse(BaseModel):
|
||||
id: int
|
||||
name: str | None = None
|
||||
@@ -145,9 +197,6 @@ class CanvasAnnouncement(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
CanvasStage: TypeAlias = Literal["pages", "assignments", "announcements"]
|
||||
|
||||
|
||||
class CanvasConnectorCheckpoint(ConnectorCheckpoint):
|
||||
"""Checkpoint state for resumable Canvas indexing.
|
||||
|
||||
@@ -165,15 +214,30 @@ class CanvasConnectorCheckpoint(ConnectorCheckpoint):
|
||||
|
||||
course_ids: list[int] = []
|
||||
current_course_index: int = 0
|
||||
stage: CanvasStage = "pages"
|
||||
stage: CanvasStage = CanvasStage.PAGES
|
||||
next_url: str | None = None
|
||||
|
||||
def advance_course(self) -> None:
|
||||
"""Move to the next course and reset within-course state."""
|
||||
self.current_course_index += 1
|
||||
self.stage = "pages"
|
||||
self.stage = CanvasStage.PAGES
|
||||
self.next_url = None
|
||||
|
||||
def advance_stage(self) -> None:
|
||||
"""Advance past the current stage.
|
||||
|
||||
Moves to the next stage within the same course, or to the next
|
||||
course if the current stage is the last one. Resets next_url so
|
||||
the next call starts fresh on the new stage.
|
||||
"""
|
||||
self.next_url = None
|
||||
stages: list[CanvasStage] = list(CanvasStage)
|
||||
next_idx = stages.index(self.stage) + 1
|
||||
if next_idx < len(stages):
|
||||
self.stage = stages[next_idx]
|
||||
else:
|
||||
self.advance_course()
|
||||
|
||||
|
||||
class CanvasConnector(
|
||||
CheckpointedConnectorWithPermSync[CanvasConnectorCheckpoint],
|
||||
@@ -295,13 +359,7 @@ class CanvasConnector(
|
||||
if body_text:
|
||||
text_parts.append(body_text)
|
||||
|
||||
doc_updated_at = (
|
||||
datetime.fromisoformat(page.updated_at.replace("Z", "+00:00")).astimezone(
|
||||
timezone.utc
|
||||
)
|
||||
if page.updated_at
|
||||
else None
|
||||
)
|
||||
doc_updated_at = _parse_canvas_dt(page.updated_at) if page.updated_at else None
|
||||
|
||||
document = self._build_document(
|
||||
doc_id=f"canvas-page-{page.course_id}-{page.page_id}",
|
||||
@@ -325,17 +383,11 @@ class CanvasConnector(
|
||||
if desc_text:
|
||||
text_parts.append(desc_text)
|
||||
if assignment.due_at:
|
||||
due_dt = datetime.fromisoformat(
|
||||
assignment.due_at.replace("Z", "+00:00")
|
||||
).astimezone(timezone.utc)
|
||||
due_dt = _parse_canvas_dt(assignment.due_at)
|
||||
text_parts.append(f"Due: {due_dt.strftime('%B %d, %Y %H:%M UTC')}")
|
||||
|
||||
doc_updated_at = (
|
||||
datetime.fromisoformat(
|
||||
assignment.updated_at.replace("Z", "+00:00")
|
||||
).astimezone(timezone.utc)
|
||||
if assignment.updated_at
|
||||
else None
|
||||
_parse_canvas_dt(assignment.updated_at) if assignment.updated_at else None
|
||||
)
|
||||
|
||||
document = self._build_document(
|
||||
@@ -361,11 +413,7 @@ class CanvasConnector(
|
||||
text_parts.append(msg_text)
|
||||
|
||||
doc_updated_at = (
|
||||
datetime.fromisoformat(
|
||||
announcement.posted_at.replace("Z", "+00:00")
|
||||
).astimezone(timezone.utc)
|
||||
if announcement.posted_at
|
||||
else None
|
||||
_parse_canvas_dt(announcement.posted_at) if announcement.posted_at else None
|
||||
)
|
||||
|
||||
document = self._build_document(
|
||||
@@ -400,6 +448,314 @@ class CanvasConnector(
|
||||
self._canvas_client = client
|
||||
return None
|
||||
|
||||
def _fetch_stage_page(
|
||||
self,
|
||||
next_url: str | None,
|
||||
endpoint: str,
|
||||
params: dict[str, Any],
|
||||
) -> tuple[list[Any], str | None]:
|
||||
"""Fetch one page of API results for the current stage.
|
||||
|
||||
Returns (items, next_url). All error handling is done by the
|
||||
caller (_load_from_checkpoint).
|
||||
"""
|
||||
if next_url:
|
||||
# Resuming mid-pagination: the next_url from Canvas's
|
||||
# Link header already contains endpoint + query params.
|
||||
response, result_next_url = self.canvas_client.get(full_url=next_url)
|
||||
else:
|
||||
# First request for this stage: build from endpoint + params.
|
||||
response, result_next_url = self.canvas_client.get(
|
||||
endpoint=endpoint, params=params
|
||||
)
|
||||
return response or [], result_next_url
|
||||
|
||||
def _process_items(
|
||||
self,
|
||||
response: list[Any],
|
||||
stage: CanvasStage,
|
||||
course_id: int,
|
||||
start: float,
|
||||
end: float,
|
||||
include_permissions: bool,
|
||||
) -> tuple[list[Document | ConnectorFailure], bool]:
|
||||
"""Process a page of API results into documents.
|
||||
|
||||
Returns (docs, early_exit). early_exit is True when pages
|
||||
(sorted desc by updated_at) hit an item older than start,
|
||||
signaling that pagination should stop.
|
||||
"""
|
||||
results: list[Document | ConnectorFailure] = []
|
||||
early_exit = False
|
||||
|
||||
for item in response:
|
||||
try:
|
||||
if stage == CanvasStage.PAGES:
|
||||
page = CanvasPage.from_api(item, course_id=course_id)
|
||||
if not page.updated_at:
|
||||
continue
|
||||
# Pages are sorted by updated_at desc — once we see
|
||||
# an item at or before `start`, all remaining items
|
||||
# on this and subsequent pages are older too.
|
||||
if not _in_time_window(page.updated_at, start, end):
|
||||
if _parse_canvas_dt(page.updated_at).timestamp() <= start:
|
||||
early_exit = True
|
||||
break
|
||||
# ts > end: page is newer than our window, skip it
|
||||
continue
|
||||
doc = self._convert_page_to_document(page)
|
||||
results.append(
|
||||
self._maybe_attach_permissions(
|
||||
doc, course_id, include_permissions
|
||||
)
|
||||
)
|
||||
|
||||
elif stage == CanvasStage.ASSIGNMENTS:
|
||||
assignment = CanvasAssignment.from_api(item, course_id=course_id)
|
||||
if not assignment.updated_at or not _in_time_window(
|
||||
assignment.updated_at, start, end
|
||||
):
|
||||
continue
|
||||
doc = self._convert_assignment_to_document(assignment)
|
||||
results.append(
|
||||
self._maybe_attach_permissions(
|
||||
doc, course_id, include_permissions
|
||||
)
|
||||
)
|
||||
|
||||
elif stage == CanvasStage.ANNOUNCEMENTS:
|
||||
announcement = CanvasAnnouncement.from_api(
|
||||
item, course_id=course_id
|
||||
)
|
||||
if not announcement.posted_at:
|
||||
logger.debug(
|
||||
f"Skipping announcement {announcement.id} in "
|
||||
f"course {course_id}: no posted_at"
|
||||
)
|
||||
continue
|
||||
if not _in_time_window(announcement.posted_at, start, end):
|
||||
continue
|
||||
doc = self._convert_announcement_to_document(announcement)
|
||||
results.append(
|
||||
self._maybe_attach_permissions(
|
||||
doc, course_id, include_permissions
|
||||
)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
item_id = item.get("id") or item.get("page_id", "unknown")
|
||||
if stage == CanvasStage.PAGES:
|
||||
doc_link = (
|
||||
f"{self.canvas_base_url}/courses/{course_id}"
|
||||
f"/pages/{item.get('url', '')}"
|
||||
)
|
||||
else:
|
||||
doc_link = item.get("html_url", "")
|
||||
results.append(
|
||||
ConnectorFailure(
|
||||
failed_document=DocumentFailure(
|
||||
document_id=f"canvas-{stage.removesuffix('s')}-{course_id}-{item_id}",
|
||||
document_link=doc_link,
|
||||
),
|
||||
failure_message=f"Failed to process {stage.removesuffix('s')}: {e}",
|
||||
exception=e,
|
||||
)
|
||||
)
|
||||
|
||||
return results, early_exit
|
||||
|
||||
def _maybe_attach_permissions(
|
||||
self,
|
||||
document: Document,
|
||||
course_id: int,
|
||||
include_permissions: bool,
|
||||
) -> Document:
|
||||
if include_permissions:
|
||||
document.external_access = self._get_course_permissions(course_id)
|
||||
return document
|
||||
|
||||
def _load_from_checkpoint(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch,
|
||||
end: SecondsSinceUnixEpoch,
|
||||
checkpoint: CanvasConnectorCheckpoint,
|
||||
include_permissions: bool = False,
|
||||
) -> CheckpointOutput[CanvasConnectorCheckpoint]:
|
||||
"""Shared implementation for load_from_checkpoint and load_from_checkpoint_with_perm_sync."""
|
||||
new_checkpoint = checkpoint.model_copy(deep=True)
|
||||
|
||||
# First call: materialize the list of course IDs.
|
||||
# On failure, let the exception propagate so the framework fails the
|
||||
# attempt cleanly. Swallowing errors here would leave the checkpoint
|
||||
# state unchanged and cause an infinite retry loop.
|
||||
if not new_checkpoint.course_ids:
|
||||
try:
|
||||
courses = self._list_courses()
|
||||
except OnyxError as e:
|
||||
if e.status_code in (401, 403):
|
||||
_handle_canvas_api_error(e) # NoReturn — always raises
|
||||
raise
|
||||
new_checkpoint.course_ids = [c.id for c in courses]
|
||||
logger.info(f"Found {len(courses)} Canvas courses to process")
|
||||
new_checkpoint.has_more = len(new_checkpoint.course_ids) > 0
|
||||
return new_checkpoint
|
||||
|
||||
# All courses done.
|
||||
if new_checkpoint.current_course_index >= len(new_checkpoint.course_ids):
|
||||
new_checkpoint.has_more = False
|
||||
return new_checkpoint
|
||||
|
||||
course_id = new_checkpoint.course_ids[new_checkpoint.current_course_index]
|
||||
try:
|
||||
stage = CanvasStage(new_checkpoint.stage)
|
||||
except ValueError as e:
|
||||
raise ValueError(
|
||||
f"Invalid checkpoint stage: {new_checkpoint.stage!r}. "
|
||||
f"Valid stages: {[s.value for s in CanvasStage]}"
|
||||
) from e
|
||||
|
||||
# Build endpoint + params from the static template.
|
||||
config = _STAGE_CONFIG[stage]
|
||||
endpoint = config["endpoint"].format(course_id=course_id)
|
||||
params = {k: v.format(course_id=course_id) for k, v in config["params"].items()}
|
||||
# Only the announcements API supports server-side date filtering
|
||||
# (start_date/end_date). Pages support server-side sorting
|
||||
# (sort=updated_at desc) enabling early exit, but not date
|
||||
# filtering. Assignments support neither. Both are filtered
|
||||
# client-side via _in_time_window after fetching.
|
||||
if stage == CanvasStage.ANNOUNCEMENTS:
|
||||
params["start_date"] = _unix_to_canvas_time(start)
|
||||
params["end_date"] = _unix_to_canvas_time(end)
|
||||
|
||||
try:
|
||||
response, result_next_url = self._fetch_stage_page(
|
||||
next_url=new_checkpoint.next_url,
|
||||
endpoint=endpoint,
|
||||
params=params,
|
||||
)
|
||||
except OnyxError as oe:
|
||||
# Security errors from _parse_next_link (host/scheme
|
||||
# mismatch on pagination URLs) have no status code override
|
||||
# and must not be silenced.
|
||||
is_api_error = oe._status_code_override is not None
|
||||
if not is_api_error:
|
||||
raise
|
||||
if oe.status_code in (401, 403):
|
||||
_handle_canvas_api_error(oe) # NoReturn — always raises
|
||||
|
||||
# 404 means the course itself is gone or inaccessible. The
|
||||
# other stages on this course will hit the same 404, so skip
|
||||
# the whole course rather than burning API calls on each stage.
|
||||
if oe.status_code == 404:
|
||||
logger.warning(
|
||||
f"Canvas course {course_id} not found while fetching "
|
||||
f"{stage} (HTTP 404). Skipping course."
|
||||
)
|
||||
yield ConnectorFailure(
|
||||
failed_entity=EntityFailure(
|
||||
entity_id=f"canvas-course-{course_id}",
|
||||
),
|
||||
failure_message=(f"Canvas course {course_id} not found: {oe}"),
|
||||
exception=oe,
|
||||
)
|
||||
new_checkpoint.advance_course()
|
||||
else:
|
||||
logger.warning(
|
||||
f"Failed to fetch {stage} for course {course_id}: {oe}. "
|
||||
f"Skipping remainder of this stage."
|
||||
)
|
||||
yield ConnectorFailure(
|
||||
failed_entity=EntityFailure(
|
||||
entity_id=f"canvas-{stage}-{course_id}",
|
||||
),
|
||||
failure_message=(
|
||||
f"Failed to fetch {stage} for course {course_id}: {oe}"
|
||||
),
|
||||
exception=oe,
|
||||
)
|
||||
new_checkpoint.advance_stage()
|
||||
new_checkpoint.has_more = new_checkpoint.current_course_index < len(
|
||||
new_checkpoint.course_ids
|
||||
)
|
||||
return new_checkpoint
|
||||
except Exception as e:
|
||||
# Unknown error — skip the stage and try to continue.
|
||||
logger.warning(
|
||||
f"Failed to fetch {stage} for course {course_id}: {e}. "
|
||||
f"Skipping remainder of this stage."
|
||||
)
|
||||
yield ConnectorFailure(
|
||||
failed_entity=EntityFailure(
|
||||
entity_id=f"canvas-{stage}-{course_id}",
|
||||
),
|
||||
failure_message=(
|
||||
f"Failed to fetch {stage} for course {course_id}: {e}"
|
||||
),
|
||||
exception=e,
|
||||
)
|
||||
new_checkpoint.advance_stage()
|
||||
new_checkpoint.has_more = new_checkpoint.current_course_index < len(
|
||||
new_checkpoint.course_ids
|
||||
)
|
||||
return new_checkpoint
|
||||
|
||||
# Process fetched items
|
||||
results, early_exit = self._process_items(
|
||||
response, stage, course_id, start, end, include_permissions
|
||||
)
|
||||
for result in results:
|
||||
yield result
|
||||
|
||||
# If we hit an item older than our window (pages sorted desc),
|
||||
# skip remaining pagination and advance to the next stage.
|
||||
if early_exit:
|
||||
result_next_url = None
|
||||
|
||||
# If there are more pages, save the cursor and return
|
||||
if result_next_url:
|
||||
new_checkpoint.next_url = result_next_url
|
||||
else:
|
||||
# Stage complete — advance to next stage (or next course if last).
|
||||
new_checkpoint.advance_stage()
|
||||
|
||||
new_checkpoint.has_more = new_checkpoint.current_course_index < len(
|
||||
new_checkpoint.course_ids
|
||||
)
|
||||
return new_checkpoint
|
||||
|
||||
@override
|
||||
def load_from_checkpoint(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch,
|
||||
end: SecondsSinceUnixEpoch,
|
||||
checkpoint: CanvasConnectorCheckpoint,
|
||||
) -> CheckpointOutput[CanvasConnectorCheckpoint]:
|
||||
return self._load_from_checkpoint(
|
||||
start, end, checkpoint, include_permissions=False
|
||||
)
|
||||
|
||||
@override
|
||||
def load_from_checkpoint_with_perm_sync(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch,
|
||||
end: SecondsSinceUnixEpoch,
|
||||
checkpoint: CanvasConnectorCheckpoint,
|
||||
) -> CheckpointOutput[CanvasConnectorCheckpoint]:
|
||||
"""Load documents from checkpoint with permission information included."""
|
||||
return self._load_from_checkpoint(
|
||||
start, end, checkpoint, include_permissions=True
|
||||
)
|
||||
|
||||
@override
|
||||
def build_dummy_checkpoint(self) -> CanvasConnectorCheckpoint:
|
||||
return CanvasConnectorCheckpoint(has_more=True)
|
||||
|
||||
@override
|
||||
def validate_checkpoint_json(
|
||||
self, checkpoint_json: str
|
||||
) -> CanvasConnectorCheckpoint:
|
||||
return CanvasConnectorCheckpoint.model_validate_json(checkpoint_json)
|
||||
|
||||
@override
|
||||
def validate_connector_settings(self) -> None:
|
||||
"""Validate Canvas connector settings by testing API access."""
|
||||
@@ -415,38 +771,6 @@ class CanvasConnector(
|
||||
f"Unexpected error during Canvas settings validation: {exc}"
|
||||
)
|
||||
|
||||
@override
|
||||
def load_from_checkpoint(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch,
|
||||
end: SecondsSinceUnixEpoch,
|
||||
checkpoint: CanvasConnectorCheckpoint,
|
||||
) -> CheckpointOutput[CanvasConnectorCheckpoint]:
|
||||
# TODO(benwu408): implemented in PR3 (checkpoint)
|
||||
raise NotImplementedError
|
||||
|
||||
@override
|
||||
def load_from_checkpoint_with_perm_sync(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch,
|
||||
end: SecondsSinceUnixEpoch,
|
||||
checkpoint: CanvasConnectorCheckpoint,
|
||||
) -> CheckpointOutput[CanvasConnectorCheckpoint]:
|
||||
# TODO(benwu408): implemented in PR3 (checkpoint)
|
||||
raise NotImplementedError
|
||||
|
||||
@override
|
||||
def build_dummy_checkpoint(self) -> CanvasConnectorCheckpoint:
|
||||
# TODO(benwu408): implemented in PR3 (checkpoint)
|
||||
raise NotImplementedError
|
||||
|
||||
@override
|
||||
def validate_checkpoint_json(
|
||||
self, checkpoint_json: str
|
||||
) -> CanvasConnectorCheckpoint:
|
||||
# TODO(benwu408): implemented in PR3 (checkpoint)
|
||||
raise NotImplementedError
|
||||
|
||||
@override
|
||||
def retrieve_all_slim_docs_perm_sync(
|
||||
self,
|
||||
|
||||
@@ -171,7 +171,10 @@ class ClickupConnector(LoadConnector, PollConnector):
|
||||
document.metadata[extra_field] = task[extra_field]
|
||||
|
||||
if self.retrieve_task_comments:
|
||||
document.sections.extend(self._get_task_comments(task["id"]))
|
||||
document.sections = [
|
||||
*document.sections,
|
||||
*self._get_task_comments(task["id"]),
|
||||
]
|
||||
|
||||
doc_batch.append(document)
|
||||
|
||||
|
||||
@@ -61,6 +61,9 @@ _USER_NOT_FOUND = "Unknown Confluence User"
|
||||
_USER_ID_TO_DISPLAY_NAME_CACHE: dict[str, str | None] = {}
|
||||
_USER_EMAIL_CACHE: dict[str, str | None] = {}
|
||||
_DEFAULT_PAGINATION_LIMIT = 1000
|
||||
_MINIMUM_PAGINATION_LIMIT = 5
|
||||
|
||||
_SERVER_ERROR_CODES = {500, 502, 503, 504}
|
||||
|
||||
_CONFLUENCE_SPACES_API_V1 = "rest/api/space"
|
||||
_CONFLUENCE_SPACES_API_V2 = "wiki/api/v2/spaces"
|
||||
@@ -569,7 +572,8 @@ class OnyxConfluence:
|
||||
if not limit:
|
||||
limit = _DEFAULT_PAGINATION_LIMIT
|
||||
|
||||
url_suffix = update_param_in_path(url_suffix, "limit", str(limit))
|
||||
current_limit = limit
|
||||
url_suffix = update_param_in_path(url_suffix, "limit", str(current_limit))
|
||||
|
||||
while url_suffix:
|
||||
logger.debug(f"Making confluence call to {url_suffix}")
|
||||
@@ -609,40 +613,61 @@ class OnyxConfluence:
|
||||
)
|
||||
continue
|
||||
|
||||
# If we fail due to a 500, try one by one.
|
||||
# NOTE: this iterative approach only works for server, since cloud uses cursor-based
|
||||
# pagination
|
||||
if raw_response.status_code == 500 and not self._is_cloud:
|
||||
initial_start = get_start_param_from_url(url_suffix)
|
||||
if initial_start is None:
|
||||
# can't handle this if we don't have offset-based pagination
|
||||
raise
|
||||
if raw_response.status_code in _SERVER_ERROR_CODES:
|
||||
# Try reducing the page size -- Confluence often times out
|
||||
# on large result sets (especially Cloud 504s).
|
||||
if current_limit > _MINIMUM_PAGINATION_LIMIT:
|
||||
old_limit = current_limit
|
||||
current_limit = max(
|
||||
current_limit // 2, _MINIMUM_PAGINATION_LIMIT
|
||||
)
|
||||
logger.warning(
|
||||
f"Confluence returned {raw_response.status_code}. "
|
||||
f"Reducing limit from {old_limit} to {current_limit} "
|
||||
f"and retrying."
|
||||
)
|
||||
url_suffix = update_param_in_path(
|
||||
url_suffix, "limit", str(current_limit)
|
||||
)
|
||||
continue
|
||||
|
||||
# this will just yield the successful items from the batch
|
||||
new_url_suffix = yield from self._try_one_by_one_for_paginated_url(
|
||||
url_suffix,
|
||||
initial_start=initial_start,
|
||||
limit=limit,
|
||||
)
|
||||
# Limit reduction exhausted -- for Server, fall back to
|
||||
# one-by-one offset pagination as a last resort.
|
||||
if not self._is_cloud:
|
||||
initial_start = get_start_param_from_url(url_suffix)
|
||||
# this will just yield the successful items from the batch
|
||||
new_url_suffix = (
|
||||
yield from self._try_one_by_one_for_paginated_url(
|
||||
url_suffix,
|
||||
initial_start=initial_start,
|
||||
limit=current_limit,
|
||||
)
|
||||
)
|
||||
# this means we ran into an empty page
|
||||
if new_url_suffix is None:
|
||||
if next_page_callback:
|
||||
next_page_callback("")
|
||||
break
|
||||
|
||||
# this means we ran into an empty page
|
||||
if new_url_suffix is None:
|
||||
if next_page_callback:
|
||||
next_page_callback("")
|
||||
break
|
||||
url_suffix = new_url_suffix
|
||||
continue
|
||||
|
||||
url_suffix = new_url_suffix
|
||||
continue
|
||||
|
||||
else:
|
||||
logger.exception(
|
||||
f"Error in confluence call to {url_suffix} \n"
|
||||
f"Raw Response Text: {raw_response.text} \n"
|
||||
f"Full Response: {raw_response.__dict__} \n"
|
||||
f"Error: {e} \n"
|
||||
f"Error in confluence call to {url_suffix} "
|
||||
f"after reducing limit to {current_limit}.\n"
|
||||
f"Raw Response Text: {raw_response.text}\n"
|
||||
f"Error: {e}\n"
|
||||
)
|
||||
raise
|
||||
|
||||
logger.exception(
|
||||
f"Error in confluence call to {url_suffix} \n"
|
||||
f"Raw Response Text: {raw_response.text} \n"
|
||||
f"Full Response: {raw_response.__dict__} \n"
|
||||
f"Error: {e} \n"
|
||||
)
|
||||
raise
|
||||
|
||||
try:
|
||||
next_response = raw_response.json()
|
||||
except Exception as e:
|
||||
@@ -680,6 +705,10 @@ class OnyxConfluence:
|
||||
old_url_suffix = url_suffix
|
||||
updated_start = get_start_param_from_url(old_url_suffix)
|
||||
url_suffix = cast(str, next_response.get("_links", {}).get("next", ""))
|
||||
if url_suffix and current_limit != limit:
|
||||
url_suffix = update_param_in_path(
|
||||
url_suffix, "limit", str(current_limit)
|
||||
)
|
||||
for i, result in enumerate(results):
|
||||
updated_start += 1
|
||||
if url_suffix and next_page_callback and i == len(results) - 1:
|
||||
|
||||
@@ -0,0 +1,65 @@
|
||||
import csv
|
||||
import io
|
||||
from typing import IO
|
||||
|
||||
from onyx.connectors.models import TabularSection
|
||||
from onyx.file_processing.extract_file_text import file_io_to_text
|
||||
from onyx.file_processing.extract_file_text import xlsx_sheet_extraction
|
||||
from onyx.file_processing.file_types import OnyxFileExtensions
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def is_tabular_file(file_name: str) -> bool:
|
||||
lowered = file_name.lower()
|
||||
return any(lowered.endswith(ext) for ext in OnyxFileExtensions.TABULAR_EXTENSIONS)
|
||||
|
||||
|
||||
def _tsv_to_csv(tsv_text: str) -> str:
|
||||
"""Re-serialize tab-separated text as CSV so downstream parsers that
|
||||
assume the default Excel dialect read the columns correctly."""
|
||||
out = io.StringIO()
|
||||
csv.writer(out, lineterminator="\n").writerows(
|
||||
csv.reader(io.StringIO(tsv_text), dialect="excel-tab")
|
||||
)
|
||||
return out.getvalue().rstrip("\n")
|
||||
|
||||
|
||||
def tabular_file_to_sections(
|
||||
file: IO[bytes],
|
||||
file_name: str,
|
||||
link: str = "",
|
||||
) -> list[TabularSection]:
|
||||
"""Convert a tabular file into one or more TabularSections.
|
||||
|
||||
- .xlsx → one TabularSection per non-empty sheet.
|
||||
- .csv / .tsv → a single TabularSection containing the full decoded
|
||||
file.
|
||||
|
||||
Returns an empty list when the file yields no extractable content.
|
||||
"""
|
||||
lowered = file_name.lower()
|
||||
|
||||
if lowered.endswith(".xlsx"):
|
||||
return [
|
||||
TabularSection(link=f"{file_name} :: {sheet_title}", text=csv_text)
|
||||
for csv_text, sheet_title in xlsx_sheet_extraction(
|
||||
file, file_name=file_name
|
||||
)
|
||||
]
|
||||
|
||||
if not lowered.endswith((".csv", ".tsv")):
|
||||
raise ValueError(f"{file_name!r} is not a tabular file")
|
||||
|
||||
try:
|
||||
text = file_io_to_text(file).strip()
|
||||
except Exception:
|
||||
logger.exception(f"Failure decoding {file_name}")
|
||||
raise
|
||||
|
||||
if not text:
|
||||
return []
|
||||
if lowered.endswith(".tsv"):
|
||||
text = _tsv_to_csv(text)
|
||||
return [TabularSection(link=link or file_name, text=text)]
|
||||
@@ -42,6 +42,9 @@ from onyx.connectors.google_drive.file_retrieval import (
|
||||
get_all_files_in_my_drive_and_shared,
|
||||
)
|
||||
from onyx.connectors.google_drive.file_retrieval import get_external_access_for_folder
|
||||
from onyx.connectors.google_drive.file_retrieval import (
|
||||
get_files_by_web_view_links_batch,
|
||||
)
|
||||
from onyx.connectors.google_drive.file_retrieval import get_files_in_shared_drive
|
||||
from onyx.connectors.google_drive.file_retrieval import get_folder_metadata
|
||||
from onyx.connectors.google_drive.file_retrieval import get_root_folder_id
|
||||
@@ -70,11 +73,14 @@ from onyx.connectors.interfaces import CheckpointedConnectorWithPermSync
|
||||
from onyx.connectors.interfaces import CheckpointOutput
|
||||
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.connectors.interfaces import NormalizationResult
|
||||
from onyx.connectors.interfaces import Resolver
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.interfaces import SlimConnector
|
||||
from onyx.connectors.interfaces import SlimConnectorWithPermSync
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import DocumentFailure
|
||||
from onyx.connectors.models import EntityFailure
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.connectors.models import SlimDocument
|
||||
@@ -202,7 +208,10 @@ class DriveIdStatus(Enum):
|
||||
|
||||
|
||||
class GoogleDriveConnector(
|
||||
SlimConnectorWithPermSync, CheckpointedConnectorWithPermSync[GoogleDriveCheckpoint]
|
||||
SlimConnector,
|
||||
SlimConnectorWithPermSync,
|
||||
CheckpointedConnectorWithPermSync[GoogleDriveCheckpoint],
|
||||
Resolver,
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -1665,12 +1674,89 @@ class GoogleDriveConnector(
|
||||
start, end, checkpoint, include_permissions=True
|
||||
)
|
||||
|
||||
@override
|
||||
def resolve_errors(
|
||||
self,
|
||||
errors: list[ConnectorFailure],
|
||||
include_permissions: bool = False,
|
||||
) -> Generator[Document | ConnectorFailure | HierarchyNode, None, None]:
|
||||
if self._creds is None or self._primary_admin_email is None:
|
||||
raise RuntimeError(
|
||||
"Credentials missing, should not call this method before calling load_credentials"
|
||||
)
|
||||
|
||||
logger.info(f"Resolving {len(errors)} errors")
|
||||
doc_ids = [
|
||||
failure.failed_document.document_id
|
||||
for failure in errors
|
||||
if failure.failed_document
|
||||
]
|
||||
service = get_drive_service(self.creds, self.primary_admin_email)
|
||||
field_type = (
|
||||
DriveFileFieldType.WITH_PERMISSIONS
|
||||
if include_permissions or self.exclude_domain_link_only
|
||||
else DriveFileFieldType.STANDARD
|
||||
)
|
||||
batch_result = get_files_by_web_view_links_batch(service, doc_ids, field_type)
|
||||
|
||||
for doc_id, error in batch_result.errors.items():
|
||||
yield ConnectorFailure(
|
||||
failed_document=DocumentFailure(
|
||||
document_id=doc_id,
|
||||
document_link=doc_id,
|
||||
),
|
||||
failure_message=f"Failed to retrieve file during error resolution: {error}",
|
||||
exception=error,
|
||||
)
|
||||
|
||||
permission_sync_context = (
|
||||
PermissionSyncContext(
|
||||
primary_admin_email=self.primary_admin_email,
|
||||
google_domain=self.google_domain,
|
||||
)
|
||||
if include_permissions
|
||||
else None
|
||||
)
|
||||
|
||||
retrieved_files = [
|
||||
RetrievedDriveFile(
|
||||
drive_file=file,
|
||||
user_email=self.primary_admin_email,
|
||||
completion_stage=DriveRetrievalStage.DONE,
|
||||
)
|
||||
for file in batch_result.files.values()
|
||||
]
|
||||
|
||||
yield from self._get_new_ancestors_for_files(
|
||||
files=retrieved_files,
|
||||
seen_hierarchy_node_raw_ids=ThreadSafeSet(),
|
||||
fully_walked_hierarchy_node_raw_ids=ThreadSafeSet(),
|
||||
permission_sync_context=permission_sync_context,
|
||||
add_prefix=True,
|
||||
)
|
||||
|
||||
func_with_args = [
|
||||
(
|
||||
self._convert_retrieved_file_to_document,
|
||||
(rf, permission_sync_context),
|
||||
)
|
||||
for rf in retrieved_files
|
||||
]
|
||||
results = cast(
|
||||
list[Document | ConnectorFailure | None],
|
||||
run_functions_tuples_in_parallel(func_with_args, max_workers=8),
|
||||
)
|
||||
for result in results:
|
||||
if result is not None:
|
||||
yield result
|
||||
|
||||
def _extract_slim_docs_from_google_drive(
|
||||
self,
|
||||
checkpoint: GoogleDriveCheckpoint,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
include_permissions: bool = True,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
files_batch: list[RetrievedDriveFile] = []
|
||||
slim_batch: list[SlimDocument | HierarchyNode] = []
|
||||
@@ -1680,9 +1766,13 @@ class GoogleDriveConnector(
|
||||
nonlocal files_batch, slim_batch
|
||||
|
||||
# Get new ancestor hierarchy nodes first
|
||||
permission_sync_context = PermissionSyncContext(
|
||||
primary_admin_email=self.primary_admin_email,
|
||||
google_domain=self.google_domain,
|
||||
permission_sync_context = (
|
||||
PermissionSyncContext(
|
||||
primary_admin_email=self.primary_admin_email,
|
||||
google_domain=self.google_domain,
|
||||
)
|
||||
if include_permissions
|
||||
else None
|
||||
)
|
||||
new_ancestors = self._get_new_ancestors_for_files(
|
||||
files=files_batch,
|
||||
@@ -1696,10 +1786,7 @@ class GoogleDriveConnector(
|
||||
if doc := build_slim_document(
|
||||
self.creds,
|
||||
file.drive_file,
|
||||
PermissionSyncContext(
|
||||
primary_admin_email=self.primary_admin_email,
|
||||
google_domain=self.google_domain,
|
||||
),
|
||||
permission_sync_context,
|
||||
retriever_email=file.user_email,
|
||||
):
|
||||
slim_batch.append(doc)
|
||||
@@ -1739,11 +1826,12 @@ class GoogleDriveConnector(
|
||||
if files_batch:
|
||||
yield _yield_slim_batch()
|
||||
|
||||
def retrieve_all_slim_docs_perm_sync(
|
||||
def _retrieve_all_slim_docs_impl(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
include_permissions: bool = True,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
try:
|
||||
checkpoint = self.build_dummy_checkpoint()
|
||||
@@ -1753,13 +1841,34 @@ class GoogleDriveConnector(
|
||||
start=start,
|
||||
end=end,
|
||||
callback=callback,
|
||||
include_permissions=include_permissions,
|
||||
)
|
||||
logger.info("Drive perm sync: Slim doc retrieval complete")
|
||||
|
||||
logger.info("Drive slim doc retrieval complete")
|
||||
except Exception as e:
|
||||
if MISSING_SCOPES_ERROR_STR in str(e):
|
||||
raise PermissionError(ONYX_SCOPE_INSTRUCTIONS) from e
|
||||
raise e
|
||||
raise
|
||||
|
||||
@override
|
||||
def retrieve_all_slim_docs(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
return self._retrieve_all_slim_docs_impl(
|
||||
start=start, end=end, callback=callback, include_permissions=False
|
||||
)
|
||||
|
||||
def retrieve_all_slim_docs_perm_sync(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
return self._retrieve_all_slim_docs_impl(
|
||||
start=start, end=end, callback=callback, include_permissions=True
|
||||
)
|
||||
|
||||
def validate_connector_settings(self) -> None:
|
||||
if self._creds is None:
|
||||
|
||||
@@ -9,6 +9,7 @@ from urllib.parse import urlparse
|
||||
|
||||
from googleapiclient.discovery import Resource # type: ignore
|
||||
from googleapiclient.errors import HttpError # type: ignore
|
||||
from googleapiclient.http import BatchHttpRequest # type: ignore
|
||||
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.connectors.google_drive.constants import DRIVE_FOLDER_TYPE
|
||||
@@ -60,6 +61,8 @@ SLIM_FILE_FIELDS = (
|
||||
)
|
||||
FOLDER_FIELDS = "nextPageToken, files(id, name, permissions, modifiedTime, webViewLink, shortcutDetails)"
|
||||
|
||||
MAX_BATCH_SIZE = 100
|
||||
|
||||
HIERARCHY_FIELDS = "id, name, parents, webViewLink, mimeType, driveId"
|
||||
|
||||
HIERARCHY_FIELDS_WITH_PERMISSIONS = (
|
||||
@@ -216,7 +219,7 @@ def get_external_access_for_folder(
|
||||
|
||||
|
||||
def _get_fields_for_file_type(field_type: DriveFileFieldType) -> str:
|
||||
"""Get the appropriate fields string based on the field type enum"""
|
||||
"""Get the appropriate fields string for files().list() based on the field type enum."""
|
||||
if field_type == DriveFileFieldType.SLIM:
|
||||
return SLIM_FILE_FIELDS
|
||||
elif field_type == DriveFileFieldType.WITH_PERMISSIONS:
|
||||
@@ -225,6 +228,25 @@ def _get_fields_for_file_type(field_type: DriveFileFieldType) -> str:
|
||||
return FILE_FIELDS
|
||||
|
||||
|
||||
def _extract_single_file_fields(list_fields: str) -> str:
|
||||
"""Convert a files().list() fields string to one suitable for files().get().
|
||||
|
||||
List fields look like "nextPageToken, files(field1, field2, ...)"
|
||||
Single-file fields should be just "field1, field2, ..."
|
||||
"""
|
||||
start = list_fields.find("files(")
|
||||
if start == -1:
|
||||
return list_fields
|
||||
inner_start = start + len("files(")
|
||||
inner_end = list_fields.rfind(")")
|
||||
return list_fields[inner_start:inner_end]
|
||||
|
||||
|
||||
def _get_single_file_fields(field_type: DriveFileFieldType) -> str:
|
||||
"""Get the appropriate fields string for files().get() based on the field type enum."""
|
||||
return _extract_single_file_fields(_get_fields_for_file_type(field_type))
|
||||
|
||||
|
||||
def _get_files_in_parent(
|
||||
service: Resource,
|
||||
parent_id: str,
|
||||
@@ -536,3 +558,74 @@ def get_file_by_web_view_link(
|
||||
)
|
||||
.execute()
|
||||
)
|
||||
|
||||
|
||||
class BatchRetrievalResult:
|
||||
"""Result of a batch file retrieval, separating successes from errors."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.files: dict[str, GoogleDriveFileType] = {}
|
||||
self.errors: dict[str, Exception] = {}
|
||||
|
||||
|
||||
def get_files_by_web_view_links_batch(
|
||||
service: GoogleDriveService,
|
||||
web_view_links: list[str],
|
||||
field_type: DriveFileFieldType,
|
||||
) -> BatchRetrievalResult:
|
||||
"""Retrieve multiple Google Drive files by webViewLink using the batch API.
|
||||
|
||||
Returns a BatchRetrievalResult containing successful file retrievals
|
||||
and errors for any files that could not be fetched.
|
||||
Automatically splits into chunks of MAX_BATCH_SIZE.
|
||||
"""
|
||||
fields = _get_single_file_fields(field_type)
|
||||
if len(web_view_links) <= MAX_BATCH_SIZE:
|
||||
return _get_files_by_web_view_links_batch(service, web_view_links, fields)
|
||||
|
||||
combined = BatchRetrievalResult()
|
||||
for i in range(0, len(web_view_links), MAX_BATCH_SIZE):
|
||||
chunk = web_view_links[i : i + MAX_BATCH_SIZE]
|
||||
chunk_result = _get_files_by_web_view_links_batch(service, chunk, fields)
|
||||
combined.files.update(chunk_result.files)
|
||||
combined.errors.update(chunk_result.errors)
|
||||
return combined
|
||||
|
||||
|
||||
def _get_files_by_web_view_links_batch(
|
||||
service: GoogleDriveService,
|
||||
web_view_links: list[str],
|
||||
fields: str,
|
||||
) -> BatchRetrievalResult:
|
||||
"""Single-batch implementation."""
|
||||
|
||||
result = BatchRetrievalResult()
|
||||
|
||||
def callback(
|
||||
request_id: str,
|
||||
response: GoogleDriveFileType,
|
||||
exception: Exception | None,
|
||||
) -> None:
|
||||
if exception:
|
||||
logger.warning(f"Error retrieving file {request_id}: {exception}")
|
||||
result.errors[request_id] = exception
|
||||
else:
|
||||
result.files[request_id] = response
|
||||
|
||||
batch = cast(BatchHttpRequest, service.new_batch_http_request(callback=callback))
|
||||
|
||||
for web_view_link in web_view_links:
|
||||
try:
|
||||
file_id = _extract_file_id_from_web_view_link(web_view_link)
|
||||
request = service.files().get(
|
||||
fileId=file_id,
|
||||
supportsAllDrives=True,
|
||||
fields=fields,
|
||||
)
|
||||
batch.add(request, request_id=web_view_link)
|
||||
except ValueError as e:
|
||||
logger.warning(f"Failed to extract file ID from {web_view_link}: {e}")
|
||||
result.errors[web_view_link] = e
|
||||
|
||||
batch.execute()
|
||||
return result
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import json
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from urllib.parse import parse_qs
|
||||
from urllib.parse import ParseResult
|
||||
@@ -53,6 +54,21 @@ from onyx.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _load_google_json(raw: object) -> dict[str, Any]:
|
||||
"""Accept both the current (dict) and legacy (JSON string) KV payload shapes.
|
||||
|
||||
Payloads written before the fix for serializing Google credentials into
|
||||
``EncryptedJson`` columns are stored as JSON strings; new writes store dicts.
|
||||
Once every install has re-uploaded their Google credentials the legacy
|
||||
``str`` branch can be removed.
|
||||
"""
|
||||
if isinstance(raw, dict):
|
||||
return raw
|
||||
if isinstance(raw, str):
|
||||
return json.loads(raw)
|
||||
raise ValueError(f"Unexpected Google credential payload type: {type(raw)!r}")
|
||||
|
||||
|
||||
def _build_frontend_google_drive_redirect(source: DocumentSource) -> str:
|
||||
if source == DocumentSource.GOOGLE_DRIVE:
|
||||
return f"{WEB_DOMAIN}/admin/connectors/google-drive/auth/callback"
|
||||
@@ -162,12 +178,13 @@ def build_service_account_creds(
|
||||
|
||||
def get_auth_url(credential_id: int, source: DocumentSource) -> str:
|
||||
if source == DocumentSource.GOOGLE_DRIVE:
|
||||
creds_str = str(get_kv_store().load(KV_GOOGLE_DRIVE_CRED_KEY))
|
||||
credential_json = _load_google_json(
|
||||
get_kv_store().load(KV_GOOGLE_DRIVE_CRED_KEY)
|
||||
)
|
||||
elif source == DocumentSource.GMAIL:
|
||||
creds_str = str(get_kv_store().load(KV_GMAIL_CRED_KEY))
|
||||
credential_json = _load_google_json(get_kv_store().load(KV_GMAIL_CRED_KEY))
|
||||
else:
|
||||
raise ValueError(f"Unsupported source: {source}")
|
||||
credential_json = json.loads(creds_str)
|
||||
flow = InstalledAppFlow.from_client_config(
|
||||
credential_json,
|
||||
scopes=GOOGLE_SCOPES[source],
|
||||
@@ -188,12 +205,12 @@ def get_auth_url(credential_id: int, source: DocumentSource) -> str:
|
||||
|
||||
def get_google_app_cred(source: DocumentSource) -> GoogleAppCredentials:
|
||||
if source == DocumentSource.GOOGLE_DRIVE:
|
||||
creds_str = str(get_kv_store().load(KV_GOOGLE_DRIVE_CRED_KEY))
|
||||
creds = _load_google_json(get_kv_store().load(KV_GOOGLE_DRIVE_CRED_KEY))
|
||||
elif source == DocumentSource.GMAIL:
|
||||
creds_str = str(get_kv_store().load(KV_GMAIL_CRED_KEY))
|
||||
creds = _load_google_json(get_kv_store().load(KV_GMAIL_CRED_KEY))
|
||||
else:
|
||||
raise ValueError(f"Unsupported source: {source}")
|
||||
return GoogleAppCredentials(**json.loads(creds_str))
|
||||
return GoogleAppCredentials(**creds)
|
||||
|
||||
|
||||
def upsert_google_app_cred(
|
||||
@@ -201,10 +218,14 @@ def upsert_google_app_cred(
|
||||
) -> None:
|
||||
if source == DocumentSource.GOOGLE_DRIVE:
|
||||
get_kv_store().store(
|
||||
KV_GOOGLE_DRIVE_CRED_KEY, app_credentials.json(), encrypt=True
|
||||
KV_GOOGLE_DRIVE_CRED_KEY,
|
||||
app_credentials.model_dump(mode="json"),
|
||||
encrypt=True,
|
||||
)
|
||||
elif source == DocumentSource.GMAIL:
|
||||
get_kv_store().store(KV_GMAIL_CRED_KEY, app_credentials.json(), encrypt=True)
|
||||
get_kv_store().store(
|
||||
KV_GMAIL_CRED_KEY, app_credentials.model_dump(mode="json"), encrypt=True
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported source: {source}")
|
||||
|
||||
@@ -220,12 +241,14 @@ def delete_google_app_cred(source: DocumentSource) -> None:
|
||||
|
||||
def get_service_account_key(source: DocumentSource) -> GoogleServiceAccountKey:
|
||||
if source == DocumentSource.GOOGLE_DRIVE:
|
||||
creds_str = str(get_kv_store().load(KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY))
|
||||
creds = _load_google_json(
|
||||
get_kv_store().load(KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY)
|
||||
)
|
||||
elif source == DocumentSource.GMAIL:
|
||||
creds_str = str(get_kv_store().load(KV_GMAIL_SERVICE_ACCOUNT_KEY))
|
||||
creds = _load_google_json(get_kv_store().load(KV_GMAIL_SERVICE_ACCOUNT_KEY))
|
||||
else:
|
||||
raise ValueError(f"Unsupported source: {source}")
|
||||
return GoogleServiceAccountKey(**json.loads(creds_str))
|
||||
return GoogleServiceAccountKey(**creds)
|
||||
|
||||
|
||||
def upsert_service_account_key(
|
||||
@@ -234,12 +257,14 @@ def upsert_service_account_key(
|
||||
if source == DocumentSource.GOOGLE_DRIVE:
|
||||
get_kv_store().store(
|
||||
KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY,
|
||||
service_account_key.json(),
|
||||
service_account_key.model_dump(mode="json"),
|
||||
encrypt=True,
|
||||
)
|
||||
elif source == DocumentSource.GMAIL:
|
||||
get_kv_store().store(
|
||||
KV_GMAIL_SERVICE_ACCOUNT_KEY, service_account_key.json(), encrypt=True
|
||||
KV_GMAIL_SERVICE_ACCOUNT_KEY,
|
||||
service_account_key.model_dump(mode="json"),
|
||||
encrypt=True,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported source: {source}")
|
||||
|
||||
@@ -123,6 +123,9 @@ class SlimConnector(BaseConnector):
|
||||
@abc.abstractmethod
|
||||
def retrieve_all_slim_docs(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -298,6 +301,22 @@ class CheckpointedConnectorWithPermSync(CheckpointedConnector[CT]):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class Resolver(BaseConnector):
|
||||
@abc.abstractmethod
|
||||
def resolve_errors(
|
||||
self,
|
||||
errors: list[ConnectorFailure],
|
||||
include_permissions: bool = False,
|
||||
) -> Generator[Document | ConnectorFailure | HierarchyNode, None, None]:
|
||||
"""Attempts to yield back ALL the documents described by the errors, no checkpointing.
|
||||
|
||||
Caller's responsibility is to delete the old ConnectorFailures and replace with the new ones.
|
||||
If include_permissions is True, the documents will have permissions synced.
|
||||
May also yield HierarchyNode objects for ancestor folders of resolved documents.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class HierarchyConnector(BaseConnector):
|
||||
@abc.abstractmethod
|
||||
def load_hierarchy(
|
||||
|
||||
@@ -60,8 +60,10 @@ logger = setup_logger()
|
||||
|
||||
ONE_HOUR = 3600
|
||||
|
||||
_MAX_RESULTS_FETCH_IDS = 5000 # 5000
|
||||
_MAX_RESULTS_FETCH_IDS = 5000
|
||||
_JIRA_FULL_PAGE_SIZE = 50
|
||||
# https://developer.atlassian.com/cloud/jira/platform/rest/v3/api-group-issues/
|
||||
_JIRA_BULK_FETCH_LIMIT = 100
|
||||
|
||||
# Constants for Jira field names
|
||||
_FIELD_REPORTER = "reporter"
|
||||
@@ -255,15 +257,13 @@ def _bulk_fetch_request(
|
||||
return resp.json()["issues"]
|
||||
|
||||
|
||||
def bulk_fetch_issues(
|
||||
jira_client: JIRA, issue_ids: list[str], fields: str | None = None
|
||||
) -> list[Issue]:
|
||||
# TODO(evan): move away from this jira library if they continue to not support
|
||||
# the endpoints we need. Using private fields is not ideal, but
|
||||
# is likely fine for now since we pin the library version
|
||||
|
||||
def _bulk_fetch_batch(
|
||||
jira_client: JIRA, issue_ids: list[str], fields: str | None
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Fetch a single batch (must be <= _JIRA_BULK_FETCH_LIMIT).
|
||||
On JSONDecodeError, recursively bisects until it succeeds or reaches size 1."""
|
||||
try:
|
||||
raw_issues = _bulk_fetch_request(jira_client, issue_ids, fields)
|
||||
return _bulk_fetch_request(jira_client, issue_ids, fields)
|
||||
except requests.exceptions.JSONDecodeError:
|
||||
if len(issue_ids) <= 1:
|
||||
logger.exception(
|
||||
@@ -277,12 +277,25 @@ def bulk_fetch_issues(
|
||||
f"Jira bulk-fetch JSON decode failed for batch of {len(issue_ids)} issues. "
|
||||
f"Splitting into sub-batches of {mid} and {len(issue_ids) - mid}."
|
||||
)
|
||||
left = bulk_fetch_issues(jira_client, issue_ids[:mid], fields)
|
||||
right = bulk_fetch_issues(jira_client, issue_ids[mid:], fields)
|
||||
left = _bulk_fetch_batch(jira_client, issue_ids[:mid], fields)
|
||||
right = _bulk_fetch_batch(jira_client, issue_ids[mid:], fields)
|
||||
return left + right
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching issues: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def bulk_fetch_issues(
|
||||
jira_client: JIRA, issue_ids: list[str], fields: str | None = None
|
||||
) -> list[Issue]:
|
||||
# TODO(evan): move away from this jira library if they continue to not support
|
||||
# the endpoints we need. Using private fields is not ideal, but
|
||||
# is likely fine for now since we pin the library version
|
||||
|
||||
raw_issues: list[dict[str, Any]] = []
|
||||
for batch in chunked(issue_ids, _JIRA_BULK_FETCH_LIMIT):
|
||||
try:
|
||||
raw_issues.extend(_bulk_fetch_batch(jira_client, list(batch), fields))
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching issues: {e}")
|
||||
raise
|
||||
|
||||
return [
|
||||
Issue(jira_client._options, jira_client._session, raw=issue)
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
import sys
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
@@ -33,9 +35,18 @@ class ConnectorMissingCredentialError(PermissionError):
|
||||
)
|
||||
|
||||
|
||||
class SectionType(str, Enum):
|
||||
"""Discriminator for Section subclasses."""
|
||||
|
||||
TEXT = "text"
|
||||
IMAGE = "image"
|
||||
TABULAR = "tabular"
|
||||
|
||||
|
||||
class Section(BaseModel):
|
||||
"""Base section class with common attributes"""
|
||||
|
||||
type: SectionType
|
||||
link: str | None = None
|
||||
text: str | None = None
|
||||
image_file_id: str | None = None
|
||||
@@ -44,6 +55,7 @@ class Section(BaseModel):
|
||||
class TextSection(Section):
|
||||
"""Section containing text content"""
|
||||
|
||||
type: Literal[SectionType.TEXT] = SectionType.TEXT
|
||||
text: str
|
||||
|
||||
def __sizeof__(self) -> int:
|
||||
@@ -53,12 +65,25 @@ class TextSection(Section):
|
||||
class ImageSection(Section):
|
||||
"""Section containing an image reference"""
|
||||
|
||||
type: Literal[SectionType.IMAGE] = SectionType.IMAGE
|
||||
image_file_id: str
|
||||
|
||||
def __sizeof__(self) -> int:
|
||||
return sys.getsizeof(self.image_file_id) + sys.getsizeof(self.link)
|
||||
|
||||
|
||||
class TabularSection(Section):
|
||||
"""Section containing tabular data (csv/tsv content, or one sheet of
|
||||
an xlsx workbook rendered as CSV)."""
|
||||
|
||||
type: Literal[SectionType.TABULAR] = SectionType.TABULAR
|
||||
text: str # CSV representation in a string
|
||||
link: str
|
||||
|
||||
def __sizeof__(self) -> int:
|
||||
return sys.getsizeof(self.text) + sys.getsizeof(self.link)
|
||||
|
||||
|
||||
class BasicExpertInfo(BaseModel):
|
||||
"""Basic Information for the owner of a document, any of the fields can be left as None
|
||||
Display fallback goes as follows:
|
||||
@@ -134,7 +159,6 @@ class BasicExpertInfo(BaseModel):
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, model_dict: dict[str, Any]) -> "BasicExpertInfo":
|
||||
|
||||
first_name = cast(str, model_dict.get("FirstName"))
|
||||
last_name = cast(str, model_dict.get("LastName"))
|
||||
email = cast(str, model_dict.get("Email"))
|
||||
@@ -161,7 +185,7 @@ class DocumentBase(BaseModel):
|
||||
"""Used for Onyx ingestion api, the ID is inferred before use if not provided"""
|
||||
|
||||
id: str | None = None
|
||||
sections: list[TextSection | ImageSection]
|
||||
sections: Sequence[TextSection | ImageSection | TabularSection]
|
||||
source: DocumentSource | None = None
|
||||
semantic_identifier: str # displayed in the UI as the main identifier for the doc
|
||||
# TODO(andrei): Ideally we could improve this to where each value is just a
|
||||
@@ -371,12 +395,9 @@ class IndexingDocument(Document):
|
||||
)
|
||||
else:
|
||||
section_len = sum(
|
||||
(
|
||||
len(section.text)
|
||||
if isinstance(section, TextSection) and section.text is not None
|
||||
else 0
|
||||
)
|
||||
len(section.text) if section.text is not None else 0
|
||||
for section in self.sections
|
||||
if isinstance(section, (TextSection, TabularSection))
|
||||
)
|
||||
|
||||
return title_len + section_len
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import TypedDict
|
||||
|
||||
@@ -6,6 +7,14 @@ from pydantic import BaseModel
|
||||
from onyx.onyxbot.slack.models import ChannelType
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DirectThreadFetch:
|
||||
"""Request to fetch a Slack thread directly by channel and timestamp."""
|
||||
|
||||
channel_id: str
|
||||
thread_ts: str
|
||||
|
||||
|
||||
class ChannelMetadata(TypedDict):
|
||||
"""Type definition for cached channel metadata."""
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ from onyx.configs.chat_configs import DOC_TIME_DECAY
|
||||
from onyx.connectors.models import IndexingDocument
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.context.search.federated.models import ChannelMetadata
|
||||
from onyx.context.search.federated.models import DirectThreadFetch
|
||||
from onyx.context.search.federated.models import SlackMessage
|
||||
from onyx.context.search.federated.slack_search_utils import ALL_CHANNEL_TYPES
|
||||
from onyx.context.search.federated.slack_search_utils import build_channel_query_filter
|
||||
@@ -49,7 +50,6 @@ from onyx.server.federated.models import FederatedConnectorDetail
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
from onyx.utils.timing import log_function_time
|
||||
from shared_configs.configs import DOC_EMBEDDING_CONTEXT_SIZE
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -58,7 +58,6 @@ HIGHLIGHT_END_CHAR = "\ue001"
|
||||
|
||||
CHANNEL_METADATA_CACHE_TTL = 60 * 60 * 24 # 24 hours
|
||||
USER_PROFILE_CACHE_TTL = 60 * 60 * 24 # 24 hours
|
||||
SLACK_THREAD_CONTEXT_WINDOW = 3 # Number of messages before matched message to include
|
||||
CHANNEL_METADATA_MAX_RETRIES = 3 # Maximum retry attempts for channel metadata fetching
|
||||
CHANNEL_METADATA_RETRY_DELAY = 1 # Initial retry delay in seconds (exponential backoff)
|
||||
|
||||
@@ -421,6 +420,94 @@ class SlackQueryResult(BaseModel):
|
||||
filtered_channels: list[str] # Channels filtered out during this query
|
||||
|
||||
|
||||
def _fetch_thread_from_url(
|
||||
thread_fetch: DirectThreadFetch,
|
||||
access_token: str,
|
||||
channel_metadata_dict: dict[str, ChannelMetadata] | None = None,
|
||||
) -> SlackQueryResult:
|
||||
"""Fetch a thread directly from a Slack URL via conversations.replies."""
|
||||
channel_id = thread_fetch.channel_id
|
||||
thread_ts = thread_fetch.thread_ts
|
||||
|
||||
slack_client = WebClient(token=access_token)
|
||||
try:
|
||||
response = slack_client.conversations_replies(
|
||||
channel=channel_id,
|
||||
ts=thread_ts,
|
||||
)
|
||||
response.validate()
|
||||
messages: list[dict[str, Any]] = response.get("messages", [])
|
||||
except SlackApiError as e:
|
||||
logger.warning(
|
||||
f"Failed to fetch thread from URL (channel={channel_id}, ts={thread_ts}): {e}"
|
||||
)
|
||||
return SlackQueryResult(messages=[], filtered_channels=[])
|
||||
|
||||
if not messages:
|
||||
logger.warning(
|
||||
f"No messages found for URL override (channel={channel_id}, ts={thread_ts})"
|
||||
)
|
||||
return SlackQueryResult(messages=[], filtered_channels=[])
|
||||
|
||||
# Build thread text from all messages
|
||||
thread_text = _build_thread_text(messages, access_token, None, slack_client)
|
||||
|
||||
# Get channel name from metadata cache or API
|
||||
channel_name = "unknown"
|
||||
if channel_metadata_dict and channel_id in channel_metadata_dict:
|
||||
channel_name = channel_metadata_dict[channel_id].get("name", "unknown")
|
||||
else:
|
||||
try:
|
||||
ch_response = slack_client.conversations_info(channel=channel_id)
|
||||
ch_response.validate()
|
||||
channel_info: dict[str, Any] = ch_response.get("channel", {})
|
||||
channel_name = channel_info.get("name", "unknown")
|
||||
except SlackApiError:
|
||||
pass
|
||||
|
||||
# Build the SlackMessage
|
||||
parent_msg = messages[0]
|
||||
message_ts = parent_msg.get("ts", thread_ts)
|
||||
username = parent_msg.get("user", "unknown_user")
|
||||
parent_text = parent_msg.get("text", "")
|
||||
snippet = (
|
||||
parent_text[:50].rstrip() + "..." if len(parent_text) > 50 else parent_text
|
||||
).replace("\n", " ")
|
||||
|
||||
doc_time = datetime.fromtimestamp(float(message_ts))
|
||||
decay_factor = DOC_TIME_DECAY
|
||||
doc_age_years = (datetime.now() - doc_time).total_seconds() / (365 * 24 * 60 * 60)
|
||||
recency_bias = max(1 / (1 + decay_factor * doc_age_years), 0.75)
|
||||
|
||||
permalink = (
|
||||
f"https://slack.com/archives/{channel_id}/p{message_ts.replace('.', '')}"
|
||||
)
|
||||
|
||||
slack_message = SlackMessage(
|
||||
document_id=f"{channel_id}_{message_ts}",
|
||||
channel_id=channel_id,
|
||||
message_id=message_ts,
|
||||
thread_id=None, # Prevent double-enrichment in thread context fetch
|
||||
link=permalink,
|
||||
metadata={
|
||||
"channel": channel_name,
|
||||
"time": doc_time.isoformat(),
|
||||
},
|
||||
timestamp=doc_time,
|
||||
recency_bias=recency_bias,
|
||||
semantic_identifier=f"{username} in #{channel_name}: {snippet}",
|
||||
text=thread_text,
|
||||
highlighted_texts=set(),
|
||||
slack_score=100000.0, # High priority — user explicitly asked for this thread
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"URL override: fetched thread from channel={channel_id}, ts={thread_ts}, {len(messages)} messages"
|
||||
)
|
||||
|
||||
return SlackQueryResult(messages=[slack_message], filtered_channels=[])
|
||||
|
||||
|
||||
def query_slack(
|
||||
query_string: str,
|
||||
access_token: str,
|
||||
@@ -432,7 +519,6 @@ def query_slack(
|
||||
available_channels: list[str] | None = None,
|
||||
channel_metadata_dict: dict[str, ChannelMetadata] | None = None,
|
||||
) -> SlackQueryResult:
|
||||
|
||||
# Check if query has channel override (user specified channels in query)
|
||||
has_channel_override = query_string.startswith("__CHANNEL_OVERRIDE__")
|
||||
|
||||
@@ -662,7 +748,6 @@ def _fetch_thread_context(
|
||||
"""
|
||||
channel_id = message.channel_id
|
||||
thread_id = message.thread_id
|
||||
message_id = message.message_id
|
||||
|
||||
# If not a thread, return original text as success
|
||||
if thread_id is None:
|
||||
@@ -695,62 +780,37 @@ def _fetch_thread_context(
|
||||
if len(messages) <= 1:
|
||||
return ThreadContextResult.success(message.text)
|
||||
|
||||
# Build thread text from thread starter + context window around matched message
|
||||
thread_text = _build_thread_text(
|
||||
messages, message_id, thread_id, access_token, team_id, slack_client
|
||||
)
|
||||
# Build thread text from thread starter + all replies
|
||||
thread_text = _build_thread_text(messages, access_token, team_id, slack_client)
|
||||
return ThreadContextResult.success(thread_text)
|
||||
|
||||
|
||||
def _build_thread_text(
|
||||
messages: list[dict[str, Any]],
|
||||
message_id: str,
|
||||
thread_id: str,
|
||||
access_token: str,
|
||||
team_id: str | None,
|
||||
slack_client: WebClient,
|
||||
) -> str:
|
||||
"""Build the thread text from messages."""
|
||||
"""Build thread text including all replies.
|
||||
|
||||
Includes the thread parent message followed by all replies in order.
|
||||
"""
|
||||
msg_text = messages[0].get("text", "")
|
||||
msg_sender = messages[0].get("user", "")
|
||||
thread_text = f"<@{msg_sender}>: {msg_text}"
|
||||
|
||||
# All messages after index 0 are replies
|
||||
replies = messages[1:]
|
||||
if not replies:
|
||||
return thread_text
|
||||
|
||||
logger.debug(f"Thread {messages[0].get('ts')}: {len(replies)} replies included")
|
||||
thread_text += "\n\nReplies:"
|
||||
if thread_id == message_id:
|
||||
message_id_idx = 0
|
||||
else:
|
||||
message_id_idx = next(
|
||||
(i for i, msg in enumerate(messages) if msg.get("ts") == message_id), 0
|
||||
)
|
||||
if not message_id_idx:
|
||||
return thread_text
|
||||
|
||||
start_idx = max(1, message_id_idx - SLACK_THREAD_CONTEXT_WINDOW)
|
||||
|
||||
if start_idx > 1:
|
||||
thread_text += "\n..."
|
||||
|
||||
for i in range(start_idx, message_id_idx):
|
||||
msg_text = messages[i].get("text", "")
|
||||
msg_sender = messages[i].get("user", "")
|
||||
thread_text += f"\n\n<@{msg_sender}>: {msg_text}"
|
||||
|
||||
msg_text = messages[message_id_idx].get("text", "")
|
||||
msg_sender = messages[message_id_idx].get("user", "")
|
||||
thread_text += f"\n\n<@{msg_sender}>: {msg_text}"
|
||||
|
||||
# Add following replies
|
||||
len_replies = 0
|
||||
for msg in messages[message_id_idx + 1 :]:
|
||||
for msg in replies:
|
||||
msg_text = msg.get("text", "")
|
||||
msg_sender = msg.get("user", "")
|
||||
reply = f"\n\n<@{msg_sender}>: {msg_text}"
|
||||
thread_text += reply
|
||||
|
||||
len_replies += len(reply)
|
||||
if len_replies >= DOC_EMBEDDING_CONTEXT_SIZE * 4:
|
||||
thread_text += "\n..."
|
||||
break
|
||||
thread_text += f"\n\n<@{msg_sender}>: {msg_text}"
|
||||
|
||||
# Replace user IDs with names using cached lookups
|
||||
userids: set[str] = set(re.findall(r"<@([A-Z0-9]+)>", thread_text))
|
||||
@@ -976,7 +1036,16 @@ def slack_retrieval(
|
||||
|
||||
# Query slack with entity filtering
|
||||
llm = get_default_llm()
|
||||
query_strings = build_slack_queries(query, llm, entities, available_channels)
|
||||
query_items = build_slack_queries(query, llm, entities, available_channels)
|
||||
|
||||
# Partition into direct thread fetches and search query strings
|
||||
direct_fetches: list[DirectThreadFetch] = []
|
||||
query_strings: list[str] = []
|
||||
for item in query_items:
|
||||
if isinstance(item, DirectThreadFetch):
|
||||
direct_fetches.append(item)
|
||||
else:
|
||||
query_strings.append(item)
|
||||
|
||||
# Determine filtering based on entities OR context (bot)
|
||||
include_dm = False
|
||||
@@ -993,8 +1062,16 @@ def slack_retrieval(
|
||||
f"Private channel context: will only allow messages from {allowed_private_channel} + public channels"
|
||||
)
|
||||
|
||||
# Build search tasks
|
||||
search_tasks = [
|
||||
# Build search tasks — direct thread fetches + keyword searches
|
||||
search_tasks: list[tuple] = [
|
||||
(
|
||||
_fetch_thread_from_url,
|
||||
(fetch, access_token, channel_metadata_dict),
|
||||
)
|
||||
for fetch in direct_fetches
|
||||
]
|
||||
|
||||
search_tasks.extend(
|
||||
(
|
||||
query_slack,
|
||||
(
|
||||
@@ -1010,7 +1087,7 @@ def slack_retrieval(
|
||||
),
|
||||
)
|
||||
for query_string in query_strings
|
||||
]
|
||||
)
|
||||
|
||||
# If include_dm is True AND we're not already searching all channels,
|
||||
# add additional searches without channel filters.
|
||||
|
||||
@@ -10,6 +10,7 @@ from pydantic import ValidationError
|
||||
|
||||
from onyx.configs.app_configs import MAX_SLACK_QUERY_EXPANSIONS
|
||||
from onyx.context.search.federated.models import ChannelMetadata
|
||||
from onyx.context.search.federated.models import DirectThreadFetch
|
||||
from onyx.context.search.models import ChunkIndexRequest
|
||||
from onyx.federated_connectors.slack.models import SlackEntities
|
||||
from onyx.llm.interfaces import LLM
|
||||
@@ -638,12 +639,38 @@ def expand_query_with_llm(query_text: str, llm: LLM) -> list[str]:
|
||||
return [query_text]
|
||||
|
||||
|
||||
SLACK_URL_PATTERN = re.compile(
|
||||
r"https?://[a-z0-9-]+\.slack\.com/archives/([A-Z0-9]+)/p(\d{16})"
|
||||
)
|
||||
|
||||
|
||||
def extract_slack_message_urls(
|
||||
query_text: str,
|
||||
) -> list[tuple[str, str]]:
|
||||
"""Extract Slack message URLs from query text.
|
||||
|
||||
Parses URLs like:
|
||||
https://onyx-company.slack.com/archives/C097NBWMY8Y/p1775491616524769
|
||||
|
||||
Returns list of (channel_id, thread_ts) tuples.
|
||||
The 16-digit timestamp is converted to Slack ts format (with dot).
|
||||
"""
|
||||
results = []
|
||||
for match in SLACK_URL_PATTERN.finditer(query_text):
|
||||
channel_id = match.group(1)
|
||||
raw_ts = match.group(2)
|
||||
# Convert p1775491616524769 -> 1775491616.524769
|
||||
thread_ts = f"{raw_ts[:10]}.{raw_ts[10:]}"
|
||||
results.append((channel_id, thread_ts))
|
||||
return results
|
||||
|
||||
|
||||
def build_slack_queries(
|
||||
query: ChunkIndexRequest,
|
||||
llm: LLM,
|
||||
entities: dict[str, Any] | None = None,
|
||||
available_channels: list[str] | None = None,
|
||||
) -> list[str]:
|
||||
) -> list[str | DirectThreadFetch]:
|
||||
"""Build Slack query strings with date filtering and query expansion."""
|
||||
default_search_days = 30
|
||||
if entities:
|
||||
@@ -668,6 +695,15 @@ def build_slack_queries(
|
||||
cutoff_date = datetime.now(timezone.utc) - timedelta(days=days_back)
|
||||
time_filter = f" after:{cutoff_date.strftime('%Y-%m-%d')}"
|
||||
|
||||
# Check for Slack message URLs — if found, add direct fetch requests
|
||||
url_fetches: list[DirectThreadFetch] = []
|
||||
slack_urls = extract_slack_message_urls(query.query)
|
||||
for channel_id, thread_ts in slack_urls:
|
||||
url_fetches.append(
|
||||
DirectThreadFetch(channel_id=channel_id, thread_ts=thread_ts)
|
||||
)
|
||||
logger.info(f"Detected Slack URL: channel={channel_id}, ts={thread_ts}")
|
||||
|
||||
# ALWAYS extract channel references from the query (not just for recency queries)
|
||||
channel_references = extract_channel_references_from_query(query.query)
|
||||
|
||||
@@ -684,7 +720,9 @@ def build_slack_queries(
|
||||
|
||||
# If valid channels detected, use ONLY those channels with NO keywords
|
||||
# Return query with ONLY time filter + channel filter (no keywords)
|
||||
return [build_channel_override_query(channel_references, time_filter)]
|
||||
return url_fetches + [
|
||||
build_channel_override_query(channel_references, time_filter)
|
||||
]
|
||||
except ValueError as e:
|
||||
# If validation fails, log the error and continue with normal flow
|
||||
logger.warning(f"Channel reference validation failed: {e}")
|
||||
@@ -702,7 +740,8 @@ def build_slack_queries(
|
||||
rephrased_queries = expand_query_with_llm(query.query, llm)
|
||||
|
||||
# Build final query strings with time filters
|
||||
return [
|
||||
search_queries = [
|
||||
rephrased_query.strip() + time_filter
|
||||
for rephrased_query in rephrased_queries[:MAX_SLACK_QUERY_EXPANSIONS]
|
||||
]
|
||||
return url_fetches + search_queries
|
||||
|
||||
@@ -750,31 +750,3 @@ def resync_cc_pair(
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
|
||||
# ── Metrics query helpers ──────────────────────────────────────────────
|
||||
|
||||
|
||||
def get_connector_health_for_metrics(
|
||||
db_session: Session,
|
||||
) -> list: # Returns list of Row tuples
|
||||
"""Return connector health data for Prometheus metrics.
|
||||
|
||||
Each row is (cc_pair_id, status, in_repeated_error_state,
|
||||
last_successful_index_time, name, source).
|
||||
"""
|
||||
return (
|
||||
db_session.query(
|
||||
ConnectorCredentialPair.id,
|
||||
ConnectorCredentialPair.status,
|
||||
ConnectorCredentialPair.in_repeated_error_state,
|
||||
ConnectorCredentialPair.last_successful_index_time,
|
||||
ConnectorCredentialPair.name,
|
||||
Connector.source,
|
||||
)
|
||||
.join(
|
||||
Connector,
|
||||
ConnectorCredentialPair.connector_id == Connector.id,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
@@ -335,6 +335,7 @@ def update_document_set(
|
||||
"Cannot update document set while it is syncing. Please wait for it to finish syncing, and then try again."
|
||||
)
|
||||
|
||||
document_set_row.name = document_set_update_request.name
|
||||
document_set_row.description = document_set_update_request.description
|
||||
if not DISABLE_VECTOR_DB:
|
||||
document_set_row.is_up_to_date = False
|
||||
|
||||
@@ -11,6 +11,7 @@ from sqlalchemy import event
|
||||
from sqlalchemy import pool
|
||||
from sqlalchemy.engine import create_engine
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.exc import DBAPIError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.app_configs import DB_READONLY_PASSWORD
|
||||
@@ -346,6 +347,25 @@ def get_session_with_shared_schema() -> Generator[Session, None, None]:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
|
||||
def _safe_close_session(session: Session) -> None:
|
||||
"""Close a session, catching connection-closed errors during cleanup.
|
||||
|
||||
Long-running operations (e.g. multi-model LLM loops) can hold a session
|
||||
open for minutes. If the underlying connection is dropped by cloud
|
||||
infrastructure (load-balancer timeouts, PgBouncer, idle-in-transaction
|
||||
timeouts, etc.), the implicit rollback in Session.close() raises
|
||||
OperationalError or InterfaceError. Since the work is already complete,
|
||||
we log and move on — SQLAlchemy internally invalidates the connection
|
||||
for pool recycling.
|
||||
"""
|
||||
try:
|
||||
session.close()
|
||||
except DBAPIError:
|
||||
logger.warning(
|
||||
"DB connection lost during session cleanup — the connection will be invalidated and recycled by the pool."
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_session_with_tenant(*, tenant_id: str) -> Generator[Session, None, None]:
|
||||
"""
|
||||
@@ -358,8 +378,11 @@ def get_session_with_tenant(*, tenant_id: str) -> Generator[Session, None, None]
|
||||
|
||||
# no need to use the schema translation map for self-hosted + default schema
|
||||
if not MULTI_TENANT and tenant_id == POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE:
|
||||
with Session(bind=engine, expire_on_commit=False) as session:
|
||||
session = Session(bind=engine, expire_on_commit=False)
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
_safe_close_session(session)
|
||||
return
|
||||
|
||||
# Create connection with schema translation to handle querying the right schema
|
||||
@@ -367,8 +390,11 @@ def get_session_with_tenant(*, tenant_id: str) -> Generator[Session, None, None]
|
||||
with engine.connect().execution_options(
|
||||
schema_translate_map=schema_translate_map
|
||||
) as connection:
|
||||
with Session(bind=connection, expire_on_commit=False) as session:
|
||||
session = Session(bind=connection, expire_on_commit=False)
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
_safe_close_session(session)
|
||||
|
||||
|
||||
def get_session() -> Generator[Session, None, None]:
|
||||
|
||||
@@ -2,8 +2,6 @@ from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from typing import NamedTuple
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TypeVarTuple
|
||||
|
||||
from sqlalchemy import and_
|
||||
@@ -30,17 +28,6 @@ from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.telemetry import optional_telemetry
|
||||
from onyx.utils.telemetry import RecordType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from onyx.configs.constants import DocumentSource
|
||||
|
||||
# from sqlalchemy.sql.selectable import Select
|
||||
|
||||
# Comment out unused imports that cause mypy errors
|
||||
# from onyx.auth.models import UserRole
|
||||
# from onyx.configs.constants import MAX_LAST_VALID_CHECKPOINT_AGE_SECONDS
|
||||
# from onyx.db.connector_credential_pair import ConnectorCredentialPairIdentifier
|
||||
# from onyx.db.engine import async_query_for_dms
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@@ -899,6 +886,7 @@ def create_index_attempt_error(
|
||||
failure: ConnectorFailure,
|
||||
db_session: Session,
|
||||
) -> int:
|
||||
exc = failure.exception
|
||||
new_error = IndexAttemptError(
|
||||
index_attempt_id=index_attempt_id,
|
||||
connector_credential_pair_id=connector_credential_pair_id,
|
||||
@@ -921,6 +909,7 @@ def create_index_attempt_error(
|
||||
),
|
||||
failure_message=failure.failure_message,
|
||||
is_resolved=False,
|
||||
error_type=type(exc).__name__ if exc else None,
|
||||
)
|
||||
db_session.add(new_error)
|
||||
db_session.commit()
|
||||
@@ -979,104 +968,48 @@ def get_index_attempt_errors_for_cc_pair(
|
||||
return list(db_session.scalars(stmt).all())
|
||||
|
||||
|
||||
# ── Metrics query helpers ──────────────────────────────────────────────
|
||||
|
||||
|
||||
class ActiveIndexAttemptMetric(NamedTuple):
|
||||
"""Row returned by get_active_index_attempts_for_metrics."""
|
||||
|
||||
status: IndexingStatus
|
||||
source: "DocumentSource"
|
||||
cc_pair_id: int
|
||||
cc_pair_name: str | None
|
||||
attempt_count: int
|
||||
|
||||
|
||||
def get_active_index_attempts_for_metrics(
|
||||
def get_index_attempt_errors_across_connectors(
|
||||
db_session: Session,
|
||||
) -> list[ActiveIndexAttemptMetric]:
|
||||
"""Return non-terminal index attempts grouped by status, source, and connector.
|
||||
cc_pair_id: int | None = None,
|
||||
error_type: str | None = None,
|
||||
start_time: datetime | None = None,
|
||||
end_time: datetime | None = None,
|
||||
unresolved_only: bool = True,
|
||||
page: int = 0,
|
||||
page_size: int = 25,
|
||||
) -> tuple[list[IndexAttemptError], int]:
|
||||
"""Query index attempt errors across all connectors with optional filters.
|
||||
|
||||
Each row is (status, source, cc_pair_id, cc_pair_name, attempt_count).
|
||||
Returns (errors, total_count) for pagination.
|
||||
"""
|
||||
from onyx.db.models import Connector
|
||||
stmt = select(IndexAttemptError)
|
||||
count_stmt = select(func.count()).select_from(IndexAttemptError)
|
||||
|
||||
terminal_statuses = [s for s in IndexingStatus if s.is_terminal()]
|
||||
rows = (
|
||||
db_session.query(
|
||||
IndexAttempt.status,
|
||||
Connector.source,
|
||||
ConnectorCredentialPair.id,
|
||||
ConnectorCredentialPair.name,
|
||||
func.count(),
|
||||
if cc_pair_id is not None:
|
||||
stmt = stmt.where(IndexAttemptError.connector_credential_pair_id == cc_pair_id)
|
||||
count_stmt = count_stmt.where(
|
||||
IndexAttemptError.connector_credential_pair_id == cc_pair_id
|
||||
)
|
||||
.join(
|
||||
ConnectorCredentialPair,
|
||||
IndexAttempt.connector_credential_pair_id == ConnectorCredentialPair.id,
|
||||
)
|
||||
.join(
|
||||
Connector,
|
||||
ConnectorCredentialPair.connector_id == Connector.id,
|
||||
)
|
||||
.filter(IndexAttempt.status.notin_(terminal_statuses))
|
||||
.group_by(
|
||||
IndexAttempt.status,
|
||||
Connector.source,
|
||||
ConnectorCredentialPair.id,
|
||||
ConnectorCredentialPair.name,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
return [ActiveIndexAttemptMetric(*row) for row in rows]
|
||||
|
||||
if error_type is not None:
|
||||
stmt = stmt.where(IndexAttemptError.error_type == error_type)
|
||||
count_stmt = count_stmt.where(IndexAttemptError.error_type == error_type)
|
||||
|
||||
def get_failed_attempt_counts_by_cc_pair(
|
||||
db_session: Session,
|
||||
since: datetime | None = None,
|
||||
) -> dict[int, int]:
|
||||
"""Return {cc_pair_id: failed_attempt_count} for all connectors.
|
||||
if unresolved_only:
|
||||
stmt = stmt.where(IndexAttemptError.is_resolved.is_(False))
|
||||
count_stmt = count_stmt.where(IndexAttemptError.is_resolved.is_(False))
|
||||
|
||||
When ``since`` is provided, only attempts created after that timestamp
|
||||
are counted. Defaults to the last 90 days to avoid unbounded historical
|
||||
aggregation.
|
||||
"""
|
||||
if since is None:
|
||||
since = datetime.now(timezone.utc) - timedelta(days=90)
|
||||
if start_time is not None:
|
||||
stmt = stmt.where(IndexAttemptError.time_created >= start_time)
|
||||
count_stmt = count_stmt.where(IndexAttemptError.time_created >= start_time)
|
||||
|
||||
rows = (
|
||||
db_session.query(
|
||||
IndexAttempt.connector_credential_pair_id,
|
||||
func.count(),
|
||||
)
|
||||
.filter(IndexAttempt.status == IndexingStatus.FAILED)
|
||||
.filter(IndexAttempt.time_created >= since)
|
||||
.group_by(IndexAttempt.connector_credential_pair_id)
|
||||
.all()
|
||||
)
|
||||
return {cc_id: count for cc_id, count in rows}
|
||||
if end_time is not None:
|
||||
stmt = stmt.where(IndexAttemptError.time_created <= end_time)
|
||||
count_stmt = count_stmt.where(IndexAttemptError.time_created <= end_time)
|
||||
|
||||
stmt = stmt.order_by(desc(IndexAttemptError.time_created))
|
||||
stmt = stmt.offset(page * page_size).limit(page_size)
|
||||
|
||||
def get_docs_indexed_by_cc_pair(
|
||||
db_session: Session,
|
||||
since: datetime | None = None,
|
||||
) -> dict[int, int]:
|
||||
"""Return {cc_pair_id: total_new_docs_indexed} across successful attempts.
|
||||
|
||||
Only counts attempts with status SUCCESS to avoid inflating counts with
|
||||
partial results from failed attempts. When ``since`` is provided, only
|
||||
attempts created after that timestamp are included.
|
||||
"""
|
||||
if since is None:
|
||||
since = datetime.now(timezone.utc) - timedelta(days=90)
|
||||
|
||||
query = (
|
||||
db_session.query(
|
||||
IndexAttempt.connector_credential_pair_id,
|
||||
func.sum(func.coalesce(IndexAttempt.new_docs_indexed, 0)),
|
||||
)
|
||||
.filter(IndexAttempt.status == IndexingStatus.SUCCESS)
|
||||
.filter(IndexAttempt.time_created >= since)
|
||||
.group_by(IndexAttempt.connector_credential_pair_id)
|
||||
)
|
||||
rows = query.all()
|
||||
return {cc_id: int(total or 0) for cc_id, total in rows}
|
||||
total = db_session.scalar(count_stmt) or 0
|
||||
errors = list(db_session.scalars(stmt).all())
|
||||
return errors, total
|
||||
|
||||
@@ -5,6 +5,7 @@ from pydantic import ConfigDict
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant_if_none
|
||||
from onyx.db.models import Memory
|
||||
from onyx.db.models import User
|
||||
|
||||
@@ -83,47 +84,51 @@ def get_memories(user: User, db_session: Session) -> UserMemoryContext:
|
||||
def add_memory(
|
||||
user_id: UUID,
|
||||
memory_text: str,
|
||||
db_session: Session,
|
||||
) -> Memory:
|
||||
db_session: Session | None = None,
|
||||
) -> int:
|
||||
"""Insert a new Memory row for the given user.
|
||||
|
||||
If the user already has MAX_MEMORIES_PER_USER memories, the oldest
|
||||
one (lowest id) is deleted before inserting the new one.
|
||||
|
||||
Returns the id of the newly created Memory row.
|
||||
"""
|
||||
existing = db_session.scalars(
|
||||
select(Memory).where(Memory.user_id == user_id).order_by(Memory.id.asc())
|
||||
).all()
|
||||
with get_session_with_current_tenant_if_none(db_session) as db_session:
|
||||
existing = db_session.scalars(
|
||||
select(Memory).where(Memory.user_id == user_id).order_by(Memory.id.asc())
|
||||
).all()
|
||||
|
||||
if len(existing) >= MAX_MEMORIES_PER_USER:
|
||||
db_session.delete(existing[0])
|
||||
if len(existing) >= MAX_MEMORIES_PER_USER:
|
||||
db_session.delete(existing[0])
|
||||
|
||||
memory = Memory(
|
||||
user_id=user_id,
|
||||
memory_text=memory_text,
|
||||
)
|
||||
db_session.add(memory)
|
||||
db_session.commit()
|
||||
return memory
|
||||
memory = Memory(
|
||||
user_id=user_id,
|
||||
memory_text=memory_text,
|
||||
)
|
||||
db_session.add(memory)
|
||||
db_session.commit()
|
||||
return memory.id
|
||||
|
||||
|
||||
def update_memory_at_index(
|
||||
user_id: UUID,
|
||||
index: int,
|
||||
new_text: str,
|
||||
db_session: Session,
|
||||
) -> Memory | None:
|
||||
db_session: Session | None = None,
|
||||
) -> int | None:
|
||||
"""Update the memory at the given 0-based index (ordered by id ASC, matching get_memories()).
|
||||
|
||||
Returns the updated Memory row, or None if the index is out of range.
|
||||
Returns the id of the updated Memory row, or None if the index is out of range.
|
||||
"""
|
||||
memory_rows = db_session.scalars(
|
||||
select(Memory).where(Memory.user_id == user_id).order_by(Memory.id.asc())
|
||||
).all()
|
||||
with get_session_with_current_tenant_if_none(db_session) as db_session:
|
||||
memory_rows = db_session.scalars(
|
||||
select(Memory).where(Memory.user_id == user_id).order_by(Memory.id.asc())
|
||||
).all()
|
||||
|
||||
if index < 0 or index >= len(memory_rows):
|
||||
return None
|
||||
if index < 0 or index >= len(memory_rows):
|
||||
return None
|
||||
|
||||
memory = memory_rows[index]
|
||||
memory.memory_text = new_text
|
||||
db_session.commit()
|
||||
return memory
|
||||
memory = memory_rows[index]
|
||||
memory.memory_text = new_text
|
||||
db_session.commit()
|
||||
return memory.id
|
||||
|
||||
@@ -2422,6 +2422,8 @@ class IndexAttemptError(Base):
|
||||
failure_message: Mapped[str] = mapped_column(Text)
|
||||
is_resolved: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
|
||||
error_type: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
|
||||
time_created: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
|
||||
@@ -7,8 +7,6 @@ import time
|
||||
from collections.abc import Callable
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.chat.chat_state import ChatStateContainer
|
||||
from onyx.chat.citation_processor import CitationMapping
|
||||
from onyx.chat.citation_processor import DynamicCitationProcessor
|
||||
@@ -22,6 +20,7 @@ from onyx.chat.models import LlmStepResult
|
||||
from onyx.chat.models import ToolCallSimple
|
||||
from onyx.configs.chat_configs import SKIP_DEEP_RESEARCH_CLARIFICATION
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.tools import get_tool_by_name
|
||||
from onyx.deep_research.dr_mock_tools import get_clarification_tool_definitions
|
||||
from onyx.deep_research.dr_mock_tools import get_orchestrator_tools
|
||||
@@ -184,6 +183,14 @@ def generate_final_report(
|
||||
return has_reasoned
|
||||
|
||||
|
||||
def _get_research_agent_tool_id() -> int:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
return get_tool_by_name(
|
||||
tool_name=RESEARCH_AGENT_TOOL_NAME,
|
||||
db_session=db_session,
|
||||
).id
|
||||
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def run_deep_research_llm_loop(
|
||||
emitter: Emitter,
|
||||
@@ -193,7 +200,6 @@ def run_deep_research_llm_loop(
|
||||
custom_agent_prompt: str | None, # noqa: ARG001
|
||||
llm: LLM,
|
||||
token_counter: Callable[[str], int],
|
||||
db_session: Session,
|
||||
skip_clarification: bool = False,
|
||||
user_identity: LLMUserIdentity | None = None,
|
||||
chat_session_id: str | None = None,
|
||||
@@ -717,6 +723,7 @@ def run_deep_research_llm_loop(
|
||||
simple_chat_history.append(assistant_with_tools)
|
||||
|
||||
# Now add TOOL_CALL_RESPONSE messages and tool call info for each result
|
||||
research_agent_tool_id = _get_research_agent_tool_id()
|
||||
for tab_index, report in enumerate(
|
||||
research_results.intermediate_reports
|
||||
):
|
||||
@@ -737,10 +744,7 @@ def run_deep_research_llm_loop(
|
||||
tab_index=tab_index,
|
||||
tool_name=current_tool_call.tool_name,
|
||||
tool_call_id=current_tool_call.tool_call_id,
|
||||
tool_id=get_tool_by_name(
|
||||
tool_name=RESEARCH_AGENT_TOOL_NAME,
|
||||
db_session=db_session,
|
||||
).id,
|
||||
tool_id=research_agent_tool_id,
|
||||
reasoning_tokens=llm_step_result.reasoning
|
||||
or most_recent_reasoning,
|
||||
tool_call_arguments=current_tool_call.tool_args,
|
||||
|
||||
@@ -379,13 +379,25 @@ def _worksheet_to_matrix(
|
||||
worksheet: Worksheet,
|
||||
) -> list[list[str]]:
|
||||
"""
|
||||
Converts a singular worksheet to a matrix of values
|
||||
Converts a singular worksheet to a matrix of values.
|
||||
|
||||
Rows are padded to a uniform width. In openpyxl's read_only mode,
|
||||
iter_rows can yield rows of differing lengths (trailing empty cells
|
||||
are sometimes omitted), and downstream column cleanup assumes a
|
||||
rectangular matrix.
|
||||
"""
|
||||
rows: list[list[str]] = []
|
||||
max_len = 0
|
||||
for worksheet_row in worksheet.iter_rows(min_row=1, values_only=True):
|
||||
row = ["" if cell is None else str(cell) for cell in worksheet_row]
|
||||
if len(row) > max_len:
|
||||
max_len = len(row)
|
||||
rows.append(row)
|
||||
|
||||
for row in rows:
|
||||
if len(row) < max_len:
|
||||
row.extend([""] * (max_len - len(row)))
|
||||
|
||||
return rows
|
||||
|
||||
|
||||
@@ -463,29 +475,13 @@ def _remove_empty_runs(
|
||||
return result
|
||||
|
||||
|
||||
def xlsx_to_text(file: IO[Any], file_name: str = "") -> str:
|
||||
# TODO: switch back to this approach in a few months when markitdown
|
||||
# fixes their handling of excel files
|
||||
def xlsx_sheet_extraction(file: IO[Any], file_name: str = "") -> list[tuple[str, str]]:
|
||||
"""
|
||||
Converts each sheet in the excel file to a csv condensed string.
|
||||
Returns a string and the worksheet title for each worksheet
|
||||
|
||||
# md = get_markitdown_converter()
|
||||
# stream_info = StreamInfo(
|
||||
# mimetype=SPREADSHEET_MIME_TYPE, filename=file_name or None, extension=".xlsx"
|
||||
# )
|
||||
# try:
|
||||
# workbook = md.convert(to_bytesio(file), stream_info=stream_info)
|
||||
# except (
|
||||
# BadZipFile,
|
||||
# ValueError,
|
||||
# FileConversionException,
|
||||
# UnsupportedFormatException,
|
||||
# ) as e:
|
||||
# error_str = f"Failed to extract text from {file_name or 'xlsx file'}: {e}"
|
||||
# if file_name.startswith("~"):
|
||||
# logger.debug(error_str + " (this is expected for files with ~)")
|
||||
# else:
|
||||
# logger.warning(error_str)
|
||||
# return ""
|
||||
# return workbook.markdown
|
||||
Returns a list of (csv_text, sheet)
|
||||
"""
|
||||
try:
|
||||
workbook = openpyxl.load_workbook(file, read_only=True)
|
||||
except BadZipFile as e:
|
||||
@@ -494,23 +490,30 @@ def xlsx_to_text(file: IO[Any], file_name: str = "") -> str:
|
||||
logger.debug(error_str + " (this is expected for files with ~)")
|
||||
else:
|
||||
logger.warning(error_str)
|
||||
return ""
|
||||
return []
|
||||
except Exception as e:
|
||||
if any(s in str(e) for s in KNOWN_OPENPYXL_BUGS):
|
||||
logger.error(
|
||||
f"Failed to extract text from {file_name or 'xlsx file'}. This happens due to a bug in openpyxl. {e}"
|
||||
)
|
||||
return ""
|
||||
return []
|
||||
raise
|
||||
|
||||
text_content = []
|
||||
sheets: list[tuple[str, str]] = []
|
||||
for sheet in workbook.worksheets:
|
||||
sheet_matrix = _clean_worksheet_matrix(_worksheet_to_matrix(sheet))
|
||||
buf = io.StringIO()
|
||||
writer = csv.writer(buf, lineterminator="\n")
|
||||
writer.writerows(sheet_matrix)
|
||||
text_content.append(buf.getvalue().rstrip("\n"))
|
||||
return TEXT_SECTION_SEPARATOR.join(text_content)
|
||||
csv_text = buf.getvalue().rstrip("\n")
|
||||
if csv_text.strip():
|
||||
sheets.append((csv_text, sheet.title))
|
||||
return sheets
|
||||
|
||||
|
||||
def xlsx_to_text(file: IO[Any], file_name: str = "") -> str:
|
||||
sheets = xlsx_sheet_extraction(file, file_name)
|
||||
return TEXT_SECTION_SEPARATOR.join(csv_text for csv_text, _title in sheets)
|
||||
|
||||
|
||||
def eml_to_text(file: IO[Any]) -> str:
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from typing import cast
|
||||
|
||||
from chonkie import SentenceChunker
|
||||
|
||||
from onyx.configs.app_configs import AVERAGE_SUMMARY_EMBEDDINGS
|
||||
@@ -16,16 +14,14 @@ from onyx.connectors.cross_connector_utils.miscellaneous_utils import (
|
||||
get_metadata_keys_to_ignore,
|
||||
)
|
||||
from onyx.connectors.models import IndexingDocument
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.indexing.chunking import DocumentChunker
|
||||
from onyx.indexing.chunking import extract_blurb
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.indexing.models import DocAwareChunk
|
||||
from onyx.llm.utils import MAX_CONTEXT_TOKENS
|
||||
from onyx.natural_language_processing.utils import BaseTokenizer
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.text_processing import clean_text
|
||||
from onyx.utils.text_processing import shared_precompare_cleanup
|
||||
from shared_configs.configs import DOC_EMBEDDING_CONTEXT_SIZE
|
||||
from shared_configs.configs import STRICT_CHUNK_TOKEN_LIMIT
|
||||
|
||||
# Not supporting overlaps, we need a clean combination of chunks and it is unclear if overlaps
|
||||
# actually help quality at all
|
||||
@@ -154,9 +150,6 @@ class Chunker:
|
||||
self.tokenizer = tokenizer
|
||||
self.callback = callback
|
||||
|
||||
self.max_context = 0
|
||||
self.prompt_tokens = 0
|
||||
|
||||
# Create a token counter function that returns the count instead of the tokens
|
||||
def token_counter(text: str) -> int:
|
||||
return len(tokenizer.encode(text))
|
||||
@@ -186,234 +179,12 @@ class Chunker:
|
||||
else None
|
||||
)
|
||||
|
||||
def _split_oversized_chunk(self, text: str, content_token_limit: int) -> list[str]:
|
||||
"""
|
||||
Splits the text into smaller chunks based on token count to ensure
|
||||
no chunk exceeds the content_token_limit.
|
||||
"""
|
||||
tokens = self.tokenizer.tokenize(text)
|
||||
chunks = []
|
||||
start = 0
|
||||
total_tokens = len(tokens)
|
||||
while start < total_tokens:
|
||||
end = min(start + content_token_limit, total_tokens)
|
||||
token_chunk = tokens[start:end]
|
||||
chunk_text = " ".join(token_chunk)
|
||||
chunks.append(chunk_text)
|
||||
start = end
|
||||
return chunks
|
||||
|
||||
def _extract_blurb(self, text: str) -> str:
|
||||
"""
|
||||
Extract a short blurb from the text (first chunk of size `blurb_size`).
|
||||
"""
|
||||
# chunker is in `text` mode
|
||||
texts = cast(list[str], self.blurb_splitter.chunk(text))
|
||||
if not texts:
|
||||
return ""
|
||||
return texts[0]
|
||||
|
||||
def _get_mini_chunk_texts(self, chunk_text: str) -> list[str] | None:
|
||||
"""
|
||||
For "multipass" mode: additional sub-chunks (mini-chunks) for use in certain embeddings.
|
||||
"""
|
||||
if self.mini_chunk_splitter and chunk_text.strip():
|
||||
# chunker is in `text` mode
|
||||
return cast(list[str], self.mini_chunk_splitter.chunk(chunk_text))
|
||||
return None
|
||||
|
||||
# ADDED: extra param image_url to store in the chunk
|
||||
def _create_chunk(
|
||||
self,
|
||||
document: IndexingDocument,
|
||||
chunks_list: list[DocAwareChunk],
|
||||
text: str,
|
||||
links: dict[int, str],
|
||||
is_continuation: bool = False,
|
||||
title_prefix: str = "",
|
||||
metadata_suffix_semantic: str = "",
|
||||
metadata_suffix_keyword: str = "",
|
||||
image_file_id: str | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Helper to create a new DocAwareChunk, append it to chunks_list.
|
||||
"""
|
||||
new_chunk = DocAwareChunk(
|
||||
source_document=document,
|
||||
chunk_id=len(chunks_list),
|
||||
blurb=self._extract_blurb(text),
|
||||
content=text,
|
||||
source_links=links or {0: ""},
|
||||
image_file_id=image_file_id,
|
||||
section_continuation=is_continuation,
|
||||
title_prefix=title_prefix,
|
||||
metadata_suffix_semantic=metadata_suffix_semantic,
|
||||
metadata_suffix_keyword=metadata_suffix_keyword,
|
||||
mini_chunk_texts=self._get_mini_chunk_texts(text),
|
||||
large_chunk_id=None,
|
||||
doc_summary="",
|
||||
chunk_context="",
|
||||
contextual_rag_reserved_tokens=0, # set per-document in _handle_single_document
|
||||
self._document_chunker = DocumentChunker(
|
||||
tokenizer=tokenizer,
|
||||
blurb_splitter=self.blurb_splitter,
|
||||
chunk_splitter=self.chunk_splitter,
|
||||
mini_chunk_splitter=self.mini_chunk_splitter,
|
||||
)
|
||||
chunks_list.append(new_chunk)
|
||||
|
||||
def _chunk_document_with_sections(
|
||||
self,
|
||||
document: IndexingDocument,
|
||||
sections: list[Section],
|
||||
title_prefix: str,
|
||||
metadata_suffix_semantic: str,
|
||||
metadata_suffix_keyword: str,
|
||||
content_token_limit: int,
|
||||
) -> list[DocAwareChunk]:
|
||||
"""
|
||||
Loops through sections of the document, converting them into one or more chunks.
|
||||
Works with processed sections that are base Section objects.
|
||||
"""
|
||||
chunks: list[DocAwareChunk] = []
|
||||
link_offsets: dict[int, str] = {}
|
||||
chunk_text = ""
|
||||
|
||||
for section_idx, section in enumerate(sections):
|
||||
# Get section text and other attributes
|
||||
section_text = clean_text(str(section.text or ""))
|
||||
section_link_text = section.link or ""
|
||||
image_url = section.image_file_id
|
||||
|
||||
# If there is no useful content, skip
|
||||
if not section_text and (not document.title or section_idx > 0):
|
||||
logger.warning(
|
||||
f"Skipping empty or irrelevant section in doc {document.semantic_identifier}, link={section_link_text}"
|
||||
)
|
||||
continue
|
||||
|
||||
# CASE 1: If this section has an image, force a separate chunk
|
||||
if image_url:
|
||||
# First, if we have any partially built text chunk, finalize it
|
||||
if chunk_text.strip():
|
||||
self._create_chunk(
|
||||
document,
|
||||
chunks,
|
||||
chunk_text,
|
||||
link_offsets,
|
||||
is_continuation=False,
|
||||
title_prefix=title_prefix,
|
||||
metadata_suffix_semantic=metadata_suffix_semantic,
|
||||
metadata_suffix_keyword=metadata_suffix_keyword,
|
||||
)
|
||||
chunk_text = ""
|
||||
link_offsets = {}
|
||||
|
||||
# Create a chunk specifically for this image section
|
||||
# (Using the text summary that was generated during processing)
|
||||
self._create_chunk(
|
||||
document,
|
||||
chunks,
|
||||
section_text,
|
||||
links={0: section_link_text} if section_link_text else {},
|
||||
image_file_id=image_url,
|
||||
title_prefix=title_prefix,
|
||||
metadata_suffix_semantic=metadata_suffix_semantic,
|
||||
metadata_suffix_keyword=metadata_suffix_keyword,
|
||||
)
|
||||
# Continue to next section
|
||||
continue
|
||||
|
||||
# CASE 2: Normal text section
|
||||
section_token_count = len(self.tokenizer.encode(section_text))
|
||||
|
||||
# If the section is large on its own, split it separately
|
||||
if section_token_count > content_token_limit:
|
||||
if chunk_text.strip():
|
||||
self._create_chunk(
|
||||
document,
|
||||
chunks,
|
||||
chunk_text,
|
||||
link_offsets,
|
||||
False,
|
||||
title_prefix,
|
||||
metadata_suffix_semantic,
|
||||
metadata_suffix_keyword,
|
||||
)
|
||||
chunk_text = ""
|
||||
link_offsets = {}
|
||||
|
||||
# chunker is in `text` mode
|
||||
split_texts = cast(list[str], self.chunk_splitter.chunk(section_text))
|
||||
for i, split_text in enumerate(split_texts):
|
||||
# If even the split_text is bigger than strict limit, further split
|
||||
if (
|
||||
STRICT_CHUNK_TOKEN_LIMIT
|
||||
and len(self.tokenizer.encode(split_text)) > content_token_limit
|
||||
):
|
||||
smaller_chunks = self._split_oversized_chunk(
|
||||
split_text, content_token_limit
|
||||
)
|
||||
for j, small_chunk in enumerate(smaller_chunks):
|
||||
self._create_chunk(
|
||||
document,
|
||||
chunks,
|
||||
small_chunk,
|
||||
{0: section_link_text},
|
||||
is_continuation=(j != 0),
|
||||
title_prefix=title_prefix,
|
||||
metadata_suffix_semantic=metadata_suffix_semantic,
|
||||
metadata_suffix_keyword=metadata_suffix_keyword,
|
||||
)
|
||||
else:
|
||||
self._create_chunk(
|
||||
document,
|
||||
chunks,
|
||||
split_text,
|
||||
{0: section_link_text},
|
||||
is_continuation=(i != 0),
|
||||
title_prefix=title_prefix,
|
||||
metadata_suffix_semantic=metadata_suffix_semantic,
|
||||
metadata_suffix_keyword=metadata_suffix_keyword,
|
||||
)
|
||||
continue
|
||||
|
||||
# If we can still fit this section into the current chunk, do so
|
||||
current_token_count = len(self.tokenizer.encode(chunk_text))
|
||||
current_offset = len(shared_precompare_cleanup(chunk_text))
|
||||
next_section_tokens = (
|
||||
len(self.tokenizer.encode(SECTION_SEPARATOR)) + section_token_count
|
||||
)
|
||||
|
||||
if next_section_tokens + current_token_count <= content_token_limit:
|
||||
if chunk_text:
|
||||
chunk_text += SECTION_SEPARATOR
|
||||
chunk_text += section_text
|
||||
link_offsets[current_offset] = section_link_text
|
||||
else:
|
||||
# finalize the existing chunk
|
||||
self._create_chunk(
|
||||
document,
|
||||
chunks,
|
||||
chunk_text,
|
||||
link_offsets,
|
||||
False,
|
||||
title_prefix,
|
||||
metadata_suffix_semantic,
|
||||
metadata_suffix_keyword,
|
||||
)
|
||||
# start a new chunk
|
||||
link_offsets = {0: section_link_text}
|
||||
chunk_text = section_text
|
||||
|
||||
# finalize any leftover text chunk
|
||||
if chunk_text.strip() or not chunks:
|
||||
self._create_chunk(
|
||||
document,
|
||||
chunks,
|
||||
chunk_text,
|
||||
link_offsets or {0: ""}, # safe default
|
||||
False,
|
||||
title_prefix,
|
||||
metadata_suffix_semantic,
|
||||
metadata_suffix_keyword,
|
||||
)
|
||||
return chunks
|
||||
|
||||
def _handle_single_document(
|
||||
self, document: IndexingDocument
|
||||
@@ -423,7 +194,10 @@ class Chunker:
|
||||
logger.debug(f"Chunking {document.semantic_identifier}")
|
||||
|
||||
# Title prep
|
||||
title = self._extract_blurb(document.get_title_for_document_index() or "")
|
||||
title = extract_blurb(
|
||||
document.get_title_for_document_index() or "",
|
||||
self.blurb_splitter,
|
||||
)
|
||||
title_prefix = title + RETURN_SEPARATOR if title else ""
|
||||
title_tokens = len(self.tokenizer.encode(title_prefix))
|
||||
|
||||
@@ -491,7 +265,7 @@ class Chunker:
|
||||
# Use processed_sections if available (IndexingDocument), otherwise use original sections
|
||||
sections_to_chunk = document.processed_sections
|
||||
|
||||
normal_chunks = self._chunk_document_with_sections(
|
||||
normal_chunks = self._document_chunker.chunk(
|
||||
document,
|
||||
sections_to_chunk,
|
||||
title_prefix,
|
||||
|
||||
7
backend/onyx/indexing/chunking/__init__.py
Normal file
7
backend/onyx/indexing/chunking/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from onyx.indexing.chunking.document_chunker import DocumentChunker
|
||||
from onyx.indexing.chunking.section_chunker import extract_blurb
|
||||
|
||||
__all__ = [
|
||||
"DocumentChunker",
|
||||
"extract_blurb",
|
||||
]
|
||||
113
backend/onyx/indexing/chunking/document_chunker.py
Normal file
113
backend/onyx/indexing/chunking/document_chunker.py
Normal file
@@ -0,0 +1,113 @@
|
||||
from chonkie import SentenceChunker
|
||||
|
||||
from onyx.connectors.models import IndexingDocument
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.connectors.models import SectionType
|
||||
from onyx.indexing.chunking.image_section_chunker import ImageChunker
|
||||
from onyx.indexing.chunking.section_chunker import AccumulatorState
|
||||
from onyx.indexing.chunking.section_chunker import ChunkPayload
|
||||
from onyx.indexing.chunking.section_chunker import SectionChunker
|
||||
from onyx.indexing.chunking.tabular_section_chunker import TabularChunker
|
||||
from onyx.indexing.chunking.text_section_chunker import TextChunker
|
||||
from onyx.indexing.models import DocAwareChunk
|
||||
from onyx.natural_language_processing.utils import BaseTokenizer
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.text_processing import clean_text
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class DocumentChunker:
|
||||
"""Converts a document's processed sections into DocAwareChunks.
|
||||
|
||||
Drop-in replacement for `Chunker._chunk_document_with_sections`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: BaseTokenizer,
|
||||
blurb_splitter: SentenceChunker,
|
||||
chunk_splitter: SentenceChunker,
|
||||
mini_chunk_splitter: SentenceChunker | None = None,
|
||||
) -> None:
|
||||
self.blurb_splitter = blurb_splitter
|
||||
self.mini_chunk_splitter = mini_chunk_splitter
|
||||
|
||||
self._dispatch: dict[SectionType, SectionChunker] = {
|
||||
SectionType.TEXT: TextChunker(
|
||||
tokenizer=tokenizer,
|
||||
chunk_splitter=chunk_splitter,
|
||||
),
|
||||
SectionType.IMAGE: ImageChunker(),
|
||||
SectionType.TABULAR: TabularChunker(tokenizer=tokenizer),
|
||||
}
|
||||
|
||||
def chunk(
|
||||
self,
|
||||
document: IndexingDocument,
|
||||
sections: list[Section],
|
||||
title_prefix: str,
|
||||
metadata_suffix_semantic: str,
|
||||
metadata_suffix_keyword: str,
|
||||
content_token_limit: int,
|
||||
) -> list[DocAwareChunk]:
|
||||
payloads = self._collect_section_payloads(
|
||||
document=document,
|
||||
sections=sections,
|
||||
content_token_limit=content_token_limit,
|
||||
)
|
||||
|
||||
if not payloads:
|
||||
payloads.append(ChunkPayload(text="", links={0: ""}))
|
||||
|
||||
return [
|
||||
payload.to_doc_aware_chunk(
|
||||
document=document,
|
||||
chunk_id=idx,
|
||||
blurb_splitter=self.blurb_splitter,
|
||||
mini_chunk_splitter=self.mini_chunk_splitter,
|
||||
title_prefix=title_prefix,
|
||||
metadata_suffix_semantic=metadata_suffix_semantic,
|
||||
metadata_suffix_keyword=metadata_suffix_keyword,
|
||||
)
|
||||
for idx, payload in enumerate(payloads)
|
||||
]
|
||||
|
||||
def _collect_section_payloads(
|
||||
self,
|
||||
document: IndexingDocument,
|
||||
sections: list[Section],
|
||||
content_token_limit: int,
|
||||
) -> list[ChunkPayload]:
|
||||
accumulator = AccumulatorState()
|
||||
payloads: list[ChunkPayload] = []
|
||||
|
||||
for section_idx, section in enumerate(sections):
|
||||
section_text = clean_text(str(section.text or ""))
|
||||
|
||||
if not section_text and (not document.title or section_idx > 0):
|
||||
logger.warning(
|
||||
f"Skipping empty or irrelevant section in doc "
|
||||
f"{document.semantic_identifier}, link={section.link}"
|
||||
)
|
||||
continue
|
||||
|
||||
chunker = self._select_chunker(section)
|
||||
result = chunker.chunk_section(
|
||||
section=section,
|
||||
accumulator=accumulator,
|
||||
content_token_limit=content_token_limit,
|
||||
)
|
||||
payloads.extend(result.payloads)
|
||||
accumulator = result.accumulator
|
||||
|
||||
# Final flush — any leftover buffered text becomes one last payload.
|
||||
payloads.extend(accumulator.flush_to_list())
|
||||
|
||||
return payloads
|
||||
|
||||
def _select_chunker(self, section: Section) -> SectionChunker:
|
||||
try:
|
||||
return self._dispatch[section.type]
|
||||
except KeyError:
|
||||
raise ValueError(f"No SectionChunker registered for type={section.type}")
|
||||
35
backend/onyx/indexing/chunking/image_section_chunker.py
Normal file
35
backend/onyx/indexing/chunking/image_section_chunker.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.indexing.chunking.section_chunker import AccumulatorState
|
||||
from onyx.indexing.chunking.section_chunker import ChunkPayload
|
||||
from onyx.indexing.chunking.section_chunker import SectionChunker
|
||||
from onyx.indexing.chunking.section_chunker import SectionChunkerOutput
|
||||
from onyx.utils.text_processing import clean_text
|
||||
|
||||
|
||||
class ImageChunker(SectionChunker):
|
||||
def chunk_section(
|
||||
self,
|
||||
section: Section,
|
||||
accumulator: AccumulatorState,
|
||||
content_token_limit: int, # noqa: ARG002
|
||||
) -> SectionChunkerOutput:
|
||||
assert section.image_file_id is not None
|
||||
|
||||
section_text = clean_text(str(section.text or ""))
|
||||
section_link = section.link or ""
|
||||
|
||||
# Flush any partially built text chunks
|
||||
payloads = accumulator.flush_to_list()
|
||||
payloads.append(
|
||||
ChunkPayload(
|
||||
text=section_text,
|
||||
links={0: section_link} if section_link else {},
|
||||
image_file_id=section.image_file_id,
|
||||
is_continuation=False,
|
||||
)
|
||||
)
|
||||
|
||||
return SectionChunkerOutput(
|
||||
payloads=payloads,
|
||||
accumulator=AccumulatorState(),
|
||||
)
|
||||
100
backend/onyx/indexing/chunking/section_chunker.py
Normal file
100
backend/onyx/indexing/chunking/section_chunker.py
Normal file
@@ -0,0 +1,100 @@
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from typing import cast
|
||||
|
||||
from chonkie import SentenceChunker
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
|
||||
from onyx.connectors.models import IndexingDocument
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.indexing.models import DocAwareChunk
|
||||
|
||||
|
||||
def extract_blurb(text: str, blurb_splitter: SentenceChunker) -> str:
|
||||
texts = cast(list[str], blurb_splitter.chunk(text))
|
||||
if not texts:
|
||||
return ""
|
||||
return texts[0]
|
||||
|
||||
|
||||
def get_mini_chunk_texts(
|
||||
chunk_text: str,
|
||||
mini_chunk_splitter: SentenceChunker | None,
|
||||
) -> list[str] | None:
|
||||
if mini_chunk_splitter and chunk_text.strip():
|
||||
return list(cast(Sequence[str], mini_chunk_splitter.chunk(chunk_text)))
|
||||
return None
|
||||
|
||||
|
||||
class ChunkPayload(BaseModel):
|
||||
"""Section-local chunk content without document-scoped fields.
|
||||
|
||||
The orchestrator upgrades these to DocAwareChunks via
|
||||
`to_doc_aware_chunk` after assigning chunk_ids and attaching
|
||||
title/metadata.
|
||||
"""
|
||||
|
||||
text: str
|
||||
links: dict[int, str]
|
||||
is_continuation: bool = False
|
||||
image_file_id: str | None = None
|
||||
|
||||
def to_doc_aware_chunk(
|
||||
self,
|
||||
document: IndexingDocument,
|
||||
chunk_id: int,
|
||||
blurb_splitter: SentenceChunker,
|
||||
title_prefix: str = "",
|
||||
metadata_suffix_semantic: str = "",
|
||||
metadata_suffix_keyword: str = "",
|
||||
mini_chunk_splitter: SentenceChunker | None = None,
|
||||
) -> DocAwareChunk:
|
||||
return DocAwareChunk(
|
||||
source_document=document,
|
||||
chunk_id=chunk_id,
|
||||
blurb=extract_blurb(self.text, blurb_splitter),
|
||||
content=self.text,
|
||||
source_links=self.links or {0: ""},
|
||||
image_file_id=self.image_file_id,
|
||||
section_continuation=self.is_continuation,
|
||||
title_prefix=title_prefix,
|
||||
metadata_suffix_semantic=metadata_suffix_semantic,
|
||||
metadata_suffix_keyword=metadata_suffix_keyword,
|
||||
mini_chunk_texts=get_mini_chunk_texts(self.text, mini_chunk_splitter),
|
||||
large_chunk_id=None,
|
||||
doc_summary="",
|
||||
chunk_context="",
|
||||
contextual_rag_reserved_tokens=0,
|
||||
)
|
||||
|
||||
|
||||
class AccumulatorState(BaseModel):
|
||||
"""Cross-section text buffer threaded through SectionChunkers."""
|
||||
|
||||
text: str = ""
|
||||
link_offsets: dict[int, str] = Field(default_factory=dict)
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
return not self.text.strip()
|
||||
|
||||
def flush_to_list(self) -> list[ChunkPayload]:
|
||||
if self.is_empty():
|
||||
return []
|
||||
return [ChunkPayload(text=self.text, links=self.link_offsets)]
|
||||
|
||||
|
||||
class SectionChunkerOutput(BaseModel):
|
||||
payloads: list[ChunkPayload]
|
||||
accumulator: AccumulatorState
|
||||
|
||||
|
||||
class SectionChunker(ABC):
|
||||
@abstractmethod
|
||||
def chunk_section(
|
||||
self,
|
||||
section: Section,
|
||||
accumulator: AccumulatorState,
|
||||
content_token_limit: int,
|
||||
) -> SectionChunkerOutput: ...
|
||||
272
backend/onyx/indexing/chunking/tabular_section_chunker.py
Normal file
272
backend/onyx/indexing/chunking/tabular_section_chunker.py
Normal file
@@ -0,0 +1,272 @@
|
||||
import csv
|
||||
import io
|
||||
from collections.abc import Iterable
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.indexing.chunking.section_chunker import AccumulatorState
|
||||
from onyx.indexing.chunking.section_chunker import ChunkPayload
|
||||
from onyx.indexing.chunking.section_chunker import SectionChunker
|
||||
from onyx.indexing.chunking.section_chunker import SectionChunkerOutput
|
||||
from onyx.natural_language_processing.utils import BaseTokenizer
|
||||
from onyx.natural_language_processing.utils import count_tokens
|
||||
from onyx.natural_language_processing.utils import split_text_by_tokens
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
COLUMNS_MARKER = "Columns:"
|
||||
FIELD_VALUE_SEPARATOR = ", "
|
||||
ROW_JOIN = "\n"
|
||||
NEWLINE_TOKENS = 1
|
||||
|
||||
|
||||
class _ParsedRow(BaseModel):
|
||||
header: list[str]
|
||||
row: list[str]
|
||||
|
||||
|
||||
class _TokenizedText(BaseModel):
|
||||
text: str
|
||||
token_count: int
|
||||
|
||||
|
||||
def format_row(header: list[str], row: list[str]) -> str:
|
||||
"""
|
||||
A header-row combination is formatted like this:
|
||||
field1=value1, field2=value2, field3=value3
|
||||
"""
|
||||
pairs = _row_to_pairs(header, row)
|
||||
formatted = FIELD_VALUE_SEPARATOR.join(f"{h}={v}" for h, v in pairs)
|
||||
return formatted
|
||||
|
||||
|
||||
def format_columns_header(headers: list[str]) -> str:
|
||||
"""
|
||||
Format the column header line. Underscored headers get a
|
||||
space-substituted friendly alias in parens.
|
||||
Example:
|
||||
headers = ["id", "MTTR_hours"]
|
||||
=> "Columns: id, MTTR_hours (MTTR hours)"
|
||||
"""
|
||||
parts: list[str] = []
|
||||
for header in headers:
|
||||
friendly = header
|
||||
if "_" in header:
|
||||
friendly = f'{header} ({header.replace("_", " ")})'
|
||||
parts.append(friendly)
|
||||
return f"{COLUMNS_MARKER} " + FIELD_VALUE_SEPARATOR.join(parts)
|
||||
|
||||
|
||||
def parse_section(section: Section) -> list[_ParsedRow]:
|
||||
"""Parse CSV into headers + rows. First non-empty row is the header;
|
||||
blank rows are skipped."""
|
||||
section_text = section.text or ""
|
||||
if not section_text.strip():
|
||||
return []
|
||||
|
||||
reader = csv.reader(io.StringIO(section_text))
|
||||
non_empty_rows = [row for row in reader if any(cell.strip() for cell in row)]
|
||||
|
||||
if not non_empty_rows:
|
||||
return []
|
||||
|
||||
header, *data_rows = non_empty_rows
|
||||
return [_ParsedRow(header=header, row=row) for row in data_rows]
|
||||
|
||||
|
||||
def _row_to_pairs(headers: list[str], row: list[str]) -> list[tuple[str, str]]:
|
||||
return [(h, v) for h, v in zip(headers, row) if v.strip()]
|
||||
|
||||
|
||||
def pack_chunk(chunk: str, new_row: str) -> str:
|
||||
return chunk + "\n" + new_row
|
||||
|
||||
|
||||
def _split_row_by_pairs(
|
||||
pairs: list[tuple[str, str]],
|
||||
tokenizer: BaseTokenizer,
|
||||
max_tokens: int,
|
||||
) -> list[_TokenizedText]:
|
||||
"""Greedily pack pairs into max-sized pieces. Any single pair that
|
||||
itself exceeds ``max_tokens`` is token-split at id boundaries.
|
||||
No headers."""
|
||||
separator_tokens = count_tokens(FIELD_VALUE_SEPARATOR, tokenizer)
|
||||
pieces: list[_TokenizedText] = []
|
||||
current_parts: list[str] = []
|
||||
current_tokens = 0
|
||||
|
||||
for pair in pairs:
|
||||
pair_str = f"{pair[0]}={pair[1]}"
|
||||
pair_tokens = count_tokens(pair_str, tokenizer)
|
||||
increment = pair_tokens if not current_parts else separator_tokens + pair_tokens
|
||||
|
||||
if current_tokens + increment <= max_tokens:
|
||||
current_parts.append(pair_str)
|
||||
current_tokens += increment
|
||||
continue
|
||||
|
||||
if current_parts:
|
||||
pieces.append(
|
||||
_TokenizedText(
|
||||
text=FIELD_VALUE_SEPARATOR.join(current_parts),
|
||||
token_count=current_tokens,
|
||||
)
|
||||
)
|
||||
current_parts = []
|
||||
current_tokens = 0
|
||||
|
||||
if pair_tokens > max_tokens:
|
||||
for split_text in split_text_by_tokens(pair_str, tokenizer, max_tokens):
|
||||
pieces.append(
|
||||
_TokenizedText(
|
||||
text=split_text,
|
||||
token_count=count_tokens(split_text, tokenizer),
|
||||
)
|
||||
)
|
||||
else:
|
||||
current_parts = [pair_str]
|
||||
current_tokens = pair_tokens
|
||||
|
||||
if current_parts:
|
||||
pieces.append(
|
||||
_TokenizedText(
|
||||
text=FIELD_VALUE_SEPARATOR.join(current_parts),
|
||||
token_count=current_tokens,
|
||||
)
|
||||
)
|
||||
return pieces
|
||||
|
||||
|
||||
def _build_chunk_from_scratch(
|
||||
pairs: list[tuple[str, str]],
|
||||
formatted_row: str,
|
||||
row_tokens: int,
|
||||
column_header: str,
|
||||
column_header_tokens: int,
|
||||
sheet_header: str,
|
||||
sheet_header_tokens: int,
|
||||
tokenizer: BaseTokenizer,
|
||||
max_tokens: int,
|
||||
) -> list[_TokenizedText]:
|
||||
# 1. Row alone is too large — split by pairs, no headers.
|
||||
if row_tokens > max_tokens:
|
||||
return _split_row_by_pairs(pairs, tokenizer, max_tokens)
|
||||
|
||||
chunk = formatted_row
|
||||
chunk_tokens = row_tokens
|
||||
|
||||
# 2. Attempt to add column header
|
||||
candidate_tokens = column_header_tokens + NEWLINE_TOKENS + chunk_tokens
|
||||
if candidate_tokens <= max_tokens:
|
||||
chunk = column_header + ROW_JOIN + chunk
|
||||
chunk_tokens = candidate_tokens
|
||||
|
||||
# 3. Attempt to add sheet header
|
||||
if sheet_header:
|
||||
candidate_tokens = sheet_header_tokens + NEWLINE_TOKENS + chunk_tokens
|
||||
if candidate_tokens <= max_tokens:
|
||||
chunk = sheet_header + ROW_JOIN + chunk
|
||||
chunk_tokens = candidate_tokens
|
||||
|
||||
return [_TokenizedText(text=chunk, token_count=chunk_tokens)]
|
||||
|
||||
|
||||
def parse_to_chunks(
|
||||
rows: Iterable[_ParsedRow],
|
||||
sheet_header: str,
|
||||
tokenizer: BaseTokenizer,
|
||||
max_tokens: int,
|
||||
) -> list[str]:
|
||||
rows_list = list(rows)
|
||||
if not rows_list:
|
||||
return []
|
||||
|
||||
column_header = format_columns_header(rows_list[0].header)
|
||||
column_header_tokens = count_tokens(column_header, tokenizer)
|
||||
sheet_header_tokens = count_tokens(sheet_header, tokenizer) if sheet_header else 0
|
||||
|
||||
chunks: list[str] = []
|
||||
current_chunk = ""
|
||||
current_chunk_tokens = 0
|
||||
|
||||
for row in rows_list:
|
||||
pairs: list[tuple[str, str]] = _row_to_pairs(row.header, row.row)
|
||||
formatted = format_row(row.header, row.row)
|
||||
row_tokens = count_tokens(formatted, tokenizer)
|
||||
|
||||
if current_chunk:
|
||||
# Attempt to pack it in (additive approximation)
|
||||
if current_chunk_tokens + NEWLINE_TOKENS + row_tokens <= max_tokens:
|
||||
current_chunk = pack_chunk(current_chunk, formatted)
|
||||
current_chunk_tokens += NEWLINE_TOKENS + row_tokens
|
||||
continue
|
||||
# Doesn't fit — flush and start new
|
||||
chunks.append(current_chunk)
|
||||
current_chunk = ""
|
||||
current_chunk_tokens = 0
|
||||
|
||||
# Build chunk from scratch
|
||||
for piece in _build_chunk_from_scratch(
|
||||
pairs=pairs,
|
||||
formatted_row=formatted,
|
||||
row_tokens=row_tokens,
|
||||
column_header=column_header,
|
||||
column_header_tokens=column_header_tokens,
|
||||
sheet_header=sheet_header,
|
||||
sheet_header_tokens=sheet_header_tokens,
|
||||
tokenizer=tokenizer,
|
||||
max_tokens=max_tokens,
|
||||
):
|
||||
if current_chunk:
|
||||
chunks.append(current_chunk)
|
||||
current_chunk = piece.text
|
||||
current_chunk_tokens = piece.token_count
|
||||
|
||||
# Flush remaining
|
||||
if current_chunk:
|
||||
chunks.append(current_chunk)
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
class TabularChunker(SectionChunker):
|
||||
def __init__(self, tokenizer: BaseTokenizer) -> None:
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
def chunk_section(
|
||||
self,
|
||||
section: Section,
|
||||
accumulator: AccumulatorState,
|
||||
content_token_limit: int,
|
||||
) -> SectionChunkerOutput:
|
||||
payloads = accumulator.flush_to_list()
|
||||
|
||||
parsed_rows = parse_section(section)
|
||||
if not parsed_rows:
|
||||
logger.warning(
|
||||
f"TabularChunker: skipping unparseable section (link={section.link})"
|
||||
)
|
||||
return SectionChunkerOutput(
|
||||
payloads=payloads, accumulator=AccumulatorState()
|
||||
)
|
||||
|
||||
sheet_header = section.link or ""
|
||||
chunk_texts = parse_to_chunks(
|
||||
rows=parsed_rows,
|
||||
sheet_header=sheet_header,
|
||||
tokenizer=self.tokenizer,
|
||||
max_tokens=content_token_limit,
|
||||
)
|
||||
|
||||
for i, text in enumerate(chunk_texts):
|
||||
payloads.append(
|
||||
ChunkPayload(
|
||||
text=text,
|
||||
links={0: section.link or ""},
|
||||
is_continuation=(i > 0),
|
||||
)
|
||||
)
|
||||
return SectionChunkerOutput(payloads=payloads, accumulator=AccumulatorState())
|
||||
117
backend/onyx/indexing/chunking/text_section_chunker.py
Normal file
117
backend/onyx/indexing/chunking/text_section_chunker.py
Normal file
@@ -0,0 +1,117 @@
|
||||
from typing import cast
|
||||
|
||||
from chonkie import SentenceChunker
|
||||
|
||||
from onyx.configs.constants import SECTION_SEPARATOR
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.indexing.chunking.section_chunker import AccumulatorState
|
||||
from onyx.indexing.chunking.section_chunker import ChunkPayload
|
||||
from onyx.indexing.chunking.section_chunker import SectionChunker
|
||||
from onyx.indexing.chunking.section_chunker import SectionChunkerOutput
|
||||
from onyx.natural_language_processing.utils import BaseTokenizer
|
||||
from onyx.natural_language_processing.utils import count_tokens
|
||||
from onyx.natural_language_processing.utils import split_text_by_tokens
|
||||
from onyx.utils.text_processing import clean_text
|
||||
from onyx.utils.text_processing import shared_precompare_cleanup
|
||||
from shared_configs.configs import STRICT_CHUNK_TOKEN_LIMIT
|
||||
|
||||
|
||||
class TextChunker(SectionChunker):
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: BaseTokenizer,
|
||||
chunk_splitter: SentenceChunker,
|
||||
) -> None:
|
||||
self.tokenizer = tokenizer
|
||||
self.chunk_splitter = chunk_splitter
|
||||
|
||||
self.section_separator_token_count = count_tokens(
|
||||
SECTION_SEPARATOR,
|
||||
self.tokenizer,
|
||||
)
|
||||
|
||||
def chunk_section(
|
||||
self,
|
||||
section: Section,
|
||||
accumulator: AccumulatorState,
|
||||
content_token_limit: int,
|
||||
) -> SectionChunkerOutput:
|
||||
section_text = clean_text(str(section.text or ""))
|
||||
section_link = section.link or ""
|
||||
section_token_count = len(self.tokenizer.encode(section_text))
|
||||
|
||||
# Oversized — flush buffer and split the section
|
||||
if section_token_count > content_token_limit:
|
||||
return self._handle_oversized_section(
|
||||
section_text=section_text,
|
||||
section_link=section_link,
|
||||
accumulator=accumulator,
|
||||
content_token_limit=content_token_limit,
|
||||
)
|
||||
|
||||
current_token_count = count_tokens(accumulator.text, self.tokenizer)
|
||||
next_section_tokens = self.section_separator_token_count + section_token_count
|
||||
|
||||
# Fits — extend the accumulator
|
||||
if next_section_tokens + current_token_count <= content_token_limit:
|
||||
offset = len(shared_precompare_cleanup(accumulator.text))
|
||||
new_text = accumulator.text
|
||||
if new_text:
|
||||
new_text += SECTION_SEPARATOR
|
||||
new_text += section_text
|
||||
return SectionChunkerOutput(
|
||||
payloads=[],
|
||||
accumulator=AccumulatorState(
|
||||
text=new_text,
|
||||
link_offsets={**accumulator.link_offsets, offset: section_link},
|
||||
),
|
||||
)
|
||||
|
||||
# Doesn't fit — flush buffer and restart with this section
|
||||
return SectionChunkerOutput(
|
||||
payloads=accumulator.flush_to_list(),
|
||||
accumulator=AccumulatorState(
|
||||
text=section_text,
|
||||
link_offsets={0: section_link},
|
||||
),
|
||||
)
|
||||
|
||||
def _handle_oversized_section(
|
||||
self,
|
||||
section_text: str,
|
||||
section_link: str,
|
||||
accumulator: AccumulatorState,
|
||||
content_token_limit: int,
|
||||
) -> SectionChunkerOutput:
|
||||
payloads = accumulator.flush_to_list()
|
||||
|
||||
split_texts = cast(list[str], self.chunk_splitter.chunk(section_text))
|
||||
for i, split_text in enumerate(split_texts):
|
||||
if (
|
||||
STRICT_CHUNK_TOKEN_LIMIT
|
||||
and count_tokens(split_text, self.tokenizer) > content_token_limit
|
||||
):
|
||||
smaller_chunks = split_text_by_tokens(
|
||||
split_text, self.tokenizer, content_token_limit
|
||||
)
|
||||
for j, small_chunk in enumerate(smaller_chunks):
|
||||
payloads.append(
|
||||
ChunkPayload(
|
||||
text=small_chunk,
|
||||
links={0: section_link},
|
||||
is_continuation=(j != 0),
|
||||
)
|
||||
)
|
||||
else:
|
||||
payloads.append(
|
||||
ChunkPayload(
|
||||
text=split_text,
|
||||
links={0: section_link},
|
||||
is_continuation=(i != 0),
|
||||
)
|
||||
)
|
||||
|
||||
return SectionChunkerOutput(
|
||||
payloads=payloads,
|
||||
accumulator=AccumulatorState(),
|
||||
)
|
||||
@@ -3,6 +3,8 @@ from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from collections import defaultdict
|
||||
|
||||
import sentry_sdk
|
||||
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import ConnectorStopSignal
|
||||
from onyx.connectors.models import DocumentFailure
|
||||
@@ -291,6 +293,13 @@ def embed_chunks_with_failure_handling(
|
||||
)
|
||||
embedded_chunks.extend(doc_embedded_chunks)
|
||||
except Exception as e:
|
||||
with sentry_sdk.new_scope() as scope:
|
||||
scope.set_tag("stage", "embedding")
|
||||
scope.set_tag("doc_id", doc_id)
|
||||
if tenant_id:
|
||||
scope.set_tag("tenant_id", tenant_id)
|
||||
scope.fingerprint = ["embedding-failure", type(e).__name__]
|
||||
sentry_sdk.capture_exception(e)
|
||||
logger.exception(f"Failed to embed chunks for document '{doc_id}'")
|
||||
failures.append(
|
||||
ConnectorFailure(
|
||||
|
||||
@@ -5,6 +5,7 @@ from collections.abc import Iterator
|
||||
from contextlib import contextmanager
|
||||
from typing import Protocol
|
||||
|
||||
import sentry_sdk
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -332,6 +333,13 @@ def index_doc_batch_with_handler(
|
||||
except Exception as e:
|
||||
# don't log the batch directly, it's too much text
|
||||
document_ids = [doc.id for doc in document_batch]
|
||||
with sentry_sdk.new_scope() as scope:
|
||||
scope.set_tag("stage", "indexing_pipeline")
|
||||
scope.set_tag("tenant_id", tenant_id)
|
||||
scope.set_tag("batch_size", str(len(document_batch)))
|
||||
scope.set_extra("document_ids", document_ids)
|
||||
scope.fingerprint = ["indexing-pipeline-failure", type(e).__name__]
|
||||
sentry_sdk.capture_exception(e)
|
||||
logger.exception(f"Failed to index document batch: {document_ids}")
|
||||
|
||||
index_pipeline_result = IndexingPipelineResult(
|
||||
@@ -542,6 +550,7 @@ def process_image_sections(documents: list[Document]) -> list[IndexingDocument]:
|
||||
**document.model_dump(),
|
||||
processed_sections=[
|
||||
Section(
|
||||
type=section.type,
|
||||
text=section.text if isinstance(section, TextSection) else "",
|
||||
link=section.link,
|
||||
image_file_id=(
|
||||
@@ -566,6 +575,7 @@ def process_image_sections(documents: list[Document]) -> list[IndexingDocument]:
|
||||
if isinstance(section, ImageSection):
|
||||
# Default section with image path preserved - ensure text is always a string
|
||||
processed_section = Section(
|
||||
type=section.type,
|
||||
link=section.link,
|
||||
image_file_id=section.image_file_id,
|
||||
text="", # Initialize with empty string
|
||||
@@ -609,6 +619,7 @@ def process_image_sections(documents: list[Document]) -> list[IndexingDocument]:
|
||||
# For TextSection, create a base Section with text and link
|
||||
elif isinstance(section, TextSection):
|
||||
processed_section = Section(
|
||||
type=section.type,
|
||||
text=section.text or "", # Ensure text is always a string, not None
|
||||
link=section.link,
|
||||
image_file_id=None,
|
||||
|
||||
@@ -6,6 +6,7 @@ from itertools import chain
|
||||
from itertools import groupby
|
||||
|
||||
import httpx
|
||||
import sentry_sdk
|
||||
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import DocumentFailure
|
||||
@@ -88,6 +89,12 @@ def write_chunks_to_vector_db_with_backoff(
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
with sentry_sdk.new_scope() as scope:
|
||||
scope.set_tag("stage", "vector_db_write")
|
||||
scope.set_tag("doc_id", doc_id)
|
||||
scope.set_tag("tenant_id", index_batch_params.tenant_id)
|
||||
scope.fingerprint = ["vector-db-write-failure", type(e).__name__]
|
||||
sentry_sdk.capture_exception(e)
|
||||
logger.exception(
|
||||
f"Failed to write document chunks for '{doc_id}' to vector db"
|
||||
)
|
||||
|
||||
@@ -66,7 +66,7 @@ PROVIDER_DISPLAY_NAMES: dict[str, str] = {
|
||||
LlmProviderNames.LM_STUDIO: "LM Studio",
|
||||
LlmProviderNames.LITELLM_PROXY: "LiteLLM Proxy",
|
||||
LlmProviderNames.BIFROST: "Bifrost",
|
||||
LlmProviderNames.OPENAI_COMPATIBLE: "OpenAI Compatible",
|
||||
LlmProviderNames.OPENAI_COMPATIBLE: "OpenAI-Compatible",
|
||||
"groq": "Groq",
|
||||
"anyscale": "Anyscale",
|
||||
"deepseek": "DeepSeek",
|
||||
|
||||
@@ -290,7 +290,11 @@ def litellm_exception_to_error_msg(
|
||||
error_code = "BUDGET_EXCEEDED"
|
||||
is_retryable = False
|
||||
elif isinstance(core_exception, Timeout):
|
||||
error_msg = "Request timed out: The operation took too long to complete. Please try again."
|
||||
error_msg = (
|
||||
"The LLM took too long to respond. "
|
||||
"If you're running a local model, try increasing the "
|
||||
"LLM_SOCKET_READ_TIMEOUT environment variable (current default: 120 seconds)."
|
||||
)
|
||||
error_code = "CONNECTION_ERROR"
|
||||
is_retryable = True
|
||||
elif isinstance(core_exception, APIError):
|
||||
|
||||
@@ -338,7 +338,7 @@ def get_provider_display_name(provider_name: str) -> str:
|
||||
VERTEXAI_PROVIDER_NAME: "Google Vertex AI",
|
||||
OPENROUTER_PROVIDER_NAME: "OpenRouter",
|
||||
LITELLM_PROXY_PROVIDER_NAME: "LiteLLM Proxy",
|
||||
OPENAI_COMPATIBLE_PROVIDER_NAME: "OpenAI Compatible",
|
||||
OPENAI_COMPATIBLE_PROVIDER_NAME: "OpenAI-Compatible",
|
||||
}
|
||||
|
||||
if provider_name in _ONYX_PROVIDER_DISPLAY_NAMES:
|
||||
|
||||
@@ -434,11 +434,14 @@ def get_application(lifespan_override: Lifespan | None = None) -> FastAPI:
|
||||
lifespan=lifespan_override or lifespan,
|
||||
)
|
||||
if SENTRY_DSN:
|
||||
from onyx.configs.sentry import _add_instance_tags
|
||||
|
||||
sentry_sdk.init(
|
||||
dsn=SENTRY_DSN,
|
||||
integrations=[StarletteIntegration(), FastApiIntegration()],
|
||||
traces_sample_rate=0.1,
|
||||
release=__version__,
|
||||
before_send=_add_instance_tags,
|
||||
)
|
||||
logger.info("Sentry initialized")
|
||||
else:
|
||||
|
||||
@@ -201,6 +201,33 @@ def count_tokens(
|
||||
return total
|
||||
|
||||
|
||||
def split_text_by_tokens(
|
||||
text: str,
|
||||
tokenizer: BaseTokenizer,
|
||||
max_tokens: int,
|
||||
) -> list[str]:
|
||||
"""Split ``text`` into pieces of ≤ ``max_tokens`` tokens each, via
|
||||
encode/decode at token-id boundaries.
|
||||
|
||||
Note: the returned pieces are not strictly guaranteed to re-tokenize to
|
||||
≤ max_tokens. BPE merges at window boundaries may drift by a few tokens,
|
||||
and cuts landing mid-multi-byte-UTF-8-character produce replacement
|
||||
characters on decode. Good enough for "best-effort" splitting of
|
||||
oversized content, not for hard limit enforcement.
|
||||
"""
|
||||
if not text:
|
||||
return []
|
||||
|
||||
token_ids: list[int] = []
|
||||
for start in range(0, len(text), _ENCODE_CHUNK_SIZE):
|
||||
token_ids.extend(tokenizer.encode(text[start : start + _ENCODE_CHUNK_SIZE]))
|
||||
|
||||
return [
|
||||
tokenizer.decode(token_ids[start : start + max_tokens])
|
||||
for start in range(0, len(token_ids), max_tokens)
|
||||
]
|
||||
|
||||
|
||||
def tokenizer_trim_content(
|
||||
content: str, desired_length: int, tokenizer: BaseTokenizer
|
||||
) -> str:
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
"date-fns": "^4.1.0",
|
||||
"embla-carousel-react": "^8.6.0",
|
||||
"lucide-react": "^0.562.0",
|
||||
"next": "16.1.7",
|
||||
"next": "16.2.3",
|
||||
"next-themes": "^0.4.6",
|
||||
"radix-ui": "^1.4.3",
|
||||
"react": "19.2.3",
|
||||
@@ -961,9 +961,9 @@
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/@hono/node-server": {
|
||||
"version": "1.19.10",
|
||||
"resolved": "https://registry.npmjs.org/@hono/node-server/-/node-server-1.19.10.tgz",
|
||||
"integrity": "sha512-hZ7nOssGqRgyV3FVVQdfi+U4q02uB23bpnYpdvNXkYTRRyWx84b7yf1ans+dnJ/7h41sGL3CeQTfO+ZGxuO+Iw==",
|
||||
"version": "1.19.13",
|
||||
"resolved": "https://registry.npmjs.org/@hono/node-server/-/node-server-1.19.13.tgz",
|
||||
"integrity": "sha512-TsQLe4i2gvoTtrHje625ngThGBySOgSK3Xo2XRYOdqGN1teR8+I7vchQC46uLJi8OF62YTYA3AhSpumtkhsaKQ==",
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": ">=18.14.1"
|
||||
@@ -1711,9 +1711,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@next/env": {
|
||||
"version": "16.1.7",
|
||||
"resolved": "https://registry.npmjs.org/@next/env/-/env-16.1.7.tgz",
|
||||
"integrity": "sha512-rJJbIdJB/RQr2F1nylZr/PJzamvNNhfr3brdKP6s/GW850jbtR70QlSfFselvIBbcPUOlQwBakexjFzqLzF6pg==",
|
||||
"version": "16.2.3",
|
||||
"resolved": "https://registry.npmjs.org/@next/env/-/env-16.2.3.tgz",
|
||||
"integrity": "sha512-ZWXyj4uNu4GCWQw9cjRxWlbD+33mcDszIo9iQxFnBX3Wmgq9ulaSJcl6VhuWx5pCWqqD+9W6Wfz7N0lM5lYPMA==",
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/@next/eslint-plugin-next": {
|
||||
@@ -1727,9 +1727,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@next/swc-darwin-arm64": {
|
||||
"version": "16.1.7",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-darwin-arm64/-/swc-darwin-arm64-16.1.7.tgz",
|
||||
"integrity": "sha512-b2wWIE8sABdyafc4IM8r5Y/dS6kD80JRtOGrUiKTsACFQfWWgUQ2NwoUX1yjFMXVsAwcQeNpnucF2ZrujsBBPg==",
|
||||
"version": "16.2.3",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-darwin-arm64/-/swc-darwin-arm64-16.2.3.tgz",
|
||||
"integrity": "sha512-u37KDKTKQ+OQLvY+z7SNXixwo4Q2/IAJFDzU1fYe66IbCE51aDSAzkNDkWmLN0yjTUh4BKBd+hb69jYn6qqqSg==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
@@ -1743,9 +1743,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@next/swc-darwin-x64": {
|
||||
"version": "16.1.7",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-darwin-x64/-/swc-darwin-x64-16.1.7.tgz",
|
||||
"integrity": "sha512-zcnVaaZulS1WL0Ss38R5Q6D2gz7MtBu8GZLPfK+73D/hp4GFMrC2sudLky1QibfV7h6RJBJs/gOFvYP0X7UVlQ==",
|
||||
"version": "16.2.3",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-darwin-x64/-/swc-darwin-x64-16.2.3.tgz",
|
||||
"integrity": "sha512-gHjL/qy6Q6CG3176FWbAKyKh9IfntKZTB3RY/YOJdDFpHGsUDXVH38U4mMNpHVGXmeYW4wj22dMp1lTfmu/bTQ==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
@@ -1759,9 +1759,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@next/swc-linux-arm64-gnu": {
|
||||
"version": "16.1.7",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-linux-arm64-gnu/-/swc-linux-arm64-gnu-16.1.7.tgz",
|
||||
"integrity": "sha512-2ant89Lux/Q3VyC8vNVg7uBaFVP9SwoK2jJOOR0L8TQnX8CAYnh4uctAScy2Hwj2dgjVHqHLORQZJ2wH6VxhSQ==",
|
||||
"version": "16.2.3",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-linux-arm64-gnu/-/swc-linux-arm64-gnu-16.2.3.tgz",
|
||||
"integrity": "sha512-U6vtblPtU/P14Y/b/n9ZY0GOxbbIhTFuaFR7F4/uMBidCi2nSdaOFhA0Go81L61Zd6527+yvuX44T4ksnf8T+Q==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
@@ -1775,9 +1775,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@next/swc-linux-arm64-musl": {
|
||||
"version": "16.1.7",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-linux-arm64-musl/-/swc-linux-arm64-musl-16.1.7.tgz",
|
||||
"integrity": "sha512-uufcze7LYv0FQg9GnNeZ3/whYfo+1Q3HnQpm16o6Uyi0OVzLlk2ZWoY7j07KADZFY8qwDbsmFnMQP3p3+Ftprw==",
|
||||
"version": "16.2.3",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-linux-arm64-musl/-/swc-linux-arm64-musl-16.2.3.tgz",
|
||||
"integrity": "sha512-/YV0LgjHUmfhQpn9bVoGc4x4nan64pkhWR5wyEV8yCOfwwrH630KpvRg86olQHTwHIn1z59uh6JwKvHq1h4QEw==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
@@ -1791,9 +1791,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@next/swc-linux-x64-gnu": {
|
||||
"version": "16.1.7",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-linux-x64-gnu/-/swc-linux-x64-gnu-16.1.7.tgz",
|
||||
"integrity": "sha512-KWVf2gxYvHtvuT+c4MBOGxuse5TD7DsMFYSxVxRBnOzok/xryNeQSjXgxSv9QpIVlaGzEn/pIuI6Koosx8CGWA==",
|
||||
"version": "16.2.3",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-linux-x64-gnu/-/swc-linux-x64-gnu-16.2.3.tgz",
|
||||
"integrity": "sha512-/HiWEcp+WMZ7VajuiMEFGZ6cg0+aYZPqCJD3YJEfpVWQsKYSjXQG06vJP6F1rdA03COD9Fef4aODs3YxKx+RDQ==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
@@ -1807,9 +1807,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@next/swc-linux-x64-musl": {
|
||||
"version": "16.1.7",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-linux-x64-musl/-/swc-linux-x64-musl-16.1.7.tgz",
|
||||
"integrity": "sha512-HguhaGwsGr1YAGs68uRKc4aGWxLET+NevJskOcCAwXbwj0fYX0RgZW2gsOCzr9S11CSQPIkxmoSbuVaBp4Z3dA==",
|
||||
"version": "16.2.3",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-linux-x64-musl/-/swc-linux-x64-musl-16.2.3.tgz",
|
||||
"integrity": "sha512-Kt44hGJfZSefebhk/7nIdivoDr3Ugp5+oNz9VvF3GUtfxutucUIHfIO0ZYO8QlOPDQloUVQn4NVC/9JvHRk9hw==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
@@ -1823,9 +1823,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@next/swc-win32-arm64-msvc": {
|
||||
"version": "16.1.7",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-win32-arm64-msvc/-/swc-win32-arm64-msvc-16.1.7.tgz",
|
||||
"integrity": "sha512-S0n3KrDJokKTeFyM/vGGGR8+pCmXYrjNTk2ZozOL1C/JFdfUIL9O1ATaJOl5r2POe56iRChbsszrjMAdWSv7kQ==",
|
||||
"version": "16.2.3",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-win32-arm64-msvc/-/swc-win32-arm64-msvc-16.2.3.tgz",
|
||||
"integrity": "sha512-O2NZ9ie3Tq6xj5Z5CSwBT3+aWAMW2PIZ4egUi9MaWLkwaehgtB7YZjPm+UpcNpKOme0IQuqDcor7BsW6QBiQBw==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
@@ -1839,9 +1839,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@next/swc-win32-x64-msvc": {
|
||||
"version": "16.1.7",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-win32-x64-msvc/-/swc-win32-x64-msvc-16.1.7.tgz",
|
||||
"integrity": "sha512-mwgtg8CNZGYm06LeEd+bNnOUfwOyNem/rOiP14Lsz+AnUY92Zq/LXwtebtUiaeVkhbroRCQ0c8GlR4UT1U+0yg==",
|
||||
"version": "16.2.3",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-win32-x64-msvc/-/swc-win32-x64-msvc-16.2.3.tgz",
|
||||
"integrity": "sha512-Ibm29/GgB/ab5n7XKqlStkm54qqZE8v2FnijUPBgrd67FWrac45o/RsNlaOWjme/B5UqeWt/8KM4aWBwA1D2Kw==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
@@ -7427,9 +7427,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/hono": {
|
||||
"version": "4.12.7",
|
||||
"resolved": "https://registry.npmjs.org/hono/-/hono-4.12.7.tgz",
|
||||
"integrity": "sha512-jq9l1DM0zVIvsm3lv9Nw9nlJnMNPOcAtsbsgiUhWcFzPE99Gvo6yRTlszSLLYacMeQ6quHD6hMfId8crVHvexw==",
|
||||
"version": "4.12.12",
|
||||
"resolved": "https://registry.npmjs.org/hono/-/hono-4.12.12.tgz",
|
||||
"integrity": "sha512-p1JfQMKaceuCbpJKAPKVqyqviZdS0eUxH9v82oWo1kb9xjQ5wA6iP3FNVAPDFlz5/p7d45lO+BpSk1tuSZMF4Q==",
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": ">=16.9.0"
|
||||
@@ -8637,9 +8637,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/lodash": {
|
||||
"version": "4.17.23",
|
||||
"resolved": "https://registry.npmjs.org/lodash/-/lodash-4.17.23.tgz",
|
||||
"integrity": "sha512-LgVTMpQtIopCi79SJeDiP0TfWi5CNEc/L/aRdTh3yIvmZXTnheWpKjSZhnvMl8iXbC1tFg9gdHHDMLoV7CnG+w==",
|
||||
"version": "4.18.1",
|
||||
"resolved": "https://registry.npmjs.org/lodash/-/lodash-4.18.1.tgz",
|
||||
"integrity": "sha512-dMInicTPVE8d1e5otfwmmjlxkZoUpiVLwyeTdUsi/Caj/gfzzblBcCE5sRHV/AsjuCmxWrte2TNGSYuCeCq+0Q==",
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/lodash.merge": {
|
||||
@@ -8978,12 +8978,12 @@
|
||||
}
|
||||
},
|
||||
"node_modules/next": {
|
||||
"version": "16.1.7",
|
||||
"resolved": "https://registry.npmjs.org/next/-/next-16.1.7.tgz",
|
||||
"integrity": "sha512-WM0L7WrSvKwoLegLYr6V+mz+RIofqQgVAfHhMp9a88ms0cFX8iX9ew+snpWlSBwpkURJOUdvCEt3uLl3NNzvWg==",
|
||||
"version": "16.2.3",
|
||||
"resolved": "https://registry.npmjs.org/next/-/next-16.2.3.tgz",
|
||||
"integrity": "sha512-9V3zV4oZFza3PVev5/poB9g0dEafVcgNyQ8eTRop8GvxZjV2G15FC5ARuG1eFD42QgeYkzJBJzHghNP8Ad9xtA==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@next/env": "16.1.7",
|
||||
"@next/env": "16.2.3",
|
||||
"@swc/helpers": "0.5.15",
|
||||
"baseline-browser-mapping": "^2.9.19",
|
||||
"caniuse-lite": "^1.0.30001579",
|
||||
@@ -8997,15 +8997,15 @@
|
||||
"node": ">=20.9.0"
|
||||
},
|
||||
"optionalDependencies": {
|
||||
"@next/swc-darwin-arm64": "16.1.7",
|
||||
"@next/swc-darwin-x64": "16.1.7",
|
||||
"@next/swc-linux-arm64-gnu": "16.1.7",
|
||||
"@next/swc-linux-arm64-musl": "16.1.7",
|
||||
"@next/swc-linux-x64-gnu": "16.1.7",
|
||||
"@next/swc-linux-x64-musl": "16.1.7",
|
||||
"@next/swc-win32-arm64-msvc": "16.1.7",
|
||||
"@next/swc-win32-x64-msvc": "16.1.7",
|
||||
"sharp": "^0.34.4"
|
||||
"@next/swc-darwin-arm64": "16.2.3",
|
||||
"@next/swc-darwin-x64": "16.2.3",
|
||||
"@next/swc-linux-arm64-gnu": "16.2.3",
|
||||
"@next/swc-linux-arm64-musl": "16.2.3",
|
||||
"@next/swc-linux-x64-gnu": "16.2.3",
|
||||
"@next/swc-linux-x64-musl": "16.2.3",
|
||||
"@next/swc-win32-arm64-msvc": "16.2.3",
|
||||
"@next/swc-win32-x64-msvc": "16.2.3",
|
||||
"sharp": "^0.34.5"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"@opentelemetry/api": "^1.1.0",
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
"date-fns": "^4.1.0",
|
||||
"embla-carousel-react": "^8.6.0",
|
||||
"lucide-react": "^0.562.0",
|
||||
"next": "16.1.7",
|
||||
"next": "16.2.3",
|
||||
"next-themes": "^0.4.6",
|
||||
"radix-ui": "^1.4.3",
|
||||
"react": "19.2.3",
|
||||
|
||||
@@ -618,6 +618,7 @@ done
|
||||
"app.kubernetes.io/managed-by": "onyx",
|
||||
"onyx.app/sandbox-id": sandbox_id,
|
||||
"onyx.app/tenant-id": tenant_id,
|
||||
"admission.datadoghq.com/enabled": "false",
|
||||
},
|
||||
),
|
||||
spec=pod_spec,
|
||||
|
||||
@@ -63,6 +63,7 @@ class DocumentSetCreationRequest(BaseModel):
|
||||
|
||||
class DocumentSetUpdateRequest(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
description: str
|
||||
cc_pair_ids: list[int]
|
||||
is_public: bool
|
||||
|
||||
@@ -96,6 +96,32 @@ def _truncate_description(description: str | None, max_length: int = 500) -> str
|
||||
return description[: max_length - 3] + "..."
|
||||
|
||||
|
||||
# TODO: Replace mask-comparison approach with an explicit Unset sentinel from the
|
||||
# frontend indicating whether each credential field was actually modified. The current
|
||||
# approach is brittle (e.g. short credentials produce a fixed-length mask that could
|
||||
# collide) and mutates request values, which is surprising. The frontend should signal
|
||||
# "unchanged" vs "new value" directly rather than relying on masked-string equality.
|
||||
def _restore_masked_oauth_credentials(
|
||||
request_client_id: str | None,
|
||||
request_client_secret: str | None,
|
||||
existing_client: OAuthClientInformationFull,
|
||||
) -> tuple[str | None, str | None]:
|
||||
"""If the frontend sent back masked credentials, restore the real stored values."""
|
||||
if (
|
||||
request_client_id
|
||||
and existing_client.client_id
|
||||
and request_client_id == mask_string(existing_client.client_id)
|
||||
):
|
||||
request_client_id = existing_client.client_id
|
||||
if (
|
||||
request_client_secret
|
||||
and existing_client.client_secret
|
||||
and request_client_secret == mask_string(existing_client.client_secret)
|
||||
):
|
||||
request_client_secret = existing_client.client_secret
|
||||
return request_client_id, request_client_secret
|
||||
|
||||
|
||||
router = APIRouter(prefix="/mcp")
|
||||
admin_router = APIRouter(prefix="/admin/mcp")
|
||||
STATE_TTL_SECONDS = 60 * 5 # 5 minutes
|
||||
@@ -392,6 +418,26 @@ async def _connect_oauth(
|
||||
detail=f"Server was configured with authentication type {auth_type_str}",
|
||||
)
|
||||
|
||||
# If the frontend sent back masked credentials (unchanged by the user),
|
||||
# restore the real stored values so we don't overwrite them with masks.
|
||||
if mcp_server.admin_connection_config:
|
||||
existing_data = extract_connection_data(
|
||||
mcp_server.admin_connection_config, apply_mask=False
|
||||
)
|
||||
existing_client_raw = existing_data.get(MCPOAuthKeys.CLIENT_INFO.value)
|
||||
if existing_client_raw:
|
||||
existing_client = OAuthClientInformationFull.model_validate(
|
||||
existing_client_raw
|
||||
)
|
||||
(
|
||||
request.oauth_client_id,
|
||||
request.oauth_client_secret,
|
||||
) = _restore_masked_oauth_credentials(
|
||||
request.oauth_client_id,
|
||||
request.oauth_client_secret,
|
||||
existing_client,
|
||||
)
|
||||
|
||||
# Create admin config with client info if provided
|
||||
config_data = MCPConnectionData(headers={})
|
||||
if request.oauth_client_id and request.oauth_client_secret:
|
||||
@@ -1356,6 +1402,19 @@ def _upsert_mcp_server(
|
||||
if client_info_raw:
|
||||
client_info = OAuthClientInformationFull.model_validate(client_info_raw)
|
||||
|
||||
# If the frontend sent back masked credentials (unchanged by the user),
|
||||
# restore the real stored values so the comparison below sees no change
|
||||
# and the credentials aren't overwritten with masked strings.
|
||||
if client_info and request.auth_type == MCPAuthenticationType.OAUTH:
|
||||
(
|
||||
request.oauth_client_id,
|
||||
request.oauth_client_secret,
|
||||
) = _restore_masked_oauth_credentials(
|
||||
request.oauth_client_id,
|
||||
request.oauth_client_secret,
|
||||
client_info,
|
||||
)
|
||||
|
||||
changing_connection_config = (
|
||||
not mcp_server.admin_connection_config
|
||||
or (
|
||||
|
||||
@@ -11,6 +11,9 @@ from onyx.db.notification import dismiss_notification
|
||||
from onyx.db.notification import get_notification_by_id
|
||||
from onyx.db.notification import get_notifications
|
||||
from onyx.server.features.build.utils import ensure_build_mode_intro_notification
|
||||
from onyx.server.features.notifications.utils import (
|
||||
ensure_permissions_migration_notification,
|
||||
)
|
||||
from onyx.server.features.release_notes.utils import (
|
||||
ensure_release_notes_fresh_and_notify,
|
||||
)
|
||||
@@ -49,6 +52,13 @@ def get_notifications_api(
|
||||
except Exception:
|
||||
logger.exception("Failed to check for release notes in notifications endpoint")
|
||||
|
||||
try:
|
||||
ensure_permissions_migration_notification(user, db_session)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to create permissions_migration_v1 announcement in notifications endpoint"
|
||||
)
|
||||
|
||||
notifications = [
|
||||
NotificationModel.from_model(notif)
|
||||
for notif in get_notifications(user, db_session, include_dismissed=True)
|
||||
|
||||
21
backend/onyx/server/features/notifications/utils.py
Normal file
21
backend/onyx/server/features/notifications/utils.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.constants import NotificationType
|
||||
from onyx.db.models import User
|
||||
from onyx.db.notification import create_notification
|
||||
|
||||
|
||||
def ensure_permissions_migration_notification(user: User, db_session: Session) -> None:
|
||||
# Feature id "permissions_migration_v1" must not change after shipping —
|
||||
# it is the dedup key on (user_id, notif_type, additional_data).
|
||||
create_notification(
|
||||
user_id=user.id,
|
||||
notif_type=NotificationType.FEATURE_ANNOUNCEMENT,
|
||||
db_session=db_session,
|
||||
title="Permissions are changing in Onyx",
|
||||
description="Roles are moving to group-based permissions. Click for details.",
|
||||
additional_data={
|
||||
"feature": "permissions_migration_v1",
|
||||
"link": "https://docs.onyx.app/admins/permissions/whats_changing",
|
||||
},
|
||||
)
|
||||
@@ -185,6 +185,10 @@ class MinimalPersonaSnapshot(BaseModel):
|
||||
for doc_set in persona.document_sets:
|
||||
for cc_pair in doc_set.connector_credential_pairs:
|
||||
sources.add(cc_pair.connector.source)
|
||||
for fed_ds in doc_set.federated_connectors:
|
||||
non_fed = fed_ds.federated_connector.source.to_non_federated_source()
|
||||
if non_fed is not None:
|
||||
sources.add(non_fed)
|
||||
|
||||
# Sources from hierarchy nodes
|
||||
for node in persona.hierarchy_nodes:
|
||||
@@ -195,6 +199,9 @@ class MinimalPersonaSnapshot(BaseModel):
|
||||
if doc.parent_hierarchy_node:
|
||||
sources.add(doc.parent_hierarchy_node.source)
|
||||
|
||||
if persona.user_files:
|
||||
sources.add(DocumentSource.USER_FILE)
|
||||
|
||||
return MinimalPersonaSnapshot(
|
||||
# Core fields actually used by ChatPage
|
||||
id=persona.id,
|
||||
|
||||
@@ -11,6 +11,7 @@ from sqlalchemy.orm import Session
|
||||
from onyx.auth.permissions import require_permission
|
||||
from onyx.auth.users import current_curator_or_admin_user
|
||||
from onyx.background.celery.versioned_apps.client import app as client_app
|
||||
from onyx.background.indexing.models import IndexAttemptErrorPydantic
|
||||
from onyx.configs.app_configs import GENERATIVE_MODEL_ACCESS_CHECK_FREQ
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import KV_GEN_AI_KEY_CHECK_TIME
|
||||
@@ -28,6 +29,7 @@ from onyx.db.feedback import fetch_docs_ranked_by_boost_for_user
|
||||
from onyx.db.feedback import update_document_boost_for_user
|
||||
from onyx.db.feedback import update_document_hidden_for_user
|
||||
from onyx.db.index_attempt import cancel_indexing_attempts_for_ccpair
|
||||
from onyx.db.index_attempt import get_index_attempt_errors_across_connectors
|
||||
from onyx.db.models import User
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.key_value_store.factory import get_kv_store
|
||||
@@ -35,6 +37,7 @@ from onyx.key_value_store.interface import KvKeyNotFoundError
|
||||
from onyx.llm.factory import get_default_llm
|
||||
from onyx.llm.utils import test_llm
|
||||
from onyx.server.documents.models import ConnectorCredentialPairIdentifier
|
||||
from onyx.server.documents.models import PaginatedReturn
|
||||
from onyx.server.manage.models import BoostDoc
|
||||
from onyx.server.manage.models import BoostUpdateRequest
|
||||
from onyx.server.manage.models import HiddenUpdateRequest
|
||||
@@ -206,3 +209,40 @@ def create_deletion_attempt_for_connector_id(
|
||||
file_store = get_default_file_store()
|
||||
for file_id in connector.connector_specific_config.get("file_locations", []):
|
||||
file_store.delete_file(file_id)
|
||||
|
||||
|
||||
@router.get("/admin/indexing/failed-documents")
|
||||
def get_failed_documents(
|
||||
cc_pair_id: int | None = None,
|
||||
error_type: str | None = None,
|
||||
start_time: datetime | None = None,
|
||||
end_time: datetime | None = None,
|
||||
include_resolved: bool = False,
|
||||
page_num: int = 0,
|
||||
page_size: int = 25,
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> PaginatedReturn[IndexAttemptErrorPydantic]:
|
||||
"""Get indexing errors across all connectors with optional filters.
|
||||
|
||||
Provides a cross-connector view of document indexing failures.
|
||||
Defaults to last 30 days if no start_time is provided to avoid
|
||||
unbounded count queries.
|
||||
"""
|
||||
if start_time is None:
|
||||
start_time = datetime.now(tz=timezone.utc) - timedelta(days=30)
|
||||
|
||||
errors, total = get_index_attempt_errors_across_connectors(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
error_type=error_type,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
unresolved_only=not include_resolved,
|
||||
page=page_num,
|
||||
page_size=page_size,
|
||||
)
|
||||
return PaginatedReturn(
|
||||
items=[IndexAttemptErrorPydantic.from_model(e) for e in errors],
|
||||
total_items=total,
|
||||
)
|
||||
|
||||
@@ -111,6 +111,43 @@ def _mask_string(value: str) -> str:
|
||||
return value[:4] + "****" + value[-4:]
|
||||
|
||||
|
||||
def _resolve_api_key(
|
||||
api_key: str | None,
|
||||
provider_name: str | None,
|
||||
api_base: str | None,
|
||||
db_session: Session,
|
||||
) -> str | None:
|
||||
"""Return the real API key for model-fetch endpoints.
|
||||
|
||||
When editing an existing provider the form value is masked (e.g.
|
||||
``sk-a****b1c2``). If *provider_name* is supplied we can look up
|
||||
the unmasked key from the database so the external request succeeds.
|
||||
|
||||
The stored key is only returned when the request's *api_base*
|
||||
matches the value stored in the database.
|
||||
"""
|
||||
if not provider_name:
|
||||
return api_key
|
||||
|
||||
existing_provider = fetch_existing_llm_provider(
|
||||
name=provider_name, db_session=db_session
|
||||
)
|
||||
if existing_provider and existing_provider.api_key:
|
||||
# Normalise both URLs before comparing so trailing-slash
|
||||
# differences don't cause a false mismatch.
|
||||
stored_base = (existing_provider.api_base or "").strip().rstrip("/")
|
||||
request_base = (api_base or "").strip().rstrip("/")
|
||||
if stored_base != request_base:
|
||||
return api_key
|
||||
|
||||
stored_key = existing_provider.api_key.get_value(apply_mask=False)
|
||||
# Only resolve when the incoming value is the masked form of the
|
||||
# stored key — i.e. the user hasn't typed a new key.
|
||||
if api_key and api_key == _mask_string(stored_key):
|
||||
return stored_key
|
||||
return api_key
|
||||
|
||||
|
||||
def _sync_fetched_models(
|
||||
db_session: Session,
|
||||
provider_name: str,
|
||||
@@ -1174,16 +1211,17 @@ def get_ollama_available_models(
|
||||
return sorted_results
|
||||
|
||||
|
||||
def _get_openrouter_models_response(api_base: str, api_key: str) -> dict:
|
||||
def _get_openrouter_models_response(api_base: str, api_key: str | None) -> dict:
|
||||
"""Perform GET to OpenRouter /models and return parsed JSON."""
|
||||
cleaned_api_base = api_base.strip().rstrip("/")
|
||||
url = f"{cleaned_api_base}/models"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
headers: dict[str, str] = {
|
||||
# Optional headers recommended by OpenRouter for attribution
|
||||
"HTTP-Referer": "https://onyx.app",
|
||||
"X-Title": "Onyx",
|
||||
}
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
try:
|
||||
response = httpx.get(url, headers=headers, timeout=10.0)
|
||||
response.raise_for_status()
|
||||
@@ -1206,8 +1244,12 @@ def get_openrouter_available_models(
|
||||
Parses id, name (display), context_length, and architecture.input_modalities.
|
||||
"""
|
||||
|
||||
api_key = _resolve_api_key(
|
||||
request.api_key, request.provider_name, request.api_base, db_session
|
||||
)
|
||||
|
||||
response_json = _get_openrouter_models_response(
|
||||
api_base=request.api_base, api_key=request.api_key
|
||||
api_base=request.api_base, api_key=api_key
|
||||
)
|
||||
|
||||
data = response_json.get("data", [])
|
||||
@@ -1300,13 +1342,18 @@ def get_lm_studio_available_models(
|
||||
|
||||
# If provider_name is given and the api_key hasn't been changed by the user,
|
||||
# fall back to the stored API key from the database (the form value is masked).
|
||||
# Only do so when the api_base matches what is stored.
|
||||
api_key = request.api_key
|
||||
if request.provider_name and not request.api_key_changed:
|
||||
existing_provider = fetch_existing_llm_provider(
|
||||
name=request.provider_name, db_session=db_session
|
||||
)
|
||||
if existing_provider and existing_provider.custom_config:
|
||||
api_key = existing_provider.custom_config.get(LM_STUDIO_API_KEY_CONFIG_KEY)
|
||||
stored_base = (existing_provider.api_base or "").strip().rstrip("/")
|
||||
if stored_base == cleaned_api_base:
|
||||
api_key = existing_provider.custom_config.get(
|
||||
LM_STUDIO_API_KEY_CONFIG_KEY
|
||||
)
|
||||
|
||||
url = f"{cleaned_api_base}/api/v1/models"
|
||||
headers: dict[str, str] = {}
|
||||
@@ -1390,8 +1437,12 @@ def get_litellm_available_models(
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[LitellmFinalModelResponse]:
|
||||
"""Fetch available models from Litellm proxy /v1/models endpoint."""
|
||||
api_key = _resolve_api_key(
|
||||
request.api_key, request.provider_name, request.api_base, db_session
|
||||
)
|
||||
|
||||
response_json = _get_litellm_models_response(
|
||||
api_key=request.api_key, api_base=request.api_base
|
||||
api_key=api_key, api_base=request.api_base
|
||||
)
|
||||
|
||||
models = response_json.get("data", [])
|
||||
@@ -1448,7 +1499,7 @@ def get_litellm_available_models(
|
||||
return sorted_results
|
||||
|
||||
|
||||
def _get_litellm_models_response(api_key: str, api_base: str) -> dict:
|
||||
def _get_litellm_models_response(api_key: str | None, api_base: str) -> dict:
|
||||
"""Perform GET to Litellm proxy /api/v1/models and return parsed JSON."""
|
||||
cleaned_api_base = api_base.strip().rstrip("/")
|
||||
url = f"{cleaned_api_base}/v1/models"
|
||||
@@ -1523,8 +1574,12 @@ def get_bifrost_available_models(
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[BifrostFinalModelResponse]:
|
||||
"""Fetch available models from Bifrost gateway /v1/models endpoint."""
|
||||
api_key = _resolve_api_key(
|
||||
request.api_key, request.provider_name, request.api_base, db_session
|
||||
)
|
||||
|
||||
response_json = _get_bifrost_models_response(
|
||||
api_base=request.api_base, api_key=request.api_key
|
||||
api_base=request.api_base, api_key=api_key
|
||||
)
|
||||
|
||||
models = response_json.get("data", [])
|
||||
@@ -1613,8 +1668,12 @@ def get_openai_compatible_server_available_models(
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[OpenAICompatibleFinalModelResponse]:
|
||||
"""Fetch available models from a generic OpenAI-compatible /v1/models endpoint."""
|
||||
api_key = _resolve_api_key(
|
||||
request.api_key, request.provider_name, request.api_base, db_session
|
||||
)
|
||||
|
||||
response_json = _get_openai_compatible_server_response(
|
||||
api_base=request.api_base, api_key=request.api_key
|
||||
api_base=request.api_base, api_key=api_key
|
||||
)
|
||||
|
||||
models = response_json.get("data", [])
|
||||
@@ -1674,7 +1733,7 @@ def get_openai_compatible_server_available_models(
|
||||
)
|
||||
for r in sorted_results
|
||||
],
|
||||
source_label="OpenAI Compatible",
|
||||
source_label="OpenAI-Compatible",
|
||||
)
|
||||
|
||||
return sorted_results
|
||||
@@ -1693,6 +1752,6 @@ def _get_openai_compatible_server_response(
|
||||
|
||||
return _get_openai_compatible_models_response(
|
||||
url=url,
|
||||
source_name="OpenAI Compatible",
|
||||
source_name="OpenAI-Compatible",
|
||||
api_key=api_key,
|
||||
)
|
||||
|
||||
@@ -183,6 +183,9 @@ def generate_ollama_display_name(model_name: str) -> str:
|
||||
"qwen2.5:7b" → "Qwen 2.5 7B"
|
||||
"mistral:latest" → "Mistral"
|
||||
"deepseek-r1:14b" → "DeepSeek R1 14B"
|
||||
"gemma4:e4b" → "Gemma 4 E4B"
|
||||
"deepseek-v3.1:671b-cloud" → "DeepSeek V3.1 671B Cloud"
|
||||
"qwen3-vl:235b-instruct-cloud" → "Qwen 3-vl 235B Instruct Cloud"
|
||||
"""
|
||||
# Split into base name and tag
|
||||
if ":" in model_name:
|
||||
@@ -209,13 +212,24 @@ def generate_ollama_display_name(model_name: str) -> str:
|
||||
# Default: Title case with dashes converted to spaces
|
||||
display_name = base.replace("-", " ").title()
|
||||
|
||||
# Process tag to extract size info (skip "latest")
|
||||
# Process tag (skip "latest")
|
||||
if tag and tag.lower() != "latest":
|
||||
# Extract size like "7b", "70b", "14b"
|
||||
size_match = re.match(r"^(\d+(?:\.\d+)?[bBmM])", tag)
|
||||
# Check for size prefix like "7b", "70b", optionally followed by modifiers
|
||||
size_match = re.match(r"^(\d+(?:\.\d+)?[bBmM])(-.+)?$", tag)
|
||||
if size_match:
|
||||
size = size_match.group(1).upper()
|
||||
display_name = f"{display_name} {size}"
|
||||
remainder = size_match.group(2)
|
||||
if remainder:
|
||||
# Format modifiers like "-cloud", "-instruct-cloud"
|
||||
modifiers = " ".join(
|
||||
p.title() for p in remainder.strip("-").split("-") if p
|
||||
)
|
||||
display_name = f"{display_name} {size} {modifiers}"
|
||||
else:
|
||||
display_name = f"{display_name} {size}"
|
||||
else:
|
||||
# Non-size tags like "e4b", "q4_0", "fp16", "cloud"
|
||||
display_name = f"{display_name} {tag.upper()}"
|
||||
|
||||
return display_name
|
||||
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
import json
|
||||
import secrets
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import File
|
||||
from fastapi import Query
|
||||
from fastapi import UploadFile
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.permissions import require_permission
|
||||
@@ -113,28 +114,47 @@ async def transcribe_audio(
|
||||
) from exc
|
||||
|
||||
|
||||
def _extract_provider_error(exc: Exception) -> str:
|
||||
"""Extract a human-readable message from a provider exception.
|
||||
|
||||
Provider errors often embed JSON from upstream APIs (e.g. ElevenLabs).
|
||||
This tries to parse a readable ``message`` field out of common JSON
|
||||
error shapes; falls back to ``str(exc)`` if nothing better is found.
|
||||
"""
|
||||
raw = str(exc)
|
||||
try:
|
||||
# Many providers embed JSON after a prefix like "ElevenLabs TTS failed: {...}"
|
||||
json_start = raw.find("{")
|
||||
if json_start == -1:
|
||||
return raw
|
||||
parsed = json.loads(raw[json_start:])
|
||||
# Shape: {"detail": {"message": "..."}} (ElevenLabs)
|
||||
detail = parsed.get("detail", parsed)
|
||||
if isinstance(detail, dict):
|
||||
return detail.get("message") or detail.get("error") or raw
|
||||
if isinstance(detail, str):
|
||||
return detail
|
||||
except (json.JSONDecodeError, AttributeError, TypeError):
|
||||
pass
|
||||
return raw
|
||||
|
||||
|
||||
class SynthesizeRequest(BaseModel):
|
||||
text: str = Field(..., min_length=1)
|
||||
voice: str | None = None
|
||||
speed: float | None = Field(default=None, ge=0.5, le=2.0)
|
||||
|
||||
|
||||
@router.post("/synthesize")
|
||||
async def synthesize_speech(
|
||||
text: str | None = Query(
|
||||
default=None, description="Text to synthesize", max_length=4096
|
||||
),
|
||||
voice: str | None = Query(default=None, description="Voice ID to use"),
|
||||
speed: float | None = Query(
|
||||
default=None, description="Playback speed (0.5-2.0)", ge=0.5, le=2.0
|
||||
),
|
||||
body: SynthesizeRequest,
|
||||
user: User = Depends(require_permission(Permission.BASIC_ACCESS)),
|
||||
) -> StreamingResponse:
|
||||
"""
|
||||
Synthesize text to speech using the default TTS provider.
|
||||
|
||||
Accepts parameters via query string for streaming compatibility.
|
||||
"""
|
||||
logger.info(
|
||||
f"TTS request: text length={len(text) if text else 0}, voice={voice}, speed={speed}"
|
||||
)
|
||||
|
||||
if not text:
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, "Text is required")
|
||||
"""Synthesize text to speech using the default TTS provider."""
|
||||
text = body.text
|
||||
voice = body.voice
|
||||
speed = body.speed
|
||||
logger.info(f"TTS request: text length={len(text)}, voice={voice}, speed={speed}")
|
||||
|
||||
# Use short-lived session to fetch provider config, then release connection
|
||||
# before starting the long-running streaming response
|
||||
@@ -177,31 +197,36 @@ async def synthesize_speech(
|
||||
logger.error(f"Failed to get voice provider: {exc}")
|
||||
raise OnyxError(OnyxErrorCode.INTERNAL_ERROR, str(exc)) from exc
|
||||
|
||||
# Session is now closed - streaming response won't hold DB connection
|
||||
# Pull the first chunk before returning the StreamingResponse. If the
|
||||
# provider rejects the request (e.g. text too long), the error surfaces
|
||||
# as a proper HTTP error instead of a broken audio stream.
|
||||
stream_iter = provider.synthesize_stream(
|
||||
text=text, voice=final_voice, speed=final_speed
|
||||
)
|
||||
try:
|
||||
first_chunk = await stream_iter.__anext__()
|
||||
except StopAsyncIteration:
|
||||
raise OnyxError(OnyxErrorCode.INTERNAL_ERROR, "TTS provider returned no audio")
|
||||
except Exception as exc:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.BAD_GATEWAY, _extract_provider_error(exc)
|
||||
) from exc
|
||||
|
||||
async def audio_stream() -> AsyncIterator[bytes]:
|
||||
try:
|
||||
chunk_count = 0
|
||||
async for chunk in provider.synthesize_stream(
|
||||
text=text, voice=final_voice, speed=final_speed
|
||||
):
|
||||
chunk_count += 1
|
||||
yield chunk
|
||||
logger.info(f"TTS streaming complete: {chunk_count} chunks sent")
|
||||
except NotImplementedError as exc:
|
||||
logger.error(f"TTS not implemented: {exc}")
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(f"Synthesis failed: {exc}")
|
||||
raise
|
||||
yield first_chunk
|
||||
chunk_count = 1
|
||||
async for chunk in stream_iter:
|
||||
chunk_count += 1
|
||||
yield chunk
|
||||
logger.info(f"TTS streaming complete: {chunk_count} chunks sent")
|
||||
|
||||
return StreamingResponse(
|
||||
audio_stream(),
|
||||
media_type="audio/mpeg",
|
||||
headers={
|
||||
"Content-Disposition": "inline; filename=speech.mp3",
|
||||
# Allow streaming by not setting content-length
|
||||
"Cache-Control": "no-cache",
|
||||
"X-Accel-Buffering": "no", # Disable nginx buffering
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
"""Generic Celery task lifecycle Prometheus metrics.
|
||||
|
||||
Provides signal handlers that track task started/completed/failed counts,
|
||||
active task gauge, task duration histograms, and retry/reject/revoke counts.
|
||||
active task gauge, task duration histograms, queue wait time histograms,
|
||||
and retry/reject/revoke counts.
|
||||
These fire for ALL tasks on the worker — no per-connector enrichment
|
||||
(see indexing_task_metrics.py for that).
|
||||
|
||||
@@ -71,6 +72,32 @@ TASK_REJECTED = Counter(
|
||||
["task_name"],
|
||||
)
|
||||
|
||||
TASK_QUEUE_WAIT = Histogram(
|
||||
"onyx_celery_task_queue_wait_seconds",
|
||||
"Time a Celery task spent waiting in the queue before execution started",
|
||||
["task_name", "queue"],
|
||||
buckets=[
|
||||
0.1,
|
||||
0.5,
|
||||
1,
|
||||
5,
|
||||
30,
|
||||
60,
|
||||
300,
|
||||
600,
|
||||
1800,
|
||||
3600,
|
||||
7200,
|
||||
14400,
|
||||
28800,
|
||||
43200,
|
||||
86400,
|
||||
172800,
|
||||
432000,
|
||||
864000,
|
||||
],
|
||||
)
|
||||
|
||||
# task_id → (monotonic start time, metric labels)
|
||||
_task_start_times: dict[str, tuple[float, dict[str, str]]] = {}
|
||||
|
||||
@@ -133,6 +160,13 @@ def on_celery_task_prerun(
|
||||
with _task_start_times_lock:
|
||||
_evict_stale_start_times()
|
||||
_task_start_times[task_id] = (time.monotonic(), labels)
|
||||
|
||||
headers = getattr(task.request, "headers", None) or {}
|
||||
enqueued_at = headers.get("enqueued_at")
|
||||
if isinstance(enqueued_at, (int, float)):
|
||||
TASK_QUEUE_WAIT.labels(**labels).observe(
|
||||
max(0.0, time.time() - enqueued_at)
|
||||
)
|
||||
except Exception:
|
||||
logger.debug("Failed to record celery task prerun metrics", exc_info=True)
|
||||
|
||||
|
||||
110
backend/onyx/server/metrics/connector_health_metrics.py
Normal file
110
backend/onyx/server/metrics/connector_health_metrics.py
Normal file
@@ -0,0 +1,110 @@
|
||||
"""Prometheus metrics for connector health and index attempts.
|
||||
|
||||
Emitted by docfetching and docprocessing workers when connector or
|
||||
index attempt state changes. All functions silently catch exceptions
|
||||
to avoid disrupting the caller's business logic.
|
||||
|
||||
Gauge metrics (error state, last success timestamp) are per-process.
|
||||
With multiple worker pods, use max() aggregation in PromQL to get the
|
||||
correct value across instances, e.g.:
|
||||
max by (cc_pair_id) (onyx_connector_in_error_state)
|
||||
"""
|
||||
|
||||
from prometheus_client import Counter
|
||||
from prometheus_client import Gauge
|
||||
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# --- Index attempt lifecycle ---
|
||||
|
||||
INDEX_ATTEMPT_STATUS = Counter(
|
||||
"onyx_index_attempt_transitions_total",
|
||||
"Index attempt status transitions",
|
||||
["tenant_id", "source", "cc_pair_id", "status"],
|
||||
)
|
||||
|
||||
# --- Connector health ---
|
||||
|
||||
CONNECTOR_IN_ERROR_STATE = Gauge(
|
||||
"onyx_connector_in_error_state",
|
||||
"Whether the connector is in a repeated error state (1=yes, 0=no)",
|
||||
["tenant_id", "source", "cc_pair_id"],
|
||||
)
|
||||
|
||||
CONNECTOR_LAST_SUCCESS_TIMESTAMP = Gauge(
|
||||
"onyx_connector_last_success_timestamp_seconds",
|
||||
"Unix timestamp of last successful indexing for this connector",
|
||||
["tenant_id", "source", "cc_pair_id"],
|
||||
)
|
||||
|
||||
CONNECTOR_DOCS_INDEXED = Counter(
|
||||
"onyx_connector_docs_indexed_total",
|
||||
"Total documents indexed per connector (monotonic)",
|
||||
["tenant_id", "source", "cc_pair_id"],
|
||||
)
|
||||
|
||||
CONNECTOR_INDEXING_ERRORS = Counter(
|
||||
"onyx_connector_indexing_errors_total",
|
||||
"Total failed index attempts per connector (monotonic)",
|
||||
["tenant_id", "source", "cc_pair_id"],
|
||||
)
|
||||
|
||||
|
||||
def on_index_attempt_status_change(
|
||||
tenant_id: str,
|
||||
source: str,
|
||||
cc_pair_id: int,
|
||||
status: str,
|
||||
) -> None:
|
||||
"""Called on any index attempt status transition."""
|
||||
try:
|
||||
labels = {
|
||||
"tenant_id": tenant_id,
|
||||
"source": source,
|
||||
"cc_pair_id": str(cc_pair_id),
|
||||
}
|
||||
INDEX_ATTEMPT_STATUS.labels(**labels, status=status).inc()
|
||||
if status == "failed":
|
||||
CONNECTOR_INDEXING_ERRORS.labels(**labels).inc()
|
||||
except Exception:
|
||||
logger.debug("Failed to record index attempt status metric", exc_info=True)
|
||||
|
||||
|
||||
def on_connector_error_state_change(
|
||||
tenant_id: str,
|
||||
source: str,
|
||||
cc_pair_id: int,
|
||||
in_error: bool,
|
||||
) -> None:
|
||||
"""Called when a connector's in_repeated_error_state changes."""
|
||||
try:
|
||||
CONNECTOR_IN_ERROR_STATE.labels(
|
||||
tenant_id=tenant_id,
|
||||
source=source,
|
||||
cc_pair_id=str(cc_pair_id),
|
||||
).set(1.0 if in_error else 0.0)
|
||||
except Exception:
|
||||
logger.debug("Failed to record connector error state metric", exc_info=True)
|
||||
|
||||
|
||||
def on_connector_indexing_success(
|
||||
tenant_id: str,
|
||||
source: str,
|
||||
cc_pair_id: int,
|
||||
docs_indexed: int,
|
||||
success_timestamp: float,
|
||||
) -> None:
|
||||
"""Called when an indexing run completes successfully."""
|
||||
try:
|
||||
labels = {
|
||||
"tenant_id": tenant_id,
|
||||
"source": source,
|
||||
"cc_pair_id": str(cc_pair_id),
|
||||
}
|
||||
CONNECTOR_LAST_SUCCESS_TIMESTAMP.labels(**labels).set(success_timestamp)
|
||||
if docs_indexed > 0:
|
||||
CONNECTOR_DOCS_INDEXED.labels(**labels).inc(docs_indexed)
|
||||
except Exception:
|
||||
logger.debug("Failed to record connector success metric", exc_info=True)
|
||||
104
backend/onyx/server/metrics/deletion_metrics.py
Normal file
104
backend/onyx/server/metrics/deletion_metrics.py
Normal file
@@ -0,0 +1,104 @@
|
||||
"""Connector-deletion-specific Prometheus metrics.
|
||||
|
||||
Tracks the deletion lifecycle:
|
||||
1. Deletions started (taskset generated)
|
||||
2. Deletions completed (success or failure)
|
||||
3. Taskset duration (from taskset generation to completion or failure).
|
||||
Note: this measures the most recent taskset execution, NOT wall-clock
|
||||
time since the user triggered the deletion. When deletion is blocked by
|
||||
indexing/pruning/permissions, the fence is cleared and a fresh taskset
|
||||
is generated on each retry, resetting this timer.
|
||||
4. Deletion blocked by dependencies (indexing, pruning, permissions, etc.)
|
||||
5. Fence resets (stuck deletion recovery)
|
||||
|
||||
All metrics are labeled by tenant_id. cc_pair_id is intentionally excluded
|
||||
to avoid unbounded cardinality.
|
||||
|
||||
Usage:
|
||||
from onyx.server.metrics.deletion_metrics import (
|
||||
inc_deletion_started,
|
||||
inc_deletion_completed,
|
||||
observe_deletion_taskset_duration,
|
||||
inc_deletion_blocked,
|
||||
inc_deletion_fence_reset,
|
||||
)
|
||||
"""
|
||||
|
||||
from prometheus_client import Counter
|
||||
from prometheus_client import Histogram
|
||||
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
DELETION_STARTED = Counter(
|
||||
"onyx_deletion_started_total",
|
||||
"Connector deletions initiated (taskset generated)",
|
||||
["tenant_id"],
|
||||
)
|
||||
|
||||
DELETION_COMPLETED = Counter(
|
||||
"onyx_deletion_completed_total",
|
||||
"Connector deletions completed",
|
||||
["tenant_id", "outcome"],
|
||||
)
|
||||
|
||||
DELETION_TASKSET_DURATION = Histogram(
|
||||
"onyx_deletion_taskset_duration_seconds",
|
||||
"Duration of a connector deletion taskset, from taskset generation "
|
||||
"to completion or failure. Does not include time spent blocked on "
|
||||
"indexing/pruning/permissions before the taskset was generated.",
|
||||
["tenant_id", "outcome"],
|
||||
buckets=[10, 30, 60, 120, 300, 600, 1800, 3600, 7200, 21600],
|
||||
)
|
||||
|
||||
DELETION_BLOCKED = Counter(
|
||||
"onyx_deletion_blocked_total",
|
||||
"Times deletion was blocked by a dependency",
|
||||
["tenant_id", "blocker"],
|
||||
)
|
||||
|
||||
DELETION_FENCE_RESET = Counter(
|
||||
"onyx_deletion_fence_reset_total",
|
||||
"Deletion fences reset due to missing celery tasks",
|
||||
["tenant_id"],
|
||||
)
|
||||
|
||||
|
||||
def inc_deletion_started(tenant_id: str) -> None:
|
||||
try:
|
||||
DELETION_STARTED.labels(tenant_id=tenant_id).inc()
|
||||
except Exception:
|
||||
logger.debug("Failed to record deletion started", exc_info=True)
|
||||
|
||||
|
||||
def inc_deletion_completed(tenant_id: str, outcome: str) -> None:
|
||||
try:
|
||||
DELETION_COMPLETED.labels(tenant_id=tenant_id, outcome=outcome).inc()
|
||||
except Exception:
|
||||
logger.debug("Failed to record deletion completed", exc_info=True)
|
||||
|
||||
|
||||
def observe_deletion_taskset_duration(
|
||||
tenant_id: str, outcome: str, duration_seconds: float
|
||||
) -> None:
|
||||
try:
|
||||
DELETION_TASKSET_DURATION.labels(tenant_id=tenant_id, outcome=outcome).observe(
|
||||
duration_seconds
|
||||
)
|
||||
except Exception:
|
||||
logger.debug("Failed to record deletion taskset duration", exc_info=True)
|
||||
|
||||
|
||||
def inc_deletion_blocked(tenant_id: str, blocker: str) -> None:
|
||||
try:
|
||||
DELETION_BLOCKED.labels(tenant_id=tenant_id, blocker=blocker).inc()
|
||||
except Exception:
|
||||
logger.debug("Failed to record deletion blocked", exc_info=True)
|
||||
|
||||
|
||||
def inc_deletion_fence_reset(tenant_id: str) -> None:
|
||||
try:
|
||||
DELETION_FENCE_RESET.labels(tenant_id=tenant_id).inc()
|
||||
except Exception:
|
||||
logger.debug("Failed to record deletion fence reset", exc_info=True)
|
||||
@@ -1,25 +1,30 @@
|
||||
"""Prometheus collectors for Celery queue depths and indexing pipeline state.
|
||||
"""Prometheus collectors for Celery queue depths and infrastructure health.
|
||||
|
||||
These collectors query Redis and Postgres at scrape time (the Collector pattern),
|
||||
These collectors query Redis at scrape time (the Collector pattern),
|
||||
so metrics are always fresh when Prometheus scrapes /metrics. They run inside the
|
||||
monitoring celery worker which already has Redis and DB access.
|
||||
monitoring celery worker which already has Redis access.
|
||||
|
||||
To avoid hammering Redis/Postgres on every 15s scrape, results are cached with
|
||||
To avoid hammering Redis on every 15s scrape, results are cached with
|
||||
a configurable TTL (default 30s). This means metrics may be up to TTL seconds
|
||||
stale, which is fine for monitoring dashboards.
|
||||
|
||||
Note: connector health and index attempt metrics are push-based (emitted by
|
||||
workers at state-change time) and live in connector_health_metrics.py.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import concurrent.futures
|
||||
import json
|
||||
import threading
|
||||
import time
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
|
||||
from prometheus_client.core import GaugeMetricFamily
|
||||
from prometheus_client.registry import Collector
|
||||
from redis import Redis
|
||||
|
||||
from onyx.background.celery.celery_redis import celery_get_broker_client
|
||||
from onyx.background.celery.celery_redis import celery_get_queue_length
|
||||
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
@@ -31,6 +36,11 @@ logger = setup_logger()
|
||||
# the previous result without re-querying Redis/Postgres.
|
||||
_DEFAULT_CACHE_TTL = 30.0
|
||||
|
||||
# Maximum time (seconds) a single _collect_fresh() call may take before
|
||||
# the collector gives up and returns stale/empty results. Prevents the
|
||||
# /metrics endpoint from hanging indefinitely when a DB or Redis query stalls.
|
||||
_DEFAULT_COLLECT_TIMEOUT = 120.0
|
||||
|
||||
_QUEUE_LABEL_MAP: dict[str, str] = {
|
||||
OnyxCeleryQueues.PRIMARY: "primary",
|
||||
OnyxCeleryQueues.DOCPROCESSING: "docprocessing",
|
||||
@@ -62,18 +72,32 @@ _UNACKED_QUEUES: list[str] = [
|
||||
|
||||
|
||||
class _CachedCollector(Collector):
|
||||
"""Base collector with TTL-based caching.
|
||||
"""Base collector with TTL-based caching and timeout protection.
|
||||
|
||||
Subclasses implement ``_collect_fresh()`` to query the actual data source.
|
||||
The base ``collect()`` returns cached results if the TTL hasn't expired,
|
||||
avoiding repeated queries when Prometheus scrapes frequently.
|
||||
|
||||
A per-collection timeout prevents a slow DB or Redis query from blocking
|
||||
the /metrics endpoint indefinitely. If _collect_fresh() exceeds the
|
||||
timeout, stale cached results are returned instead.
|
||||
"""
|
||||
|
||||
def __init__(self, cache_ttl: float = _DEFAULT_CACHE_TTL) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
cache_ttl: float = _DEFAULT_CACHE_TTL,
|
||||
collect_timeout: float = _DEFAULT_COLLECT_TIMEOUT,
|
||||
) -> None:
|
||||
self._cache_ttl = cache_ttl
|
||||
self._collect_timeout = collect_timeout
|
||||
self._cached_result: list[GaugeMetricFamily] | None = None
|
||||
self._last_collect_time: float = 0.0
|
||||
self._lock = threading.Lock()
|
||||
self._executor = concurrent.futures.ThreadPoolExecutor(
|
||||
max_workers=1,
|
||||
thread_name_prefix=type(self).__name__,
|
||||
)
|
||||
self._inflight: concurrent.futures.Future | None = None
|
||||
|
||||
def collect(self) -> list[GaugeMetricFamily]:
|
||||
with self._lock:
|
||||
@@ -84,12 +108,28 @@ class _CachedCollector(Collector):
|
||||
):
|
||||
return self._cached_result
|
||||
|
||||
# If a previous _collect_fresh() is still running, wait on it
|
||||
# rather than queuing another. This prevents unbounded task
|
||||
# accumulation in the executor during extended DB outages.
|
||||
if self._inflight is not None and not self._inflight.done():
|
||||
future = self._inflight
|
||||
else:
|
||||
future = self._executor.submit(self._collect_fresh)
|
||||
self._inflight = future
|
||||
|
||||
try:
|
||||
result = self._collect_fresh()
|
||||
result = future.result(timeout=self._collect_timeout)
|
||||
self._inflight = None
|
||||
self._cached_result = result
|
||||
self._last_collect_time = now
|
||||
return result
|
||||
except concurrent.futures.TimeoutError:
|
||||
logger.warning(
|
||||
f"{type(self).__name__}._collect_fresh() timed out after {self._collect_timeout}s, returning stale cache"
|
||||
)
|
||||
return self._cached_result if self._cached_result is not None else []
|
||||
except Exception:
|
||||
self._inflight = None
|
||||
logger.exception(f"Error in {type(self).__name__}.collect()")
|
||||
# Return stale cache on error rather than nothing — avoids
|
||||
# metrics disappearing during transient failures.
|
||||
@@ -117,8 +157,6 @@ class QueueDepthCollector(_CachedCollector):
|
||||
if self._celery_app is None:
|
||||
return []
|
||||
|
||||
from onyx.background.celery.celery_redis import celery_get_broker_client
|
||||
|
||||
redis_client = celery_get_broker_client(self._celery_app)
|
||||
|
||||
depth = GaugeMetricFamily(
|
||||
@@ -194,208 +232,6 @@ class QueueDepthCollector(_CachedCollector):
|
||||
return None
|
||||
|
||||
|
||||
class IndexAttemptCollector(_CachedCollector):
|
||||
"""Queries Postgres for index attempt state on each scrape."""
|
||||
|
||||
def __init__(self, cache_ttl: float = _DEFAULT_CACHE_TTL) -> None:
|
||||
super().__init__(cache_ttl)
|
||||
self._configured: bool = False
|
||||
self._terminal_statuses: list = []
|
||||
|
||||
def configure(self) -> None:
|
||||
"""Call once DB engine is initialized."""
|
||||
from onyx.db.enums import IndexingStatus
|
||||
|
||||
self._terminal_statuses = [s for s in IndexingStatus if s.is_terminal()]
|
||||
self._configured = True
|
||||
|
||||
def _collect_fresh(self) -> list[GaugeMetricFamily]:
|
||||
if not self._configured:
|
||||
return []
|
||||
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.engine.tenant_utils import get_all_tenant_ids
|
||||
from onyx.db.index_attempt import get_active_index_attempts_for_metrics
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
attempts_gauge = GaugeMetricFamily(
|
||||
"onyx_index_attempts_active",
|
||||
"Number of non-terminal index attempts",
|
||||
labels=[
|
||||
"status",
|
||||
"source",
|
||||
"tenant_id",
|
||||
"connector_name",
|
||||
"cc_pair_id",
|
||||
],
|
||||
)
|
||||
|
||||
tenant_ids = get_all_tenant_ids()
|
||||
|
||||
for tid in tenant_ids:
|
||||
# Defensive guard — get_all_tenant_ids() should never yield None,
|
||||
# but we guard here for API stability in case the contract changes.
|
||||
if tid is None:
|
||||
continue
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tid)
|
||||
try:
|
||||
with get_session_with_current_tenant() as session:
|
||||
rows = get_active_index_attempts_for_metrics(session)
|
||||
|
||||
for status, source, cc_id, cc_name, count in rows:
|
||||
name_val = cc_name or f"cc_pair_{cc_id}"
|
||||
attempts_gauge.add_metric(
|
||||
[
|
||||
status.value,
|
||||
source.value,
|
||||
tid,
|
||||
name_val,
|
||||
str(cc_id),
|
||||
],
|
||||
count,
|
||||
)
|
||||
finally:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
return [attempts_gauge]
|
||||
|
||||
|
||||
class ConnectorHealthCollector(_CachedCollector):
|
||||
"""Queries Postgres for connector health state on each scrape."""
|
||||
|
||||
def __init__(self, cache_ttl: float = _DEFAULT_CACHE_TTL) -> None:
|
||||
super().__init__(cache_ttl)
|
||||
self._configured: bool = False
|
||||
|
||||
def configure(self) -> None:
|
||||
"""Call once DB engine is initialized."""
|
||||
self._configured = True
|
||||
|
||||
def _collect_fresh(self) -> list[GaugeMetricFamily]:
|
||||
if not self._configured:
|
||||
return []
|
||||
|
||||
from onyx.db.connector_credential_pair import (
|
||||
get_connector_health_for_metrics,
|
||||
)
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.engine.tenant_utils import get_all_tenant_ids
|
||||
from onyx.db.index_attempt import get_docs_indexed_by_cc_pair
|
||||
from onyx.db.index_attempt import get_failed_attempt_counts_by_cc_pair
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
staleness_gauge = GaugeMetricFamily(
|
||||
"onyx_connector_last_success_age_seconds",
|
||||
"Seconds since last successful index for this connector",
|
||||
labels=["tenant_id", "source", "cc_pair_id", "connector_name"],
|
||||
)
|
||||
error_state_gauge = GaugeMetricFamily(
|
||||
"onyx_connector_in_error_state",
|
||||
"Whether the connector is in a repeated error state (1=yes, 0=no)",
|
||||
labels=["tenant_id", "source", "cc_pair_id", "connector_name"],
|
||||
)
|
||||
by_status_gauge = GaugeMetricFamily(
|
||||
"onyx_connectors_by_status",
|
||||
"Number of connectors grouped by status",
|
||||
labels=["tenant_id", "status"],
|
||||
)
|
||||
error_total_gauge = GaugeMetricFamily(
|
||||
"onyx_connectors_in_error_total",
|
||||
"Total number of connectors in repeated error state",
|
||||
labels=["tenant_id"],
|
||||
)
|
||||
per_connector_labels = [
|
||||
"tenant_id",
|
||||
"source",
|
||||
"cc_pair_id",
|
||||
"connector_name",
|
||||
]
|
||||
docs_success_gauge = GaugeMetricFamily(
|
||||
"onyx_connector_docs_indexed",
|
||||
"Total new documents indexed (90-day rolling sum) per connector",
|
||||
labels=per_connector_labels,
|
||||
)
|
||||
docs_error_gauge = GaugeMetricFamily(
|
||||
"onyx_connector_error_count",
|
||||
"Total number of failed index attempts per connector",
|
||||
labels=per_connector_labels,
|
||||
)
|
||||
|
||||
now = datetime.now(tz=timezone.utc)
|
||||
tenant_ids = get_all_tenant_ids()
|
||||
|
||||
for tid in tenant_ids:
|
||||
# Defensive guard — get_all_tenant_ids() should never yield None,
|
||||
# but we guard here for API stability in case the contract changes.
|
||||
if tid is None:
|
||||
continue
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tid)
|
||||
try:
|
||||
with get_session_with_current_tenant() as session:
|
||||
pairs = get_connector_health_for_metrics(session)
|
||||
error_counts_by_cc = get_failed_attempt_counts_by_cc_pair(session)
|
||||
docs_by_cc = get_docs_indexed_by_cc_pair(session)
|
||||
|
||||
status_counts: dict[str, int] = {}
|
||||
error_count = 0
|
||||
|
||||
for (
|
||||
cc_id,
|
||||
status,
|
||||
in_error,
|
||||
last_success,
|
||||
cc_name,
|
||||
source,
|
||||
) in pairs:
|
||||
cc_id_str = str(cc_id)
|
||||
source_val = source.value
|
||||
name_val = cc_name or f"cc_pair_{cc_id}"
|
||||
label_vals = [tid, source_val, cc_id_str, name_val]
|
||||
|
||||
if last_success is not None:
|
||||
# Both `now` and `last_success` are timezone-aware
|
||||
# (the DB column uses DateTime(timezone=True)),
|
||||
# so subtraction is safe.
|
||||
age = (now - last_success).total_seconds()
|
||||
staleness_gauge.add_metric(label_vals, age)
|
||||
|
||||
error_state_gauge.add_metric(
|
||||
label_vals,
|
||||
1.0 if in_error else 0.0,
|
||||
)
|
||||
if in_error:
|
||||
error_count += 1
|
||||
|
||||
docs_success_gauge.add_metric(
|
||||
label_vals,
|
||||
docs_by_cc.get(cc_id, 0),
|
||||
)
|
||||
|
||||
docs_error_gauge.add_metric(
|
||||
label_vals,
|
||||
error_counts_by_cc.get(cc_id, 0),
|
||||
)
|
||||
|
||||
status_val = status.value
|
||||
status_counts[status_val] = status_counts.get(status_val, 0) + 1
|
||||
|
||||
for status_val, count in status_counts.items():
|
||||
by_status_gauge.add_metric([tid, status_val], count)
|
||||
|
||||
error_total_gauge.add_metric([tid], error_count)
|
||||
finally:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
return [
|
||||
staleness_gauge,
|
||||
error_state_gauge,
|
||||
by_status_gauge,
|
||||
error_total_gauge,
|
||||
docs_success_gauge,
|
||||
docs_error_gauge,
|
||||
]
|
||||
|
||||
|
||||
class RedisHealthCollector(_CachedCollector):
|
||||
"""Collects Redis server health metrics (memory, clients, etc.)."""
|
||||
|
||||
@@ -411,8 +247,6 @@ class RedisHealthCollector(_CachedCollector):
|
||||
if self._celery_app is None:
|
||||
return []
|
||||
|
||||
from onyx.background.celery.celery_redis import celery_get_broker_client
|
||||
|
||||
redis_client = celery_get_broker_client(self._celery_app)
|
||||
|
||||
memory_used = GaugeMetricFamily(
|
||||
@@ -495,7 +329,9 @@ class WorkerHeartbeatMonitor:
|
||||
},
|
||||
)
|
||||
recv.capture(
|
||||
limit=None, timeout=self._HEARTBEAT_TIMEOUT_SECONDS, wakeup=True
|
||||
limit=None,
|
||||
timeout=self._HEARTBEAT_TIMEOUT_SECONDS,
|
||||
wakeup=True,
|
||||
)
|
||||
except Exception:
|
||||
if self._running:
|
||||
|
||||
@@ -6,8 +6,6 @@ Called once by the monitoring celery worker after Redis and DB are ready.
|
||||
from celery import Celery
|
||||
from prometheus_client.registry import REGISTRY
|
||||
|
||||
from onyx.server.metrics.indexing_pipeline import ConnectorHealthCollector
|
||||
from onyx.server.metrics.indexing_pipeline import IndexAttemptCollector
|
||||
from onyx.server.metrics.indexing_pipeline import QueueDepthCollector
|
||||
from onyx.server.metrics.indexing_pipeline import RedisHealthCollector
|
||||
from onyx.server.metrics.indexing_pipeline import WorkerHealthCollector
|
||||
@@ -21,8 +19,6 @@ logger = setup_logger()
|
||||
# module level ensures they survive the lifetime of the worker process and are
|
||||
# only registered with the Prometheus registry once.
|
||||
_queue_collector = QueueDepthCollector()
|
||||
_attempt_collector = IndexAttemptCollector()
|
||||
_connector_collector = ConnectorHealthCollector()
|
||||
_redis_health_collector = RedisHealthCollector()
|
||||
_worker_health_collector = WorkerHealthCollector()
|
||||
_heartbeat_monitor: WorkerHeartbeatMonitor | None = None
|
||||
@@ -34,6 +30,9 @@ def setup_indexing_pipeline_metrics(celery_app: Celery) -> None:
|
||||
Args:
|
||||
celery_app: The Celery application instance. Used to obtain a
|
||||
broker Redis client on each scrape for queue depth metrics.
|
||||
|
||||
Note: connector health and index attempt metrics are push-based
|
||||
(see connector_health_metrics.py) and do not use collectors.
|
||||
"""
|
||||
_queue_collector.set_celery_app(celery_app)
|
||||
_redis_health_collector.set_celery_app(celery_app)
|
||||
@@ -47,13 +46,8 @@ def setup_indexing_pipeline_metrics(celery_app: Celery) -> None:
|
||||
_heartbeat_monitor.start()
|
||||
_worker_health_collector.set_monitor(_heartbeat_monitor)
|
||||
|
||||
_attempt_collector.configure()
|
||||
_connector_collector.configure()
|
||||
|
||||
for collector in (
|
||||
_queue_collector,
|
||||
_attempt_collector,
|
||||
_connector_collector,
|
||||
_redis_health_collector,
|
||||
_worker_health_collector,
|
||||
):
|
||||
|
||||
@@ -27,6 +27,7 @@ _DEFAULT_PORTS: dict[str, int] = {
|
||||
"docfetching": 9092,
|
||||
"docprocessing": 9093,
|
||||
"heavy": 9094,
|
||||
"light": 9095,
|
||||
}
|
||||
|
||||
_server_started = False
|
||||
|
||||
@@ -28,14 +28,14 @@ PRUNING_ENUMERATION_DURATION = Histogram(
|
||||
"onyx_pruning_enumeration_duration_seconds",
|
||||
"Duration of document ID enumeration from the source connector during pruning",
|
||||
["connector_type"],
|
||||
buckets=[1, 5, 15, 30, 60, 120, 300, 600, 1800, 3600],
|
||||
buckets=[5, 60, 600, 1800, 3600, 10800, 21600],
|
||||
)
|
||||
|
||||
PRUNING_DIFF_DURATION = Histogram(
|
||||
"onyx_pruning_diff_duration_seconds",
|
||||
"Duration of diff computation and subtask dispatch during pruning",
|
||||
["connector_type"],
|
||||
buckets=[1, 5, 15, 30, 60, 120, 300, 600, 1800, 3600],
|
||||
buckets=[0.1, 0.25, 0.5, 1, 2, 5, 15, 30, 60],
|
||||
)
|
||||
|
||||
PRUNING_RATE_LIMIT_ERRORS = Counter(
|
||||
|
||||
@@ -65,7 +65,8 @@ class Settings(BaseModel):
|
||||
anonymous_user_enabled: bool | None = None
|
||||
invite_only_enabled: bool = False
|
||||
deep_research_enabled: bool | None = None
|
||||
search_ui_enabled: bool | None = None
|
||||
multi_model_chat_enabled: bool | None = True
|
||||
search_ui_enabled: bool | None = True
|
||||
|
||||
# Whether EE features are unlocked for use.
|
||||
# Depends on license status: True when the user has a valid license
|
||||
@@ -89,7 +90,8 @@ class Settings(BaseModel):
|
||||
default=DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB, ge=0
|
||||
)
|
||||
file_token_count_threshold_k: int | None = Field(
|
||||
default=None, ge=0 # thousands of tokens; None = context-aware default
|
||||
default=None,
|
||||
ge=0, # thousands of tokens; None = context-aware default
|
||||
)
|
||||
|
||||
# Connector settings
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user