Compare commits

..

7 Commits

Author SHA1 Message Date
Nik
7cbd96ffe6 fix(multi-model): replace React pointer handlers with window listeners for carousel drag
setPointerCapture + React synthetic event delegation conflict caused both
click-to-select and drag-scroll to break in selection mode. Replace
onPointerMove/onPointerUp React handlers with native window.addEventListener
closures created inside onPointerDown — window listeners fire regardless of
cursor position and don't require pointer capture.

- Removes dragStartX, baseTranslateX, dragCurrentDelta, isDraggingRef refs
- Adds dragCleanupRef for unmount-safe listener removal
- Adds e.preventDefault() in pointerdown to block text-selection drag
- Captures preferredIndex/responses/hiddenPanels at press time to avoid stale closures
2026-04-02 01:30:30 -07:00
Nik
ebe0514e21 fix(chat): defer setPointerCapture to drag threshold to restore panel clicks
setPointerCapture on pointerdown was redirecting pointerup to the
container for plain clicks, causing the browser to fire click on the
container instead of the child panel — breaking panel selection. Now
capture is set only once the 5px drag threshold is crossed in
pointermove, so clicks pass through normally.
2026-04-02 01:30:30 -07:00
Nik
c9339729c7 feat(chat): add drag-scroll to selection mode carousel
Pointer drag on the carousel container moves the track directly via DOM
transform (no re-renders during drag). On release, if the drag exceeds
80px the adjacent visible panel becomes preferred and the existing
width+position animation snaps it into center; shorter drags snap back
to the current preferred panel. Clicks after a drag are suppressed via
justDraggedRef so they don't accidentally change the selection.
2026-04-02 01:30:30 -07:00
Nik
719951cc08 docs(chat): add JSDoc to MultiModelResponseView and MultiModelPanel 2026-04-02 01:30:30 -07:00
Nik
0bcf2053ac feat(chat): multi-model response UI with carousel selection mode
Adds the frontend multi-model response UI:

- MultiModelResponseView: carousel-based layout with generation mode
  (equal panels side-by-side) and selection mode (preferred panel
  centered, non-preferred panels peeking at viewport edges with
  transform animation). Non-preferred panels are height-capped to the
  preferred panel's measured height, dimmed at 50% opacity, and receive
  a bottom fade gradient. Hidden panels constrained to 220px in both
  layouts.
- MultiModelPanel: panel header with provider icon, preferred badge,
  hide/show toggle, and AgentMessage body.
- ModelSelector: popover for selecting up to 3 models per chat.
- useMultiModelChat: hook for managing multi-model streaming state.
2026-04-02 01:30:30 -07:00
Nik
c9e1c6e742 fix(types): add type discriminant to MessageResponseIDInfo union types
Add literal type fields to MessageResponseIDInfo ("message_id_info") and
MultiModelMessageResponseIDInfo ("multi_model_message_id_info") to enable
proper TypeScript discriminated union narrowing between the two types.
2026-04-02 01:30:03 -07:00
Nik
63b84e91f1 feat(chat): add frontend types and API helpers for multi-model streaming 2026-04-02 01:30:03 -07:00
625 changed files with 10983 additions and 35382 deletions

View File

@@ -1 +0,0 @@
../../../cli/internal/embedded/SKILL.md

View File

@@ -0,0 +1,186 @@
---
name: onyx-cli
description: Query the Onyx knowledge base using the onyx-cli command. Use when the user wants to search company documents, ask questions about internal knowledge, query connected data sources, or look up information stored in Onyx.
---
# Onyx CLI — Agent Tool
Onyx is an enterprise search and Gen-AI platform that connects to company documents, apps, and people. The `onyx-cli` CLI provides non-interactive commands to query the Onyx knowledge base and list available agents.
## Prerequisites
### 1. Check if installed
```bash
which onyx-cli
```
### 2. Install (if needed)
**Primary — pip:**
```bash
pip install onyx-cli
```
**From source (Go):**
```bash
cd cli && go build -o onyx-cli . && sudo mv onyx-cli /usr/local/bin/
```
### 3. Check if configured
```bash
onyx-cli validate-config
```
This checks the config file exists, API key is present, and tests the server connection via `/api/me`. Exit code 0 on success, non-zero with a descriptive error on failure.
If unconfigured, you have two options:
**Option A — Interactive setup (requires user input):**
```bash
onyx-cli configure
```
This prompts for the Onyx server URL and API key, tests the connection, and saves config.
**Option B — Environment variables (non-interactive, preferred for agents):**
```bash
export ONYX_SERVER_URL="https://your-onyx-server.com" # default: https://cloud.onyx.app
export ONYX_API_KEY="your-api-key"
```
Environment variables override the config file. If these are set, no config file is needed.
| Variable | Required | Description |
|----------|----------|-------------|
| `ONYX_SERVER_URL` | No | Onyx server base URL (default: `https://cloud.onyx.app`) |
| `ONYX_API_KEY` | Yes | API key for authentication |
| `ONYX_PERSONA_ID` | No | Default agent/persona ID |
If neither the config file nor environment variables are set, tell the user that `onyx-cli` needs to be configured and ask them to either:
- Run `onyx-cli configure` interactively, or
- Set `ONYX_SERVER_URL` and `ONYX_API_KEY` environment variables
## Commands
### Validate configuration
```bash
onyx-cli validate-config
```
Checks config file exists, API key is present, and tests the server connection. Use this before `ask` or `agents` to confirm the CLI is properly set up.
### List available agents
```bash
onyx-cli agents
```
Prints a table of agent IDs, names, and descriptions. Use `--json` for structured output:
```bash
onyx-cli agents --json
```
Use agent IDs with `ask --agent-id` to query a specific agent.
### Basic query (plain text output)
```bash
onyx-cli ask "What is our company's PTO policy?"
```
Streams the answer as plain text to stdout. Exit code 0 on success, non-zero on error.
### JSON output (structured events)
```bash
onyx-cli ask --json "What authentication methods do we support?"
```
Outputs JSON-encoded parsed stream events (one object per line). Key event objects include message deltas, stop, errors, search-start, and citation payloads.
Each line is a JSON object with this envelope:
```json
{"type": "<event_type>", "event": { ... }}
```
| Event Type | Description |
|------------|-------------|
| `message_delta` | Content token — concatenate all `content` fields for the full answer |
| `stop` | Stream complete |
| `error` | Error with `error` message field |
| `search_tool_start` | Onyx started searching documents |
| `citation_info` | Source citation — see shape below |
`citation_info` event shape:
```json
{
"type": "citation_info",
"event": {
"citation_number": 1,
"document_id": "abc123def456",
"placement": {"turn_index": 0, "tab_index": 0, "sub_turn_index": null}
}
}
```
`placement` is metadata about where in the conversation the citation appeared and can be ignored for most use cases.
### Specify an agent
```bash
onyx-cli ask --agent-id 5 "Summarize our Q4 roadmap"
```
Uses a specific Onyx agent/persona instead of the default.
### All flags
| Flag | Type | Description |
|------|------|-------------|
| `--agent-id` | int | Agent ID to use (overrides default) |
| `--json` | bool | Output raw NDJSON events instead of plain text |
## Statelessness
Each `onyx-cli ask` call creates an independent chat session. There is no built-in way to chain context across multiple `ask` invocations — every call starts fresh. If you need multi-turn conversation with memory, use the interactive TUI (`onyx-cli` or `onyx-cli chat`) instead.
## When to Use
Use `onyx-cli ask` when:
- The user asks about company-specific information (policies, docs, processes)
- You need to search internal knowledge bases or connected data sources
- The user references Onyx, asks you to "search Onyx", or wants to query their documents
- You need context from company wikis, Confluence, Google Drive, Slack, or other connected sources
Do NOT use when:
- The question is about general programming knowledge (use your own knowledge)
- The user is asking about code in the current repository (use grep/read tools)
- The user hasn't mentioned Onyx and the question doesn't require internal company data
## Examples
```bash
# Simple question
onyx-cli ask "What are the steps to deploy to production?"
# Get structured output for parsing
onyx-cli ask --json "List all active API integrations"
# Use a specialized agent
onyx-cli ask --agent-id 3 "What were the action items from last week's standup?"
# Pipe the answer into another command
onyx-cli ask "What is the database schema for users?" | head -20
```

View File

@@ -1,64 +0,0 @@
FROM ubuntu:26.04@sha256:cc925e589b7543b910fea57a240468940003fbfc0515245a495dd0ad8fe7cef1
RUN apt-get update && apt-get install -y --no-install-recommends \
curl \
fd-find \
fzf \
git \
jq \
less \
make \
neovim \
openssh-client \
python3-venv \
ripgrep \
sudo \
ca-certificates \
iptables \
ipset \
iproute2 \
dnsutils \
unzip \
wget \
zsh \
&& curl -fsSL https://deb.nodesource.com/setup_20.x | bash - \
&& apt-get install -y nodejs \
&& install -m 0755 -d /etc/apt/keyrings \
&& curl -fsSL https://download.docker.com/linux/ubuntu/gpg -o /etc/apt/keyrings/docker.asc \
&& echo "deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/docker.asc] https://download.docker.com/linux/ubuntu $(. /etc/os-release && echo "$VERSION_CODENAME") stable" > /etc/apt/sources.list.d/docker.list \
&& curl -fsSL https://cli.github.com/packages/githubcli-archive-keyring.gpg -o /etc/apt/keyrings/githubcli-archive-keyring.gpg \
&& chmod go+r /etc/apt/keyrings/githubcli-archive-keyring.gpg \
&& echo "deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/githubcli-archive-keyring.gpg] https://cli.github.com/packages stable main" > /etc/apt/sources.list.d/github-cli.list \
&& apt-get update \
&& apt-get install -y --no-install-recommends docker-ce-cli docker-compose-plugin gh \
&& apt-get clean && rm -rf /var/lib/apt/lists/*
# fd-find installs as fdfind on Debian/Ubuntu — symlink to fd
RUN ln -sf "$(which fdfind)" /usr/local/bin/fd
# Install uv (Python package manager)
COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /usr/local/bin/
# Create non-root dev user with passwordless sudo
RUN useradd -m -s /bin/zsh dev && \
echo "dev ALL=(ALL) NOPASSWD:ALL" > /etc/sudoers.d/dev && \
chmod 0440 /etc/sudoers.d/dev
ENV DEVCONTAINER=true
RUN mkdir -p /workspace && \
chown -R dev:dev /workspace
WORKDIR /workspace
# Install Claude Code
ARG CLAUDE_CODE_VERSION=latest
RUN npm install -g @anthropic-ai/claude-code@${CLAUDE_CODE_VERSION}
# Configure zsh — source the repo-local zshrc so shell customization
# doesn't require an image rebuild.
RUN chsh -s /bin/zsh root && \
for rc in /root/.zshrc /home/dev/.zshrc; do \
echo '[ -f /workspace/.devcontainer/zshrc ] && . /workspace/.devcontainer/zshrc' >> "$rc"; \
done && \
chown dev:dev /home/dev/.zshrc

View File

@@ -1,99 +0,0 @@
# Onyx Dev Container
A containerized development environment for working on Onyx.
## What's included
- Ubuntu 26.04 base image
- Node.js 20, uv, Claude Code
- Docker CLI, GitHub CLI (`gh`)
- Neovim, ripgrep, fd, fzf, jq, make, wget, unzip
- Zsh as default shell (sources host `~/.zshrc` if available)
- Python venv auto-activation
- Network firewall (default-deny, whitelists npm, GitHub, Anthropic APIs, Sentry, and VS Code update servers)
## Usage
### CLI (`ods dev`)
The [`ods` devtools CLI](../tools/ods/README.md) provides workspace-aware wrappers
for all devcontainer operations (also available as `ods dc`):
```bash
# Start the container
ods dev up
# Open a shell
ods dev into
# Run a command
ods dev exec npm test
# Stop the container
ods dev stop
```
## Restarting the container
```bash
# Restart the container
ods dev restart
# Pull the latest published image and recreate
ods dev rebuild
```
## Image
The devcontainer uses a prebuilt image published to `onyxdotapp/onyx-devcontainer`.
The tag is pinned in `devcontainer.json` — no local build is required.
To build the image locally (e.g. while iterating on the Dockerfile):
```bash
docker buildx bake devcontainer
```
The `devcontainer` target is defined in `docker-bake.hcl` at the repo root.
## User & permissions
The container runs as the `dev` user by default (`remoteUser` in devcontainer.json).
An init script (`init-dev-user.sh`) runs at container start to ensure the active
user has read/write access to the bind-mounted workspace:
- **Standard Docker** — `dev`'s UID/GID is remapped to match the workspace owner,
so file permissions work seamlessly.
- **Rootless Docker** — The workspace appears as root-owned (UID 0) inside the
container due to user-namespace mapping. `ods dev up` auto-detects rootless Docker
and sets `DEVCONTAINER_REMOTE_USER=root` so the container runs as root — which
maps back to your host user via the user namespace. New files are owned by your
host UID and no ACL workarounds are needed.
To override the auto-detection, set `DEVCONTAINER_REMOTE_USER` before running
`ods dev up`.
## Docker socket
The container mounts the host's Docker socket so you can run `docker` commands
from inside. `ods dev` auto-detects the socket path and sets `DOCKER_SOCK`:
| Environment | Socket path |
| ----------------------- | ------------------------------ |
| Linux (rootless Docker) | `$XDG_RUNTIME_DIR/docker.sock` |
| macOS (Docker Desktop) | `~/.docker/run/docker.sock` |
| Linux (standard Docker) | `/var/run/docker.sock` |
To override, set `DOCKER_SOCK` before running `ods dev up`.
## Firewall
The container starts with a default-deny firewall (`init-firewall.sh`) that only allows outbound traffic to:
- npm registry
- GitHub
- Anthropic API
- Sentry
- VS Code update servers
This requires the `NET_ADMIN` and `NET_RAW` capabilities, which are added via `runArgs` in `devcontainer.json`.

View File

@@ -1,24 +0,0 @@
{
"name": "Onyx Dev Sandbox",
"image": "onyxdotapp/onyx-devcontainer@sha256:12184169c5bcc9cca0388286d5ffe504b569bc9c37bfa631b76ee8eee2064055",
"runArgs": ["--cap-add=NET_ADMIN", "--cap-add=NET_RAW"],
"mounts": [
"source=${localEnv:DOCKER_SOCK},target=/var/run/docker.sock,type=bind",
"source=${localEnv:HOME}/.claude,target=/home/dev/.claude,type=bind",
"source=${localEnv:HOME}/.claude.json,target=/home/dev/.claude.json,type=bind",
"source=${localEnv:HOME}/.zshrc,target=/home/dev/.zshrc.host,type=bind,readonly",
"source=${localEnv:HOME}/.gitconfig,target=/home/dev/.gitconfig,type=bind,readonly",
"source=${localEnv:HOME}/.config/nvim,target=/home/dev/.config/nvim,type=bind,readonly",
"source=onyx-devcontainer-cache,target=/home/dev/.cache,type=volume",
"source=onyx-devcontainer-local,target=/home/dev/.local,type=volume"
],
"containerEnv": {
"SSH_AUTH_SOCK": "/tmp/ssh-agent.sock"
},
"remoteUser": "${localEnv:DEVCONTAINER_REMOTE_USER:dev}",
"updateRemoteUserUID": false,
"workspaceMount": "source=${localWorkspaceFolder},target=/workspace,type=bind,consistency=delegated",
"workspaceFolder": "/workspace",
"postStartCommand": "sudo bash /workspace/.devcontainer/init-dev-user.sh && sudo bash /workspace/.devcontainer/init-firewall.sh",
"waitFor": "postStartCommand"
}

View File

@@ -1,107 +0,0 @@
#!/usr/bin/env bash
set -euo pipefail
# Remap the dev user's UID/GID to match the workspace owner so that
# bind-mounted files are accessible without running as root.
#
# Standard Docker: Workspace is owned by the host user's UID (e.g. 1000).
# We remap dev to that UID -- fast and seamless.
#
# Rootless Docker: Workspace appears as root-owned (UID 0) inside the
# container due to user-namespace mapping. Requires
# DEVCONTAINER_REMOTE_USER=root (set automatically by
# ods dev up). Container root IS the host user, so
# bind-mounts and named volumes are symlinked into /root.
WORKSPACE=/workspace
TARGET_USER=dev
REMOTE_USER="${SUDO_USER:-$TARGET_USER}"
WS_UID=$(stat -c '%u' "$WORKSPACE")
WS_GID=$(stat -c '%g' "$WORKSPACE")
DEV_UID=$(id -u "$TARGET_USER")
DEV_GID=$(id -g "$TARGET_USER")
# devcontainer.json bind-mounts and named volumes target /home/dev regardless
# of remoteUser. When running as root ($HOME=/root), Phase 1 bridges the gap
# with symlinks from ACTIVE_HOME → MOUNT_HOME.
MOUNT_HOME=/home/"$TARGET_USER"
if [ "$REMOTE_USER" = "root" ]; then
ACTIVE_HOME="/root"
else
ACTIVE_HOME="$MOUNT_HOME"
fi
# ── Phase 1: home directory setup ───────────────────────────────────
# ~/.local and ~/.cache are named Docker volumes mounted under MOUNT_HOME.
mkdir -p "$MOUNT_HOME"/.local/state "$MOUNT_HOME"/.local/share
# When running as root, symlink bind-mounts and named volumes into /root
# so that $HOME-relative tools (Claude Code, git, etc.) find them.
if [ "$ACTIVE_HOME" != "$MOUNT_HOME" ]; then
for item in .claude .cache .local; do
[ -d "$MOUNT_HOME/$item" ] || continue
if [ -e "$ACTIVE_HOME/$item" ] && [ ! -L "$ACTIVE_HOME/$item" ]; then
echo "warning: replacing $ACTIVE_HOME/$item with symlink to $MOUNT_HOME/$item" >&2
rm -rf "$ACTIVE_HOME/$item"
fi
ln -sfn "$MOUNT_HOME/$item" "$ACTIVE_HOME/$item"
done
# Symlink files (not directories).
for file in .claude.json .gitconfig .zshrc.host; do
[ -f "$MOUNT_HOME/$file" ] && ln -sf "$MOUNT_HOME/$file" "$ACTIVE_HOME/$file"
done
# Nested mount: .config/nvim
if [ -d "$MOUNT_HOME/.config/nvim" ]; then
mkdir -p "$ACTIVE_HOME/.config"
if [ -e "$ACTIVE_HOME/.config/nvim" ] && [ ! -L "$ACTIVE_HOME/.config/nvim" ]; then
echo "warning: replacing $ACTIVE_HOME/.config/nvim with symlink" >&2
rm -rf "$ACTIVE_HOME/.config/nvim"
fi
ln -sfn "$MOUNT_HOME/.config/nvim" "$ACTIVE_HOME/.config/nvim"
fi
fi
# ── Phase 2: workspace access ───────────────────────────────────────
# Root always has workspace access; Phase 1 handled home setup.
if [ "$REMOTE_USER" = "root" ]; then
exit 0
fi
# Already matching -- nothing to do.
if [ "$WS_UID" = "$DEV_UID" ] && [ "$WS_GID" = "$DEV_GID" ]; then
exit 0
fi
if [ "$WS_UID" != "0" ]; then
# ── Standard Docker ──────────────────────────────────────────────
# Workspace is owned by a non-root UID (the host user).
# Remap dev's UID/GID to match.
if [ "$DEV_GID" != "$WS_GID" ]; then
if ! groupmod -g "$WS_GID" "$TARGET_USER" 2>&1; then
echo "warning: failed to remap $TARGET_USER GID to $WS_GID" >&2
fi
fi
if [ "$DEV_UID" != "$WS_UID" ]; then
if ! usermod -u "$WS_UID" -g "$WS_GID" "$TARGET_USER" 2>&1; then
echo "warning: failed to remap $TARGET_USER UID to $WS_UID" >&2
fi
fi
if ! chown -R "$TARGET_USER":"$TARGET_USER" "$MOUNT_HOME" 2>&1; then
echo "warning: failed to chown $MOUNT_HOME" >&2
fi
else
# ── Rootless Docker ──────────────────────────────────────────────
# Workspace is root-owned (UID 0) due to user-namespace mapping.
# The supported path is remoteUser=root (set DEVCONTAINER_REMOTE_USER=root),
# which is handled above. If we reach here, the user is running as dev
# under rootless Docker without the override.
echo "error: rootless Docker detected but remoteUser is not root." >&2
echo " Set DEVCONTAINER_REMOTE_USER=root before starting the container," >&2
echo " or use 'ods dev up' which sets it automatically." >&2
exit 1
fi

View File

@@ -1,104 +0,0 @@
#!/usr/bin/env bash
set -euo pipefail
echo "Setting up firewall..."
# Preserve docker dns resolution
DOCKER_DNS_RULES=$(iptables-save | grep -E "^-A.*-d 127.0.0.11/32" || true)
# Flush all rules
iptables -t nat -F
iptables -t nat -X
iptables -t mangle -F
iptables -t mangle -X
iptables -F
iptables -X
# Restore docker dns rules
if [ -n "$DOCKER_DNS_RULES" ]; then
echo "$DOCKER_DNS_RULES" | iptables-restore -n
fi
# Create ipset for allowed destinations
ipset create allowed-domains hash:net || true
ipset flush allowed-domains
# Fetch GitHub IP ranges (IPv4 only -- ipset hash:net and iptables are IPv4)
GITHUB_IPS=$(curl -s https://api.github.com/meta | jq -r '.api[]' 2>/dev/null | grep -v ':' || echo "")
for ip in $GITHUB_IPS; do
if ! ipset add allowed-domains "$ip" -exist 2>&1; then
echo "warning: failed to add GitHub IP $ip to allowlist" >&2
fi
done
# Resolve allowed domains
ALLOWED_DOMAINS=(
"registry.npmjs.org"
"api.anthropic.com"
"api-staging.anthropic.com"
"files.anthropic.com"
"sentry.io"
"update.code.visualstudio.com"
"pypi.org"
"files.pythonhosted.org"
"go.dev"
"storage.googleapis.com"
"static.rust-lang.org"
)
for domain in "${ALLOWED_DOMAINS[@]}"; do
IPS=$(getent ahosts "$domain" 2>/dev/null | awk '{print $1}' | grep -v ':' | sort -u || echo "")
for ip in $IPS; do
if ! ipset add allowed-domains "$ip/32" -exist 2>&1; then
echo "warning: failed to add $domain ($ip) to allowlist" >&2
fi
done
done
# Detect host network
if [[ "${DOCKER_HOST:-}" == "unix://"* ]]; then
DOCKER_GATEWAY=$(ip -4 route show | grep "^default" | awk '{print $3}')
if ! ipset add allowed-domains "$DOCKER_GATEWAY/32" -exist 2>&1; then
echo "warning: failed to add Docker gateway $DOCKER_GATEWAY to allowlist" >&2
fi
fi
# Set default policies to DROP
iptables -P FORWARD DROP
iptables -P INPUT DROP
iptables -P OUTPUT DROP
# Allow established connections
iptables -A INPUT -m conntrack --ctstate ESTABLISHED,RELATED -j ACCEPT
iptables -A OUTPUT -m conntrack --ctstate ESTABLISHED,RELATED -j ACCEPT
# Allow loopback
iptables -A INPUT -i lo -j ACCEPT
iptables -A OUTPUT -o lo -j ACCEPT
# Allow DNS
iptables -A OUTPUT -p udp --dport 53 -j ACCEPT
iptables -A OUTPUT -p tcp --dport 53 -j ACCEPT
# Allow outbound to allowed destinations
iptables -A OUTPUT -m set --match-set allowed-domains dst -j ACCEPT
# Reject unauthorized outbound
iptables -A OUTPUT -j REJECT --reject-with icmp-host-unreachable
# Validate firewall configuration
echo "Validating firewall configuration..."
BLOCKED_SITES=("example.com" "google.com" "facebook.com")
for site in "${BLOCKED_SITES[@]}"; do
if timeout 2 ping -c 1 "$site" &>/dev/null; then
echo "Warning: $site is still reachable"
fi
done
if ! timeout 5 curl -s https://api.github.com/meta > /dev/null; then
echo "Warning: GitHub API is not accessible"
fi
echo "Firewall setup complete"

View File

@@ -1,10 +0,0 @@
# Devcontainer zshrc — sourced automatically for both root and dev users.
# Edit this file to customize the shell without rebuilding the image.
# Auto-activate Python venv
if [ -f /workspace/.venv/bin/activate ]; then
. /workspace/.venv/bin/activate
fi
# Source host zshrc if bind-mounted
[ -f ~/.zshrc.host ] && . ~/.zshrc.host

View File

@@ -13,7 +13,7 @@ permissions:
id-token: write # zizmor: ignore[excessive-permissions]
env:
EDGE_TAG: ${{ startsWith(github.ref_name, 'nightly-latest') || github.ref_name == 'edge' }}
EDGE_TAG: ${{ startsWith(github.ref_name, 'nightly-latest') }}
jobs:
# Determine which components to build based on the tag
@@ -156,7 +156,7 @@ jobs:
check-version-tag:
runs-on: ubuntu-slim
timeout-minutes: 10
if: ${{ !startsWith(github.ref_name, 'nightly-latest') && github.ref_name != 'edge' && github.event_name != 'workflow_dispatch' }}
if: ${{ !startsWith(github.ref_name, 'nightly-latest') && github.event_name != 'workflow_dispatch' }}
steps:
- name: Checkout
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
@@ -228,7 +228,7 @@ jobs:
- name: Create GitHub Release
id: create-release
uses: softprops/action-gh-release@153bb8e04406b158c6c84fc1615b65b24149a1fe # ratchet:softprops/action-gh-release@v2
uses: softprops/action-gh-release@da05d552573ad5aba039eaac05058a918a7bf631 # ratchet:softprops/action-gh-release@v2
with:
tag_name: ${{ steps.release-tag.outputs.tag }}
name: ${{ steps.release-tag.outputs.tag }}

View File

@@ -21,7 +21,7 @@ jobs:
persist-credentials: false
- name: Install Helm CLI
uses: azure/setup-helm@dda3372f752e03dde6b3237bc9431cdc2f7a02a2 # ratchet:azure/setup-helm@v5.0.0
uses: azure/setup-helm@1a275c3b69536ee54be43f2070a358922e12c8d4 # ratchet:azure/setup-helm@v4
with:
version: v3.12.1

View File

@@ -13,7 +13,7 @@ jobs:
runs-on: ubuntu-latest
timeout-minutes: 45
steps:
- uses: actions/stale@b5d41d4e1d5dceea10e7104786b73624c18a190f # ratchet:actions/stale@v10
- uses: actions/stale@997185467fa4f803885201cee163a9f38240193d # ratchet:actions/stale@v10
with:
stale-issue-message: 'This issue is stale because it has been open 75 days with no activity. Remove stale label or comment or this will be closed in 15 days.'
stale-pr-message: 'This PR is stale because it has been open 75 days with no activity. Remove stale label or comment or this will be closed in 15 days.'

View File

@@ -36,7 +36,7 @@ jobs:
persist-credentials: false
- name: Set up Helm
uses: azure/setup-helm@dda3372f752e03dde6b3237bc9431cdc2f7a02a2 # ratchet:azure/setup-helm@v5.0.0
uses: azure/setup-helm@1a275c3b69536ee54be43f2070a358922e12c8d4 # ratchet:azure/setup-helm@v4.3.1
with:
version: v3.19.0

3
.gitignore vendored
View File

@@ -59,6 +59,3 @@ node_modules
# plans
plans/
# Added context for LLMs
onyx-llm-context/

View File

@@ -9,6 +9,7 @@ repos:
rev: d30b4298e4fb63ce8609e29acdbcf4c9018a483c
hooks:
- id: uv-sync
args: ["--locked", "--all-extras"]
- id: uv-lock
- id: uv-export
name: uv-export default.txt
@@ -17,7 +18,7 @@ repos:
"--no-emit-project",
"--no-default-groups",
"--no-hashes",
"--group",
"--extra",
"backend",
"-o",
"backend/requirements/default.txt",
@@ -30,7 +31,7 @@ repos:
"--no-emit-project",
"--no-default-groups",
"--no-hashes",
"--group",
"--extra",
"dev",
"-o",
"backend/requirements/dev.txt",
@@ -43,7 +44,7 @@ repos:
"--no-emit-project",
"--no-default-groups",
"--no-hashes",
"--group",
"--extra",
"ee",
"-o",
"backend/requirements/ee.txt",
@@ -56,7 +57,7 @@ repos:
"--no-emit-project",
"--no-default-groups",
"--no-hashes",
"--group",
"--extra",
"model_server",
"-o",
"backend/requirements/model_server.txt",

3
.vscode/launch.json vendored
View File

@@ -531,7 +531,8 @@
"request": "launch",
"runtimeExecutable": "uv",
"runtimeArgs": [
"sync"
"sync",
"--all-extras"
],
"cwd": "${workspaceFolder}",
"console": "integratedTerminal",

View File

@@ -117,7 +117,7 @@ If using PowerShell, the command slightly differs:
Install the required Python dependencies:
```bash
uv sync
uv sync --all-extras
```
Install Playwright for Python (headless browser required by the Web Connector):

View File

@@ -1,4 +1,4 @@
from typing import Any
from typing import Any, Literal
from onyx.db.engine.iam_auth import get_iam_auth_token
from onyx.configs.app_configs import USE_IAM_AUTH
from onyx.configs.app_configs import POSTGRES_HOST
@@ -19,6 +19,7 @@ from logging.config import fileConfig
from alembic import context
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.sql.schema import SchemaItem
from onyx.configs.constants import SSL_CERT_FILE
from shared_configs.configs import (
MULTI_TENANT,
@@ -44,6 +45,8 @@ if config.config_file_name is not None and config.attributes.get(
target_metadata = [Base.metadata, ResultModelBase.metadata]
EXCLUDE_TABLES = {"kombu_queue", "kombu_message"}
logger = logging.getLogger(__name__)
ssl_context: ssl.SSLContext | None = None
@@ -53,6 +56,25 @@ if USE_IAM_AUTH:
ssl_context = ssl.create_default_context(cafile=SSL_CERT_FILE)
def include_object(
object: SchemaItem, # noqa: ARG001
name: str | None,
type_: Literal[
"schema",
"table",
"column",
"index",
"unique_constraint",
"foreign_key_constraint",
],
reflected: bool, # noqa: ARG001
compare_to: SchemaItem | None, # noqa: ARG001
) -> bool:
if type_ == "table" and name in EXCLUDE_TABLES:
return False
return True
def filter_tenants_by_range(
tenant_ids: list[str], start_range: int | None = None, end_range: int | None = None
) -> list[str]:
@@ -209,6 +231,7 @@ def do_run_migrations(
context.configure(
connection=connection,
target_metadata=target_metadata, # type: ignore
include_object=include_object,
version_table_schema=schema_name,
include_schemas=True,
compare_type=True,
@@ -382,6 +405,7 @@ def run_migrations_offline() -> None:
url=url,
target_metadata=target_metadata, # type: ignore
literal_binds=True,
include_object=include_object,
version_table_schema=schema,
include_schemas=True,
script_location=config.get_main_option("script_location"),
@@ -423,6 +447,7 @@ def run_migrations_offline() -> None:
url=url,
target_metadata=target_metadata, # type: ignore
literal_binds=True,
include_object=include_object,
version_table_schema=schema,
include_schemas=True,
script_location=config.get_main_option("script_location"),
@@ -465,6 +490,7 @@ def run_migrations_online() -> None:
context.configure(
connection=connection,
target_metadata=target_metadata, # type: ignore
include_object=include_object,
version_table_schema=schema_name,
include_schemas=True,
compare_type=True,

View File

@@ -1,108 +0,0 @@
"""backfill_account_type
Revision ID: 03d085c5c38d
Revises: 977e834c1427
Create Date: 2026-03-25 16:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "03d085c5c38d"
down_revision = "977e834c1427"
branch_labels = None
depends_on = None
_STANDARD = "STANDARD"
_BOT = "BOT"
_EXT_PERM_USER = "EXT_PERM_USER"
_SERVICE_ACCOUNT = "SERVICE_ACCOUNT"
_ANONYMOUS = "ANONYMOUS"
# Well-known anonymous user UUID
ANONYMOUS_USER_ID = "00000000-0000-0000-0000-000000000002"
# Email pattern for API key virtual users
API_KEY_EMAIL_PATTERN = r"API\_KEY\_\_%"
# Reflect the table structure for use in DML
user_table = sa.table(
"user",
sa.column("id", sa.Uuid),
sa.column("email", sa.String),
sa.column("role", sa.String),
sa.column("account_type", sa.String),
)
def upgrade() -> None:
# ------------------------------------------------------------------
# Step 1: Backfill account_type from role.
# Order matters — most-specific matches first so the final catch-all
# only touches rows that haven't been classified yet.
# ------------------------------------------------------------------
# 1a. API key virtual users → SERVICE_ACCOUNT
op.execute(
sa.update(user_table)
.where(
user_table.c.email.ilike(API_KEY_EMAIL_PATTERN),
user_table.c.account_type.is_(None),
)
.values(account_type=_SERVICE_ACCOUNT)
)
# 1b. Anonymous user → ANONYMOUS
op.execute(
sa.update(user_table)
.where(
user_table.c.id == ANONYMOUS_USER_ID,
user_table.c.account_type.is_(None),
)
.values(account_type=_ANONYMOUS)
)
# 1c. SLACK_USER role → BOT
op.execute(
sa.update(user_table)
.where(
user_table.c.role == "SLACK_USER",
user_table.c.account_type.is_(None),
)
.values(account_type=_BOT)
)
# 1d. EXT_PERM_USER role → EXT_PERM_USER
op.execute(
sa.update(user_table)
.where(
user_table.c.role == "EXT_PERM_USER",
user_table.c.account_type.is_(None),
)
.values(account_type=_EXT_PERM_USER)
)
# 1e. Everything else → STANDARD
op.execute(
sa.update(user_table)
.where(user_table.c.account_type.is_(None))
.values(account_type=_STANDARD)
)
# ------------------------------------------------------------------
# Step 2: Set account_type to NOT NULL now that every row is filled.
# ------------------------------------------------------------------
op.alter_column(
"user",
"account_type",
nullable=False,
server_default="STANDARD",
)
def downgrade() -> None:
op.alter_column("user", "account_type", nullable=True, server_default=None)
op.execute(sa.update(user_table).values(account_type=None))

View File

@@ -1,104 +0,0 @@
"""add_effective_permissions
Adds a JSONB column `effective_permissions` to the user table to store
directly granted permissions (e.g. ["admin"] or ["basic"]). Implied
permissions are expanded at read time, not stored.
Backfill: joins user__user_group → permission_grant to collect each
user's granted permissions into a JSON array. Users without group
memberships keep the default [].
Revision ID: 503883791c39
Revises: b4b7e1028dfd
Create Date: 2026-03-30 14:49:22.261748
"""
from collections.abc import Sequence
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "503883791c39"
down_revision = "b4b7e1028dfd"
branch_labels: str | None = None
depends_on: str | Sequence[str] | None = None
user_table = sa.table(
"user",
sa.column("id", sa.Uuid),
sa.column("effective_permissions", postgresql.JSONB),
)
user_user_group = sa.table(
"user__user_group",
sa.column("user_id", sa.Uuid),
sa.column("user_group_id", sa.Integer),
)
permission_grant = sa.table(
"permission_grant",
sa.column("group_id", sa.Integer),
sa.column("permission", sa.String),
sa.column("is_deleted", sa.Boolean),
)
def upgrade() -> None:
op.add_column(
"user",
sa.Column(
"effective_permissions",
postgresql.JSONB(),
nullable=False,
server_default=sa.text("'[]'::jsonb"),
),
)
conn = op.get_bind()
# Deduplicated permissions per user
deduped = (
sa.select(
user_user_group.c.user_id,
permission_grant.c.permission,
)
.select_from(
user_user_group.join(
permission_grant,
sa.and_(
permission_grant.c.group_id == user_user_group.c.user_group_id,
permission_grant.c.is_deleted == sa.false(),
),
)
)
.distinct()
.subquery("deduped")
)
# Aggregate into JSONB array per user (order is not guaranteed;
# consumers read this as a set so ordering does not matter)
perms_per_user = (
sa.select(
deduped.c.user_id,
sa.func.jsonb_agg(
deduped.c.permission,
type_=postgresql.JSONB,
).label("perms"),
)
.group_by(deduped.c.user_id)
.subquery("sub")
)
conn.execute(
user_table.update()
.where(user_table.c.id == perms_per_user.c.user_id)
.values(effective_permissions=perms_per_user.c.perms)
)
def downgrade() -> None:
op.drop_column("user", "effective_permissions")

View File

@@ -1,541 +0,0 @@
"""add proposal review tables
Revision ID: 61ea78857c97
Revises: c7bf5721733e
Create Date: 2026-04-09 10:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
import fastapi_users_db_sqlalchemy
# revision identifiers, used by Alembic.
revision = "61ea78857c97"
down_revision = "c7bf5721733e"
branch_labels: str | None = None
depends_on: str | None = None
def upgrade() -> None:
# -- proposal_review_ruleset --
op.create_table(
"proposal_review_ruleset",
sa.Column(
"id",
postgresql.UUID(as_uuid=True),
server_default=sa.text("gen_random_uuid()"),
nullable=False,
),
sa.Column("tenant_id", sa.Text(), nullable=False),
sa.Column("name", sa.Text(), nullable=False),
sa.Column("description", sa.Text(), nullable=True),
sa.Column(
"is_default",
sa.Boolean(),
server_default=sa.text("false"),
nullable=False,
),
sa.Column(
"is_active",
sa.Boolean(),
server_default=sa.text("true"),
nullable=False,
),
sa.Column(
"created_by",
fastapi_users_db_sqlalchemy.generics.GUID(),
nullable=True,
),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.ForeignKeyConstraint(["created_by"], ["user.id"]),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
"ix_proposal_review_ruleset_tenant_id",
"proposal_review_ruleset",
["tenant_id"],
)
# -- proposal_review_rule --
op.create_table(
"proposal_review_rule",
sa.Column(
"id",
postgresql.UUID(as_uuid=True),
server_default=sa.text("gen_random_uuid()"),
nullable=False,
),
sa.Column(
"ruleset_id",
postgresql.UUID(as_uuid=True),
nullable=False,
),
sa.Column("name", sa.Text(), nullable=False),
sa.Column("description", sa.Text(), nullable=True),
sa.Column("category", sa.Text(), nullable=True),
sa.Column("rule_type", sa.Text(), nullable=False),
sa.Column(
"rule_intent",
sa.Text(),
server_default=sa.text("'CHECK'"),
nullable=False,
),
sa.Column("prompt_template", sa.Text(), nullable=False),
sa.Column(
"source",
sa.Text(),
server_default=sa.text("'MANUAL'"),
nullable=False,
),
sa.Column("authority", sa.Text(), nullable=True),
sa.Column(
"is_hard_stop",
sa.Boolean(),
server_default=sa.text("false"),
nullable=False,
),
sa.Column(
"priority",
sa.Integer(),
server_default=sa.text("0"),
nullable=False,
),
sa.Column(
"is_active",
sa.Boolean(),
server_default=sa.text("true"),
nullable=False,
),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.ForeignKeyConstraint(
["ruleset_id"],
["proposal_review_ruleset.id"],
ondelete="CASCADE",
),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
"ix_proposal_review_rule_ruleset_id",
"proposal_review_rule",
["ruleset_id"],
)
# -- proposal_review_proposal --
op.create_table(
"proposal_review_proposal",
sa.Column(
"id",
postgresql.UUID(as_uuid=True),
server_default=sa.text("gen_random_uuid()"),
nullable=False,
),
sa.Column("document_id", sa.Text(), nullable=False),
sa.Column("tenant_id", sa.Text(), nullable=False),
sa.Column(
"status",
sa.Text(),
server_default=sa.text("'PENDING'"),
nullable=False,
),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("document_id", "tenant_id"),
)
op.create_index(
"ix_proposal_review_proposal_tenant_id",
"proposal_review_proposal",
["tenant_id"],
)
op.create_index(
"ix_proposal_review_proposal_document_id",
"proposal_review_proposal",
["document_id"],
)
op.create_index(
"ix_proposal_review_proposal_status",
"proposal_review_proposal",
["status"],
)
# -- proposal_review_run --
op.create_table(
"proposal_review_run",
sa.Column(
"id",
postgresql.UUID(as_uuid=True),
server_default=sa.text("gen_random_uuid()"),
nullable=False,
),
sa.Column(
"proposal_id",
postgresql.UUID(as_uuid=True),
nullable=False,
),
sa.Column(
"ruleset_id",
postgresql.UUID(as_uuid=True),
nullable=False,
),
sa.Column(
"triggered_by",
fastapi_users_db_sqlalchemy.generics.GUID(),
nullable=False,
),
sa.Column(
"status",
sa.Text(),
server_default=sa.text("'PENDING'"),
nullable=False,
),
sa.Column("total_rules", sa.Integer(), nullable=False),
sa.Column(
"completed_rules",
sa.Integer(),
server_default=sa.text("0"),
nullable=False,
),
sa.Column("started_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("completed_at", sa.DateTime(timezone=True), nullable=True),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.ForeignKeyConstraint(
["proposal_id"],
["proposal_review_proposal.id"],
ondelete="CASCADE",
),
sa.ForeignKeyConstraint(
["ruleset_id"],
["proposal_review_ruleset.id"],
),
sa.ForeignKeyConstraint(["triggered_by"], ["user.id"]),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
"ix_proposal_review_run_proposal_id",
"proposal_review_run",
["proposal_id"],
)
# -- proposal_review_finding --
op.create_table(
"proposal_review_finding",
sa.Column(
"id",
postgresql.UUID(as_uuid=True),
server_default=sa.text("gen_random_uuid()"),
nullable=False,
),
sa.Column(
"proposal_id",
postgresql.UUID(as_uuid=True),
nullable=False,
),
sa.Column(
"rule_id",
postgresql.UUID(as_uuid=True),
nullable=False,
),
sa.Column(
"review_run_id",
postgresql.UUID(as_uuid=True),
nullable=False,
),
sa.Column("verdict", sa.Text(), nullable=False),
sa.Column("confidence", sa.Text(), nullable=True),
sa.Column("evidence", sa.Text(), nullable=True),
sa.Column("explanation", sa.Text(), nullable=True),
sa.Column("suggested_action", sa.Text(), nullable=True),
sa.Column("llm_model", sa.Text(), nullable=True),
sa.Column("llm_tokens_used", sa.Integer(), nullable=True),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.ForeignKeyConstraint(
["proposal_id"],
["proposal_review_proposal.id"],
ondelete="CASCADE",
),
sa.ForeignKeyConstraint(
["rule_id"],
["proposal_review_rule.id"],
ondelete="CASCADE",
),
sa.ForeignKeyConstraint(
["review_run_id"],
["proposal_review_run.id"],
ondelete="CASCADE",
),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
"ix_proposal_review_finding_proposal_id",
"proposal_review_finding",
["proposal_id"],
)
op.create_index(
"ix_proposal_review_finding_review_run_id",
"proposal_review_finding",
["review_run_id"],
)
# -- proposal_review_decision (per-finding) --
op.create_table(
"proposal_review_decision",
sa.Column(
"id",
postgresql.UUID(as_uuid=True),
server_default=sa.text("gen_random_uuid()"),
nullable=False,
),
sa.Column(
"finding_id",
postgresql.UUID(as_uuid=True),
nullable=False,
),
sa.Column(
"officer_id",
fastapi_users_db_sqlalchemy.generics.GUID(),
nullable=False,
),
sa.Column("action", sa.Text(), nullable=False),
sa.Column("notes", sa.Text(), nullable=True),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.ForeignKeyConstraint(
["finding_id"],
["proposal_review_finding.id"],
ondelete="CASCADE",
),
sa.ForeignKeyConstraint(["officer_id"], ["user.id"]),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("finding_id"),
)
# -- proposal_review_proposal_decision --
op.create_table(
"proposal_review_proposal_decision",
sa.Column(
"id",
postgresql.UUID(as_uuid=True),
server_default=sa.text("gen_random_uuid()"),
nullable=False,
),
sa.Column(
"proposal_id",
postgresql.UUID(as_uuid=True),
nullable=False,
),
sa.Column(
"officer_id",
fastapi_users_db_sqlalchemy.generics.GUID(),
nullable=False,
),
sa.Column("decision", sa.Text(), nullable=False),
sa.Column("notes", sa.Text(), nullable=True),
sa.Column(
"jira_synced",
sa.Boolean(),
server_default=sa.text("false"),
nullable=False,
),
sa.Column("jira_synced_at", sa.DateTime(timezone=True), nullable=True),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.ForeignKeyConstraint(
["proposal_id"],
["proposal_review_proposal.id"],
ondelete="CASCADE",
),
sa.ForeignKeyConstraint(["officer_id"], ["user.id"]),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
"ix_proposal_review_proposal_decision_proposal_id",
"proposal_review_proposal_decision",
["proposal_id"],
)
# -- proposal_review_document --
op.create_table(
"proposal_review_document",
sa.Column(
"id",
postgresql.UUID(as_uuid=True),
server_default=sa.text("gen_random_uuid()"),
nullable=False,
),
sa.Column(
"proposal_id",
postgresql.UUID(as_uuid=True),
nullable=False,
),
sa.Column("file_name", sa.Text(), nullable=False),
sa.Column("file_type", sa.Text(), nullable=True),
sa.Column("file_store_id", sa.Text(), nullable=True),
sa.Column("extracted_text", sa.Text(), nullable=True),
sa.Column("document_role", sa.Text(), nullable=False),
sa.Column(
"uploaded_by",
fastapi_users_db_sqlalchemy.generics.GUID(),
nullable=True,
),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.ForeignKeyConstraint(
["proposal_id"],
["proposal_review_proposal.id"],
ondelete="CASCADE",
),
sa.ForeignKeyConstraint(["uploaded_by"], ["user.id"]),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
"ix_proposal_review_document_proposal_id",
"proposal_review_document",
["proposal_id"],
)
# -- proposal_review_audit_log --
op.create_table(
"proposal_review_audit_log",
sa.Column(
"id",
postgresql.UUID(as_uuid=True),
server_default=sa.text("gen_random_uuid()"),
nullable=False,
),
sa.Column(
"proposal_id",
postgresql.UUID(as_uuid=True),
nullable=False,
),
sa.Column(
"user_id",
fastapi_users_db_sqlalchemy.generics.GUID(),
nullable=True,
),
sa.Column("action", sa.Text(), nullable=False),
sa.Column("details", postgresql.JSONB(), nullable=True),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.ForeignKeyConstraint(
["proposal_id"],
["proposal_review_proposal.id"],
ondelete="CASCADE",
),
sa.ForeignKeyConstraint(["user_id"], ["user.id"]),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
"ix_proposal_review_audit_log_proposal_id",
"proposal_review_audit_log",
["proposal_id"],
)
# -- proposal_review_config --
op.create_table(
"proposal_review_config",
sa.Column(
"id",
postgresql.UUID(as_uuid=True),
server_default=sa.text("gen_random_uuid()"),
nullable=False,
),
sa.Column("tenant_id", sa.Text(), nullable=False, unique=True),
sa.Column("jira_connector_id", sa.Integer(), nullable=True),
sa.Column("jira_project_key", sa.Text(), nullable=True),
sa.Column("field_mapping", postgresql.JSONB(), nullable=True),
sa.Column("jira_writeback", postgresql.JSONB(), nullable=True),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.PrimaryKeyConstraint("id"),
)
def downgrade() -> None:
op.drop_table("proposal_review_config")
op.drop_table("proposal_review_audit_log")
op.drop_table("proposal_review_document")
op.drop_table("proposal_review_proposal_decision")
op.drop_table("proposal_review_decision")
op.drop_table("proposal_review_finding")
op.drop_table("proposal_review_run")
op.drop_table("proposal_review_proposal")
op.drop_table("proposal_review_rule")
op.drop_table("proposal_review_ruleset")

View File

@@ -1,139 +0,0 @@
"""seed_default_groups
Revision ID: 977e834c1427
Revises: 8188861f4e92
Create Date: 2026-03-25 14:59:41.313091
"""
from typing import Any
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects.postgresql import insert as pg_insert
# revision identifiers, used by Alembic.
revision = "977e834c1427"
down_revision = "8188861f4e92"
branch_labels = None
depends_on = None
# (group_name, permission_value)
DEFAULT_GROUPS = [
("Admin", "admin"),
("Basic", "basic"),
]
CUSTOM_SUFFIX = "(Custom)"
MAX_RENAME_ATTEMPTS = 100
# Reflect table structures for use in DML
user_group_table = sa.table(
"user_group",
sa.column("id", sa.Integer),
sa.column("name", sa.String),
sa.column("is_up_to_date", sa.Boolean),
sa.column("is_up_for_deletion", sa.Boolean),
sa.column("is_default", sa.Boolean),
)
permission_grant_table = sa.table(
"permission_grant",
sa.column("group_id", sa.Integer),
sa.column("permission", sa.String),
sa.column("grant_source", sa.String),
)
user__user_group_table = sa.table(
"user__user_group",
sa.column("user_group_id", sa.Integer),
sa.column("user_id", sa.Uuid),
)
def _find_available_name(conn: sa.engine.Connection, base: str) -> str:
"""Return a name like 'Admin (Custom)' or 'Admin (Custom 2)' that is not taken."""
candidate = f"{base} {CUSTOM_SUFFIX}"
attempt = 1
while attempt <= MAX_RENAME_ATTEMPTS:
exists: Any = conn.execute(
sa.select(sa.literal(1))
.select_from(user_group_table)
.where(user_group_table.c.name == candidate)
.limit(1)
).fetchone()
if exists is None:
return candidate
attempt += 1
candidate = f"{base} (Custom {attempt})"
raise RuntimeError(
f"Could not find an available name for group '{base}' "
f"after {MAX_RENAME_ATTEMPTS} attempts"
)
def upgrade() -> None:
conn = op.get_bind()
for group_name, permission_value in DEFAULT_GROUPS:
# Step 1: Rename ALL existing groups that clash with the canonical name.
conflicting = conn.execute(
sa.select(user_group_table.c.id, user_group_table.c.name).where(
user_group_table.c.name == group_name
)
).fetchall()
for row_id, row_name in conflicting:
new_name = _find_available_name(conn, row_name)
op.execute(
sa.update(user_group_table)
.where(user_group_table.c.id == row_id)
.values(name=new_name, is_up_to_date=False)
)
# Step 2: Create a fresh default group.
result = conn.execute(
user_group_table.insert()
.values(
name=group_name,
is_up_to_date=True,
is_up_for_deletion=False,
is_default=True,
)
.returning(user_group_table.c.id)
).fetchone()
assert result is not None
group_id = result[0]
# Step 3: Upsert permission grant.
op.execute(
pg_insert(permission_grant_table)
.values(
group_id=group_id,
permission=permission_value,
grant_source="SYSTEM",
)
.on_conflict_do_nothing(index_elements=["group_id", "permission"])
)
def downgrade() -> None:
# Remove the default groups created by this migration.
# First remove user-group memberships that reference default groups
# to avoid FK violations, then delete the groups themselves.
default_group_ids = sa.select(user_group_table.c.id).where(
user_group_table.c.is_default == True # noqa: E712
)
conn = op.get_bind()
conn.execute(
sa.delete(user__user_group_table).where(
user__user_group_table.c.user_group_id.in_(default_group_ids)
)
)
conn.execute(
sa.delete(user_group_table).where(
user_group_table.c.is_default == True # noqa: E712
)
)

View File

@@ -1,84 +0,0 @@
"""grant_basic_to_existing_groups
Grants the "basic" permission to all existing groups that don't already
have it. Every group should have at least "basic" so that its members
get basic access when effective_permissions is backfilled.
Revision ID: b4b7e1028dfd
Revises: b7bcc991d722
Create Date: 2026-03-30 16:15:17.093498
"""
from collections.abc import Sequence
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "b4b7e1028dfd"
down_revision = "b7bcc991d722"
branch_labels: str | None = None
depends_on: str | Sequence[str] | None = None
user_group = sa.table(
"user_group",
sa.column("id", sa.Integer),
sa.column("is_default", sa.Boolean),
)
permission_grant = sa.table(
"permission_grant",
sa.column("group_id", sa.Integer),
sa.column("permission", sa.String),
sa.column("grant_source", sa.String),
sa.column("is_deleted", sa.Boolean),
)
def upgrade() -> None:
conn = op.get_bind()
already_has_basic = (
sa.select(sa.literal(1))
.select_from(permission_grant)
.where(
permission_grant.c.group_id == user_group.c.id,
permission_grant.c.permission == "basic",
)
.exists()
)
groups_needing_basic = sa.select(
user_group.c.id,
sa.literal("basic").label("permission"),
sa.literal("SYSTEM").label("grant_source"),
sa.literal(False).label("is_deleted"),
).where(
user_group.c.is_default == sa.false(),
~already_has_basic,
)
conn.execute(
permission_grant.insert().from_select(
["group_id", "permission", "grant_source", "is_deleted"],
groups_needing_basic,
)
)
def downgrade() -> None:
conn = op.get_bind()
non_default_group_ids = sa.select(user_group.c.id).where(
user_group.c.is_default == sa.false()
)
conn.execute(
permission_grant.delete().where(
permission_grant.c.permission == "basic",
permission_grant.c.grant_source == "SYSTEM",
permission_grant.c.group_id.in_(non_default_group_ids),
)
)

View File

@@ -1,125 +0,0 @@
"""assign_users_to_default_groups
Revision ID: b7bcc991d722
Revises: 03d085c5c38d
Create Date: 2026-03-25 16:30:39.529301
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects.postgresql import insert as pg_insert
# revision identifiers, used by Alembic.
revision = "b7bcc991d722"
down_revision = "03d085c5c38d"
branch_labels = None
depends_on = None
# The no-auth placeholder user must NOT be assigned to default groups.
# A database trigger (migrate_no_auth_data_to_user) will try to DELETE this
# user when the first real user registers; group membership rows would cause
# an FK violation on that DELETE.
NO_AUTH_PLACEHOLDER_USER_UUID = "00000000-0000-0000-0000-000000000001"
# Reflect table structures for use in DML
user_group_table = sa.table(
"user_group",
sa.column("id", sa.Integer),
sa.column("name", sa.String),
sa.column("is_default", sa.Boolean),
)
user_table = sa.table(
"user",
sa.column("id", sa.Uuid),
sa.column("role", sa.String),
sa.column("account_type", sa.String),
sa.column("is_active", sa.Boolean),
)
user__user_group_table = sa.table(
"user__user_group",
sa.column("user_group_id", sa.Integer),
sa.column("user_id", sa.Uuid),
)
def upgrade() -> None:
conn = op.get_bind()
# Look up default group IDs
admin_row = conn.execute(
sa.select(user_group_table.c.id).where(
user_group_table.c.name == "Admin",
user_group_table.c.is_default == True, # noqa: E712
)
).fetchone()
basic_row = conn.execute(
sa.select(user_group_table.c.id).where(
user_group_table.c.name == "Basic",
user_group_table.c.is_default == True, # noqa: E712
)
).fetchone()
if admin_row is None:
raise RuntimeError(
"Default 'Admin' group not found. "
"Ensure migration 977e834c1427 (seed_default_groups) ran successfully."
)
if basic_row is None:
raise RuntimeError(
"Default 'Basic' group not found. "
"Ensure migration 977e834c1427 (seed_default_groups) ran successfully."
)
# Users with role=admin → Admin group
# Include inactive users so reactivation doesn't require reconciliation.
# Exclude non-human account types (mirrors assign_user_to_default_groups logic).
admin_users = sa.select(
sa.literal(admin_row[0]).label("user_group_id"),
user_table.c.id.label("user_id"),
).where(
user_table.c.role == "ADMIN",
user_table.c.account_type.notin_(["BOT", "EXT_PERM_USER", "ANONYMOUS"]),
user_table.c.id != NO_AUTH_PLACEHOLDER_USER_UUID,
)
op.execute(
pg_insert(user__user_group_table)
.from_select(["user_group_id", "user_id"], admin_users)
.on_conflict_do_nothing(index_elements=["user_group_id", "user_id"])
)
# STANDARD users (non-admin) and SERVICE_ACCOUNT users (role=basic) → Basic group
# Include inactive users so reactivation doesn't require reconciliation.
basic_users = sa.select(
sa.literal(basic_row[0]).label("user_group_id"),
user_table.c.id.label("user_id"),
).where(
user_table.c.account_type.notin_(["BOT", "EXT_PERM_USER", "ANONYMOUS"]),
user_table.c.id != NO_AUTH_PLACEHOLDER_USER_UUID,
sa.or_(
sa.and_(
user_table.c.account_type == "STANDARD",
user_table.c.role != "ADMIN",
),
sa.and_(
user_table.c.account_type == "SERVICE_ACCOUNT",
user_table.c.role == "BASIC",
),
),
)
op.execute(
pg_insert(user__user_group_table)
.from_select(["user_group_id", "user_id"], basic_users)
.on_conflict_do_nothing(index_elements=["user_group_id", "user_id"])
)
def downgrade() -> None:
# Group memberships are left in place — removing them risks
# deleting memberships that existed before this migration.
pass

View File

@@ -1,9 +1,11 @@
import asyncio
from logging.config import fileConfig
from typing import Literal
from sqlalchemy import pool
from sqlalchemy.engine import Connection
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.schema import SchemaItem
from alembic import context
from onyx.db.engine.sql_engine import build_connection_string
@@ -33,6 +35,27 @@ target_metadata = [PublicBase.metadata]
# my_important_option = config.get_main_option("my_important_option")
# ... etc.
EXCLUDE_TABLES = {"kombu_queue", "kombu_message"}
def include_object(
object: SchemaItem, # noqa: ARG001
name: str | None,
type_: Literal[
"schema",
"table",
"column",
"index",
"unique_constraint",
"foreign_key_constraint",
],
reflected: bool, # noqa: ARG001
compare_to: SchemaItem | None, # noqa: ARG001
) -> bool:
if type_ == "table" and name in EXCLUDE_TABLES:
return False
return True
def run_migrations_offline() -> None:
"""Run migrations in 'offline' mode.
@@ -62,6 +85,7 @@ def do_run_migrations(connection: Connection) -> None:
context.configure(
connection=connection,
target_metadata=target_metadata, # type: ignore[arg-type]
include_object=include_object,
)
with context.begin_transaction():

View File

@@ -10,10 +10,9 @@ from fastapi import status
from ee.onyx.configs.app_configs import SUPER_CLOUD_API_KEY
from ee.onyx.configs.app_configs import SUPER_USERS
from ee.onyx.server.seeding import get_seed_config
from onyx.auth.permissions import require_permission
from onyx.auth.users import current_admin_user
from onyx.configs.app_configs import AUTH_TYPE
from onyx.configs.app_configs import USER_AUTH_SECRET
from onyx.db.enums import Permission
from onyx.db.models import User
from onyx.utils.logger import setup_logger
@@ -40,7 +39,7 @@ def get_default_admin_user_emails_() -> list[str]:
async def current_cloud_superuser(
request: Request,
user: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
user: User = Depends(current_admin_user),
) -> User:
api_key = request.headers.get("Authorization", "").replace("Bearer ", "")
if api_key != SUPER_CLOUD_API_KEY:

View File

@@ -5,7 +5,6 @@ from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from redis.lock import Lock as RedisLock
from ee.onyx.server.tenants.product_gating import get_gated_tenants
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.tasks.beat_schedule import BEAT_EXPIRES_DEFAULT
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
@@ -31,7 +30,6 @@ def cloud_beat_task_generator(
queue: str = OnyxCeleryTask.DEFAULT,
priority: int = OnyxCeleryPriority.MEDIUM,
expires: int = BEAT_EXPIRES_DEFAULT,
skip_gated: bool = True,
) -> bool | None:
"""a lightweight task used to kick off individual beat tasks per tenant."""
time_start = time.monotonic()
@@ -50,22 +48,20 @@ def cloud_beat_task_generator(
last_lock_time = time.monotonic()
tenant_ids: list[str] = []
num_processed_tenants = 0
num_skipped_gated = 0
try:
tenant_ids = get_all_tenant_ids()
# Per-task control over whether gated tenants are included. Most periodic tasks
# do no useful work on gated tenants and just waste DB connections fanning out
# to ~10k+ inactive tenants. A small number of cleanup tasks (connector deletion,
# checkpoint/index attempt cleanup) need to run on gated tenants and pass
# `skip_gated=False` from the beat schedule.
gated_tenants: set[str] = get_gated_tenants() if skip_gated else set()
# NOTE: for now, we are running tasks for gated tenants, since we want to allow
# connector deletion to run successfully. The new plan is to continously prune
# the gated tenants set, so we won't have a build up of old, unused gated tenants.
# Keeping this around in case we want to revert to the previous behavior.
# gated_tenants = get_gated_tenants()
for tenant_id in tenant_ids:
if tenant_id in gated_tenants:
num_skipped_gated += 1
continue
# Same comment here as the above NOTE
# if tenant_id in gated_tenants:
# continue
current_time = time.monotonic()
if current_time - last_lock_time >= (CELERY_GENERIC_BEAT_LOCK_TIMEOUT / 4):
@@ -108,7 +104,6 @@ def cloud_beat_task_generator(
f"cloud_beat_task_generator finished: "
f"task={task_name} "
f"num_processed_tenants={num_processed_tenants} "
f"num_skipped_gated={num_skipped_gated} "
f"num_tenants={len(tenant_ids)} "
f"elapsed={time_elapsed:.2f}"
)

View File

@@ -27,13 +27,13 @@ from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import TENANT_ID_PREFIX
# Maximum tenants to provision in a single task run.
# Each tenant takes ~80s (alembic migrations), so 15 tenants ≈ 20 minutes.
_MAX_TENANTS_PER_RUN = 15
# Each tenant takes ~80s (alembic migrations), so 5 tenants ≈ 7 minutes.
_MAX_TENANTS_PER_RUN = 5
# Time limits sized for worst-case: provisioning up to _MAX_TENANTS_PER_RUN new tenants
# (~90s each) plus migrating up to TARGET_AVAILABLE_TENANTS pool tenants (~90s each).
_TENANT_PROVISIONING_SOFT_TIME_LIMIT = 60 * 40 # 40 minutes
_TENANT_PROVISIONING_TIME_LIMIT = 60 * 45 # 45 minutes
_TENANT_PROVISIONING_SOFT_TIME_LIMIT = 60 * 20 # 20 minutes
_TENANT_PROVISIONING_TIME_LIMIT = 60 * 25 # 25 minutes
@shared_task(

View File

@@ -1,14 +1,20 @@
from datetime import datetime
from datetime import timezone
from uuid import UUID
from celery import shared_task
from celery import Task
from ee.onyx.background.celery_utils import should_perform_chat_ttl_check
from ee.onyx.background.task_name_builders import name_chat_ttl_task
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.configs.constants import OnyxCeleryTask
from onyx.db.chat import delete_chat_session
from onyx.db.chat import get_chat_sessions_older_than
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.enums import TaskStatus
from onyx.db.tasks import mark_task_as_finished_with_id
from onyx.db.tasks import register_task
from onyx.server.settings.store import load_settings
from onyx.utils.logger import setup_logger
@@ -23,42 +29,59 @@ logger = setup_logger()
trail=False,
)
def perform_ttl_management_task(
self: Task, retention_limit_days: int, *, tenant_id: str # noqa: ARG001
self: Task, retention_limit_days: int, *, tenant_id: str
) -> None:
task_id = self.request.id
if not task_id:
raise RuntimeError("No task id defined for this task; cannot identify it")
start_time = datetime.now(tz=timezone.utc)
user_id: UUID | None = None
session_id: UUID | None = None
try:
with get_session_with_current_tenant() as db_session:
# we generally want to move off this, but keeping for now
register_task(
db_session=db_session,
task_name=name_chat_ttl_task(retention_limit_days, tenant_id),
task_id=task_id,
status=TaskStatus.STARTED,
start_time=start_time,
)
old_chat_sessions = get_chat_sessions_older_than(
retention_limit_days, db_session
)
for user_id, session_id in old_chat_sessions:
try:
with get_session_with_current_tenant() as db_session:
delete_chat_session(
user_id,
session_id,
db_session,
include_deleted=True,
hard_delete=True,
)
except Exception:
logger.exception(
"Failed to delete chat session "
f"user_id={user_id} session_id={session_id}, "
"continuing with remaining sessions"
# one session per delete so that we don't blow up if a deletion fails.
with get_session_with_current_tenant() as db_session:
delete_chat_session(
user_id,
session_id,
db_session,
include_deleted=True,
hard_delete=True,
)
with get_session_with_current_tenant() as db_session:
mark_task_as_finished_with_id(
db_session=db_session,
task_id=task_id,
success=True,
)
except Exception:
logger.exception(
f"delete_chat_session exceptioned. user_id={user_id} session_id={session_id}"
)
with get_session_with_current_tenant() as db_session:
mark_task_as_finished_with_id(
db_session=db_session,
task_id=task_id,
success=False,
)
raise

View File

@@ -13,7 +13,6 @@ from ee.onyx.server.license.models import LicenseSource
from onyx.auth.schemas import UserRole
from onyx.cache.factory import get_cache_backend
from onyx.configs.constants import ANONYMOUS_USER_EMAIL
from onyx.db.enums import AccountType
from onyx.db.models import License
from onyx.db.models import User
from onyx.utils.logger import setup_logger
@@ -108,13 +107,12 @@ def get_used_seats(tenant_id: str | None = None) -> int:
Get current seat usage directly from database.
For multi-tenant: counts users in UserTenantMapping for this tenant.
For self-hosted: counts all active users.
For self-hosted: counts all active users (excludes EXT_PERM_USER role
and the anonymous system user).
Only human accounts count toward seat limits.
SERVICE_ACCOUNT (API key dummy users), EXT_PERM_USER, and the
anonymous system user are excluded. BOT (Slack users) ARE counted
because they represent real humans and get upgraded to STANDARD
when they log in via web.
TODO: Exclude API key dummy users from seat counting. API keys create
users with emails like `__DANSWER_API_KEY_*` that should not count toward
seat limits. See: https://linear.app/onyx-app/issue/ENG-3518
"""
if MULTI_TENANT:
from ee.onyx.server.tenants.user_mapping import get_tenant_count
@@ -131,7 +129,6 @@ def get_used_seats(tenant_id: str | None = None) -> int:
User.is_active == True, # type: ignore # noqa: E712
User.role != UserRole.EXT_PERM_USER,
User.email != ANONYMOUS_USER_EMAIL, # type: ignore
User.account_type != AccountType.SERVICE_ACCOUNT,
)
)
return result.scalar() or 0

View File

@@ -36,16 +36,13 @@ from ee.onyx.server.scim.filtering import ScimFilter
from ee.onyx.server.scim.filtering import ScimFilterOperator
from ee.onyx.server.scim.models import ScimMappingFields
from onyx.db.dal import DAL
from onyx.db.enums import AccountType
from onyx.db.enums import GrantSource
from onyx.db.enums import Permission
from onyx.db.models import PermissionGrant
from onyx.db.models import ScimGroupMapping
from onyx.db.models import ScimToken
from onyx.db.models import ScimUserMapping
from onyx.db.models import User
from onyx.db.models import User__UserGroup
from onyx.db.models import UserGroup
from onyx.db.models import UserRole
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -283,9 +280,7 @@ class ScimDAL(DAL):
query = (
select(User)
.join(ScimUserMapping, ScimUserMapping.user_id == User.id)
.where(
User.account_type.notin_([AccountType.BOT, AccountType.EXT_PERM_USER])
)
.where(User.role.notin_([UserRole.SLACK_USER, UserRole.EXT_PERM_USER]))
)
if scim_filter:
@@ -526,22 +521,6 @@ class ScimDAL(DAL):
self._session.add(group)
self._session.flush()
def add_permission_grant_to_group(
self,
group_id: int,
permission: Permission,
grant_source: GrantSource,
) -> None:
"""Grant a permission to a group and flush."""
self._session.add(
PermissionGrant(
group_id=group_id,
permission=permission,
grant_source=grant_source,
)
)
self._session.flush()
def update_group(
self,
group: UserGroup,

View File

@@ -19,8 +19,6 @@ from onyx.configs.app_configs import DISABLE_VECTOR_DB
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import GrantSource
from onyx.db.enums import Permission
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import Credential
from onyx.db.models import Credential__UserGroup
@@ -30,7 +28,6 @@ from onyx.db.models import DocumentSet
from onyx.db.models import DocumentSet__UserGroup
from onyx.db.models import FederatedConnector__DocumentSet
from onyx.db.models import LLMProvider__UserGroup
from onyx.db.models import PermissionGrant
from onyx.db.models import Persona
from onyx.db.models import Persona__UserGroup
from onyx.db.models import TokenRateLimit__UserGroup
@@ -39,8 +36,6 @@ from onyx.db.models import User__UserGroup
from onyx.db.models import UserGroup
from onyx.db.models import UserGroup__ConnectorCredentialPair
from onyx.db.models import UserRole
from onyx.db.permissions import recompute_permissions_for_group__no_commit
from onyx.db.permissions import recompute_user_permissions__no_commit
from onyx.db.users import fetch_user_by_id
from onyx.utils.logger import setup_logger
@@ -260,7 +255,6 @@ def fetch_user_groups(
db_session: Session,
only_up_to_date: bool = True,
eager_load_for_snapshot: bool = False,
include_default: bool = True,
) -> Sequence[UserGroup]:
"""
Fetches user groups from the database.
@@ -275,7 +269,6 @@ def fetch_user_groups(
to include only up to date user groups. Defaults to `True`.
eager_load_for_snapshot: If True, adds eager loading for all relationships
needed by UserGroup.from_model snapshot creation.
include_default: If False, excludes system default groups (is_default=True).
Returns:
Sequence[UserGroup]: A sequence of `UserGroup` objects matching the query criteria.
@@ -283,8 +276,6 @@ def fetch_user_groups(
stmt = select(UserGroup)
if only_up_to_date:
stmt = stmt.where(UserGroup.is_up_to_date == True) # noqa: E712
if not include_default:
stmt = stmt.where(UserGroup.is_default == False) # noqa: E712
if eager_load_for_snapshot:
stmt = _add_user_group_snapshot_eager_loads(stmt)
return db_session.scalars(stmt).unique().all()
@@ -295,7 +286,6 @@ def fetch_user_groups_for_user(
user_id: UUID,
only_curator_groups: bool = False,
eager_load_for_snapshot: bool = False,
include_default: bool = True,
) -> Sequence[UserGroup]:
stmt = (
select(UserGroup)
@@ -305,8 +295,6 @@ def fetch_user_groups_for_user(
)
if only_curator_groups:
stmt = stmt.where(User__UserGroup.is_curator == True) # noqa: E712
if not include_default:
stmt = stmt.where(UserGroup.is_default == False) # noqa: E712
if eager_load_for_snapshot:
stmt = _add_user_group_snapshot_eager_loads(stmt)
return db_session.scalars(stmt).unique().all()
@@ -490,16 +478,6 @@ def insert_user_group(db_session: Session, user_group: UserGroupCreate) -> UserG
db_session.add(db_user_group)
db_session.flush() # give the group an ID
# Every group gets the "basic" permission by default
db_session.add(
PermissionGrant(
group_id=db_user_group.id,
permission=Permission.BASIC_ACCESS,
grant_source=GrantSource.SYSTEM,
)
)
db_session.flush()
_add_user__user_group_relationships__no_commit(
db_session=db_session,
user_group_id=db_user_group.id,
@@ -511,8 +489,6 @@ def insert_user_group(db_session: Session, user_group: UserGroupCreate) -> UserG
cc_pair_ids=user_group.cc_pair_ids,
)
recompute_user_permissions__no_commit(user_group.user_ids, db_session)
db_session.commit()
return db_user_group
@@ -820,10 +796,6 @@ def update_user_group(
# update "time_updated" to now
db_user_group.time_last_modified_by_user = func.now()
recompute_user_permissions__no_commit(
list(set(added_user_ids) | set(removed_user_ids)), db_session
)
db_session.commit()
return db_user_group
@@ -863,19 +835,6 @@ def prepare_user_group_for_deletion(db_session: Session, user_group_id: int) ->
_check_user_group_is_modifiable(db_user_group)
# Collect affected user IDs before cleanup deletes the relationships
affected_user_ids: list[UUID] = [
uid
for uid in db_session.execute(
select(User__UserGroup.user_id).where(
User__UserGroup.user_group_id == user_group_id
)
)
.scalars()
.all()
if uid is not None
]
_mark_user_group__cc_pair_relationships_outdated__no_commit(
db_session=db_session, user_group_id=user_group_id
)
@@ -904,10 +863,6 @@ def prepare_user_group_for_deletion(db_session: Session, user_group_id: int) ->
db_session=db_session, user_group_id=user_group_id
)
# Recompute permissions for affected users now that their
# membership in this group has been removed
recompute_user_permissions__no_commit(affected_user_ids, db_session)
db_user_group.is_up_to_date = False
db_user_group.is_up_for_deletion = True
db_session.commit()
@@ -953,46 +908,3 @@ def delete_user_group_cc_pair_relationship__no_commit(
UserGroup__ConnectorCredentialPair.cc_pair_id == cc_pair_id,
)
db_session.execute(delete_stmt)
def set_group_permission__no_commit(
group_id: int,
permission: Permission,
enabled: bool,
granted_by: UUID,
db_session: Session,
) -> None:
"""Grant or revoke a single permission for a group using soft-delete.
Does NOT commit — caller must commit the session.
"""
existing = db_session.execute(
select(PermissionGrant)
.where(
PermissionGrant.group_id == group_id,
PermissionGrant.permission == permission,
)
.with_for_update()
).scalar_one_or_none()
if enabled:
if existing is not None:
if existing.is_deleted:
existing.is_deleted = False
existing.granted_by = granted_by
existing.granted_at = func.now()
else:
db_session.add(
PermissionGrant(
group_id=group_id,
permission=permission,
grant_source=GrantSource.USER,
granted_by=granted_by,
)
)
else:
if existing is not None and not existing.is_deleted:
existing.is_deleted = True
db_session.flush()
recompute_permissions_for_group__no_commit(group_id, db_session)

View File

@@ -155,7 +155,7 @@ def get_application() -> FastAPI:
include_router_with_global_prefix_prepended(application, license_router)
# Unified billing API - always registered in EE.
# Each endpoint is protected by admin permission checks.
# Each endpoint is protected by the `current_admin_user` dependency (admin auth).
include_router_with_global_prefix_prepended(application, billing_router)
if MULTI_TENANT:

View File

@@ -17,10 +17,10 @@ from ee.onyx.db.analytics import fetch_persona_message_analytics
from ee.onyx.db.analytics import fetch_persona_unique_users
from ee.onyx.db.analytics import fetch_query_analytics
from ee.onyx.db.analytics import user_can_view_assistant_stats
from onyx.auth.permissions import require_permission
from onyx.auth.users import current_admin_user
from onyx.auth.users import current_user
from onyx.configs.constants import PUBLIC_API_TAGS
from onyx.db.engine.sql_engine import get_session
from onyx.db.enums import Permission
from onyx.db.models import User
router = APIRouter(prefix="/analytics", tags=PUBLIC_API_TAGS)
@@ -40,7 +40,7 @@ class QueryAnalyticsResponse(BaseModel):
def get_query_analytics(
start: datetime.datetime | None = None,
end: datetime.datetime | None = None,
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> list[QueryAnalyticsResponse]:
daily_query_usage_info = fetch_query_analytics(
@@ -71,7 +71,7 @@ class UserAnalyticsResponse(BaseModel):
def get_user_analytics(
start: datetime.datetime | None = None,
end: datetime.datetime | None = None,
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> list[UserAnalyticsResponse]:
daily_query_usage_info_per_user = fetch_per_user_query_analytics(
@@ -105,7 +105,7 @@ class OnyxbotAnalyticsResponse(BaseModel):
def get_onyxbot_analytics(
start: datetime.datetime | None = None,
end: datetime.datetime | None = None,
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> list[OnyxbotAnalyticsResponse]:
daily_onyxbot_info = fetch_onyxbot_analytics(
@@ -141,7 +141,7 @@ def get_persona_messages(
persona_id: int,
start: datetime.datetime | None = None,
end: datetime.datetime | None = None,
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> list[PersonaMessageAnalyticsResponse]:
"""Fetch daily message counts for a single persona within the given time range."""
@@ -179,7 +179,7 @@ def get_persona_unique_users(
persona_id: int,
start: datetime.datetime,
end: datetime.datetime,
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> list[PersonaUniqueUsersResponse]:
"""Get unique users per day for a single persona."""
@@ -218,7 +218,7 @@ def get_assistant_stats(
assistant_id: int,
start: datetime.datetime | None = None,
end: datetime.datetime | None = None,
user: User = Depends(require_permission(Permission.BASIC_ACCESS)),
user: User = Depends(current_user),
db_session: Session = Depends(get_session),
) -> AssistantStatsResponse:
"""

View File

@@ -29,6 +29,7 @@ from fastapi import Depends
from pydantic import BaseModel
from sqlalchemy.orm import Session
from ee.onyx.auth.users import current_admin_user
from ee.onyx.db.license import get_license
from ee.onyx.db.license import get_used_seats
from ee.onyx.server.billing.models import BillingInformationResponse
@@ -50,13 +51,11 @@ from ee.onyx.server.billing.service import (
get_billing_information as get_billing_service,
)
from ee.onyx.server.billing.service import update_seat_count as update_seat_service
from onyx.auth.permissions import require_permission
from onyx.auth.users import User
from onyx.configs.app_configs import STRIPE_PUBLISHABLE_KEY_OVERRIDE
from onyx.configs.app_configs import STRIPE_PUBLISHABLE_KEY_URL
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.db.engine.sql_engine import get_session
from onyx.db.enums import Permission
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.redis.redis_pool import get_shared_redis_client
@@ -148,7 +147,7 @@ def _get_tenant_id() -> str | None:
@router.post("/create-checkout-session")
async def create_checkout_session(
request: CreateCheckoutSessionRequest | None = None,
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> CreateCheckoutSessionResponse:
"""Create a Stripe checkout session for new subscription or renewal.
@@ -192,7 +191,7 @@ async def create_checkout_session(
@router.post("/create-customer-portal-session")
async def create_customer_portal_session(
request: CreateCustomerPortalSessionRequest | None = None,
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> CreateCustomerPortalSessionResponse:
"""Create a Stripe customer portal session for managing subscription.
@@ -217,7 +216,7 @@ async def create_customer_portal_session(
@router.get("/billing-information")
async def get_billing_information(
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> BillingInformationResponse | SubscriptionStatusResponse:
"""Get billing information for the current subscription.
@@ -259,7 +258,7 @@ async def get_billing_information(
@router.post("/seats/update")
async def update_seats(
request: SeatUpdateRequest,
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> SeatUpdateResponse:
"""Update the seat count for the current subscription.
@@ -365,7 +364,7 @@ class ResetConnectionResponse(BaseModel):
@router.post("/reset-connection")
async def reset_stripe_connection(
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
_: User = Depends(current_admin_user),
) -> ResetConnectionResponse:
"""Reset the Stripe connection circuit breaker.

View File

@@ -27,12 +27,11 @@ from ee.onyx.server.scim.auth import generate_scim_token
from ee.onyx.server.scim.models import ScimTokenCreate
from ee.onyx.server.scim.models import ScimTokenCreatedResponse
from ee.onyx.server.scim.models import ScimTokenResponse
from onyx.auth.permissions import require_permission
from onyx.auth.users import current_admin_user
from onyx.auth.users import current_user_with_expired_token
from onyx.auth.users import get_user_manager
from onyx.auth.users import UserManager
from onyx.db.engine.sql_engine import get_session
from onyx.db.enums import Permission
from onyx.db.models import User
from onyx.file_store.file_store import get_default_file_store
from onyx.server.utils import BasicAuthenticationError
@@ -121,8 +120,7 @@ async def refresh_access_token(
@admin_router.put("")
def admin_ee_put_settings(
settings: EnterpriseSettings,
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
settings: EnterpriseSettings, _: User = Depends(current_admin_user)
) -> None:
store_settings(settings)
@@ -141,7 +139,7 @@ def ee_fetch_settings() -> EnterpriseSettings:
def put_logo(
file: UploadFile,
is_logotype: bool = False,
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
_: User = Depends(current_admin_user),
) -> None:
upload_logo(file=file, is_logotype=is_logotype)
@@ -198,8 +196,7 @@ def fetch_logo(
@admin_router.put("/custom-analytics-script")
def upload_custom_analytics_script(
script_upload: AnalyticsScriptUpload,
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
script_upload: AnalyticsScriptUpload, _: User = Depends(current_admin_user)
) -> None:
try:
store_analytics_script(script_upload)
@@ -223,7 +220,7 @@ def _get_scim_dal(db_session: Session = Depends(get_session)) -> ScimDAL:
@admin_router.get("/scim/token")
def get_active_scim_token(
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
_: User = Depends(current_admin_user),
dal: ScimDAL = Depends(_get_scim_dal),
) -> ScimTokenResponse:
"""Return the currently active SCIM token's metadata, or 404 if none."""
@@ -253,7 +250,7 @@ def get_active_scim_token(
@admin_router.post("/scim/token", status_code=201)
def create_scim_token(
body: ScimTokenCreate,
user: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
user: User = Depends(current_admin_user),
dal: ScimDAL = Depends(_get_scim_dal),
) -> ScimTokenCreatedResponse:
"""Create a new SCIM bearer token.

View File

@@ -4,13 +4,12 @@ from fastapi import Depends
from fastapi import Query
from sqlalchemy.orm import Session
from onyx.auth.permissions import require_permission
from onyx.auth.users import current_admin_user
from onyx.auth.users import User
from onyx.db.constants import UNSET
from onyx.db.constants import UnsetType
from onyx.db.engine.sql_engine import get_session
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.enums import Permission
from onyx.db.hook import create_hook__no_commit
from onyx.db.hook import delete_hook__no_commit
from onyx.db.hook import get_hook_by_id
@@ -179,7 +178,7 @@ router = APIRouter(prefix="/admin/hooks")
@router.get("/specs")
def get_hook_point_specs(
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
_: User = Depends(current_admin_user),
_hook_enabled: None = Depends(require_hook_enabled),
) -> list[HookPointMetaResponse]:
return [
@@ -200,7 +199,7 @@ def get_hook_point_specs(
@router.get("")
def list_hooks(
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
_: User = Depends(current_admin_user),
_hook_enabled: None = Depends(require_hook_enabled),
db_session: Session = Depends(get_session),
) -> list[HookResponse]:
@@ -211,7 +210,7 @@ def list_hooks(
@router.post("")
def create_hook(
req: HookCreateRequest,
user: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
user: User = Depends(current_admin_user),
_hook_enabled: None = Depends(require_hook_enabled),
db_session: Session = Depends(get_session),
) -> HookResponse:
@@ -247,7 +246,7 @@ def create_hook(
@router.get("/{hook_id}")
def get_hook(
hook_id: int,
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
_: User = Depends(current_admin_user),
_hook_enabled: None = Depends(require_hook_enabled),
db_session: Session = Depends(get_session),
) -> HookResponse:
@@ -259,7 +258,7 @@ def get_hook(
def update_hook(
hook_id: int,
req: HookUpdateRequest,
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
_: User = Depends(current_admin_user),
_hook_enabled: None = Depends(require_hook_enabled),
db_session: Session = Depends(get_session),
) -> HookResponse:
@@ -329,7 +328,7 @@ def update_hook(
@router.delete("/{hook_id}")
def delete_hook(
hook_id: int,
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
_: User = Depends(current_admin_user),
_hook_enabled: None = Depends(require_hook_enabled),
db_session: Session = Depends(get_session),
) -> None:
@@ -340,7 +339,7 @@ def delete_hook(
@router.post("/{hook_id}/activate")
def activate_hook(
hook_id: int,
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
_: User = Depends(current_admin_user),
_hook_enabled: None = Depends(require_hook_enabled),
db_session: Session = Depends(get_session),
) -> HookResponse:
@@ -382,7 +381,7 @@ def activate_hook(
@router.post("/{hook_id}/validate")
def validate_hook(
hook_id: int,
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
_: User = Depends(current_admin_user),
_hook_enabled: None = Depends(require_hook_enabled),
db_session: Session = Depends(get_session),
) -> HookValidateResponse:
@@ -410,7 +409,7 @@ def validate_hook(
@router.post("/{hook_id}/deactivate")
def deactivate_hook(
hook_id: int,
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
_: User = Depends(current_admin_user),
_hook_enabled: None = Depends(require_hook_enabled),
db_session: Session = Depends(get_session),
) -> HookResponse:
@@ -433,7 +432,7 @@ def deactivate_hook(
def list_hook_execution_logs(
hook_id: int,
limit: int = Query(default=10, ge=1, le=100),
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
_: User = Depends(current_admin_user),
_hook_enabled: None = Depends(require_hook_enabled),
db_session: Session = Depends(get_session),
) -> list[HookExecutionRecord]:

View File

@@ -17,6 +17,7 @@ from fastapi import File
from fastapi import UploadFile
from sqlalchemy.orm import Session
from ee.onyx.auth.users import current_admin_user
from ee.onyx.configs.app_configs import CLOUD_DATA_PLANE_URL
from ee.onyx.db.license import delete_license as db_delete_license
from ee.onyx.db.license import get_license
@@ -31,10 +32,8 @@ from ee.onyx.server.license.models import LicenseStatusResponse
from ee.onyx.server.license.models import LicenseUploadResponse
from ee.onyx.server.license.models import SeatUsageResponse
from ee.onyx.utils.license import verify_license_signature
from onyx.auth.permissions import require_permission
from onyx.auth.users import User
from onyx.db.engine.sql_engine import get_session
from onyx.db.enums import Permission
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.utils.logger import setup_logger
@@ -61,7 +60,7 @@ def _strip_pem_delimiters(content: str) -> str:
@router.get("")
async def get_license_status(
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> LicenseStatusResponse:
"""Get current license status and seat usage."""
@@ -85,7 +84,7 @@ async def get_license_status(
@router.get("/seats")
async def get_seat_usage(
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> SeatUsageResponse:
"""Get detailed seat usage information."""
@@ -108,7 +107,7 @@ async def get_seat_usage(
@router.post("/claim")
async def claim_license(
session_id: str | None = None,
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> LicenseResponse:
"""
@@ -216,7 +215,7 @@ async def claim_license(
@router.post("/upload")
async def upload_license(
license_file: UploadFile = File(...),
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> LicenseUploadResponse:
"""
@@ -264,7 +263,7 @@ async def upload_license(
@router.post("/refresh")
async def refresh_license_cache_endpoint(
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> LicenseStatusResponse:
"""
@@ -293,7 +292,7 @@ async def refresh_license_cache_endpoint(
@router.delete("")
async def delete_license(
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> dict[str, bool]:
"""

View File

@@ -12,9 +12,8 @@ from ee.onyx.db.standard_answer import insert_standard_answer_category
from ee.onyx.db.standard_answer import remove_standard_answer
from ee.onyx.db.standard_answer import update_standard_answer
from ee.onyx.db.standard_answer import update_standard_answer_category
from onyx.auth.permissions import require_permission
from onyx.auth.users import current_admin_user
from onyx.db.engine.sql_engine import get_session
from onyx.db.enums import Permission
from onyx.db.models import User
from onyx.server.manage.models import StandardAnswer
from onyx.server.manage.models import StandardAnswerCategory
@@ -28,7 +27,7 @@ router = APIRouter(prefix="/manage")
def create_standard_answer(
standard_answer_creation_request: StandardAnswerCreationRequest,
db_session: Session = Depends(get_session),
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
_: User = Depends(current_admin_user),
) -> StandardAnswer:
standard_answer_model = insert_standard_answer(
keyword=standard_answer_creation_request.keyword,
@@ -44,7 +43,7 @@ def create_standard_answer(
@router.get("/admin/standard-answer")
def list_standard_answers(
db_session: Session = Depends(get_session),
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
_: User = Depends(current_admin_user),
) -> list[StandardAnswer]:
standard_answer_models = fetch_standard_answers(db_session=db_session)
return [
@@ -58,7 +57,7 @@ def patch_standard_answer(
standard_answer_id: int,
standard_answer_creation_request: StandardAnswerCreationRequest,
db_session: Session = Depends(get_session),
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
_: User = Depends(current_admin_user),
) -> StandardAnswer:
existing_standard_answer = fetch_standard_answer(
standard_answer_id=standard_answer_id,
@@ -84,7 +83,7 @@ def patch_standard_answer(
def delete_standard_answer(
standard_answer_id: int,
db_session: Session = Depends(get_session),
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
_: User = Depends(current_admin_user),
) -> None:
return remove_standard_answer(
standard_answer_id=standard_answer_id,
@@ -96,7 +95,7 @@ def delete_standard_answer(
def create_standard_answer_category(
standard_answer_category_creation_request: StandardAnswerCategoryCreationRequest,
db_session: Session = Depends(get_session),
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
_: User = Depends(current_admin_user),
) -> StandardAnswerCategory:
standard_answer_category_model = insert_standard_answer_category(
category_name=standard_answer_category_creation_request.name,
@@ -108,7 +107,7 @@ def create_standard_answer_category(
@router.get("/admin/standard-answer/category")
def list_standard_answer_categories(
db_session: Session = Depends(get_session),
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
_: User = Depends(current_admin_user),
) -> list[StandardAnswerCategory]:
standard_answer_category_models = fetch_standard_answer_categories(
db_session=db_session
@@ -124,7 +123,7 @@ def patch_standard_answer_category(
standard_answer_category_id: int,
standard_answer_category_creation_request: StandardAnswerCategoryCreationRequest,
db_session: Session = Depends(get_session),
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
_: User = Depends(current_admin_user),
) -> StandardAnswerCategory:
existing_standard_answer_category = fetch_standard_answer_category(
standard_answer_category_id=standard_answer_category_id,

View File

@@ -9,10 +9,9 @@ from ee.onyx.server.oauth.api_router import router
from ee.onyx.server.oauth.confluence_cloud import ConfluenceCloudOAuth
from ee.onyx.server.oauth.google_drive import GoogleDriveOAuth
from ee.onyx.server.oauth.slack import SlackOAuth
from onyx.auth.permissions import require_permission
from onyx.auth.users import current_admin_user
from onyx.configs.app_configs import DEV_MODE
from onyx.configs.constants import DocumentSource
from onyx.db.enums import Permission
from onyx.db.models import User
from onyx.redis.redis_pool import get_redis_client
from onyx.utils.logger import setup_logger
@@ -25,7 +24,7 @@ logger = setup_logger()
def prepare_authorization_request(
connector: DocumentSource,
redirect_on_success: str | None,
user: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
user: User = Depends(current_admin_user),
tenant_id: str | None = Depends(get_current_tenant_id),
) -> JSONResponse:
"""Used by the frontend to generate the url for the user's browser during auth request.

View File

@@ -15,7 +15,7 @@ from pydantic import ValidationError
from sqlalchemy.orm import Session
from ee.onyx.server.oauth.api_router import router
from onyx.auth.permissions import require_permission
from onyx.auth.users import current_admin_user
from onyx.configs.app_configs import DEV_MODE
from onyx.configs.app_configs import OAUTH_CONFLUENCE_CLOUD_CLIENT_ID
from onyx.configs.app_configs import OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET
@@ -26,7 +26,6 @@ from onyx.db.credentials import create_credential
from onyx.db.credentials import fetch_credential_by_id_for_user
from onyx.db.credentials import update_credential_json
from onyx.db.engine.sql_engine import get_session
from onyx.db.enums import Permission
from onyx.db.models import User
from onyx.redis.redis_pool import get_redis_client
from onyx.server.documents.models import CredentialBase
@@ -147,7 +146,7 @@ class ConfluenceCloudOAuth:
def confluence_oauth_callback(
code: str,
state: str,
user: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
user: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
tenant_id: str | None = Depends(get_current_tenant_id),
) -> JSONResponse:
@@ -259,7 +258,7 @@ def confluence_oauth_callback(
@router.get("/connector/confluence/accessible-resources")
def confluence_oauth_accessible_resources(
credential_id: int,
user: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
user: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
tenant_id: str | None = Depends(get_current_tenant_id), # noqa: ARG001
) -> JSONResponse:
@@ -326,7 +325,7 @@ def confluence_oauth_finalize(
cloud_id: str,
cloud_name: str,
cloud_url: str,
user: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
user: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
tenant_id: str | None = Depends(get_current_tenant_id), # noqa: ARG001
) -> JSONResponse:

View File

@@ -12,7 +12,7 @@ from pydantic import BaseModel
from sqlalchemy.orm import Session
from ee.onyx.server.oauth.api_router import router
from onyx.auth.permissions import require_permission
from onyx.auth.users import current_admin_user
from onyx.configs.app_configs import DEV_MODE
from onyx.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_ID
from onyx.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
@@ -34,7 +34,6 @@ from onyx.connectors.google_utils.shared_constants import (
)
from onyx.db.credentials import create_credential
from onyx.db.engine.sql_engine import get_session
from onyx.db.enums import Permission
from onyx.db.models import User
from onyx.redis.redis_pool import get_redis_client
from onyx.server.documents.models import CredentialBase
@@ -115,7 +114,7 @@ class GoogleDriveOAuth:
def handle_google_drive_oauth_callback(
code: str,
state: str,
user: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
user: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
tenant_id: str | None = Depends(get_current_tenant_id),
) -> JSONResponse:

View File

@@ -10,7 +10,7 @@ from pydantic import BaseModel
from sqlalchemy.orm import Session
from ee.onyx.server.oauth.api_router import router
from onyx.auth.permissions import require_permission
from onyx.auth.users import current_admin_user
from onyx.configs.app_configs import DEV_MODE
from onyx.configs.app_configs import OAUTH_SLACK_CLIENT_ID
from onyx.configs.app_configs import OAUTH_SLACK_CLIENT_SECRET
@@ -18,7 +18,6 @@ from onyx.configs.app_configs import WEB_DOMAIN
from onyx.configs.constants import DocumentSource
from onyx.db.credentials import create_credential
from onyx.db.engine.sql_engine import get_session
from onyx.db.enums import Permission
from onyx.db.models import User
from onyx.redis.redis_pool import get_redis_client
from onyx.server.documents.models import CredentialBase
@@ -99,7 +98,7 @@ class SlackOAuth:
def handle_slack_oauth_callback(
code: str,
state: str,
user: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
user: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
tenant_id: str | None = Depends(get_current_tenant_id),
) -> JSONResponse:

View File

@@ -8,9 +8,8 @@ from ee.onyx.onyxbot.slack.handlers.handle_standard_answers import (
)
from ee.onyx.server.query_and_chat.models import StandardAnswerRequest
from ee.onyx.server.query_and_chat.models import StandardAnswerResponse
from onyx.auth.permissions import require_permission
from onyx.auth.users import current_user
from onyx.db.engine.sql_engine import get_session
from onyx.db.enums import Permission
from onyx.db.models import User
from onyx.utils.logger import setup_logger
@@ -23,7 +22,7 @@ basic_router = APIRouter(prefix="/query")
def get_standard_answer(
request: StandardAnswerRequest,
db_session: Session = Depends(get_session),
_: User = Depends(require_permission(Permission.BASIC_ACCESS)),
_: User = Depends(current_user),
) -> StandardAnswerResponse:
try:
standard_answers = oneoff_standard_answers(

View File

@@ -19,11 +19,10 @@ from ee.onyx.server.query_and_chat.models import SearchHistoryResponse
from ee.onyx.server.query_and_chat.models import SearchQueryResponse
from ee.onyx.server.query_and_chat.models import SendSearchQueryRequest
from ee.onyx.server.query_and_chat.streaming_models import SearchErrorPacket
from onyx.auth.permissions import require_permission
from onyx.auth.users import current_user
from onyx.configs.app_configs import ONYX_SEARCH_UI_USES_OPENSEARCH_KEYWORD_SEARCH
from onyx.db.engine.sql_engine import get_session
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.enums import Permission
from onyx.db.models import User
from onyx.llm.factory import get_default_llm
from onyx.server.usage_limits import check_llm_cost_limit_for_provider
@@ -40,7 +39,7 @@ router = APIRouter(prefix="/search")
@router.post("/search-flow-classification")
def search_flow_classification(
request: SearchFlowClassificationRequest,
_: User = Depends(require_permission(Permission.BASIC_ACCESS)),
_: User = Depends(current_user),
db_session: Session = Depends(get_session),
) -> SearchFlowClassificationResponse:
query = request.user_query
@@ -80,7 +79,7 @@ def search_flow_classification(
)
def handle_send_search_message(
request: SendSearchQueryRequest,
user: User = Depends(require_permission(Permission.BASIC_ACCESS)),
user: User = Depends(current_user),
db_session: Session = Depends(get_session),
) -> StreamingResponse | SearchFullResponse:
"""
@@ -130,7 +129,7 @@ def handle_send_search_message(
def get_search_history(
limit: int = 100,
filter_days: int | None = None,
user: User = Depends(require_permission(Permission.BASIC_ACCESS)),
user: User = Depends(current_user),
db_session: Session = Depends(get_session),
) -> SearchHistoryResponse:
"""

View File

@@ -20,7 +20,7 @@ from ee.onyx.server.query_history.models import ChatSessionMinimal
from ee.onyx.server.query_history.models import ChatSessionSnapshot
from ee.onyx.server.query_history.models import MessageSnapshot
from ee.onyx.server.query_history.models import QueryHistoryExport
from onyx.auth.permissions import require_permission
from onyx.auth.users import current_admin_user
from onyx.auth.users import get_display_email
from onyx.background.celery.versioned_apps.client import app as client_app
from onyx.background.task_utils import construct_query_history_report_name
@@ -39,7 +39,6 @@ from onyx.configs.constants import SessionType
from onyx.db.chat import get_chat_session_by_id
from onyx.db.chat import get_chat_sessions_by_user
from onyx.db.engine.sql_engine import get_session
from onyx.db.enums import Permission
from onyx.db.enums import TaskStatus
from onyx.db.file_record import get_query_history_export_files
from onyx.db.models import ChatSession
@@ -154,7 +153,7 @@ def snapshot_from_chat_session(
@router.get("/admin/chat-sessions")
def admin_get_chat_sessions(
user_id: UUID,
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> ChatSessionsResponse:
# we specifically don't allow this endpoint if "anonymized" since
@@ -197,7 +196,7 @@ def get_chat_session_history(
feedback_type: QAFeedbackType | None = None,
start_time: datetime | None = None,
end_time: datetime | None = None,
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> PaginatedReturn[ChatSessionMinimal]:
ensure_query_history_is_enabled(disallowed=[QueryHistoryType.DISABLED])
@@ -235,7 +234,7 @@ def get_chat_session_history(
@router.get("/admin/chat-session-history/{chat_session_id}")
def get_chat_session_admin(
chat_session_id: UUID,
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> ChatSessionSnapshot:
ensure_query_history_is_enabled(disallowed=[QueryHistoryType.DISABLED])
@@ -270,7 +269,7 @@ def get_chat_session_admin(
@router.get("/admin/query-history/list")
def list_all_query_history_exports(
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> list[QueryHistoryExport]:
ensure_query_history_is_enabled(disallowed=[QueryHistoryType.DISABLED])
@@ -298,7 +297,7 @@ def list_all_query_history_exports(
@router.post("/admin/query-history/start-export", tags=PUBLIC_API_TAGS)
def start_query_history_export(
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
start: datetime | None = None,
end: datetime | None = None,
@@ -345,7 +344,7 @@ def start_query_history_export(
@router.get("/admin/query-history/export-status", tags=PUBLIC_API_TAGS)
def get_query_history_export_status(
request_id: str,
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> dict[str, str]:
ensure_query_history_is_enabled(disallowed=[QueryHistoryType.DISABLED])
@@ -379,7 +378,7 @@ def get_query_history_export_status(
@router.get("/admin/query-history/download", tags=PUBLIC_API_TAGS)
def download_query_history_csv(
request_id: str,
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> StreamingResponse:
ensure_query_history_is_enabled(disallowed=[QueryHistoryType.DISABLED])

View File

@@ -12,11 +12,10 @@ from sqlalchemy.orm import Session
from ee.onyx.db.usage_export import get_all_usage_reports
from ee.onyx.db.usage_export import get_usage_report_data
from ee.onyx.db.usage_export import UsageReportMetadata
from onyx.auth.permissions import require_permission
from onyx.auth.users import current_admin_user
from onyx.background.celery.versioned_apps.client import app as client_app
from onyx.configs.constants import OnyxCeleryTask
from onyx.db.engine.sql_engine import get_session
from onyx.db.enums import Permission
from onyx.db.models import User
from onyx.file_store.constants import STANDARD_CHUNK_SIZE
from shared_configs.contextvars import get_current_tenant_id
@@ -32,7 +31,7 @@ class GenerateUsageReportParams(BaseModel):
@router.post("/admin/usage-report", status_code=204)
def generate_report(
params: GenerateUsageReportParams,
user: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
user: User = Depends(current_admin_user),
) -> None:
# Validate period parameters
if params.period_from and params.period_to:
@@ -59,7 +58,7 @@ def generate_report(
@router.get("/admin/usage-report/{report_name}")
def read_usage_report(
report_name: str,
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session), # noqa: ARG001
) -> Response:
try:
@@ -83,7 +82,7 @@ def read_usage_report(
@router.get("/admin/usage-report")
def fetch_usage_reports(
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> list[UsageReportMetadata]:
try:

View File

@@ -52,25 +52,16 @@ from ee.onyx.server.scim.schema_definitions import SERVICE_PROVIDER_CONFIG
from ee.onyx.server.scim.schema_definitions import USER_RESOURCE_TYPE
from ee.onyx.server.scim.schema_definitions import USER_SCHEMA_DEF
from onyx.db.engine.sql_engine import get_session
from onyx.db.enums import AccountType
from onyx.db.enums import GrantSource
from onyx.db.enums import Permission
from onyx.db.models import ScimToken
from onyx.db.models import ScimUserMapping
from onyx.db.models import User
from onyx.db.models import UserGroup
from onyx.db.models import UserRole
from onyx.db.permissions import recompute_permissions_for_group__no_commit
from onyx.db.permissions import recompute_user_permissions__no_commit
from onyx.db.users import assign_user_to_default_groups__no_commit
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
logger = setup_logger()
# Group names reserved for system default groups (seeded by migration).
_RESERVED_GROUP_NAMES = frozenset({"Admin", "Basic"})
class ScimJSONResponse(JSONResponse):
"""JSONResponse with Content-Type: application/scim+json (RFC 7644 §3.1)."""
@@ -495,7 +486,6 @@ def create_user(
email=email,
hashed_password=_pw_helper.hash(_pw_helper.generate()),
role=UserRole.BASIC,
account_type=AccountType.STANDARD,
is_active=user_resource.active,
is_verified=True,
personal_name=personal_name,
@@ -516,25 +506,13 @@ def create_user(
scim_username=scim_username,
fields=fields,
)
dal.commit()
except IntegrityError:
dal.rollback()
return _scim_error_response(
409, f"User with email {email} already has a SCIM mapping"
)
# Assign user to default group BEFORE commit so everything is atomic.
# If this fails, the entire user creation rolls back and IdP can retry.
try:
assign_user_to_default_groups__no_commit(db_session, user)
except Exception:
dal.rollback()
logger.exception(f"Failed to assign SCIM user {email} to default groups")
return _scim_error_response(
500, f"Failed to assign user {email} to default group"
)
dal.commit()
return _scim_resource_response(
provider.build_user_resource(
user,
@@ -564,8 +542,7 @@ def replace_user(
user = result
# Handle activation (need seat check) / deactivation
is_reactivation = user_resource.active and not user.is_active
if is_reactivation:
if user_resource.active and not user.is_active:
seat_error = _check_seat_availability(dal)
if seat_error:
return _scim_error_response(403, seat_error)
@@ -579,12 +556,6 @@ def replace_user(
personal_name=personal_name,
)
# Reconcile default-group membership on reactivation
if is_reactivation:
assign_user_to_default_groups__no_commit(
db_session, user, is_admin=(user.role == UserRole.ADMIN)
)
new_external_id = user_resource.externalId
scim_username = user_resource.userName.strip()
fields = _fields_from_resource(user_resource)
@@ -650,7 +621,6 @@ def patch_user(
return _scim_error_response(e.status, e.detail)
# Apply changes back to the DB model
is_reactivation = patched.active and not user.is_active
if patched.active != user.is_active:
if patched.active:
seat_error = _check_seat_availability(dal)
@@ -679,12 +649,6 @@ def patch_user(
personal_name=personal_name,
)
# Reconcile default-group membership on reactivation
if is_reactivation:
assign_user_to_default_groups__no_commit(
db_session, user, is_admin=(user.role == UserRole.ADMIN)
)
# Build updated fields by merging PATCH enterprise data with current values
cf = current_fields or ScimMappingFields()
fields = ScimMappingFields(
@@ -893,11 +857,6 @@ def create_group(
dal = ScimDAL(db_session)
dal.update_token_last_used(_token.id)
if group_resource.displayName in _RESERVED_GROUP_NAMES:
return _scim_error_response(
409, f"'{group_resource.displayName}' is a reserved group name."
)
if dal.get_group_by_name(group_resource.displayName):
return _scim_error_response(
409, f"Group with name '{group_resource.displayName}' already exists"
@@ -920,18 +879,8 @@ def create_group(
409, f"Group with name '{group_resource.displayName}' already exists"
)
# Every group gets the "basic" permission by default.
dal.add_permission_grant_to_group(
group_id=db_group.id,
permission=Permission.BASIC_ACCESS,
grant_source=GrantSource.SYSTEM,
)
dal.upsert_group_members(db_group.id, member_uuids)
# Recompute permissions for initial members.
recompute_user_permissions__no_commit(member_uuids, db_session)
external_id = group_resource.externalId
if external_id:
dal.create_group_mapping(external_id=external_id, user_group_id=db_group.id)
@@ -962,36 +911,14 @@ def replace_group(
return result
group = result
if group.name in _RESERVED_GROUP_NAMES and group_resource.displayName != group.name:
return _scim_error_response(
409, f"'{group.name}' is a reserved group name and cannot be renamed."
)
if (
group_resource.displayName in _RESERVED_GROUP_NAMES
and group_resource.displayName != group.name
):
return _scim_error_response(
409, f"'{group_resource.displayName}' is a reserved group name."
)
member_uuids, err = _validate_and_parse_members(group_resource.members, dal)
if err:
return _scim_error_response(400, err)
# Capture old member IDs before replacing so we can recompute their
# permissions after they are removed from the group.
old_member_ids = {uid for uid, _ in dal.get_group_members(group.id)}
dal.update_group(group, name=group_resource.displayName)
dal.replace_group_members(group.id, member_uuids)
dal.sync_group_external_id(group.id, group_resource.externalId)
# Recompute permissions for current members (batch) and removed members.
recompute_permissions_for_group__no_commit(group.id, db_session)
removed_ids = list(old_member_ids - set(member_uuids))
recompute_user_permissions__no_commit(removed_ids, db_session)
dal.commit()
members = dal.get_group_members(group.id)
@@ -1034,19 +961,8 @@ def patch_group(
return _scim_error_response(e.status, e.detail)
new_name = patched.displayName if patched.displayName != group.name else None
if group.name in _RESERVED_GROUP_NAMES and new_name:
return _scim_error_response(
409, f"'{group.name}' is a reserved group name and cannot be renamed."
)
if new_name and new_name in _RESERVED_GROUP_NAMES:
return _scim_error_response(409, f"'{new_name}' is a reserved group name.")
dal.update_group(group, name=new_name)
affected_uuids: list[UUID] = []
if added_ids:
add_uuids = [UUID(mid) for mid in added_ids if _is_valid_uuid(mid)]
if add_uuids:
@@ -1057,15 +973,10 @@ def patch_group(
f"Member(s) not found: {', '.join(str(u) for u in missing)}",
)
dal.upsert_group_members(group.id, add_uuids)
affected_uuids.extend(add_uuids)
if removed_ids:
remove_uuids = [UUID(mid) for mid in removed_ids if _is_valid_uuid(mid)]
dal.remove_group_members(group.id, remove_uuids)
affected_uuids.extend(remove_uuids)
# Recompute permissions for all users whose group membership changed.
recompute_user_permissions__no_commit(affected_uuids, db_session)
dal.sync_group_external_id(group.id, patched.externalId)
dal.commit()
@@ -1091,21 +1002,11 @@ def delete_group(
return result
group = result
if group.name in _RESERVED_GROUP_NAMES:
return _scim_error_response(409, f"'{group.name}' is a reserved group name.")
# Capture member IDs before deletion so we can recompute their permissions.
affected_user_ids = [uid for uid, _ in dal.get_group_members(group.id)]
mapping = dal.get_group_mapping_by_group_id(group.id)
if mapping:
dal.delete_group_mapping(mapping.id)
dal.delete_group_with_members(group)
# Recompute permissions for users who lost this group membership.
recompute_user_permissions__no_commit(affected_user_ids, db_session)
dal.commit()
return Response(status_code=204)

View File

@@ -12,13 +12,12 @@ from ee.onyx.server.tenants.anonymous_user_path import (
from ee.onyx.server.tenants.anonymous_user_path import modify_anonymous_user_path
from ee.onyx.server.tenants.anonymous_user_path import validate_anonymous_user_path
from ee.onyx.server.tenants.models import AnonymousUserPath
from onyx.auth.permissions import require_permission
from onyx.auth.users import anonymous_user_enabled
from onyx.auth.users import current_admin_user
from onyx.auth.users import User
from onyx.configs.constants import ANONYMOUS_USER_COOKIE_NAME
from onyx.configs.constants import FASTAPI_USERS_AUTH_COOKIE_NAME
from onyx.db.engine.sql_engine import get_session_with_shared_schema
from onyx.db.enums import Permission
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import get_current_tenant_id
@@ -29,7 +28,7 @@ router = APIRouter(prefix="/tenants")
@router.get("/anonymous-user-path")
async def get_anonymous_user_path_api(
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
_: User = Depends(current_admin_user),
) -> AnonymousUserPath:
tenant_id = get_current_tenant_id()
@@ -45,7 +44,7 @@ async def get_anonymous_user_path_api(
@router.post("/anonymous-user-path")
async def set_anonymous_user_path_api(
anonymous_user_path: str,
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
_: User = Depends(current_admin_user),
) -> None:
tenant_id = get_current_tenant_id()
try:

View File

@@ -22,6 +22,7 @@ import httpx
from fastapi import APIRouter
from fastapi import Depends
from ee.onyx.auth.users import current_admin_user
from ee.onyx.server.tenants.access import control_plane_dep
from ee.onyx.server.tenants.billing import fetch_billing_information
from ee.onyx.server.tenants.billing import fetch_customer_portal_session
@@ -37,12 +38,10 @@ from ee.onyx.server.tenants.models import SubscriptionSessionResponse
from ee.onyx.server.tenants.models import SubscriptionStatusResponse
from ee.onyx.server.tenants.product_gating import overwrite_full_gated_set
from ee.onyx.server.tenants.product_gating import store_product_gating
from onyx.auth.permissions import require_permission
from onyx.auth.users import User
from onyx.configs.app_configs import STRIPE_PUBLISHABLE_KEY_OVERRIDE
from onyx.configs.app_configs import STRIPE_PUBLISHABLE_KEY_URL
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.db.enums import Permission
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.utils.logger import setup_logger
@@ -100,7 +99,7 @@ def gate_product_full_sync(
@router.get("/billing-information")
async def billing_information(
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
_: User = Depends(current_admin_user),
) -> BillingInformation | SubscriptionStatusResponse:
logger.info("Fetching billing information")
tenant_id = get_current_tenant_id()
@@ -109,7 +108,7 @@ async def billing_information(
@router.post("/create-customer-portal-session")
async def create_customer_portal_session(
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
_: User = Depends(current_admin_user),
) -> dict:
"""Create a Stripe customer portal session via the control plane."""
tenant_id = get_current_tenant_id()
@@ -131,7 +130,7 @@ async def create_customer_portal_session(
@router.post("/create-checkout-session")
async def create_checkout_session(
request: CreateCheckoutSessionRequest | None = None,
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
_: User = Depends(current_admin_user),
) -> dict:
"""Create a Stripe checkout session via the control plane."""
tenant_id = get_current_tenant_id()
@@ -154,7 +153,7 @@ async def create_checkout_session(
@router.post("/create-subscription-session")
async def create_subscription_session(
request: CreateSubscriptionSessionRequest | None = None,
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
_: User = Depends(current_admin_user),
) -> SubscriptionSessionResponse:
try:
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()

View File

@@ -6,11 +6,10 @@ from sqlalchemy.orm import Session
from ee.onyx.server.tenants.provisioning import delete_user_from_control_plane
from ee.onyx.server.tenants.user_mapping import remove_all_users_from_tenant
from ee.onyx.server.tenants.user_mapping import remove_users_from_tenant
from onyx.auth.permissions import require_permission
from onyx.auth.users import current_admin_user
from onyx.auth.users import User
from onyx.db.auth import get_user_count
from onyx.db.engine.sql_engine import get_session
from onyx.db.enums import Permission
from onyx.db.users import delete_user_from_db
from onyx.db.users import get_user_by_email
from onyx.server.manage.models import UserByEmail
@@ -25,9 +24,7 @@ router = APIRouter(prefix="/tenants")
@router.post("/leave-team")
async def leave_organization(
user_email: UserByEmail,
current_user: User = Depends(
require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)
),
current_user: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> None:
tenant_id = get_current_tenant_id()

View File

@@ -3,9 +3,8 @@ from fastapi import Depends
from ee.onyx.server.tenants.models import TenantByDomainResponse
from ee.onyx.server.tenants.provisioning import get_tenant_by_domain_from_control_plane
from onyx.auth.permissions import require_permission
from onyx.auth.users import current_user
from onyx.auth.users import User
from onyx.db.enums import Permission
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import get_current_tenant_id
@@ -27,7 +26,7 @@ FORBIDDEN_COMMON_EMAIL_SUBSTRINGS = [
@router.get("/existing-team-by-domain")
def get_existing_tenant_by_domain(
user: User = Depends(require_permission(Permission.BASIC_ACCESS)),
user: User = Depends(current_user),
) -> TenantByDomainResponse | None:
domain = user.email.split("@")[1]
if any(substring in domain for substring in FORBIDDEN_COMMON_EMAIL_SUBSTRINGS):

View File

@@ -10,9 +10,9 @@ from ee.onyx.server.tenants.user_mapping import approve_user_invite
from ee.onyx.server.tenants.user_mapping import deny_user_invite
from ee.onyx.server.tenants.user_mapping import invite_self_to_tenant
from onyx.auth.invited_users import get_pending_users
from onyx.auth.permissions import require_permission
from onyx.auth.users import current_admin_user
from onyx.auth.users import current_user
from onyx.auth.users import User
from onyx.db.enums import Permission
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import get_current_tenant_id
@@ -24,7 +24,7 @@ router = APIRouter(prefix="/tenants")
@router.post("/users/invite/request")
async def request_invite(
invite_request: RequestInviteRequest,
user: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
user: User = Depends(current_admin_user),
) -> None:
try:
invite_self_to_tenant(user.email, invite_request.tenant_id)
@@ -37,7 +37,7 @@ async def request_invite(
@router.get("/users/pending")
def list_pending_users(
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
_: User = Depends(current_admin_user),
) -> list[PendingUserSnapshot]:
pending_emails = get_pending_users()
return [PendingUserSnapshot(email=email) for email in pending_emails]
@@ -46,7 +46,7 @@ def list_pending_users(
@router.post("/users/invite/approve")
async def approve_user(
approve_user_request: ApproveUserRequest,
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
_: User = Depends(current_admin_user),
) -> None:
tenant_id = get_current_tenant_id()
approve_user_invite(approve_user_request.email, tenant_id)
@@ -55,7 +55,7 @@ async def approve_user(
@router.post("/users/invite/accept")
async def accept_invite(
invite_request: RequestInviteRequest,
user: User = Depends(require_permission(Permission.BASIC_ACCESS)),
user: User = Depends(current_user),
) -> None:
"""
Accept an invitation to join a tenant.
@@ -70,7 +70,7 @@ async def accept_invite(
@router.post("/users/invite/deny")
async def deny_invite(
invite_request: RequestInviteRequest,
user: User = Depends(require_permission(Permission.BASIC_ACCESS)),
user: User = Depends(current_user),
) -> None:
"""
Deny an invitation to join a tenant.

View File

@@ -7,11 +7,10 @@ from sqlalchemy.orm import Session
from ee.onyx.db.token_limit import fetch_all_user_group_token_rate_limits_by_group
from ee.onyx.db.token_limit import fetch_user_group_token_rate_limits_for_user
from ee.onyx.db.token_limit import insert_user_group_token_rate_limit
from onyx.auth.permissions import require_permission
from onyx.auth.users import current_admin_user
from onyx.auth.users import current_curator_or_admin_user
from onyx.configs.constants import PUBLIC_API_TAGS
from onyx.db.engine.sql_engine import get_session
from onyx.db.enums import Permission
from onyx.db.models import User
from onyx.db.token_limit import fetch_all_user_token_rate_limits
from onyx.db.token_limit import insert_user_token_rate_limit
@@ -29,7 +28,7 @@ Group Token Limit Settings
@router.get("/user-groups")
def get_all_group_token_limit_settings(
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> dict[str, list[TokenRateLimitDisplay]]:
user_groups_to_token_rate_limits = fetch_all_user_group_token_rate_limits_by_group(
@@ -65,7 +64,7 @@ def get_group_token_limit_settings(
def create_group_token_limit_settings(
group_id: int,
token_limit_settings: TokenRateLimitArgs,
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> TokenRateLimitDisplay:
rate_limit_display = TokenRateLimitDisplay.from_db(
@@ -87,7 +86,7 @@ User Token Limit Settings
@router.get("/users")
def get_user_token_limit_settings(
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> list[TokenRateLimitDisplay]:
return [
@@ -99,7 +98,7 @@ def get_user_token_limit_settings(
@router.post("/users")
def create_user_token_limit_settings(
token_limit_settings: TokenRateLimitArgs,
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> TokenRateLimitDisplay:
rate_limit_display = TokenRateLimitDisplay.from_db(

View File

@@ -13,26 +13,22 @@ from ee.onyx.db.user_group import fetch_user_groups_for_user
from ee.onyx.db.user_group import insert_user_group
from ee.onyx.db.user_group import prepare_user_group_for_deletion
from ee.onyx.db.user_group import rename_user_group
from ee.onyx.db.user_group import set_group_permission__no_commit
from ee.onyx.db.user_group import update_user_curator_relationship
from ee.onyx.db.user_group import update_user_group
from ee.onyx.server.user_group.models import AddUsersToUserGroupRequest
from ee.onyx.server.user_group.models import MinimalUserGroupSnapshot
from ee.onyx.server.user_group.models import SetCuratorRequest
from ee.onyx.server.user_group.models import SetPermissionRequest
from ee.onyx.server.user_group.models import SetPermissionResponse
from ee.onyx.server.user_group.models import UpdateGroupAgentsRequest
from ee.onyx.server.user_group.models import UserGroup
from ee.onyx.server.user_group.models import UserGroupCreate
from ee.onyx.server.user_group.models import UserGroupRename
from ee.onyx.server.user_group.models import UserGroupUpdate
from onyx.auth.permissions import NON_TOGGLEABLE_PERMISSIONS
from onyx.auth.permissions import require_permission
from onyx.auth.users import current_admin_user
from onyx.auth.users import current_curator_or_admin_user
from onyx.auth.users import current_user
from onyx.configs.app_configs import DISABLE_VECTOR_DB
from onyx.configs.constants import PUBLIC_API_TAGS
from onyx.db.engine.sql_engine import get_session
from onyx.db.enums import Permission
from onyx.db.models import User
from onyx.db.models import UserRole
from onyx.db.persona import get_persona_by_id
@@ -47,16 +43,12 @@ router = APIRouter(prefix="/manage", tags=PUBLIC_API_TAGS)
@router.get("/admin/user-group")
def list_user_groups(
include_default: bool = False,
user: User = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
) -> list[UserGroup]:
if user.role == UserRole.ADMIN:
user_groups = fetch_user_groups(
db_session,
only_up_to_date=False,
eager_load_for_snapshot=True,
include_default=include_default,
db_session, only_up_to_date=False, eager_load_for_snapshot=True
)
else:
user_groups = fetch_user_groups_for_user(
@@ -64,81 +56,31 @@ def list_user_groups(
user_id=user.id,
only_curator_groups=user.role == UserRole.CURATOR,
eager_load_for_snapshot=True,
include_default=include_default,
)
return [UserGroup.from_model(user_group) for user_group in user_groups]
@router.get("/user-groups/minimal")
def list_minimal_user_groups(
include_default: bool = False,
user: User = Depends(require_permission(Permission.BASIC_ACCESS)),
user: User = Depends(current_user),
db_session: Session = Depends(get_session),
) -> list[MinimalUserGroupSnapshot]:
if user.role == UserRole.ADMIN:
user_groups = fetch_user_groups(
db_session,
only_up_to_date=False,
include_default=include_default,
)
user_groups = fetch_user_groups(db_session, only_up_to_date=False)
else:
user_groups = fetch_user_groups_for_user(
db_session=db_session,
user_id=user.id,
include_default=include_default,
)
return [
MinimalUserGroupSnapshot.from_model(user_group) for user_group in user_groups
]
@router.get("/admin/user-group/{user_group_id}/permissions")
def get_user_group_permissions(
user_group_id: int,
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
db_session: Session = Depends(get_session),
) -> list[Permission]:
group = fetch_user_group(db_session, user_group_id)
if group is None:
raise OnyxError(OnyxErrorCode.NOT_FOUND, "User group not found")
return [
grant.permission for grant in group.permission_grants if not grant.is_deleted
]
@router.put("/admin/user-group/{user_group_id}/permissions")
def set_user_group_permission(
user_group_id: int,
request: SetPermissionRequest,
user: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
db_session: Session = Depends(get_session),
) -> SetPermissionResponse:
group = fetch_user_group(db_session, user_group_id)
if group is None:
raise OnyxError(OnyxErrorCode.NOT_FOUND, "User group not found")
if request.permission in NON_TOGGLEABLE_PERMISSIONS:
raise OnyxError(
OnyxErrorCode.INVALID_INPUT,
f"Permission '{request.permission}' cannot be toggled via this endpoint",
)
set_group_permission__no_commit(
group_id=user_group_id,
permission=request.permission,
enabled=request.enabled,
granted_by=user.id,
db_session=db_session,
)
db_session.commit()
return SetPermissionResponse(permission=request.permission, enabled=request.enabled)
@router.post("/admin/user-group")
def create_user_group(
user_group: UserGroupCreate,
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> UserGroup:
try:
@@ -155,12 +97,9 @@ def create_user_group(
@router.patch("/admin/user-group/rename")
def rename_user_group_endpoint(
rename_request: UserGroupRename,
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> UserGroup:
group = fetch_user_group(db_session, rename_request.id)
if group and group.is_default:
raise OnyxError(OnyxErrorCode.CONFLICT, "Cannot rename a default system group.")
try:
return UserGroup.from_model(
rename_user_group(
@@ -243,12 +182,9 @@ def set_user_curator(
@router.delete("/admin/user-group/{user_group_id}")
def delete_user_group(
user_group_id: int,
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> None:
group = fetch_user_group(db_session, user_group_id)
if group and group.is_default:
raise OnyxError(OnyxErrorCode.CONFLICT, "Cannot delete a default system group.")
try:
prepare_user_group_for_deletion(db_session, user_group_id)
except ValueError as e:
@@ -264,7 +200,7 @@ def delete_user_group(
def update_group_agents(
user_group_id: int,
request: UpdateGroupAgentsRequest,
user: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
user: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> None:
for agent_id in request.added_agent_ids:

View File

@@ -2,7 +2,6 @@ from uuid import UUID
from pydantic import BaseModel
from onyx.auth.permissions import Permission
from onyx.db.models import UserGroup as UserGroupModel
from onyx.server.documents.models import ConnectorCredentialPairDescriptor
from onyx.server.documents.models import ConnectorSnapshot
@@ -23,7 +22,6 @@ class UserGroup(BaseModel):
personas: list[PersonaSnapshot]
is_up_to_date: bool
is_up_for_deletion: bool
is_default: bool
@classmethod
def from_model(cls, user_group_model: UserGroupModel) -> "UserGroup":
@@ -76,21 +74,18 @@ class UserGroup(BaseModel):
],
is_up_to_date=user_group_model.is_up_to_date,
is_up_for_deletion=user_group_model.is_up_for_deletion,
is_default=user_group_model.is_default,
)
class MinimalUserGroupSnapshot(BaseModel):
id: int
name: str
is_default: bool
@classmethod
def from_model(cls, user_group_model: UserGroupModel) -> "MinimalUserGroupSnapshot":
return cls(
id=user_group_model.id,
name=user_group_model.name,
is_default=user_group_model.is_default,
)
@@ -122,13 +117,3 @@ class SetCuratorRequest(BaseModel):
class UpdateGroupAgentsRequest(BaseModel):
added_agent_ids: list[int]
removed_agent_ids: list[int]
class SetPermissionRequest(BaseModel):
permission: Permission
enabled: bool
class SetPermissionResponse(BaseModel):
permission: Permission
enabled: bool

View File

@@ -1,125 +0,0 @@
"""
Permission resolution for group-based authorization.
Granted permissions are stored as a JSONB column on the User table and
loaded for free with every auth query. Implied permissions are expanded
at read time — only directly granted permissions are persisted.
"""
from collections.abc import Callable
from collections.abc import Coroutine
from typing import Any
from fastapi import Depends
from onyx.auth.users import current_user
from onyx.db.enums import Permission
from onyx.db.models import User
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.utils.logger import setup_logger
logger = setup_logger()
ALL_PERMISSIONS: frozenset[str] = frozenset(p.value for p in Permission)
# Implication map: granted permission -> set of permissions it implies.
IMPLIED_PERMISSIONS: dict[str, set[str]] = {
Permission.ADD_AGENTS.value: {Permission.READ_AGENTS.value},
Permission.MANAGE_AGENTS.value: {
Permission.ADD_AGENTS.value,
Permission.READ_AGENTS.value,
},
Permission.MANAGE_DOCUMENT_SETS.value: {
Permission.READ_DOCUMENT_SETS.value,
Permission.READ_CONNECTORS.value,
},
Permission.ADD_CONNECTORS.value: {Permission.READ_CONNECTORS.value},
Permission.MANAGE_CONNECTORS.value: {
Permission.ADD_CONNECTORS.value,
Permission.READ_CONNECTORS.value,
},
Permission.MANAGE_USER_GROUPS.value: {
Permission.READ_CONNECTORS.value,
Permission.READ_DOCUMENT_SETS.value,
Permission.READ_AGENTS.value,
Permission.READ_USERS.value,
},
}
# Permissions that cannot be toggled via the group-permission API.
# BASIC_ACCESS is always granted, FULL_ADMIN_PANEL_ACCESS is too broad,
# and READ_* permissions are implied (never stored directly).
NON_TOGGLEABLE_PERMISSIONS: frozenset[Permission] = frozenset(
{
Permission.BASIC_ACCESS,
Permission.FULL_ADMIN_PANEL_ACCESS,
Permission.READ_CONNECTORS,
Permission.READ_DOCUMENT_SETS,
Permission.READ_AGENTS,
Permission.READ_USERS,
}
)
def resolve_effective_permissions(granted: set[str]) -> set[str]:
"""Expand granted permissions with their implied permissions.
If "admin" is present, returns all 19 permissions.
"""
if Permission.FULL_ADMIN_PANEL_ACCESS.value in granted:
return set(ALL_PERMISSIONS)
effective = set(granted)
changed = True
while changed:
changed = False
for perm in list(effective):
implied = IMPLIED_PERMISSIONS.get(perm)
if implied and not implied.issubset(effective):
effective |= implied
changed = True
return effective
def get_effective_permissions(user: User) -> set[Permission]:
"""Read granted permissions from the column and expand implied permissions."""
granted: set[Permission] = set()
for p in user.effective_permissions:
try:
granted.add(Permission(p))
except ValueError:
logger.warning(f"Skipping unknown permission '{p}' for user {user.id}")
if Permission.FULL_ADMIN_PANEL_ACCESS in granted:
return set(Permission)
expanded = resolve_effective_permissions({p.value for p in granted})
return {Permission(p) for p in expanded}
def require_permission(
required: Permission,
) -> Callable[..., Coroutine[Any, Any, User]]:
"""FastAPI dependency factory for permission-based access control.
Usage:
@router.get("/endpoint")
def endpoint(user: User = Depends(require_permission(Permission.MANAGE_CONNECTORS))):
...
"""
async def dependency(user: User = Depends(current_user)) -> User:
effective = get_effective_permissions(user)
if Permission.FULL_ADMIN_PANEL_ACCESS in effective:
return user
if required not in effective:
raise OnyxError(
OnyxErrorCode.INSUFFICIENT_PERMISSIONS,
"You do not have the required permissions for this action.",
)
return user
dependency._is_require_permission = True # type: ignore[attr-defined] # sentinel for auth_check detection
return dependency

View File

@@ -5,8 +5,6 @@ from typing import Any
from fastapi_users import schemas
from typing_extensions import override
from onyx.db.enums import AccountType
class UserRole(str, Enum):
"""
@@ -43,7 +41,6 @@ class UserRead(schemas.BaseUser[uuid.UUID]):
class UserCreate(schemas.BaseUserCreate):
role: UserRole = UserRole.BASIC
account_type: AccountType = AccountType.STANDARD
tenant_id: str | None = None
# Captcha token for cloud signup protection (optional, only used when captcha is enabled)
# Excluded from create_update_dict so it never reaches the DB layer
@@ -53,19 +50,19 @@ class UserCreate(schemas.BaseUserCreate):
def create_update_dict(self) -> dict[str, Any]:
d = super().create_update_dict()
d.pop("captcha_token", None)
# Force STANDARD for self-registration; only trusted paths
# (SCIM, API key creation) supply a different account_type directly.
d["account_type"] = AccountType.STANDARD
return d
@override
def create_update_dict_superuser(self) -> dict[str, Any]:
d = super().create_update_dict_superuser()
d.pop("captcha_token", None)
d.setdefault("account_type", self.account_type)
return d
class UserUpdateWithRole(schemas.BaseUserUpdate):
role: UserRole
class UserUpdate(schemas.BaseUserUpdate):
"""
Role updates are not allowed through the user update endpoint for security reasons

View File

@@ -80,6 +80,7 @@ from onyx.auth.pat import get_hashed_pat_from_request
from onyx.auth.schemas import AuthBackend
from onyx.auth.schemas import UserCreate
from onyx.auth.schemas import UserRole
from onyx.auth.schemas import UserUpdateWithRole
from onyx.configs.app_configs import AUTH_BACKEND
from onyx.configs.app_configs import AUTH_COOKIE_EXPIRE_TIME_SECONDS
from onyx.configs.app_configs import AUTH_TYPE
@@ -119,15 +120,12 @@ from onyx.db.engine.async_sql_engine import get_async_session
from onyx.db.engine.async_sql_engine import get_async_session_context_manager
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.engine.sql_engine import get_session_with_tenant
from onyx.db.enums import AccountType
from onyx.db.models import AccessToken
from onyx.db.models import OAuthAccount
from onyx.db.models import Persona
from onyx.db.models import User
from onyx.db.pat import fetch_user_for_pat
from onyx.db.users import assign_user_to_default_groups__no_commit
from onyx.db.users import get_user_by_email
from onyx.db.users import is_limited_user
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import log_onyx_error
from onyx.error_handling.exceptions import onyx_error_to_json_response
@@ -502,21 +500,18 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
user = user_by_session
if (
user.account_type.is_web_login()
user.role.is_web_login()
or not isinstance(user_create, UserCreate)
or not user_create.account_type.is_web_login()
or not user_create.role.is_web_login()
):
raise exceptions.UserAlreadyExists()
# Cache id before expire — accessing attrs on an expired
# object triggers a sync lazy-load which raises MissingGreenlet
# in this async context.
user_id = user.id
self._upgrade_user_to_standard__sync(user_id, user_create)
# Expire so the async session re-fetches the row updated by
# the sync session above.
self.user_db.session.expire(user)
user = await self.user_db.get(user_id) # type: ignore[assignment]
user_update = UserUpdateWithRole(
password=user_create.password,
is_verified=user_create.is_verified,
role=user_create.role,
)
user = await self.update(user_update, user)
except exceptions.UserAlreadyExists:
user = await self.get_by_email(user_create.email)
@@ -530,21 +525,18 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
# Handle case where user has used product outside of web and is now creating an account through web
if (
user.account_type.is_web_login()
user.role.is_web_login()
or not isinstance(user_create, UserCreate)
or not user_create.account_type.is_web_login()
or not user_create.role.is_web_login()
):
raise exceptions.UserAlreadyExists()
# Cache id before expire — accessing attrs on an expired
# object triggers a sync lazy-load which raises MissingGreenlet
# in this async context.
user_id = user.id
self._upgrade_user_to_standard__sync(user_id, user_create)
# Expire so the async session re-fetches the row updated by
# the sync session above.
self.user_db.session.expire(user)
user = await self.user_db.get(user_id) # type: ignore[assignment]
user_update = UserUpdateWithRole(
password=user_create.password,
is_verified=user_create.is_verified,
role=user_create.role,
)
user = await self.update(user_update, user)
if user_created:
await self._assign_default_pinned_assistants(user, db_session)
remove_user_from_invited_users(user_create.email)
@@ -581,38 +573,6 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
)
user.pinned_assistants = default_persona_ids
def _upgrade_user_to_standard__sync(
self,
user_id: uuid.UUID,
user_create: UserCreate,
) -> None:
"""Upgrade a non-web user to STANDARD and assign default groups atomically.
All writes happen in a single sync transaction so neither the field
update nor the group assignment is visible without the other.
"""
with get_session_with_current_tenant() as sync_db:
sync_user = sync_db.query(User).filter(User.id == user_id).first() # type: ignore[arg-type]
if sync_user:
sync_user.hashed_password = self.password_helper.hash(
user_create.password
)
sync_user.is_verified = user_create.is_verified or False
sync_user.role = user_create.role
sync_user.account_type = AccountType.STANDARD
assign_user_to_default_groups__no_commit(
sync_db,
sync_user,
is_admin=(user_create.role == UserRole.ADMIN),
)
sync_db.commit()
else:
logger.warning(
"User %s not found in sync session during upgrade to standard; "
"skipping upgrade",
user_id,
)
async def validate_password(self, password: str, _: schemas.UC | models.UP) -> None:
# Validate password according to configurable security policy (defined via environment variables)
if len(password) < PASSWORD_MIN_LENGTH:
@@ -734,7 +694,6 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
"email": account_email,
"hashed_password": self.password_helper.hash(password),
"is_verified": is_verified_by_default,
"account_type": AccountType.STANDARD,
}
user = await self.user_db.create(user_dict)
@@ -767,7 +726,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
)
# Handle case where user has used product outside of web and is now creating an account through web
if not user.account_type.is_web_login():
if not user.role.is_web_login():
# We must use the existing user in the session if it matches
# the user we just got by email/oauth. Note that this only applies
# to multi-tenant, due to the overwriting of the user_db
@@ -784,25 +743,14 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
with get_session_with_current_tenant() as sync_db:
enforce_seat_limit(sync_db)
# Upgrade the user and assign default groups in a single
# transaction so neither change is visible without the other.
was_inactive = not user.is_active
with get_session_with_current_tenant() as sync_db:
sync_user = sync_db.query(User).filter(User.id == user.id).first() # type: ignore[arg-type]
if sync_user:
sync_user.is_verified = is_verified_by_default
sync_user.role = UserRole.BASIC
sync_user.account_type = AccountType.STANDARD
if was_inactive:
sync_user.is_active = True
assign_user_to_default_groups__no_commit(sync_db, sync_user)
sync_db.commit()
# Refresh the async user object so downstream code
# (e.g. oidc_expiry check) sees the updated fields.
self.user_db.session.expire(user)
user = await self.user_db.get(user.id)
assert user is not None
await self.user_db.update(
user,
{
"is_verified": is_verified_by_default,
"role": UserRole.BASIC,
**({"is_active": True} if not user.is_active else {}),
},
)
# this is needed if an organization goes from `TRACK_EXTERNAL_IDP_EXPIRY=true` to `false`
# otherwise, the oidc expiry will always be old, and the user will never be able to login
@@ -888,16 +836,6 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
event=MilestoneRecordType.TENANT_CREATED,
)
# Assign user to the appropriate default group (Admin or Basic).
# Must happen inside the try block while tenant context is active,
# otherwise get_session_with_current_tenant() targets the wrong schema.
is_admin = user_count == 1 or user.email in get_default_admin_user_emails()
with get_session_with_current_tenant() as db_session:
assign_user_to_default_groups__no_commit(
db_session, user, is_admin=is_admin
)
db_session.commit()
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
@@ -1037,7 +975,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
self.password_helper.hash(credentials.password)
return None
if not user.account_type.is_web_login():
if not user.role.is_web_login():
raise BasicAuthenticationError(
detail="NO_WEB_LOGIN_AND_HAS_NO_PASSWORD",
)
@@ -1533,7 +1471,7 @@ async def _get_or_create_user_from_jwt(
if not user.is_active:
logger.warning("Inactive user %s attempted JWT login; skipping", email)
return None
if not user.account_type.is_web_login():
if not user.role.is_web_login():
raise exceptions.UserNotExists()
except exceptions.UserNotExists:
logger.info("Provisioning user %s from JWT login", email)
@@ -1554,7 +1492,7 @@ async def _get_or_create_user_from_jwt(
email,
)
return None
if not user.account_type.is_web_login():
if not user.role.is_web_login():
logger.warning(
"Non-web-login user %s attempted JWT login during provisioning race; skipping",
email,
@@ -1616,7 +1554,6 @@ def get_anonymous_user() -> User:
is_verified=True,
is_superuser=False,
role=UserRole.LIMITED,
account_type=AccountType.ANONYMOUS,
use_memories=False,
enable_memory_tool=False,
)
@@ -1682,9 +1619,9 @@ async def current_user(
) -> User:
user = await double_check_user(user)
if is_limited_user(user):
if user.role == UserRole.LIMITED:
raise BasicAuthenticationError(
detail="Access denied. User has limited permissions.",
detail="Access denied. User role is LIMITED. BASIC or higher permissions are required.",
)
return user
@@ -1701,6 +1638,15 @@ async def current_curator_or_admin_user(
return user
async def current_admin_user(user: User = Depends(current_user)) -> User:
if user.role != UserRole.ADMIN:
raise BasicAuthenticationError(
detail="Access denied. User must be an admin to perform this action.",
)
return user
async def _get_user_from_token_data(token_data: dict) -> User | None:
"""Shared logic: token data dict → User object.
@@ -1809,11 +1755,11 @@ async def current_user_from_websocket(
# Apply same checks as HTTP auth (verification, OIDC expiry, role)
user = await double_check_user(user)
# Block limited users (same as current_user)
if is_limited_user(user):
logger.warning(f"WS auth: user {user.email} is limited")
# Block LIMITED users (same as current_user)
if user.role == UserRole.LIMITED:
logger.warning(f"WS auth: user {user.email} has LIMITED role")
raise BasicAuthenticationError(
detail="Access denied. User has limited permissions.",
detail="Access denied. User role is LIMITED. BASIC or higher permissions are required.",
)
logger.debug(f"WS auth: authenticated {user.email}")

View File

@@ -1,7 +1,6 @@
# Overview of Onyx Background Jobs
The background jobs take care of:
1. Pulling/Indexing documents (from connectors)
2. Updating document metadata (from connectors)
3. Cleaning up checkpoints and logic around indexing work (indexing indexing checkpoints and index attempt metadata)
@@ -10,41 +9,37 @@ The background jobs take care of:
## Worker → Queue Mapping
| Worker | File | Queues |
| ------------------------- | ------------------------------ | -------------------------------------------------------------------------------------------------------------------- |
| Primary | `apps/primary.py` | `celery` |
| Light | `apps/light.py` | `vespa_metadata_sync`, `connector_deletion`, `doc_permissions_upsert`, `checkpoint_cleanup`, `index_attempt_cleanup` |
| Heavy | `apps/heavy.py` | `connector_pruning`, `connector_doc_permissions_sync`, `connector_external_group_sync`, `csv_generation`, `sandbox` |
| Docprocessing | `apps/docprocessing.py` | `docprocessing` |
| Docfetching | `apps/docfetching.py` | `connector_doc_fetching` |
| User File Processing | `apps/user_file_processing.py` | `user_file_processing`, `user_file_project_sync`, `user_file_delete` |
| Monitoring | `apps/monitoring.py` | `monitoring` |
| Background (consolidated) | `apps/background.py` | All queues above except `celery` |
| Worker | File | Queues |
|--------|------|--------|
| Primary | `apps/primary.py` | `celery` |
| Light | `apps/light.py` | `vespa_metadata_sync`, `connector_deletion`, `doc_permissions_upsert`, `checkpoint_cleanup`, `index_attempt_cleanup` |
| Heavy | `apps/heavy.py` | `connector_pruning`, `connector_doc_permissions_sync`, `connector_external_group_sync`, `csv_generation`, `sandbox` |
| Docprocessing | `apps/docprocessing.py` | `docprocessing` |
| Docfetching | `apps/docfetching.py` | `connector_doc_fetching` |
| User File Processing | `apps/user_file_processing.py` | `user_file_processing`, `user_file_project_sync`, `user_file_delete` |
| Monitoring | `apps/monitoring.py` | `monitoring` |
| Background (consolidated) | `apps/background.py` | All queues above except `celery` |
## Non-Worker Apps
| App | File | Purpose |
| ---------- | ----------- | ----------------------------------------------------------------------------------------------------- |
| **Beat** | `beat.py` | Celery beat scheduler with `DynamicTenantScheduler` that generates per-tenant periodic task schedules |
| **Client** | `client.py` | Minimal app for task submission from non-worker processes (e.g., API server) |
| App | File | Purpose |
|-----|------|---------|
| **Beat** | `beat.py` | Celery beat scheduler with `DynamicTenantScheduler` that generates per-tenant periodic task schedules |
| **Client** | `client.py` | Minimal app for task submission from non-worker processes (e.g., API server) |
### Shared Module
`app_base.py` provides:
- `TenantAwareTask` - Base task class that sets tenant context
- Signal handlers for logging, cleanup, and lifecycle events
- Readiness probes and health checks
## Worker Details
### Primary (Coordinator and task dispatcher)
It is the single worker which handles tasks from the default celery queue. It is a singleton worker ensured by the `PRIMARY_WORKER` Redis lock
which it touches every `CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 8` seconds (using Celery Bootsteps)
On startup:
- waits for redis, postgres, document index to all be healthy
- acquires the singleton lock
- cleans all the redis states associated with background jobs
@@ -52,34 +47,34 @@ On startup:
Then it cycles through its tasks as scheduled by Celery Beat:
| Task | Frequency | Description |
| --------------------------------- | --------- | ------------------------------------------------------------------------------------------ |
| `check_for_indexing` | 15s | Scans for connectors needing indexing → dispatches to `DOCFETCHING` queue |
| `check_for_vespa_sync_task` | 20s | Finds stale documents/document sets → dispatches sync tasks to `VESPA_METADATA_SYNC` queue |
| `check_for_pruning` | 20s | Finds connectors due for pruning → dispatches to `CONNECTOR_PRUNING` queue |
| `check_for_connector_deletion` | 20s | Processes deletion requests → dispatches to `CONNECTOR_DELETION` queue |
| `check_for_user_file_processing` | 20s | Checks for user uploads → dispatches to `USER_FILE_PROCESSING` queue |
| `check_for_checkpoint_cleanup` | 1h | Cleans up old indexing checkpoints |
| `check_for_index_attempt_cleanup` | 30m | Cleans up old index attempts |
| `celery_beat_heartbeat` | 1m | Heartbeat for Beat watchdog |
| Task | Frequency | Description |
|------|-----------|-------------|
| `check_for_indexing` | 15s | Scans for connectors needing indexing → dispatches to `DOCFETCHING` queue |
| `check_for_vespa_sync_task` | 20s | Finds stale documents/document sets → dispatches sync tasks to `VESPA_METADATA_SYNC` queue |
| `check_for_pruning` | 20s | Finds connectors due for pruning → dispatches to `CONNECTOR_PRUNING` queue |
| `check_for_connector_deletion` | 20s | Processes deletion requests → dispatches to `CONNECTOR_DELETION` queue |
| `check_for_user_file_processing` | 20s | Checks for user uploads → dispatches to `USER_FILE_PROCESSING` queue |
| `check_for_checkpoint_cleanup` | 1h | Cleans up old indexing checkpoints |
| `check_for_index_attempt_cleanup` | 30m | Cleans up old index attempts |
| `kombu_message_cleanup_task` | periodic | Cleans orphaned Kombu messages from DB (Kombu being the messaging framework used by Celery) |
| `celery_beat_heartbeat` | 1m | Heartbeat for Beat watchdog |
Watchdog is a separate Python process managed by supervisord which runs alongside celery workers. It checks the ONYX_CELERY_BEAT_HEARTBEAT_KEY in
Redis to ensure Celery Beat is not dead. Beat schedules the celery_beat_heartbeat for Primary to touch the key and share that it's still alive.
See supervisord.conf for watchdog config.
### Light
### Light
Fast and short living tasks that are not resource intensive. High concurrency:
Can have 24 concurrent workers, each with a prefetch of 8 for a total of 192 tasks in flight at once.
Tasks it handles:
- Syncs access/permissions, document sets, boosts, hidden state
- Deletes documents that are marked for deletion in Postgres
- Cleanup of checkpoints and index attempts
### Heavy
### Heavy
Long running, resource intensive tasks, handles pruning and sandbox operations. Low concurrency - max concurrency of 4 with 1 prefetch.
Does not interact with the Document Index, it handles the syncs with external systems. Large volume API calls to handle pruning and fetching permissions, etc.
@@ -88,24 +83,16 @@ Generates CSV exports which may take a long time with significant data in Postgr
Sandbox (new feature) for running Next.js, Python virtual env, OpenCode AI Agent, and access to knowledge files
### Docprocessing, Docfetching, User File Processing
Docprocessing and Docfetching are for indexing documents:
- Docfetching runs connectors to pull documents from external APIs (Google Drive, Confluence, etc.), stores batches to file storage, and dispatches docprocessing tasks
- Docprocessing retrieves batches, runs the indexing pipeline (chunking, embedding), and indexes into the Document Index
- User Files come from uploads directly via the input bar
- Docprocessing retrieves batches, runs the indexing pipeline (chunking, embedding), and indexes into the Document Index
User Files come from uploads directly via the input bar
### Monitoring
Observability and metrics collections:
- Queue lengths, connector success/failure, connector latencies
- Queue lengths, connector success/failure, lconnector latencies
- Memory of supervisor managed processes (workers, beat, slack)
- Cloud and multitenant specific monitorings
## Prometheus Metrics
Workers can expose Prometheus metrics via a standalone HTTP server. Currently docfetching and docprocessing have push-based task lifecycle metrics; the monitoring worker runs pull-based collectors for queue depth and connector health.
For the full metric reference, integration guide, and PromQL examples, see [`docs/METRICS.md`](../../../docs/METRICS.md#celery-worker-metrics).

View File

@@ -13,12 +13,6 @@ from celery.signals import worker_shutdown
import onyx.background.celery.apps.app_base as app_base
from onyx.configs.constants import POSTGRES_CELERY_WORKER_HEAVY_APP_NAME
from onyx.db.engine.sql_engine import SqlEngine
from onyx.server.metrics.celery_task_metrics import on_celery_task_postrun
from onyx.server.metrics.celery_task_metrics import on_celery_task_prerun
from onyx.server.metrics.celery_task_metrics import on_celery_task_rejected
from onyx.server.metrics.celery_task_metrics import on_celery_task_retry
from onyx.server.metrics.celery_task_metrics import on_celery_task_revoked
from onyx.server.metrics.metrics_server import start_metrics_server
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
@@ -40,7 +34,6 @@ def on_task_prerun(
**kwds: Any,
) -> None:
app_base.on_task_prerun(sender, task_id, task, args, kwargs, **kwds)
on_celery_task_prerun(task_id, task)
@signals.task_postrun.connect
@@ -55,31 +48,6 @@ def on_task_postrun(
**kwds: Any,
) -> None:
app_base.on_task_postrun(sender, task_id, task, args, kwargs, retval, state, **kwds)
on_celery_task_postrun(task_id, task, state)
@signals.task_retry.connect
def on_task_retry(sender: Any | None = None, **kwargs: Any) -> None: # noqa: ARG001
task_id = getattr(getattr(sender, "request", None), "id", None)
on_celery_task_retry(task_id, sender)
@signals.task_revoked.connect
def on_task_revoked(sender: Any | None = None, **kwargs: Any) -> None:
task_name = getattr(sender, "name", None) or str(sender)
on_celery_task_revoked(kwargs.get("task_id"), task_name)
@signals.task_rejected.connect
def on_task_rejected(sender: Any | None = None, **kwargs: Any) -> None: # noqa: ARG001
message = kwargs.get("message")
task_name: str | None = None
if message is not None:
headers = getattr(message, "headers", None) or {}
task_name = headers.get("task")
if task_name is None:
task_name = "unknown"
on_celery_task_rejected(None, task_name)
@celeryd_init.connect
@@ -108,7 +76,6 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
@worker_ready.connect
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
start_metrics_server("heavy")
app_base.on_worker_ready(sender, **kwargs)

View File

@@ -317,12 +317,12 @@ celery_app.autodiscover_tasks(
"onyx.background.celery.tasks.docprocessing",
"onyx.background.celery.tasks.evals",
"onyx.background.celery.tasks.hierarchyfetching",
"onyx.background.celery.tasks.periodic",
"onyx.background.celery.tasks.pruning",
"onyx.background.celery.tasks.shared",
"onyx.background.celery.tasks.vespa",
"onyx.background.celery.tasks.llm_model_update",
"onyx.background.celery.tasks.user_file_processing",
"onyx.server.features.proposal_review.engine",
]
)
)

View File

@@ -1,4 +1,3 @@
import time
from collections.abc import Generator
from collections.abc import Iterator
from collections.abc import Sequence
@@ -31,8 +30,6 @@ from onyx.connectors.models import HierarchyNode
from onyx.connectors.models import SlimDocument
from onyx.httpx.httpx_pool import HttpxPool
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.server.metrics.pruning_metrics import inc_pruning_rate_limit_error
from onyx.server.metrics.pruning_metrics import observe_pruning_enumeration_duration
from onyx.utils.logger import setup_logger
@@ -133,7 +130,6 @@ def _extract_from_batch(
def extract_ids_from_runnable_connector(
runnable_connector: BaseConnector,
callback: IndexingHeartbeatInterface | None = None,
connector_type: str = "unknown",
) -> SlimConnectorExtractionResult:
"""
Extract document IDs and hierarchy nodes from a runnable connector.
@@ -183,38 +179,21 @@ def extract_ids_from_runnable_connector(
)
# process raw batches to extract both IDs and hierarchy nodes
enumeration_start = time.monotonic()
try:
for doc_list in raw_batch_generator:
if callback and callback.should_stop():
raise RuntimeError(
"extract_ids_from_runnable_connector: Stop signal detected"
)
for doc_list in raw_batch_generator:
if callback and callback.should_stop():
raise RuntimeError(
"extract_ids_from_runnable_connector: Stop signal detected"
)
batch_result = _extract_from_batch(doc_list)
batch_ids = batch_result.raw_id_to_parent
batch_nodes = batch_result.hierarchy_nodes
doc_batch_processing_func(batch_ids)
all_raw_id_to_parent.update(batch_ids)
all_hierarchy_nodes.extend(batch_nodes)
batch_result = _extract_from_batch(doc_list)
batch_ids = batch_result.raw_id_to_parent
batch_nodes = batch_result.hierarchy_nodes
doc_batch_processing_func(batch_ids)
all_raw_id_to_parent.update(batch_ids)
all_hierarchy_nodes.extend(batch_nodes)
if callback:
callback.progress("extract_ids_from_runnable_connector", len(batch_ids))
except Exception as e:
# Best-effort rate limit detection via string matching.
# Connectors surface rate limits inconsistently — some raise HTTP 429,
# some use SDK-specific exceptions (e.g. google.api_core.exceptions.ResourceExhausted)
# that may or may not include "rate limit" or "429" in the message.
# TODO(Bo): replace with a standard ConnectorRateLimitError exception that all
# connectors raise when rate limited, making this check precise.
error_str = str(e)
if "rate limit" in error_str.lower() or "429" in error_str:
inc_pruning_rate_limit_error(connector_type)
raise
finally:
observe_pruning_enumeration_duration(
time.monotonic() - enumeration_start, connector_type
)
if callback:
callback.progress("extract_ids_from_runnable_connector", len(batch_ids))
return SlimConnectorExtractionResult(
raw_id_to_parent=all_raw_id_to_parent,

View File

@@ -75,8 +75,6 @@ beat_task_templates: list[dict] = [
"options": {
"priority": OnyxCeleryPriority.LOW,
"expires": BEAT_EXPIRES_DEFAULT,
# Run on gated tenants too — they may still have stale checkpoints to clean.
"skip_gated": False,
},
},
{
@@ -86,8 +84,6 @@ beat_task_templates: list[dict] = [
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
# Run on gated tenants too — they may still have stale index attempts.
"skip_gated": False,
},
},
{
@@ -97,8 +93,6 @@ beat_task_templates: list[dict] = [
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
# Gated tenants may still have connectors awaiting deletion.
"skip_gated": False,
},
},
{
@@ -142,14 +136,7 @@ beat_task_templates: list[dict] = [
{
"name": "cleanup-idle-sandboxes",
"task": OnyxCeleryTask.CLEANUP_IDLE_SANDBOXES,
# SANDBOX_IDLE_TIMEOUT_SECONDS defaults to 1 hour, so there is no
# functional reason to scan more often than every ~15 minutes. In the
# cloud this is multiplied by CLOUD_BEAT_MULTIPLIER_DEFAULT (=8) so
# the effective cadence becomes ~2 hours, which still meets the
# idle-detection SLA. The previous 1-minute base schedule produced
# an 8-minute per-tenant fan-out and was the dominant source of
# background DB load on the cloud cluster.
"schedule": timedelta(minutes=15),
"schedule": timedelta(minutes=1),
"options": {
"priority": OnyxCeleryPriority.LOW,
"expires": BEAT_EXPIRES_DEFAULT,
@@ -279,7 +266,7 @@ def make_cloud_generator_task(task: dict[str, Any]) -> dict[str, Any]:
cloud_task["kwargs"] = {}
cloud_task["kwargs"]["task_name"] = task["task"]
optional_fields = ["queue", "priority", "expires", "skip_gated"]
optional_fields = ["queue", "priority", "expires"]
for field in optional_fields:
if field in task["options"]:
cloud_task["kwargs"][field] = task["options"][field]
@@ -315,7 +302,7 @@ beat_cloud_tasks: list[dict] = [
{
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-available-tenants",
"task": OnyxCeleryTask.CLOUD_CHECK_AVAILABLE_TENANTS,
"schedule": timedelta(minutes=2),
"schedule": timedelta(minutes=10),
"options": {
"queue": OnyxCeleryQueues.MONITORING,
"priority": OnyxCeleryPriority.HIGH,
@@ -372,13 +359,7 @@ if not MULTI_TENANT:
]
)
# `skip_gated` is a cloud-only hint consumed by `cloud_beat_task_generator`. Strip
# it before extending the self-hosted schedule so it doesn't leak into apply_async
# as an unrecognised option on every fired task message.
for _template in beat_task_templates:
_self_hosted_template = copy.deepcopy(_template)
_self_hosted_template["options"].pop("skip_gated", None)
tasks_to_schedule.append(_self_hosted_template)
tasks_to_schedule.extend(beat_task_templates)
def generate_cloud_tasks(

View File

@@ -36,7 +36,6 @@ from onyx.configs.constants import OnyxRedisLocks
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.opensearch_migration import build_sanitized_to_original_doc_id_mapping
from onyx.db.opensearch_migration import get_vespa_visit_state
from onyx.db.opensearch_migration import is_migration_completed
from onyx.db.opensearch_migration import (
mark_migration_completed_time_if_not_set_with_commit,
)
@@ -107,19 +106,14 @@ def migrate_chunks_from_vespa_to_opensearch_task(
acquired; effectively a no-op. True if the task completed
successfully. False if the task errored.
"""
# 1. Check if we should run the task.
# 1.a. If OpenSearch indexing is disabled, we don't run the task.
if not ENABLE_OPENSEARCH_INDEXING_FOR_ONYX:
task_logger.warning(
"OpenSearch migration is not enabled, skipping chunk migration task."
)
return None
task_logger.info("Starting chunk-level migration from Vespa to OpenSearch.")
task_start_time = time.monotonic()
# 1.b. Only one instance per tenant of this task may run concurrently at
# once. If we fail to acquire a lock, we assume it is because another task
# has one and we exit.
r = get_redis_client()
lock: RedisLock = r.lock(
name=OnyxRedisLocks.OPENSEARCH_MIGRATION_BEAT_LOCK,
@@ -142,11 +136,10 @@ def migrate_chunks_from_vespa_to_opensearch_task(
f"Token: {lock.local.token}"
)
# 2. Prepare to migrate.
total_chunks_migrated_this_task = 0
total_chunks_errored_this_task = 0
try:
# 2.a. Double-check that tenant info is correct.
# Double check that tenant info is correct.
if tenant_id != get_current_tenant_id():
err_str = (
f"Tenant ID mismatch in the OpenSearch migration task: "
@@ -155,62 +148,16 @@ def migrate_chunks_from_vespa_to_opensearch_task(
task_logger.error(err_str)
return False
# Do as much as we can with a DB session in one spot to not hold a
# session during a migration batch.
with get_session_with_current_tenant() as db_session:
# 2.b. Immediately check to see if this tenant is done, to save
# having to do any other work. This function does not require a
# migration record to necessarily exist.
if is_migration_completed(db_session):
return True
# 2.c. Try to insert the OpenSearchTenantMigrationRecord table if it
# does not exist.
with (
get_session_with_current_tenant() as db_session,
get_vespa_http_client(
timeout=VESPA_MIGRATION_REQUEST_TIMEOUT_S
) as vespa_client,
):
try_insert_opensearch_tenant_migration_record_with_commit(db_session)
# 2.d. Get search settings.
search_settings = get_current_search_settings(db_session)
indexing_setting = IndexingSetting.from_db_model(search_settings)
# 2.e. Build sanitized to original doc ID mapping to check for
# conflicts in the event we sanitize a doc ID to an
# already-existing doc ID.
# We reconstruct this mapping for every task invocation because
# a document may have been added in the time between two tasks.
sanitized_doc_start_time = time.monotonic()
sanitized_to_original_doc_id_mapping = (
build_sanitized_to_original_doc_id_mapping(db_session)
)
task_logger.debug(
f"Built sanitized_to_original_doc_id_mapping with {len(sanitized_to_original_doc_id_mapping)} entries "
f"in {time.monotonic() - sanitized_doc_start_time:.3f} seconds."
)
# 2.f. Get the current migration state.
continuation_token_map, total_chunks_migrated = get_vespa_visit_state(
db_session
)
# 2.f.1. Double-check that the migration state does not imply
# completion. Really we should never have to enter this block as we
# would expect is_migration_completed to return True, but in the
# strange event that the migration is complete but the migration
# completed time was never stamped, we do so here.
if is_continuation_token_done_for_all_slices(continuation_token_map):
task_logger.info(
f"OpenSearch migration COMPLETED for tenant {tenant_id}. Total chunks migrated: {total_chunks_migrated}."
)
mark_migration_completed_time_if_not_set_with_commit(db_session)
return True
task_logger.debug(
f"Read the tenant migration record. Total chunks migrated: {total_chunks_migrated}. "
f"Continuation token map: {continuation_token_map}"
)
with get_vespa_http_client(
timeout=VESPA_MIGRATION_REQUEST_TIMEOUT_S
) as vespa_client:
# 2.g. Create the OpenSearch and Vespa document indexes.
tenant_state = TenantState(tenant_id=tenant_id, multitenant=MULTI_TENANT)
indexing_setting = IndexingSetting.from_db_model(search_settings)
opensearch_document_index = OpenSearchDocumentIndex(
tenant_state=tenant_state,
index_name=search_settings.index_name,
@@ -224,14 +171,22 @@ def migrate_chunks_from_vespa_to_opensearch_task(
httpx_client=vespa_client,
)
# 2.h. Get the approximate chunk count in Vespa as of this time to
# update the migration record.
sanitized_doc_start_time = time.monotonic()
# We reconstruct this mapping for every task invocation because a
# document may have been added in the time between two tasks.
sanitized_to_original_doc_id_mapping = (
build_sanitized_to_original_doc_id_mapping(db_session)
)
task_logger.debug(
f"Built sanitized_to_original_doc_id_mapping with {len(sanitized_to_original_doc_id_mapping)} entries "
f"in {time.monotonic() - sanitized_doc_start_time:.3f} seconds."
)
approx_chunk_count_in_vespa: int | None = None
get_chunk_count_start_time = time.monotonic()
try:
approx_chunk_count_in_vespa = vespa_document_index.get_chunk_count()
except Exception:
# This failure should not be blocking.
task_logger.exception(
"Error getting approximate chunk count in Vespa. Moving on..."
)
@@ -240,12 +195,25 @@ def migrate_chunks_from_vespa_to_opensearch_task(
f"approximate chunk count in Vespa. Got {approx_chunk_count_in_vespa}."
)
# 3. Do the actual migration in batches until we run out of time.
while (
time.monotonic() - task_start_time < MIGRATION_TASK_SOFT_TIME_LIMIT_S
and lock.owned()
):
# 3.a. Get the next batch of raw chunks from Vespa.
(
continuation_token_map,
total_chunks_migrated,
) = get_vespa_visit_state(db_session)
if is_continuation_token_done_for_all_slices(continuation_token_map):
task_logger.info(
f"OpenSearch migration COMPLETED for tenant {tenant_id}. Total chunks migrated: {total_chunks_migrated}."
)
mark_migration_completed_time_if_not_set_with_commit(db_session)
break
task_logger.debug(
f"Read the tenant migration record. Total chunks migrated: {total_chunks_migrated}. "
f"Continuation token map: {continuation_token_map}"
)
get_vespa_chunks_start_time = time.monotonic()
raw_vespa_chunks, next_continuation_token_map = (
vespa_document_index.get_all_raw_document_chunks_paginated(
@@ -258,7 +226,6 @@ def migrate_chunks_from_vespa_to_opensearch_task(
f"seconds. Next continuation token map: {next_continuation_token_map}"
)
# 3.b. Transform the raw chunks to OpenSearch chunks in memory.
opensearch_document_chunks, errored_chunks = (
transform_vespa_chunks_to_opensearch_chunks(
raw_vespa_chunks,
@@ -273,7 +240,6 @@ def migrate_chunks_from_vespa_to_opensearch_task(
"errored."
)
# 3.c. Index the OpenSearch chunks into OpenSearch.
index_opensearch_chunks_start_time = time.monotonic()
opensearch_document_index.index_raw_chunks(
chunks=opensearch_document_chunks
@@ -285,38 +251,12 @@ def migrate_chunks_from_vespa_to_opensearch_task(
total_chunks_migrated_this_task += len(opensearch_document_chunks)
total_chunks_errored_this_task += len(errored_chunks)
# Do as much as we can with a DB session in one spot to not hold a
# session during a migration batch.
with get_session_with_current_tenant() as db_session:
# 3.d. Update the migration state.
update_vespa_visit_progress_with_commit(
db_session,
continuation_token_map=next_continuation_token_map,
chunks_processed=len(opensearch_document_chunks),
chunks_errored=len(errored_chunks),
approx_chunk_count_in_vespa=approx_chunk_count_in_vespa,
)
# 3.e. Get the current migration state. Even thought we
# technically have it in-memory since we just wrote it, we
# want to reference the DB as the source of truth at all
# times.
continuation_token_map, total_chunks_migrated = (
get_vespa_visit_state(db_session)
)
# 3.e.1. Check if the migration is done.
if is_continuation_token_done_for_all_slices(
continuation_token_map
):
task_logger.info(
f"OpenSearch migration COMPLETED for tenant {tenant_id}. Total chunks migrated: {total_chunks_migrated}."
)
mark_migration_completed_time_if_not_set_with_commit(db_session)
return True
task_logger.debug(
f"Read the tenant migration record. Total chunks migrated: {total_chunks_migrated}. "
f"Continuation token map: {continuation_token_map}"
update_vespa_visit_progress_with_commit(
db_session,
continuation_token_map=next_continuation_token_map,
chunks_processed=len(opensearch_document_chunks),
chunks_errored=len(errored_chunks),
approx_chunk_count_in_vespa=approx_chunk_count_in_vespa,
)
except Exception:
traceback.print_exc()

View File

@@ -0,0 +1,138 @@
#####
# Periodic Tasks
#####
import json
from typing import Any
from celery import shared_task
from celery.contrib.abortable import AbortableTask # type: ignore
from celery.exceptions import TaskRevokedError
from sqlalchemy import inspect
from sqlalchemy import text
from sqlalchemy.orm import Session
from onyx.background.celery.apps.app_base import task_logger
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import PostgresAdvisoryLocks
from onyx.db.engine.sql_engine import get_session_with_current_tenant
@shared_task(
name=OnyxCeleryTask.KOMBU_MESSAGE_CLEANUP_TASK,
soft_time_limit=JOB_TIMEOUT,
bind=True,
base=AbortableTask,
)
def kombu_message_cleanup_task(self: Any, tenant_id: str) -> int: # noqa: ARG001
"""Runs periodically to clean up the kombu_message table"""
# we will select messages older than this amount to clean up
KOMBU_MESSAGE_CLEANUP_AGE = 7 # days
KOMBU_MESSAGE_CLEANUP_PAGE_LIMIT = 1000
ctx = {}
ctx["last_processed_id"] = 0
ctx["deleted"] = 0
ctx["cleanup_age"] = KOMBU_MESSAGE_CLEANUP_AGE
ctx["page_limit"] = KOMBU_MESSAGE_CLEANUP_PAGE_LIMIT
with get_session_with_current_tenant() as db_session:
# Exit the task if we can't take the advisory lock
result = db_session.execute(
text("SELECT pg_try_advisory_lock(:id)"),
{"id": PostgresAdvisoryLocks.KOMBU_MESSAGE_CLEANUP_LOCK_ID.value},
).scalar()
if not result:
return 0
while True:
if self.is_aborted():
raise TaskRevokedError("kombu_message_cleanup_task was aborted.")
b = kombu_message_cleanup_task_helper(ctx, db_session)
if not b:
break
db_session.commit()
if ctx["deleted"] > 0:
task_logger.info(
f"Deleted {ctx['deleted']} orphaned messages from kombu_message."
)
return ctx["deleted"]
def kombu_message_cleanup_task_helper(ctx: dict, db_session: Session) -> bool:
"""
Helper function to clean up old messages from the `kombu_message` table that are no longer relevant.
This function retrieves messages from the `kombu_message` table that are no longer visible and
older than a specified interval. It checks if the corresponding task_id exists in the
`celery_taskmeta` table. If the task_id does not exist, the message is deleted.
Args:
ctx (dict): A context dictionary containing configuration parameters such as:
- 'cleanup_age' (int): The age in days after which messages are considered old.
- 'page_limit' (int): The maximum number of messages to process in one batch.
- 'last_processed_id' (int): The ID of the last processed message to handle pagination.
- 'deleted' (int): A counter to track the number of deleted messages.
db_session (Session): The SQLAlchemy database session for executing queries.
Returns:
bool: Returns True if there are more rows to process, False if not.
"""
inspector = inspect(db_session.bind)
if not inspector:
return False
# With the move to redis as celery's broker and backend, kombu tables may not even exist.
# We can fail silently.
if not inspector.has_table("kombu_message"):
return False
query = text(
"""
SELECT id, timestamp, payload
FROM kombu_message WHERE visible = 'false'
AND timestamp < CURRENT_TIMESTAMP - INTERVAL :interval_days
AND id > :last_processed_id
ORDER BY id
LIMIT :page_limit
"""
)
kombu_messages = db_session.execute(
query,
{
"interval_days": f"{ctx['cleanup_age']} days",
"page_limit": ctx["page_limit"],
"last_processed_id": ctx["last_processed_id"],
},
).fetchall()
if len(kombu_messages) == 0:
return False
for msg in kombu_messages:
payload = json.loads(msg[2])
task_id = payload["headers"]["id"]
# Check if task_id exists in celery_taskmeta
task_exists = db_session.execute(
text("SELECT 1 FROM celery_taskmeta WHERE task_id = :task_id"),
{"task_id": task_id},
).fetchone()
# If task_id does not exist, delete the message
if not task_exists:
result = db_session.execute(
text("DELETE FROM kombu_message WHERE id = :message_id"),
{"message_id": msg[0]},
)
if result.rowcount > 0: # type: ignore
ctx["deleted"] += 1
ctx["last_processed_id"] = msg[0]
return True

View File

@@ -72,7 +72,6 @@ from onyx.redis.redis_hierarchy import get_source_node_id_from_cache
from onyx.redis.redis_hierarchy import HierarchyNodeCacheEntry
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_pool import get_redis_replica_client
from onyx.server.metrics.pruning_metrics import observe_pruning_diff_duration
from onyx.server.runtime.onyx_runtime import OnyxRuntime
from onyx.server.utils import make_short_id
from onyx.utils.logger import format_error_for_logging
@@ -218,7 +217,7 @@ def check_for_pruning(self: Task, *, tenant_id: str) -> bool | None:
try:
# the entire task needs to run frequently in order to finalize pruning
# but pruning only kicks off once per min
# but pruning only kicks off once per hour
if not r.exists(OnyxRedisSignals.BLOCK_PRUNING):
task_logger.info("Checking for pruning due")
@@ -571,9 +570,8 @@ def connector_pruning_generator_task(
)
# Extract docs and hierarchy nodes from the source
connector_type = cc_pair.connector.source.value
extraction_result = extract_ids_from_runnable_connector(
runnable_connector, callback, connector_type=connector_type
runnable_connector, callback
)
all_connector_doc_ids = extraction_result.raw_id_to_parent
@@ -638,46 +636,40 @@ def connector_pruning_generator_task(
commit=True,
)
diff_start = time.monotonic()
try:
# a list of docs in our local index
all_indexed_document_ids = {
doc.id
for doc in get_documents_for_connector_credential_pair(
db_session=db_session,
connector_id=connector_id,
credential_id=credential_id,
)
}
# a list of docs in our local index
all_indexed_document_ids = {
doc.id
for doc in get_documents_for_connector_credential_pair(
db_session=db_session,
connector_id=connector_id,
credential_id=credential_id,
)
}
# generate list of docs to remove (no longer in the source)
doc_ids_to_remove = list(
all_indexed_document_ids - all_connector_doc_ids.keys()
)
# generate list of docs to remove (no longer in the source)
doc_ids_to_remove = list(
all_indexed_document_ids - all_connector_doc_ids.keys()
)
task_logger.info(
"Pruning set collected: "
f"cc_pair={cc_pair_id} "
f"connector_source={cc_pair.connector.source} "
f"docs_to_remove={len(doc_ids_to_remove)}"
)
task_logger.info(
"Pruning set collected: "
f"cc_pair={cc_pair_id} "
f"connector_source={cc_pair.connector.source} "
f"docs_to_remove={len(doc_ids_to_remove)}"
)
task_logger.info(
f"RedisConnector.prune.generate_tasks starting. cc_pair={cc_pair_id}"
)
tasks_generated = redis_connector.prune.generate_tasks(
set(doc_ids_to_remove), self.app, db_session, None
)
if tasks_generated is None:
return None
task_logger.info(
f"RedisConnector.prune.generate_tasks starting. cc_pair={cc_pair_id}"
)
tasks_generated = redis_connector.prune.generate_tasks(
set(doc_ids_to_remove), self.app, db_session, None
)
if tasks_generated is None:
return None
task_logger.info(
f"RedisConnector.prune.generate_tasks finished. cc_pair={cc_pair_id} tasks_generated={tasks_generated}"
)
finally:
observe_pruning_diff_duration(
time.monotonic() - diff_start, connector_type
)
task_logger.info(
f"RedisConnector.prune.generate_tasks finished. cc_pair={cc_pair_id} tasks_generated={tasks_generated}"
)
redis_connector.prune.generator_complete = tasks_generated

View File

@@ -996,7 +996,6 @@ def _run_models(
def _run_model(model_idx: int) -> None:
"""Run one LLM loop inside a worker thread, writing packets to ``merged_queue``."""
model_emitter = Emitter(
model_idx=model_idx,
merged_queue=merged_queue,
@@ -1103,33 +1102,33 @@ def _run_models(
finally:
merged_queue.put((model_idx, _MODEL_DONE))
def _save_errored_message(model_idx: int, context: str) -> None:
"""Save an error message to a reserved ChatMessage that failed during execution."""
def _delete_orphaned_message(model_idx: int, context: str) -> None:
"""Delete a reserved ChatMessage that was never populated due to a model error."""
try:
msg = db_session.get(ChatMessage, setup.reserved_messages[model_idx].id)
if msg is not None:
error_text = f"Error from {setup.model_display_names[model_idx]}: model encountered an error during generation."
msg.message = error_text
msg.error = error_text
orphaned = db_session.get(
ChatMessage, setup.reserved_messages[model_idx].id
)
if orphaned is not None:
db_session.delete(orphaned)
db_session.commit()
except Exception:
logger.exception(
"%s error save failed for model %d (%s)",
"%s orphan cleanup failed for model %d (%s)",
context,
model_idx,
setup.model_display_names[model_idx],
)
# Each worker thread needs its own Context copy — a single Context object
# cannot be entered concurrently by multiple threads (RuntimeError).
# Copy contextvars before submitting futures — ThreadPoolExecutor does NOT
# auto-propagate contextvars in Python 3.11; threads would inherit a blank context.
worker_context = contextvars.copy_context()
executor = ThreadPoolExecutor(
max_workers=n_models, thread_name_prefix="multi-model"
)
completion_persisted: bool = False
try:
for i in range(n_models):
ctx = contextvars.copy_context()
executor.submit(ctx.run, _run_model, i)
executor.submit(worker_context.run, _run_model, i)
# ── Main thread: merge and yield packets ────────────────────────────
models_remaining = n_models
@@ -1146,7 +1145,7 @@ def _run_models(
# save "stopped by user" for a model that actually threw an exception.
for i in range(n_models):
if model_errored[i]:
_save_errored_message(i, "stop-button")
_delete_orphaned_message(i, "stop-button")
continue
try:
succeeded = model_succeeded[i]
@@ -1212,7 +1211,7 @@ def _run_models(
for i in range(n_models):
if not model_succeeded[i]:
# Model errored — delete its orphaned reserved message.
_save_errored_message(i, "normal")
_delete_orphaned_message(i, "normal")
continue
try:
llm_loop_completion_handle(
@@ -1265,7 +1264,7 @@ def _run_models(
setup.model_display_names[i],
)
elif model_errored[i]:
_save_errored_message(i, "disconnect")
_delete_orphaned_message(i, "disconnect")
# 4. Drain buffered packets from memory — no consumer is running.
while not merged_queue.empty():
try:

View File

@@ -379,14 +379,6 @@ POSTGRES_HOST = os.environ.get("POSTGRES_HOST") or "127.0.0.1"
POSTGRES_PORT = os.environ.get("POSTGRES_PORT") or "5432"
POSTGRES_DB = os.environ.get("POSTGRES_DB") or "postgres"
AWS_REGION_NAME = os.environ.get("AWS_REGION_NAME") or "us-east-2"
# Comma-separated replica / multi-host list. If unset, defaults to POSTGRES_HOST
# only.
_POSTGRES_HOSTS_STR = os.environ.get("POSTGRES_HOSTS", "").strip()
POSTGRES_HOSTS: list[str] = (
[h.strip() for h in _POSTGRES_HOSTS_STR.split(",") if h.strip()]
if _POSTGRES_HOSTS_STR
else [POSTGRES_HOST]
)
POSTGRES_API_SERVER_POOL_SIZE = int(
os.environ.get("POSTGRES_API_SERVER_POOL_SIZE") or 40

View File

@@ -12,11 +12,6 @@ SLACK_USER_TOKEN_PREFIX = "xoxp-"
SLACK_BOT_TOKEN_PREFIX = "xoxb-"
ONYX_EMAILABLE_LOGO_MAX_DIM = 512
# The mask_string() function in encryption.py uses "•" (U+2022 BULLET) to mask secrets.
MASK_CREDENTIAL_CHAR = "\u2022"
# Pattern produced by mask_string for strings >= 14 chars: "abcd...wxyz" (exactly 11 chars)
MASK_CREDENTIAL_LONG_RE = re.compile(r"^.{4}\.{3}.{4}$")
SOURCE_TYPE = "source_type"
# stored in the `metadata` of a chunk. Used to signify that this chunk should
# not be used for QA. For example, Google Drive file types which can't be parsed
@@ -396,6 +391,10 @@ class MilestoneRecordType(str, Enum):
REQUESTED_CONNECTOR = "requested_connector"
class PostgresAdvisoryLocks(Enum):
KOMBU_MESSAGE_CLEANUP_LOCK_ID = auto()
class OnyxCeleryQueues:
# "celery" is the default queue defined by celery and also the queue
# we are running in the primary worker to run system tasks
@@ -578,6 +577,7 @@ class OnyxCeleryTask:
MONITOR_PROCESS_MEMORY = "monitor_process_memory"
CELERY_BEAT_HEARTBEAT = "celery_beat_heartbeat"
KOMBU_MESSAGE_CLEANUP_TASK = "kombu_message_cleanup_task"
CONNECTOR_PERMISSION_SYNC_GENERATOR_TASK = (
"connector_permission_sync_generator_task"
)

View File

@@ -42,9 +42,6 @@ from onyx.connectors.google_drive.file_retrieval import (
get_all_files_in_my_drive_and_shared,
)
from onyx.connectors.google_drive.file_retrieval import get_external_access_for_folder
from onyx.connectors.google_drive.file_retrieval import (
get_files_by_web_view_links_batch,
)
from onyx.connectors.google_drive.file_retrieval import get_files_in_shared_drive
from onyx.connectors.google_drive.file_retrieval import get_folder_metadata
from onyx.connectors.google_drive.file_retrieval import get_root_folder_id
@@ -73,13 +70,11 @@ from onyx.connectors.interfaces import CheckpointedConnectorWithPermSync
from onyx.connectors.interfaces import CheckpointOutput
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.connectors.interfaces import NormalizationResult
from onyx.connectors.interfaces import Resolver
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import SlimConnectorWithPermSync
from onyx.connectors.models import ConnectorFailure
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import DocumentFailure
from onyx.connectors.models import EntityFailure
from onyx.connectors.models import HierarchyNode
from onyx.connectors.models import SlimDocument
@@ -207,9 +202,7 @@ class DriveIdStatus(Enum):
class GoogleDriveConnector(
SlimConnectorWithPermSync,
CheckpointedConnectorWithPermSync[GoogleDriveCheckpoint],
Resolver,
SlimConnectorWithPermSync, CheckpointedConnectorWithPermSync[GoogleDriveCheckpoint]
):
def __init__(
self,
@@ -1672,82 +1665,6 @@ class GoogleDriveConnector(
start, end, checkpoint, include_permissions=True
)
@override
def resolve_errors(
self,
errors: list[ConnectorFailure],
include_permissions: bool = False,
) -> Generator[Document | ConnectorFailure | HierarchyNode, None, None]:
if self._creds is None or self._primary_admin_email is None:
raise RuntimeError(
"Credentials missing, should not call this method before calling load_credentials"
)
logger.info(f"Resolving {len(errors)} errors")
doc_ids = [
failure.failed_document.document_id
for failure in errors
if failure.failed_document
]
service = get_drive_service(self.creds, self.primary_admin_email)
field_type = (
DriveFileFieldType.WITH_PERMISSIONS
if include_permissions or self.exclude_domain_link_only
else DriveFileFieldType.STANDARD
)
batch_result = get_files_by_web_view_links_batch(service, doc_ids, field_type)
for doc_id, error in batch_result.errors.items():
yield ConnectorFailure(
failed_document=DocumentFailure(
document_id=doc_id,
document_link=doc_id,
),
failure_message=f"Failed to retrieve file during error resolution: {error}",
exception=error,
)
permission_sync_context = (
PermissionSyncContext(
primary_admin_email=self.primary_admin_email,
google_domain=self.google_domain,
)
if include_permissions
else None
)
retrieved_files = [
RetrievedDriveFile(
drive_file=file,
user_email=self.primary_admin_email,
completion_stage=DriveRetrievalStage.DONE,
)
for file in batch_result.files.values()
]
yield from self._get_new_ancestors_for_files(
files=retrieved_files,
seen_hierarchy_node_raw_ids=ThreadSafeSet(),
fully_walked_hierarchy_node_raw_ids=ThreadSafeSet(),
permission_sync_context=permission_sync_context,
add_prefix=True,
)
func_with_args = [
(
self._convert_retrieved_file_to_document,
(rf, permission_sync_context),
)
for rf in retrieved_files
]
results = cast(
list[Document | ConnectorFailure | None],
run_functions_tuples_in_parallel(func_with_args, max_workers=8),
)
for result in results:
if result is not None:
yield result
def _extract_slim_docs_from_google_drive(
self,
checkpoint: GoogleDriveCheckpoint,

View File

@@ -9,7 +9,6 @@ from urllib.parse import urlparse
from googleapiclient.discovery import Resource # type: ignore
from googleapiclient.errors import HttpError # type: ignore
from googleapiclient.http import BatchHttpRequest # type: ignore
from onyx.access.models import ExternalAccess
from onyx.connectors.google_drive.constants import DRIVE_FOLDER_TYPE
@@ -61,8 +60,6 @@ SLIM_FILE_FIELDS = (
)
FOLDER_FIELDS = "nextPageToken, files(id, name, permissions, modifiedTime, webViewLink, shortcutDetails)"
MAX_BATCH_SIZE = 100
HIERARCHY_FIELDS = "id, name, parents, webViewLink, mimeType, driveId"
HIERARCHY_FIELDS_WITH_PERMISSIONS = (
@@ -219,7 +216,7 @@ def get_external_access_for_folder(
def _get_fields_for_file_type(field_type: DriveFileFieldType) -> str:
"""Get the appropriate fields string for files().list() based on the field type enum."""
"""Get the appropriate fields string based on the field type enum"""
if field_type == DriveFileFieldType.SLIM:
return SLIM_FILE_FIELDS
elif field_type == DriveFileFieldType.WITH_PERMISSIONS:
@@ -228,25 +225,6 @@ def _get_fields_for_file_type(field_type: DriveFileFieldType) -> str:
return FILE_FIELDS
def _extract_single_file_fields(list_fields: str) -> str:
"""Convert a files().list() fields string to one suitable for files().get().
List fields look like "nextPageToken, files(field1, field2, ...)"
Single-file fields should be just "field1, field2, ..."
"""
start = list_fields.find("files(")
if start == -1:
return list_fields
inner_start = start + len("files(")
inner_end = list_fields.rfind(")")
return list_fields[inner_start:inner_end]
def _get_single_file_fields(field_type: DriveFileFieldType) -> str:
"""Get the appropriate fields string for files().get() based on the field type enum."""
return _extract_single_file_fields(_get_fields_for_file_type(field_type))
def _get_files_in_parent(
service: Resource,
parent_id: str,
@@ -558,74 +536,3 @@ def get_file_by_web_view_link(
)
.execute()
)
class BatchRetrievalResult:
"""Result of a batch file retrieval, separating successes from errors."""
def __init__(self) -> None:
self.files: dict[str, GoogleDriveFileType] = {}
self.errors: dict[str, Exception] = {}
def get_files_by_web_view_links_batch(
service: GoogleDriveService,
web_view_links: list[str],
field_type: DriveFileFieldType,
) -> BatchRetrievalResult:
"""Retrieve multiple Google Drive files by webViewLink using the batch API.
Returns a BatchRetrievalResult containing successful file retrievals
and errors for any files that could not be fetched.
Automatically splits into chunks of MAX_BATCH_SIZE.
"""
fields = _get_single_file_fields(field_type)
if len(web_view_links) <= MAX_BATCH_SIZE:
return _get_files_by_web_view_links_batch(service, web_view_links, fields)
combined = BatchRetrievalResult()
for i in range(0, len(web_view_links), MAX_BATCH_SIZE):
chunk = web_view_links[i : i + MAX_BATCH_SIZE]
chunk_result = _get_files_by_web_view_links_batch(service, chunk, fields)
combined.files.update(chunk_result.files)
combined.errors.update(chunk_result.errors)
return combined
def _get_files_by_web_view_links_batch(
service: GoogleDriveService,
web_view_links: list[str],
fields: str,
) -> BatchRetrievalResult:
"""Single-batch implementation."""
result = BatchRetrievalResult()
def callback(
request_id: str,
response: GoogleDriveFileType,
exception: Exception | None,
) -> None:
if exception:
logger.warning(f"Error retrieving file {request_id}: {exception}")
result.errors[request_id] = exception
else:
result.files[request_id] = response
batch = cast(BatchHttpRequest, service.new_batch_http_request(callback=callback))
for web_view_link in web_view_links:
try:
file_id = _extract_file_id_from_web_view_link(web_view_link)
request = service.files().get(
fileId=file_id,
supportsAllDrives=True,
fields=fields,
)
batch.add(request, request_id=web_view_link)
except ValueError as e:
logger.warning(f"Failed to extract file ID from {web_view_link}: {e}")
result.errors[web_view_link] = e
batch.execute()
return result

View File

@@ -298,22 +298,6 @@ class CheckpointedConnectorWithPermSync(CheckpointedConnector[CT]):
raise NotImplementedError
class Resolver(BaseConnector):
@abc.abstractmethod
def resolve_errors(
self,
errors: list[ConnectorFailure],
include_permissions: bool = False,
) -> Generator[Document | ConnectorFailure | HierarchyNode, None, None]:
"""Attempts to yield back ALL the documents described by the errors, no checkpointing.
Caller's responsibility is to delete the old ConnectorFailures and replace with the new ones.
If include_permissions is True, the documents will have permissions synced.
May also yield HierarchyNode objects for ancestor folders of resolved documents.
"""
raise NotImplementedError
class HierarchyConnector(BaseConnector):
@abc.abstractmethod
def load_hierarchy(

View File

@@ -8,7 +8,6 @@ from collections.abc import Iterator
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from io import BytesIO
from typing import Any
import requests
@@ -41,7 +40,6 @@ from onyx.connectors.jira.utils import best_effort_basic_expert_info
from onyx.connectors.jira.utils import best_effort_get_field_from_issue
from onyx.connectors.jira.utils import build_jira_client
from onyx.connectors.jira.utils import build_jira_url
from onyx.connectors.jira.utils import CustomFieldExtractor
from onyx.connectors.jira.utils import extract_text_from_adf
from onyx.connectors.jira.utils import get_comment_strs
from onyx.connectors.jira.utils import JIRA_CLOUD_API_VERSION
@@ -54,7 +52,6 @@ from onyx.connectors.models import HierarchyNode
from onyx.connectors.models import SlimDocument
from onyx.connectors.models import TextSection
from onyx.db.enums import HierarchyNodeType
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
@@ -63,11 +60,8 @@ logger = setup_logger()
ONE_HOUR = 3600
_MAX_RESULTS_FETCH_IDS = 5000
_MAX_RESULTS_FETCH_IDS = 5000 # 5000
_JIRA_FULL_PAGE_SIZE = 50
# https://developer.atlassian.com/cloud/jira/platform/rest/v3/api-group-issues/
_JIRA_BULK_FETCH_LIMIT = 100
_MAX_ATTACHMENT_SIZE_BYTES = 50 * 1024 * 1024 # 50 MB
# Constants for Jira field names
_FIELD_REPORTER = "reporter"
@@ -261,13 +255,15 @@ def _bulk_fetch_request(
return resp.json()["issues"]
def _bulk_fetch_batch(
jira_client: JIRA, issue_ids: list[str], fields: str | None
) -> list[dict[str, Any]]:
"""Fetch a single batch (must be <= _JIRA_BULK_FETCH_LIMIT).
On JSONDecodeError, recursively bisects until it succeeds or reaches size 1."""
def bulk_fetch_issues(
jira_client: JIRA, issue_ids: list[str], fields: str | None = None
) -> list[Issue]:
# TODO(evan): move away from this jira library if they continue to not support
# the endpoints we need. Using private fields is not ideal, but
# is likely fine for now since we pin the library version
try:
return _bulk_fetch_request(jira_client, issue_ids, fields)
raw_issues = _bulk_fetch_request(jira_client, issue_ids, fields)
except requests.exceptions.JSONDecodeError:
if len(issue_ids) <= 1:
logger.exception(
@@ -281,25 +277,12 @@ def _bulk_fetch_batch(
f"Jira bulk-fetch JSON decode failed for batch of {len(issue_ids)} issues. "
f"Splitting into sub-batches of {mid} and {len(issue_ids) - mid}."
)
left = _bulk_fetch_batch(jira_client, issue_ids[:mid], fields)
right = _bulk_fetch_batch(jira_client, issue_ids[mid:], fields)
left = bulk_fetch_issues(jira_client, issue_ids[:mid], fields)
right = bulk_fetch_issues(jira_client, issue_ids[mid:], fields)
return left + right
def bulk_fetch_issues(
jira_client: JIRA, issue_ids: list[str], fields: str | None = None
) -> list[Issue]:
# TODO(evan): move away from this jira library if they continue to not support
# the endpoints we need. Using private fields is not ideal, but
# is likely fine for now since we pin the library version
raw_issues: list[dict[str, Any]] = []
for batch in chunked(issue_ids, _JIRA_BULK_FETCH_LIMIT):
try:
raw_issues.extend(_bulk_fetch_batch(jira_client, list(batch), fields))
except Exception as e:
logger.error(f"Error fetching issues: {e}")
raise
except Exception as e:
logger.error(f"Error fetching issues: {e}")
raise
return [
Issue(jira_client._options, jira_client._session, raw=issue)
@@ -381,7 +364,6 @@ def process_jira_issue(
comment_email_blacklist: tuple[str, ...] = (),
labels_to_skip: set[str] | None = None,
parent_hierarchy_raw_node_id: str | None = None,
custom_fields_mapping: dict[str, str] | None = None,
) -> Document | None:
if labels_to_skip:
if any(label in issue.fields.labels for label in labels_to_skip):
@@ -467,24 +449,6 @@ def process_jira_issue(
else:
logger.error(f"Project should exist but does not for {issue.key}")
# Merge custom fields into metadata if a mapping was provided
if custom_fields_mapping:
try:
custom_fields = CustomFieldExtractor.get_issue_custom_fields(
issue, custom_fields_mapping
)
# Filter out custom fields that collide with existing metadata keys
for key in list(custom_fields.keys()):
if key in metadata_dict:
logger.warning(
f"Custom field '{key}' on {issue.key} collides with "
f"standard metadata key; skipping custom field value"
)
del custom_fields[key]
metadata_dict.update(custom_fields)
except Exception as e:
logger.warning(f"Failed to extract custom fields for {issue.key}: {e}")
return Document(
id=page_url,
sections=[TextSection(link=page_url, text=ticket_content)],
@@ -527,12 +491,6 @@ class JiraConnector(
# Custom JQL query to filter Jira issues
jql_query: str | None = None,
scoped_token: bool = False,
# When True, extract custom fields from Jira issues and include them
# in document metadata with human-readable field names.
extract_custom_fields: bool = False,
# When True, download attachments from Jira issues and yield them
# as separate Documents linked to the parent ticket.
fetch_attachments: bool = False,
) -> None:
self.batch_size = batch_size
@@ -546,11 +504,7 @@ class JiraConnector(
self.labels_to_skip = set(labels_to_skip)
self.jql_query = jql_query
self.scoped_token = scoped_token
self.extract_custom_fields = extract_custom_fields
self.fetch_attachments = fetch_attachments
self._jira_client: JIRA | None = None
# Mapping of custom field IDs to human-readable names (populated on load_credentials)
self._custom_fields_mapping: dict[str, str] = {}
# Cache project permissions to avoid fetching them repeatedly across runs
self._project_permissions_cache: dict[str, Any] = {}
@@ -711,134 +665,12 @@ class JiraConnector(
# the document belongs directly under the project in the hierarchy
return project_key
def _process_attachments(
self,
issue: Issue,
parent_hierarchy_raw_node_id: str | None,
include_permissions: bool = False,
project_key: str | None = None,
) -> Generator[Document | ConnectorFailure, None, None]:
"""Download and yield Documents for each attachment on a Jira issue.
Each attachment becomes a separate Document whose text is extracted
from the downloaded file content. Failures on individual attachments
are logged and yielded as ConnectorFailure so they never break the
overall indexing run.
"""
attachments = best_effort_get_field_from_issue(issue, "attachment")
if not attachments:
return
issue_url = build_jira_url(self.jira_base, issue.key)
for attachment in attachments:
try:
filename = getattr(attachment, "filename", "unknown")
try:
size = int(getattr(attachment, "size", 0) or 0)
except (ValueError, TypeError):
size = 0
content_url = getattr(attachment, "content", None)
attachment_id = getattr(attachment, "id", filename)
mime_type = getattr(attachment, "mimeType", "application/octet-stream")
created = getattr(attachment, "created", None)
if size > _MAX_ATTACHMENT_SIZE_BYTES:
logger.warning(
f"Skipping attachment '{filename}' on {issue.key}: "
f"size {size} bytes exceeds {_MAX_ATTACHMENT_SIZE_BYTES} byte limit"
)
continue
if not content_url:
logger.warning(
f"Skipping attachment '{filename}' on {issue.key}: "
f"no content URL available"
)
continue
# Download the attachment using the public API on the
# python-jira Attachment resource (avoids private _session access
# and the double-copy from response.content + BytesIO wrapping).
file_content = attachment.get()
# Extract text from the downloaded file
try:
text = extract_file_text(
file=BytesIO(file_content),
file_name=filename,
)
except Exception as e:
logger.warning(
f"Could not extract text from attachment '{filename}' "
f"on {issue.key}: {e}"
)
continue
if not text or not text.strip():
logger.info(
f"Skipping attachment '{filename}' on {issue.key}: "
f"no text content could be extracted"
)
continue
doc_id = f"{issue_url}/attachments/{attachment_id}"
attachment_doc = Document(
id=doc_id,
sections=[TextSection(link=issue_url, text=text)],
source=DocumentSource.JIRA,
semantic_identifier=f"{issue.key}: {filename}",
title=filename,
doc_updated_at=(time_str_to_utc(created) if created else None),
parent_hierarchy_raw_node_id=parent_hierarchy_raw_node_id,
metadata={
"parent_ticket": issue.key,
"attachment_filename": filename,
"attachment_mime_type": mime_type,
"attachment_size": str(size),
},
)
if include_permissions and project_key:
attachment_doc.external_access = self._get_project_permissions(
project_key,
add_prefix=True,
)
yield attachment_doc
except Exception as e:
logger.error(f"Failed to process attachment on {issue.key}: {e}")
yield ConnectorFailure(
failed_document=DocumentFailure(
document_id=f"{issue_url}/attachments/{getattr(attachment, 'id', 'unknown')}",
document_link=issue_url,
),
failure_message=f"Failed to process attachment '{getattr(attachment, 'filename', 'unknown')}': {str(e)}",
exception=e,
)
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
self._jira_client = build_jira_client(
credentials=credentials,
jira_base=self.jira_base,
scoped_token=self.scoped_token,
)
# Fetch the custom field ID-to-name mapping once at credential load time.
# This avoids repeated API calls during issue processing.
if self.extract_custom_fields:
try:
self._custom_fields_mapping = (
CustomFieldExtractor.get_all_custom_fields(self._jira_client)
)
logger.info(
f"Loaded {len(self._custom_fields_mapping)} custom field definitions"
)
except Exception as e:
logger.warning(
f"Failed to fetch custom field definitions; "
f"custom field extraction will be skipped: {e}"
)
self._custom_fields_mapping = {}
return None
def _get_jql_query(
@@ -969,11 +801,6 @@ class JiraConnector(
comment_email_blacklist=self.comment_email_blacklist,
labels_to_skip=self.labels_to_skip,
parent_hierarchy_raw_node_id=parent_hierarchy_raw_node_id,
custom_fields_mapping=(
self._custom_fields_mapping
if self._custom_fields_mapping
else None
),
):
# Add permission information to the document if requested
if include_permissions:
@@ -983,15 +810,6 @@ class JiraConnector(
)
yield document
# Yield attachment documents if enabled
if self.fetch_attachments:
yield from self._process_attachments(
issue=issue,
parent_hierarchy_raw_node_id=parent_hierarchy_raw_node_id,
include_permissions=include_permissions,
project_key=project_key,
)
except Exception as e:
yield ConnectorFailure(
failed_document=DocumentFailure(
@@ -1099,41 +917,20 @@ class JiraConnector(
issue_key = best_effort_get_field_from_issue(issue, _FIELD_KEY)
doc_id = build_jira_url(self.jira_base, issue_key)
parent_hierarchy_raw_node_id = (
self._get_parent_hierarchy_raw_node_id(issue, project_key)
if project_key
else None
)
project_perms = self._get_project_permissions(
project_key, add_prefix=False
)
slim_doc_batch.append(
SlimDocument(
id=doc_id,
# Permission sync path - don't prefix, upsert_document_external_perms handles it
external_access=project_perms,
parent_hierarchy_raw_node_id=parent_hierarchy_raw_node_id,
external_access=self._get_project_permissions(
project_key, add_prefix=False
),
parent_hierarchy_raw_node_id=(
self._get_parent_hierarchy_raw_node_id(issue, project_key)
if project_key
else None
),
)
)
# Also emit SlimDocument entries for each attachment
if self.fetch_attachments:
attachments = best_effort_get_field_from_issue(issue, "attachment")
if attachments:
for attachment in attachments:
attachment_id = getattr(
attachment,
"id",
getattr(attachment, "filename", "unknown"),
)
slim_doc_batch.append(
SlimDocument(
id=f"{doc_id}/attachments/{attachment_id}",
external_access=project_perms,
parent_hierarchy_raw_node_id=parent_hierarchy_raw_node_id,
)
)
current_offset += 1
if len(slim_doc_batch) >= JIRA_SLIM_PAGE_SIZE:
yield slim_doc_batch

View File

@@ -44,7 +44,7 @@ _NOTION_CALL_TIMEOUT = 30 # 30 seconds
_MAX_PAGES = 1000
# TODO: Pages need to have their metadata ingested
# TODO: Tables need to be ingested, Pages need to have their metadata ingested
class NotionPage(BaseModel):
@@ -452,19 +452,6 @@ class NotionConnector(LoadConnector, PollConnector):
sub_inner_dict: dict[str, Any] | list[Any] | str = inner_dict
while isinstance(sub_inner_dict, dict) and "type" in sub_inner_dict:
type_name = sub_inner_dict["type"]
# Notion user objects (people properties, created_by, etc.) have
# "name" at the same level as "type": "person"/"bot". If we drill
# into the person/bot sub-dict we lose the name. Capture it here
# before descending, but skip "title"-type properties where "name"
# is not the display value we want.
if (
"name" in sub_inner_dict
and isinstance(sub_inner_dict["name"], str)
and type_name not in ("title",)
):
return sub_inner_dict["name"]
sub_inner_dict = sub_inner_dict[type_name]
# If the innermost layer is None, the value is not set
@@ -676,19 +663,6 @@ class NotionConnector(LoadConnector, PollConnector):
text = rich_text["text"]["content"]
cur_result_text_arr.append(text)
# table_row blocks store content in "cells" (list of lists
# of rich text objects) rather than "rich_text"
if "cells" in result_obj:
row_cells: list[str] = []
for cell in result_obj["cells"]:
cell_texts = [
rt.get("plain_text", "")
for rt in cell
if isinstance(rt, dict)
]
row_cells.append(" ".join(cell_texts))
cur_result_text_arr.append("\t".join(row_cells))
if result["has_children"]:
if result_type == "child_page":
# Child pages will not be included at this top level, it will be a separate document.

View File

@@ -1,4 +1,3 @@
from dataclasses import dataclass
from datetime import datetime
from typing import TypedDict
@@ -7,14 +6,6 @@ from pydantic import BaseModel
from onyx.onyxbot.slack.models import ChannelType
@dataclass(frozen=True)
class DirectThreadFetch:
"""Request to fetch a Slack thread directly by channel and timestamp."""
channel_id: str
thread_ts: str
class ChannelMetadata(TypedDict):
"""Type definition for cached channel metadata."""

View File

@@ -19,7 +19,6 @@ from onyx.configs.chat_configs import DOC_TIME_DECAY
from onyx.connectors.models import IndexingDocument
from onyx.connectors.models import TextSection
from onyx.context.search.federated.models import ChannelMetadata
from onyx.context.search.federated.models import DirectThreadFetch
from onyx.context.search.federated.models import SlackMessage
from onyx.context.search.federated.slack_search_utils import ALL_CHANNEL_TYPES
from onyx.context.search.federated.slack_search_utils import build_channel_query_filter
@@ -50,6 +49,7 @@ from onyx.server.federated.models import FederatedConnectorDetail
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
from onyx.utils.timing import log_function_time
from shared_configs.configs import DOC_EMBEDDING_CONTEXT_SIZE
logger = setup_logger()
@@ -58,6 +58,7 @@ HIGHLIGHT_END_CHAR = "\ue001"
CHANNEL_METADATA_CACHE_TTL = 60 * 60 * 24 # 24 hours
USER_PROFILE_CACHE_TTL = 60 * 60 * 24 # 24 hours
SLACK_THREAD_CONTEXT_WINDOW = 3 # Number of messages before matched message to include
CHANNEL_METADATA_MAX_RETRIES = 3 # Maximum retry attempts for channel metadata fetching
CHANNEL_METADATA_RETRY_DELAY = 1 # Initial retry delay in seconds (exponential backoff)
@@ -420,94 +421,6 @@ class SlackQueryResult(BaseModel):
filtered_channels: list[str] # Channels filtered out during this query
def _fetch_thread_from_url(
thread_fetch: DirectThreadFetch,
access_token: str,
channel_metadata_dict: dict[str, ChannelMetadata] | None = None,
) -> SlackQueryResult:
"""Fetch a thread directly from a Slack URL via conversations.replies."""
channel_id = thread_fetch.channel_id
thread_ts = thread_fetch.thread_ts
slack_client = WebClient(token=access_token)
try:
response = slack_client.conversations_replies(
channel=channel_id,
ts=thread_ts,
)
response.validate()
messages: list[dict[str, Any]] = response.get("messages", [])
except SlackApiError as e:
logger.warning(
f"Failed to fetch thread from URL (channel={channel_id}, ts={thread_ts}): {e}"
)
return SlackQueryResult(messages=[], filtered_channels=[])
if not messages:
logger.warning(
f"No messages found for URL override (channel={channel_id}, ts={thread_ts})"
)
return SlackQueryResult(messages=[], filtered_channels=[])
# Build thread text from all messages
thread_text = _build_thread_text(messages, access_token, None, slack_client)
# Get channel name from metadata cache or API
channel_name = "unknown"
if channel_metadata_dict and channel_id in channel_metadata_dict:
channel_name = channel_metadata_dict[channel_id].get("name", "unknown")
else:
try:
ch_response = slack_client.conversations_info(channel=channel_id)
ch_response.validate()
channel_info: dict[str, Any] = ch_response.get("channel", {})
channel_name = channel_info.get("name", "unknown")
except SlackApiError:
pass
# Build the SlackMessage
parent_msg = messages[0]
message_ts = parent_msg.get("ts", thread_ts)
username = parent_msg.get("user", "unknown_user")
parent_text = parent_msg.get("text", "")
snippet = (
parent_text[:50].rstrip() + "..." if len(parent_text) > 50 else parent_text
).replace("\n", " ")
doc_time = datetime.fromtimestamp(float(message_ts))
decay_factor = DOC_TIME_DECAY
doc_age_years = (datetime.now() - doc_time).total_seconds() / (365 * 24 * 60 * 60)
recency_bias = max(1 / (1 + decay_factor * doc_age_years), 0.75)
permalink = (
f"https://slack.com/archives/{channel_id}/p{message_ts.replace('.', '')}"
)
slack_message = SlackMessage(
document_id=f"{channel_id}_{message_ts}",
channel_id=channel_id,
message_id=message_ts,
thread_id=None, # Prevent double-enrichment in thread context fetch
link=permalink,
metadata={
"channel": channel_name,
"time": doc_time.isoformat(),
},
timestamp=doc_time,
recency_bias=recency_bias,
semantic_identifier=f"{username} in #{channel_name}: {snippet}",
text=thread_text,
highlighted_texts=set(),
slack_score=100000.0, # High priority — user explicitly asked for this thread
)
logger.info(
f"URL override: fetched thread from channel={channel_id}, ts={thread_ts}, {len(messages)} messages"
)
return SlackQueryResult(messages=[slack_message], filtered_channels=[])
def query_slack(
query_string: str,
access_token: str,
@@ -519,6 +432,7 @@ def query_slack(
available_channels: list[str] | None = None,
channel_metadata_dict: dict[str, ChannelMetadata] | None = None,
) -> SlackQueryResult:
# Check if query has channel override (user specified channels in query)
has_channel_override = query_string.startswith("__CHANNEL_OVERRIDE__")
@@ -748,6 +662,7 @@ def _fetch_thread_context(
"""
channel_id = message.channel_id
thread_id = message.thread_id
message_id = message.message_id
# If not a thread, return original text as success
if thread_id is None:
@@ -780,37 +695,62 @@ def _fetch_thread_context(
if len(messages) <= 1:
return ThreadContextResult.success(message.text)
# Build thread text from thread starter + all replies
thread_text = _build_thread_text(messages, access_token, team_id, slack_client)
# Build thread text from thread starter + context window around matched message
thread_text = _build_thread_text(
messages, message_id, thread_id, access_token, team_id, slack_client
)
return ThreadContextResult.success(thread_text)
def _build_thread_text(
messages: list[dict[str, Any]],
message_id: str,
thread_id: str,
access_token: str,
team_id: str | None,
slack_client: WebClient,
) -> str:
"""Build thread text including all replies.
Includes the thread parent message followed by all replies in order.
"""
"""Build the thread text from messages."""
msg_text = messages[0].get("text", "")
msg_sender = messages[0].get("user", "")
thread_text = f"<@{msg_sender}>: {msg_text}"
# All messages after index 0 are replies
replies = messages[1:]
if not replies:
return thread_text
logger.debug(f"Thread {messages[0].get('ts')}: {len(replies)} replies included")
thread_text += "\n\nReplies:"
if thread_id == message_id:
message_id_idx = 0
else:
message_id_idx = next(
(i for i, msg in enumerate(messages) if msg.get("ts") == message_id), 0
)
if not message_id_idx:
return thread_text
for msg in replies:
start_idx = max(1, message_id_idx - SLACK_THREAD_CONTEXT_WINDOW)
if start_idx > 1:
thread_text += "\n..."
for i in range(start_idx, message_id_idx):
msg_text = messages[i].get("text", "")
msg_sender = messages[i].get("user", "")
thread_text += f"\n\n<@{msg_sender}>: {msg_text}"
msg_text = messages[message_id_idx].get("text", "")
msg_sender = messages[message_id_idx].get("user", "")
thread_text += f"\n\n<@{msg_sender}>: {msg_text}"
# Add following replies
len_replies = 0
for msg in messages[message_id_idx + 1 :]:
msg_text = msg.get("text", "")
msg_sender = msg.get("user", "")
thread_text += f"\n\n<@{msg_sender}>: {msg_text}"
reply = f"\n\n<@{msg_sender}>: {msg_text}"
thread_text += reply
len_replies += len(reply)
if len_replies >= DOC_EMBEDDING_CONTEXT_SIZE * 4:
thread_text += "\n..."
break
# Replace user IDs with names using cached lookups
userids: set[str] = set(re.findall(r"<@([A-Z0-9]+)>", thread_text))
@@ -1036,16 +976,7 @@ def slack_retrieval(
# Query slack with entity filtering
llm = get_default_llm()
query_items = build_slack_queries(query, llm, entities, available_channels)
# Partition into direct thread fetches and search query strings
direct_fetches: list[DirectThreadFetch] = []
query_strings: list[str] = []
for item in query_items:
if isinstance(item, DirectThreadFetch):
direct_fetches.append(item)
else:
query_strings.append(item)
query_strings = build_slack_queries(query, llm, entities, available_channels)
# Determine filtering based on entities OR context (bot)
include_dm = False
@@ -1062,16 +993,8 @@ def slack_retrieval(
f"Private channel context: will only allow messages from {allowed_private_channel} + public channels"
)
# Build search tasks — direct thread fetches + keyword searches
search_tasks: list[tuple] = [
(
_fetch_thread_from_url,
(fetch, access_token, channel_metadata_dict),
)
for fetch in direct_fetches
]
search_tasks.extend(
# Build search tasks
search_tasks = [
(
query_slack,
(
@@ -1087,7 +1010,7 @@ def slack_retrieval(
),
)
for query_string in query_strings
)
]
# If include_dm is True AND we're not already searching all channels,
# add additional searches without channel filters.

View File

@@ -10,7 +10,6 @@ from pydantic import ValidationError
from onyx.configs.app_configs import MAX_SLACK_QUERY_EXPANSIONS
from onyx.context.search.federated.models import ChannelMetadata
from onyx.context.search.federated.models import DirectThreadFetch
from onyx.context.search.models import ChunkIndexRequest
from onyx.federated_connectors.slack.models import SlackEntities
from onyx.llm.interfaces import LLM
@@ -639,38 +638,12 @@ def expand_query_with_llm(query_text: str, llm: LLM) -> list[str]:
return [query_text]
SLACK_URL_PATTERN = re.compile(
r"https?://[a-z0-9-]+\.slack\.com/archives/([A-Z0-9]+)/p(\d{16})"
)
def extract_slack_message_urls(
query_text: str,
) -> list[tuple[str, str]]:
"""Extract Slack message URLs from query text.
Parses URLs like:
https://onyx-company.slack.com/archives/C097NBWMY8Y/p1775491616524769
Returns list of (channel_id, thread_ts) tuples.
The 16-digit timestamp is converted to Slack ts format (with dot).
"""
results = []
for match in SLACK_URL_PATTERN.finditer(query_text):
channel_id = match.group(1)
raw_ts = match.group(2)
# Convert p1775491616524769 -> 1775491616.524769
thread_ts = f"{raw_ts[:10]}.{raw_ts[10:]}"
results.append((channel_id, thread_ts))
return results
def build_slack_queries(
query: ChunkIndexRequest,
llm: LLM,
entities: dict[str, Any] | None = None,
available_channels: list[str] | None = None,
) -> list[str | DirectThreadFetch]:
) -> list[str]:
"""Build Slack query strings with date filtering and query expansion."""
default_search_days = 30
if entities:
@@ -695,15 +668,6 @@ def build_slack_queries(
cutoff_date = datetime.now(timezone.utc) - timedelta(days=days_back)
time_filter = f" after:{cutoff_date.strftime('%Y-%m-%d')}"
# Check for Slack message URLs — if found, add direct fetch requests
url_fetches: list[DirectThreadFetch] = []
slack_urls = extract_slack_message_urls(query.query)
for channel_id, thread_ts in slack_urls:
url_fetches.append(
DirectThreadFetch(channel_id=channel_id, thread_ts=thread_ts)
)
logger.info(f"Detected Slack URL: channel={channel_id}, ts={thread_ts}")
# ALWAYS extract channel references from the query (not just for recency queries)
channel_references = extract_channel_references_from_query(query.query)
@@ -720,9 +684,7 @@ def build_slack_queries(
# If valid channels detected, use ONLY those channels with NO keywords
# Return query with ONLY time filter + channel filter (no keywords)
return url_fetches + [
build_channel_override_query(channel_references, time_filter)
]
return [build_channel_override_query(channel_references, time_filter)]
except ValueError as e:
# If validation fails, log the error and continue with normal flow
logger.warning(f"Channel reference validation failed: {e}")
@@ -740,8 +702,7 @@ def build_slack_queries(
rephrased_queries = expand_query_with_llm(query.query, llm)
# Build final query strings with time filters
search_queries = [
return [
rephrased_query.strip() + time_filter
for rephrased_query in rephrased_queries[:MAX_SLACK_QUERY_EXPANSIONS]
]
return url_fetches + search_queries

View File

@@ -1,7 +1,6 @@
import uuid
from fastapi_users.password import PasswordHelper
from sqlalchemy import delete
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import joinedload
@@ -11,23 +10,14 @@ from onyx.auth.api_key import ApiKeyDescriptor
from onyx.auth.api_key import build_displayable_api_key
from onyx.auth.api_key import generate_api_key
from onyx.auth.api_key import hash_api_key
from onyx.auth.schemas import UserRole
from onyx.configs.constants import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
from onyx.configs.constants import DANSWER_API_KEY_PREFIX
from onyx.configs.constants import UNNAMED_KEY_PLACEHOLDER
from onyx.db.enums import AccountType
from onyx.db.models import ApiKey
from onyx.db.models import User
from onyx.db.models import User__UserGroup
from onyx.db.models import UserGroup
from onyx.db.permissions import recompute_user_permissions__no_commit
from onyx.db.users import assign_user_to_default_groups__no_commit
from onyx.server.api_key.models import APIKeyArgs
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
def get_api_key_email_pattern() -> str:
return DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
@@ -95,7 +85,6 @@ def insert_api_key(
is_superuser=False,
is_verified=True,
role=api_key_args.role,
account_type=AccountType.SERVICE_ACCOUNT,
)
db_session.add(api_key_user_row)
@@ -108,18 +97,7 @@ def insert_api_key(
)
db_session.add(api_key_row)
# Assign the API key virtual user to the appropriate default group
# before commit so everything is atomic.
# Only ADMIN and BASIC roles get default group membership.
if api_key_args.role in (UserRole.ADMIN, UserRole.BASIC):
assign_user_to_default_groups__no_commit(
db_session,
api_key_user_row,
is_admin=(api_key_args.role == UserRole.ADMIN),
)
db_session.commit()
return ApiKeyDescriptor(
api_key_id=api_key_row.id,
api_key_role=api_key_user_row.role,
@@ -146,33 +124,7 @@ def update_api_key(
email_name = api_key_args.name or UNNAMED_KEY_PLACEHOLDER
api_key_user.email = get_api_key_fake_email(email_name, str(api_key_user.id))
old_role = api_key_user.role
api_key_user.role = api_key_args.role
# Reconcile default-group membership when the role changes.
if old_role != api_key_args.role:
# Remove from all default groups first.
delete_stmt = delete(User__UserGroup).where(
User__UserGroup.user_id == api_key_user.id,
User__UserGroup.user_group_id.in_(
select(UserGroup.id).where(UserGroup.is_default.is_(True))
),
)
db_session.execute(delete_stmt)
# Re-assign to the correct default group (only for ADMIN/BASIC).
if api_key_args.role in (UserRole.ADMIN, UserRole.BASIC):
assign_user_to_default_groups__no_commit(
db_session,
api_key_user,
is_admin=(api_key_args.role == UserRole.ADMIN),
)
else:
# No group assigned for LIMITED, but we still need to recompute
# since we just removed the old default-group membership above.
recompute_user_permissions__no_commit(api_key_user.id, db_session)
db_session.commit()
return ApiKeyDescriptor(

View File

@@ -190,23 +190,16 @@ def delete_messages_and_files_from_chat_session(
chat_session_id: UUID, db_session: Session
) -> None:
# Select messages older than cutoff_time with files
messages_with_files = (
db_session.execute(
select(ChatMessage.id, ChatMessage.files).where(
ChatMessage.chat_session_id == chat_session_id,
)
messages_with_files = db_session.execute(
select(ChatMessage.id, ChatMessage.files).where(
ChatMessage.chat_session_id == chat_session_id,
)
.tuples()
.all()
)
).fetchall()
file_store = get_default_file_store()
for _, files in messages_with_files:
file_store = get_default_file_store()
for file_info in files or []:
if file_info.get("user_file_id"):
# user files are managed by the user file lifecycle
continue
file_store.delete_file(file_id=file_info["id"], error_on_missing=False)
file_store.delete_file(file_id=file_info.get("id"))
# Delete ChatMessage records - CASCADE constraints will automatically handle:
# - ChatMessage__StandardAnswer relationship records

View File

@@ -13,26 +13,19 @@ class AccountType(str, PyEnum):
BOT, EXT_PERM_USER, ANONYMOUS → fixed behavior
"""
STANDARD = "STANDARD"
BOT = "BOT"
EXT_PERM_USER = "EXT_PERM_USER"
SERVICE_ACCOUNT = "SERVICE_ACCOUNT"
ANONYMOUS = "ANONYMOUS"
def is_web_login(self) -> bool:
"""Whether this account type supports interactive web login."""
return self not in (
AccountType.BOT,
AccountType.EXT_PERM_USER,
)
STANDARD = "standard"
BOT = "bot"
EXT_PERM_USER = "ext_perm_user"
SERVICE_ACCOUNT = "service_account"
ANONYMOUS = "anonymous"
class GrantSource(str, PyEnum):
"""How a permission grant was created."""
USER = "USER"
SCIM = "SCIM"
SYSTEM = "SYSTEM"
USER = "user"
SCIM = "scim"
SYSTEM = "system"
class IndexingStatus(str, PyEnum):

View File

@@ -8,8 +8,6 @@ from sqlalchemy.orm import selectinload
from sqlalchemy.orm import Session
from onyx.configs.constants import FederatedConnectorSource
from onyx.configs.constants import MASK_CREDENTIAL_CHAR
from onyx.configs.constants import MASK_CREDENTIAL_LONG_RE
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.models import DocumentSet
from onyx.db.models import FederatedConnector
@@ -47,23 +45,6 @@ def fetch_all_federated_connectors_parallel() -> list[FederatedConnector]:
return fetch_all_federated_connectors(db_session)
def _reject_masked_credentials(credentials: dict[str, Any]) -> None:
"""Raise if any credential string value contains mask placeholder characters.
mask_string() has two output formats:
- Short strings (< 14 chars): "••••••••••••" (U+2022 BULLET)
- Long strings (>= 14 chars): "abcd...wxyz" (first4 + "..." + last4)
Both must be rejected.
"""
for key, val in credentials.items():
if isinstance(val, str) and (
MASK_CREDENTIAL_CHAR in val or MASK_CREDENTIAL_LONG_RE.match(val)
):
raise ValueError(
f"Credential field '{key}' contains masked placeholder characters. Please provide the actual credential value."
)
def validate_federated_connector_credentials(
source: FederatedConnectorSource,
credentials: dict[str, Any],
@@ -85,8 +66,6 @@ def create_federated_connector(
config: dict[str, Any] | None = None,
) -> FederatedConnector:
"""Create a new federated connector with credential and config validation."""
_reject_masked_credentials(credentials)
# Validate credentials before creating
if not validate_federated_connector_credentials(source, credentials):
raise ValueError(
@@ -298,8 +277,6 @@ def update_federated_connector(
)
if credentials is not None:
_reject_masked_credentials(credentials)
# Validate credentials before updating
if not validate_federated_connector_credentials(
federated_connector.source, credentials

View File

@@ -236,15 +236,14 @@ def upsert_llm_provider(
db_session.add(existing_llm_provider)
# Filter out empty strings and None values from custom_config to allow
# providers like Bedrock to fall back to IAM roles when credentials are not provided.
# NOTE: An empty dict ({}) is preserved as-is — it signals that the provider was
# created via the custom modal and must be reopened with CustomModal, not a
# provider-specific modal. Only None means "no custom config at all".
# providers like Bedrock to fall back to IAM roles when credentials are not provided
custom_config = llm_provider_upsert_request.custom_config
if custom_config:
custom_config = {
k: v for k, v in custom_config.items() if v is not None and v.strip() != ""
}
# Set to None if the dict is empty after filtering
custom_config = custom_config or None
api_base = llm_provider_upsert_request.api_base or None
existing_llm_provider.provider = llm_provider_upsert_request.provider
@@ -304,7 +303,16 @@ def upsert_llm_provider(
).delete(synchronize_session="fetch")
db_session.flush()
# Import here to avoid circular imports
from onyx.llm.utils import get_max_input_tokens
for model_config in llm_provider_upsert_request.model_configurations:
max_input_tokens = model_config.max_input_tokens
if max_input_tokens is None:
max_input_tokens = get_max_input_tokens(
model_name=model_config.name,
model_provider=llm_provider_upsert_request.provider,
)
supported_flows = [LLMModelFlowType.CHAT]
if model_config.supports_image_input:
@@ -317,7 +325,7 @@ def upsert_llm_provider(
model_configuration_id=existing.id,
supported_flows=supported_flows,
is_visible=model_config.is_visible,
max_input_tokens=model_config.max_input_tokens,
max_input_tokens=max_input_tokens,
display_name=model_config.display_name,
)
else:
@@ -327,7 +335,7 @@ def upsert_llm_provider(
model_name=model_config.name,
supported_flows=supported_flows,
is_visible=model_config.is_visible,
max_input_tokens=model_config.max_input_tokens,
max_input_tokens=max_input_tokens,
display_name=model_config.display_name,
)

View File

@@ -305,11 +305,8 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
role: Mapped[UserRole] = mapped_column(
Enum(UserRole, native_enum=False, default=UserRole.BASIC)
)
account_type: Mapped[AccountType] = mapped_column(
Enum(AccountType, native_enum=False),
nullable=False,
default=AccountType.STANDARD,
server_default="STANDARD",
account_type: Mapped[AccountType | None] = mapped_column(
Enum(AccountType, native_enum=False), nullable=True
)
"""
@@ -356,13 +353,6 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
postgresql.JSONB(), nullable=True, default=None
)
effective_permissions: Mapped[list[str]] = mapped_column(
postgresql.JSONB(),
nullable=False,
default=list,
server_default=text("'[]'::jsonb"),
)
oidc_expiry: Mapped[datetime.datetime] = mapped_column(
TIMESTAMPAware(timezone=True), nullable=True
)
@@ -4026,12 +4016,7 @@ class PermissionGrant(Base):
ForeignKey("user_group.id", ondelete="CASCADE"), nullable=False
)
permission: Mapped[Permission] = mapped_column(
Enum(
Permission,
native_enum=False,
values_callable=lambda x: [e.value for e in x],
),
nullable=False,
Enum(Permission, native_enum=False), nullable=False
)
grant_source: Mapped[GrantSource] = mapped_column(
Enum(GrantSource, native_enum=False), nullable=False

View File

@@ -324,15 +324,6 @@ def mark_migration_completed_time_if_not_set_with_commit(
db_session.commit()
def is_migration_completed(db_session: Session) -> bool:
"""Returns True if the migration is completed.
Can be run even if the migration record does not exist.
"""
record = db_session.query(OpenSearchTenantMigrationRecord).first()
return record is not None and record.migration_completed_at is not None
def build_sanitized_to_original_doc_id_mapping(
db_session: Session,
) -> dict[str, str]:

View File

@@ -1,95 +0,0 @@
"""
DB operations for recomputing user effective_permissions.
These live in onyx/db/ (not onyx/auth/) because they are pure DB operations
that query PermissionGrant rows and update the User.effective_permissions
JSONB column. Keeping them here avoids circular imports when called from
other onyx/db/ modules such as users.py.
"""
from collections import defaultdict
from uuid import UUID
from sqlalchemy import select
from sqlalchemy import update
from sqlalchemy.orm import Session
from onyx.db.models import PermissionGrant
from onyx.db.models import User
from onyx.db.models import User__UserGroup
def recompute_user_permissions__no_commit(
user_ids: UUID | str | list[UUID] | list[str], db_session: Session
) -> None:
"""Recompute granted permissions for one or more users.
Accepts a single UUID or a list. Uses a single query regardless of
how many users are passed, avoiding N+1 issues.
Stores only directly granted permissions — implication expansion
happens at read time via get_effective_permissions().
Does NOT commit — caller must commit the session.
"""
if isinstance(user_ids, (UUID, str)):
uid_list = [user_ids]
else:
uid_list = list(user_ids)
if not uid_list:
return
# Single query to fetch ALL permissions for these users across ALL their
# groups (a user may belong to multiple groups with different grants).
rows = db_session.execute(
select(User__UserGroup.user_id, PermissionGrant.permission)
.join(
PermissionGrant,
PermissionGrant.group_id == User__UserGroup.user_group_id,
)
.where(
User__UserGroup.user_id.in_(uid_list),
PermissionGrant.is_deleted.is_(False),
)
).all()
# Group permissions by user; users with no grants get an empty set.
perms_by_user: dict[UUID | str, set[str]] = defaultdict(set)
for uid in uid_list:
perms_by_user[uid] # ensure every user has an entry
for uid, perm in rows:
perms_by_user[uid].add(perm.value)
for uid, perms in perms_by_user.items():
db_session.execute(
update(User)
.where(User.id == uid) # type: ignore[arg-type]
.values(effective_permissions=sorted(perms))
)
def recompute_permissions_for_group__no_commit(
group_id: int, db_session: Session
) -> None:
"""Recompute granted permissions for all users in a group.
Does NOT commit — caller must commit the session.
"""
user_ids: list[UUID] = [
uid
for uid in db_session.execute(
select(User__UserGroup.user_id).where(
User__UserGroup.user_group_id == group_id,
User__UserGroup.user_id.isnot(None),
)
)
.scalars()
.all()
if uid is not None
]
if not user_ids:
return
recompute_user_permissions__no_commit(user_ids, db_session)

View File

@@ -5,11 +5,11 @@ from urllib.parse import urlencode
from sqlalchemy import select
from sqlalchemy.orm import Session
from onyx.auth.schemas import UserRole
from onyx.configs.app_configs import INSTANCE_TYPE
from onyx.configs.constants import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
from onyx.configs.constants import NotificationType
from onyx.configs.constants import ONYX_UTM_SOURCE
from onyx.db.enums import AccountType
from onyx.db.models import User
from onyx.db.notification import batch_create_notifications
from onyx.server.features.release_notes.constants import DOCS_CHANGELOG_BASE_URL
@@ -49,7 +49,7 @@ def create_release_notifications_for_versions(
db_session.scalars(
select(User.id).where( # type: ignore
User.is_active == True, # noqa: E712
User.account_type.notin_([AccountType.BOT, AccountType.EXT_PERM_USER]),
User.role.notin_([UserRole.SLACK_USER, UserRole.EXT_PERM_USER]),
User.email.endswith(DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN).is_(False), # type: ignore[attr-defined]
)
).all()

View File

@@ -9,18 +9,12 @@ from sqlalchemy import update
from sqlalchemy.orm import Session
from onyx.auth.schemas import UserRole
from onyx.db.enums import AccountType
from onyx.db.enums import DefaultAppMode
from onyx.db.enums import ThemePreference
from onyx.db.models import AccessToken
from onyx.db.models import Assistant__UserSpecificConfig
from onyx.db.models import Memory
from onyx.db.models import User
from onyx.db.models import User__UserGroup
from onyx.db.models import UserGroup
from onyx.db.permissions import recompute_user_permissions__no_commit
from onyx.db.users import assign_user_to_default_groups__no_commit
from onyx.db.users import is_limited_user
from onyx.server.manage.models import MemoryItem
from onyx.server.manage.models import UserSpecificAssistantPreference
from onyx.utils.logger import setup_logger
@@ -29,56 +23,13 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
_ROLE_TO_ACCOUNT_TYPE: dict[UserRole, AccountType] = {
UserRole.SLACK_USER: AccountType.BOT,
UserRole.EXT_PERM_USER: AccountType.EXT_PERM_USER,
}
def update_user_role(
user: User,
new_role: UserRole,
db_session: Session,
) -> None:
"""Update a user's role in the database.
Dual-writes account_type to keep it in sync with role and
reconciles default-group membership (Admin / Basic)."""
old_role = user.role
"""Update a user's role in the database."""
user.role = new_role
# Note: setting account_type to BOT or EXT_PERM_USER causes
# assign_user_to_default_groups__no_commit to early-return, which is
# intentional — these account types should not be in default groups.
if new_role in _ROLE_TO_ACCOUNT_TYPE:
user.account_type = _ROLE_TO_ACCOUNT_TYPE[new_role]
elif user.account_type in (AccountType.BOT, AccountType.EXT_PERM_USER):
# Upgrading from a non-web-login account type to a web role
user.account_type = AccountType.STANDARD
# Reconcile default-group membership when the role changes.
if old_role != new_role:
# Remove from all default groups first.
db_session.execute(
delete(User__UserGroup).where(
User__UserGroup.user_id == user.id,
User__UserGroup.user_group_id.in_(
select(UserGroup.id).where(UserGroup.is_default.is_(True))
),
)
)
# Re-assign to the correct default group.
# assign_user_to_default_groups__no_commit internally skips
# ANONYMOUS, BOT, and EXT_PERM_USER account types.
# Also skip limited users (no group assignment).
if not is_limited_user(user):
assign_user_to_default_groups__no_commit(
db_session,
user,
is_admin=(new_role == UserRole.ADMIN),
)
recompute_user_permissions__no_commit(user.id, db_session)
db_session.commit()
@@ -96,19 +47,8 @@ def activate_user(
user: User,
db_session: Session,
) -> None:
"""Activate a user by setting is_active to True.
Also reconciles default-group membership — the user may have been
created while inactive or deactivated before the backfill migration.
"""
"""Activate a user by setting is_active to True."""
user.is_active = True
# assign_user_to_default_groups__no_commit internally skips
# ANONYMOUS, BOT, and EXT_PERM_USER account types.
# Also skip limited users (no group assignment).
if not is_limited_user(user):
assign_user_to_default_groups__no_commit(
db_session, user, is_admin=(user.role == UserRole.ADMIN)
)
db_session.add(user)
db_session.commit()

View File

@@ -17,9 +17,8 @@ from sqlalchemy.sql.expression import or_
from onyx.auth.invited_users import remove_user_from_invited_users
from onyx.auth.schemas import UserRole
from onyx.configs.constants import ANONYMOUS_USER_EMAIL
from onyx.configs.constants import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
from onyx.configs.constants import NO_AUTH_PLACEHOLDER_USER_EMAIL
from onyx.db.enums import AccountType
from onyx.db.api_key import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
from onyx.db.models import DocumentSet
from onyx.db.models import DocumentSet__User
from onyx.db.models import Persona
@@ -28,35 +27,11 @@ from onyx.db.models import SamlAccount
from onyx.db.models import User
from onyx.db.models import User__UserGroup
from onyx.db.models import UserGroup
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
logger = setup_logger()
def is_limited_user(user: User) -> bool:
"""Check if a user is effectively limited — i.e. should be denied
access by ``current_user`` and should not receive default-group
membership.
A user is limited when they are:
* an anonymous user, or
* a service account with no effective permissions (no group membership).
"""
if user.account_type == AccountType.ANONYMOUS:
return True
if (
user.account_type == AccountType.SERVICE_ACCOUNT
and not user.effective_permissions
):
return True
return False
def validate_user_role_update(
requested_role: UserRole,
current_account_type: AccountType,
explicit_override: bool = False,
requested_role: UserRole, current_role: UserRole, explicit_override: bool = False
) -> None:
"""
Validate that a user role update is valid.
@@ -66,27 +41,28 @@ def validate_user_role_update(
- requested role is a slack user
- requested role is an external permissioned user
- requested role is a limited user
- current account type is BOT (slack user)
- current account type is EXT_PERM_USER
- current account type is ANONYMOUS or SERVICE_ACCOUNT
- current role is a slack user
- current role is an external permissioned user
- current role is a limited user
"""
if current_account_type == AccountType.BOT:
if current_role == UserRole.SLACK_USER:
raise HTTPException(
status_code=400,
detail="To change a Slack User's role, they must first login to Onyx via the web app.",
)
if current_account_type == AccountType.EXT_PERM_USER:
if current_role == UserRole.EXT_PERM_USER:
# This shouldn't happen, but just in case
raise HTTPException(
status_code=400,
detail="To change an External Permissioned User's role, they must first login to Onyx via the web app.",
)
if current_account_type in (AccountType.ANONYMOUS, AccountType.SERVICE_ACCOUNT):
if current_role == UserRole.LIMITED:
raise HTTPException(
status_code=400,
detail="Cannot change the role of an anonymous or service account user.",
detail="To change a Limited User's role, they must first login to Onyx via the web app.",
)
if explicit_override:
@@ -322,7 +298,6 @@ def _generate_slack_user(email: str) -> User:
email=email,
hashed_password=hashed_pass,
role=UserRole.SLACK_USER,
account_type=AccountType.BOT,
)
@@ -331,9 +306,8 @@ def add_slack_user_if_not_exists(db_session: Session, email: str) -> User:
user = get_user_by_email(email, db_session)
if user is not None:
# If the user is an external permissioned user, we update it to a slack user
if user.account_type == AccountType.EXT_PERM_USER:
if user.role == UserRole.EXT_PERM_USER:
user.role = UserRole.SLACK_USER
user.account_type = AccountType.BOT
db_session.commit()
return user
@@ -370,7 +344,6 @@ def _generate_ext_permissioned_user(email: str) -> User:
email=email,
hashed_password=hashed_pass,
role=UserRole.EXT_PERM_USER,
account_type=AccountType.EXT_PERM_USER,
)
@@ -402,81 +375,6 @@ def batch_add_ext_perm_user_if_not_exists(
return all_users
def assign_user_to_default_groups__no_commit(
db_session: Session,
user: User,
is_admin: bool = False,
) -> None:
"""Assign a newly created user to the appropriate default group.
Does NOT commit — callers must commit the session themselves so that
group assignment can be part of the same transaction as user creation.
Args:
is_admin: If True, assign to Admin default group; otherwise Basic.
Callers determine this from their own context (e.g. user_count,
admin email list, explicit choice). Defaults to False (Basic).
"""
if user.account_type in (
AccountType.BOT,
AccountType.EXT_PERM_USER,
AccountType.ANONYMOUS,
):
return
target_group_name = "Admin" if is_admin else "Basic"
default_group = (
db_session.query(UserGroup)
.filter(
UserGroup.name == target_group_name,
UserGroup.is_default.is_(True),
)
.first()
)
if default_group is None:
raise RuntimeError(
f"Default group '{target_group_name}' not found. "
f"Cannot assign user {user.email} to a group. "
f"Ensure the seed_default_groups migration has run."
)
# Check if the user is already in the group
existing = (
db_session.query(User__UserGroup)
.filter(
User__UserGroup.user_id == user.id,
User__UserGroup.user_group_id == default_group.id,
)
.first()
)
if existing is not None:
return
savepoint = db_session.begin_nested()
try:
db_session.add(
User__UserGroup(
user_id=user.id,
user_group_id=default_group.id,
)
)
db_session.flush()
except IntegrityError:
# Race condition: another transaction inserted this membership
# between our SELECT and INSERT. The savepoint isolates the failure
# so the outer transaction (user creation) stays intact.
savepoint.rollback()
return
from onyx.db.permissions import recompute_user_permissions__no_commit
recompute_user_permissions__no_commit(user.id, db_session)
logger.info(f"Assigned user {user.email} to default group '{default_group.name}'")
def delete_user_from_db(
user_to_delete: User,
db_session: Session,
@@ -523,14 +421,13 @@ def delete_user_from_db(
def batch_get_user_groups(
db_session: Session,
user_ids: list[UUID],
include_default: bool = False,
) -> dict[UUID, list[tuple[int, str]]]:
"""Fetch group memberships for a batch of users in a single query.
Returns a mapping of user_id -> list of (group_id, group_name) tuples."""
if not user_ids:
return {}
stmt = (
rows = db_session.execute(
select(
User__UserGroup.user_id,
UserGroup.id,
@@ -538,11 +435,7 @@ def batch_get_user_groups(
)
.join(UserGroup, UserGroup.id == User__UserGroup.user_group_id)
.where(User__UserGroup.user_id.in_(user_ids))
)
if not include_default:
stmt = stmt.where(UserGroup.is_default == False) # noqa: E712
rows = db_session.execute(stmt).all()
).all()
result: dict[UUID, list[tuple[int, str]]] = {uid: [] for uid in user_ids}
for user_id, group_id, group_name in rows:

View File

@@ -1,4 +1,3 @@
import hashlib
from datetime import datetime
from datetime import timezone
from typing import Any
@@ -21,13 +20,9 @@ from onyx.document_index.opensearch.constants import DEFAULT_MAX_CHUNK_SIZE
from onyx.document_index.opensearch.constants import EF_CONSTRUCTION
from onyx.document_index.opensearch.constants import EF_SEARCH
from onyx.document_index.opensearch.constants import M
from onyx.document_index.opensearch.string_filtering import DocumentIDTooLongError
from onyx.document_index.opensearch.string_filtering import (
filter_and_validate_document_id,
)
from onyx.document_index.opensearch.string_filtering import (
MAX_DOCUMENT_ID_ENCODED_LENGTH,
)
from onyx.utils.tenant import get_tenant_id_short_string
from shared_configs.configs import MULTI_TENANT
from shared_configs.contextvars import get_current_tenant_id
@@ -80,50 +75,17 @@ def get_opensearch_doc_chunk_id(
This will be the string used to identify the chunk in OpenSearch. Any direct
chunk queries should use this function.
If the document ID is too long, a hash of the ID is used instead.
"""
opensearch_doc_chunk_id_suffix: str = f"__{max_chunk_size}__{chunk_index}"
encoded_suffix_length: int = len(opensearch_doc_chunk_id_suffix.encode("utf-8"))
max_encoded_permissible_doc_id_length: int = (
MAX_DOCUMENT_ID_ENCODED_LENGTH - encoded_suffix_length
sanitized_document_id = filter_and_validate_document_id(document_id)
opensearch_doc_chunk_id = (
f"{sanitized_document_id}__{max_chunk_size}__{chunk_index}"
)
opensearch_doc_chunk_id_tenant_prefix: str = ""
if tenant_state.multitenant:
short_tenant_id: str = get_tenant_id_short_string(tenant_state.tenant_id)
# Use tenant ID because in multitenant mode each tenant has its own
# Documents table, so there is a very small chance that doc IDs are not
# actually unique across all tenants.
opensearch_doc_chunk_id_tenant_prefix = f"{short_tenant_id}__"
encoded_prefix_length: int = len(
opensearch_doc_chunk_id_tenant_prefix.encode("utf-8")
)
max_encoded_permissible_doc_id_length -= encoded_prefix_length
try:
sanitized_document_id: str = filter_and_validate_document_id(
document_id, max_encoded_length=max_encoded_permissible_doc_id_length
)
except DocumentIDTooLongError:
# If the document ID is too long, use a hash instead.
# We use blake2b because it is faster and equally secure as SHA256, and
# accepts digest_size which controls the number of bytes returned in the
# hash.
# digest_size is the size of the returned hash in bytes. Since we're
# decoding the hash bytes as a hex string, the digest_size should be
# half the max target size of the hash string.
# Subtract 1 because filter_and_validate_document_id compares on >= on
# max_encoded_length.
# 64 is the max digest_size blake2b returns.
digest_size: int = min((max_encoded_permissible_doc_id_length - 1) // 2, 64)
sanitized_document_id = hashlib.blake2b(
document_id.encode("utf-8"), digest_size=digest_size
).hexdigest()
opensearch_doc_chunk_id: str = (
f"{opensearch_doc_chunk_id_tenant_prefix}{sanitized_document_id}{opensearch_doc_chunk_id_suffix}"
)
short_tenant_id = get_tenant_id_short_string(tenant_state.tenant_id)
opensearch_doc_chunk_id = f"{short_tenant_id}__{opensearch_doc_chunk_id}"
# Do one more validation to ensure we haven't exceeded the max length.
opensearch_doc_chunk_id = filter_and_validate_document_id(opensearch_doc_chunk_id)
return opensearch_doc_chunk_id

View File

@@ -1,15 +1,7 @@
import re
MAX_DOCUMENT_ID_ENCODED_LENGTH: int = 512
class DocumentIDTooLongError(ValueError):
"""Raised when a document ID is too long for OpenSearch after filtering."""
def filter_and_validate_document_id(
document_id: str, max_encoded_length: int = MAX_DOCUMENT_ID_ENCODED_LENGTH
) -> str:
def filter_and_validate_document_id(document_id: str) -> str:
"""
Filters and validates a document ID such that it can be used as an ID in
OpenSearch.
@@ -27,13 +19,9 @@ def filter_and_validate_document_id(
Args:
document_id: The document ID to filter and validate.
max_encoded_length: The maximum length of the document ID after
filtering in bytes. Compared with >= for extra resilience, so
encoded values of this length will fail.
Raises:
DocumentIDTooLongError: If the document ID is too long after filtering.
ValueError: If the document ID is empty after filtering.
ValueError: If the document ID is empty or too long after filtering.
Returns:
str: The filtered document ID.
@@ -41,8 +29,6 @@ def filter_and_validate_document_id(
filtered_document_id = re.sub(r"[^A-Za-z0-9_.\-~]", "", document_id)
if not filtered_document_id:
raise ValueError(f"Document ID {document_id} is empty after filtering.")
if len(filtered_document_id.encode("utf-8")) >= max_encoded_length:
raise DocumentIDTooLongError(
f"Document ID {document_id} is too long after filtering."
)
if len(filtered_document_id.encode("utf-8")) >= 512:
raise ValueError(f"Document ID {document_id} is too long after filtering.")
return filtered_document_id

View File

@@ -52,21 +52,9 @@ KNOWN_OPENPYXL_BUGS = [
def get_markitdown_converter() -> "MarkItDown":
global _MARKITDOWN_CONVERTER
from markitdown import MarkItDown
if _MARKITDOWN_CONVERTER is None:
from markitdown import MarkItDown
# Patch this function to effectively no-op because we were seeing this
# module take an inordinate amount of time to convert charts to markdown,
# making some powerpoint files with many or complicated charts nearly
# unindexable.
from markitdown.converters._pptx_converter import PptxConverter
setattr(
PptxConverter,
"_convert_chart_to_markdown",
lambda self, chart: "\n\n[chart omitted]\n\n", # noqa: ARG005
)
_MARKITDOWN_CONVERTER = MarkItDown(enable_plugins=False)
return _MARKITDOWN_CONVERTER
@@ -217,26 +205,18 @@ def read_pdf_file(
try:
pdf_reader = PdfReader(file)
if pdf_reader.is_encrypted:
# Try the explicit password first, then fall back to an empty
# string. Owner-password-only PDFs (permission restrictions but
# no open password) decrypt successfully with "".
# See https://github.com/onyx-dot-app/onyx/issues/9754
passwords = [p for p in [pdf_pass, ""] if p is not None]
if pdf_reader.is_encrypted and pdf_pass is not None:
decrypt_success = False
for pw in passwords:
try:
if pdf_reader.decrypt(pw) != 0:
decrypt_success = True
break
except Exception:
pass
try:
decrypt_success = pdf_reader.decrypt(pdf_pass) != 0
except Exception:
logger.error("Unable to decrypt pdf")
if not decrypt_success:
logger.error(
"Encrypted PDF could not be decrypted, returning empty text."
)
return "", metadata, []
elif pdf_reader.is_encrypted:
logger.warning("No Password for an encrypted PDF, returning empty text.")
return "", metadata, []
# Basic PDF metadata
if pdf_reader.metadata is not None:

View File

@@ -33,20 +33,8 @@ def is_pdf_protected(file: IO[Any]) -> bool:
with preserve_position(file):
reader = PdfReader(file)
if not reader.is_encrypted:
return False
# PDFs with only an owner password (permission restrictions like
# print/copy disabled) use an empty user password — any viewer can open
# them without prompting. decrypt("") returns 0 only when a real user
# password is required. See https://github.com/onyx-dot-app/onyx/issues/9754
try:
return reader.decrypt("") == 0
except Exception:
logger.exception(
"Failed to evaluate PDF encryption; treating as password protected"
)
return True
return bool(reader.is_encrypted)
def is_docx_protected(file: IO[Any]) -> bool:

View File

@@ -136,14 +136,12 @@ class FileStore(ABC):
"""
@abstractmethod
def delete_file(self, file_id: str, error_on_missing: bool = True) -> None:
def delete_file(self, file_id: str) -> None:
"""
Delete a file by its ID.
Parameters:
- file_id: ID of file to delete
- error_on_missing: If False, silently return when the file record
does not exist instead of raising.
- file_name: Name of file to delete
"""
@abstractmethod
@@ -454,23 +452,12 @@ class S3BackedFileStore(FileStore):
logger.warning(f"Error getting file size for {file_id}: {e}")
return None
def delete_file(
self,
file_id: str,
error_on_missing: bool = True,
db_session: Session | None = None,
) -> None:
def delete_file(self, file_id: str, db_session: Session | None = None) -> None:
with get_session_with_current_tenant_if_none(db_session) as db_session:
try:
file_record = get_filerecord_by_file_id_optional(
file_record = get_filerecord_by_file_id(
file_id=file_id, db_session=db_session
)
if file_record is None:
if error_on_missing:
raise RuntimeError(
f"File by id {file_id} does not exist or was deleted"
)
return
if not file_record.bucket_name:
logger.error(
f"File record {file_id} with key {file_record.object_key} "

View File

@@ -222,23 +222,12 @@ class PostgresBackedFileStore(FileStore):
logger.warning(f"Error getting file size for {file_id}: {e}")
return None
def delete_file(
self,
file_id: str,
error_on_missing: bool = True,
db_session: Session | None = None,
) -> None:
def delete_file(self, file_id: str, db_session: Session | None = None) -> None:
with get_session_with_current_tenant_if_none(db_session) as session:
try:
file_content = get_file_content_by_file_id_optional(
file_content = get_file_content_by_file_id(
file_id=file_id, db_session=session
)
if file_content is None:
if error_on_missing:
raise RuntimeError(
f"File content for file_id {file_id} does not exist or was deleted"
)
return
raw_conn = _get_raw_connection(session)
try:

View File

@@ -26,7 +26,6 @@ class LlmProviderNames(str, Enum):
MISTRAL = "mistral"
LITELLM_PROXY = "litellm_proxy"
BIFROST = "bifrost"
OPENAI_COMPATIBLE = "openai_compatible"
def __str__(self) -> str:
"""Needed so things like:
@@ -47,7 +46,6 @@ WELL_KNOWN_PROVIDER_NAMES = [
LlmProviderNames.LM_STUDIO,
LlmProviderNames.LITELLM_PROXY,
LlmProviderNames.BIFROST,
LlmProviderNames.OPENAI_COMPATIBLE,
]
@@ -66,7 +64,6 @@ PROVIDER_DISPLAY_NAMES: dict[str, str] = {
LlmProviderNames.LM_STUDIO: "LM Studio",
LlmProviderNames.LITELLM_PROXY: "LiteLLM Proxy",
LlmProviderNames.BIFROST: "Bifrost",
LlmProviderNames.OPENAI_COMPATIBLE: "OpenAI-Compatible",
"groq": "Groq",
"anyscale": "Anyscale",
"deepseek": "DeepSeek",
@@ -87,44 +84,6 @@ PROVIDER_DISPLAY_NAMES: dict[str, str] = {
"gemini": "Gemini",
"stability": "Stability",
"writer": "Writer",
# Custom provider display names (used in the custom provider picker)
"aiml": "AI/ML",
"assemblyai": "AssemblyAI",
"aws_polly": "AWS Polly",
"azure_ai": "Azure AI",
"chatgpt": "ChatGPT",
"cohere_chat": "Cohere Chat",
"datarobot": "DataRobot",
"deepgram": "Deepgram",
"deepinfra": "DeepInfra",
"elevenlabs": "ElevenLabs",
"fal_ai": "fal.ai",
"featherless_ai": "Featherless AI",
"fireworks_ai": "Fireworks AI",
"friendliai": "FriendliAI",
"gigachat": "GigaChat",
"github_copilot": "GitHub Copilot",
"gradient_ai": "Gradient AI",
"huggingface": "HuggingFace",
"jina_ai": "Jina AI",
"lambda_ai": "Lambda AI",
"llamagate": "LlamaGate",
"meta_llama": "Meta Llama",
"minimax": "MiniMax",
"nlp_cloud": "NLP Cloud",
"nvidia_nim": "NVIDIA NIM",
"oci": "OCI",
"ovhcloud": "OVHcloud",
"palm": "PaLM",
"publicai": "PublicAI",
"runwayml": "RunwayML",
"sambanova": "SambaNova",
"together_ai": "Together AI",
"vercel_ai_gateway": "Vercel AI Gateway",
"volcengine": "Volcengine",
"wandb": "W&B",
"watsonx": "IBM watsonx",
"zai": "ZAI",
}
# Map vendors to their brand names (used for provider_display_name generation)
@@ -157,7 +116,6 @@ AGGREGATOR_PROVIDERS: set[str] = {
LlmProviderNames.AZURE,
LlmProviderNames.LITELLM_PROXY,
LlmProviderNames.BIFROST,
LlmProviderNames.OPENAI_COMPATIBLE,
}
# Model family name mappings for display name generation

View File

@@ -327,19 +327,12 @@ class LitellmLLM(LLM):
):
model_kwargs[VERTEX_LOCATION_KWARG] = "global"
# Bifrost and OpenAI-compatible: OpenAI-compatible proxies that send
# model names directly to the endpoint. We route through LiteLLM's
# openai provider with the server's base URL, and ensure /v1 is appended.
if model_provider in (
LlmProviderNames.BIFROST,
LlmProviderNames.OPENAI_COMPATIBLE,
):
# Bifrost: OpenAI-compatible proxy that expects model names in
# provider/model format (e.g. "anthropic/claude-sonnet-4-6").
# We route through LiteLLM's openai provider with the Bifrost base URL,
# and ensure /v1 is appended.
if model_provider == LlmProviderNames.BIFROST:
self._custom_llm_provider = "openai"
# LiteLLM's OpenAI client requires an api_key to be set.
# Many OpenAI-compatible servers don't need auth, so supply a
# placeholder to prevent LiteLLM from raising AuthenticationError.
if not self._api_key:
model_kwargs.setdefault("api_key", "not-needed")
if self._api_base is not None:
base = self._api_base.rstrip("/")
self._api_base = base if base.endswith("/v1") else f"{base}/v1"
@@ -456,20 +449,17 @@ class LitellmLLM(LLM):
optional_kwargs: dict[str, Any] = {}
# Model name
is_openai_compatible_proxy = self._model_provider in (
LlmProviderNames.BIFROST,
LlmProviderNames.OPENAI_COMPATIBLE,
)
is_bifrost = self._model_provider == LlmProviderNames.BIFROST
model_provider = (
f"{self.config.model_provider}/responses"
if is_openai_model # Uses litellm's completions -> responses bridge
else self.config.model_provider
)
if is_openai_compatible_proxy:
# OpenAI-compatible proxies (Bifrost, generic OpenAI-compatible
# servers) expect model names sent directly to their endpoint.
# We use custom_llm_provider="openai" so LiteLLM doesn't try
# to route based on the provider prefix.
if is_bifrost:
# Bifrost expects model names in provider/model format
# (e.g. "anthropic/claude-sonnet-4-6") sent directly to its
# OpenAI-compatible endpoint. We use custom_llm_provider="openai"
# so LiteLLM doesn't try to route based on the provider prefix.
model = self.config.deployment_name or self.config.model_name
else:
model = f"{model_provider}/{self.config.deployment_name or self.config.model_name}"
@@ -560,10 +550,7 @@ class LitellmLLM(LLM):
if structured_response_format:
optional_kwargs["response_format"] = structured_response_format
if (
not (is_claude_model or is_ollama or is_mistral)
or is_openai_compatible_proxy
):
if not (is_claude_model or is_ollama or is_mistral) or is_bifrost:
# Litellm bug: tool_choice is dropped silently if not specified here for OpenAI
# However, this param breaks Anthropic and Mistral models,
# so it must be conditionally included unless the request is

View File

@@ -15,8 +15,6 @@ LITELLM_PROXY_PROVIDER_NAME = "litellm_proxy"
BIFROST_PROVIDER_NAME = "bifrost"
OPENAI_COMPATIBLE_PROVIDER_NAME = "openai_compatible"
# Providers that use optional Bearer auth from custom_config
PROVIDERS_WITH_SPECIAL_API_KEY_HANDLING: dict[str, str] = {
LlmProviderNames.OLLAMA_CHAT: OLLAMA_API_KEY_CONFIG_KEY,

View File

@@ -19,7 +19,6 @@ from onyx.llm.well_known_providers.constants import BIFROST_PROVIDER_NAME
from onyx.llm.well_known_providers.constants import LITELLM_PROXY_PROVIDER_NAME
from onyx.llm.well_known_providers.constants import LM_STUDIO_PROVIDER_NAME
from onyx.llm.well_known_providers.constants import OLLAMA_PROVIDER_NAME
from onyx.llm.well_known_providers.constants import OPENAI_COMPATIBLE_PROVIDER_NAME
from onyx.llm.well_known_providers.constants import OPENAI_PROVIDER_NAME
from onyx.llm.well_known_providers.constants import OPENROUTER_PROVIDER_NAME
from onyx.llm.well_known_providers.constants import VERTEXAI_PROVIDER_NAME
@@ -52,7 +51,6 @@ def _get_provider_to_models_map() -> dict[str, list[str]]:
OPENROUTER_PROVIDER_NAME: [], # Dynamic - fetched from OpenRouter API
LITELLM_PROXY_PROVIDER_NAME: [], # Dynamic - fetched from LiteLLM proxy API
BIFROST_PROVIDER_NAME: [], # Dynamic - fetched from Bifrost API
OPENAI_COMPATIBLE_PROVIDER_NAME: [], # Dynamic - fetched from OpenAI-compatible API
}
@@ -338,7 +336,6 @@ def get_provider_display_name(provider_name: str) -> str:
VERTEXAI_PROVIDER_NAME: "Google Vertex AI",
OPENROUTER_PROVIDER_NAME: "OpenRouter",
LITELLM_PROXY_PROVIDER_NAME: "LiteLLM Proxy",
OPENAI_COMPATIBLE_PROVIDER_NAME: "OpenAI-Compatible",
}
if provider_name in _ONYX_PROVIDER_DISPLAY_NAMES:

View File

@@ -96,9 +96,6 @@ from onyx.server.features.persona.api import admin_router as admin_persona_route
from onyx.server.features.persona.api import agents_router
from onyx.server.features.persona.api import basic_router as persona_router
from onyx.server.features.projects.api import router as projects_router
from onyx.server.features.proposal_review.api.api import (
router as proposal_review_router,
)
from onyx.server.features.tool.api import admin_router as admin_tool_router
from onyx.server.features.tool.api import router as tool_router
from onyx.server.features.user_oauth_token.api import router as user_oauth_token_router
@@ -472,7 +469,6 @@ def get_application(lifespan_override: Lifespan | None = None) -> FastAPI:
include_router_with_global_prefix_prepended(application, projects_router)
include_router_with_global_prefix_prepended(application, public_build_router)
include_router_with_global_prefix_prepended(application, build_router)
include_router_with_global_prefix_prepended(application, proposal_review_router)
include_router_with_global_prefix_prepended(application, document_set_router)
include_router_with_global_prefix_prepended(application, hierarchy_router)
include_router_with_global_prefix_prepended(application, search_settings_router)

Some files were not shown because too many files have changed in this diff Show More