mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-04-16 15:06:45 +00:00
Compare commits
1 Commits
jamison/ti
...
jamison/rm
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
cd02b99d5e |
@@ -1,63 +0,0 @@
|
||||
FROM ubuntu:26.04@sha256:cc925e589b7543b910fea57a240468940003fbfc0515245a495dd0ad8fe7cef1
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
curl \
|
||||
default-jre \
|
||||
fd-find \
|
||||
fzf \
|
||||
git \
|
||||
jq \
|
||||
less \
|
||||
make \
|
||||
neovim \
|
||||
openssh-client \
|
||||
python3-venv \
|
||||
ripgrep \
|
||||
sudo \
|
||||
ca-certificates \
|
||||
iptables \
|
||||
ipset \
|
||||
iproute2 \
|
||||
dnsutils \
|
||||
unzip \
|
||||
wget \
|
||||
zsh \
|
||||
&& curl -fsSL https://deb.nodesource.com/setup_20.x | bash - \
|
||||
&& apt-get install -y nodejs \
|
||||
&& install -m 0755 -d /etc/apt/keyrings \
|
||||
&& curl -fsSL https://cli.github.com/packages/githubcli-archive-keyring.gpg -o /etc/apt/keyrings/githubcli-archive-keyring.gpg \
|
||||
&& chmod go+r /etc/apt/keyrings/githubcli-archive-keyring.gpg \
|
||||
&& echo "deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/githubcli-archive-keyring.gpg] https://cli.github.com/packages stable main" > /etc/apt/sources.list.d/github-cli.list \
|
||||
&& apt-get update \
|
||||
&& apt-get install -y --no-install-recommends gh \
|
||||
&& apt-get clean && rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# fd-find installs as fdfind on Debian/Ubuntu — symlink to fd
|
||||
RUN ln -sf "$(which fdfind)" /usr/local/bin/fd
|
||||
|
||||
# Install uv (Python package manager)
|
||||
COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /usr/local/bin/
|
||||
|
||||
# Create non-root dev user with passwordless sudo
|
||||
RUN useradd -m -s /bin/zsh dev && \
|
||||
echo "dev ALL=(ALL) NOPASSWD:ALL" > /etc/sudoers.d/dev && \
|
||||
chmod 0440 /etc/sudoers.d/dev
|
||||
|
||||
ENV DEVCONTAINER=true
|
||||
|
||||
RUN mkdir -p /workspace && \
|
||||
chown -R dev:dev /workspace
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
# Install Claude Code
|
||||
ARG CLAUDE_CODE_VERSION=latest
|
||||
RUN npm install -g @anthropic-ai/claude-code@${CLAUDE_CODE_VERSION}
|
||||
|
||||
# Configure zsh — source the repo-local zshrc so shell customization
|
||||
# doesn't require an image rebuild.
|
||||
RUN chsh -s /bin/zsh root && \
|
||||
for rc in /root/.zshrc /home/dev/.zshrc; do \
|
||||
echo '[ -f /workspace/.devcontainer/zshrc ] && . /workspace/.devcontainer/zshrc' >> "$rc"; \
|
||||
done && \
|
||||
chown dev:dev /home/dev/.zshrc
|
||||
@@ -1,86 +0,0 @@
|
||||
# Onyx Dev Container
|
||||
|
||||
A containerized development environment for working on Onyx.
|
||||
|
||||
## What's included
|
||||
|
||||
- Ubuntu 26.04 base image
|
||||
- Node.js 20, uv, Claude Code
|
||||
- GitHub CLI (`gh`)
|
||||
- Neovim, ripgrep, fd, fzf, jq, make, wget, unzip
|
||||
- Zsh as default shell (sources host `~/.zshrc` if available)
|
||||
- Python venv auto-activation
|
||||
- Network firewall (default-deny, whitelists npm, GitHub, Anthropic APIs, Sentry, and VS Code update servers)
|
||||
|
||||
## Usage
|
||||
|
||||
### CLI (`ods dev`)
|
||||
|
||||
The [`ods` devtools CLI](../tools/ods/README.md) provides workspace-aware wrappers
|
||||
for all devcontainer operations (also available as `ods dc`):
|
||||
|
||||
```bash
|
||||
# Start the container
|
||||
ods dev up
|
||||
|
||||
# Open a shell
|
||||
ods dev into
|
||||
|
||||
# Run a command
|
||||
ods dev exec npm test
|
||||
|
||||
# Stop the container
|
||||
ods dev stop
|
||||
```
|
||||
|
||||
## Restarting the container
|
||||
|
||||
```bash
|
||||
# Restart the container
|
||||
ods dev restart
|
||||
|
||||
# Pull the latest published image and recreate
|
||||
ods dev rebuild
|
||||
```
|
||||
|
||||
## Image
|
||||
|
||||
The devcontainer uses a prebuilt image published to `onyxdotapp/onyx-devcontainer`.
|
||||
The tag is pinned in `devcontainer.json` — no local build is required.
|
||||
|
||||
To build the image locally (e.g. while iterating on the Dockerfile):
|
||||
|
||||
```bash
|
||||
docker buildx bake devcontainer
|
||||
```
|
||||
|
||||
The `devcontainer` target is defined in `docker-bake.hcl` at the repo root.
|
||||
|
||||
## User & permissions
|
||||
|
||||
The container runs as the `dev` user by default (`remoteUser` in devcontainer.json).
|
||||
An init script (`init-dev-user.sh`) runs at container start to ensure the active
|
||||
user has read/write access to the bind-mounted workspace:
|
||||
|
||||
- **Standard Docker** — `dev`'s UID/GID is remapped to match the workspace owner,
|
||||
so file permissions work seamlessly.
|
||||
- **Rootless Docker** — The workspace appears as root-owned (UID 0) inside the
|
||||
container due to user-namespace mapping. `ods dev up` auto-detects rootless Docker
|
||||
and sets `DEVCONTAINER_REMOTE_USER=root` so the container runs as root — which
|
||||
maps back to your host user via the user namespace. New files are owned by your
|
||||
host UID and no ACL workarounds are needed.
|
||||
|
||||
To override the auto-detection, set `DEVCONTAINER_REMOTE_USER` before running
|
||||
`ods dev up`.
|
||||
|
||||
## Firewall
|
||||
|
||||
The container starts with a default-deny firewall (`init-firewall.sh`) that only allows outbound traffic to:
|
||||
|
||||
- npm registry
|
||||
- GitHub
|
||||
- Anthropic API
|
||||
- Sentry
|
||||
- VS Code update servers
|
||||
|
||||
This requires the `NET_ADMIN` and `NET_RAW` capabilities, which are added via `runArgs` in `devcontainer.json`.
|
||||
@@ -1,26 +0,0 @@
|
||||
{
|
||||
"name": "Onyx Dev Sandbox",
|
||||
"image": "onyxdotapp/onyx-devcontainer@sha256:0f02d9299928849c7b15f3b348dcfdcdcb64411ff7a4580cbc026a6ee7aa1554",
|
||||
"runArgs": ["--cap-add=NET_ADMIN", "--cap-add=NET_RAW", "--network=onyx_default"],
|
||||
"mounts": [
|
||||
"source=${localEnv:HOME}/.claude,target=/home/dev/.claude,type=bind",
|
||||
"source=${localEnv:HOME}/.claude.json,target=/home/dev/.claude.json,type=bind",
|
||||
"source=${localEnv:HOME}/.zshrc,target=/home/dev/.zshrc.host,type=bind,readonly",
|
||||
"source=${localEnv:HOME}/.gitconfig,target=/home/dev/.gitconfig,type=bind,readonly",
|
||||
"source=${localEnv:HOME}/.config/nvim,target=/home/dev/.config/nvim,type=bind,readonly",
|
||||
"source=onyx-devcontainer-cache,target=/home/dev/.cache,type=volume",
|
||||
"source=onyx-devcontainer-local,target=/home/dev/.local,type=volume"
|
||||
],
|
||||
"containerEnv": {
|
||||
"SSH_AUTH_SOCK": "/tmp/ssh-agent.sock",
|
||||
"POSTGRES_HOST": "relational_db",
|
||||
"REDIS_HOST": "cache"
|
||||
},
|
||||
"remoteUser": "${localEnv:DEVCONTAINER_REMOTE_USER:dev}",
|
||||
"updateRemoteUserUID": false,
|
||||
"initializeCommand": "docker network create onyx_default 2>/dev/null || true",
|
||||
"workspaceMount": "source=${localWorkspaceFolder},target=/workspace,type=bind,consistency=delegated",
|
||||
"workspaceFolder": "/workspace",
|
||||
"postStartCommand": "sudo bash /workspace/.devcontainer/init-dev-user.sh && sudo bash /workspace/.devcontainer/init-firewall.sh",
|
||||
"waitFor": "postStartCommand"
|
||||
}
|
||||
@@ -1,107 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
# Remap the dev user's UID/GID to match the workspace owner so that
|
||||
# bind-mounted files are accessible without running as root.
|
||||
#
|
||||
# Standard Docker: Workspace is owned by the host user's UID (e.g. 1000).
|
||||
# We remap dev to that UID -- fast and seamless.
|
||||
#
|
||||
# Rootless Docker: Workspace appears as root-owned (UID 0) inside the
|
||||
# container due to user-namespace mapping. Requires
|
||||
# DEVCONTAINER_REMOTE_USER=root (set automatically by
|
||||
# ods dev up). Container root IS the host user, so
|
||||
# bind-mounts and named volumes are symlinked into /root.
|
||||
|
||||
WORKSPACE=/workspace
|
||||
TARGET_USER=dev
|
||||
REMOTE_USER="${SUDO_USER:-$TARGET_USER}"
|
||||
|
||||
WS_UID=$(stat -c '%u' "$WORKSPACE")
|
||||
WS_GID=$(stat -c '%g' "$WORKSPACE")
|
||||
DEV_UID=$(id -u "$TARGET_USER")
|
||||
DEV_GID=$(id -g "$TARGET_USER")
|
||||
|
||||
# devcontainer.json bind-mounts and named volumes target /home/dev regardless
|
||||
# of remoteUser. When running as root ($HOME=/root), Phase 1 bridges the gap
|
||||
# with symlinks from ACTIVE_HOME → MOUNT_HOME.
|
||||
MOUNT_HOME=/home/"$TARGET_USER"
|
||||
|
||||
if [ "$REMOTE_USER" = "root" ]; then
|
||||
ACTIVE_HOME="/root"
|
||||
else
|
||||
ACTIVE_HOME="$MOUNT_HOME"
|
||||
fi
|
||||
|
||||
# ── Phase 1: home directory setup ───────────────────────────────────
|
||||
|
||||
# ~/.local and ~/.cache are named Docker volumes mounted under MOUNT_HOME.
|
||||
mkdir -p "$MOUNT_HOME"/.local/state "$MOUNT_HOME"/.local/share
|
||||
|
||||
# When running as root, symlink bind-mounts and named volumes into /root
|
||||
# so that $HOME-relative tools (Claude Code, git, etc.) find them.
|
||||
if [ "$ACTIVE_HOME" != "$MOUNT_HOME" ]; then
|
||||
for item in .claude .cache .local; do
|
||||
[ -d "$MOUNT_HOME/$item" ] || continue
|
||||
if [ -e "$ACTIVE_HOME/$item" ] && [ ! -L "$ACTIVE_HOME/$item" ]; then
|
||||
echo "warning: replacing $ACTIVE_HOME/$item with symlink to $MOUNT_HOME/$item" >&2
|
||||
rm -rf "$ACTIVE_HOME/$item"
|
||||
fi
|
||||
ln -sfn "$MOUNT_HOME/$item" "$ACTIVE_HOME/$item"
|
||||
done
|
||||
# Symlink files (not directories).
|
||||
for file in .claude.json .gitconfig .zshrc.host; do
|
||||
[ -f "$MOUNT_HOME/$file" ] && ln -sf "$MOUNT_HOME/$file" "$ACTIVE_HOME/$file"
|
||||
done
|
||||
|
||||
# Nested mount: .config/nvim
|
||||
if [ -d "$MOUNT_HOME/.config/nvim" ]; then
|
||||
mkdir -p "$ACTIVE_HOME/.config"
|
||||
if [ -e "$ACTIVE_HOME/.config/nvim" ] && [ ! -L "$ACTIVE_HOME/.config/nvim" ]; then
|
||||
echo "warning: replacing $ACTIVE_HOME/.config/nvim with symlink" >&2
|
||||
rm -rf "$ACTIVE_HOME/.config/nvim"
|
||||
fi
|
||||
ln -sfn "$MOUNT_HOME/.config/nvim" "$ACTIVE_HOME/.config/nvim"
|
||||
fi
|
||||
fi
|
||||
|
||||
# ── Phase 2: workspace access ───────────────────────────────────────
|
||||
|
||||
# Root always has workspace access; Phase 1 handled home setup.
|
||||
if [ "$REMOTE_USER" = "root" ]; then
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# Already matching -- nothing to do.
|
||||
if [ "$WS_UID" = "$DEV_UID" ] && [ "$WS_GID" = "$DEV_GID" ]; then
|
||||
exit 0
|
||||
fi
|
||||
|
||||
if [ "$WS_UID" != "0" ]; then
|
||||
# ── Standard Docker ──────────────────────────────────────────────
|
||||
# Workspace is owned by a non-root UID (the host user).
|
||||
# Remap dev's UID/GID to match.
|
||||
if [ "$DEV_GID" != "$WS_GID" ]; then
|
||||
if ! groupmod -g "$WS_GID" "$TARGET_USER" 2>&1; then
|
||||
echo "warning: failed to remap $TARGET_USER GID to $WS_GID" >&2
|
||||
fi
|
||||
fi
|
||||
if [ "$DEV_UID" != "$WS_UID" ]; then
|
||||
if ! usermod -u "$WS_UID" -g "$WS_GID" "$TARGET_USER" 2>&1; then
|
||||
echo "warning: failed to remap $TARGET_USER UID to $WS_UID" >&2
|
||||
fi
|
||||
fi
|
||||
if ! chown -R "$TARGET_USER":"$TARGET_USER" "$MOUNT_HOME" 2>&1; then
|
||||
echo "warning: failed to chown $MOUNT_HOME" >&2
|
||||
fi
|
||||
else
|
||||
# ── Rootless Docker ──────────────────────────────────────────────
|
||||
# Workspace is root-owned (UID 0) due to user-namespace mapping.
|
||||
# The supported path is remoteUser=root (set DEVCONTAINER_REMOTE_USER=root),
|
||||
# which is handled above. If we reach here, the user is running as dev
|
||||
# under rootless Docker without the override.
|
||||
echo "error: rootless Docker detected but remoteUser is not root." >&2
|
||||
echo " Set DEVCONTAINER_REMOTE_USER=root before starting the container," >&2
|
||||
echo " or use 'ods dev up' which sets it automatically." >&2
|
||||
exit 1
|
||||
fi
|
||||
@@ -1,104 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
echo "Setting up firewall..."
|
||||
|
||||
# Only flush the filter table. The nat and mangle tables are managed by Docker
|
||||
# (DNS DNAT to 127.0.0.11, container networking, etc.) and must not be touched —
|
||||
# flushing them breaks Docker's embedded DNS resolver.
|
||||
iptables -F
|
||||
iptables -X
|
||||
|
||||
# Create ipset for allowed destinations
|
||||
ipset create allowed-domains hash:net || true
|
||||
ipset flush allowed-domains
|
||||
|
||||
# Fetch GitHub IP ranges (IPv4 only -- ipset hash:net and iptables are IPv4)
|
||||
GITHUB_IPS=$(curl -s https://api.github.com/meta | jq -r '.api[]' 2>/dev/null | grep -v ':' || echo "")
|
||||
for ip in $GITHUB_IPS; do
|
||||
if ! ipset add allowed-domains "$ip" -exist 2>&1; then
|
||||
echo "warning: failed to add GitHub IP $ip to allowlist" >&2
|
||||
fi
|
||||
done
|
||||
|
||||
# Resolve allowed domains
|
||||
ALLOWED_DOMAINS=(
|
||||
"github.com"
|
||||
"registry.npmjs.org"
|
||||
"api.anthropic.com"
|
||||
"api-staging.anthropic.com"
|
||||
"files.anthropic.com"
|
||||
"sentry.io"
|
||||
"update.code.visualstudio.com"
|
||||
"pypi.org"
|
||||
"files.pythonhosted.org"
|
||||
"go.dev"
|
||||
"storage.googleapis.com"
|
||||
"static.rust-lang.org"
|
||||
)
|
||||
|
||||
for domain in "${ALLOWED_DOMAINS[@]}"; do
|
||||
IPS=$(getent ahosts "$domain" 2>/dev/null | awk '{print $1}' | grep -v ':' | sort -u || echo "")
|
||||
for ip in $IPS; do
|
||||
if ! ipset add allowed-domains "$ip/32" -exist 2>&1; then
|
||||
echo "warning: failed to add $domain ($ip) to allowlist" >&2
|
||||
fi
|
||||
done
|
||||
done
|
||||
|
||||
# Allow traffic to the Docker gateway so the container can reach host services
|
||||
# (e.g. the Onyx stack at localhost:3000, localhost:8080, etc.)
|
||||
DOCKER_GATEWAY=$(ip -4 route show default | awk '{print $3}')
|
||||
if [ -n "$DOCKER_GATEWAY" ]; then
|
||||
if ! ipset add allowed-domains "$DOCKER_GATEWAY/32" -exist 2>&1; then
|
||||
echo "warning: failed to add Docker gateway $DOCKER_GATEWAY to allowlist" >&2
|
||||
fi
|
||||
fi
|
||||
|
||||
# Allow traffic to all attached Docker network subnets so the container can
|
||||
# reach sibling services (e.g. relational_db, cache) on shared compose networks.
|
||||
for subnet in $(ip -4 -o addr show scope global | awk '{print $4}'); do
|
||||
if ! ipset add allowed-domains "$subnet" -exist 2>&1; then
|
||||
echo "warning: failed to add Docker subnet $subnet to allowlist" >&2
|
||||
fi
|
||||
done
|
||||
|
||||
# Set default policies to DROP
|
||||
iptables -P FORWARD DROP
|
||||
iptables -P INPUT DROP
|
||||
iptables -P OUTPUT DROP
|
||||
|
||||
# Allow established connections
|
||||
iptables -A INPUT -m conntrack --ctstate ESTABLISHED,RELATED -j ACCEPT
|
||||
iptables -A OUTPUT -m conntrack --ctstate ESTABLISHED,RELATED -j ACCEPT
|
||||
|
||||
# Allow loopback
|
||||
iptables -A INPUT -i lo -j ACCEPT
|
||||
iptables -A OUTPUT -o lo -j ACCEPT
|
||||
|
||||
# Allow DNS
|
||||
iptables -A OUTPUT -p udp --dport 53 -j ACCEPT
|
||||
iptables -A OUTPUT -p tcp --dport 53 -j ACCEPT
|
||||
|
||||
# Allow outbound to allowed destinations
|
||||
iptables -A OUTPUT -m set --match-set allowed-domains dst -j ACCEPT
|
||||
|
||||
# Reject unauthorized outbound
|
||||
iptables -A OUTPUT -j REJECT --reject-with icmp-host-unreachable
|
||||
|
||||
# Validate firewall configuration
|
||||
echo "Validating firewall configuration..."
|
||||
|
||||
BLOCKED_SITES=("example.com" "google.com" "facebook.com")
|
||||
for site in "${BLOCKED_SITES[@]}"; do
|
||||
if timeout 2 ping -c 1 "$site" &>/dev/null; then
|
||||
echo "Warning: $site is still reachable"
|
||||
fi
|
||||
done
|
||||
|
||||
if ! timeout 5 curl -s https://api.github.com/meta > /dev/null; then
|
||||
echo "Warning: GitHub API is not accessible"
|
||||
fi
|
||||
|
||||
echo "Firewall setup complete"
|
||||
@@ -1,10 +0,0 @@
|
||||
# Devcontainer zshrc — sourced automatically for both root and dev users.
|
||||
# Edit this file to customize the shell without rebuilding the image.
|
||||
|
||||
# Auto-activate Python venv
|
||||
if [ -f /workspace/.venv/bin/activate ]; then
|
||||
. /workspace/.venv/bin/activate
|
||||
fi
|
||||
|
||||
# Source host zshrc if bind-mounted
|
||||
[ -f ~/.zshrc.host ] && . ~/.zshrc.host
|
||||
10
.github/workflows/deployment.yml
vendored
10
.github/workflows/deployment.yml
vendored
@@ -13,7 +13,7 @@ permissions:
|
||||
id-token: write # zizmor: ignore[excessive-permissions]
|
||||
|
||||
env:
|
||||
EDGE_TAG: ${{ startsWith(github.ref_name, 'nightly-latest') || github.ref_name == 'edge' }}
|
||||
EDGE_TAG: ${{ startsWith(github.ref_name, 'nightly-latest') }}
|
||||
|
||||
jobs:
|
||||
# Determine which components to build based on the tag
|
||||
@@ -44,7 +44,7 @@ jobs:
|
||||
fetch-tags: true
|
||||
|
||||
- name: Setup uv
|
||||
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # ratchet:astral-sh/setup-uv@v8.0.0
|
||||
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # ratchet:astral-sh/setup-uv@v7
|
||||
with:
|
||||
version: "0.9.9"
|
||||
enable-cache: false
|
||||
@@ -156,7 +156,7 @@ jobs:
|
||||
check-version-tag:
|
||||
runs-on: ubuntu-slim
|
||||
timeout-minutes: 10
|
||||
if: ${{ !startsWith(github.ref_name, 'nightly-latest') && github.ref_name != 'edge' && github.event_name != 'workflow_dispatch' }}
|
||||
if: ${{ !startsWith(github.ref_name, 'nightly-latest') && github.event_name != 'workflow_dispatch' }}
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
@@ -165,7 +165,7 @@ jobs:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Setup uv
|
||||
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # ratchet:astral-sh/setup-uv@v8.0.0
|
||||
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # ratchet:astral-sh/setup-uv@v7
|
||||
with:
|
||||
version: "0.9.9"
|
||||
# NOTE: This isn't caching much and zizmor suggests this could be poisoned, so disable.
|
||||
@@ -228,7 +228,7 @@ jobs:
|
||||
|
||||
- name: Create GitHub Release
|
||||
id: create-release
|
||||
uses: softprops/action-gh-release@153bb8e04406b158c6c84fc1615b65b24149a1fe # ratchet:softprops/action-gh-release@v2
|
||||
uses: softprops/action-gh-release@da05d552573ad5aba039eaac05058a918a7bf631 # ratchet:softprops/action-gh-release@v2
|
||||
with:
|
||||
tag_name: ${{ steps.release-tag.outputs.tag }}
|
||||
name: ${{ steps.release-tag.outputs.tag }}
|
||||
|
||||
2
.github/workflows/helm-chart-releases.yml
vendored
2
.github/workflows/helm-chart-releases.yml
vendored
@@ -21,7 +21,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Helm CLI
|
||||
uses: azure/setup-helm@dda3372f752e03dde6b3237bc9431cdc2f7a02a2 # ratchet:azure/setup-helm@v5.0.0
|
||||
uses: azure/setup-helm@1a275c3b69536ee54be43f2070a358922e12c8d4 # ratchet:azure/setup-helm@v4
|
||||
with:
|
||||
version: v3.12.1
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- uses: actions/stale@b5d41d4e1d5dceea10e7104786b73624c18a190f # ratchet:actions/stale@v10
|
||||
- uses: actions/stale@997185467fa4f803885201cee163a9f38240193d # ratchet:actions/stale@v10
|
||||
with:
|
||||
stale-issue-message: 'This issue is stale because it has been open 75 days with no activity. Remove stale label or comment or this will be closed in 15 days.'
|
||||
stale-pr-message: 'This PR is stale because it has been open 75 days with no activity. Remove stale label or comment or this will be closed in 15 days.'
|
||||
|
||||
@@ -114,7 +114,7 @@ jobs:
|
||||
ref: main
|
||||
|
||||
- name: Install the latest version of uv
|
||||
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # ratchet:astral-sh/setup-uv@v8.0.0
|
||||
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # ratchet:astral-sh/setup-uv@v7
|
||||
with:
|
||||
enable-cache: false
|
||||
version: "0.9.9"
|
||||
|
||||
2
.github/workflows/pr-helm-chart-testing.yml
vendored
2
.github/workflows/pr-helm-chart-testing.yml
vendored
@@ -36,7 +36,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Set up Helm
|
||||
uses: azure/setup-helm@dda3372f752e03dde6b3237bc9431cdc2f7a02a2 # ratchet:azure/setup-helm@v5.0.0
|
||||
uses: azure/setup-helm@1a275c3b69536ee54be43f2070a358922e12c8d4 # ratchet:azure/setup-helm@v4.3.1
|
||||
with:
|
||||
version: v3.19.0
|
||||
|
||||
|
||||
4
.github/workflows/pr-playwright-tests.yml
vendored
4
.github/workflows/pr-playwright-tests.yml
vendored
@@ -471,7 +471,7 @@ jobs:
|
||||
|
||||
- name: Install the latest version of uv
|
||||
if: always()
|
||||
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # ratchet:astral-sh/setup-uv@v8.0.0
|
||||
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # ratchet:astral-sh/setup-uv@v7
|
||||
with:
|
||||
enable-cache: false
|
||||
version: "0.9.9"
|
||||
@@ -710,7 +710,7 @@ jobs:
|
||||
pull-requests: write
|
||||
steps:
|
||||
- name: Download visual diff summaries
|
||||
uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c
|
||||
uses: actions/download-artifact@70fc10c6e5e1ce46ad2ea6f2b72d43f7d47b13c3
|
||||
with:
|
||||
pattern: screenshot-diff-summary-*
|
||||
path: summaries/
|
||||
|
||||
2
.github/workflows/pr-quality-checks.yml
vendored
2
.github/workflows/pr-quality-checks.yml
vendored
@@ -38,7 +38,7 @@ jobs:
|
||||
- name: Install node dependencies
|
||||
working-directory: ./web
|
||||
run: npm ci
|
||||
- uses: j178/prek-action@cbc2f23eb5539cf20d82d1aabd0d0ecbcc56f4e3
|
||||
- uses: j178/prek-action@0bb87d7f00b0c99306c8bcb8b8beba1eb581c037 # ratchet:j178/prek-action@v1
|
||||
with:
|
||||
prek-version: '0.3.4'
|
||||
extra-args: ${{ github.event_name == 'pull_request' && format('--from-ref {0} --to-ref {1}', github.event.pull_request.base.sha, github.event.pull_request.head.sha) || github.event_name == 'merge_group' && format('--from-ref {0} --to-ref {1}', github.event.merge_group.base_sha, github.event.merge_group.head_sha) || github.ref_name == 'main' && '--all-files' || '' }}
|
||||
|
||||
2
.github/workflows/release-cli.yml
vendored
2
.github/workflows/release-cli.yml
vendored
@@ -17,7 +17,7 @@ jobs:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # ratchet:astral-sh/setup-uv@v8.0.0
|
||||
- uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # ratchet:astral-sh/setup-uv@v7
|
||||
with:
|
||||
enable-cache: false
|
||||
version: "0.9.9"
|
||||
|
||||
2
.github/workflows/release-devtools.yml
vendored
2
.github/workflows/release-devtools.yml
vendored
@@ -26,7 +26,7 @@ jobs:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # ratchet:astral-sh/setup-uv@v8.0.0
|
||||
- uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # ratchet:astral-sh/setup-uv@v7
|
||||
with:
|
||||
enable-cache: false
|
||||
version: "0.9.9"
|
||||
|
||||
2
.github/workflows/zizmor.yml
vendored
2
.github/workflows/zizmor.yml
vendored
@@ -24,7 +24,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install the latest version of uv
|
||||
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # ratchet:astral-sh/setup-uv@v8.0.0
|
||||
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # ratchet:astral-sh/setup-uv@v7
|
||||
with:
|
||||
enable-cache: false
|
||||
version: "0.9.9"
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -59,6 +59,3 @@ node_modules
|
||||
|
||||
# plans
|
||||
plans/
|
||||
|
||||
# Added context for LLMs
|
||||
onyx-llm-context/
|
||||
|
||||
@@ -1,57 +1,64 @@
|
||||
{
|
||||
"labels": [],
|
||||
"comment": "",
|
||||
"fixWithAI": true,
|
||||
"hideFooter": false,
|
||||
"strictness": 3,
|
||||
"statusCheck": true,
|
||||
"commentTypes": ["logic", "syntax", "style"],
|
||||
"instructions": "",
|
||||
"disabledLabels": [],
|
||||
"excludeAuthors": ["dependabot[bot]", "renovate[bot]"],
|
||||
"ignoreKeywords": "",
|
||||
"ignorePatterns": "",
|
||||
"includeAuthors": [],
|
||||
"summarySection": {
|
||||
"included": true,
|
||||
"collapsible": false,
|
||||
"defaultOpen": false
|
||||
},
|
||||
"excludeBranches": [],
|
||||
"fileChangeLimit": 300,
|
||||
"includeBranches": [],
|
||||
"includeKeywords": "",
|
||||
"triggerOnUpdates": false,
|
||||
"updateExistingSummaryComment": true,
|
||||
"updateSummaryOnly": false,
|
||||
"issuesTableSection": {
|
||||
"included": true,
|
||||
"collapsible": false,
|
||||
"defaultOpen": false
|
||||
},
|
||||
"statusCommentsEnabled": true,
|
||||
"confidenceScoreSection": {
|
||||
"included": true,
|
||||
"collapsible": false
|
||||
},
|
||||
"sequenceDiagramSection": {
|
||||
"included": true,
|
||||
"collapsible": false,
|
||||
"defaultOpen": false
|
||||
},
|
||||
"shouldUpdateDescription": false,
|
||||
"rules": [
|
||||
{
|
||||
"scope": ["web/**"],
|
||||
"rule": "In Onyx's Next.js app, the `app/ee/admin/` directory is a filesystem convention for Enterprise Edition route overrides — it does NOT add an `/ee/` prefix to the URL. Both `app/admin/groups/page.tsx` and `app/ee/admin/groups/page.tsx` serve the same URL `/admin/groups`. Hardcoded `/admin/...` paths in router.push() calls are correct and do NOT break EE deployments. Do not flag hardcoded admin paths as bugs."
|
||||
"labels": [],
|
||||
"comment": "",
|
||||
"fixWithAI": true,
|
||||
"hideFooter": false,
|
||||
"strictness": 3,
|
||||
"statusCheck": true,
|
||||
"commentTypes": [
|
||||
"logic",
|
||||
"syntax",
|
||||
"style"
|
||||
],
|
||||
"instructions": "",
|
||||
"disabledLabels": [],
|
||||
"excludeAuthors": [
|
||||
"dependabot[bot]",
|
||||
"renovate[bot]"
|
||||
],
|
||||
"ignoreKeywords": "",
|
||||
"ignorePatterns": "",
|
||||
"includeAuthors": [],
|
||||
"summarySection": {
|
||||
"included": true,
|
||||
"collapsible": false,
|
||||
"defaultOpen": false
|
||||
},
|
||||
{
|
||||
"scope": ["web/**"],
|
||||
"rule": "In Onyx, each API key creates a unique user row in the database with a unique `user_id` (UUID). There is a 1:1 mapping between API keys and their backing user records. Multiple API keys do NOT share the same `user_id`. Do not flag potential duplicate row IDs when using `user_id` from API key descriptors."
|
||||
"excludeBranches": [],
|
||||
"fileChangeLimit": 300,
|
||||
"includeBranches": [],
|
||||
"includeKeywords": "",
|
||||
"triggerOnUpdates": true,
|
||||
"updateExistingSummaryComment": true,
|
||||
"updateSummaryOnly": false,
|
||||
"issuesTableSection": {
|
||||
"included": true,
|
||||
"collapsible": false,
|
||||
"defaultOpen": false
|
||||
},
|
||||
{
|
||||
"scope": ["backend/**/*.py"],
|
||||
"rule": "Never raise HTTPException directly in business code. Use `raise OnyxError(OnyxErrorCode.XXX, \"message\")` from `onyx.error_handling.exceptions`. A global FastAPI exception handler converts OnyxError into structured JSON responses with {\"error_code\": \"...\", \"detail\": \"...\"}. Error codes are defined in `onyx.error_handling.error_codes.OnyxErrorCode`. For upstream errors with dynamic HTTP status codes, use `status_code_override`: `raise OnyxError(OnyxErrorCode.BAD_GATEWAY, detail, status_code_override=upstream_status)`."
|
||||
}
|
||||
]
|
||||
"statusCommentsEnabled": true,
|
||||
"confidenceScoreSection": {
|
||||
"included": true,
|
||||
"collapsible": false
|
||||
},
|
||||
"sequenceDiagramSection": {
|
||||
"included": true,
|
||||
"collapsible": false,
|
||||
"defaultOpen": false
|
||||
},
|
||||
"shouldUpdateDescription": false,
|
||||
"rules": [
|
||||
{
|
||||
"scope": ["web/**"],
|
||||
"rule": "In Onyx's Next.js app, the `app/ee/admin/` directory is a filesystem convention for Enterprise Edition route overrides — it does NOT add an `/ee/` prefix to the URL. Both `app/admin/groups/page.tsx` and `app/ee/admin/groups/page.tsx` serve the same URL `/admin/groups`. Hardcoded `/admin/...` paths in router.push() calls are correct and do NOT break EE deployments. Do not flag hardcoded admin paths as bugs."
|
||||
},
|
||||
{
|
||||
"scope": ["web/**"],
|
||||
"rule": "In Onyx, each API key creates a unique user row in the database with a unique `user_id` (UUID). There is a 1:1 mapping between API keys and their backing user records. Multiple API keys do NOT share the same `user_id`. Do not flag potential duplicate row IDs when using `user_id` from API key descriptors."
|
||||
},
|
||||
{
|
||||
"scope": ["backend/**/*.py"],
|
||||
"rule": "Never raise HTTPException directly in business code. Use `raise OnyxError(OnyxErrorCode.XXX, \"message\")` from `onyx.error_handling.exceptions`. A global FastAPI exception handler converts OnyxError into structured JSON responses with {\"error_code\": \"...\", \"detail\": \"...\"}. Error codes are defined in `onyx.error_handling.error_codes.OnyxErrorCode`. For upstream errors with dynamic HTTP status codes, use `status_code_override`: `raise OnyxError(OnyxErrorCode.BAD_GATEWAY, detail, status_code_override=upstream_status)`."
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ repos:
|
||||
rev: d30b4298e4fb63ce8609e29acdbcf4c9018a483c
|
||||
hooks:
|
||||
- id: uv-sync
|
||||
args: ["--locked", "--all-extras"]
|
||||
- id: uv-lock
|
||||
- id: uv-export
|
||||
name: uv-export default.txt
|
||||
@@ -17,7 +18,7 @@ repos:
|
||||
"--no-emit-project",
|
||||
"--no-default-groups",
|
||||
"--no-hashes",
|
||||
"--group",
|
||||
"--extra",
|
||||
"backend",
|
||||
"-o",
|
||||
"backend/requirements/default.txt",
|
||||
@@ -30,7 +31,7 @@ repos:
|
||||
"--no-emit-project",
|
||||
"--no-default-groups",
|
||||
"--no-hashes",
|
||||
"--group",
|
||||
"--extra",
|
||||
"dev",
|
||||
"-o",
|
||||
"backend/requirements/dev.txt",
|
||||
@@ -43,7 +44,7 @@ repos:
|
||||
"--no-emit-project",
|
||||
"--no-default-groups",
|
||||
"--no-hashes",
|
||||
"--group",
|
||||
"--extra",
|
||||
"ee",
|
||||
"-o",
|
||||
"backend/requirements/ee.txt",
|
||||
@@ -56,7 +57,7 @@ repos:
|
||||
"--no-emit-project",
|
||||
"--no-default-groups",
|
||||
"--no-hashes",
|
||||
"--group",
|
||||
"--extra",
|
||||
"model_server",
|
||||
"-o",
|
||||
"backend/requirements/model_server.txt",
|
||||
|
||||
15
.vscode/launch.json
vendored
15
.vscode/launch.json
vendored
@@ -475,18 +475,6 @@
|
||||
"order": 0
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Start Monitoring Stack (Prometheus + Grafana)",
|
||||
"type": "node",
|
||||
"request": "launch",
|
||||
"runtimeExecutable": "docker",
|
||||
"runtimeArgs": ["compose", "up", "-d"],
|
||||
"cwd": "${workspaceFolder}/profiling",
|
||||
"console": "integratedTerminal",
|
||||
"presentation": {
|
||||
"group": "3"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Clear and Restart External Volumes and Containers",
|
||||
"type": "node",
|
||||
@@ -543,7 +531,8 @@
|
||||
"request": "launch",
|
||||
"runtimeExecutable": "uv",
|
||||
"runtimeArgs": [
|
||||
"sync"
|
||||
"sync",
|
||||
"--all-extras"
|
||||
],
|
||||
"cwd": "${workspaceFolder}",
|
||||
"console": "integratedTerminal",
|
||||
|
||||
@@ -49,12 +49,12 @@ Onyx uses Celery for asynchronous task processing with multiple specialized work
|
||||
|
||||
4. **Light Worker** (`light`)
|
||||
- Handles lightweight, fast operations
|
||||
- Tasks: vespa metadata sync, connector deletion, doc permissions upsert, checkpoint cleanup, index attempt cleanup
|
||||
- Tasks: vespa operations, document permissions sync, external group sync
|
||||
- Higher concurrency for quick tasks
|
||||
|
||||
5. **Heavy Worker** (`heavy`)
|
||||
- Handles resource-intensive operations
|
||||
- Tasks: connector pruning, document permissions sync, external group sync, CSV generation
|
||||
- Primary task: document pruning operations
|
||||
- Runs with 4 threads concurrency
|
||||
|
||||
6. **KG Processing Worker** (`kg_processing`)
|
||||
|
||||
@@ -117,7 +117,7 @@ If using PowerShell, the command slightly differs:
|
||||
Install the required Python dependencies:
|
||||
|
||||
```bash
|
||||
uv sync
|
||||
uv sync --all-extras
|
||||
```
|
||||
|
||||
Install Playwright for Python (headless browser required by the Web Connector):
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
FROM python:3.11-slim-bookworm@sha256:9c6f90801e6b68e772b7c0ca74260cbf7af9f320acec894e26fccdaccfbe3b47
|
||||
FROM python:3.11.7-slim-bookworm
|
||||
|
||||
LABEL com.danswer.maintainer="founders@onyx.app"
|
||||
LABEL com.danswer.description="This image is the web/frontend container of Onyx which \
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# Base stage with dependencies
|
||||
FROM python:3.11-slim-bookworm@sha256:9c6f90801e6b68e772b7c0ca74260cbf7af9f320acec894e26fccdaccfbe3b47 AS base
|
||||
FROM python:3.11.7-slim-bookworm AS base
|
||||
|
||||
ENV DANSWER_RUNNING_IN_DOCKER="true" \
|
||||
HF_HOME=/app/.cache/huggingface
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any
|
||||
from typing import Any, Literal
|
||||
from onyx.db.engine.iam_auth import get_iam_auth_token
|
||||
from onyx.configs.app_configs import USE_IAM_AUTH
|
||||
from onyx.configs.app_configs import POSTGRES_HOST
|
||||
@@ -19,6 +19,7 @@ from logging.config import fileConfig
|
||||
|
||||
from alembic import context
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from sqlalchemy.sql.schema import SchemaItem
|
||||
from onyx.configs.constants import SSL_CERT_FILE
|
||||
from shared_configs.configs import (
|
||||
MULTI_TENANT,
|
||||
@@ -44,6 +45,8 @@ if config.config_file_name is not None and config.attributes.get(
|
||||
|
||||
target_metadata = [Base.metadata, ResultModelBase.metadata]
|
||||
|
||||
EXCLUDE_TABLES = {"kombu_queue", "kombu_message"}
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ssl_context: ssl.SSLContext | None = None
|
||||
@@ -53,6 +56,25 @@ if USE_IAM_AUTH:
|
||||
ssl_context = ssl.create_default_context(cafile=SSL_CERT_FILE)
|
||||
|
||||
|
||||
def include_object(
|
||||
object: SchemaItem, # noqa: ARG001
|
||||
name: str | None,
|
||||
type_: Literal[
|
||||
"schema",
|
||||
"table",
|
||||
"column",
|
||||
"index",
|
||||
"unique_constraint",
|
||||
"foreign_key_constraint",
|
||||
],
|
||||
reflected: bool, # noqa: ARG001
|
||||
compare_to: SchemaItem | None, # noqa: ARG001
|
||||
) -> bool:
|
||||
if type_ == "table" and name in EXCLUDE_TABLES:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def filter_tenants_by_range(
|
||||
tenant_ids: list[str], start_range: int | None = None, end_range: int | None = None
|
||||
) -> list[str]:
|
||||
@@ -208,7 +230,8 @@ def do_run_migrations(
|
||||
|
||||
context.configure(
|
||||
connection=connection,
|
||||
target_metadata=target_metadata,
|
||||
target_metadata=target_metadata, # type: ignore
|
||||
include_object=include_object,
|
||||
version_table_schema=schema_name,
|
||||
include_schemas=True,
|
||||
compare_type=True,
|
||||
@@ -380,8 +403,9 @@ def run_migrations_offline() -> None:
|
||||
logger.info(f"Migrating schema: {schema}")
|
||||
context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata,
|
||||
target_metadata=target_metadata, # type: ignore
|
||||
literal_binds=True,
|
||||
include_object=include_object,
|
||||
version_table_schema=schema,
|
||||
include_schemas=True,
|
||||
script_location=config.get_main_option("script_location"),
|
||||
@@ -421,8 +445,9 @@ def run_migrations_offline() -> None:
|
||||
logger.info(f"Migrating schema: {schema}")
|
||||
context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata,
|
||||
target_metadata=target_metadata, # type: ignore
|
||||
literal_binds=True,
|
||||
include_object=include_object,
|
||||
version_table_schema=schema,
|
||||
include_schemas=True,
|
||||
script_location=config.get_main_option("script_location"),
|
||||
@@ -464,7 +489,8 @@ def run_migrations_online() -> None:
|
||||
|
||||
context.configure(
|
||||
connection=connection,
|
||||
target_metadata=target_metadata,
|
||||
target_metadata=target_metadata, # type: ignore
|
||||
include_object=include_object,
|
||||
version_table_schema=schema_name,
|
||||
include_schemas=True,
|
||||
compare_type=True,
|
||||
|
||||
@@ -25,7 +25,7 @@ def upgrade() -> None:
|
||||
|
||||
# Use batch mode to modify the enum type
|
||||
with op.batch_alter_table("user", schema=None) as batch_op:
|
||||
batch_op.alter_column(
|
||||
batch_op.alter_column( # type: ignore[attr-defined]
|
||||
"role",
|
||||
type_=sa.Enum(
|
||||
"BASIC",
|
||||
@@ -71,7 +71,7 @@ def downgrade() -> None:
|
||||
op.drop_column("user__user_group", "is_curator")
|
||||
|
||||
with op.batch_alter_table("user", schema=None) as batch_op:
|
||||
batch_op.alter_column(
|
||||
batch_op.alter_column( # type: ignore[attr-defined]
|
||||
"role",
|
||||
type_=sa.Enum(
|
||||
"BASIC", "ADMIN", name="userrole", native_enum=False, length=20
|
||||
|
||||
@@ -63,7 +63,7 @@ def upgrade() -> None:
|
||||
"time_created",
|
||||
existing_type=postgresql.TIMESTAMP(timezone=True),
|
||||
nullable=False,
|
||||
existing_server_default=sa.text("now()"),
|
||||
existing_server_default=sa.text("now()"), # type: ignore
|
||||
)
|
||||
op.alter_column(
|
||||
"index_attempt",
|
||||
@@ -85,7 +85,7 @@ def downgrade() -> None:
|
||||
"time_created",
|
||||
existing_type=postgresql.TIMESTAMP(timezone=True),
|
||||
nullable=True,
|
||||
existing_server_default=sa.text("now()"),
|
||||
existing_server_default=sa.text("now()"), # type: ignore
|
||||
)
|
||||
op.drop_index(op.f("ix_accesstoken_created_at"), table_name="accesstoken")
|
||||
op.drop_table("accesstoken")
|
||||
|
||||
@@ -19,7 +19,7 @@ depends_on: None = None
|
||||
|
||||
def upgrade() -> None:
|
||||
sequence = Sequence("connector_credential_pair_id_seq")
|
||||
op.execute(CreateSequence(sequence))
|
||||
op.execute(CreateSequence(sequence)) # type: ignore
|
||||
op.add_column(
|
||||
"connector_credential_pair",
|
||||
sa.Column(
|
||||
|
||||
@@ -1,28 +0,0 @@
|
||||
"""add_error_tracking_fields_to_index_attempt_errors
|
||||
|
||||
Revision ID: d129f37b3d87
|
||||
Revises: 503883791c39
|
||||
Create Date: 2026-04-06 19:11:18.261800
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "d129f37b3d87"
|
||||
down_revision = "503883791c39"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"index_attempt_errors",
|
||||
sa.Column("error_type", sa.String(), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("index_attempt_errors", "error_type")
|
||||
@@ -1,9 +1,11 @@
|
||||
import asyncio
|
||||
from logging.config import fileConfig
|
||||
from typing import Literal
|
||||
|
||||
from sqlalchemy import pool
|
||||
from sqlalchemy.engine import Connection
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from sqlalchemy.schema import SchemaItem
|
||||
|
||||
from alembic import context
|
||||
from onyx.db.engine.sql_engine import build_connection_string
|
||||
@@ -33,6 +35,27 @@ target_metadata = [PublicBase.metadata]
|
||||
# my_important_option = config.get_main_option("my_important_option")
|
||||
# ... etc.
|
||||
|
||||
EXCLUDE_TABLES = {"kombu_queue", "kombu_message"}
|
||||
|
||||
|
||||
def include_object(
|
||||
object: SchemaItem, # noqa: ARG001
|
||||
name: str | None,
|
||||
type_: Literal[
|
||||
"schema",
|
||||
"table",
|
||||
"column",
|
||||
"index",
|
||||
"unique_constraint",
|
||||
"foreign_key_constraint",
|
||||
],
|
||||
reflected: bool, # noqa: ARG001
|
||||
compare_to: SchemaItem | None, # noqa: ARG001
|
||||
) -> bool:
|
||||
if type_ == "table" and name in EXCLUDE_TABLES:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def run_migrations_offline() -> None:
|
||||
"""Run migrations in 'offline' mode.
|
||||
@@ -49,7 +72,7 @@ def run_migrations_offline() -> None:
|
||||
url = build_connection_string()
|
||||
context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata,
|
||||
target_metadata=target_metadata, # type: ignore
|
||||
literal_binds=True,
|
||||
dialect_opts={"paramstyle": "named"},
|
||||
)
|
||||
@@ -61,7 +84,8 @@ def run_migrations_offline() -> None:
|
||||
def do_run_migrations(connection: Connection) -> None:
|
||||
context.configure(
|
||||
connection=connection,
|
||||
target_metadata=target_metadata,
|
||||
target_metadata=target_metadata, # type: ignore[arg-type]
|
||||
include_object=include_object,
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
|
||||
@@ -10,10 +10,9 @@ from fastapi import status
|
||||
from ee.onyx.configs.app_configs import SUPER_CLOUD_API_KEY
|
||||
from ee.onyx.configs.app_configs import SUPER_USERS
|
||||
from ee.onyx.server.seeding import get_seed_config
|
||||
from onyx.auth.permissions import require_permission
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.configs.app_configs import AUTH_TYPE
|
||||
from onyx.configs.app_configs import USER_AUTH_SECRET
|
||||
from onyx.db.enums import Permission
|
||||
from onyx.db.models import User
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -40,7 +39,7 @@ def get_default_admin_user_emails_() -> list[str]:
|
||||
|
||||
async def current_cloud_superuser(
|
||||
request: Request,
|
||||
user: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
user: User = Depends(current_admin_user),
|
||||
) -> User:
|
||||
api_key = request.headers.get("Authorization", "").replace("Bearer ", "")
|
||||
if api_key != SUPER_CLOUD_API_KEY:
|
||||
|
||||
@@ -5,7 +5,6 @@ from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from redis.lock import Lock as RedisLock
|
||||
|
||||
from ee.onyx.server.tenants.product_gating import get_gated_tenants
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.tasks.beat_schedule import BEAT_EXPIRES_DEFAULT
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
@@ -31,7 +30,6 @@ def cloud_beat_task_generator(
|
||||
queue: str = OnyxCeleryTask.DEFAULT,
|
||||
priority: int = OnyxCeleryPriority.MEDIUM,
|
||||
expires: int = BEAT_EXPIRES_DEFAULT,
|
||||
skip_gated: bool = True,
|
||||
) -> bool | None:
|
||||
"""a lightweight task used to kick off individual beat tasks per tenant."""
|
||||
time_start = time.monotonic()
|
||||
@@ -50,22 +48,20 @@ def cloud_beat_task_generator(
|
||||
last_lock_time = time.monotonic()
|
||||
tenant_ids: list[str] = []
|
||||
num_processed_tenants = 0
|
||||
num_skipped_gated = 0
|
||||
|
||||
try:
|
||||
tenant_ids = get_all_tenant_ids()
|
||||
|
||||
# Per-task control over whether gated tenants are included. Most periodic tasks
|
||||
# do no useful work on gated tenants and just waste DB connections fanning out
|
||||
# to ~10k+ inactive tenants. A small number of cleanup tasks (connector deletion,
|
||||
# checkpoint/index attempt cleanup) need to run on gated tenants and pass
|
||||
# `skip_gated=False` from the beat schedule.
|
||||
gated_tenants: set[str] = get_gated_tenants() if skip_gated else set()
|
||||
# NOTE: for now, we are running tasks for gated tenants, since we want to allow
|
||||
# connector deletion to run successfully. The new plan is to continously prune
|
||||
# the gated tenants set, so we won't have a build up of old, unused gated tenants.
|
||||
# Keeping this around in case we want to revert to the previous behavior.
|
||||
# gated_tenants = get_gated_tenants()
|
||||
|
||||
for tenant_id in tenant_ids:
|
||||
if tenant_id in gated_tenants:
|
||||
num_skipped_gated += 1
|
||||
continue
|
||||
# Same comment here as the above NOTE
|
||||
# if tenant_id in gated_tenants:
|
||||
# continue
|
||||
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_lock_time >= (CELERY_GENERIC_BEAT_LOCK_TIMEOUT / 4):
|
||||
@@ -108,7 +104,6 @@ def cloud_beat_task_generator(
|
||||
f"cloud_beat_task_generator finished: "
|
||||
f"task={task_name} "
|
||||
f"num_processed_tenants={num_processed_tenants} "
|
||||
f"num_skipped_gated={num_skipped_gated} "
|
||||
f"num_tenants={len(tenant_ids)} "
|
||||
f"elapsed={time_elapsed:.2f}"
|
||||
)
|
||||
|
||||
@@ -13,7 +13,6 @@ from ee.onyx.server.license.models import LicenseSource
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.cache.factory import get_cache_backend
|
||||
from onyx.configs.constants import ANONYMOUS_USER_EMAIL
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.models import License
|
||||
from onyx.db.models import User
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -108,13 +107,12 @@ def get_used_seats(tenant_id: str | None = None) -> int:
|
||||
Get current seat usage directly from database.
|
||||
|
||||
For multi-tenant: counts users in UserTenantMapping for this tenant.
|
||||
For self-hosted: counts all active users.
|
||||
For self-hosted: counts all active users (excludes EXT_PERM_USER role
|
||||
and the anonymous system user).
|
||||
|
||||
Only human accounts count toward seat limits.
|
||||
SERVICE_ACCOUNT (API key dummy users), EXT_PERM_USER, and the
|
||||
anonymous system user are excluded. BOT (Slack users) ARE counted
|
||||
because they represent real humans and get upgraded to STANDARD
|
||||
when they log in via web.
|
||||
TODO: Exclude API key dummy users from seat counting. API keys create
|
||||
users with emails like `__DANSWER_API_KEY_*` that should not count toward
|
||||
seat limits. See: https://linear.app/onyx-app/issue/ENG-3518
|
||||
"""
|
||||
if MULTI_TENANT:
|
||||
from ee.onyx.server.tenants.user_mapping import get_tenant_count
|
||||
@@ -131,7 +129,6 @@ def get_used_seats(tenant_id: str | None = None) -> int:
|
||||
User.is_active == True, # type: ignore # noqa: E712
|
||||
User.role != UserRole.EXT_PERM_USER,
|
||||
User.email != ANONYMOUS_USER_EMAIL, # type: ignore
|
||||
User.account_type != AccountType.SERVICE_ACCOUNT,
|
||||
)
|
||||
)
|
||||
return result.scalar() or 0
|
||||
|
||||
@@ -39,7 +39,6 @@ from onyx.db.models import User__UserGroup
|
||||
from onyx.db.models import UserGroup
|
||||
from onyx.db.models import UserGroup__ConnectorCredentialPair
|
||||
from onyx.db.models import UserRole
|
||||
from onyx.db.permissions import recompute_permissions_for_group__no_commit
|
||||
from onyx.db.permissions import recompute_user_permissions__no_commit
|
||||
from onyx.db.users import fetch_user_by_id
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -953,46 +952,3 @@ def delete_user_group_cc_pair_relationship__no_commit(
|
||||
UserGroup__ConnectorCredentialPair.cc_pair_id == cc_pair_id,
|
||||
)
|
||||
db_session.execute(delete_stmt)
|
||||
|
||||
|
||||
def set_group_permission__no_commit(
|
||||
group_id: int,
|
||||
permission: Permission,
|
||||
enabled: bool,
|
||||
granted_by: UUID,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
"""Grant or revoke a single permission for a group using soft-delete.
|
||||
|
||||
Does NOT commit — caller must commit the session.
|
||||
"""
|
||||
existing = db_session.execute(
|
||||
select(PermissionGrant)
|
||||
.where(
|
||||
PermissionGrant.group_id == group_id,
|
||||
PermissionGrant.permission == permission,
|
||||
)
|
||||
.with_for_update()
|
||||
).scalar_one_or_none()
|
||||
|
||||
if enabled:
|
||||
if existing is not None:
|
||||
if existing.is_deleted:
|
||||
existing.is_deleted = False
|
||||
existing.granted_by = granted_by
|
||||
existing.granted_at = func.now()
|
||||
else:
|
||||
db_session.add(
|
||||
PermissionGrant(
|
||||
group_id=group_id,
|
||||
permission=permission,
|
||||
grant_source=GrantSource.USER,
|
||||
granted_by=granted_by,
|
||||
)
|
||||
)
|
||||
else:
|
||||
if existing is not None and not existing.is_deleted:
|
||||
existing.is_deleted = True
|
||||
|
||||
db_session.flush()
|
||||
recompute_permissions_for_group__no_commit(group_id, db_session)
|
||||
|
||||
@@ -155,7 +155,7 @@ def get_application() -> FastAPI:
|
||||
include_router_with_global_prefix_prepended(application, license_router)
|
||||
|
||||
# Unified billing API - always registered in EE.
|
||||
# Each endpoint is protected by admin permission checks.
|
||||
# Each endpoint is protected by the `current_admin_user` dependency (admin auth).
|
||||
include_router_with_global_prefix_prepended(application, billing_router)
|
||||
|
||||
if MULTI_TENANT:
|
||||
|
||||
@@ -17,10 +17,10 @@ from ee.onyx.db.analytics import fetch_persona_message_analytics
|
||||
from ee.onyx.db.analytics import fetch_persona_unique_users
|
||||
from ee.onyx.db.analytics import fetch_query_analytics
|
||||
from ee.onyx.db.analytics import user_can_view_assistant_stats
|
||||
from onyx.auth.permissions import require_permission
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.configs.constants import PUBLIC_API_TAGS
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.enums import Permission
|
||||
from onyx.db.models import User
|
||||
|
||||
router = APIRouter(prefix="/analytics", tags=PUBLIC_API_TAGS)
|
||||
@@ -40,7 +40,7 @@ class QueryAnalyticsResponse(BaseModel):
|
||||
def get_query_analytics(
|
||||
start: datetime.datetime | None = None,
|
||||
end: datetime.datetime | None = None,
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[QueryAnalyticsResponse]:
|
||||
daily_query_usage_info = fetch_query_analytics(
|
||||
@@ -71,7 +71,7 @@ class UserAnalyticsResponse(BaseModel):
|
||||
def get_user_analytics(
|
||||
start: datetime.datetime | None = None,
|
||||
end: datetime.datetime | None = None,
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[UserAnalyticsResponse]:
|
||||
daily_query_usage_info_per_user = fetch_per_user_query_analytics(
|
||||
@@ -105,7 +105,7 @@ class OnyxbotAnalyticsResponse(BaseModel):
|
||||
def get_onyxbot_analytics(
|
||||
start: datetime.datetime | None = None,
|
||||
end: datetime.datetime | None = None,
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[OnyxbotAnalyticsResponse]:
|
||||
daily_onyxbot_info = fetch_onyxbot_analytics(
|
||||
@@ -141,7 +141,7 @@ def get_persona_messages(
|
||||
persona_id: int,
|
||||
start: datetime.datetime | None = None,
|
||||
end: datetime.datetime | None = None,
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[PersonaMessageAnalyticsResponse]:
|
||||
"""Fetch daily message counts for a single persona within the given time range."""
|
||||
@@ -179,7 +179,7 @@ def get_persona_unique_users(
|
||||
persona_id: int,
|
||||
start: datetime.datetime,
|
||||
end: datetime.datetime,
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[PersonaUniqueUsersResponse]:
|
||||
"""Get unique users per day for a single persona."""
|
||||
@@ -218,7 +218,7 @@ def get_assistant_stats(
|
||||
assistant_id: int,
|
||||
start: datetime.datetime | None = None,
|
||||
end: datetime.datetime | None = None,
|
||||
user: User = Depends(require_permission(Permission.BASIC_ACCESS)),
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> AssistantStatsResponse:
|
||||
"""
|
||||
|
||||
@@ -29,6 +29,7 @@ from fastapi import Depends
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.auth.users import current_admin_user
|
||||
from ee.onyx.db.license import get_license
|
||||
from ee.onyx.db.license import get_used_seats
|
||||
from ee.onyx.server.billing.models import BillingInformationResponse
|
||||
@@ -50,13 +51,11 @@ from ee.onyx.server.billing.service import (
|
||||
get_billing_information as get_billing_service,
|
||||
)
|
||||
from ee.onyx.server.billing.service import update_seat_count as update_seat_service
|
||||
from onyx.auth.permissions import require_permission
|
||||
from onyx.auth.users import User
|
||||
from onyx.configs.app_configs import STRIPE_PUBLISHABLE_KEY_OVERRIDE
|
||||
from onyx.configs.app_configs import STRIPE_PUBLISHABLE_KEY_URL
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.enums import Permission
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.redis.redis_pool import get_shared_redis_client
|
||||
@@ -148,7 +147,7 @@ def _get_tenant_id() -> str | None:
|
||||
@router.post("/create-checkout-session")
|
||||
async def create_checkout_session(
|
||||
request: CreateCheckoutSessionRequest | None = None,
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> CreateCheckoutSessionResponse:
|
||||
"""Create a Stripe checkout session for new subscription or renewal.
|
||||
@@ -192,7 +191,7 @@ async def create_checkout_session(
|
||||
@router.post("/create-customer-portal-session")
|
||||
async def create_customer_portal_session(
|
||||
request: CreateCustomerPortalSessionRequest | None = None,
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> CreateCustomerPortalSessionResponse:
|
||||
"""Create a Stripe customer portal session for managing subscription.
|
||||
@@ -217,7 +216,7 @@ async def create_customer_portal_session(
|
||||
|
||||
@router.get("/billing-information")
|
||||
async def get_billing_information(
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> BillingInformationResponse | SubscriptionStatusResponse:
|
||||
"""Get billing information for the current subscription.
|
||||
@@ -259,7 +258,7 @@ async def get_billing_information(
|
||||
@router.post("/seats/update")
|
||||
async def update_seats(
|
||||
request: SeatUpdateRequest,
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> SeatUpdateResponse:
|
||||
"""Update the seat count for the current subscription.
|
||||
@@ -365,7 +364,7 @@ class ResetConnectionResponse(BaseModel):
|
||||
|
||||
@router.post("/reset-connection")
|
||||
async def reset_stripe_connection(
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
_: User = Depends(current_admin_user),
|
||||
) -> ResetConnectionResponse:
|
||||
"""Reset the Stripe connection circuit breaker.
|
||||
|
||||
|
||||
@@ -27,12 +27,11 @@ from ee.onyx.server.scim.auth import generate_scim_token
|
||||
from ee.onyx.server.scim.models import ScimTokenCreate
|
||||
from ee.onyx.server.scim.models import ScimTokenCreatedResponse
|
||||
from ee.onyx.server.scim.models import ScimTokenResponse
|
||||
from onyx.auth.permissions import require_permission
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import current_user_with_expired_token
|
||||
from onyx.auth.users import get_user_manager
|
||||
from onyx.auth.users import UserManager
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.enums import Permission
|
||||
from onyx.db.models import User
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.server.utils import BasicAuthenticationError
|
||||
@@ -121,8 +120,7 @@ async def refresh_access_token(
|
||||
|
||||
@admin_router.put("")
|
||||
def admin_ee_put_settings(
|
||||
settings: EnterpriseSettings,
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
settings: EnterpriseSettings, _: User = Depends(current_admin_user)
|
||||
) -> None:
|
||||
store_settings(settings)
|
||||
|
||||
@@ -141,7 +139,7 @@ def ee_fetch_settings() -> EnterpriseSettings:
|
||||
def put_logo(
|
||||
file: UploadFile,
|
||||
is_logotype: bool = False,
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
_: User = Depends(current_admin_user),
|
||||
) -> None:
|
||||
upload_logo(file=file, is_logotype=is_logotype)
|
||||
|
||||
@@ -198,8 +196,7 @@ def fetch_logo(
|
||||
|
||||
@admin_router.put("/custom-analytics-script")
|
||||
def upload_custom_analytics_script(
|
||||
script_upload: AnalyticsScriptUpload,
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
script_upload: AnalyticsScriptUpload, _: User = Depends(current_admin_user)
|
||||
) -> None:
|
||||
try:
|
||||
store_analytics_script(script_upload)
|
||||
@@ -223,7 +220,7 @@ def _get_scim_dal(db_session: Session = Depends(get_session)) -> ScimDAL:
|
||||
|
||||
@admin_router.get("/scim/token")
|
||||
def get_active_scim_token(
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
_: User = Depends(current_admin_user),
|
||||
dal: ScimDAL = Depends(_get_scim_dal),
|
||||
) -> ScimTokenResponse:
|
||||
"""Return the currently active SCIM token's metadata, or 404 if none."""
|
||||
@@ -253,7 +250,7 @@ def get_active_scim_token(
|
||||
@admin_router.post("/scim/token", status_code=201)
|
||||
def create_scim_token(
|
||||
body: ScimTokenCreate,
|
||||
user: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
user: User = Depends(current_admin_user),
|
||||
dal: ScimDAL = Depends(_get_scim_dal),
|
||||
) -> ScimTokenCreatedResponse:
|
||||
"""Create a new SCIM bearer token.
|
||||
|
||||
@@ -4,13 +4,12 @@ from fastapi import Depends
|
||||
from fastapi import Query
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.permissions import require_permission
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import User
|
||||
from onyx.db.constants import UNSET
|
||||
from onyx.db.constants import UnsetType
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import Permission
|
||||
from onyx.db.hook import create_hook__no_commit
|
||||
from onyx.db.hook import delete_hook__no_commit
|
||||
from onyx.db.hook import get_hook_by_id
|
||||
@@ -179,7 +178,7 @@ router = APIRouter(prefix="/admin/hooks")
|
||||
|
||||
@router.get("/specs")
|
||||
def get_hook_point_specs(
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
_: User = Depends(current_admin_user),
|
||||
_hook_enabled: None = Depends(require_hook_enabled),
|
||||
) -> list[HookPointMetaResponse]:
|
||||
return [
|
||||
@@ -200,7 +199,7 @@ def get_hook_point_specs(
|
||||
|
||||
@router.get("")
|
||||
def list_hooks(
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
_: User = Depends(current_admin_user),
|
||||
_hook_enabled: None = Depends(require_hook_enabled),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[HookResponse]:
|
||||
@@ -211,7 +210,7 @@ def list_hooks(
|
||||
@router.post("")
|
||||
def create_hook(
|
||||
req: HookCreateRequest,
|
||||
user: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
user: User = Depends(current_admin_user),
|
||||
_hook_enabled: None = Depends(require_hook_enabled),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> HookResponse:
|
||||
@@ -247,7 +246,7 @@ def create_hook(
|
||||
@router.get("/{hook_id}")
|
||||
def get_hook(
|
||||
hook_id: int,
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
_: User = Depends(current_admin_user),
|
||||
_hook_enabled: None = Depends(require_hook_enabled),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> HookResponse:
|
||||
@@ -259,7 +258,7 @@ def get_hook(
|
||||
def update_hook(
|
||||
hook_id: int,
|
||||
req: HookUpdateRequest,
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
_: User = Depends(current_admin_user),
|
||||
_hook_enabled: None = Depends(require_hook_enabled),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> HookResponse:
|
||||
@@ -329,7 +328,7 @@ def update_hook(
|
||||
@router.delete("/{hook_id}")
|
||||
def delete_hook(
|
||||
hook_id: int,
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
_: User = Depends(current_admin_user),
|
||||
_hook_enabled: None = Depends(require_hook_enabled),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
@@ -340,7 +339,7 @@ def delete_hook(
|
||||
@router.post("/{hook_id}/activate")
|
||||
def activate_hook(
|
||||
hook_id: int,
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
_: User = Depends(current_admin_user),
|
||||
_hook_enabled: None = Depends(require_hook_enabled),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> HookResponse:
|
||||
@@ -382,7 +381,7 @@ def activate_hook(
|
||||
@router.post("/{hook_id}/validate")
|
||||
def validate_hook(
|
||||
hook_id: int,
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
_: User = Depends(current_admin_user),
|
||||
_hook_enabled: None = Depends(require_hook_enabled),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> HookValidateResponse:
|
||||
@@ -410,7 +409,7 @@ def validate_hook(
|
||||
@router.post("/{hook_id}/deactivate")
|
||||
def deactivate_hook(
|
||||
hook_id: int,
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
_: User = Depends(current_admin_user),
|
||||
_hook_enabled: None = Depends(require_hook_enabled),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> HookResponse:
|
||||
@@ -433,7 +432,7 @@ def deactivate_hook(
|
||||
def list_hook_execution_logs(
|
||||
hook_id: int,
|
||||
limit: int = Query(default=10, ge=1, le=100),
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
_: User = Depends(current_admin_user),
|
||||
_hook_enabled: None = Depends(require_hook_enabled),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[HookExecutionRecord]:
|
||||
|
||||
@@ -17,6 +17,7 @@ from fastapi import File
|
||||
from fastapi import UploadFile
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.auth.users import current_admin_user
|
||||
from ee.onyx.configs.app_configs import CLOUD_DATA_PLANE_URL
|
||||
from ee.onyx.db.license import delete_license as db_delete_license
|
||||
from ee.onyx.db.license import get_license
|
||||
@@ -31,10 +32,8 @@ from ee.onyx.server.license.models import LicenseStatusResponse
|
||||
from ee.onyx.server.license.models import LicenseUploadResponse
|
||||
from ee.onyx.server.license.models import SeatUsageResponse
|
||||
from ee.onyx.utils.license import verify_license_signature
|
||||
from onyx.auth.permissions import require_permission
|
||||
from onyx.auth.users import User
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.enums import Permission
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -61,7 +60,7 @@ def _strip_pem_delimiters(content: str) -> str:
|
||||
|
||||
@router.get("")
|
||||
async def get_license_status(
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> LicenseStatusResponse:
|
||||
"""Get current license status and seat usage."""
|
||||
@@ -85,7 +84,7 @@ async def get_license_status(
|
||||
|
||||
@router.get("/seats")
|
||||
async def get_seat_usage(
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> SeatUsageResponse:
|
||||
"""Get detailed seat usage information."""
|
||||
@@ -108,7 +107,7 @@ async def get_seat_usage(
|
||||
@router.post("/claim")
|
||||
async def claim_license(
|
||||
session_id: str | None = None,
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> LicenseResponse:
|
||||
"""
|
||||
@@ -216,7 +215,7 @@ async def claim_license(
|
||||
@router.post("/upload")
|
||||
async def upload_license(
|
||||
license_file: UploadFile = File(...),
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> LicenseUploadResponse:
|
||||
"""
|
||||
@@ -264,7 +263,7 @@ async def upload_license(
|
||||
|
||||
@router.post("/refresh")
|
||||
async def refresh_license_cache_endpoint(
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> LicenseStatusResponse:
|
||||
"""
|
||||
@@ -293,7 +292,7 @@ async def refresh_license_cache_endpoint(
|
||||
|
||||
@router.delete("")
|
||||
async def delete_license(
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> dict[str, bool]:
|
||||
"""
|
||||
|
||||
@@ -12,9 +12,8 @@ from ee.onyx.db.standard_answer import insert_standard_answer_category
|
||||
from ee.onyx.db.standard_answer import remove_standard_answer
|
||||
from ee.onyx.db.standard_answer import update_standard_answer
|
||||
from ee.onyx.db.standard_answer import update_standard_answer_category
|
||||
from onyx.auth.permissions import require_permission
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.enums import Permission
|
||||
from onyx.db.models import User
|
||||
from onyx.server.manage.models import StandardAnswer
|
||||
from onyx.server.manage.models import StandardAnswerCategory
|
||||
@@ -28,7 +27,7 @@ router = APIRouter(prefix="/manage")
|
||||
def create_standard_answer(
|
||||
standard_answer_creation_request: StandardAnswerCreationRequest,
|
||||
db_session: Session = Depends(get_session),
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
_: User = Depends(current_admin_user),
|
||||
) -> StandardAnswer:
|
||||
standard_answer_model = insert_standard_answer(
|
||||
keyword=standard_answer_creation_request.keyword,
|
||||
@@ -44,7 +43,7 @@ def create_standard_answer(
|
||||
@router.get("/admin/standard-answer")
|
||||
def list_standard_answers(
|
||||
db_session: Session = Depends(get_session),
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
_: User = Depends(current_admin_user),
|
||||
) -> list[StandardAnswer]:
|
||||
standard_answer_models = fetch_standard_answers(db_session=db_session)
|
||||
return [
|
||||
@@ -58,7 +57,7 @@ def patch_standard_answer(
|
||||
standard_answer_id: int,
|
||||
standard_answer_creation_request: StandardAnswerCreationRequest,
|
||||
db_session: Session = Depends(get_session),
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
_: User = Depends(current_admin_user),
|
||||
) -> StandardAnswer:
|
||||
existing_standard_answer = fetch_standard_answer(
|
||||
standard_answer_id=standard_answer_id,
|
||||
@@ -84,7 +83,7 @@ def patch_standard_answer(
|
||||
def delete_standard_answer(
|
||||
standard_answer_id: int,
|
||||
db_session: Session = Depends(get_session),
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
_: User = Depends(current_admin_user),
|
||||
) -> None:
|
||||
return remove_standard_answer(
|
||||
standard_answer_id=standard_answer_id,
|
||||
@@ -96,7 +95,7 @@ def delete_standard_answer(
|
||||
def create_standard_answer_category(
|
||||
standard_answer_category_creation_request: StandardAnswerCategoryCreationRequest,
|
||||
db_session: Session = Depends(get_session),
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
_: User = Depends(current_admin_user),
|
||||
) -> StandardAnswerCategory:
|
||||
standard_answer_category_model = insert_standard_answer_category(
|
||||
category_name=standard_answer_category_creation_request.name,
|
||||
@@ -108,7 +107,7 @@ def create_standard_answer_category(
|
||||
@router.get("/admin/standard-answer/category")
|
||||
def list_standard_answer_categories(
|
||||
db_session: Session = Depends(get_session),
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
_: User = Depends(current_admin_user),
|
||||
) -> list[StandardAnswerCategory]:
|
||||
standard_answer_category_models = fetch_standard_answer_categories(
|
||||
db_session=db_session
|
||||
@@ -124,7 +123,7 @@ def patch_standard_answer_category(
|
||||
standard_answer_category_id: int,
|
||||
standard_answer_category_creation_request: StandardAnswerCategoryCreationRequest,
|
||||
db_session: Session = Depends(get_session),
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
_: User = Depends(current_admin_user),
|
||||
) -> StandardAnswerCategory:
|
||||
existing_standard_answer_category = fetch_standard_answer_category(
|
||||
standard_answer_category_id=standard_answer_category_id,
|
||||
|
||||
@@ -9,10 +9,9 @@ from ee.onyx.server.oauth.api_router import router
|
||||
from ee.onyx.server.oauth.confluence_cloud import ConfluenceCloudOAuth
|
||||
from ee.onyx.server.oauth.google_drive import GoogleDriveOAuth
|
||||
from ee.onyx.server.oauth.slack import SlackOAuth
|
||||
from onyx.auth.permissions import require_permission
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.configs.app_configs import DEV_MODE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.db.enums import Permission
|
||||
from onyx.db.models import User
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -25,7 +24,7 @@ logger = setup_logger()
|
||||
def prepare_authorization_request(
|
||||
connector: DocumentSource,
|
||||
redirect_on_success: str | None,
|
||||
user: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
user: User = Depends(current_admin_user),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
) -> JSONResponse:
|
||||
"""Used by the frontend to generate the url for the user's browser during auth request.
|
||||
|
||||
@@ -15,7 +15,7 @@ from pydantic import ValidationError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.server.oauth.api_router import router
|
||||
from onyx.auth.permissions import require_permission
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.configs.app_configs import DEV_MODE
|
||||
from onyx.configs.app_configs import OAUTH_CONFLUENCE_CLOUD_CLIENT_ID
|
||||
from onyx.configs.app_configs import OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET
|
||||
@@ -26,7 +26,6 @@ from onyx.db.credentials import create_credential
|
||||
from onyx.db.credentials import fetch_credential_by_id_for_user
|
||||
from onyx.db.credentials import update_credential_json
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.enums import Permission
|
||||
from onyx.db.models import User
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.server.documents.models import CredentialBase
|
||||
@@ -147,7 +146,7 @@ class ConfluenceCloudOAuth:
|
||||
def confluence_oauth_callback(
|
||||
code: str,
|
||||
state: str,
|
||||
user: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
user: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
) -> JSONResponse:
|
||||
@@ -259,7 +258,7 @@ def confluence_oauth_callback(
|
||||
@router.get("/connector/confluence/accessible-resources")
|
||||
def confluence_oauth_accessible_resources(
|
||||
credential_id: int,
|
||||
user: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
user: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id), # noqa: ARG001
|
||||
) -> JSONResponse:
|
||||
@@ -326,7 +325,7 @@ def confluence_oauth_finalize(
|
||||
cloud_id: str,
|
||||
cloud_name: str,
|
||||
cloud_url: str,
|
||||
user: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
user: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id), # noqa: ARG001
|
||||
) -> JSONResponse:
|
||||
|
||||
@@ -12,7 +12,7 @@ from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.server.oauth.api_router import router
|
||||
from onyx.auth.permissions import require_permission
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.configs.app_configs import DEV_MODE
|
||||
from onyx.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_ID
|
||||
from onyx.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
|
||||
@@ -34,7 +34,6 @@ from onyx.connectors.google_utils.shared_constants import (
|
||||
)
|
||||
from onyx.db.credentials import create_credential
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.enums import Permission
|
||||
from onyx.db.models import User
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.server.documents.models import CredentialBase
|
||||
@@ -115,7 +114,7 @@ class GoogleDriveOAuth:
|
||||
def handle_google_drive_oauth_callback(
|
||||
code: str,
|
||||
state: str,
|
||||
user: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
user: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
) -> JSONResponse:
|
||||
|
||||
@@ -10,7 +10,7 @@ from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.server.oauth.api_router import router
|
||||
from onyx.auth.permissions import require_permission
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.configs.app_configs import DEV_MODE
|
||||
from onyx.configs.app_configs import OAUTH_SLACK_CLIENT_ID
|
||||
from onyx.configs.app_configs import OAUTH_SLACK_CLIENT_SECRET
|
||||
@@ -18,7 +18,6 @@ from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.db.credentials import create_credential
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.enums import Permission
|
||||
from onyx.db.models import User
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.server.documents.models import CredentialBase
|
||||
@@ -99,7 +98,7 @@ class SlackOAuth:
|
||||
def handle_slack_oauth_callback(
|
||||
code: str,
|
||||
state: str,
|
||||
user: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
user: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
) -> JSONResponse:
|
||||
|
||||
@@ -8,9 +8,8 @@ from ee.onyx.onyxbot.slack.handlers.handle_standard_answers import (
|
||||
)
|
||||
from ee.onyx.server.query_and_chat.models import StandardAnswerRequest
|
||||
from ee.onyx.server.query_and_chat.models import StandardAnswerResponse
|
||||
from onyx.auth.permissions import require_permission
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.enums import Permission
|
||||
from onyx.db.models import User
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -23,7 +22,7 @@ basic_router = APIRouter(prefix="/query")
|
||||
def get_standard_answer(
|
||||
request: StandardAnswerRequest,
|
||||
db_session: Session = Depends(get_session),
|
||||
_: User = Depends(require_permission(Permission.BASIC_ACCESS)),
|
||||
_: User = Depends(current_user),
|
||||
) -> StandardAnswerResponse:
|
||||
try:
|
||||
standard_answers = oneoff_standard_answers(
|
||||
|
||||
@@ -19,11 +19,10 @@ from ee.onyx.server.query_and_chat.models import SearchHistoryResponse
|
||||
from ee.onyx.server.query_and_chat.models import SearchQueryResponse
|
||||
from ee.onyx.server.query_and_chat.models import SendSearchQueryRequest
|
||||
from ee.onyx.server.query_and_chat.streaming_models import SearchErrorPacket
|
||||
from onyx.auth.permissions import require_permission
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.configs.app_configs import ONYX_SEARCH_UI_USES_OPENSEARCH_KEYWORD_SEARCH
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import Permission
|
||||
from onyx.db.models import User
|
||||
from onyx.llm.factory import get_default_llm
|
||||
from onyx.server.usage_limits import check_llm_cost_limit_for_provider
|
||||
@@ -40,7 +39,7 @@ router = APIRouter(prefix="/search")
|
||||
@router.post("/search-flow-classification")
|
||||
def search_flow_classification(
|
||||
request: SearchFlowClassificationRequest,
|
||||
_: User = Depends(require_permission(Permission.BASIC_ACCESS)),
|
||||
_: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> SearchFlowClassificationResponse:
|
||||
query = request.user_query
|
||||
@@ -80,7 +79,7 @@ def search_flow_classification(
|
||||
)
|
||||
def handle_send_search_message(
|
||||
request: SendSearchQueryRequest,
|
||||
user: User = Depends(require_permission(Permission.BASIC_ACCESS)),
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> StreamingResponse | SearchFullResponse:
|
||||
"""
|
||||
@@ -130,7 +129,7 @@ def handle_send_search_message(
|
||||
def get_search_history(
|
||||
limit: int = 100,
|
||||
filter_days: int | None = None,
|
||||
user: User = Depends(require_permission(Permission.BASIC_ACCESS)),
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> SearchHistoryResponse:
|
||||
"""
|
||||
|
||||
@@ -20,7 +20,7 @@ from ee.onyx.server.query_history.models import ChatSessionMinimal
|
||||
from ee.onyx.server.query_history.models import ChatSessionSnapshot
|
||||
from ee.onyx.server.query_history.models import MessageSnapshot
|
||||
from ee.onyx.server.query_history.models import QueryHistoryExport
|
||||
from onyx.auth.permissions import require_permission
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import get_display_email
|
||||
from onyx.background.celery.versioned_apps.client import app as client_app
|
||||
from onyx.background.task_utils import construct_query_history_report_name
|
||||
@@ -39,7 +39,6 @@ from onyx.configs.constants import SessionType
|
||||
from onyx.db.chat import get_chat_session_by_id
|
||||
from onyx.db.chat import get_chat_sessions_by_user
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.enums import Permission
|
||||
from onyx.db.enums import TaskStatus
|
||||
from onyx.db.file_record import get_query_history_export_files
|
||||
from onyx.db.models import ChatSession
|
||||
@@ -154,7 +153,7 @@ def snapshot_from_chat_session(
|
||||
@router.get("/admin/chat-sessions")
|
||||
def admin_get_chat_sessions(
|
||||
user_id: UUID,
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ChatSessionsResponse:
|
||||
# we specifically don't allow this endpoint if "anonymized" since
|
||||
@@ -197,7 +196,7 @@ def get_chat_session_history(
|
||||
feedback_type: QAFeedbackType | None = None,
|
||||
start_time: datetime | None = None,
|
||||
end_time: datetime | None = None,
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> PaginatedReturn[ChatSessionMinimal]:
|
||||
ensure_query_history_is_enabled(disallowed=[QueryHistoryType.DISABLED])
|
||||
@@ -235,7 +234,7 @@ def get_chat_session_history(
|
||||
@router.get("/admin/chat-session-history/{chat_session_id}")
|
||||
def get_chat_session_admin(
|
||||
chat_session_id: UUID,
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ChatSessionSnapshot:
|
||||
ensure_query_history_is_enabled(disallowed=[QueryHistoryType.DISABLED])
|
||||
@@ -270,7 +269,7 @@ def get_chat_session_admin(
|
||||
|
||||
@router.get("/admin/query-history/list")
|
||||
def list_all_query_history_exports(
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[QueryHistoryExport]:
|
||||
ensure_query_history_is_enabled(disallowed=[QueryHistoryType.DISABLED])
|
||||
@@ -298,7 +297,7 @@ def list_all_query_history_exports(
|
||||
|
||||
@router.post("/admin/query-history/start-export", tags=PUBLIC_API_TAGS)
|
||||
def start_query_history_export(
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
start: datetime | None = None,
|
||||
end: datetime | None = None,
|
||||
@@ -345,7 +344,7 @@ def start_query_history_export(
|
||||
@router.get("/admin/query-history/export-status", tags=PUBLIC_API_TAGS)
|
||||
def get_query_history_export_status(
|
||||
request_id: str,
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> dict[str, str]:
|
||||
ensure_query_history_is_enabled(disallowed=[QueryHistoryType.DISABLED])
|
||||
@@ -379,7 +378,7 @@ def get_query_history_export_status(
|
||||
@router.get("/admin/query-history/download", tags=PUBLIC_API_TAGS)
|
||||
def download_query_history_csv(
|
||||
request_id: str,
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> StreamingResponse:
|
||||
ensure_query_history_is_enabled(disallowed=[QueryHistoryType.DISABLED])
|
||||
|
||||
@@ -12,11 +12,10 @@ from sqlalchemy.orm import Session
|
||||
from ee.onyx.db.usage_export import get_all_usage_reports
|
||||
from ee.onyx.db.usage_export import get_usage_report_data
|
||||
from ee.onyx.db.usage_export import UsageReportMetadata
|
||||
from onyx.auth.permissions import require_permission
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.background.celery.versioned_apps.client import app as client_app
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.enums import Permission
|
||||
from onyx.db.models import User
|
||||
from onyx.file_store.constants import STANDARD_CHUNK_SIZE
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
@@ -32,7 +31,7 @@ class GenerateUsageReportParams(BaseModel):
|
||||
@router.post("/admin/usage-report", status_code=204)
|
||||
def generate_report(
|
||||
params: GenerateUsageReportParams,
|
||||
user: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
user: User = Depends(current_admin_user),
|
||||
) -> None:
|
||||
# Validate period parameters
|
||||
if params.period_from and params.period_to:
|
||||
@@ -59,7 +58,7 @@ def generate_report(
|
||||
@router.get("/admin/usage-report/{report_name}")
|
||||
def read_usage_report(
|
||||
report_name: str,
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session), # noqa: ARG001
|
||||
) -> Response:
|
||||
try:
|
||||
@@ -83,7 +82,7 @@ def read_usage_report(
|
||||
|
||||
@router.get("/admin/usage-report")
|
||||
def fetch_usage_reports(
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[UsageReportMetadata]:
|
||||
try:
|
||||
|
||||
@@ -11,8 +11,6 @@ require a valid SCIM bearer token.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import struct
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter
|
||||
@@ -24,7 +22,6 @@ from fastapi import Response
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi_users.password import PasswordHelper
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -68,25 +65,12 @@ from onyx.db.permissions import recompute_user_permissions__no_commit
|
||||
from onyx.db.users import assign_user_to_default_groups__no_commit
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# Group names reserved for system default groups (seeded by migration).
|
||||
_RESERVED_GROUP_NAMES = frozenset({"Admin", "Basic"})
|
||||
|
||||
# Namespace prefix for the seat-allocation advisory lock. Hashed together
|
||||
# with the tenant ID so the lock is scoped per-tenant (unrelated tenants
|
||||
# never block each other) and cannot collide with unrelated advisory locks.
|
||||
_SEAT_LOCK_NAMESPACE = "onyx_scim_seat_lock"
|
||||
|
||||
|
||||
def _seat_lock_id_for_tenant(tenant_id: str) -> int:
|
||||
"""Derive a stable 64-bit signed int lock id for this tenant's seat lock."""
|
||||
digest = hashlib.sha256(f"{_SEAT_LOCK_NAMESPACE}:{tenant_id}".encode()).digest()
|
||||
# pg_advisory_xact_lock takes a signed 8-byte int; unpack as such.
|
||||
return struct.unpack("q", digest[:8])[0]
|
||||
|
||||
|
||||
class ScimJSONResponse(JSONResponse):
|
||||
"""JSONResponse with Content-Type: application/scim+json (RFC 7644 §3.1)."""
|
||||
@@ -225,37 +209,12 @@ def _apply_exclusions(
|
||||
|
||||
|
||||
def _check_seat_availability(dal: ScimDAL) -> str | None:
|
||||
"""Return an error message if seat limit is reached, else None.
|
||||
|
||||
Acquires a transaction-scoped advisory lock so that concurrent
|
||||
SCIM requests are serialized. IdPs like Okta send provisioning
|
||||
requests in parallel batches — without serialization the check is
|
||||
vulnerable to a TOCTOU race where N concurrent requests each see
|
||||
"seats available", all insert, and the tenant ends up over its
|
||||
seat limit.
|
||||
|
||||
The lock is held until the caller's next COMMIT or ROLLBACK, which
|
||||
means the seat count cannot change between the check here and the
|
||||
subsequent INSERT/UPDATE. Each call site in this module follows
|
||||
the pattern: _check_seat_availability → write → dal.commit()
|
||||
(which releases the lock for the next waiting request).
|
||||
"""
|
||||
"""Return an error message if seat limit is reached, else None."""
|
||||
check_fn = fetch_ee_implementation_or_noop(
|
||||
"onyx.db.license", "check_seat_availability", None
|
||||
)
|
||||
if check_fn is None:
|
||||
return None
|
||||
|
||||
# Transaction-scoped advisory lock — released on dal.commit() / dal.rollback().
|
||||
# The lock id is derived from the tenant so unrelated tenants never block
|
||||
# each other, and from a namespace string so it cannot collide with
|
||||
# unrelated advisory locks elsewhere in the codebase.
|
||||
lock_id = _seat_lock_id_for_tenant(get_current_tenant_id())
|
||||
dal.session.execute(
|
||||
text("SELECT pg_advisory_xact_lock(:lock_id)"),
|
||||
{"lock_id": lock_id},
|
||||
)
|
||||
|
||||
result = check_fn(dal.session, seats_needed=1)
|
||||
if not result.available:
|
||||
return result.error_message or "Seat limit reached"
|
||||
|
||||
@@ -12,13 +12,12 @@ from ee.onyx.server.tenants.anonymous_user_path import (
|
||||
from ee.onyx.server.tenants.anonymous_user_path import modify_anonymous_user_path
|
||||
from ee.onyx.server.tenants.anonymous_user_path import validate_anonymous_user_path
|
||||
from ee.onyx.server.tenants.models import AnonymousUserPath
|
||||
from onyx.auth.permissions import require_permission
|
||||
from onyx.auth.users import anonymous_user_enabled
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import User
|
||||
from onyx.configs.constants import ANONYMOUS_USER_COOKIE_NAME
|
||||
from onyx.configs.constants import FASTAPI_USERS_AUTH_COOKIE_NAME
|
||||
from onyx.db.engine.sql_engine import get_session_with_shared_schema
|
||||
from onyx.db.enums import Permission
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
@@ -29,7 +28,7 @@ router = APIRouter(prefix="/tenants")
|
||||
|
||||
@router.get("/anonymous-user-path")
|
||||
async def get_anonymous_user_path_api(
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
_: User = Depends(current_admin_user),
|
||||
) -> AnonymousUserPath:
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
@@ -45,7 +44,7 @@ async def get_anonymous_user_path_api(
|
||||
@router.post("/anonymous-user-path")
|
||||
async def set_anonymous_user_path_api(
|
||||
anonymous_user_path: str,
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
_: User = Depends(current_admin_user),
|
||||
) -> None:
|
||||
tenant_id = get_current_tenant_id()
|
||||
try:
|
||||
|
||||
@@ -22,6 +22,7 @@ import httpx
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
|
||||
from ee.onyx.auth.users import current_admin_user
|
||||
from ee.onyx.server.tenants.access import control_plane_dep
|
||||
from ee.onyx.server.tenants.billing import fetch_billing_information
|
||||
from ee.onyx.server.tenants.billing import fetch_customer_portal_session
|
||||
@@ -37,12 +38,10 @@ from ee.onyx.server.tenants.models import SubscriptionSessionResponse
|
||||
from ee.onyx.server.tenants.models import SubscriptionStatusResponse
|
||||
from ee.onyx.server.tenants.product_gating import overwrite_full_gated_set
|
||||
from ee.onyx.server.tenants.product_gating import store_product_gating
|
||||
from onyx.auth.permissions import require_permission
|
||||
from onyx.auth.users import User
|
||||
from onyx.configs.app_configs import STRIPE_PUBLISHABLE_KEY_OVERRIDE
|
||||
from onyx.configs.app_configs import STRIPE_PUBLISHABLE_KEY_URL
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.db.enums import Permission
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -100,7 +99,7 @@ def gate_product_full_sync(
|
||||
|
||||
@router.get("/billing-information")
|
||||
async def billing_information(
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
_: User = Depends(current_admin_user),
|
||||
) -> BillingInformation | SubscriptionStatusResponse:
|
||||
logger.info("Fetching billing information")
|
||||
tenant_id = get_current_tenant_id()
|
||||
@@ -109,7 +108,7 @@ async def billing_information(
|
||||
|
||||
@router.post("/create-customer-portal-session")
|
||||
async def create_customer_portal_session(
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
_: User = Depends(current_admin_user),
|
||||
) -> dict:
|
||||
"""Create a Stripe customer portal session via the control plane."""
|
||||
tenant_id = get_current_tenant_id()
|
||||
@@ -131,7 +130,7 @@ async def create_customer_portal_session(
|
||||
@router.post("/create-checkout-session")
|
||||
async def create_checkout_session(
|
||||
request: CreateCheckoutSessionRequest | None = None,
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
_: User = Depends(current_admin_user),
|
||||
) -> dict:
|
||||
"""Create a Stripe checkout session via the control plane."""
|
||||
tenant_id = get_current_tenant_id()
|
||||
@@ -154,7 +153,7 @@ async def create_checkout_session(
|
||||
@router.post("/create-subscription-session")
|
||||
async def create_subscription_session(
|
||||
request: CreateSubscriptionSessionRequest | None = None,
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
_: User = Depends(current_admin_user),
|
||||
) -> SubscriptionSessionResponse:
|
||||
try:
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
|
||||
@@ -6,11 +6,10 @@ from sqlalchemy.orm import Session
|
||||
from ee.onyx.server.tenants.provisioning import delete_user_from_control_plane
|
||||
from ee.onyx.server.tenants.user_mapping import remove_all_users_from_tenant
|
||||
from ee.onyx.server.tenants.user_mapping import remove_users_from_tenant
|
||||
from onyx.auth.permissions import require_permission
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import User
|
||||
from onyx.db.auth import get_user_count
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.enums import Permission
|
||||
from onyx.db.users import delete_user_from_db
|
||||
from onyx.db.users import get_user_by_email
|
||||
from onyx.server.manage.models import UserByEmail
|
||||
@@ -25,9 +24,7 @@ router = APIRouter(prefix="/tenants")
|
||||
@router.post("/leave-team")
|
||||
async def leave_organization(
|
||||
user_email: UserByEmail,
|
||||
current_user: User = Depends(
|
||||
require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)
|
||||
),
|
||||
current_user: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
@@ -3,9 +3,8 @@ from fastapi import Depends
|
||||
|
||||
from ee.onyx.server.tenants.models import TenantByDomainResponse
|
||||
from ee.onyx.server.tenants.provisioning import get_tenant_by_domain_from_control_plane
|
||||
from onyx.auth.permissions import require_permission
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.auth.users import User
|
||||
from onyx.db.enums import Permission
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
@@ -27,7 +26,7 @@ FORBIDDEN_COMMON_EMAIL_SUBSTRINGS = [
|
||||
|
||||
@router.get("/existing-team-by-domain")
|
||||
def get_existing_tenant_by_domain(
|
||||
user: User = Depends(require_permission(Permission.BASIC_ACCESS)),
|
||||
user: User = Depends(current_user),
|
||||
) -> TenantByDomainResponse | None:
|
||||
domain = user.email.split("@")[1]
|
||||
if any(substring in domain for substring in FORBIDDEN_COMMON_EMAIL_SUBSTRINGS):
|
||||
|
||||
@@ -10,9 +10,9 @@ from ee.onyx.server.tenants.user_mapping import approve_user_invite
|
||||
from ee.onyx.server.tenants.user_mapping import deny_user_invite
|
||||
from ee.onyx.server.tenants.user_mapping import invite_self_to_tenant
|
||||
from onyx.auth.invited_users import get_pending_users
|
||||
from onyx.auth.permissions import require_permission
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.auth.users import User
|
||||
from onyx.db.enums import Permission
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
@@ -24,7 +24,7 @@ router = APIRouter(prefix="/tenants")
|
||||
@router.post("/users/invite/request")
|
||||
async def request_invite(
|
||||
invite_request: RequestInviteRequest,
|
||||
user: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
user: User = Depends(current_admin_user),
|
||||
) -> None:
|
||||
try:
|
||||
invite_self_to_tenant(user.email, invite_request.tenant_id)
|
||||
@@ -37,7 +37,7 @@ async def request_invite(
|
||||
|
||||
@router.get("/users/pending")
|
||||
def list_pending_users(
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
_: User = Depends(current_admin_user),
|
||||
) -> list[PendingUserSnapshot]:
|
||||
pending_emails = get_pending_users()
|
||||
return [PendingUserSnapshot(email=email) for email in pending_emails]
|
||||
@@ -46,7 +46,7 @@ def list_pending_users(
|
||||
@router.post("/users/invite/approve")
|
||||
async def approve_user(
|
||||
approve_user_request: ApproveUserRequest,
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
_: User = Depends(current_admin_user),
|
||||
) -> None:
|
||||
tenant_id = get_current_tenant_id()
|
||||
approve_user_invite(approve_user_request.email, tenant_id)
|
||||
@@ -55,7 +55,7 @@ async def approve_user(
|
||||
@router.post("/users/invite/accept")
|
||||
async def accept_invite(
|
||||
invite_request: RequestInviteRequest,
|
||||
user: User = Depends(require_permission(Permission.BASIC_ACCESS)),
|
||||
user: User = Depends(current_user),
|
||||
) -> None:
|
||||
"""
|
||||
Accept an invitation to join a tenant.
|
||||
@@ -70,7 +70,7 @@ async def accept_invite(
|
||||
@router.post("/users/invite/deny")
|
||||
async def deny_invite(
|
||||
invite_request: RequestInviteRequest,
|
||||
user: User = Depends(require_permission(Permission.BASIC_ACCESS)),
|
||||
user: User = Depends(current_user),
|
||||
) -> None:
|
||||
"""
|
||||
Deny an invitation to join a tenant.
|
||||
|
||||
@@ -7,11 +7,10 @@ from sqlalchemy.orm import Session
|
||||
from ee.onyx.db.token_limit import fetch_all_user_group_token_rate_limits_by_group
|
||||
from ee.onyx.db.token_limit import fetch_user_group_token_rate_limits_for_user
|
||||
from ee.onyx.db.token_limit import insert_user_group_token_rate_limit
|
||||
from onyx.auth.permissions import require_permission
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import current_curator_or_admin_user
|
||||
from onyx.configs.constants import PUBLIC_API_TAGS
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.enums import Permission
|
||||
from onyx.db.models import User
|
||||
from onyx.db.token_limit import fetch_all_user_token_rate_limits
|
||||
from onyx.db.token_limit import insert_user_token_rate_limit
|
||||
@@ -29,7 +28,7 @@ Group Token Limit Settings
|
||||
|
||||
@router.get("/user-groups")
|
||||
def get_all_group_token_limit_settings(
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> dict[str, list[TokenRateLimitDisplay]]:
|
||||
user_groups_to_token_rate_limits = fetch_all_user_group_token_rate_limits_by_group(
|
||||
@@ -65,7 +64,7 @@ def get_group_token_limit_settings(
|
||||
def create_group_token_limit_settings(
|
||||
group_id: int,
|
||||
token_limit_settings: TokenRateLimitArgs,
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> TokenRateLimitDisplay:
|
||||
rate_limit_display = TokenRateLimitDisplay.from_db(
|
||||
@@ -87,7 +86,7 @@ User Token Limit Settings
|
||||
|
||||
@router.get("/users")
|
||||
def get_user_token_limit_settings(
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[TokenRateLimitDisplay]:
|
||||
return [
|
||||
@@ -99,7 +98,7 @@ def get_user_token_limit_settings(
|
||||
@router.post("/users")
|
||||
def create_user_token_limit_settings(
|
||||
token_limit_settings: TokenRateLimitArgs,
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> TokenRateLimitDisplay:
|
||||
rate_limit_display = TokenRateLimitDisplay.from_db(
|
||||
|
||||
@@ -13,26 +13,22 @@ from ee.onyx.db.user_group import fetch_user_groups_for_user
|
||||
from ee.onyx.db.user_group import insert_user_group
|
||||
from ee.onyx.db.user_group import prepare_user_group_for_deletion
|
||||
from ee.onyx.db.user_group import rename_user_group
|
||||
from ee.onyx.db.user_group import set_group_permission__no_commit
|
||||
from ee.onyx.db.user_group import update_user_curator_relationship
|
||||
from ee.onyx.db.user_group import update_user_group
|
||||
from ee.onyx.server.user_group.models import AddUsersToUserGroupRequest
|
||||
from ee.onyx.server.user_group.models import MinimalUserGroupSnapshot
|
||||
from ee.onyx.server.user_group.models import SetCuratorRequest
|
||||
from ee.onyx.server.user_group.models import SetPermissionRequest
|
||||
from ee.onyx.server.user_group.models import SetPermissionResponse
|
||||
from ee.onyx.server.user_group.models import UpdateGroupAgentsRequest
|
||||
from ee.onyx.server.user_group.models import UserGroup
|
||||
from ee.onyx.server.user_group.models import UserGroupCreate
|
||||
from ee.onyx.server.user_group.models import UserGroupRename
|
||||
from ee.onyx.server.user_group.models import UserGroupUpdate
|
||||
from onyx.auth.permissions import NON_TOGGLEABLE_PERMISSIONS
|
||||
from onyx.auth.permissions import require_permission
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import current_curator_or_admin_user
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.configs.constants import PUBLIC_API_TAGS
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.enums import Permission
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserRole
|
||||
from onyx.db.persona import get_persona_by_id
|
||||
@@ -72,7 +68,7 @@ def list_user_groups(
|
||||
@router.get("/user-groups/minimal")
|
||||
def list_minimal_user_groups(
|
||||
include_default: bool = False,
|
||||
user: User = Depends(require_permission(Permission.BASIC_ACCESS)),
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[MinimalUserGroupSnapshot]:
|
||||
if user.role == UserRole.ADMIN:
|
||||
@@ -95,50 +91,23 @@ def list_minimal_user_groups(
|
||||
@router.get("/admin/user-group/{user_group_id}/permissions")
|
||||
def get_user_group_permissions(
|
||||
user_group_id: int,
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[Permission]:
|
||||
) -> list[str]:
|
||||
group = fetch_user_group(db_session, user_group_id)
|
||||
if group is None:
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, "User group not found")
|
||||
return [
|
||||
grant.permission for grant in group.permission_grants if not grant.is_deleted
|
||||
grant.permission.value
|
||||
for grant in group.permission_grants
|
||||
if not grant.is_deleted
|
||||
]
|
||||
|
||||
|
||||
@router.put("/admin/user-group/{user_group_id}/permissions")
|
||||
def set_user_group_permission(
|
||||
user_group_id: int,
|
||||
request: SetPermissionRequest,
|
||||
user: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> SetPermissionResponse:
|
||||
group = fetch_user_group(db_session, user_group_id)
|
||||
if group is None:
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, "User group not found")
|
||||
|
||||
if request.permission in NON_TOGGLEABLE_PERMISSIONS:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INVALID_INPUT,
|
||||
f"Permission '{request.permission}' cannot be toggled via this endpoint",
|
||||
)
|
||||
|
||||
set_group_permission__no_commit(
|
||||
group_id=user_group_id,
|
||||
permission=request.permission,
|
||||
enabled=request.enabled,
|
||||
granted_by=user.id,
|
||||
db_session=db_session,
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
return SetPermissionResponse(permission=request.permission, enabled=request.enabled)
|
||||
|
||||
|
||||
@router.post("/admin/user-group")
|
||||
def create_user_group(
|
||||
user_group: UserGroupCreate,
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> UserGroup:
|
||||
try:
|
||||
@@ -155,7 +124,7 @@ def create_user_group(
|
||||
@router.patch("/admin/user-group/rename")
|
||||
def rename_user_group_endpoint(
|
||||
rename_request: UserGroupRename,
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> UserGroup:
|
||||
group = fetch_user_group(db_session, rename_request.id)
|
||||
@@ -243,7 +212,7 @@ def set_user_curator(
|
||||
@router.delete("/admin/user-group/{user_group_id}")
|
||||
def delete_user_group(
|
||||
user_group_id: int,
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
group = fetch_user_group(db_session, user_group_id)
|
||||
@@ -264,7 +233,7 @@ def delete_user_group(
|
||||
def update_group_agents(
|
||||
user_group_id: int,
|
||||
request: UpdateGroupAgentsRequest,
|
||||
user: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
user: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
for agent_id in request.added_agent_ids:
|
||||
|
||||
@@ -2,7 +2,6 @@ from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.auth.permissions import Permission
|
||||
from onyx.db.models import UserGroup as UserGroupModel
|
||||
from onyx.server.documents.models import ConnectorCredentialPairDescriptor
|
||||
from onyx.server.documents.models import ConnectorSnapshot
|
||||
@@ -122,13 +121,3 @@ class SetCuratorRequest(BaseModel):
|
||||
class UpdateGroupAgentsRequest(BaseModel):
|
||||
added_agent_ids: list[int]
|
||||
removed_agent_ids: list[int]
|
||||
|
||||
|
||||
class SetPermissionRequest(BaseModel):
|
||||
permission: Permission
|
||||
enabled: bool
|
||||
|
||||
|
||||
class SetPermissionResponse(BaseModel):
|
||||
permission: Permission
|
||||
enabled: bool
|
||||
|
||||
@@ -96,14 +96,11 @@ def get_model_app() -> FastAPI:
|
||||
title="Onyx Model Server", version=__version__, lifespan=lifespan
|
||||
)
|
||||
if SENTRY_DSN:
|
||||
from onyx.configs.sentry import _add_instance_tags
|
||||
|
||||
sentry_sdk.init(
|
||||
dsn=SENTRY_DSN,
|
||||
integrations=[StarletteIntegration(), FastApiIntegration()],
|
||||
traces_sample_rate=0.1,
|
||||
release=__version__,
|
||||
before_send=_add_instance_tags,
|
||||
)
|
||||
logger.info("Sentry initialized")
|
||||
else:
|
||||
|
||||
@@ -47,20 +47,6 @@ IMPLIED_PERMISSIONS: dict[str, set[str]] = {
|
||||
},
|
||||
}
|
||||
|
||||
# Permissions that cannot be toggled via the group-permission API.
|
||||
# BASIC_ACCESS is always granted, FULL_ADMIN_PANEL_ACCESS is too broad,
|
||||
# and READ_* permissions are implied (never stored directly).
|
||||
NON_TOGGLEABLE_PERMISSIONS: frozenset[Permission] = frozenset(
|
||||
{
|
||||
Permission.BASIC_ACCESS,
|
||||
Permission.FULL_ADMIN_PANEL_ACCESS,
|
||||
Permission.READ_CONNECTORS,
|
||||
Permission.READ_DOCUMENT_SETS,
|
||||
Permission.READ_AGENTS,
|
||||
Permission.READ_USERS,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def resolve_effective_permissions(granted: set[str]) -> set[str]:
|
||||
"""Expand granted permissions with their implied permissions.
|
||||
@@ -121,5 +107,4 @@ def require_permission(
|
||||
|
||||
return user
|
||||
|
||||
dependency._is_require_permission = True # type: ignore[attr-defined] # sentinel for auth_check detection
|
||||
return dependency
|
||||
|
||||
@@ -127,7 +127,6 @@ from onyx.db.models import User
|
||||
from onyx.db.pat import fetch_user_for_pat
|
||||
from onyx.db.users import assign_user_to_default_groups__no_commit
|
||||
from onyx.db.users import get_user_by_email
|
||||
from onyx.db.users import is_limited_user
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import log_onyx_error
|
||||
from onyx.error_handling.exceptions import onyx_error_to_json_response
|
||||
@@ -1682,9 +1681,9 @@ async def current_user(
|
||||
) -> User:
|
||||
user = await double_check_user(user)
|
||||
|
||||
if is_limited_user(user):
|
||||
if user.role == UserRole.LIMITED:
|
||||
raise BasicAuthenticationError(
|
||||
detail="Access denied. User has limited permissions.",
|
||||
detail="Access denied. User role is LIMITED. BASIC or higher permissions are required.",
|
||||
)
|
||||
return user
|
||||
|
||||
@@ -1701,6 +1700,15 @@ async def current_curator_or_admin_user(
|
||||
return user
|
||||
|
||||
|
||||
async def current_admin_user(user: User = Depends(current_user)) -> User:
|
||||
if user.role != UserRole.ADMIN:
|
||||
raise BasicAuthenticationError(
|
||||
detail="Access denied. User must be an admin to perform this action.",
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
|
||||
async def _get_user_from_token_data(token_data: dict) -> User | None:
|
||||
"""Shared logic: token data dict → User object.
|
||||
|
||||
@@ -1809,11 +1817,11 @@ async def current_user_from_websocket(
|
||||
# Apply same checks as HTTP auth (verification, OIDC expiry, role)
|
||||
user = await double_check_user(user)
|
||||
|
||||
# Block limited users (same as current_user)
|
||||
if is_limited_user(user):
|
||||
logger.warning(f"WS auth: user {user.email} is limited")
|
||||
# Block LIMITED users (same as current_user)
|
||||
if user.role == UserRole.LIMITED:
|
||||
logger.warning(f"WS auth: user {user.email} has LIMITED role")
|
||||
raise BasicAuthenticationError(
|
||||
detail="Access denied. User has limited permissions.",
|
||||
detail="Access denied. User role is LIMITED. BASIC or higher permissions are required.",
|
||||
)
|
||||
|
||||
logger.debug(f"WS auth: authenticated {user.email}")
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
# Overview of Onyx Background Jobs
|
||||
|
||||
The background jobs take care of:
|
||||
|
||||
1. Pulling/Indexing documents (from connectors)
|
||||
2. Updating document metadata (from connectors)
|
||||
3. Cleaning up checkpoints and logic around indexing work (indexing indexing checkpoints and index attempt metadata)
|
||||
@@ -10,41 +9,37 @@ The background jobs take care of:
|
||||
|
||||
## Worker → Queue Mapping
|
||||
|
||||
| Worker | File | Queues |
|
||||
| ------------------------- | ------------------------------ | -------------------------------------------------------------------------------------------------------------------- |
|
||||
| Primary | `apps/primary.py` | `celery` |
|
||||
| Light | `apps/light.py` | `vespa_metadata_sync`, `connector_deletion`, `doc_permissions_upsert`, `checkpoint_cleanup`, `index_attempt_cleanup` |
|
||||
| Heavy | `apps/heavy.py` | `connector_pruning`, `connector_doc_permissions_sync`, `connector_external_group_sync`, `csv_generation`, `sandbox` |
|
||||
| Docprocessing | `apps/docprocessing.py` | `docprocessing` |
|
||||
| Docfetching | `apps/docfetching.py` | `connector_doc_fetching` |
|
||||
| User File Processing | `apps/user_file_processing.py` | `user_file_processing`, `user_file_project_sync`, `user_file_delete` |
|
||||
| Monitoring | `apps/monitoring.py` | `monitoring` |
|
||||
| Background (consolidated) | `apps/background.py` | All queues above except `celery` |
|
||||
| Worker | File | Queues |
|
||||
|--------|------|--------|
|
||||
| Primary | `apps/primary.py` | `celery` |
|
||||
| Light | `apps/light.py` | `vespa_metadata_sync`, `connector_deletion`, `doc_permissions_upsert`, `checkpoint_cleanup`, `index_attempt_cleanup` |
|
||||
| Heavy | `apps/heavy.py` | `connector_pruning`, `connector_doc_permissions_sync`, `connector_external_group_sync`, `csv_generation`, `sandbox` |
|
||||
| Docprocessing | `apps/docprocessing.py` | `docprocessing` |
|
||||
| Docfetching | `apps/docfetching.py` | `connector_doc_fetching` |
|
||||
| User File Processing | `apps/user_file_processing.py` | `user_file_processing`, `user_file_project_sync`, `user_file_delete` |
|
||||
| Monitoring | `apps/monitoring.py` | `monitoring` |
|
||||
| Background (consolidated) | `apps/background.py` | All queues above except `celery` |
|
||||
|
||||
## Non-Worker Apps
|
||||
|
||||
| App | File | Purpose |
|
||||
| ---------- | ----------- | ----------------------------------------------------------------------------------------------------- |
|
||||
| **Beat** | `beat.py` | Celery beat scheduler with `DynamicTenantScheduler` that generates per-tenant periodic task schedules |
|
||||
| **Client** | `client.py` | Minimal app for task submission from non-worker processes (e.g., API server) |
|
||||
| App | File | Purpose |
|
||||
|-----|------|---------|
|
||||
| **Beat** | `beat.py` | Celery beat scheduler with `DynamicTenantScheduler` that generates per-tenant periodic task schedules |
|
||||
| **Client** | `client.py` | Minimal app for task submission from non-worker processes (e.g., API server) |
|
||||
|
||||
### Shared Module
|
||||
|
||||
`app_base.py` provides:
|
||||
|
||||
- `TenantAwareTask` - Base task class that sets tenant context
|
||||
- Signal handlers for logging, cleanup, and lifecycle events
|
||||
- Readiness probes and health checks
|
||||
|
||||
|
||||
## Worker Details
|
||||
|
||||
### Primary (Coordinator and task dispatcher)
|
||||
|
||||
It is the single worker which handles tasks from the default celery queue. It is a singleton worker ensured by the `PRIMARY_WORKER` Redis lock
|
||||
which it touches every `CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 8` seconds (using Celery Bootsteps)
|
||||
|
||||
On startup:
|
||||
|
||||
- waits for redis, postgres, document index to all be healthy
|
||||
- acquires the singleton lock
|
||||
- cleans all the redis states associated with background jobs
|
||||
@@ -52,34 +47,34 @@ On startup:
|
||||
|
||||
Then it cycles through its tasks as scheduled by Celery Beat:
|
||||
|
||||
| Task | Frequency | Description |
|
||||
| --------------------------------- | --------- | ------------------------------------------------------------------------------------------ |
|
||||
| `check_for_indexing` | 15s | Scans for connectors needing indexing → dispatches to `DOCFETCHING` queue |
|
||||
| `check_for_vespa_sync_task` | 20s | Finds stale documents/document sets → dispatches sync tasks to `VESPA_METADATA_SYNC` queue |
|
||||
| `check_for_pruning` | 20s | Finds connectors due for pruning → dispatches to `CONNECTOR_PRUNING` queue |
|
||||
| `check_for_connector_deletion` | 20s | Processes deletion requests → dispatches to `CONNECTOR_DELETION` queue |
|
||||
| `check_for_user_file_processing` | 20s | Checks for user uploads → dispatches to `USER_FILE_PROCESSING` queue |
|
||||
| `check_for_checkpoint_cleanup` | 1h | Cleans up old indexing checkpoints |
|
||||
| `check_for_index_attempt_cleanup` | 30m | Cleans up old index attempts |
|
||||
| `celery_beat_heartbeat` | 1m | Heartbeat for Beat watchdog |
|
||||
| Task | Frequency | Description |
|
||||
|------|-----------|-------------|
|
||||
| `check_for_indexing` | 15s | Scans for connectors needing indexing → dispatches to `DOCFETCHING` queue |
|
||||
| `check_for_vespa_sync_task` | 20s | Finds stale documents/document sets → dispatches sync tasks to `VESPA_METADATA_SYNC` queue |
|
||||
| `check_for_pruning` | 20s | Finds connectors due for pruning → dispatches to `CONNECTOR_PRUNING` queue |
|
||||
| `check_for_connector_deletion` | 20s | Processes deletion requests → dispatches to `CONNECTOR_DELETION` queue |
|
||||
| `check_for_user_file_processing` | 20s | Checks for user uploads → dispatches to `USER_FILE_PROCESSING` queue |
|
||||
| `check_for_checkpoint_cleanup` | 1h | Cleans up old indexing checkpoints |
|
||||
| `check_for_index_attempt_cleanup` | 30m | Cleans up old index attempts |
|
||||
| `kombu_message_cleanup_task` | periodic | Cleans orphaned Kombu messages from DB (Kombu being the messaging framework used by Celery) |
|
||||
| `celery_beat_heartbeat` | 1m | Heartbeat for Beat watchdog |
|
||||
|
||||
Watchdog is a separate Python process managed by supervisord which runs alongside celery workers. It checks the ONYX_CELERY_BEAT_HEARTBEAT_KEY in
|
||||
Redis to ensure Celery Beat is not dead. Beat schedules the celery_beat_heartbeat for Primary to touch the key and share that it's still alive.
|
||||
See supervisord.conf for watchdog config.
|
||||
|
||||
### Light
|
||||
|
||||
### Light
|
||||
Fast and short living tasks that are not resource intensive. High concurrency:
|
||||
Can have 24 concurrent workers, each with a prefetch of 8 for a total of 192 tasks in flight at once.
|
||||
|
||||
Tasks it handles:
|
||||
|
||||
- Syncs access/permissions, document sets, boosts, hidden state
|
||||
- Deletes documents that are marked for deletion in Postgres
|
||||
- Cleanup of checkpoints and index attempts
|
||||
|
||||
### Heavy
|
||||
|
||||
### Heavy
|
||||
Long running, resource intensive tasks, handles pruning and sandbox operations. Low concurrency - max concurrency of 4 with 1 prefetch.
|
||||
|
||||
Does not interact with the Document Index, it handles the syncs with external systems. Large volume API calls to handle pruning and fetching permissions, etc.
|
||||
@@ -88,24 +83,16 @@ Generates CSV exports which may take a long time with significant data in Postgr
|
||||
|
||||
Sandbox (new feature) for running Next.js, Python virtual env, OpenCode AI Agent, and access to knowledge files
|
||||
|
||||
|
||||
### Docprocessing, Docfetching, User File Processing
|
||||
|
||||
Docprocessing and Docfetching are for indexing documents:
|
||||
|
||||
- Docfetching runs connectors to pull documents from external APIs (Google Drive, Confluence, etc.), stores batches to file storage, and dispatches docprocessing tasks
|
||||
- Docprocessing retrieves batches, runs the indexing pipeline (chunking, embedding), and indexes into the Document Index
|
||||
- User Files come from uploads directly via the input bar
|
||||
- Docprocessing retrieves batches, runs the indexing pipeline (chunking, embedding), and indexes into the Document Index
|
||||
User Files come from uploads directly via the input bar
|
||||
|
||||
|
||||
### Monitoring
|
||||
|
||||
Observability and metrics collections:
|
||||
|
||||
- Queue lengths, connector success/failure, connector latencies
|
||||
- Queue lengths, connector success/failure, lconnector latencies
|
||||
- Memory of supervisor managed processes (workers, beat, slack)
|
||||
- Cloud and multitenant specific monitorings
|
||||
|
||||
## Prometheus Metrics
|
||||
|
||||
Workers can expose Prometheus metrics via a standalone HTTP server. Currently docfetching and docprocessing have push-based task lifecycle metrics; the monitoring worker runs pull-based collectors for queue depth and connector health.
|
||||
|
||||
For the full metric reference, integration guide, and PromQL examples, see [`docs/METRICS.md`](../../../docs/METRICS.md#celery-worker-metrics).
|
||||
|
||||
@@ -10,7 +10,6 @@ from celery import bootsteps # type: ignore
|
||||
from celery import Task
|
||||
from celery.app import trace
|
||||
from celery.exceptions import WorkerShutdown
|
||||
from celery.signals import before_task_publish
|
||||
from celery.signals import task_postrun
|
||||
from celery.signals import task_prerun
|
||||
from celery.states import READY_STATES
|
||||
@@ -63,14 +62,11 @@ logger = setup_logger()
|
||||
task_logger = get_task_logger(__name__)
|
||||
|
||||
if SENTRY_DSN:
|
||||
from onyx.configs.sentry import _add_instance_tags
|
||||
|
||||
sentry_sdk.init(
|
||||
dsn=SENTRY_DSN,
|
||||
integrations=[CeleryIntegration()],
|
||||
traces_sample_rate=0.1,
|
||||
release=__version__,
|
||||
before_send=_add_instance_tags,
|
||||
)
|
||||
logger.info("Sentry initialized")
|
||||
else:
|
||||
@@ -98,17 +94,6 @@ class TenantAwareTask(Task):
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(None)
|
||||
|
||||
|
||||
@before_task_publish.connect
|
||||
def on_before_task_publish(
|
||||
headers: dict[str, Any] | None = None,
|
||||
**kwargs: Any, # noqa: ARG001
|
||||
) -> None:
|
||||
"""Stamp the current wall-clock time into the task message headers so that
|
||||
workers can compute queue wait time (time between publish and execution)."""
|
||||
if headers is not None:
|
||||
headers["enqueued_at"] = time.time()
|
||||
|
||||
|
||||
@task_prerun.connect
|
||||
def on_task_prerun(
|
||||
sender: Any | None = None, # noqa: ARG001
|
||||
|
||||
@@ -13,12 +13,6 @@ from celery.signals import worker_shutdown
|
||||
import onyx.background.celery.apps.app_base as app_base
|
||||
from onyx.configs.constants import POSTGRES_CELERY_WORKER_HEAVY_APP_NAME
|
||||
from onyx.db.engine.sql_engine import SqlEngine
|
||||
from onyx.server.metrics.celery_task_metrics import on_celery_task_postrun
|
||||
from onyx.server.metrics.celery_task_metrics import on_celery_task_prerun
|
||||
from onyx.server.metrics.celery_task_metrics import on_celery_task_rejected
|
||||
from onyx.server.metrics.celery_task_metrics import on_celery_task_retry
|
||||
from onyx.server.metrics.celery_task_metrics import on_celery_task_revoked
|
||||
from onyx.server.metrics.metrics_server import start_metrics_server
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
@@ -40,7 +34,6 @@ def on_task_prerun(
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
app_base.on_task_prerun(sender, task_id, task, args, kwargs, **kwds)
|
||||
on_celery_task_prerun(task_id, task)
|
||||
|
||||
|
||||
@signals.task_postrun.connect
|
||||
@@ -55,31 +48,6 @@ def on_task_postrun(
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
app_base.on_task_postrun(sender, task_id, task, args, kwargs, retval, state, **kwds)
|
||||
on_celery_task_postrun(task_id, task, state)
|
||||
|
||||
|
||||
@signals.task_retry.connect
|
||||
def on_task_retry(sender: Any | None = None, **kwargs: Any) -> None: # noqa: ARG001
|
||||
task_id = getattr(getattr(sender, "request", None), "id", None)
|
||||
on_celery_task_retry(task_id, sender)
|
||||
|
||||
|
||||
@signals.task_revoked.connect
|
||||
def on_task_revoked(sender: Any | None = None, **kwargs: Any) -> None:
|
||||
task_name = getattr(sender, "name", None) or str(sender)
|
||||
on_celery_task_revoked(kwargs.get("task_id"), task_name)
|
||||
|
||||
|
||||
@signals.task_rejected.connect
|
||||
def on_task_rejected(sender: Any | None = None, **kwargs: Any) -> None: # noqa: ARG001
|
||||
message = kwargs.get("message")
|
||||
task_name: str | None = None
|
||||
if message is not None:
|
||||
headers = getattr(message, "headers", None) or {}
|
||||
task_name = headers.get("task")
|
||||
if task_name is None:
|
||||
task_name = "unknown"
|
||||
on_celery_task_rejected(None, task_name)
|
||||
|
||||
|
||||
@celeryd_init.connect
|
||||
@@ -108,7 +76,6 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
|
||||
@worker_ready.connect
|
||||
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
|
||||
start_metrics_server("heavy")
|
||||
app_base.on_worker_ready(sender, **kwargs)
|
||||
|
||||
|
||||
|
||||
@@ -16,12 +16,6 @@ from onyx.configs.app_configs import VESPA_CLOUD_CERT_PATH
|
||||
from onyx.configs.app_configs import VESPA_CLOUD_KEY_PATH
|
||||
from onyx.configs.constants import POSTGRES_CELERY_WORKER_LIGHT_APP_NAME
|
||||
from onyx.db.engine.sql_engine import SqlEngine
|
||||
from onyx.server.metrics.celery_task_metrics import on_celery_task_postrun
|
||||
from onyx.server.metrics.celery_task_metrics import on_celery_task_prerun
|
||||
from onyx.server.metrics.celery_task_metrics import on_celery_task_rejected
|
||||
from onyx.server.metrics.celery_task_metrics import on_celery_task_retry
|
||||
from onyx.server.metrics.celery_task_metrics import on_celery_task_revoked
|
||||
from onyx.server.metrics.metrics_server import start_metrics_server
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
@@ -42,7 +36,6 @@ def on_task_prerun(
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
app_base.on_task_prerun(sender, task_id, task, args, kwargs, **kwds)
|
||||
on_celery_task_prerun(task_id, task)
|
||||
|
||||
|
||||
@signals.task_postrun.connect
|
||||
@@ -57,31 +50,6 @@ def on_task_postrun(
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
app_base.on_task_postrun(sender, task_id, task, args, kwargs, retval, state, **kwds)
|
||||
on_celery_task_postrun(task_id, task, state)
|
||||
|
||||
|
||||
@signals.task_retry.connect
|
||||
def on_task_retry(sender: Any | None = None, **kwargs: Any) -> None: # noqa: ARG001
|
||||
task_id = getattr(getattr(sender, "request", None), "id", None)
|
||||
on_celery_task_retry(task_id, sender)
|
||||
|
||||
|
||||
@signals.task_revoked.connect
|
||||
def on_task_revoked(sender: Any | None = None, **kwargs: Any) -> None:
|
||||
task_name = getattr(sender, "name", None) or str(sender)
|
||||
on_celery_task_revoked(kwargs.get("task_id"), task_name)
|
||||
|
||||
|
||||
@signals.task_rejected.connect
|
||||
def on_task_rejected(sender: Any | None = None, **kwargs: Any) -> None: # noqa: ARG001
|
||||
message = kwargs.get("message")
|
||||
task_name: str | None = None
|
||||
if message is not None:
|
||||
headers = getattr(message, "headers", None) or {}
|
||||
task_name = headers.get("task")
|
||||
if task_name is None:
|
||||
task_name = "unknown"
|
||||
on_celery_task_rejected(None, task_name)
|
||||
|
||||
|
||||
@celeryd_init.connect
|
||||
@@ -122,7 +90,6 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
|
||||
@worker_ready.connect
|
||||
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
|
||||
start_metrics_server("light")
|
||||
app_base.on_worker_ready(sender, **kwargs)
|
||||
|
||||
|
||||
|
||||
@@ -317,6 +317,7 @@ celery_app.autodiscover_tasks(
|
||||
"onyx.background.celery.tasks.docprocessing",
|
||||
"onyx.background.celery.tasks.evals",
|
||||
"onyx.background.celery.tasks.hierarchyfetching",
|
||||
"onyx.background.celery.tasks.periodic",
|
||||
"onyx.background.celery.tasks.pruning",
|
||||
"onyx.background.celery.tasks.shared",
|
||||
"onyx.background.celery.tasks.vespa",
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Iterator
|
||||
from collections.abc import Sequence
|
||||
@@ -31,8 +30,6 @@ from onyx.connectors.models import HierarchyNode
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.httpx.httpx_pool import HttpxPool
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.server.metrics.pruning_metrics import inc_pruning_rate_limit_error
|
||||
from onyx.server.metrics.pruning_metrics import observe_pruning_enumeration_duration
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
@@ -133,7 +130,6 @@ def _extract_from_batch(
|
||||
def extract_ids_from_runnable_connector(
|
||||
runnable_connector: BaseConnector,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
connector_type: str = "unknown",
|
||||
) -> SlimConnectorExtractionResult:
|
||||
"""
|
||||
Extract document IDs and hierarchy nodes from a runnable connector.
|
||||
@@ -183,38 +179,21 @@ def extract_ids_from_runnable_connector(
|
||||
)
|
||||
|
||||
# process raw batches to extract both IDs and hierarchy nodes
|
||||
enumeration_start = time.monotonic()
|
||||
try:
|
||||
for doc_list in raw_batch_generator:
|
||||
if callback and callback.should_stop():
|
||||
raise RuntimeError(
|
||||
"extract_ids_from_runnable_connector: Stop signal detected"
|
||||
)
|
||||
for doc_list in raw_batch_generator:
|
||||
if callback and callback.should_stop():
|
||||
raise RuntimeError(
|
||||
"extract_ids_from_runnable_connector: Stop signal detected"
|
||||
)
|
||||
|
||||
batch_result = _extract_from_batch(doc_list)
|
||||
batch_ids = batch_result.raw_id_to_parent
|
||||
batch_nodes = batch_result.hierarchy_nodes
|
||||
doc_batch_processing_func(batch_ids)
|
||||
all_raw_id_to_parent.update(batch_ids)
|
||||
all_hierarchy_nodes.extend(batch_nodes)
|
||||
batch_result = _extract_from_batch(doc_list)
|
||||
batch_ids = batch_result.raw_id_to_parent
|
||||
batch_nodes = batch_result.hierarchy_nodes
|
||||
doc_batch_processing_func(batch_ids)
|
||||
all_raw_id_to_parent.update(batch_ids)
|
||||
all_hierarchy_nodes.extend(batch_nodes)
|
||||
|
||||
if callback:
|
||||
callback.progress("extract_ids_from_runnable_connector", len(batch_ids))
|
||||
except Exception as e:
|
||||
# Best-effort rate limit detection via string matching.
|
||||
# Connectors surface rate limits inconsistently — some raise HTTP 429,
|
||||
# some use SDK-specific exceptions (e.g. google.api_core.exceptions.ResourceExhausted)
|
||||
# that may or may not include "rate limit" or "429" in the message.
|
||||
# TODO(Bo): replace with a standard ConnectorRateLimitError exception that all
|
||||
# connectors raise when rate limited, making this check precise.
|
||||
error_str = str(e)
|
||||
if "rate limit" in error_str.lower() or "429" in error_str:
|
||||
inc_pruning_rate_limit_error(connector_type)
|
||||
raise
|
||||
finally:
|
||||
observe_pruning_enumeration_duration(
|
||||
time.monotonic() - enumeration_start, connector_type
|
||||
)
|
||||
if callback:
|
||||
callback.progress("extract_ids_from_runnable_connector", len(batch_ids))
|
||||
|
||||
return SlimConnectorExtractionResult(
|
||||
raw_id_to_parent=all_raw_id_to_parent,
|
||||
|
||||
@@ -75,8 +75,6 @@ beat_task_templates: list[dict] = [
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.LOW,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
# Run on gated tenants too — they may still have stale checkpoints to clean.
|
||||
"skip_gated": False,
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -86,8 +84,6 @@ beat_task_templates: list[dict] = [
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
# Run on gated tenants too — they may still have stale index attempts.
|
||||
"skip_gated": False,
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -97,8 +93,6 @@ beat_task_templates: list[dict] = [
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
# Gated tenants may still have connectors awaiting deletion.
|
||||
"skip_gated": False,
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -142,14 +136,7 @@ beat_task_templates: list[dict] = [
|
||||
{
|
||||
"name": "cleanup-idle-sandboxes",
|
||||
"task": OnyxCeleryTask.CLEANUP_IDLE_SANDBOXES,
|
||||
# SANDBOX_IDLE_TIMEOUT_SECONDS defaults to 1 hour, so there is no
|
||||
# functional reason to scan more often than every ~15 minutes. In the
|
||||
# cloud this is multiplied by CLOUD_BEAT_MULTIPLIER_DEFAULT (=8) so
|
||||
# the effective cadence becomes ~2 hours, which still meets the
|
||||
# idle-detection SLA. The previous 1-minute base schedule produced
|
||||
# an 8-minute per-tenant fan-out and was the dominant source of
|
||||
# background DB load on the cloud cluster.
|
||||
"schedule": timedelta(minutes=15),
|
||||
"schedule": timedelta(minutes=1),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.LOW,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
@@ -279,7 +266,7 @@ def make_cloud_generator_task(task: dict[str, Any]) -> dict[str, Any]:
|
||||
cloud_task["kwargs"] = {}
|
||||
cloud_task["kwargs"]["task_name"] = task["task"]
|
||||
|
||||
optional_fields = ["queue", "priority", "expires", "skip_gated"]
|
||||
optional_fields = ["queue", "priority", "expires"]
|
||||
for field in optional_fields:
|
||||
if field in task["options"]:
|
||||
cloud_task["kwargs"][field] = task["options"][field]
|
||||
@@ -372,13 +359,7 @@ if not MULTI_TENANT:
|
||||
]
|
||||
)
|
||||
|
||||
# `skip_gated` is a cloud-only hint consumed by `cloud_beat_task_generator`. Strip
|
||||
# it before extending the self-hosted schedule so it doesn't leak into apply_async
|
||||
# as an unrecognised option on every fired task message.
|
||||
for _template in beat_task_templates:
|
||||
_self_hosted_template = copy.deepcopy(_template)
|
||||
_self_hosted_template["options"].pop("skip_gated", None)
|
||||
tasks_to_schedule.append(_self_hosted_template)
|
||||
tasks_to_schedule.extend(beat_task_templates)
|
||||
|
||||
|
||||
def generate_cloud_tasks(
|
||||
|
||||
@@ -59,11 +59,6 @@ from onyx.redis.redis_connector_delete import RedisConnectorDelete
|
||||
from onyx.redis.redis_connector_delete import RedisConnectorDeletePayload
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_pool import get_redis_replica_client
|
||||
from onyx.server.metrics.deletion_metrics import inc_deletion_blocked
|
||||
from onyx.server.metrics.deletion_metrics import inc_deletion_completed
|
||||
from onyx.server.metrics.deletion_metrics import inc_deletion_fence_reset
|
||||
from onyx.server.metrics.deletion_metrics import inc_deletion_started
|
||||
from onyx.server.metrics.deletion_metrics import observe_deletion_taskset_duration
|
||||
from onyx.utils.variable_functionality import (
|
||||
fetch_versioned_implementation_with_fallback,
|
||||
)
|
||||
@@ -107,7 +102,7 @@ def revoke_tasks_blocking_deletion(
|
||||
f"Revoked permissions sync task {permissions_sync_payload.celery_task_id}."
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception("Exception while revoking permissions sync task")
|
||||
task_logger.exception("Exception while revoking pruning task")
|
||||
|
||||
try:
|
||||
prune_payload = redis_connector.prune.payload
|
||||
@@ -115,7 +110,7 @@ def revoke_tasks_blocking_deletion(
|
||||
app.control.revoke(prune_payload.celery_task_id)
|
||||
task_logger.info(f"Revoked pruning task {prune_payload.celery_task_id}.")
|
||||
except Exception:
|
||||
task_logger.exception("Exception while revoking pruning task")
|
||||
task_logger.exception("Exception while revoking permissions sync task")
|
||||
|
||||
try:
|
||||
external_group_sync_payload = redis_connector.external_group_sync.payload
|
||||
@@ -305,7 +300,6 @@ def try_generate_document_cc_pair_cleanup_tasks(
|
||||
recent_index_attempts
|
||||
and recent_index_attempts[0].status == IndexingStatus.IN_PROGRESS
|
||||
):
|
||||
inc_deletion_blocked(tenant_id, "indexing")
|
||||
raise TaskDependencyError(
|
||||
"Connector deletion - Delayed (indexing in progress): "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
@@ -313,13 +307,11 @@ def try_generate_document_cc_pair_cleanup_tasks(
|
||||
)
|
||||
|
||||
if redis_connector.prune.fenced:
|
||||
inc_deletion_blocked(tenant_id, "pruning")
|
||||
raise TaskDependencyError(
|
||||
f"Connector deletion - Delayed (pruning in progress): cc_pair={cc_pair_id}"
|
||||
)
|
||||
|
||||
if redis_connector.permissions.fenced:
|
||||
inc_deletion_blocked(tenant_id, "permissions")
|
||||
raise TaskDependencyError(
|
||||
f"Connector deletion - Delayed (permissions in progress): cc_pair={cc_pair_id}"
|
||||
)
|
||||
@@ -367,7 +359,6 @@ def try_generate_document_cc_pair_cleanup_tasks(
|
||||
# set this only after all tasks have been added
|
||||
fence_payload.num_tasks = tasks_generated
|
||||
redis_connector.delete.set_fence(fence_payload)
|
||||
inc_deletion_started(tenant_id)
|
||||
|
||||
return tasks_generated
|
||||
|
||||
@@ -517,11 +508,7 @@ def monitor_connector_deletion_taskset(
|
||||
db_session=db_session,
|
||||
connector_id=connector_id_to_delete,
|
||||
)
|
||||
if not connector:
|
||||
task_logger.info(
|
||||
"Connector deletion - Connector already deleted, skipping connector cleanup"
|
||||
)
|
||||
elif not len(connector.credentials):
|
||||
if not connector or not len(connector.credentials):
|
||||
task_logger.info(
|
||||
"Connector deletion - Found no credentials left for connector, deleting connector"
|
||||
)
|
||||
@@ -536,12 +523,6 @@ def monitor_connector_deletion_taskset(
|
||||
num_docs_synced=fence_data.num_tasks,
|
||||
)
|
||||
|
||||
duration = (
|
||||
datetime.now(timezone.utc) - fence_data.submitted
|
||||
).total_seconds()
|
||||
observe_deletion_taskset_duration(tenant_id, "success", duration)
|
||||
inc_deletion_completed(tenant_id, "success")
|
||||
|
||||
except Exception as e:
|
||||
db_session.rollback()
|
||||
stack_trace = traceback.format_exc()
|
||||
@@ -560,11 +541,6 @@ def monitor_connector_deletion_taskset(
|
||||
f"Connector deletion exceptioned: "
|
||||
f"cc_pair={cc_pair_id} connector={connector_id_to_delete} credential={credential_id_to_delete}"
|
||||
)
|
||||
duration = (
|
||||
datetime.now(timezone.utc) - fence_data.submitted
|
||||
).total_seconds()
|
||||
observe_deletion_taskset_duration(tenant_id, "failure", duration)
|
||||
inc_deletion_completed(tenant_id, "failure")
|
||||
raise e
|
||||
|
||||
task_logger.info(
|
||||
@@ -741,6 +717,5 @@ def validate_connector_deletion_fence(
|
||||
f"fence={fence_key}"
|
||||
)
|
||||
|
||||
inc_deletion_fence_reset(tenant_id)
|
||||
redis_connector.delete.reset()
|
||||
return
|
||||
|
||||
@@ -135,13 +135,10 @@ def _docfetching_task(
|
||||
# Since connector_indexing_proxy_task spawns a new process using this function as
|
||||
# the entrypoint, we init Sentry here.
|
||||
if SENTRY_DSN:
|
||||
from onyx.configs.sentry import _add_instance_tags
|
||||
|
||||
sentry_sdk.init(
|
||||
dsn=SENTRY_DSN,
|
||||
traces_sample_rate=0.1,
|
||||
release=__version__,
|
||||
before_send=_add_instance_tags,
|
||||
)
|
||||
logger.info("Sentry initialized")
|
||||
else:
|
||||
|
||||
@@ -3,7 +3,6 @@ import os
|
||||
import time
|
||||
import traceback
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
@@ -51,7 +50,6 @@ from onyx.configs.constants import AuthType
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_INDEXING_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import MilestoneRecordType
|
||||
from onyx.configs.constants import NotificationType
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
@@ -87,8 +85,6 @@ from onyx.db.indexing_coordination import INDEXING_PROGRESS_TIMEOUT_HOURS
|
||||
from onyx.db.indexing_coordination import IndexingCoordination
|
||||
from onyx.db.models import IndexAttempt
|
||||
from onyx.db.models import SearchSettings
|
||||
from onyx.db.notification import create_notification
|
||||
from onyx.db.notification import get_notifications
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.db.search_settings import get_secondary_search_settings
|
||||
from onyx.db.swap_index import check_and_perform_index_swap
|
||||
@@ -109,9 +105,6 @@ from onyx.redis.redis_pool import get_redis_replica_client
|
||||
from onyx.redis.redis_pool import redis_lock_dump
|
||||
from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT
|
||||
from onyx.redis.redis_utils import is_fence
|
||||
from onyx.server.metrics.connector_health_metrics import on_connector_error_state_change
|
||||
from onyx.server.metrics.connector_health_metrics import on_connector_indexing_success
|
||||
from onyx.server.metrics.connector_health_metrics import on_index_attempt_status_change
|
||||
from onyx.server.runtime.onyx_runtime import OnyxRuntime
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.middleware import make_randomized_onyx_request_id
|
||||
@@ -407,6 +400,7 @@ def check_indexing_completion(
|
||||
tenant_id: str,
|
||||
task: Task,
|
||||
) -> None:
|
||||
|
||||
logger.info(
|
||||
f"Checking for indexing completion: attempt={index_attempt_id} tenant={tenant_id}"
|
||||
)
|
||||
@@ -527,23 +521,13 @@ def check_indexing_completion(
|
||||
|
||||
# Update CC pair status if successful
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
db_session,
|
||||
attempt.connector_credential_pair_id,
|
||||
eager_load_connector=True,
|
||||
db_session, attempt.connector_credential_pair_id
|
||||
)
|
||||
if cc_pair is None:
|
||||
raise RuntimeError(
|
||||
f"CC pair {attempt.connector_credential_pair_id} not found in database"
|
||||
)
|
||||
|
||||
source = cc_pair.connector.source.value
|
||||
on_index_attempt_status_change(
|
||||
tenant_id=tenant_id,
|
||||
source=source,
|
||||
cc_pair_id=cc_pair.id,
|
||||
status=attempt.status.value,
|
||||
)
|
||||
|
||||
if attempt.status.is_successful():
|
||||
# NOTE: we define the last successful index time as the time the last successful
|
||||
# attempt finished. This is distinct from the poll_range_end of the last successful
|
||||
@@ -564,39 +548,10 @@ def check_indexing_completion(
|
||||
event=MilestoneRecordType.CONNECTOR_SUCCEEDED,
|
||||
)
|
||||
|
||||
on_connector_indexing_success(
|
||||
tenant_id=tenant_id,
|
||||
source=source,
|
||||
cc_pair_id=cc_pair.id,
|
||||
docs_indexed=attempt.new_docs_indexed or 0,
|
||||
success_timestamp=attempt.time_updated.timestamp(),
|
||||
)
|
||||
|
||||
# Clear repeated error state on success
|
||||
if cc_pair.in_repeated_error_state:
|
||||
cc_pair.in_repeated_error_state = False
|
||||
|
||||
# Delete any existing error notification for this CC pair so a
|
||||
# fresh one is created if the connector fails again later.
|
||||
for notif in get_notifications(
|
||||
user=None,
|
||||
db_session=db_session,
|
||||
notif_type=NotificationType.CONNECTOR_REPEATED_ERRORS,
|
||||
include_dismissed=True,
|
||||
):
|
||||
if (
|
||||
notif.additional_data
|
||||
and notif.additional_data.get("cc_pair_id") == cc_pair.id
|
||||
):
|
||||
db_session.delete(notif)
|
||||
|
||||
db_session.commit()
|
||||
on_connector_error_state_change(
|
||||
tenant_id=tenant_id,
|
||||
source=source,
|
||||
cc_pair_id=cc_pair.id,
|
||||
in_error=False,
|
||||
)
|
||||
|
||||
if attempt.status == IndexingStatus.SUCCESS:
|
||||
logger.info(
|
||||
@@ -653,27 +608,6 @@ def active_indexing_attempt(
|
||||
return bool(active_indexing_attempt)
|
||||
|
||||
|
||||
@dataclass
|
||||
class _KickoffResult:
|
||||
"""Tracks diagnostic counts from a _kickoff_indexing_tasks run."""
|
||||
|
||||
created: int = 0
|
||||
skipped_active: int = 0
|
||||
skipped_not_found: int = 0
|
||||
skipped_not_indexable: int = 0
|
||||
failed_to_create: int = 0
|
||||
|
||||
@property
|
||||
def evaluated(self) -> int:
|
||||
return (
|
||||
self.created
|
||||
+ self.skipped_active
|
||||
+ self.skipped_not_found
|
||||
+ self.skipped_not_indexable
|
||||
+ self.failed_to_create
|
||||
)
|
||||
|
||||
|
||||
def _kickoff_indexing_tasks(
|
||||
celery_app: Celery,
|
||||
db_session: Session,
|
||||
@@ -683,12 +617,12 @@ def _kickoff_indexing_tasks(
|
||||
redis_client: Redis,
|
||||
lock_beat: RedisLock,
|
||||
tenant_id: str,
|
||||
) -> _KickoffResult:
|
||||
) -> int:
|
||||
"""Kick off indexing tasks for the given cc_pair_ids and search_settings.
|
||||
|
||||
Returns a _KickoffResult with diagnostic counts.
|
||||
Returns the number of tasks successfully created.
|
||||
"""
|
||||
result = _KickoffResult()
|
||||
tasks_created = 0
|
||||
|
||||
for cc_pair_id in cc_pair_ids:
|
||||
lock_beat.reacquire()
|
||||
@@ -699,7 +633,6 @@ def _kickoff_indexing_tasks(
|
||||
search_settings_id=search_settings.id,
|
||||
db_session=db_session,
|
||||
):
|
||||
result.skipped_active += 1
|
||||
continue
|
||||
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
@@ -710,7 +643,6 @@ def _kickoff_indexing_tasks(
|
||||
task_logger.warning(
|
||||
f"_kickoff_indexing_tasks - CC pair not found: cc_pair={cc_pair_id}"
|
||||
)
|
||||
result.skipped_not_found += 1
|
||||
continue
|
||||
|
||||
# Heavyweight check after fetching cc pair
|
||||
@@ -725,7 +657,6 @@ def _kickoff_indexing_tasks(
|
||||
f"search_settings={search_settings.id}, "
|
||||
f"secondary_index_building={secondary_index_building}"
|
||||
)
|
||||
result.skipped_not_indexable += 1
|
||||
continue
|
||||
|
||||
task_logger.debug(
|
||||
@@ -765,14 +696,13 @@ def _kickoff_indexing_tasks(
|
||||
task_logger.info(
|
||||
f"Connector indexing queued: index_attempt={attempt_id} cc_pair={cc_pair.id} search_settings={search_settings.id}"
|
||||
)
|
||||
result.created += 1
|
||||
tasks_created += 1
|
||||
else:
|
||||
task_logger.error(
|
||||
f"Failed to create indexing task: cc_pair={cc_pair.id} search_settings={search_settings.id}"
|
||||
)
|
||||
result.failed_to_create += 1
|
||||
|
||||
return result
|
||||
return tasks_created
|
||||
|
||||
|
||||
@shared_task(
|
||||
@@ -798,8 +728,6 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
|
||||
task_logger.warning("check_for_indexing - Starting")
|
||||
|
||||
tasks_created = 0
|
||||
primary_result = _KickoffResult()
|
||||
secondary_result: _KickoffResult | None = None
|
||||
locked = False
|
||||
redis_client = get_redis_client()
|
||||
redis_client_replica = get_redis_replica_client()
|
||||
@@ -920,39 +848,6 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
|
||||
cc_pair_id=cc_pair_id,
|
||||
in_repeated_error_state=True,
|
||||
)
|
||||
on_connector_error_state_change(
|
||||
tenant_id=tenant_id,
|
||||
source=cc_pair.connector.source.value,
|
||||
cc_pair_id=cc_pair_id,
|
||||
in_error=True,
|
||||
)
|
||||
|
||||
connector_name = (
|
||||
cc_pair.name
|
||||
or cc_pair.connector.name
|
||||
or f"CC pair {cc_pair.id}"
|
||||
)
|
||||
source = cc_pair.connector.source.value
|
||||
connector_url = f"/admin/connector/{cc_pair.id}"
|
||||
create_notification(
|
||||
user_id=None,
|
||||
notif_type=NotificationType.CONNECTOR_REPEATED_ERRORS,
|
||||
db_session=db_session,
|
||||
title=f"Connector '{connector_name}' has entered repeated error state",
|
||||
description=(
|
||||
f"The {source} connector has failed repeatedly and "
|
||||
f"has been flagged. View indexing history in the "
|
||||
f"Advanced section: {connector_url}"
|
||||
),
|
||||
additional_data={"cc_pair_id": cc_pair.id},
|
||||
)
|
||||
|
||||
task_logger.error(
|
||||
f"Connector entered repeated error state: "
|
||||
f"cc_pair={cc_pair.id} "
|
||||
f"connector={cc_pair.connector.name} "
|
||||
f"source={source}"
|
||||
)
|
||||
# When entering repeated error state, also pause the connector
|
||||
# to prevent continued indexing retry attempts burning through embedding credits.
|
||||
# NOTE: only for Cloud, since most self-hosted users use self-hosted embedding
|
||||
@@ -968,7 +863,7 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
|
||||
# Heavy check, should_index(), is called in _kickoff_indexing_tasks
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
# Primary first
|
||||
primary_result = _kickoff_indexing_tasks(
|
||||
tasks_created += _kickoff_indexing_tasks(
|
||||
celery_app=self.app,
|
||||
db_session=db_session,
|
||||
search_settings=current_search_settings,
|
||||
@@ -978,7 +873,6 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
|
||||
lock_beat=lock_beat,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
tasks_created += primary_result.created
|
||||
|
||||
# Secondary indexing (only if secondary search settings exist and switchover_type is not INSTANT)
|
||||
if (
|
||||
@@ -986,7 +880,7 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
|
||||
and secondary_search_settings.switchover_type != SwitchoverType.INSTANT
|
||||
and secondary_cc_pair_ids
|
||||
):
|
||||
secondary_result = _kickoff_indexing_tasks(
|
||||
tasks_created += _kickoff_indexing_tasks(
|
||||
celery_app=self.app,
|
||||
db_session=db_session,
|
||||
search_settings=secondary_search_settings,
|
||||
@@ -996,7 +890,6 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
|
||||
lock_beat=lock_beat,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
tasks_created += secondary_result.created
|
||||
elif (
|
||||
secondary_search_settings
|
||||
and secondary_search_settings.switchover_type == SwitchoverType.INSTANT
|
||||
@@ -1109,26 +1002,7 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
|
||||
redis_lock_dump(lock_beat, redis_client)
|
||||
|
||||
time_elapsed = time.monotonic() - time_start
|
||||
task_logger.info(
|
||||
f"check_for_indexing finished: "
|
||||
f"elapsed={time_elapsed:.2f}s "
|
||||
f"primary=[evaluated={primary_result.evaluated} "
|
||||
f"created={primary_result.created} "
|
||||
f"skipped_active={primary_result.skipped_active} "
|
||||
f"skipped_not_found={primary_result.skipped_not_found} "
|
||||
f"skipped_not_indexable={primary_result.skipped_not_indexable} "
|
||||
f"failed={primary_result.failed_to_create}]"
|
||||
+ (
|
||||
f" secondary=[evaluated={secondary_result.evaluated} "
|
||||
f"created={secondary_result.created} "
|
||||
f"skipped_active={secondary_result.skipped_active} "
|
||||
f"skipped_not_found={secondary_result.skipped_not_found} "
|
||||
f"skipped_not_indexable={secondary_result.skipped_not_indexable} "
|
||||
f"failed={secondary_result.failed_to_create}]"
|
||||
if secondary_result
|
||||
else ""
|
||||
)
|
||||
)
|
||||
task_logger.info(f"check_for_indexing finished: elapsed={time_elapsed:.2f}")
|
||||
return tasks_created
|
||||
|
||||
|
||||
|
||||
@@ -172,10 +172,6 @@ def migrate_chunks_from_vespa_to_opensearch_task(
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
indexing_setting = IndexingSetting.from_db_model(search_settings)
|
||||
|
||||
task_logger.debug(
|
||||
"Verified tenant info, migration record, and search settings."
|
||||
)
|
||||
|
||||
# 2.e. Build sanitized to original doc ID mapping to check for
|
||||
# conflicts in the event we sanitize a doc ID to an
|
||||
# already-existing doc ID.
|
||||
@@ -329,7 +325,6 @@ def migrate_chunks_from_vespa_to_opensearch_task(
|
||||
finally:
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
task_logger.debug("Released the OpenSearch migration lock.")
|
||||
else:
|
||||
task_logger.warning(
|
||||
"The OpenSearch migration lock was not owned on completion of the migration task."
|
||||
|
||||
138
backend/onyx/background/celery/tasks/periodic/tasks.py
Normal file
138
backend/onyx/background/celery/tasks/periodic/tasks.py
Normal file
@@ -0,0 +1,138 @@
|
||||
#####
|
||||
# Periodic Tasks
|
||||
#####
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from celery import shared_task
|
||||
from celery.contrib.abortable import AbortableTask # type: ignore
|
||||
from celery.exceptions import TaskRevokedError
|
||||
from sqlalchemy import inspect
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import PostgresAdvisoryLocks
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.KOMBU_MESSAGE_CLEANUP_TASK,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
bind=True,
|
||||
base=AbortableTask,
|
||||
)
|
||||
def kombu_message_cleanup_task(self: Any, tenant_id: str) -> int: # noqa: ARG001
|
||||
"""Runs periodically to clean up the kombu_message table"""
|
||||
|
||||
# we will select messages older than this amount to clean up
|
||||
KOMBU_MESSAGE_CLEANUP_AGE = 7 # days
|
||||
KOMBU_MESSAGE_CLEANUP_PAGE_LIMIT = 1000
|
||||
|
||||
ctx = {}
|
||||
ctx["last_processed_id"] = 0
|
||||
ctx["deleted"] = 0
|
||||
ctx["cleanup_age"] = KOMBU_MESSAGE_CLEANUP_AGE
|
||||
ctx["page_limit"] = KOMBU_MESSAGE_CLEANUP_PAGE_LIMIT
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
# Exit the task if we can't take the advisory lock
|
||||
result = db_session.execute(
|
||||
text("SELECT pg_try_advisory_lock(:id)"),
|
||||
{"id": PostgresAdvisoryLocks.KOMBU_MESSAGE_CLEANUP_LOCK_ID.value},
|
||||
).scalar()
|
||||
if not result:
|
||||
return 0
|
||||
|
||||
while True:
|
||||
if self.is_aborted():
|
||||
raise TaskRevokedError("kombu_message_cleanup_task was aborted.")
|
||||
|
||||
b = kombu_message_cleanup_task_helper(ctx, db_session)
|
||||
if not b:
|
||||
break
|
||||
|
||||
db_session.commit()
|
||||
|
||||
if ctx["deleted"] > 0:
|
||||
task_logger.info(
|
||||
f"Deleted {ctx['deleted']} orphaned messages from kombu_message."
|
||||
)
|
||||
|
||||
return ctx["deleted"]
|
||||
|
||||
|
||||
def kombu_message_cleanup_task_helper(ctx: dict, db_session: Session) -> bool:
|
||||
"""
|
||||
Helper function to clean up old messages from the `kombu_message` table that are no longer relevant.
|
||||
|
||||
This function retrieves messages from the `kombu_message` table that are no longer visible and
|
||||
older than a specified interval. It checks if the corresponding task_id exists in the
|
||||
`celery_taskmeta` table. If the task_id does not exist, the message is deleted.
|
||||
|
||||
Args:
|
||||
ctx (dict): A context dictionary containing configuration parameters such as:
|
||||
- 'cleanup_age' (int): The age in days after which messages are considered old.
|
||||
- 'page_limit' (int): The maximum number of messages to process in one batch.
|
||||
- 'last_processed_id' (int): The ID of the last processed message to handle pagination.
|
||||
- 'deleted' (int): A counter to track the number of deleted messages.
|
||||
db_session (Session): The SQLAlchemy database session for executing queries.
|
||||
|
||||
Returns:
|
||||
bool: Returns True if there are more rows to process, False if not.
|
||||
"""
|
||||
|
||||
inspector = inspect(db_session.bind)
|
||||
if not inspector:
|
||||
return False
|
||||
|
||||
# With the move to redis as celery's broker and backend, kombu tables may not even exist.
|
||||
# We can fail silently.
|
||||
if not inspector.has_table("kombu_message"):
|
||||
return False
|
||||
|
||||
query = text(
|
||||
"""
|
||||
SELECT id, timestamp, payload
|
||||
FROM kombu_message WHERE visible = 'false'
|
||||
AND timestamp < CURRENT_TIMESTAMP - INTERVAL :interval_days
|
||||
AND id > :last_processed_id
|
||||
ORDER BY id
|
||||
LIMIT :page_limit
|
||||
"""
|
||||
)
|
||||
kombu_messages = db_session.execute(
|
||||
query,
|
||||
{
|
||||
"interval_days": f"{ctx['cleanup_age']} days",
|
||||
"page_limit": ctx["page_limit"],
|
||||
"last_processed_id": ctx["last_processed_id"],
|
||||
},
|
||||
).fetchall()
|
||||
|
||||
if len(kombu_messages) == 0:
|
||||
return False
|
||||
|
||||
for msg in kombu_messages:
|
||||
payload = json.loads(msg[2])
|
||||
task_id = payload["headers"]["id"]
|
||||
|
||||
# Check if task_id exists in celery_taskmeta
|
||||
task_exists = db_session.execute(
|
||||
text("SELECT 1 FROM celery_taskmeta WHERE task_id = :task_id"),
|
||||
{"task_id": task_id},
|
||||
).fetchone()
|
||||
|
||||
# If task_id does not exist, delete the message
|
||||
if not task_exists:
|
||||
result = db_session.execute(
|
||||
text("DELETE FROM kombu_message WHERE id = :message_id"),
|
||||
{"message_id": msg[0]},
|
||||
)
|
||||
if result.rowcount > 0: # type: ignore
|
||||
ctx["deleted"] += 1
|
||||
|
||||
ctx["last_processed_id"] = msg[0]
|
||||
|
||||
return True
|
||||
@@ -38,7 +38,6 @@ from onyx.configs.constants import OnyxRedisConstants
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.configs.constants import OnyxRedisSignals
|
||||
from onyx.connectors.factory import instantiate_connector
|
||||
from onyx.connectors.interfaces import BaseConnector
|
||||
from onyx.connectors.models import InputType
|
||||
from onyx.db.connector import mark_ccpair_as_pruned
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair
|
||||
@@ -73,7 +72,6 @@ from onyx.redis.redis_hierarchy import get_source_node_id_from_cache
|
||||
from onyx.redis.redis_hierarchy import HierarchyNodeCacheEntry
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_pool import get_redis_replica_client
|
||||
from onyx.server.metrics.pruning_metrics import observe_pruning_diff_duration
|
||||
from onyx.server.runtime.onyx_runtime import OnyxRuntime
|
||||
from onyx.server.utils import make_short_id
|
||||
from onyx.utils.logger import format_error_for_logging
|
||||
@@ -219,7 +217,7 @@ def check_for_pruning(self: Task, *, tenant_id: str) -> bool | None:
|
||||
try:
|
||||
# the entire task needs to run frequently in order to finalize pruning
|
||||
|
||||
# but pruning only kicks off once per min
|
||||
# but pruning only kicks off once per hour
|
||||
if not r.exists(OnyxRedisSignals.BLOCK_PRUNING):
|
||||
task_logger.info("Checking for pruning due")
|
||||
|
||||
@@ -526,14 +524,6 @@ def connector_pruning_generator_task(
|
||||
return None
|
||||
|
||||
try:
|
||||
# Session 1: pre-enumeration — load cc_pair and instantiate the connector.
|
||||
# The session is closed before enumeration so the DB connection is not held
|
||||
# open during the 10–30+ minute connector crawl.
|
||||
connector_source: DocumentSource | None = None
|
||||
connector_type: str = ""
|
||||
is_connector_public: bool = False
|
||||
runnable_connector: BaseConnector | None = None
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
cc_pair = get_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
@@ -559,51 +549,48 @@ def connector_pruning_generator_task(
|
||||
)
|
||||
redis_connector.prune.set_fence(new_payload)
|
||||
|
||||
connector_source = cc_pair.connector.source
|
||||
connector_type = connector_source.value
|
||||
is_connector_public = cc_pair.access_type == AccessType.PUBLIC
|
||||
|
||||
task_logger.info(
|
||||
f"Pruning generator running connector: cc_pair={cc_pair_id} connector_source={connector_source}"
|
||||
f"Pruning generator running connector: cc_pair={cc_pair_id} connector_source={cc_pair.connector.source}"
|
||||
)
|
||||
|
||||
runnable_connector = instantiate_connector(
|
||||
db_session,
|
||||
connector_source,
|
||||
cc_pair.connector.source,
|
||||
InputType.SLIM_RETRIEVAL,
|
||||
cc_pair.connector.connector_specific_config,
|
||||
cc_pair.credential,
|
||||
)
|
||||
# Session 1 closed here — connection released before enumeration.
|
||||
|
||||
callback = PruneCallback(
|
||||
0,
|
||||
redis_connector,
|
||||
lock,
|
||||
r,
|
||||
timeout_seconds=JOB_TIMEOUT,
|
||||
)
|
||||
callback = PruneCallback(
|
||||
0,
|
||||
redis_connector,
|
||||
lock,
|
||||
r,
|
||||
timeout_seconds=JOB_TIMEOUT,
|
||||
)
|
||||
|
||||
# Extract docs and hierarchy nodes from the source (no DB session held).
|
||||
extraction_result = extract_ids_from_runnable_connector(
|
||||
runnable_connector, callback, connector_type=connector_type
|
||||
)
|
||||
all_connector_doc_ids = extraction_result.raw_id_to_parent
|
||||
# Extract docs and hierarchy nodes from the source
|
||||
extraction_result = extract_ids_from_runnable_connector(
|
||||
runnable_connector, callback
|
||||
)
|
||||
all_connector_doc_ids = extraction_result.raw_id_to_parent
|
||||
|
||||
# Session 2: post-enumeration — hierarchy upserts, diff computation, task dispatch.
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
source = connector_source
|
||||
# Process hierarchy nodes (same as docfetching):
|
||||
# upsert to Postgres and cache in Redis
|
||||
source = cc_pair.connector.source
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
ensure_source_node_exists(redis_client, db_session, source)
|
||||
|
||||
upserted_nodes: list[DBHierarchyNode] = []
|
||||
if extraction_result.hierarchy_nodes:
|
||||
is_connector_public = cc_pair.access_type == AccessType.PUBLIC
|
||||
|
||||
upserted_nodes = upsert_hierarchy_nodes_batch(
|
||||
db_session=db_session,
|
||||
nodes=extraction_result.hierarchy_nodes,
|
||||
source=source,
|
||||
commit=False,
|
||||
commit=True,
|
||||
is_connector_public=is_connector_public,
|
||||
)
|
||||
|
||||
@@ -612,13 +599,9 @@ def connector_pruning_generator_task(
|
||||
hierarchy_node_ids=[n.id for n in upserted_nodes],
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
commit=False,
|
||||
commit=True,
|
||||
)
|
||||
|
||||
# Single commit so the FK reference in the join table can never
|
||||
# outrun the parent hierarchy_node insert.
|
||||
db_session.commit()
|
||||
|
||||
cache_entries = [
|
||||
HierarchyNodeCacheEntry.from_db_model(node)
|
||||
for node in upserted_nodes
|
||||
@@ -653,46 +636,40 @@ def connector_pruning_generator_task(
|
||||
commit=True,
|
||||
)
|
||||
|
||||
diff_start = time.monotonic()
|
||||
try:
|
||||
# a list of docs in our local index
|
||||
all_indexed_document_ids = {
|
||||
doc.id
|
||||
for doc in get_documents_for_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
)
|
||||
}
|
||||
# a list of docs in our local index
|
||||
all_indexed_document_ids = {
|
||||
doc.id
|
||||
for doc in get_documents_for_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
)
|
||||
}
|
||||
|
||||
# generate list of docs to remove (no longer in the source)
|
||||
doc_ids_to_remove = list(
|
||||
all_indexed_document_ids - all_connector_doc_ids.keys()
|
||||
)
|
||||
# generate list of docs to remove (no longer in the source)
|
||||
doc_ids_to_remove = list(
|
||||
all_indexed_document_ids - all_connector_doc_ids.keys()
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
"Pruning set collected: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"connector_source={connector_source} "
|
||||
f"docs_to_remove={len(doc_ids_to_remove)}"
|
||||
)
|
||||
task_logger.info(
|
||||
"Pruning set collected: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"connector_source={cc_pair.connector.source} "
|
||||
f"docs_to_remove={len(doc_ids_to_remove)}"
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"RedisConnector.prune.generate_tasks starting. cc_pair={cc_pair_id}"
|
||||
)
|
||||
tasks_generated = redis_connector.prune.generate_tasks(
|
||||
set(doc_ids_to_remove), self.app, db_session, None
|
||||
)
|
||||
if tasks_generated is None:
|
||||
return None
|
||||
task_logger.info(
|
||||
f"RedisConnector.prune.generate_tasks starting. cc_pair={cc_pair_id}"
|
||||
)
|
||||
tasks_generated = redis_connector.prune.generate_tasks(
|
||||
set(doc_ids_to_remove), self.app, db_session, None
|
||||
)
|
||||
if tasks_generated is None:
|
||||
return None
|
||||
|
||||
task_logger.info(
|
||||
f"RedisConnector.prune.generate_tasks finished. cc_pair={cc_pair_id} tasks_generated={tasks_generated}"
|
||||
)
|
||||
finally:
|
||||
observe_pruning_diff_duration(
|
||||
time.monotonic() - diff_start, connector_type
|
||||
)
|
||||
task_logger.info(
|
||||
f"RedisConnector.prune.generate_tasks finished. cc_pair={cc_pair_id} tasks_generated={tasks_generated}"
|
||||
)
|
||||
|
||||
redis_connector.prune.generator_complete = tasks_generated
|
||||
|
||||
|
||||
@@ -23,8 +23,6 @@ class IndexAttemptErrorPydantic(BaseModel):
|
||||
|
||||
index_attempt_id: int
|
||||
|
||||
error_type: str | None = None
|
||||
|
||||
@classmethod
|
||||
def from_model(cls, model: IndexAttemptError) -> "IndexAttemptErrorPydantic":
|
||||
return cls(
|
||||
@@ -39,5 +37,4 @@ class IndexAttemptErrorPydantic(BaseModel):
|
||||
is_resolved=model.is_resolved,
|
||||
time_created=model.time_created,
|
||||
index_attempt_id=model.index_attempt_id,
|
||||
error_type=model.error_type,
|
||||
)
|
||||
|
||||
@@ -5,7 +5,6 @@ from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
|
||||
import sentry_sdk
|
||||
from celery import Celery
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -69,7 +68,6 @@ from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.server.features.build.indexing.persistent_document_writer import (
|
||||
get_persistent_document_writer,
|
||||
)
|
||||
from onyx.server.metrics.connector_health_metrics import on_index_attempt_status_change
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.middleware import make_randomized_onyx_request_id
|
||||
from onyx.utils.postgres_sanitization import sanitize_document_for_postgres
|
||||
@@ -269,13 +267,6 @@ def run_docfetching_entrypoint(
|
||||
)
|
||||
credential_id = attempt.connector_credential_pair.credential_id
|
||||
|
||||
on_index_attempt_status_change(
|
||||
tenant_id=tenant_id,
|
||||
source=attempt.connector_credential_pair.connector.source.value,
|
||||
cc_pair_id=connector_credential_pair_id,
|
||||
status="in_progress",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Docfetching starting{tenant_str}: "
|
||||
f"connector='{connector_name}' "
|
||||
@@ -565,27 +556,6 @@ def connector_document_extraction(
|
||||
|
||||
# save record of any failures at the connector level
|
||||
if failure is not None:
|
||||
if failure.exception is not None:
|
||||
with sentry_sdk.new_scope() as scope:
|
||||
scope.set_tag("stage", "connector_fetch")
|
||||
scope.set_tag("connector_source", db_connector.source.value)
|
||||
scope.set_tag("cc_pair_id", str(cc_pair_id))
|
||||
scope.set_tag("index_attempt_id", str(index_attempt_id))
|
||||
scope.set_tag("tenant_id", tenant_id)
|
||||
if failure.failed_document:
|
||||
scope.set_tag(
|
||||
"doc_id", failure.failed_document.document_id
|
||||
)
|
||||
if failure.failed_entity:
|
||||
scope.set_tag(
|
||||
"entity_id", failure.failed_entity.entity_id
|
||||
)
|
||||
scope.fingerprint = [
|
||||
"connector-fetch-failure",
|
||||
db_connector.source.value,
|
||||
type(failure.exception).__name__,
|
||||
]
|
||||
sentry_sdk.capture_exception(failure.exception)
|
||||
total_failures += 1
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
create_index_attempt_error(
|
||||
|
||||
@@ -364,7 +364,7 @@ def _get_or_extract_plaintext(
|
||||
plaintext_io = file_store.read_file(plaintext_key, mode="b")
|
||||
return plaintext_io.read().decode("utf-8")
|
||||
except Exception:
|
||||
logger.info(f"Cache miss for file with id={file_id}")
|
||||
logger.exception(f"Error when reading file, id={file_id}")
|
||||
|
||||
# Cache miss — extract and store.
|
||||
content_text = extract_fn()
|
||||
|
||||
@@ -4,6 +4,8 @@ from collections.abc import Callable
|
||||
from typing import Any
|
||||
from typing import Literal
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.chat.chat_state import ChatStateContainer
|
||||
from onyx.chat.chat_utils import create_tool_call_failure_messages
|
||||
from onyx.chat.citation_processor import CitationMapping
|
||||
@@ -633,6 +635,7 @@ def run_llm_loop(
|
||||
user_memory_context: UserMemoryContext | None,
|
||||
llm: LLM,
|
||||
token_counter: Callable[[str], int],
|
||||
db_session: Session,
|
||||
forced_tool_id: int | None = None,
|
||||
user_identity: LLMUserIdentity | None = None,
|
||||
chat_session_id: str | None = None,
|
||||
@@ -1017,16 +1020,20 @@ def run_llm_loop(
|
||||
persisted_memory_id: int | None = None
|
||||
if user_memory_context and user_memory_context.user_id:
|
||||
if tool_response.rich_response.index_to_replace is not None:
|
||||
persisted_memory_id = update_memory_at_index(
|
||||
memory = update_memory_at_index(
|
||||
user_id=user_memory_context.user_id,
|
||||
index=tool_response.rich_response.index_to_replace,
|
||||
new_text=tool_response.rich_response.memory_text,
|
||||
db_session=db_session,
|
||||
)
|
||||
persisted_memory_id = memory.id if memory else None
|
||||
else:
|
||||
persisted_memory_id = add_memory(
|
||||
memory = add_memory(
|
||||
user_id=user_memory_context.user_id,
|
||||
memory_text=tool_response.rich_response.memory_text,
|
||||
db_session=db_session,
|
||||
)
|
||||
persisted_memory_id = memory.id
|
||||
operation: Literal["add", "update"] = (
|
||||
"update"
|
||||
if tool_response.rich_response.index_to_replace is not None
|
||||
|
||||
@@ -67,6 +67,7 @@ from onyx.db.chat import get_chat_session_by_id
|
||||
from onyx.db.chat import get_or_create_root_message
|
||||
from onyx.db.chat import reserve_message_id
|
||||
from onyx.db.chat import reserve_multi_model_message_ids
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import HookPoint
|
||||
from onyx.db.memory import get_memories
|
||||
from onyx.db.models import ChatMessage
|
||||
@@ -93,7 +94,6 @@ from onyx.llm.factory import get_llm_for_persona
|
||||
from onyx.llm.factory import get_llm_token_counter
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.interfaces import LLMUserIdentity
|
||||
from onyx.llm.multi_llm import LLMTimeoutError
|
||||
from onyx.llm.override_models import LLMOverride
|
||||
from onyx.llm.request_context import reset_llm_mock_response
|
||||
from onyx.llm.request_context import set_llm_mock_response
|
||||
@@ -996,7 +996,6 @@ def _run_models(
|
||||
|
||||
def _run_model(model_idx: int) -> None:
|
||||
"""Run one LLM loop inside a worker thread, writing packets to ``merged_queue``."""
|
||||
|
||||
model_emitter = Emitter(
|
||||
model_idx=model_idx,
|
||||
merged_queue=merged_queue,
|
||||
@@ -1006,86 +1005,93 @@ def _run_models(
|
||||
model_llm = setup.llms[model_idx]
|
||||
|
||||
try:
|
||||
# Each function opens short-lived DB sessions on demand.
|
||||
# Do NOT pass a long-lived session here — it would hold a
|
||||
# connection for the entire LLM loop (minutes), and cloud
|
||||
# infrastructure may drop idle connections.
|
||||
thread_tool_dict = construct_tools(
|
||||
persona=setup.persona,
|
||||
emitter=model_emitter,
|
||||
user=user,
|
||||
llm=model_llm,
|
||||
search_tool_config=SearchToolConfig(
|
||||
user_selected_filters=setup.new_msg_req.internal_search_filters,
|
||||
project_id_filter=setup.search_params.project_id_filter,
|
||||
persona_id_filter=setup.search_params.persona_id_filter,
|
||||
bypass_acl=setup.bypass_acl,
|
||||
slack_context=setup.slack_context,
|
||||
enable_slack_search=_should_enable_slack_search(
|
||||
setup.persona, setup.new_msg_req.internal_search_filters
|
||||
),
|
||||
),
|
||||
custom_tool_config=CustomToolConfig(
|
||||
chat_session_id=setup.chat_session.id,
|
||||
message_id=setup.user_message.id,
|
||||
additional_headers=setup.custom_tool_additional_headers,
|
||||
mcp_headers=setup.mcp_headers,
|
||||
),
|
||||
file_reader_tool_config=FileReaderToolConfig(
|
||||
user_file_ids=setup.available_files.user_file_ids,
|
||||
chat_file_ids=setup.available_files.chat_file_ids,
|
||||
),
|
||||
allowed_tool_ids=setup.new_msg_req.allowed_tool_ids,
|
||||
search_usage_forcing_setting=setup.search_params.search_usage,
|
||||
)
|
||||
model_tools = [
|
||||
tool for tool_list in thread_tool_dict.values() for tool in tool_list
|
||||
]
|
||||
|
||||
if setup.forced_tool_id and setup.forced_tool_id not in {
|
||||
tool.id for tool in model_tools
|
||||
}:
|
||||
raise ValueError(
|
||||
f"Forced tool {setup.forced_tool_id} not found in tools"
|
||||
)
|
||||
|
||||
# Per-thread copy: run_llm_loop mutates simple_chat_history in-place.
|
||||
if n_models == 1 and setup.new_msg_req.deep_research:
|
||||
if setup.chat_session.project_id:
|
||||
raise RuntimeError("Deep research is not supported for projects")
|
||||
run_deep_research_llm_loop(
|
||||
emitter=model_emitter,
|
||||
state_container=sc,
|
||||
simple_chat_history=list(setup.simple_chat_history),
|
||||
tools=model_tools,
|
||||
custom_agent_prompt=setup.custom_agent_prompt,
|
||||
llm=model_llm,
|
||||
token_counter=get_llm_token_counter(model_llm),
|
||||
skip_clarification=setup.skip_clarification,
|
||||
user_identity=setup.user_identity,
|
||||
chat_session_id=str(setup.chat_session.id),
|
||||
all_injected_file_metadata=setup.all_injected_file_metadata,
|
||||
)
|
||||
else:
|
||||
run_llm_loop(
|
||||
emitter=model_emitter,
|
||||
state_container=sc,
|
||||
simple_chat_history=list(setup.simple_chat_history),
|
||||
tools=model_tools,
|
||||
custom_agent_prompt=setup.custom_agent_prompt,
|
||||
context_files=setup.extracted_context_files,
|
||||
# Each worker opens its own session — SQLAlchemy sessions are not thread-safe.
|
||||
# Do NOT write to the outer db_session (or any shared DB state) from here;
|
||||
# all DB writes in this thread must go through thread_db_session.
|
||||
with get_session_with_current_tenant() as thread_db_session:
|
||||
thread_tool_dict = construct_tools(
|
||||
persona=setup.persona,
|
||||
user_memory_context=setup.user_memory_context,
|
||||
db_session=thread_db_session,
|
||||
emitter=model_emitter,
|
||||
user=user,
|
||||
llm=model_llm,
|
||||
token_counter=get_llm_token_counter(model_llm),
|
||||
forced_tool_id=setup.forced_tool_id,
|
||||
user_identity=setup.user_identity,
|
||||
chat_session_id=str(setup.chat_session.id),
|
||||
chat_files=setup.chat_files_for_tools,
|
||||
include_citations=setup.new_msg_req.include_citations,
|
||||
all_injected_file_metadata=setup.all_injected_file_metadata,
|
||||
inject_memories_in_prompt=user.use_memories,
|
||||
search_tool_config=SearchToolConfig(
|
||||
user_selected_filters=setup.new_msg_req.internal_search_filters,
|
||||
project_id_filter=setup.search_params.project_id_filter,
|
||||
persona_id_filter=setup.search_params.persona_id_filter,
|
||||
bypass_acl=setup.bypass_acl,
|
||||
slack_context=setup.slack_context,
|
||||
enable_slack_search=_should_enable_slack_search(
|
||||
setup.persona, setup.new_msg_req.internal_search_filters
|
||||
),
|
||||
),
|
||||
custom_tool_config=CustomToolConfig(
|
||||
chat_session_id=setup.chat_session.id,
|
||||
message_id=setup.user_message.id,
|
||||
additional_headers=setup.custom_tool_additional_headers,
|
||||
mcp_headers=setup.mcp_headers,
|
||||
),
|
||||
file_reader_tool_config=FileReaderToolConfig(
|
||||
user_file_ids=setup.available_files.user_file_ids,
|
||||
chat_file_ids=setup.available_files.chat_file_ids,
|
||||
),
|
||||
allowed_tool_ids=setup.new_msg_req.allowed_tool_ids,
|
||||
search_usage_forcing_setting=setup.search_params.search_usage,
|
||||
)
|
||||
model_tools = [
|
||||
tool
|
||||
for tool_list in thread_tool_dict.values()
|
||||
for tool in tool_list
|
||||
]
|
||||
|
||||
if setup.forced_tool_id and setup.forced_tool_id not in {
|
||||
tool.id for tool in model_tools
|
||||
}:
|
||||
raise ValueError(
|
||||
f"Forced tool {setup.forced_tool_id} not found in tools"
|
||||
)
|
||||
|
||||
# Per-thread copy: run_llm_loop mutates simple_chat_history in-place.
|
||||
if n_models == 1 and setup.new_msg_req.deep_research:
|
||||
if setup.chat_session.project_id:
|
||||
raise RuntimeError(
|
||||
"Deep research is not supported for projects"
|
||||
)
|
||||
run_deep_research_llm_loop(
|
||||
emitter=model_emitter,
|
||||
state_container=sc,
|
||||
simple_chat_history=list(setup.simple_chat_history),
|
||||
tools=model_tools,
|
||||
custom_agent_prompt=setup.custom_agent_prompt,
|
||||
llm=model_llm,
|
||||
token_counter=get_llm_token_counter(model_llm),
|
||||
db_session=thread_db_session,
|
||||
skip_clarification=setup.skip_clarification,
|
||||
user_identity=setup.user_identity,
|
||||
chat_session_id=str(setup.chat_session.id),
|
||||
all_injected_file_metadata=setup.all_injected_file_metadata,
|
||||
)
|
||||
else:
|
||||
run_llm_loop(
|
||||
emitter=model_emitter,
|
||||
state_container=sc,
|
||||
simple_chat_history=list(setup.simple_chat_history),
|
||||
tools=model_tools,
|
||||
custom_agent_prompt=setup.custom_agent_prompt,
|
||||
context_files=setup.extracted_context_files,
|
||||
persona=setup.persona,
|
||||
user_memory_context=setup.user_memory_context,
|
||||
llm=model_llm,
|
||||
token_counter=get_llm_token_counter(model_llm),
|
||||
db_session=thread_db_session,
|
||||
forced_tool_id=setup.forced_tool_id,
|
||||
user_identity=setup.user_identity,
|
||||
chat_session_id=str(setup.chat_session.id),
|
||||
chat_files=setup.chat_files_for_tools,
|
||||
include_citations=setup.new_msg_req.include_citations,
|
||||
all_injected_file_metadata=setup.all_injected_file_metadata,
|
||||
inject_memories_in_prompt=user.use_memories,
|
||||
)
|
||||
|
||||
model_succeeded[model_idx] = True
|
||||
|
||||
@@ -1096,33 +1102,33 @@ def _run_models(
|
||||
finally:
|
||||
merged_queue.put((model_idx, _MODEL_DONE))
|
||||
|
||||
def _save_errored_message(model_idx: int, context: str) -> None:
|
||||
"""Save an error message to a reserved ChatMessage that failed during execution."""
|
||||
def _delete_orphaned_message(model_idx: int, context: str) -> None:
|
||||
"""Delete a reserved ChatMessage that was never populated due to a model error."""
|
||||
try:
|
||||
msg = db_session.get(ChatMessage, setup.reserved_messages[model_idx].id)
|
||||
if msg is not None:
|
||||
error_text = f"Error from {setup.model_display_names[model_idx]}: model encountered an error during generation."
|
||||
msg.message = error_text
|
||||
msg.error = error_text
|
||||
orphaned = db_session.get(
|
||||
ChatMessage, setup.reserved_messages[model_idx].id
|
||||
)
|
||||
if orphaned is not None:
|
||||
db_session.delete(orphaned)
|
||||
db_session.commit()
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"%s error save failed for model %d (%s)",
|
||||
"%s orphan cleanup failed for model %d (%s)",
|
||||
context,
|
||||
model_idx,
|
||||
setup.model_display_names[model_idx],
|
||||
)
|
||||
|
||||
# Each worker thread needs its own Context copy — a single Context object
|
||||
# cannot be entered concurrently by multiple threads (RuntimeError).
|
||||
# Copy contextvars before submitting futures — ThreadPoolExecutor does NOT
|
||||
# auto-propagate contextvars in Python 3.11; threads would inherit a blank context.
|
||||
worker_context = contextvars.copy_context()
|
||||
executor = ThreadPoolExecutor(
|
||||
max_workers=n_models, thread_name_prefix="multi-model"
|
||||
)
|
||||
completion_persisted: bool = False
|
||||
try:
|
||||
for i in range(n_models):
|
||||
ctx = contextvars.copy_context()
|
||||
executor.submit(ctx.run, _run_model, i)
|
||||
executor.submit(worker_context.run, _run_model, i)
|
||||
|
||||
# ── Main thread: merge and yield packets ────────────────────────────
|
||||
models_remaining = n_models
|
||||
@@ -1139,7 +1145,7 @@ def _run_models(
|
||||
# save "stopped by user" for a model that actually threw an exception.
|
||||
for i in range(n_models):
|
||||
if model_errored[i]:
|
||||
_save_errored_message(i, "stop-button")
|
||||
_delete_orphaned_message(i, "stop-button")
|
||||
continue
|
||||
try:
|
||||
succeeded = model_succeeded[i]
|
||||
@@ -1167,32 +1173,6 @@ def _run_models(
|
||||
else:
|
||||
if item is _MODEL_DONE:
|
||||
models_remaining -= 1
|
||||
elif isinstance(item, LLMTimeoutError):
|
||||
model_llm = setup.llms[model_idx]
|
||||
error_msg = (
|
||||
"The LLM took too long to respond. "
|
||||
"If you're running a local model, try increasing the "
|
||||
"LLM_SOCKET_READ_TIMEOUT environment variable "
|
||||
"(current default: 120 seconds)."
|
||||
)
|
||||
stack_trace = "".join(
|
||||
traceback.format_exception(type(item), item, item.__traceback__)
|
||||
)
|
||||
if model_llm.config.api_key and len(model_llm.config.api_key) > 2:
|
||||
stack_trace = stack_trace.replace(
|
||||
model_llm.config.api_key, "[REDACTED_API_KEY]"
|
||||
)
|
||||
yield StreamingError(
|
||||
error=error_msg,
|
||||
stack_trace=stack_trace,
|
||||
error_code="CONNECTION_ERROR",
|
||||
is_retryable=True,
|
||||
details={
|
||||
"model": model_llm.config.model_name,
|
||||
"provider": model_llm.config.model_provider,
|
||||
"model_index": model_idx,
|
||||
},
|
||||
)
|
||||
elif isinstance(item, Exception):
|
||||
# Yield a tagged error for this model but keep the other models running.
|
||||
# Do NOT decrement models_remaining — _run_model's finally always posts
|
||||
@@ -1231,7 +1211,7 @@ def _run_models(
|
||||
for i in range(n_models):
|
||||
if not model_succeeded[i]:
|
||||
# Model errored — delete its orphaned reserved message.
|
||||
_save_errored_message(i, "normal")
|
||||
_delete_orphaned_message(i, "normal")
|
||||
continue
|
||||
try:
|
||||
llm_loop_completion_handle(
|
||||
@@ -1284,7 +1264,7 @@ def _run_models(
|
||||
setup.model_display_names[i],
|
||||
)
|
||||
elif model_errored[i]:
|
||||
_save_errored_message(i, "disconnect")
|
||||
_delete_orphaned_message(i, "disconnect")
|
||||
# 4. Drain buffered packets from memory — no consumer is running.
|
||||
while not merged_queue.empty():
|
||||
try:
|
||||
|
||||
@@ -379,14 +379,6 @@ POSTGRES_HOST = os.environ.get("POSTGRES_HOST") or "127.0.0.1"
|
||||
POSTGRES_PORT = os.environ.get("POSTGRES_PORT") or "5432"
|
||||
POSTGRES_DB = os.environ.get("POSTGRES_DB") or "postgres"
|
||||
AWS_REGION_NAME = os.environ.get("AWS_REGION_NAME") or "us-east-2"
|
||||
# Comma-separated replica / multi-host list. If unset, defaults to POSTGRES_HOST
|
||||
# only.
|
||||
_POSTGRES_HOSTS_STR = os.environ.get("POSTGRES_HOSTS", "").strip()
|
||||
POSTGRES_HOSTS: list[str] = (
|
||||
[h.strip() for h in _POSTGRES_HOSTS_STR.split(",") if h.strip()]
|
||||
if _POSTGRES_HOSTS_STR
|
||||
else [POSTGRES_HOST]
|
||||
)
|
||||
|
||||
POSTGRES_API_SERVER_POOL_SIZE = int(
|
||||
os.environ.get("POSTGRES_API_SERVER_POOL_SIZE") or 40
|
||||
|
||||
@@ -12,11 +12,6 @@ SLACK_USER_TOKEN_PREFIX = "xoxp-"
|
||||
SLACK_BOT_TOKEN_PREFIX = "xoxb-"
|
||||
ONYX_EMAILABLE_LOGO_MAX_DIM = 512
|
||||
|
||||
# The mask_string() function in encryption.py uses "•" (U+2022 BULLET) to mask secrets.
|
||||
MASK_CREDENTIAL_CHAR = "\u2022"
|
||||
# Pattern produced by mask_string for strings >= 14 chars: "abcd...wxyz" (exactly 11 chars)
|
||||
MASK_CREDENTIAL_LONG_RE = re.compile(r"^.{4}\.{3}.{4}$")
|
||||
|
||||
SOURCE_TYPE = "source_type"
|
||||
# stored in the `metadata` of a chunk. Used to signify that this chunk should
|
||||
# not be used for QA. For example, Google Drive file types which can't be parsed
|
||||
@@ -283,7 +278,6 @@ class NotificationType(str, Enum):
|
||||
RELEASE_NOTES = "release_notes"
|
||||
ASSISTANT_FILES_READY = "assistant_files_ready"
|
||||
FEATURE_ANNOUNCEMENT = "feature_announcement"
|
||||
CONNECTOR_REPEATED_ERRORS = "connector_repeated_errors"
|
||||
|
||||
|
||||
class BlobType(str, Enum):
|
||||
@@ -397,6 +391,10 @@ class MilestoneRecordType(str, Enum):
|
||||
REQUESTED_CONNECTOR = "requested_connector"
|
||||
|
||||
|
||||
class PostgresAdvisoryLocks(Enum):
|
||||
KOMBU_MESSAGE_CLEANUP_LOCK_ID = auto()
|
||||
|
||||
|
||||
class OnyxCeleryQueues:
|
||||
# "celery" is the default queue defined by celery and also the queue
|
||||
# we are running in the primary worker to run system tasks
|
||||
@@ -579,6 +577,7 @@ class OnyxCeleryTask:
|
||||
MONITOR_PROCESS_MEMORY = "monitor_process_memory"
|
||||
CELERY_BEAT_HEARTBEAT = "celery_beat_heartbeat"
|
||||
|
||||
KOMBU_MESSAGE_CLEANUP_TASK = "kombu_message_cleanup_task"
|
||||
CONNECTOR_PERMISSION_SYNC_GENERATOR_TASK = (
|
||||
"connector_permission_sync_generator_task"
|
||||
)
|
||||
|
||||
@@ -1,48 +0,0 @@
|
||||
from typing import Any
|
||||
|
||||
from sentry_sdk.types import Event
|
||||
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_instance_id_resolved = False
|
||||
|
||||
|
||||
def _add_instance_tags(
|
||||
event: Event,
|
||||
hint: dict[str, Any], # noqa: ARG001
|
||||
) -> Event | None:
|
||||
"""Sentry before_send hook that lazily attaches instance identification tags.
|
||||
|
||||
On the first event, resolves the instance UUID from the KV store (requires DB)
|
||||
and sets it as a global Sentry tag. Subsequent events pick it up automatically.
|
||||
"""
|
||||
global _instance_id_resolved
|
||||
|
||||
if _instance_id_resolved:
|
||||
return event
|
||||
|
||||
try:
|
||||
import sentry_sdk
|
||||
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
if MULTI_TENANT:
|
||||
instance_id = "multi-tenant-cloud"
|
||||
else:
|
||||
from onyx.utils.telemetry import get_or_generate_uuid
|
||||
|
||||
instance_id = get_or_generate_uuid()
|
||||
|
||||
sentry_sdk.set_tag("instance_id", instance_id)
|
||||
|
||||
# Also set on this event since set_tag won't retroactively apply
|
||||
event.setdefault("tags", {})["instance_id"] = instance_id
|
||||
|
||||
# Only mark resolved after success — if DB wasn't ready, retry next event
|
||||
_instance_id_resolved = True
|
||||
except Exception:
|
||||
logger.debug("Failed to resolve instance_id for Sentry tagging")
|
||||
|
||||
return event
|
||||
@@ -27,19 +27,16 @@ _STATUS_TO_ERROR_CODE: dict[int, OnyxErrorCode] = {
|
||||
401: OnyxErrorCode.CREDENTIAL_EXPIRED,
|
||||
403: OnyxErrorCode.INSUFFICIENT_PERMISSIONS,
|
||||
404: OnyxErrorCode.BAD_GATEWAY,
|
||||
429: OnyxErrorCode.RATE_LIMITED,
|
||||
}
|
||||
|
||||
|
||||
def _error_code_for_status(status_code: int) -> OnyxErrorCode:
|
||||
"""Map an HTTP status code to the appropriate OnyxErrorCode.
|
||||
|
||||
Expects a >= 400 status code. Known codes (401, 403, 404) are
|
||||
Expects a >= 400 status code. Known codes (401, 403, 404, 429) are
|
||||
mapped to specific error codes; all other codes (unrecognised 4xx
|
||||
and 5xx) map to BAD_GATEWAY as unexpected upstream errors.
|
||||
|
||||
Note: 429 is intentionally omitted — the rl_requests wrapper
|
||||
handles rate limits transparently at the HTTP layer, so 429
|
||||
responses never reach this function.
|
||||
"""
|
||||
if status_code in _STATUS_TO_ERROR_CODE:
|
||||
return _STATUS_TO_ERROR_CODE[status_code]
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import Literal
|
||||
from typing import NoReturn
|
||||
from typing import TypeAlias
|
||||
|
||||
from pydantic import BaseModel
|
||||
from retry import retry
|
||||
@@ -24,11 +25,8 @@ from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.interfaces import SlimConnectorWithPermSync
|
||||
from onyx.connectors.models import ConnectorCheckpoint
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import DocumentFailure
|
||||
from onyx.connectors.models import EntityFailure
|
||||
from onyx.connectors.models import ImageSection
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
@@ -49,6 +47,10 @@ def _handle_canvas_api_error(e: OnyxError) -> NoReturn:
|
||||
raise InsufficientPermissionsError(
|
||||
"Canvas API token does not have sufficient permissions (HTTP 403)."
|
||||
)
|
||||
elif e.status_code == 429:
|
||||
raise ConnectorValidationError(
|
||||
"Canvas rate-limit exceeded (HTTP 429). Please try again later."
|
||||
)
|
||||
elif e.status_code >= 500:
|
||||
raise UnexpectedValidationError(
|
||||
f"Unexpected Canvas HTTP error (status={e.status_code}): {e}"
|
||||
@@ -59,60 +61,6 @@ def _handle_canvas_api_error(e: OnyxError) -> NoReturn:
|
||||
)
|
||||
|
||||
|
||||
class CanvasStage(StrEnum):
|
||||
PAGES = "pages"
|
||||
ASSIGNMENTS = "assignments"
|
||||
ANNOUNCEMENTS = "announcements"
|
||||
|
||||
|
||||
_STAGE_CONFIG: dict[CanvasStage, dict[str, Any]] = {
|
||||
CanvasStage.PAGES: {
|
||||
"endpoint": "courses/{course_id}/pages",
|
||||
"params": {
|
||||
"per_page": "100",
|
||||
"include[]": "body",
|
||||
"published": "true",
|
||||
"sort": "updated_at",
|
||||
"order": "desc",
|
||||
},
|
||||
},
|
||||
CanvasStage.ASSIGNMENTS: {
|
||||
"endpoint": "courses/{course_id}/assignments",
|
||||
"params": {"per_page": "100", "published": "true"},
|
||||
},
|
||||
CanvasStage.ANNOUNCEMENTS: {
|
||||
"endpoint": "announcements",
|
||||
"params": {
|
||||
"per_page": "100",
|
||||
"context_codes[]": "course_{course_id}",
|
||||
"active_only": "true",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _parse_canvas_dt(timestamp_str: str) -> datetime:
|
||||
"""Parse a Canvas ISO-8601 timestamp (e.g. '2025-06-15T12:00:00Z')
|
||||
into a timezone-aware UTC datetime.
|
||||
|
||||
Canvas returns timestamps with a trailing 'Z' instead of '+00:00',
|
||||
so we normalise before parsing.
|
||||
"""
|
||||
return datetime.fromisoformat(timestamp_str.replace("Z", "+00:00")).astimezone(
|
||||
timezone.utc
|
||||
)
|
||||
|
||||
|
||||
def _unix_to_canvas_time(epoch: float) -> str:
|
||||
"""Convert a Unix timestamp to Canvas ISO-8601 format (e.g. '2025-06-15T12:00:00Z')."""
|
||||
return datetime.fromtimestamp(epoch, tz=timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")
|
||||
|
||||
|
||||
def _in_time_window(timestamp_str: str, start: float, end: float) -> bool:
|
||||
"""Check whether a Canvas ISO-8601 timestamp falls within (start, end]."""
|
||||
return start < _parse_canvas_dt(timestamp_str).timestamp() <= end
|
||||
|
||||
|
||||
class CanvasCourse(BaseModel):
|
||||
id: int
|
||||
name: str | None = None
|
||||
@@ -197,6 +145,9 @@ class CanvasAnnouncement(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
CanvasStage: TypeAlias = Literal["pages", "assignments", "announcements"]
|
||||
|
||||
|
||||
class CanvasConnectorCheckpoint(ConnectorCheckpoint):
|
||||
"""Checkpoint state for resumable Canvas indexing.
|
||||
|
||||
@@ -214,30 +165,15 @@ class CanvasConnectorCheckpoint(ConnectorCheckpoint):
|
||||
|
||||
course_ids: list[int] = []
|
||||
current_course_index: int = 0
|
||||
stage: CanvasStage = CanvasStage.PAGES
|
||||
stage: CanvasStage = "pages"
|
||||
next_url: str | None = None
|
||||
|
||||
def advance_course(self) -> None:
|
||||
"""Move to the next course and reset within-course state."""
|
||||
self.current_course_index += 1
|
||||
self.stage = CanvasStage.PAGES
|
||||
self.stage = "pages"
|
||||
self.next_url = None
|
||||
|
||||
def advance_stage(self) -> None:
|
||||
"""Advance past the current stage.
|
||||
|
||||
Moves to the next stage within the same course, or to the next
|
||||
course if the current stage is the last one. Resets next_url so
|
||||
the next call starts fresh on the new stage.
|
||||
"""
|
||||
self.next_url = None
|
||||
stages: list[CanvasStage] = list(CanvasStage)
|
||||
next_idx = stages.index(self.stage) + 1
|
||||
if next_idx < len(stages):
|
||||
self.stage = stages[next_idx]
|
||||
else:
|
||||
self.advance_course()
|
||||
|
||||
|
||||
class CanvasConnector(
|
||||
CheckpointedConnectorWithPermSync[CanvasConnectorCheckpoint],
|
||||
@@ -359,7 +295,13 @@ class CanvasConnector(
|
||||
if body_text:
|
||||
text_parts.append(body_text)
|
||||
|
||||
doc_updated_at = _parse_canvas_dt(page.updated_at) if page.updated_at else None
|
||||
doc_updated_at = (
|
||||
datetime.fromisoformat(page.updated_at.replace("Z", "+00:00")).astimezone(
|
||||
timezone.utc
|
||||
)
|
||||
if page.updated_at
|
||||
else None
|
||||
)
|
||||
|
||||
document = self._build_document(
|
||||
doc_id=f"canvas-page-{page.course_id}-{page.page_id}",
|
||||
@@ -383,11 +325,17 @@ class CanvasConnector(
|
||||
if desc_text:
|
||||
text_parts.append(desc_text)
|
||||
if assignment.due_at:
|
||||
due_dt = _parse_canvas_dt(assignment.due_at)
|
||||
due_dt = datetime.fromisoformat(
|
||||
assignment.due_at.replace("Z", "+00:00")
|
||||
).astimezone(timezone.utc)
|
||||
text_parts.append(f"Due: {due_dt.strftime('%B %d, %Y %H:%M UTC')}")
|
||||
|
||||
doc_updated_at = (
|
||||
_parse_canvas_dt(assignment.updated_at) if assignment.updated_at else None
|
||||
datetime.fromisoformat(
|
||||
assignment.updated_at.replace("Z", "+00:00")
|
||||
).astimezone(timezone.utc)
|
||||
if assignment.updated_at
|
||||
else None
|
||||
)
|
||||
|
||||
document = self._build_document(
|
||||
@@ -413,7 +361,11 @@ class CanvasConnector(
|
||||
text_parts.append(msg_text)
|
||||
|
||||
doc_updated_at = (
|
||||
_parse_canvas_dt(announcement.posted_at) if announcement.posted_at else None
|
||||
datetime.fromisoformat(
|
||||
announcement.posted_at.replace("Z", "+00:00")
|
||||
).astimezone(timezone.utc)
|
||||
if announcement.posted_at
|
||||
else None
|
||||
)
|
||||
|
||||
document = self._build_document(
|
||||
@@ -448,314 +400,6 @@ class CanvasConnector(
|
||||
self._canvas_client = client
|
||||
return None
|
||||
|
||||
def _fetch_stage_page(
|
||||
self,
|
||||
next_url: str | None,
|
||||
endpoint: str,
|
||||
params: dict[str, Any],
|
||||
) -> tuple[list[Any], str | None]:
|
||||
"""Fetch one page of API results for the current stage.
|
||||
|
||||
Returns (items, next_url). All error handling is done by the
|
||||
caller (_load_from_checkpoint).
|
||||
"""
|
||||
if next_url:
|
||||
# Resuming mid-pagination: the next_url from Canvas's
|
||||
# Link header already contains endpoint + query params.
|
||||
response, result_next_url = self.canvas_client.get(full_url=next_url)
|
||||
else:
|
||||
# First request for this stage: build from endpoint + params.
|
||||
response, result_next_url = self.canvas_client.get(
|
||||
endpoint=endpoint, params=params
|
||||
)
|
||||
return response or [], result_next_url
|
||||
|
||||
def _process_items(
|
||||
self,
|
||||
response: list[Any],
|
||||
stage: CanvasStage,
|
||||
course_id: int,
|
||||
start: float,
|
||||
end: float,
|
||||
include_permissions: bool,
|
||||
) -> tuple[list[Document | ConnectorFailure], bool]:
|
||||
"""Process a page of API results into documents.
|
||||
|
||||
Returns (docs, early_exit). early_exit is True when pages
|
||||
(sorted desc by updated_at) hit an item older than start,
|
||||
signaling that pagination should stop.
|
||||
"""
|
||||
results: list[Document | ConnectorFailure] = []
|
||||
early_exit = False
|
||||
|
||||
for item in response:
|
||||
try:
|
||||
if stage == CanvasStage.PAGES:
|
||||
page = CanvasPage.from_api(item, course_id=course_id)
|
||||
if not page.updated_at:
|
||||
continue
|
||||
# Pages are sorted by updated_at desc — once we see
|
||||
# an item at or before `start`, all remaining items
|
||||
# on this and subsequent pages are older too.
|
||||
if not _in_time_window(page.updated_at, start, end):
|
||||
if _parse_canvas_dt(page.updated_at).timestamp() <= start:
|
||||
early_exit = True
|
||||
break
|
||||
# ts > end: page is newer than our window, skip it
|
||||
continue
|
||||
doc = self._convert_page_to_document(page)
|
||||
results.append(
|
||||
self._maybe_attach_permissions(
|
||||
doc, course_id, include_permissions
|
||||
)
|
||||
)
|
||||
|
||||
elif stage == CanvasStage.ASSIGNMENTS:
|
||||
assignment = CanvasAssignment.from_api(item, course_id=course_id)
|
||||
if not assignment.updated_at or not _in_time_window(
|
||||
assignment.updated_at, start, end
|
||||
):
|
||||
continue
|
||||
doc = self._convert_assignment_to_document(assignment)
|
||||
results.append(
|
||||
self._maybe_attach_permissions(
|
||||
doc, course_id, include_permissions
|
||||
)
|
||||
)
|
||||
|
||||
elif stage == CanvasStage.ANNOUNCEMENTS:
|
||||
announcement = CanvasAnnouncement.from_api(
|
||||
item, course_id=course_id
|
||||
)
|
||||
if not announcement.posted_at:
|
||||
logger.debug(
|
||||
f"Skipping announcement {announcement.id} in "
|
||||
f"course {course_id}: no posted_at"
|
||||
)
|
||||
continue
|
||||
if not _in_time_window(announcement.posted_at, start, end):
|
||||
continue
|
||||
doc = self._convert_announcement_to_document(announcement)
|
||||
results.append(
|
||||
self._maybe_attach_permissions(
|
||||
doc, course_id, include_permissions
|
||||
)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
item_id = item.get("id") or item.get("page_id", "unknown")
|
||||
if stage == CanvasStage.PAGES:
|
||||
doc_link = (
|
||||
f"{self.canvas_base_url}/courses/{course_id}"
|
||||
f"/pages/{item.get('url', '')}"
|
||||
)
|
||||
else:
|
||||
doc_link = item.get("html_url", "")
|
||||
results.append(
|
||||
ConnectorFailure(
|
||||
failed_document=DocumentFailure(
|
||||
document_id=f"canvas-{stage.removesuffix('s')}-{course_id}-{item_id}",
|
||||
document_link=doc_link,
|
||||
),
|
||||
failure_message=f"Failed to process {stage.removesuffix('s')}: {e}",
|
||||
exception=e,
|
||||
)
|
||||
)
|
||||
|
||||
return results, early_exit
|
||||
|
||||
def _maybe_attach_permissions(
|
||||
self,
|
||||
document: Document,
|
||||
course_id: int,
|
||||
include_permissions: bool,
|
||||
) -> Document:
|
||||
if include_permissions:
|
||||
document.external_access = self._get_course_permissions(course_id)
|
||||
return document
|
||||
|
||||
def _load_from_checkpoint(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch,
|
||||
end: SecondsSinceUnixEpoch,
|
||||
checkpoint: CanvasConnectorCheckpoint,
|
||||
include_permissions: bool = False,
|
||||
) -> CheckpointOutput[CanvasConnectorCheckpoint]:
|
||||
"""Shared implementation for load_from_checkpoint and load_from_checkpoint_with_perm_sync."""
|
||||
new_checkpoint = checkpoint.model_copy(deep=True)
|
||||
|
||||
# First call: materialize the list of course IDs.
|
||||
# On failure, let the exception propagate so the framework fails the
|
||||
# attempt cleanly. Swallowing errors here would leave the checkpoint
|
||||
# state unchanged and cause an infinite retry loop.
|
||||
if not new_checkpoint.course_ids:
|
||||
try:
|
||||
courses = self._list_courses()
|
||||
except OnyxError as e:
|
||||
if e.status_code in (401, 403):
|
||||
_handle_canvas_api_error(e) # NoReturn — always raises
|
||||
raise
|
||||
new_checkpoint.course_ids = [c.id for c in courses]
|
||||
logger.info(f"Found {len(courses)} Canvas courses to process")
|
||||
new_checkpoint.has_more = len(new_checkpoint.course_ids) > 0
|
||||
return new_checkpoint
|
||||
|
||||
# All courses done.
|
||||
if new_checkpoint.current_course_index >= len(new_checkpoint.course_ids):
|
||||
new_checkpoint.has_more = False
|
||||
return new_checkpoint
|
||||
|
||||
course_id = new_checkpoint.course_ids[new_checkpoint.current_course_index]
|
||||
try:
|
||||
stage = CanvasStage(new_checkpoint.stage)
|
||||
except ValueError as e:
|
||||
raise ValueError(
|
||||
f"Invalid checkpoint stage: {new_checkpoint.stage!r}. "
|
||||
f"Valid stages: {[s.value for s in CanvasStage]}"
|
||||
) from e
|
||||
|
||||
# Build endpoint + params from the static template.
|
||||
config = _STAGE_CONFIG[stage]
|
||||
endpoint = config["endpoint"].format(course_id=course_id)
|
||||
params = {k: v.format(course_id=course_id) for k, v in config["params"].items()}
|
||||
# Only the announcements API supports server-side date filtering
|
||||
# (start_date/end_date). Pages support server-side sorting
|
||||
# (sort=updated_at desc) enabling early exit, but not date
|
||||
# filtering. Assignments support neither. Both are filtered
|
||||
# client-side via _in_time_window after fetching.
|
||||
if stage == CanvasStage.ANNOUNCEMENTS:
|
||||
params["start_date"] = _unix_to_canvas_time(start)
|
||||
params["end_date"] = _unix_to_canvas_time(end)
|
||||
|
||||
try:
|
||||
response, result_next_url = self._fetch_stage_page(
|
||||
next_url=new_checkpoint.next_url,
|
||||
endpoint=endpoint,
|
||||
params=params,
|
||||
)
|
||||
except OnyxError as oe:
|
||||
# Security errors from _parse_next_link (host/scheme
|
||||
# mismatch on pagination URLs) have no status code override
|
||||
# and must not be silenced.
|
||||
is_api_error = oe._status_code_override is not None
|
||||
if not is_api_error:
|
||||
raise
|
||||
if oe.status_code in (401, 403):
|
||||
_handle_canvas_api_error(oe) # NoReturn — always raises
|
||||
|
||||
# 404 means the course itself is gone or inaccessible. The
|
||||
# other stages on this course will hit the same 404, so skip
|
||||
# the whole course rather than burning API calls on each stage.
|
||||
if oe.status_code == 404:
|
||||
logger.warning(
|
||||
f"Canvas course {course_id} not found while fetching "
|
||||
f"{stage} (HTTP 404). Skipping course."
|
||||
)
|
||||
yield ConnectorFailure(
|
||||
failed_entity=EntityFailure(
|
||||
entity_id=f"canvas-course-{course_id}",
|
||||
),
|
||||
failure_message=(f"Canvas course {course_id} not found: {oe}"),
|
||||
exception=oe,
|
||||
)
|
||||
new_checkpoint.advance_course()
|
||||
else:
|
||||
logger.warning(
|
||||
f"Failed to fetch {stage} for course {course_id}: {oe}. "
|
||||
f"Skipping remainder of this stage."
|
||||
)
|
||||
yield ConnectorFailure(
|
||||
failed_entity=EntityFailure(
|
||||
entity_id=f"canvas-{stage}-{course_id}",
|
||||
),
|
||||
failure_message=(
|
||||
f"Failed to fetch {stage} for course {course_id}: {oe}"
|
||||
),
|
||||
exception=oe,
|
||||
)
|
||||
new_checkpoint.advance_stage()
|
||||
new_checkpoint.has_more = new_checkpoint.current_course_index < len(
|
||||
new_checkpoint.course_ids
|
||||
)
|
||||
return new_checkpoint
|
||||
except Exception as e:
|
||||
# Unknown error — skip the stage and try to continue.
|
||||
logger.warning(
|
||||
f"Failed to fetch {stage} for course {course_id}: {e}. "
|
||||
f"Skipping remainder of this stage."
|
||||
)
|
||||
yield ConnectorFailure(
|
||||
failed_entity=EntityFailure(
|
||||
entity_id=f"canvas-{stage}-{course_id}",
|
||||
),
|
||||
failure_message=(
|
||||
f"Failed to fetch {stage} for course {course_id}: {e}"
|
||||
),
|
||||
exception=e,
|
||||
)
|
||||
new_checkpoint.advance_stage()
|
||||
new_checkpoint.has_more = new_checkpoint.current_course_index < len(
|
||||
new_checkpoint.course_ids
|
||||
)
|
||||
return new_checkpoint
|
||||
|
||||
# Process fetched items
|
||||
results, early_exit = self._process_items(
|
||||
response, stage, course_id, start, end, include_permissions
|
||||
)
|
||||
for result in results:
|
||||
yield result
|
||||
|
||||
# If we hit an item older than our window (pages sorted desc),
|
||||
# skip remaining pagination and advance to the next stage.
|
||||
if early_exit:
|
||||
result_next_url = None
|
||||
|
||||
# If there are more pages, save the cursor and return
|
||||
if result_next_url:
|
||||
new_checkpoint.next_url = result_next_url
|
||||
else:
|
||||
# Stage complete — advance to next stage (or next course if last).
|
||||
new_checkpoint.advance_stage()
|
||||
|
||||
new_checkpoint.has_more = new_checkpoint.current_course_index < len(
|
||||
new_checkpoint.course_ids
|
||||
)
|
||||
return new_checkpoint
|
||||
|
||||
@override
|
||||
def load_from_checkpoint(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch,
|
||||
end: SecondsSinceUnixEpoch,
|
||||
checkpoint: CanvasConnectorCheckpoint,
|
||||
) -> CheckpointOutput[CanvasConnectorCheckpoint]:
|
||||
return self._load_from_checkpoint(
|
||||
start, end, checkpoint, include_permissions=False
|
||||
)
|
||||
|
||||
@override
|
||||
def load_from_checkpoint_with_perm_sync(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch,
|
||||
end: SecondsSinceUnixEpoch,
|
||||
checkpoint: CanvasConnectorCheckpoint,
|
||||
) -> CheckpointOutput[CanvasConnectorCheckpoint]:
|
||||
"""Load documents from checkpoint with permission information included."""
|
||||
return self._load_from_checkpoint(
|
||||
start, end, checkpoint, include_permissions=True
|
||||
)
|
||||
|
||||
@override
|
||||
def build_dummy_checkpoint(self) -> CanvasConnectorCheckpoint:
|
||||
return CanvasConnectorCheckpoint(has_more=True)
|
||||
|
||||
@override
|
||||
def validate_checkpoint_json(
|
||||
self, checkpoint_json: str
|
||||
) -> CanvasConnectorCheckpoint:
|
||||
return CanvasConnectorCheckpoint.model_validate_json(checkpoint_json)
|
||||
|
||||
@override
|
||||
def validate_connector_settings(self) -> None:
|
||||
"""Validate Canvas connector settings by testing API access."""
|
||||
@@ -771,6 +415,38 @@ class CanvasConnector(
|
||||
f"Unexpected error during Canvas settings validation: {exc}"
|
||||
)
|
||||
|
||||
@override
|
||||
def load_from_checkpoint(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch,
|
||||
end: SecondsSinceUnixEpoch,
|
||||
checkpoint: CanvasConnectorCheckpoint,
|
||||
) -> CheckpointOutput[CanvasConnectorCheckpoint]:
|
||||
# TODO(benwu408): implemented in PR3 (checkpoint)
|
||||
raise NotImplementedError
|
||||
|
||||
@override
|
||||
def load_from_checkpoint_with_perm_sync(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch,
|
||||
end: SecondsSinceUnixEpoch,
|
||||
checkpoint: CanvasConnectorCheckpoint,
|
||||
) -> CheckpointOutput[CanvasConnectorCheckpoint]:
|
||||
# TODO(benwu408): implemented in PR3 (checkpoint)
|
||||
raise NotImplementedError
|
||||
|
||||
@override
|
||||
def build_dummy_checkpoint(self) -> CanvasConnectorCheckpoint:
|
||||
# TODO(benwu408): implemented in PR3 (checkpoint)
|
||||
raise NotImplementedError
|
||||
|
||||
@override
|
||||
def validate_checkpoint_json(
|
||||
self, checkpoint_json: str
|
||||
) -> CanvasConnectorCheckpoint:
|
||||
# TODO(benwu408): implemented in PR3 (checkpoint)
|
||||
raise NotImplementedError
|
||||
|
||||
@override
|
||||
def retrieve_all_slim_docs_perm_sync(
|
||||
self,
|
||||
|
||||
@@ -171,10 +171,7 @@ class ClickupConnector(LoadConnector, PollConnector):
|
||||
document.metadata[extra_field] = task[extra_field]
|
||||
|
||||
if self.retrieve_task_comments:
|
||||
document.sections = [
|
||||
*document.sections,
|
||||
*self._get_task_comments(task["id"]),
|
||||
]
|
||||
document.sections.extend(self._get_task_comments(task["id"]))
|
||||
|
||||
doc_batch.append(document)
|
||||
|
||||
|
||||
@@ -61,9 +61,6 @@ _USER_NOT_FOUND = "Unknown Confluence User"
|
||||
_USER_ID_TO_DISPLAY_NAME_CACHE: dict[str, str | None] = {}
|
||||
_USER_EMAIL_CACHE: dict[str, str | None] = {}
|
||||
_DEFAULT_PAGINATION_LIMIT = 1000
|
||||
_MINIMUM_PAGINATION_LIMIT = 5
|
||||
|
||||
_SERVER_ERROR_CODES = {500, 502, 503, 504}
|
||||
|
||||
_CONFLUENCE_SPACES_API_V1 = "rest/api/space"
|
||||
_CONFLUENCE_SPACES_API_V2 = "wiki/api/v2/spaces"
|
||||
@@ -572,8 +569,7 @@ class OnyxConfluence:
|
||||
if not limit:
|
||||
limit = _DEFAULT_PAGINATION_LIMIT
|
||||
|
||||
current_limit = limit
|
||||
url_suffix = update_param_in_path(url_suffix, "limit", str(current_limit))
|
||||
url_suffix = update_param_in_path(url_suffix, "limit", str(limit))
|
||||
|
||||
while url_suffix:
|
||||
logger.debug(f"Making confluence call to {url_suffix}")
|
||||
@@ -613,61 +609,40 @@ class OnyxConfluence:
|
||||
)
|
||||
continue
|
||||
|
||||
if raw_response.status_code in _SERVER_ERROR_CODES:
|
||||
# Try reducing the page size -- Confluence often times out
|
||||
# on large result sets (especially Cloud 504s).
|
||||
if current_limit > _MINIMUM_PAGINATION_LIMIT:
|
||||
old_limit = current_limit
|
||||
current_limit = max(
|
||||
current_limit // 2, _MINIMUM_PAGINATION_LIMIT
|
||||
)
|
||||
logger.warning(
|
||||
f"Confluence returned {raw_response.status_code}. "
|
||||
f"Reducing limit from {old_limit} to {current_limit} "
|
||||
f"and retrying."
|
||||
)
|
||||
url_suffix = update_param_in_path(
|
||||
url_suffix, "limit", str(current_limit)
|
||||
)
|
||||
continue
|
||||
# If we fail due to a 500, try one by one.
|
||||
# NOTE: this iterative approach only works for server, since cloud uses cursor-based
|
||||
# pagination
|
||||
if raw_response.status_code == 500 and not self._is_cloud:
|
||||
initial_start = get_start_param_from_url(url_suffix)
|
||||
if initial_start is None:
|
||||
# can't handle this if we don't have offset-based pagination
|
||||
raise
|
||||
|
||||
# Limit reduction exhausted -- for Server, fall back to
|
||||
# one-by-one offset pagination as a last resort.
|
||||
if not self._is_cloud:
|
||||
initial_start = get_start_param_from_url(url_suffix)
|
||||
# this will just yield the successful items from the batch
|
||||
new_url_suffix = (
|
||||
yield from self._try_one_by_one_for_paginated_url(
|
||||
url_suffix,
|
||||
initial_start=initial_start,
|
||||
limit=current_limit,
|
||||
)
|
||||
)
|
||||
# this means we ran into an empty page
|
||||
if new_url_suffix is None:
|
||||
if next_page_callback:
|
||||
next_page_callback("")
|
||||
break
|
||||
# this will just yield the successful items from the batch
|
||||
new_url_suffix = yield from self._try_one_by_one_for_paginated_url(
|
||||
url_suffix,
|
||||
initial_start=initial_start,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
url_suffix = new_url_suffix
|
||||
continue
|
||||
# this means we ran into an empty page
|
||||
if new_url_suffix is None:
|
||||
if next_page_callback:
|
||||
next_page_callback("")
|
||||
break
|
||||
|
||||
url_suffix = new_url_suffix
|
||||
continue
|
||||
|
||||
else:
|
||||
logger.exception(
|
||||
f"Error in confluence call to {url_suffix} "
|
||||
f"after reducing limit to {current_limit}.\n"
|
||||
f"Raw Response Text: {raw_response.text}\n"
|
||||
f"Error: {e}\n"
|
||||
f"Error in confluence call to {url_suffix} \n"
|
||||
f"Raw Response Text: {raw_response.text} \n"
|
||||
f"Full Response: {raw_response.__dict__} \n"
|
||||
f"Error: {e} \n"
|
||||
)
|
||||
raise
|
||||
|
||||
logger.exception(
|
||||
f"Error in confluence call to {url_suffix} \n"
|
||||
f"Raw Response Text: {raw_response.text} \n"
|
||||
f"Full Response: {raw_response.__dict__} \n"
|
||||
f"Error: {e} \n"
|
||||
)
|
||||
raise
|
||||
|
||||
try:
|
||||
next_response = raw_response.json()
|
||||
except Exception as e:
|
||||
@@ -705,10 +680,6 @@ class OnyxConfluence:
|
||||
old_url_suffix = url_suffix
|
||||
updated_start = get_start_param_from_url(old_url_suffix)
|
||||
url_suffix = cast(str, next_response.get("_links", {}).get("next", ""))
|
||||
if url_suffix and current_limit != limit:
|
||||
url_suffix = update_param_in_path(
|
||||
url_suffix, "limit", str(current_limit)
|
||||
)
|
||||
for i, result in enumerate(results):
|
||||
updated_start += 1
|
||||
if url_suffix and next_page_callback and i == len(results) - 1:
|
||||
|
||||
@@ -1,65 +0,0 @@
|
||||
import csv
|
||||
import io
|
||||
from typing import IO
|
||||
|
||||
from onyx.connectors.models import TabularSection
|
||||
from onyx.file_processing.extract_file_text import file_io_to_text
|
||||
from onyx.file_processing.extract_file_text import xlsx_sheet_extraction
|
||||
from onyx.file_processing.file_types import OnyxFileExtensions
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def is_tabular_file(file_name: str) -> bool:
|
||||
lowered = file_name.lower()
|
||||
return any(lowered.endswith(ext) for ext in OnyxFileExtensions.TABULAR_EXTENSIONS)
|
||||
|
||||
|
||||
def _tsv_to_csv(tsv_text: str) -> str:
|
||||
"""Re-serialize tab-separated text as CSV so downstream parsers that
|
||||
assume the default Excel dialect read the columns correctly."""
|
||||
out = io.StringIO()
|
||||
csv.writer(out, lineterminator="\n").writerows(
|
||||
csv.reader(io.StringIO(tsv_text), dialect="excel-tab")
|
||||
)
|
||||
return out.getvalue().rstrip("\n")
|
||||
|
||||
|
||||
def tabular_file_to_sections(
|
||||
file: IO[bytes],
|
||||
file_name: str,
|
||||
link: str = "",
|
||||
) -> list[TabularSection]:
|
||||
"""Convert a tabular file into one or more TabularSections.
|
||||
|
||||
- .xlsx → one TabularSection per non-empty sheet.
|
||||
- .csv / .tsv → a single TabularSection containing the full decoded
|
||||
file.
|
||||
|
||||
Returns an empty list when the file yields no extractable content.
|
||||
"""
|
||||
lowered = file_name.lower()
|
||||
|
||||
if lowered.endswith(".xlsx"):
|
||||
return [
|
||||
TabularSection(link=f"{file_name} :: {sheet_title}", text=csv_text)
|
||||
for csv_text, sheet_title in xlsx_sheet_extraction(
|
||||
file, file_name=file_name
|
||||
)
|
||||
]
|
||||
|
||||
if not lowered.endswith((".csv", ".tsv")):
|
||||
raise ValueError(f"{file_name!r} is not a tabular file")
|
||||
|
||||
try:
|
||||
text = file_io_to_text(file).strip()
|
||||
except Exception:
|
||||
logger.exception(f"Failure decoding {file_name}")
|
||||
raise
|
||||
|
||||
if not text:
|
||||
return []
|
||||
if lowered.endswith(".tsv"):
|
||||
text = _tsv_to_csv(text)
|
||||
return [TabularSection(link=link or file_name, text=text)]
|
||||
@@ -42,9 +42,6 @@ from onyx.connectors.google_drive.file_retrieval import (
|
||||
get_all_files_in_my_drive_and_shared,
|
||||
)
|
||||
from onyx.connectors.google_drive.file_retrieval import get_external_access_for_folder
|
||||
from onyx.connectors.google_drive.file_retrieval import (
|
||||
get_files_by_web_view_links_batch,
|
||||
)
|
||||
from onyx.connectors.google_drive.file_retrieval import get_files_in_shared_drive
|
||||
from onyx.connectors.google_drive.file_retrieval import get_folder_metadata
|
||||
from onyx.connectors.google_drive.file_retrieval import get_root_folder_id
|
||||
@@ -73,14 +70,11 @@ from onyx.connectors.interfaces import CheckpointedConnectorWithPermSync
|
||||
from onyx.connectors.interfaces import CheckpointOutput
|
||||
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.connectors.interfaces import NormalizationResult
|
||||
from onyx.connectors.interfaces import Resolver
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.interfaces import SlimConnector
|
||||
from onyx.connectors.interfaces import SlimConnectorWithPermSync
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import DocumentFailure
|
||||
from onyx.connectors.models import EntityFailure
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.connectors.models import SlimDocument
|
||||
@@ -208,10 +202,7 @@ class DriveIdStatus(Enum):
|
||||
|
||||
|
||||
class GoogleDriveConnector(
|
||||
SlimConnector,
|
||||
SlimConnectorWithPermSync,
|
||||
CheckpointedConnectorWithPermSync[GoogleDriveCheckpoint],
|
||||
Resolver,
|
||||
SlimConnectorWithPermSync, CheckpointedConnectorWithPermSync[GoogleDriveCheckpoint]
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -1674,89 +1665,12 @@ class GoogleDriveConnector(
|
||||
start, end, checkpoint, include_permissions=True
|
||||
)
|
||||
|
||||
@override
|
||||
def resolve_errors(
|
||||
self,
|
||||
errors: list[ConnectorFailure],
|
||||
include_permissions: bool = False,
|
||||
) -> Generator[Document | ConnectorFailure | HierarchyNode, None, None]:
|
||||
if self._creds is None or self._primary_admin_email is None:
|
||||
raise RuntimeError(
|
||||
"Credentials missing, should not call this method before calling load_credentials"
|
||||
)
|
||||
|
||||
logger.info(f"Resolving {len(errors)} errors")
|
||||
doc_ids = [
|
||||
failure.failed_document.document_id
|
||||
for failure in errors
|
||||
if failure.failed_document
|
||||
]
|
||||
service = get_drive_service(self.creds, self.primary_admin_email)
|
||||
field_type = (
|
||||
DriveFileFieldType.WITH_PERMISSIONS
|
||||
if include_permissions or self.exclude_domain_link_only
|
||||
else DriveFileFieldType.STANDARD
|
||||
)
|
||||
batch_result = get_files_by_web_view_links_batch(service, doc_ids, field_type)
|
||||
|
||||
for doc_id, error in batch_result.errors.items():
|
||||
yield ConnectorFailure(
|
||||
failed_document=DocumentFailure(
|
||||
document_id=doc_id,
|
||||
document_link=doc_id,
|
||||
),
|
||||
failure_message=f"Failed to retrieve file during error resolution: {error}",
|
||||
exception=error,
|
||||
)
|
||||
|
||||
permission_sync_context = (
|
||||
PermissionSyncContext(
|
||||
primary_admin_email=self.primary_admin_email,
|
||||
google_domain=self.google_domain,
|
||||
)
|
||||
if include_permissions
|
||||
else None
|
||||
)
|
||||
|
||||
retrieved_files = [
|
||||
RetrievedDriveFile(
|
||||
drive_file=file,
|
||||
user_email=self.primary_admin_email,
|
||||
completion_stage=DriveRetrievalStage.DONE,
|
||||
)
|
||||
for file in batch_result.files.values()
|
||||
]
|
||||
|
||||
yield from self._get_new_ancestors_for_files(
|
||||
files=retrieved_files,
|
||||
seen_hierarchy_node_raw_ids=ThreadSafeSet(),
|
||||
fully_walked_hierarchy_node_raw_ids=ThreadSafeSet(),
|
||||
permission_sync_context=permission_sync_context,
|
||||
add_prefix=True,
|
||||
)
|
||||
|
||||
func_with_args = [
|
||||
(
|
||||
self._convert_retrieved_file_to_document,
|
||||
(rf, permission_sync_context),
|
||||
)
|
||||
for rf in retrieved_files
|
||||
]
|
||||
results = cast(
|
||||
list[Document | ConnectorFailure | None],
|
||||
run_functions_tuples_in_parallel(func_with_args, max_workers=8),
|
||||
)
|
||||
for result in results:
|
||||
if result is not None:
|
||||
yield result
|
||||
|
||||
def _extract_slim_docs_from_google_drive(
|
||||
self,
|
||||
checkpoint: GoogleDriveCheckpoint,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
include_permissions: bool = True,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
files_batch: list[RetrievedDriveFile] = []
|
||||
slim_batch: list[SlimDocument | HierarchyNode] = []
|
||||
@@ -1766,13 +1680,9 @@ class GoogleDriveConnector(
|
||||
nonlocal files_batch, slim_batch
|
||||
|
||||
# Get new ancestor hierarchy nodes first
|
||||
permission_sync_context = (
|
||||
PermissionSyncContext(
|
||||
primary_admin_email=self.primary_admin_email,
|
||||
google_domain=self.google_domain,
|
||||
)
|
||||
if include_permissions
|
||||
else None
|
||||
permission_sync_context = PermissionSyncContext(
|
||||
primary_admin_email=self.primary_admin_email,
|
||||
google_domain=self.google_domain,
|
||||
)
|
||||
new_ancestors = self._get_new_ancestors_for_files(
|
||||
files=files_batch,
|
||||
@@ -1786,7 +1696,10 @@ class GoogleDriveConnector(
|
||||
if doc := build_slim_document(
|
||||
self.creds,
|
||||
file.drive_file,
|
||||
permission_sync_context,
|
||||
PermissionSyncContext(
|
||||
primary_admin_email=self.primary_admin_email,
|
||||
google_domain=self.google_domain,
|
||||
),
|
||||
retriever_email=file.user_email,
|
||||
):
|
||||
slim_batch.append(doc)
|
||||
@@ -1826,12 +1739,11 @@ class GoogleDriveConnector(
|
||||
if files_batch:
|
||||
yield _yield_slim_batch()
|
||||
|
||||
def _retrieve_all_slim_docs_impl(
|
||||
def retrieve_all_slim_docs_perm_sync(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
include_permissions: bool = True,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
try:
|
||||
checkpoint = self.build_dummy_checkpoint()
|
||||
@@ -1841,34 +1753,13 @@ class GoogleDriveConnector(
|
||||
start=start,
|
||||
end=end,
|
||||
callback=callback,
|
||||
include_permissions=include_permissions,
|
||||
)
|
||||
logger.info("Drive slim doc retrieval complete")
|
||||
logger.info("Drive perm sync: Slim doc retrieval complete")
|
||||
|
||||
except Exception as e:
|
||||
if MISSING_SCOPES_ERROR_STR in str(e):
|
||||
raise PermissionError(ONYX_SCOPE_INSTRUCTIONS) from e
|
||||
raise
|
||||
|
||||
@override
|
||||
def retrieve_all_slim_docs(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
return self._retrieve_all_slim_docs_impl(
|
||||
start=start, end=end, callback=callback, include_permissions=False
|
||||
)
|
||||
|
||||
def retrieve_all_slim_docs_perm_sync(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
return self._retrieve_all_slim_docs_impl(
|
||||
start=start, end=end, callback=callback, include_permissions=True
|
||||
)
|
||||
raise e
|
||||
|
||||
def validate_connector_settings(self) -> None:
|
||||
if self._creds is None:
|
||||
|
||||
@@ -9,7 +9,6 @@ from urllib.parse import urlparse
|
||||
|
||||
from googleapiclient.discovery import Resource # type: ignore
|
||||
from googleapiclient.errors import HttpError # type: ignore
|
||||
from googleapiclient.http import BatchHttpRequest # type: ignore
|
||||
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.connectors.google_drive.constants import DRIVE_FOLDER_TYPE
|
||||
@@ -61,8 +60,6 @@ SLIM_FILE_FIELDS = (
|
||||
)
|
||||
FOLDER_FIELDS = "nextPageToken, files(id, name, permissions, modifiedTime, webViewLink, shortcutDetails)"
|
||||
|
||||
MAX_BATCH_SIZE = 100
|
||||
|
||||
HIERARCHY_FIELDS = "id, name, parents, webViewLink, mimeType, driveId"
|
||||
|
||||
HIERARCHY_FIELDS_WITH_PERMISSIONS = (
|
||||
@@ -219,7 +216,7 @@ def get_external_access_for_folder(
|
||||
|
||||
|
||||
def _get_fields_for_file_type(field_type: DriveFileFieldType) -> str:
|
||||
"""Get the appropriate fields string for files().list() based on the field type enum."""
|
||||
"""Get the appropriate fields string based on the field type enum"""
|
||||
if field_type == DriveFileFieldType.SLIM:
|
||||
return SLIM_FILE_FIELDS
|
||||
elif field_type == DriveFileFieldType.WITH_PERMISSIONS:
|
||||
@@ -228,25 +225,6 @@ def _get_fields_for_file_type(field_type: DriveFileFieldType) -> str:
|
||||
return FILE_FIELDS
|
||||
|
||||
|
||||
def _extract_single_file_fields(list_fields: str) -> str:
|
||||
"""Convert a files().list() fields string to one suitable for files().get().
|
||||
|
||||
List fields look like "nextPageToken, files(field1, field2, ...)"
|
||||
Single-file fields should be just "field1, field2, ..."
|
||||
"""
|
||||
start = list_fields.find("files(")
|
||||
if start == -1:
|
||||
return list_fields
|
||||
inner_start = start + len("files(")
|
||||
inner_end = list_fields.rfind(")")
|
||||
return list_fields[inner_start:inner_end]
|
||||
|
||||
|
||||
def _get_single_file_fields(field_type: DriveFileFieldType) -> str:
|
||||
"""Get the appropriate fields string for files().get() based on the field type enum."""
|
||||
return _extract_single_file_fields(_get_fields_for_file_type(field_type))
|
||||
|
||||
|
||||
def _get_files_in_parent(
|
||||
service: Resource,
|
||||
parent_id: str,
|
||||
@@ -558,74 +536,3 @@ def get_file_by_web_view_link(
|
||||
)
|
||||
.execute()
|
||||
)
|
||||
|
||||
|
||||
class BatchRetrievalResult:
|
||||
"""Result of a batch file retrieval, separating successes from errors."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.files: dict[str, GoogleDriveFileType] = {}
|
||||
self.errors: dict[str, Exception] = {}
|
||||
|
||||
|
||||
def get_files_by_web_view_links_batch(
|
||||
service: GoogleDriveService,
|
||||
web_view_links: list[str],
|
||||
field_type: DriveFileFieldType,
|
||||
) -> BatchRetrievalResult:
|
||||
"""Retrieve multiple Google Drive files by webViewLink using the batch API.
|
||||
|
||||
Returns a BatchRetrievalResult containing successful file retrievals
|
||||
and errors for any files that could not be fetched.
|
||||
Automatically splits into chunks of MAX_BATCH_SIZE.
|
||||
"""
|
||||
fields = _get_single_file_fields(field_type)
|
||||
if len(web_view_links) <= MAX_BATCH_SIZE:
|
||||
return _get_files_by_web_view_links_batch(service, web_view_links, fields)
|
||||
|
||||
combined = BatchRetrievalResult()
|
||||
for i in range(0, len(web_view_links), MAX_BATCH_SIZE):
|
||||
chunk = web_view_links[i : i + MAX_BATCH_SIZE]
|
||||
chunk_result = _get_files_by_web_view_links_batch(service, chunk, fields)
|
||||
combined.files.update(chunk_result.files)
|
||||
combined.errors.update(chunk_result.errors)
|
||||
return combined
|
||||
|
||||
|
||||
def _get_files_by_web_view_links_batch(
|
||||
service: GoogleDriveService,
|
||||
web_view_links: list[str],
|
||||
fields: str,
|
||||
) -> BatchRetrievalResult:
|
||||
"""Single-batch implementation."""
|
||||
|
||||
result = BatchRetrievalResult()
|
||||
|
||||
def callback(
|
||||
request_id: str,
|
||||
response: GoogleDriveFileType,
|
||||
exception: Exception | None,
|
||||
) -> None:
|
||||
if exception:
|
||||
logger.warning(f"Error retrieving file {request_id}: {exception}")
|
||||
result.errors[request_id] = exception
|
||||
else:
|
||||
result.files[request_id] = response
|
||||
|
||||
batch = cast(BatchHttpRequest, service.new_batch_http_request(callback=callback))
|
||||
|
||||
for web_view_link in web_view_links:
|
||||
try:
|
||||
file_id = _extract_file_id_from_web_view_link(web_view_link)
|
||||
request = service.files().get(
|
||||
fileId=file_id,
|
||||
supportsAllDrives=True,
|
||||
fields=fields,
|
||||
)
|
||||
batch.add(request, request_id=web_view_link)
|
||||
except ValueError as e:
|
||||
logger.warning(f"Failed to extract file ID from {web_view_link}: {e}")
|
||||
result.errors[web_view_link] = e
|
||||
|
||||
batch.execute()
|
||||
return result
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import json
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from urllib.parse import parse_qs
|
||||
from urllib.parse import ParseResult
|
||||
@@ -54,21 +53,6 @@ from onyx.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _load_google_json(raw: object) -> dict[str, Any]:
|
||||
"""Accept both the current (dict) and legacy (JSON string) KV payload shapes.
|
||||
|
||||
Payloads written before the fix for serializing Google credentials into
|
||||
``EncryptedJson`` columns are stored as JSON strings; new writes store dicts.
|
||||
Once every install has re-uploaded their Google credentials the legacy
|
||||
``str`` branch can be removed.
|
||||
"""
|
||||
if isinstance(raw, dict):
|
||||
return raw
|
||||
if isinstance(raw, str):
|
||||
return json.loads(raw)
|
||||
raise ValueError(f"Unexpected Google credential payload type: {type(raw)!r}")
|
||||
|
||||
|
||||
def _build_frontend_google_drive_redirect(source: DocumentSource) -> str:
|
||||
if source == DocumentSource.GOOGLE_DRIVE:
|
||||
return f"{WEB_DOMAIN}/admin/connectors/google-drive/auth/callback"
|
||||
@@ -178,13 +162,12 @@ def build_service_account_creds(
|
||||
|
||||
def get_auth_url(credential_id: int, source: DocumentSource) -> str:
|
||||
if source == DocumentSource.GOOGLE_DRIVE:
|
||||
credential_json = _load_google_json(
|
||||
get_kv_store().load(KV_GOOGLE_DRIVE_CRED_KEY)
|
||||
)
|
||||
creds_str = str(get_kv_store().load(KV_GOOGLE_DRIVE_CRED_KEY))
|
||||
elif source == DocumentSource.GMAIL:
|
||||
credential_json = _load_google_json(get_kv_store().load(KV_GMAIL_CRED_KEY))
|
||||
creds_str = str(get_kv_store().load(KV_GMAIL_CRED_KEY))
|
||||
else:
|
||||
raise ValueError(f"Unsupported source: {source}")
|
||||
credential_json = json.loads(creds_str)
|
||||
flow = InstalledAppFlow.from_client_config(
|
||||
credential_json,
|
||||
scopes=GOOGLE_SCOPES[source],
|
||||
@@ -205,12 +188,12 @@ def get_auth_url(credential_id: int, source: DocumentSource) -> str:
|
||||
|
||||
def get_google_app_cred(source: DocumentSource) -> GoogleAppCredentials:
|
||||
if source == DocumentSource.GOOGLE_DRIVE:
|
||||
creds = _load_google_json(get_kv_store().load(KV_GOOGLE_DRIVE_CRED_KEY))
|
||||
creds_str = str(get_kv_store().load(KV_GOOGLE_DRIVE_CRED_KEY))
|
||||
elif source == DocumentSource.GMAIL:
|
||||
creds = _load_google_json(get_kv_store().load(KV_GMAIL_CRED_KEY))
|
||||
creds_str = str(get_kv_store().load(KV_GMAIL_CRED_KEY))
|
||||
else:
|
||||
raise ValueError(f"Unsupported source: {source}")
|
||||
return GoogleAppCredentials(**creds)
|
||||
return GoogleAppCredentials(**json.loads(creds_str))
|
||||
|
||||
|
||||
def upsert_google_app_cred(
|
||||
@@ -218,14 +201,10 @@ def upsert_google_app_cred(
|
||||
) -> None:
|
||||
if source == DocumentSource.GOOGLE_DRIVE:
|
||||
get_kv_store().store(
|
||||
KV_GOOGLE_DRIVE_CRED_KEY,
|
||||
app_credentials.model_dump(mode="json"),
|
||||
encrypt=True,
|
||||
KV_GOOGLE_DRIVE_CRED_KEY, app_credentials.json(), encrypt=True
|
||||
)
|
||||
elif source == DocumentSource.GMAIL:
|
||||
get_kv_store().store(
|
||||
KV_GMAIL_CRED_KEY, app_credentials.model_dump(mode="json"), encrypt=True
|
||||
)
|
||||
get_kv_store().store(KV_GMAIL_CRED_KEY, app_credentials.json(), encrypt=True)
|
||||
else:
|
||||
raise ValueError(f"Unsupported source: {source}")
|
||||
|
||||
@@ -241,14 +220,12 @@ def delete_google_app_cred(source: DocumentSource) -> None:
|
||||
|
||||
def get_service_account_key(source: DocumentSource) -> GoogleServiceAccountKey:
|
||||
if source == DocumentSource.GOOGLE_DRIVE:
|
||||
creds = _load_google_json(
|
||||
get_kv_store().load(KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY)
|
||||
)
|
||||
creds_str = str(get_kv_store().load(KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY))
|
||||
elif source == DocumentSource.GMAIL:
|
||||
creds = _load_google_json(get_kv_store().load(KV_GMAIL_SERVICE_ACCOUNT_KEY))
|
||||
creds_str = str(get_kv_store().load(KV_GMAIL_SERVICE_ACCOUNT_KEY))
|
||||
else:
|
||||
raise ValueError(f"Unsupported source: {source}")
|
||||
return GoogleServiceAccountKey(**creds)
|
||||
return GoogleServiceAccountKey(**json.loads(creds_str))
|
||||
|
||||
|
||||
def upsert_service_account_key(
|
||||
@@ -257,14 +234,12 @@ def upsert_service_account_key(
|
||||
if source == DocumentSource.GOOGLE_DRIVE:
|
||||
get_kv_store().store(
|
||||
KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY,
|
||||
service_account_key.model_dump(mode="json"),
|
||||
service_account_key.json(),
|
||||
encrypt=True,
|
||||
)
|
||||
elif source == DocumentSource.GMAIL:
|
||||
get_kv_store().store(
|
||||
KV_GMAIL_SERVICE_ACCOUNT_KEY,
|
||||
service_account_key.model_dump(mode="json"),
|
||||
encrypt=True,
|
||||
KV_GMAIL_SERVICE_ACCOUNT_KEY, service_account_key.json(), encrypt=True
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported source: {source}")
|
||||
|
||||
@@ -123,9 +123,6 @@ class SlimConnector(BaseConnector):
|
||||
@abc.abstractmethod
|
||||
def retrieve_all_slim_docs(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -301,22 +298,6 @@ class CheckpointedConnectorWithPermSync(CheckpointedConnector[CT]):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class Resolver(BaseConnector):
|
||||
@abc.abstractmethod
|
||||
def resolve_errors(
|
||||
self,
|
||||
errors: list[ConnectorFailure],
|
||||
include_permissions: bool = False,
|
||||
) -> Generator[Document | ConnectorFailure | HierarchyNode, None, None]:
|
||||
"""Attempts to yield back ALL the documents described by the errors, no checkpointing.
|
||||
|
||||
Caller's responsibility is to delete the old ConnectorFailures and replace with the new ones.
|
||||
If include_permissions is True, the documents will have permissions synced.
|
||||
May also yield HierarchyNode objects for ancestor folders of resolved documents.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class HierarchyConnector(BaseConnector):
|
||||
@abc.abstractmethod
|
||||
def load_hierarchy(
|
||||
|
||||
@@ -60,10 +60,8 @@ logger = setup_logger()
|
||||
|
||||
ONE_HOUR = 3600
|
||||
|
||||
_MAX_RESULTS_FETCH_IDS = 5000
|
||||
_MAX_RESULTS_FETCH_IDS = 5000 # 5000
|
||||
_JIRA_FULL_PAGE_SIZE = 50
|
||||
# https://developer.atlassian.com/cloud/jira/platform/rest/v3/api-group-issues/
|
||||
_JIRA_BULK_FETCH_LIMIT = 100
|
||||
|
||||
# Constants for Jira field names
|
||||
_FIELD_REPORTER = "reporter"
|
||||
@@ -257,13 +255,15 @@ def _bulk_fetch_request(
|
||||
return resp.json()["issues"]
|
||||
|
||||
|
||||
def _bulk_fetch_batch(
|
||||
jira_client: JIRA, issue_ids: list[str], fields: str | None
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Fetch a single batch (must be <= _JIRA_BULK_FETCH_LIMIT).
|
||||
On JSONDecodeError, recursively bisects until it succeeds or reaches size 1."""
|
||||
def bulk_fetch_issues(
|
||||
jira_client: JIRA, issue_ids: list[str], fields: str | None = None
|
||||
) -> list[Issue]:
|
||||
# TODO(evan): move away from this jira library if they continue to not support
|
||||
# the endpoints we need. Using private fields is not ideal, but
|
||||
# is likely fine for now since we pin the library version
|
||||
|
||||
try:
|
||||
return _bulk_fetch_request(jira_client, issue_ids, fields)
|
||||
raw_issues = _bulk_fetch_request(jira_client, issue_ids, fields)
|
||||
except requests.exceptions.JSONDecodeError:
|
||||
if len(issue_ids) <= 1:
|
||||
logger.exception(
|
||||
@@ -277,25 +277,12 @@ def _bulk_fetch_batch(
|
||||
f"Jira bulk-fetch JSON decode failed for batch of {len(issue_ids)} issues. "
|
||||
f"Splitting into sub-batches of {mid} and {len(issue_ids) - mid}."
|
||||
)
|
||||
left = _bulk_fetch_batch(jira_client, issue_ids[:mid], fields)
|
||||
right = _bulk_fetch_batch(jira_client, issue_ids[mid:], fields)
|
||||
left = bulk_fetch_issues(jira_client, issue_ids[:mid], fields)
|
||||
right = bulk_fetch_issues(jira_client, issue_ids[mid:], fields)
|
||||
return left + right
|
||||
|
||||
|
||||
def bulk_fetch_issues(
|
||||
jira_client: JIRA, issue_ids: list[str], fields: str | None = None
|
||||
) -> list[Issue]:
|
||||
# TODO(evan): move away from this jira library if they continue to not support
|
||||
# the endpoints we need. Using private fields is not ideal, but
|
||||
# is likely fine for now since we pin the library version
|
||||
|
||||
raw_issues: list[dict[str, Any]] = []
|
||||
for batch in chunked(issue_ids, _JIRA_BULK_FETCH_LIMIT):
|
||||
try:
|
||||
raw_issues.extend(_bulk_fetch_batch(jira_client, list(batch), fields))
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching issues: {e}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching issues: {e}")
|
||||
raise
|
||||
|
||||
return [
|
||||
Issue(jira_client._options, jira_client._session, raw=issue)
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
import sys
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
@@ -35,18 +33,9 @@ class ConnectorMissingCredentialError(PermissionError):
|
||||
)
|
||||
|
||||
|
||||
class SectionType(str, Enum):
|
||||
"""Discriminator for Section subclasses."""
|
||||
|
||||
TEXT = "text"
|
||||
IMAGE = "image"
|
||||
TABULAR = "tabular"
|
||||
|
||||
|
||||
class Section(BaseModel):
|
||||
"""Base section class with common attributes"""
|
||||
|
||||
type: SectionType
|
||||
link: str | None = None
|
||||
text: str | None = None
|
||||
image_file_id: str | None = None
|
||||
@@ -55,7 +44,6 @@ class Section(BaseModel):
|
||||
class TextSection(Section):
|
||||
"""Section containing text content"""
|
||||
|
||||
type: Literal[SectionType.TEXT] = SectionType.TEXT
|
||||
text: str
|
||||
|
||||
def __sizeof__(self) -> int:
|
||||
@@ -65,25 +53,12 @@ class TextSection(Section):
|
||||
class ImageSection(Section):
|
||||
"""Section containing an image reference"""
|
||||
|
||||
type: Literal[SectionType.IMAGE] = SectionType.IMAGE
|
||||
image_file_id: str
|
||||
|
||||
def __sizeof__(self) -> int:
|
||||
return sys.getsizeof(self.image_file_id) + sys.getsizeof(self.link)
|
||||
|
||||
|
||||
class TabularSection(Section):
|
||||
"""Section containing tabular data (csv/tsv content, or one sheet of
|
||||
an xlsx workbook rendered as CSV)."""
|
||||
|
||||
type: Literal[SectionType.TABULAR] = SectionType.TABULAR
|
||||
text: str # CSV representation in a string
|
||||
link: str
|
||||
|
||||
def __sizeof__(self) -> int:
|
||||
return sys.getsizeof(self.text) + sys.getsizeof(self.link)
|
||||
|
||||
|
||||
class BasicExpertInfo(BaseModel):
|
||||
"""Basic Information for the owner of a document, any of the fields can be left as None
|
||||
Display fallback goes as follows:
|
||||
@@ -159,6 +134,7 @@ class BasicExpertInfo(BaseModel):
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, model_dict: dict[str, Any]) -> "BasicExpertInfo":
|
||||
|
||||
first_name = cast(str, model_dict.get("FirstName"))
|
||||
last_name = cast(str, model_dict.get("LastName"))
|
||||
email = cast(str, model_dict.get("Email"))
|
||||
@@ -185,7 +161,7 @@ class DocumentBase(BaseModel):
|
||||
"""Used for Onyx ingestion api, the ID is inferred before use if not provided"""
|
||||
|
||||
id: str | None = None
|
||||
sections: Sequence[TextSection | ImageSection | TabularSection]
|
||||
sections: list[TextSection | ImageSection]
|
||||
source: DocumentSource | None = None
|
||||
semantic_identifier: str # displayed in the UI as the main identifier for the doc
|
||||
# TODO(andrei): Ideally we could improve this to where each value is just a
|
||||
@@ -395,9 +371,12 @@ class IndexingDocument(Document):
|
||||
)
|
||||
else:
|
||||
section_len = sum(
|
||||
len(section.text) if section.text is not None else 0
|
||||
(
|
||||
len(section.text)
|
||||
if isinstance(section, TextSection) and section.text is not None
|
||||
else 0
|
||||
)
|
||||
for section in self.sections
|
||||
if isinstance(section, (TextSection, TabularSection))
|
||||
)
|
||||
|
||||
return title_len + section_len
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import TypedDict
|
||||
|
||||
@@ -7,14 +6,6 @@ from pydantic import BaseModel
|
||||
from onyx.onyxbot.slack.models import ChannelType
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DirectThreadFetch:
|
||||
"""Request to fetch a Slack thread directly by channel and timestamp."""
|
||||
|
||||
channel_id: str
|
||||
thread_ts: str
|
||||
|
||||
|
||||
class ChannelMetadata(TypedDict):
|
||||
"""Type definition for cached channel metadata."""
|
||||
|
||||
|
||||
@@ -19,7 +19,6 @@ from onyx.configs.chat_configs import DOC_TIME_DECAY
|
||||
from onyx.connectors.models import IndexingDocument
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.context.search.federated.models import ChannelMetadata
|
||||
from onyx.context.search.federated.models import DirectThreadFetch
|
||||
from onyx.context.search.federated.models import SlackMessage
|
||||
from onyx.context.search.federated.slack_search_utils import ALL_CHANNEL_TYPES
|
||||
from onyx.context.search.federated.slack_search_utils import build_channel_query_filter
|
||||
@@ -50,6 +49,7 @@ from onyx.server.federated.models import FederatedConnectorDetail
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
from onyx.utils.timing import log_function_time
|
||||
from shared_configs.configs import DOC_EMBEDDING_CONTEXT_SIZE
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -58,6 +58,7 @@ HIGHLIGHT_END_CHAR = "\ue001"
|
||||
|
||||
CHANNEL_METADATA_CACHE_TTL = 60 * 60 * 24 # 24 hours
|
||||
USER_PROFILE_CACHE_TTL = 60 * 60 * 24 # 24 hours
|
||||
SLACK_THREAD_CONTEXT_WINDOW = 3 # Number of messages before matched message to include
|
||||
CHANNEL_METADATA_MAX_RETRIES = 3 # Maximum retry attempts for channel metadata fetching
|
||||
CHANNEL_METADATA_RETRY_DELAY = 1 # Initial retry delay in seconds (exponential backoff)
|
||||
|
||||
@@ -420,94 +421,6 @@ class SlackQueryResult(BaseModel):
|
||||
filtered_channels: list[str] # Channels filtered out during this query
|
||||
|
||||
|
||||
def _fetch_thread_from_url(
|
||||
thread_fetch: DirectThreadFetch,
|
||||
access_token: str,
|
||||
channel_metadata_dict: dict[str, ChannelMetadata] | None = None,
|
||||
) -> SlackQueryResult:
|
||||
"""Fetch a thread directly from a Slack URL via conversations.replies."""
|
||||
channel_id = thread_fetch.channel_id
|
||||
thread_ts = thread_fetch.thread_ts
|
||||
|
||||
slack_client = WebClient(token=access_token)
|
||||
try:
|
||||
response = slack_client.conversations_replies(
|
||||
channel=channel_id,
|
||||
ts=thread_ts,
|
||||
)
|
||||
response.validate()
|
||||
messages: list[dict[str, Any]] = response.get("messages", [])
|
||||
except SlackApiError as e:
|
||||
logger.warning(
|
||||
f"Failed to fetch thread from URL (channel={channel_id}, ts={thread_ts}): {e}"
|
||||
)
|
||||
return SlackQueryResult(messages=[], filtered_channels=[])
|
||||
|
||||
if not messages:
|
||||
logger.warning(
|
||||
f"No messages found for URL override (channel={channel_id}, ts={thread_ts})"
|
||||
)
|
||||
return SlackQueryResult(messages=[], filtered_channels=[])
|
||||
|
||||
# Build thread text from all messages
|
||||
thread_text = _build_thread_text(messages, access_token, None, slack_client)
|
||||
|
||||
# Get channel name from metadata cache or API
|
||||
channel_name = "unknown"
|
||||
if channel_metadata_dict and channel_id in channel_metadata_dict:
|
||||
channel_name = channel_metadata_dict[channel_id].get("name", "unknown")
|
||||
else:
|
||||
try:
|
||||
ch_response = slack_client.conversations_info(channel=channel_id)
|
||||
ch_response.validate()
|
||||
channel_info: dict[str, Any] = ch_response.get("channel", {})
|
||||
channel_name = channel_info.get("name", "unknown")
|
||||
except SlackApiError:
|
||||
pass
|
||||
|
||||
# Build the SlackMessage
|
||||
parent_msg = messages[0]
|
||||
message_ts = parent_msg.get("ts", thread_ts)
|
||||
username = parent_msg.get("user", "unknown_user")
|
||||
parent_text = parent_msg.get("text", "")
|
||||
snippet = (
|
||||
parent_text[:50].rstrip() + "..." if len(parent_text) > 50 else parent_text
|
||||
).replace("\n", " ")
|
||||
|
||||
doc_time = datetime.fromtimestamp(float(message_ts))
|
||||
decay_factor = DOC_TIME_DECAY
|
||||
doc_age_years = (datetime.now() - doc_time).total_seconds() / (365 * 24 * 60 * 60)
|
||||
recency_bias = max(1 / (1 + decay_factor * doc_age_years), 0.75)
|
||||
|
||||
permalink = (
|
||||
f"https://slack.com/archives/{channel_id}/p{message_ts.replace('.', '')}"
|
||||
)
|
||||
|
||||
slack_message = SlackMessage(
|
||||
document_id=f"{channel_id}_{message_ts}",
|
||||
channel_id=channel_id,
|
||||
message_id=message_ts,
|
||||
thread_id=None, # Prevent double-enrichment in thread context fetch
|
||||
link=permalink,
|
||||
metadata={
|
||||
"channel": channel_name,
|
||||
"time": doc_time.isoformat(),
|
||||
},
|
||||
timestamp=doc_time,
|
||||
recency_bias=recency_bias,
|
||||
semantic_identifier=f"{username} in #{channel_name}: {snippet}",
|
||||
text=thread_text,
|
||||
highlighted_texts=set(),
|
||||
slack_score=100000.0, # High priority — user explicitly asked for this thread
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"URL override: fetched thread from channel={channel_id}, ts={thread_ts}, {len(messages)} messages"
|
||||
)
|
||||
|
||||
return SlackQueryResult(messages=[slack_message], filtered_channels=[])
|
||||
|
||||
|
||||
def query_slack(
|
||||
query_string: str,
|
||||
access_token: str,
|
||||
@@ -519,6 +432,7 @@ def query_slack(
|
||||
available_channels: list[str] | None = None,
|
||||
channel_metadata_dict: dict[str, ChannelMetadata] | None = None,
|
||||
) -> SlackQueryResult:
|
||||
|
||||
# Check if query has channel override (user specified channels in query)
|
||||
has_channel_override = query_string.startswith("__CHANNEL_OVERRIDE__")
|
||||
|
||||
@@ -748,6 +662,7 @@ def _fetch_thread_context(
|
||||
"""
|
||||
channel_id = message.channel_id
|
||||
thread_id = message.thread_id
|
||||
message_id = message.message_id
|
||||
|
||||
# If not a thread, return original text as success
|
||||
if thread_id is None:
|
||||
@@ -780,37 +695,62 @@ def _fetch_thread_context(
|
||||
if len(messages) <= 1:
|
||||
return ThreadContextResult.success(message.text)
|
||||
|
||||
# Build thread text from thread starter + all replies
|
||||
thread_text = _build_thread_text(messages, access_token, team_id, slack_client)
|
||||
# Build thread text from thread starter + context window around matched message
|
||||
thread_text = _build_thread_text(
|
||||
messages, message_id, thread_id, access_token, team_id, slack_client
|
||||
)
|
||||
return ThreadContextResult.success(thread_text)
|
||||
|
||||
|
||||
def _build_thread_text(
|
||||
messages: list[dict[str, Any]],
|
||||
message_id: str,
|
||||
thread_id: str,
|
||||
access_token: str,
|
||||
team_id: str | None,
|
||||
slack_client: WebClient,
|
||||
) -> str:
|
||||
"""Build thread text including all replies.
|
||||
|
||||
Includes the thread parent message followed by all replies in order.
|
||||
"""
|
||||
"""Build the thread text from messages."""
|
||||
msg_text = messages[0].get("text", "")
|
||||
msg_sender = messages[0].get("user", "")
|
||||
thread_text = f"<@{msg_sender}>: {msg_text}"
|
||||
|
||||
# All messages after index 0 are replies
|
||||
replies = messages[1:]
|
||||
if not replies:
|
||||
return thread_text
|
||||
|
||||
logger.debug(f"Thread {messages[0].get('ts')}: {len(replies)} replies included")
|
||||
thread_text += "\n\nReplies:"
|
||||
if thread_id == message_id:
|
||||
message_id_idx = 0
|
||||
else:
|
||||
message_id_idx = next(
|
||||
(i for i, msg in enumerate(messages) if msg.get("ts") == message_id), 0
|
||||
)
|
||||
if not message_id_idx:
|
||||
return thread_text
|
||||
|
||||
for msg in replies:
|
||||
start_idx = max(1, message_id_idx - SLACK_THREAD_CONTEXT_WINDOW)
|
||||
|
||||
if start_idx > 1:
|
||||
thread_text += "\n..."
|
||||
|
||||
for i in range(start_idx, message_id_idx):
|
||||
msg_text = messages[i].get("text", "")
|
||||
msg_sender = messages[i].get("user", "")
|
||||
thread_text += f"\n\n<@{msg_sender}>: {msg_text}"
|
||||
|
||||
msg_text = messages[message_id_idx].get("text", "")
|
||||
msg_sender = messages[message_id_idx].get("user", "")
|
||||
thread_text += f"\n\n<@{msg_sender}>: {msg_text}"
|
||||
|
||||
# Add following replies
|
||||
len_replies = 0
|
||||
for msg in messages[message_id_idx + 1 :]:
|
||||
msg_text = msg.get("text", "")
|
||||
msg_sender = msg.get("user", "")
|
||||
thread_text += f"\n\n<@{msg_sender}>: {msg_text}"
|
||||
reply = f"\n\n<@{msg_sender}>: {msg_text}"
|
||||
thread_text += reply
|
||||
|
||||
len_replies += len(reply)
|
||||
if len_replies >= DOC_EMBEDDING_CONTEXT_SIZE * 4:
|
||||
thread_text += "\n..."
|
||||
break
|
||||
|
||||
# Replace user IDs with names using cached lookups
|
||||
userids: set[str] = set(re.findall(r"<@([A-Z0-9]+)>", thread_text))
|
||||
@@ -1036,16 +976,7 @@ def slack_retrieval(
|
||||
|
||||
# Query slack with entity filtering
|
||||
llm = get_default_llm()
|
||||
query_items = build_slack_queries(query, llm, entities, available_channels)
|
||||
|
||||
# Partition into direct thread fetches and search query strings
|
||||
direct_fetches: list[DirectThreadFetch] = []
|
||||
query_strings: list[str] = []
|
||||
for item in query_items:
|
||||
if isinstance(item, DirectThreadFetch):
|
||||
direct_fetches.append(item)
|
||||
else:
|
||||
query_strings.append(item)
|
||||
query_strings = build_slack_queries(query, llm, entities, available_channels)
|
||||
|
||||
# Determine filtering based on entities OR context (bot)
|
||||
include_dm = False
|
||||
@@ -1062,16 +993,8 @@ def slack_retrieval(
|
||||
f"Private channel context: will only allow messages from {allowed_private_channel} + public channels"
|
||||
)
|
||||
|
||||
# Build search tasks — direct thread fetches + keyword searches
|
||||
search_tasks: list[tuple] = [
|
||||
(
|
||||
_fetch_thread_from_url,
|
||||
(fetch, access_token, channel_metadata_dict),
|
||||
)
|
||||
for fetch in direct_fetches
|
||||
]
|
||||
|
||||
search_tasks.extend(
|
||||
# Build search tasks
|
||||
search_tasks = [
|
||||
(
|
||||
query_slack,
|
||||
(
|
||||
@@ -1087,7 +1010,7 @@ def slack_retrieval(
|
||||
),
|
||||
)
|
||||
for query_string in query_strings
|
||||
)
|
||||
]
|
||||
|
||||
# If include_dm is True AND we're not already searching all channels,
|
||||
# add additional searches without channel filters.
|
||||
|
||||
@@ -10,7 +10,6 @@ from pydantic import ValidationError
|
||||
|
||||
from onyx.configs.app_configs import MAX_SLACK_QUERY_EXPANSIONS
|
||||
from onyx.context.search.federated.models import ChannelMetadata
|
||||
from onyx.context.search.federated.models import DirectThreadFetch
|
||||
from onyx.context.search.models import ChunkIndexRequest
|
||||
from onyx.federated_connectors.slack.models import SlackEntities
|
||||
from onyx.llm.interfaces import LLM
|
||||
@@ -639,38 +638,12 @@ def expand_query_with_llm(query_text: str, llm: LLM) -> list[str]:
|
||||
return [query_text]
|
||||
|
||||
|
||||
SLACK_URL_PATTERN = re.compile(
|
||||
r"https?://[a-z0-9-]+\.slack\.com/archives/([A-Z0-9]+)/p(\d{16})"
|
||||
)
|
||||
|
||||
|
||||
def extract_slack_message_urls(
|
||||
query_text: str,
|
||||
) -> list[tuple[str, str]]:
|
||||
"""Extract Slack message URLs from query text.
|
||||
|
||||
Parses URLs like:
|
||||
https://onyx-company.slack.com/archives/C097NBWMY8Y/p1775491616524769
|
||||
|
||||
Returns list of (channel_id, thread_ts) tuples.
|
||||
The 16-digit timestamp is converted to Slack ts format (with dot).
|
||||
"""
|
||||
results = []
|
||||
for match in SLACK_URL_PATTERN.finditer(query_text):
|
||||
channel_id = match.group(1)
|
||||
raw_ts = match.group(2)
|
||||
# Convert p1775491616524769 -> 1775491616.524769
|
||||
thread_ts = f"{raw_ts[:10]}.{raw_ts[10:]}"
|
||||
results.append((channel_id, thread_ts))
|
||||
return results
|
||||
|
||||
|
||||
def build_slack_queries(
|
||||
query: ChunkIndexRequest,
|
||||
llm: LLM,
|
||||
entities: dict[str, Any] | None = None,
|
||||
available_channels: list[str] | None = None,
|
||||
) -> list[str | DirectThreadFetch]:
|
||||
) -> list[str]:
|
||||
"""Build Slack query strings with date filtering and query expansion."""
|
||||
default_search_days = 30
|
||||
if entities:
|
||||
@@ -695,15 +668,6 @@ def build_slack_queries(
|
||||
cutoff_date = datetime.now(timezone.utc) - timedelta(days=days_back)
|
||||
time_filter = f" after:{cutoff_date.strftime('%Y-%m-%d')}"
|
||||
|
||||
# Check for Slack message URLs — if found, add direct fetch requests
|
||||
url_fetches: list[DirectThreadFetch] = []
|
||||
slack_urls = extract_slack_message_urls(query.query)
|
||||
for channel_id, thread_ts in slack_urls:
|
||||
url_fetches.append(
|
||||
DirectThreadFetch(channel_id=channel_id, thread_ts=thread_ts)
|
||||
)
|
||||
logger.info(f"Detected Slack URL: channel={channel_id}, ts={thread_ts}")
|
||||
|
||||
# ALWAYS extract channel references from the query (not just for recency queries)
|
||||
channel_references = extract_channel_references_from_query(query.query)
|
||||
|
||||
@@ -720,9 +684,7 @@ def build_slack_queries(
|
||||
|
||||
# If valid channels detected, use ONLY those channels with NO keywords
|
||||
# Return query with ONLY time filter + channel filter (no keywords)
|
||||
return url_fetches + [
|
||||
build_channel_override_query(channel_references, time_filter)
|
||||
]
|
||||
return [build_channel_override_query(channel_references, time_filter)]
|
||||
except ValueError as e:
|
||||
# If validation fails, log the error and continue with normal flow
|
||||
logger.warning(f"Channel reference validation failed: {e}")
|
||||
@@ -740,8 +702,7 @@ def build_slack_queries(
|
||||
rephrased_queries = expand_query_with_llm(query.query, llm)
|
||||
|
||||
# Build final query strings with time filters
|
||||
search_queries = [
|
||||
return [
|
||||
rephrased_query.strip() + time_filter
|
||||
for rephrased_query in rephrased_queries[:MAX_SLACK_QUERY_EXPANSIONS]
|
||||
]
|
||||
return url_fetches + search_queries
|
||||
|
||||
@@ -110,8 +110,8 @@ def insert_api_key(
|
||||
|
||||
# Assign the API key virtual user to the appropriate default group
|
||||
# before commit so everything is atomic.
|
||||
# Only ADMIN and BASIC roles get default group membership.
|
||||
if api_key_args.role in (UserRole.ADMIN, UserRole.BASIC):
|
||||
# LIMITED role service accounts should have no group membership.
|
||||
if api_key_args.role != UserRole.LIMITED:
|
||||
assign_user_to_default_groups__no_commit(
|
||||
db_session,
|
||||
api_key_user_row,
|
||||
@@ -161,8 +161,8 @@ def update_api_key(
|
||||
)
|
||||
db_session.execute(delete_stmt)
|
||||
|
||||
# Re-assign to the correct default group (only for ADMIN/BASIC).
|
||||
if api_key_args.role in (UserRole.ADMIN, UserRole.BASIC):
|
||||
# Re-assign to the correct default group (skip for LIMITED).
|
||||
if api_key_args.role != UserRole.LIMITED:
|
||||
assign_user_to_default_groups__no_commit(
|
||||
db_session,
|
||||
api_key_user,
|
||||
|
||||
@@ -750,3 +750,31 @@ def resync_cc_pair(
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
|
||||
# ── Metrics query helpers ──────────────────────────────────────────────
|
||||
|
||||
|
||||
def get_connector_health_for_metrics(
|
||||
db_session: Session,
|
||||
) -> list: # Returns list of Row tuples
|
||||
"""Return connector health data for Prometheus metrics.
|
||||
|
||||
Each row is (cc_pair_id, status, in_repeated_error_state,
|
||||
last_successful_index_time, name, source).
|
||||
"""
|
||||
return (
|
||||
db_session.query(
|
||||
ConnectorCredentialPair.id,
|
||||
ConnectorCredentialPair.status,
|
||||
ConnectorCredentialPair.in_repeated_error_state,
|
||||
ConnectorCredentialPair.last_successful_index_time,
|
||||
ConnectorCredentialPair.name,
|
||||
Connector.source,
|
||||
)
|
||||
.join(
|
||||
Connector,
|
||||
ConnectorCredentialPair.connector_id == Connector.id,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
@@ -335,7 +335,6 @@ def update_document_set(
|
||||
"Cannot update document set while it is syncing. Please wait for it to finish syncing, and then try again."
|
||||
)
|
||||
|
||||
document_set_row.name = document_set_update_request.name
|
||||
document_set_row.description = document_set_update_request.description
|
||||
if not DISABLE_VECTOR_DB:
|
||||
document_set_row.is_up_to_date = False
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user